# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Module containing functionals for intensity normalisation.""" from typing import List, Tuple, Union import torch from kornia.core import ImageModule as Module from kornia.core import Tensor __all__ = ["Denormalize", "Normalize", "denormalize", "normalize", "normalize_min_max"] class Normalize(Module): r"""Normalize a tensor image with mean and standard deviation. .. math:: \text{input[channel] = (input[channel] - mean[channel]) / std[channel]} Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels, Args: mean: Mean for each channel. std: Standard deviations for each channel. Shape: - Input: Image tensor of size :math:`(*, C, ...)`. - Output: Normalised tensor with same size as input :math:`(*, C, ...)`. Examples: >>> x = torch.rand(1, 4, 3, 3) >>> out = Normalize(0.0, 255.)(x) >>> out.shape torch.Size([1, 4, 3, 3]) >>> x = torch.rand(1, 4, 3, 3) >>> mean = torch.zeros(4) >>> std = 255. * torch.ones(4) >>> out = Normalize(mean, std)(x) >>> out.shape torch.Size([1, 4, 3, 3]) """ def __init__( self, mean: Union[Tensor, Tuple[float], List[float], float], std: Union[Tensor, Tuple[float], List[float], float], ) -> None: super().__init__() if isinstance(mean, (int, float)): mean = torch.tensor([mean]) if isinstance(std, (int, float)): std = torch.tensor([std]) if isinstance(mean, (tuple, list)): mean = torch.tensor(mean)[None] if isinstance(std, (tuple, list)): std = torch.tensor(std)[None] self.mean = mean self.std = std def forward(self, input: Tensor) -> Tensor: return normalize(input, self.mean, self.std) def __repr__(self) -> str: repr = f"(mean={self.mean}, std={self.std})" return self.__class__.__name__ + repr def normalize(data: Tensor, mean: Tensor, std: Tensor) -> Tensor: r"""Normalize an image/video tensor with mean and standard deviation. .. math:: \text{input[channel] = (input[channel] - mean[channel]) / std[channel]} Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels, Args: data: Image tensor of size :math:`(B, C, *)`. mean: Mean for each channel. std: Standard deviations for each channel. Return: Normalised tensor with same size as input :math:`(B, C, *)`. Examples: >>> x = torch.rand(1, 4, 3, 3) >>> out = normalize(x, torch.tensor([0.0]), torch.tensor([255.])) >>> out.shape torch.Size([1, 4, 3, 3]) >>> x = torch.rand(1, 4, 3, 3) >>> mean = torch.zeros(4) >>> std = 255. * torch.ones(4) >>> out = normalize(x, mean, std) >>> out.shape torch.Size([1, 4, 3, 3]) """ shape = data.shape if torch.onnx.is_in_onnx_export(): if not isinstance(mean, Tensor) or not isinstance(std, Tensor): raise ValueError("Only tensor is accepted when converting to ONNX.") if mean.shape[0] != 1 or std.shape[0] != 1: raise ValueError( "Batch dimension must be one for broadcasting when converting to ONNX." f"Try changing mean shape and std shape from ({mean.shape}, {std.shape}) to (1, C) or (1, C, 1, 1)." ) else: if isinstance(mean, float): mean = torch.tensor([mean] * shape[1], device=data.device, dtype=data.dtype) if isinstance(std, float): std = torch.tensor([std] * shape[1], device=data.device, dtype=data.dtype) # Allow broadcast on channel dimension if mean.shape and mean.shape[0] != 1: if mean.shape[0] != data.shape[1] and mean.shape[:2] != data.shape[:2]: raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.") # Allow broadcast on channel dimension if std.shape and std.shape[0] != 1: if std.shape[0] != data.shape[1] and std.shape[:2] != data.shape[:2]: raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.") mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype) std = torch.as_tensor(std, device=data.device, dtype=data.dtype) mean = mean[..., None] std = std[..., None] out: Tensor = (data.view(shape[0], shape[1], -1) - mean) / std return out.view(shape) class Denormalize(Module): r"""Denormalize a tensor image with mean and standard deviation. .. math:: \text{input[channel] = (input[channel] * std[channel]) + mean[channel]} Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels, Args: mean: Mean for each channel. std: Standard deviations for each channel. Shape: - Input: Image tensor of size :math:`(*, C, ...)`. - Output: Denormalised tensor with same size as input :math:`(*, C, ...)`. Examples: >>> x = torch.rand(1, 4, 3, 3) >>> out = Denormalize(0.0, 255.)(x) >>> out.shape torch.Size([1, 4, 3, 3]) >>> x = torch.rand(1, 4, 3, 3, 3) >>> mean = torch.zeros(1, 4) >>> std = 255. * torch.ones(1, 4) >>> out = Denormalize(mean, std)(x) >>> out.shape torch.Size([1, 4, 3, 3, 3]) """ def __init__(self, mean: Union[Tensor, float], std: Union[Tensor, float]) -> None: super().__init__() self.mean = mean self.std = std def forward(self, input: Tensor) -> Tensor: return denormalize(input, self.mean, self.std) def __repr__(self) -> str: repr = f"(mean={self.mean}, std={self.std})" return self.__class__.__name__ + repr def denormalize(data: Tensor, mean: Union[Tensor, float], std: Union[Tensor, float]) -> Tensor: r"""Denormalize an image/video tensor with mean and standard deviation. .. math:: \text{input[channel] = (input[channel] * std[channel]) + mean[channel]} Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels, Args: data: Image tensor of size :math:`(B, C, *)`. mean: Mean for each channel. std: Standard deviations for each channel. Return: Denormalised tensor with same size as input :math:`(B, C, *)`. Examples: >>> x = torch.rand(1, 4, 3, 3) >>> out = denormalize(x, 0.0, 255.) >>> out.shape torch.Size([1, 4, 3, 3]) >>> x = torch.rand(1, 4, 3, 3, 3) >>> mean = torch.zeros(1, 4) >>> std = 255. * torch.ones(1, 4) >>> out = denormalize(x, mean, std) >>> out.shape torch.Size([1, 4, 3, 3, 3]) """ shape = data.shape if torch.onnx.is_in_onnx_export(): if not isinstance(mean, Tensor) or not isinstance(std, Tensor): raise ValueError("Only tensor is accepted when converting to ONNX.") if mean.shape[0] != 1 or std.shape[0] != 1: raise ValueError("Batch dimension must be one for broadcasting when converting to ONNX.") else: if isinstance(mean, float): mean = torch.tensor([mean] * shape[1], device=data.device, dtype=data.dtype) if isinstance(std, float): std = torch.tensor([std] * shape[1], device=data.device, dtype=data.dtype) # Allow broadcast on channel dimension if mean.shape and mean.shape[0] != 1: if mean.shape[0] != data.shape[-3] and mean.shape[:2] != data.shape[:2]: raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.") # Allow broadcast on channel dimension if std.shape and std.shape[0] != 1: if std.shape[0] != data.shape[-3] and std.shape[:2] != data.shape[:2]: raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.") mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype) std = torch.as_tensor(std, device=data.device, dtype=data.dtype) if mean.dim() == 1: mean = mean.view(1, -1, *([1] * (data.dim() - 2))) # If the tensor is >1D (e.g., (B, C)), reshape to (B, C, 1, ...) else: while len(mean.shape) < data.dim(): mean = mean.unsqueeze(-1) if std.dim() == 1: std = std.view(1, -1, *([1] * (data.dim() - 2))) else: while len(std.shape) < data.dim(): std = std.unsqueeze(-1) return torch.addcmul(mean, data, std) def normalize_min_max(x: Tensor, min_val: float = 0.0, max_val: float = 1.0, eps: float = 1e-6) -> Tensor: r"""Normalise an image/video tensor by MinMax and re-scales the value between a range. The data is normalised using the following formulation: .. math:: y_i = (b - a) * \frac{x_i - \text{min}(x)}{\text{max}(x) - \text{min}(x)} + a where :math:`a` is :math:`\text{min_val}` and :math:`b` is :math:`\text{max_val}`. Args: x: The image tensor to be normalised with shape :math:`(B, C, *)`. min_val: The minimum value for the new range. max_val: The maximum value for the new range. eps: Float number to avoid zero division. Returns: The normalised image tensor with same shape as input :math:`(B, C, *)`. Example: >>> x = torch.rand(1, 5, 3, 3) >>> x_norm = normalize_min_max(x, min_val=-1., max_val=1.) >>> x_norm.min() tensor(-1.) >>> x_norm.max() tensor(1.0000) """ if not isinstance(x, Tensor): raise TypeError(f"data should be a tensor. Got: {type(x)}.") if not isinstance(min_val, float): raise TypeError(f"'min_val' should be a float. Got: {type(min_val)}.") if not isinstance(max_val, float): raise TypeError(f"'b' should be a float. Got: {type(max_val)}.") if len(x.shape) < 3: raise ValueError(f"Input shape must be at least a 3d tensor. Got: {x.shape}.") shape = x.shape B, C = shape[0], shape[1] x_reshaped = x.view(B, C, -1) x_min = x_reshaped.min(-1, keepdim=True)[0] # Shape: (B, C, 1) x_max = x_reshaped.max(-1, keepdim=True)[0] # Shape: (B, C, 1) x_out = (max_val - min_val) * (x_reshaped - x_min) / (x_max - x_min + eps) + min_val return x_out.view(shape)