| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- """Module containing functionals for intensity normalisation."""
- from typing import List, Tuple, Union
- import torch
- from kornia.core import ImageModule as Module
- from kornia.core import Tensor
- __all__ = ["Denormalize", "Normalize", "denormalize", "normalize", "normalize_min_max"]
- class Normalize(Module):
- r"""Normalize a tensor image with mean and standard deviation.
- .. math::
- \text{input[channel] = (input[channel] - mean[channel]) / std[channel]}
- Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels,
- Args:
- mean: Mean for each channel.
- std: Standard deviations for each channel.
- Shape:
- - Input: Image tensor of size :math:`(*, C, ...)`.
- - Output: Normalised tensor with same size as input :math:`(*, C, ...)`.
- Examples:
- >>> x = torch.rand(1, 4, 3, 3)
- >>> out = Normalize(0.0, 255.)(x)
- >>> out.shape
- torch.Size([1, 4, 3, 3])
- >>> x = torch.rand(1, 4, 3, 3)
- >>> mean = torch.zeros(4)
- >>> std = 255. * torch.ones(4)
- >>> out = Normalize(mean, std)(x)
- >>> out.shape
- torch.Size([1, 4, 3, 3])
- """
- def __init__(
- self,
- mean: Union[Tensor, Tuple[float], List[float], float],
- std: Union[Tensor, Tuple[float], List[float], float],
- ) -> None:
- super().__init__()
- if isinstance(mean, (int, float)):
- mean = torch.tensor([mean])
- if isinstance(std, (int, float)):
- std = torch.tensor([std])
- if isinstance(mean, (tuple, list)):
- mean = torch.tensor(mean)[None]
- if isinstance(std, (tuple, list)):
- std = torch.tensor(std)[None]
- self.mean = mean
- self.std = std
- def forward(self, input: Tensor) -> Tensor:
- return normalize(input, self.mean, self.std)
- def __repr__(self) -> str:
- repr = f"(mean={self.mean}, std={self.std})"
- return self.__class__.__name__ + repr
- def normalize(data: Tensor, mean: Tensor, std: Tensor) -> Tensor:
- r"""Normalize an image/video tensor with mean and standard deviation.
- .. math::
- \text{input[channel] = (input[channel] - mean[channel]) / std[channel]}
- Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels,
- Args:
- data: Image tensor of size :math:`(B, C, *)`.
- mean: Mean for each channel.
- std: Standard deviations for each channel.
- Return:
- Normalised tensor with same size as input :math:`(B, C, *)`.
- Examples:
- >>> x = torch.rand(1, 4, 3, 3)
- >>> out = normalize(x, torch.tensor([0.0]), torch.tensor([255.]))
- >>> out.shape
- torch.Size([1, 4, 3, 3])
- >>> x = torch.rand(1, 4, 3, 3)
- >>> mean = torch.zeros(4)
- >>> std = 255. * torch.ones(4)
- >>> out = normalize(x, mean, std)
- >>> out.shape
- torch.Size([1, 4, 3, 3])
- """
- shape = data.shape
- if torch.onnx.is_in_onnx_export():
- if not isinstance(mean, Tensor) or not isinstance(std, Tensor):
- raise ValueError("Only tensor is accepted when converting to ONNX.")
- if mean.shape[0] != 1 or std.shape[0] != 1:
- raise ValueError(
- "Batch dimension must be one for broadcasting when converting to ONNX."
- f"Try changing mean shape and std shape from ({mean.shape}, {std.shape}) to (1, C) or (1, C, 1, 1)."
- )
- else:
- if isinstance(mean, float):
- mean = torch.tensor([mean] * shape[1], device=data.device, dtype=data.dtype)
- if isinstance(std, float):
- std = torch.tensor([std] * shape[1], device=data.device, dtype=data.dtype)
- # Allow broadcast on channel dimension
- if mean.shape and mean.shape[0] != 1:
- if mean.shape[0] != data.shape[1] and mean.shape[:2] != data.shape[:2]:
- raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.")
- # Allow broadcast on channel dimension
- if std.shape and std.shape[0] != 1:
- if std.shape[0] != data.shape[1] and std.shape[:2] != data.shape[:2]:
- raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.")
- mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype)
- std = torch.as_tensor(std, device=data.device, dtype=data.dtype)
- mean = mean[..., None]
- std = std[..., None]
- out: Tensor = (data.view(shape[0], shape[1], -1) - mean) / std
- return out.view(shape)
- class Denormalize(Module):
- r"""Denormalize a tensor image with mean and standard deviation.
- .. math::
- \text{input[channel] = (input[channel] * std[channel]) + mean[channel]}
- Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels,
- Args:
- mean: Mean for each channel.
- std: Standard deviations for each channel.
- Shape:
- - Input: Image tensor of size :math:`(*, C, ...)`.
- - Output: Denormalised tensor with same size as input :math:`(*, C, ...)`.
- Examples:
- >>> x = torch.rand(1, 4, 3, 3)
- >>> out = Denormalize(0.0, 255.)(x)
- >>> out.shape
- torch.Size([1, 4, 3, 3])
- >>> x = torch.rand(1, 4, 3, 3, 3)
- >>> mean = torch.zeros(1, 4)
- >>> std = 255. * torch.ones(1, 4)
- >>> out = Denormalize(mean, std)(x)
- >>> out.shape
- torch.Size([1, 4, 3, 3, 3])
- """
- def __init__(self, mean: Union[Tensor, float], std: Union[Tensor, float]) -> None:
- super().__init__()
- self.mean = mean
- self.std = std
- def forward(self, input: Tensor) -> Tensor:
- return denormalize(input, self.mean, self.std)
- def __repr__(self) -> str:
- repr = f"(mean={self.mean}, std={self.std})"
- return self.__class__.__name__ + repr
- def denormalize(data: Tensor, mean: Union[Tensor, float], std: Union[Tensor, float]) -> Tensor:
- r"""Denormalize an image/video tensor with mean and standard deviation.
- .. math::
- \text{input[channel] = (input[channel] * std[channel]) + mean[channel]}
- Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels,
- Args:
- data: Image tensor of size :math:`(B, C, *)`.
- mean: Mean for each channel.
- std: Standard deviations for each channel.
- Return:
- Denormalised tensor with same size as input :math:`(B, C, *)`.
- Examples:
- >>> x = torch.rand(1, 4, 3, 3)
- >>> out = denormalize(x, 0.0, 255.)
- >>> out.shape
- torch.Size([1, 4, 3, 3])
- >>> x = torch.rand(1, 4, 3, 3, 3)
- >>> mean = torch.zeros(1, 4)
- >>> std = 255. * torch.ones(1, 4)
- >>> out = denormalize(x, mean, std)
- >>> out.shape
- torch.Size([1, 4, 3, 3, 3])
- """
- shape = data.shape
- if torch.onnx.is_in_onnx_export():
- if not isinstance(mean, Tensor) or not isinstance(std, Tensor):
- raise ValueError("Only tensor is accepted when converting to ONNX.")
- if mean.shape[0] != 1 or std.shape[0] != 1:
- raise ValueError("Batch dimension must be one for broadcasting when converting to ONNX.")
- else:
- if isinstance(mean, float):
- mean = torch.tensor([mean] * shape[1], device=data.device, dtype=data.dtype)
- if isinstance(std, float):
- std = torch.tensor([std] * shape[1], device=data.device, dtype=data.dtype)
- # Allow broadcast on channel dimension
- if mean.shape and mean.shape[0] != 1:
- if mean.shape[0] != data.shape[-3] and mean.shape[:2] != data.shape[:2]:
- raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.")
- # Allow broadcast on channel dimension
- if std.shape and std.shape[0] != 1:
- if std.shape[0] != data.shape[-3] and std.shape[:2] != data.shape[:2]:
- raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.")
- mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype)
- std = torch.as_tensor(std, device=data.device, dtype=data.dtype)
- if mean.dim() == 1:
- mean = mean.view(1, -1, *([1] * (data.dim() - 2)))
- # If the tensor is >1D (e.g., (B, C)), reshape to (B, C, 1, ...)
- else:
- while len(mean.shape) < data.dim():
- mean = mean.unsqueeze(-1)
- if std.dim() == 1:
- std = std.view(1, -1, *([1] * (data.dim() - 2)))
- else:
- while len(std.shape) < data.dim():
- std = std.unsqueeze(-1)
- return torch.addcmul(mean, data, std)
- def normalize_min_max(x: Tensor, min_val: float = 0.0, max_val: float = 1.0, eps: float = 1e-6) -> Tensor:
- r"""Normalise an image/video tensor by MinMax and re-scales the value between a range.
- The data is normalised using the following formulation:
- .. math::
- y_i = (b - a) * \frac{x_i - \text{min}(x)}{\text{max}(x) - \text{min}(x)} + a
- where :math:`a` is :math:`\text{min_val}` and :math:`b` is :math:`\text{max_val}`.
- Args:
- x: The image tensor to be normalised with shape :math:`(B, C, *)`.
- min_val: The minimum value for the new range.
- max_val: The maximum value for the new range.
- eps: Float number to avoid zero division.
- Returns:
- The normalised image tensor with same shape as input :math:`(B, C, *)`.
- Example:
- >>> x = torch.rand(1, 5, 3, 3)
- >>> x_norm = normalize_min_max(x, min_val=-1., max_val=1.)
- >>> x_norm.min()
- tensor(-1.)
- >>> x_norm.max()
- tensor(1.0000)
- """
- if not isinstance(x, Tensor):
- raise TypeError(f"data should be a tensor. Got: {type(x)}.")
- if not isinstance(min_val, float):
- raise TypeError(f"'min_val' should be a float. Got: {type(min_val)}.")
- if not isinstance(max_val, float):
- raise TypeError(f"'b' should be a float. Got: {type(max_val)}.")
- if len(x.shape) < 3:
- raise ValueError(f"Input shape must be at least a 3d tensor. Got: {x.shape}.")
- shape = x.shape
- B, C = shape[0], shape[1]
- x_reshaped = x.view(B, C, -1)
- x_min = x_reshaped.min(-1, keepdim=True)[0] # Shape: (B, C, 1)
- x_max = x_reshaped.max(-1, keepdim=True)[0] # Shape: (B, C, 1)
- x_out = (max_val - min_val) * (x_reshaped - x_min) / (x_max - x_min + eps) + min_val
- return x_out.view(shape)
|