normalize.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. """Module containing functionals for intensity normalisation."""
  18. from typing import List, Tuple, Union
  19. import torch
  20. from kornia.core import ImageModule as Module
  21. from kornia.core import Tensor
  22. __all__ = ["Denormalize", "Normalize", "denormalize", "normalize", "normalize_min_max"]
  23. class Normalize(Module):
  24. r"""Normalize a tensor image with mean and standard deviation.
  25. .. math::
  26. \text{input[channel] = (input[channel] - mean[channel]) / std[channel]}
  27. Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels,
  28. Args:
  29. mean: Mean for each channel.
  30. std: Standard deviations for each channel.
  31. Shape:
  32. - Input: Image tensor of size :math:`(*, C, ...)`.
  33. - Output: Normalised tensor with same size as input :math:`(*, C, ...)`.
  34. Examples:
  35. >>> x = torch.rand(1, 4, 3, 3)
  36. >>> out = Normalize(0.0, 255.)(x)
  37. >>> out.shape
  38. torch.Size([1, 4, 3, 3])
  39. >>> x = torch.rand(1, 4, 3, 3)
  40. >>> mean = torch.zeros(4)
  41. >>> std = 255. * torch.ones(4)
  42. >>> out = Normalize(mean, std)(x)
  43. >>> out.shape
  44. torch.Size([1, 4, 3, 3])
  45. """
  46. def __init__(
  47. self,
  48. mean: Union[Tensor, Tuple[float], List[float], float],
  49. std: Union[Tensor, Tuple[float], List[float], float],
  50. ) -> None:
  51. super().__init__()
  52. if isinstance(mean, (int, float)):
  53. mean = torch.tensor([mean])
  54. if isinstance(std, (int, float)):
  55. std = torch.tensor([std])
  56. if isinstance(mean, (tuple, list)):
  57. mean = torch.tensor(mean)[None]
  58. if isinstance(std, (tuple, list)):
  59. std = torch.tensor(std)[None]
  60. self.mean = mean
  61. self.std = std
  62. def forward(self, input: Tensor) -> Tensor:
  63. return normalize(input, self.mean, self.std)
  64. def __repr__(self) -> str:
  65. repr = f"(mean={self.mean}, std={self.std})"
  66. return self.__class__.__name__ + repr
  67. def normalize(data: Tensor, mean: Tensor, std: Tensor) -> Tensor:
  68. r"""Normalize an image/video tensor with mean and standard deviation.
  69. .. math::
  70. \text{input[channel] = (input[channel] - mean[channel]) / std[channel]}
  71. Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels,
  72. Args:
  73. data: Image tensor of size :math:`(B, C, *)`.
  74. mean: Mean for each channel.
  75. std: Standard deviations for each channel.
  76. Return:
  77. Normalised tensor with same size as input :math:`(B, C, *)`.
  78. Examples:
  79. >>> x = torch.rand(1, 4, 3, 3)
  80. >>> out = normalize(x, torch.tensor([0.0]), torch.tensor([255.]))
  81. >>> out.shape
  82. torch.Size([1, 4, 3, 3])
  83. >>> x = torch.rand(1, 4, 3, 3)
  84. >>> mean = torch.zeros(4)
  85. >>> std = 255. * torch.ones(4)
  86. >>> out = normalize(x, mean, std)
  87. >>> out.shape
  88. torch.Size([1, 4, 3, 3])
  89. """
  90. shape = data.shape
  91. if torch.onnx.is_in_onnx_export():
  92. if not isinstance(mean, Tensor) or not isinstance(std, Tensor):
  93. raise ValueError("Only tensor is accepted when converting to ONNX.")
  94. if mean.shape[0] != 1 or std.shape[0] != 1:
  95. raise ValueError(
  96. "Batch dimension must be one for broadcasting when converting to ONNX."
  97. f"Try changing mean shape and std shape from ({mean.shape}, {std.shape}) to (1, C) or (1, C, 1, 1)."
  98. )
  99. else:
  100. if isinstance(mean, float):
  101. mean = torch.tensor([mean] * shape[1], device=data.device, dtype=data.dtype)
  102. if isinstance(std, float):
  103. std = torch.tensor([std] * shape[1], device=data.device, dtype=data.dtype)
  104. # Allow broadcast on channel dimension
  105. if mean.shape and mean.shape[0] != 1:
  106. if mean.shape[0] != data.shape[1] and mean.shape[:2] != data.shape[:2]:
  107. raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.")
  108. # Allow broadcast on channel dimension
  109. if std.shape and std.shape[0] != 1:
  110. if std.shape[0] != data.shape[1] and std.shape[:2] != data.shape[:2]:
  111. raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.")
  112. mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype)
  113. std = torch.as_tensor(std, device=data.device, dtype=data.dtype)
  114. mean = mean[..., None]
  115. std = std[..., None]
  116. out: Tensor = (data.view(shape[0], shape[1], -1) - mean) / std
  117. return out.view(shape)
  118. class Denormalize(Module):
  119. r"""Denormalize a tensor image with mean and standard deviation.
  120. .. math::
  121. \text{input[channel] = (input[channel] * std[channel]) + mean[channel]}
  122. Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels,
  123. Args:
  124. mean: Mean for each channel.
  125. std: Standard deviations for each channel.
  126. Shape:
  127. - Input: Image tensor of size :math:`(*, C, ...)`.
  128. - Output: Denormalised tensor with same size as input :math:`(*, C, ...)`.
  129. Examples:
  130. >>> x = torch.rand(1, 4, 3, 3)
  131. >>> out = Denormalize(0.0, 255.)(x)
  132. >>> out.shape
  133. torch.Size([1, 4, 3, 3])
  134. >>> x = torch.rand(1, 4, 3, 3, 3)
  135. >>> mean = torch.zeros(1, 4)
  136. >>> std = 255. * torch.ones(1, 4)
  137. >>> out = Denormalize(mean, std)(x)
  138. >>> out.shape
  139. torch.Size([1, 4, 3, 3, 3])
  140. """
  141. def __init__(self, mean: Union[Tensor, float], std: Union[Tensor, float]) -> None:
  142. super().__init__()
  143. self.mean = mean
  144. self.std = std
  145. def forward(self, input: Tensor) -> Tensor:
  146. return denormalize(input, self.mean, self.std)
  147. def __repr__(self) -> str:
  148. repr = f"(mean={self.mean}, std={self.std})"
  149. return self.__class__.__name__ + repr
  150. def denormalize(data: Tensor, mean: Union[Tensor, float], std: Union[Tensor, float]) -> Tensor:
  151. r"""Denormalize an image/video tensor with mean and standard deviation.
  152. .. math::
  153. \text{input[channel] = (input[channel] * std[channel]) + mean[channel]}
  154. Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels,
  155. Args:
  156. data: Image tensor of size :math:`(B, C, *)`.
  157. mean: Mean for each channel.
  158. std: Standard deviations for each channel.
  159. Return:
  160. Denormalised tensor with same size as input :math:`(B, C, *)`.
  161. Examples:
  162. >>> x = torch.rand(1, 4, 3, 3)
  163. >>> out = denormalize(x, 0.0, 255.)
  164. >>> out.shape
  165. torch.Size([1, 4, 3, 3])
  166. >>> x = torch.rand(1, 4, 3, 3, 3)
  167. >>> mean = torch.zeros(1, 4)
  168. >>> std = 255. * torch.ones(1, 4)
  169. >>> out = denormalize(x, mean, std)
  170. >>> out.shape
  171. torch.Size([1, 4, 3, 3, 3])
  172. """
  173. shape = data.shape
  174. if torch.onnx.is_in_onnx_export():
  175. if not isinstance(mean, Tensor) or not isinstance(std, Tensor):
  176. raise ValueError("Only tensor is accepted when converting to ONNX.")
  177. if mean.shape[0] != 1 or std.shape[0] != 1:
  178. raise ValueError("Batch dimension must be one for broadcasting when converting to ONNX.")
  179. else:
  180. if isinstance(mean, float):
  181. mean = torch.tensor([mean] * shape[1], device=data.device, dtype=data.dtype)
  182. if isinstance(std, float):
  183. std = torch.tensor([std] * shape[1], device=data.device, dtype=data.dtype)
  184. # Allow broadcast on channel dimension
  185. if mean.shape and mean.shape[0] != 1:
  186. if mean.shape[0] != data.shape[-3] and mean.shape[:2] != data.shape[:2]:
  187. raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.")
  188. # Allow broadcast on channel dimension
  189. if std.shape and std.shape[0] != 1:
  190. if std.shape[0] != data.shape[-3] and std.shape[:2] != data.shape[:2]:
  191. raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.")
  192. mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype)
  193. std = torch.as_tensor(std, device=data.device, dtype=data.dtype)
  194. if mean.dim() == 1:
  195. mean = mean.view(1, -1, *([1] * (data.dim() - 2)))
  196. # If the tensor is >1D (e.g., (B, C)), reshape to (B, C, 1, ...)
  197. else:
  198. while len(mean.shape) < data.dim():
  199. mean = mean.unsqueeze(-1)
  200. if std.dim() == 1:
  201. std = std.view(1, -1, *([1] * (data.dim() - 2)))
  202. else:
  203. while len(std.shape) < data.dim():
  204. std = std.unsqueeze(-1)
  205. return torch.addcmul(mean, data, std)
  206. def normalize_min_max(x: Tensor, min_val: float = 0.0, max_val: float = 1.0, eps: float = 1e-6) -> Tensor:
  207. r"""Normalise an image/video tensor by MinMax and re-scales the value between a range.
  208. The data is normalised using the following formulation:
  209. .. math::
  210. y_i = (b - a) * \frac{x_i - \text{min}(x)}{\text{max}(x) - \text{min}(x)} + a
  211. where :math:`a` is :math:`\text{min_val}` and :math:`b` is :math:`\text{max_val}`.
  212. Args:
  213. x: The image tensor to be normalised with shape :math:`(B, C, *)`.
  214. min_val: The minimum value for the new range.
  215. max_val: The maximum value for the new range.
  216. eps: Float number to avoid zero division.
  217. Returns:
  218. The normalised image tensor with same shape as input :math:`(B, C, *)`.
  219. Example:
  220. >>> x = torch.rand(1, 5, 3, 3)
  221. >>> x_norm = normalize_min_max(x, min_val=-1., max_val=1.)
  222. >>> x_norm.min()
  223. tensor(-1.)
  224. >>> x_norm.max()
  225. tensor(1.0000)
  226. """
  227. if not isinstance(x, Tensor):
  228. raise TypeError(f"data should be a tensor. Got: {type(x)}.")
  229. if not isinstance(min_val, float):
  230. raise TypeError(f"'min_val' should be a float. Got: {type(min_val)}.")
  231. if not isinstance(max_val, float):
  232. raise TypeError(f"'b' should be a float. Got: {type(max_val)}.")
  233. if len(x.shape) < 3:
  234. raise ValueError(f"Input shape must be at least a 3d tensor. Got: {x.shape}.")
  235. shape = x.shape
  236. B, C = shape[0], shape[1]
  237. x_reshaped = x.view(B, C, -1)
  238. x_min = x_reshaped.min(-1, keepdim=True)[0] # Shape: (B, C, 1)
  239. x_max = x_reshaped.max(-1, keepdim=True)[0] # Shape: (B, C, 1)
  240. x_out = (max_val - min_val) * (x_reshaped - x_min) / (x_max - x_min + eps) + min_val
  241. return x_out.view(shape)