| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529 |
- # Copyright The Lightning team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from collections.abc import Sequence
- from typing import List, Optional, Union
- import torch
- from torch import Tensor
- from torch.nn import functional as F # noqa: N812
- from typing_extensions import Literal
- from torchmetrics.functional.image.utils import _gaussian_kernel_2d, _gaussian_kernel_3d, _reflection_pad_3d
- from torchmetrics.utilities.checks import _check_same_shape
- from torchmetrics.utilities.distributed import reduce
- def _ssim_check_inputs(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]:
- """Update and returns variables required to compute Structural Similarity Index Measure.
- Args:
- preds: Predicted tensor
- target: Ground truth tensor
- """
- if preds.dtype != target.dtype:
- target = target.to(preds.dtype)
- _check_same_shape(preds, target)
- if len(preds.shape) not in (4, 5):
- raise ValueError(
- "Expected `preds` and `target` to have BxCxHxW or BxCxDxHxW shape."
- f" Got preds: {preds.shape} and target: {target.shape}."
- )
- return preds, target
- def _ssim_update(
- preds: Tensor,
- target: Tensor,
- gaussian_kernel: bool = True,
- sigma: Union[float, Sequence[float]] = 1.5,
- kernel_size: Union[int, Sequence[int]] = 11,
- data_range: Optional[Union[float, tuple[float, float]]] = None,
- k1: float = 0.01,
- k2: float = 0.03,
- return_full_image: bool = False,
- return_contrast_sensitivity: bool = False,
- ) -> Union[Tensor, tuple[Tensor, Tensor]]:
- """Compute Structural Similarity Index Measure.
- Args:
- preds: estimated image
- target: ground truth image
- gaussian_kernel: If true (default), a gaussian kernel is used, if false a uniform kernel is used
- sigma: Standard deviation of the gaussian kernel, anisotropic kernels are possible.
- Ignored if a uniform kernel is used
- kernel_size: the size of the uniform kernel, anisotropic kernels are possible.
- Ignored if a Gaussian kernel is used
- data_range: Range of the image. If ``None``, it is determined from the image (max - min)
- k1: Parameter of SSIM.
- k2: Parameter of SSIM.
- return_full_image: If true, the full ``ssim`` image is returned as a second argument.
- Mutually exclusive with ``return_contrast_sensitivity``
- return_contrast_sensitivity: If true, the contrast term is returned as a second argument.
- The luminance term can be obtained with luminance=ssim/contrast
- Mutually exclusive with ``return_full_image``
- """
- is_3d = preds.ndim == 5
- if not isinstance(kernel_size, Sequence):
- kernel_size = 3 * [kernel_size] if is_3d else 2 * [kernel_size]
- if not isinstance(sigma, Sequence):
- sigma = 3 * [sigma] if is_3d else 2 * [sigma]
- if len(kernel_size) != len(target.shape) - 2:
- raise ValueError(
- f"`kernel_size` has dimension {len(kernel_size)}, but expected to be two less that target dimensionality,"
- f" which is: {len(target.shape)}"
- )
- if len(kernel_size) not in (2, 3):
- raise ValueError(
- f"Expected `kernel_size` dimension to be 2 or 3. `kernel_size` dimensionality: {len(kernel_size)}"
- )
- if len(sigma) != len(target.shape) - 2:
- raise ValueError(
- f"`kernel_size` has dimension {len(kernel_size)}, but expected to be two less that target dimensionality,"
- f" which is: {len(target.shape)}"
- )
- if len(sigma) not in (2, 3):
- raise ValueError(
- f"Expected `kernel_size` dimension to be 2 or 3. `kernel_size` dimensionality: {len(kernel_size)}"
- )
- if return_full_image and return_contrast_sensitivity:
- raise ValueError("Arguments `return_full_image` and `return_contrast_sensitivity` are mutually exclusive.")
- if any(x % 2 == 0 or x <= 0 for x in kernel_size):
- raise ValueError(f"Expected `kernel_size` to have odd positive number. Got {kernel_size}.")
- if any(y <= 0 for y in sigma):
- raise ValueError(f"Expected `sigma` to have positive number. Got {sigma}.")
- if data_range is None:
- data_range = max(preds.max() - preds.min(), target.max() - target.min()) # type: ignore[call-overload]
- elif isinstance(data_range, tuple):
- preds = torch.clamp(preds, min=data_range[0], max=data_range[1])
- target = torch.clamp(target, min=data_range[0], max=data_range[1])
- data_range = data_range[1] - data_range[0]
- c1 = pow(k1 * data_range, 2) # type: ignore[operator]
- c2 = pow(k2 * data_range, 2) # type: ignore[operator]
- device = preds.device
- channel = preds.size(1)
- dtype = preds.dtype
- gauss_kernel_size = [int(3.5 * s + 0.5) * 2 + 1 for s in sigma]
- if gaussian_kernel:
- pad_h = (gauss_kernel_size[0] - 1) // 2
- pad_w = (gauss_kernel_size[1] - 1) // 2
- else:
- pad_h = (kernel_size[0] - 1) // 2
- pad_w = (kernel_size[1] - 1) // 2
- if is_3d:
- pad_d = (kernel_size[2] - 1) // 2
- preds = _reflection_pad_3d(preds, pad_d, pad_w, pad_h)
- target = _reflection_pad_3d(target, pad_d, pad_w, pad_h)
- if gaussian_kernel:
- kernel = _gaussian_kernel_3d(channel, gauss_kernel_size, sigma, dtype, device)
- else:
- preds = F.pad(preds, (pad_w, pad_w, pad_h, pad_h), mode="reflect")
- target = F.pad(target, (pad_w, pad_w, pad_h, pad_h), mode="reflect")
- if gaussian_kernel:
- kernel = _gaussian_kernel_2d(channel, gauss_kernel_size, sigma, dtype, device)
- if not gaussian_kernel:
- kernel = torch.ones((channel, 1, *kernel_size), dtype=dtype, device=device) / torch.prod(
- torch.tensor(kernel_size, dtype=dtype, device=device)
- )
- input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W)
- outputs = F.conv3d(input_list, kernel, groups=channel) if is_3d else F.conv2d(input_list, kernel, groups=channel)
- output_list = outputs.split(preds.shape[0])
- mu_pred_sq = output_list[0].pow(2)
- mu_target_sq = output_list[1].pow(2)
- mu_pred_target = output_list[0] * output_list[1]
- # Calculate the variance of the predicted and target images, should be non-negative
- sigma_pred_sq = torch.clamp(output_list[2] - mu_pred_sq, min=0.0)
- sigma_target_sq = torch.clamp(output_list[3] - mu_target_sq, min=0.0)
- sigma_pred_target = output_list[4] - mu_pred_target
- upper = 2 * sigma_pred_target.to(dtype) + c2
- lower = (sigma_pred_sq + sigma_target_sq).to(dtype) + c2
- ssim_idx_full_image = ((2 * mu_pred_target + c1) * upper) / ((mu_pred_sq + mu_target_sq + c1) * lower)
- if return_contrast_sensitivity:
- contrast_sensitivity = upper / lower
- if is_3d:
- contrast_sensitivity = contrast_sensitivity[..., pad_h:-pad_h, pad_w:-pad_w, pad_d:-pad_d]
- else:
- contrast_sensitivity = contrast_sensitivity[..., pad_h:-pad_h, pad_w:-pad_w]
- return ssim_idx_full_image.reshape(ssim_idx_full_image.shape[0], -1).mean(-1), contrast_sensitivity.reshape(
- contrast_sensitivity.shape[0], -1
- ).mean(-1)
- if return_full_image:
- return ssim_idx_full_image.reshape(ssim_idx_full_image.shape[0], -1).mean(-1), ssim_idx_full_image
- return ssim_idx_full_image.reshape(ssim_idx_full_image.shape[0], -1).mean(-1)
- def _ssim_compute(
- similarities: Tensor,
- reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
- ) -> Tensor:
- """Apply the specified reduction to pre-computed structural similarity.
- Args:
- similarities: per image similarities for a batch of images.
- reduction: a method to reduce metric score over individual batch scores
- - ``'elementwise_mean'``: takes the mean
- - ``'sum'``: takes the sum
- - ``'none'`` or ``None``: no reduction will be applied
- Returns:
- The reduced SSIM score
- """
- return reduce(similarities, reduction)
- def structural_similarity_index_measure(
- preds: Tensor,
- target: Tensor,
- gaussian_kernel: bool = True,
- sigma: Union[float, Sequence[float]] = 1.5,
- kernel_size: Union[int, Sequence[int]] = 11,
- reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
- data_range: Optional[Union[float, tuple[float, float]]] = None,
- k1: float = 0.01,
- k2: float = 0.03,
- return_full_image: bool = False,
- return_contrast_sensitivity: bool = False,
- ) -> Union[Tensor, tuple[Tensor, Tensor]]:
- """Compute Structural Similarity Index Measure.
- Args:
- preds: estimated image
- target: ground truth image
- gaussian_kernel: If true (default), a gaussian kernel is used, if false a uniform kernel is used
- sigma: Standard deviation of the gaussian kernel, anisotropic kernels are possible.
- Ignored if a uniform kernel is used
- kernel_size: the size of the uniform kernel, anisotropic kernels are possible.
- Ignored if a Gaussian kernel is used
- reduction: a method to reduce metric score over labels.
- - ``'elementwise_mean'``: takes the mean
- - ``'sum'``: takes the sum
- - ``'none'`` or ``None``: no reduction will be applied
- data_range:
- the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then
- the range is calculated as the difference and input is clamped between the values.
- k1: Parameter of SSIM.
- k2: Parameter of SSIM.
- return_full_image: If true, the full ``ssim`` image is returned as a second argument.
- Mutually exclusive with ``return_contrast_sensitivity``
- return_contrast_sensitivity: If true, the constant term is returned as a second argument.
- The luminance term can be obtained with luminance=ssim/contrast
- Mutually exclusive with ``return_full_image``
- Return:
- Tensor with SSIM score
- Raises:
- TypeError:
- If ``preds`` and ``target`` don't have the same data type.
- ValueError:
- If ``preds`` and ``target`` don't have ``BxCxHxW shape``.
- ValueError:
- If the length of ``kernel_size`` or ``sigma`` is not ``2``.
- ValueError:
- If one of the elements of ``kernel_size`` is not an ``odd positive number``.
- ValueError:
- If one of the elements of ``sigma`` is not a ``positive number``.
- Example:
- >>> from torchmetrics.functional.image import structural_similarity_index_measure
- >>> preds = torch.rand([3, 3, 256, 256])
- >>> target = preds * 0.75
- >>> structural_similarity_index_measure(preds, target)
- tensor(0.9219)
- """
- preds, target = _ssim_check_inputs(preds, target)
- similarity_pack = _ssim_update(
- preds,
- target,
- gaussian_kernel,
- sigma,
- kernel_size,
- data_range,
- k1,
- k2,
- return_full_image,
- return_contrast_sensitivity,
- )
- if isinstance(similarity_pack, tuple):
- similarity, image = similarity_pack
- return _ssim_compute(similarity, reduction), image
- similarity = similarity_pack
- return _ssim_compute(similarity, reduction)
- def _get_normalized_sim_and_cs(
- preds: Tensor,
- target: Tensor,
- gaussian_kernel: bool = True,
- sigma: Union[float, Sequence[float]] = 1.5,
- kernel_size: Union[int, Sequence[int]] = 11,
- data_range: Optional[Union[float, tuple[float, float]]] = None,
- k1: float = 0.01,
- k2: float = 0.03,
- normalize: Optional[Literal["relu", "simple"]] = None,
- ) -> tuple[Tensor, Tensor]:
- sim, contrast_sensitivity = _ssim_update(
- preds,
- target,
- gaussian_kernel,
- sigma,
- kernel_size,
- data_range,
- k1,
- k2,
- return_contrast_sensitivity=True,
- )
- if normalize == "relu":
- sim = torch.relu(sim)
- contrast_sensitivity = torch.relu(contrast_sensitivity)
- return sim, contrast_sensitivity
- def _multiscale_ssim_update(
- preds: Tensor,
- target: Tensor,
- gaussian_kernel: bool = True,
- sigma: Union[float, Sequence[float]] = 1.5,
- kernel_size: Union[int, Sequence[int]] = 11,
- data_range: Optional[Union[float, tuple[float, float]]] = None,
- k1: float = 0.01,
- k2: float = 0.03,
- betas: Union[tuple[float, float, float, float, float], tuple[float, ...]] = (
- 0.0448,
- 0.2856,
- 0.3001,
- 0.2363,
- 0.1333,
- ),
- normalize: Optional[Literal["relu", "simple"]] = None,
- ) -> Tensor:
- """Compute Multi-Scale Structural Similarity Index Measure.
- Adapted from: https://github.com/jorge-pessoa/pytorch-msssim/blob/master/pytorch_msssim/__init__.py.
- Args:
- preds: estimated image
- target: ground truth image
- gaussian_kernel: If true, a gaussian kernel is used, if false a uniform kernel is used
- sigma: Standard deviation of the gaussian kernel
- kernel_size: size of the gaussian kernel
- reduction: a method to reduce metric score over labels.
- - ``'elementwise_mean'``: takes the mean
- - ``'sum'``: takes the sum
- - ``'none'`` or ``None``: no reduction will be applied
- data_range: Range of the image. If ``None``, it is determined from the image (max - min)
- k1: Parameter of structural similarity index measure.
- k2: Parameter of structural similarity index measure.
- betas: Exponent parameters for individual similarities and contrastive sensitives returned by different image
- resolutions.
- normalize: When MultiScaleSSIM loss is used for training, it is desirable to use normalizes to improve the
- training stability. This `normalize` argument is out of scope of the original implementation [1], and it is
- adapted from https://github.com/jorge-pessoa/pytorch-msssim instead.
- Raises:
- ValueError:
- If the image height or width is smaller then ``2 ** len(betas)``.
- ValueError:
- If the image height is smaller than ``(kernel_size[0] - 1) * max(1, (len(betas) - 1)) ** 2``.
- ValueError:
- If the image width is smaller than ``(kernel_size[0] - 1) * max(1, (len(betas) - 1)) ** 2``.
- """
- mcs_list: List[Tensor] = []
- is_3d = preds.ndim == 5
- if not isinstance(kernel_size, Sequence):
- kernel_size = 3 * [kernel_size] if is_3d else 2 * [kernel_size]
- if not isinstance(sigma, Sequence):
- sigma = 3 * [sigma] if is_3d else 2 * [sigma]
- if preds.size()[-1] < 2 ** len(betas) or preds.size()[-2] < 2 ** len(betas):
- raise ValueError(
- f"For a given number of `betas` parameters {len(betas)}, the image height and width dimensions must be"
- f" larger than or equal to {2 ** len(betas)}."
- )
- _betas_div = max(1, (len(betas) - 1)) ** 2
- if preds.size()[-2] // _betas_div <= kernel_size[0] - 1:
- raise ValueError(
- f"For a given number of `betas` parameters {len(betas)} and kernel size {kernel_size[0]},"
- f" the image height must be larger than {(kernel_size[0] - 1) * _betas_div}."
- )
- if preds.size()[-1] // _betas_div <= kernel_size[1] - 1:
- raise ValueError(
- f"For a given number of `betas` parameters {len(betas)} and kernel size {kernel_size[1]},"
- f" the image width must be larger than {(kernel_size[1] - 1) * _betas_div}."
- )
- for _ in range(len(betas)):
- sim, contrast_sensitivity = _get_normalized_sim_and_cs(
- preds, target, gaussian_kernel, sigma, kernel_size, data_range, k1, k2, normalize=normalize
- )
- mcs_list.append(contrast_sensitivity)
- if len(kernel_size) == 2:
- preds = F.avg_pool2d(preds, (2, 2))
- target = F.avg_pool2d(target, (2, 2))
- elif len(kernel_size) == 3:
- preds = F.avg_pool3d(preds, (2, 2, 2))
- target = F.avg_pool3d(target, (2, 2, 2))
- else:
- raise ValueError("length of kernel_size is neither 2 nor 3")
- mcs_list[-1] = sim
- mcs_stack = torch.stack(mcs_list)
- if normalize == "simple":
- mcs_stack = (mcs_stack + 1) / 2
- betas = torch.tensor(betas, device=mcs_stack.device).view(-1, 1)
- mcs_weighted = mcs_stack**betas
- return torch.prod(mcs_weighted, axis=0) # type: ignore[call-overload]
- def _multiscale_ssim_compute(
- mcs_per_image: Tensor,
- reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
- ) -> Tensor:
- """Apply the specified reduction to pre-computed multi-scale structural similarity.
- Args:
- mcs_per_image: per image similarities for a batch of images.
- reduction: a method to reduce metric score over individual batch scores
- - ``'elementwise_mean'``: takes the mean
- - ``'sum'``: takes the sum
- - ``'none'`` or ``None``: no reduction will be applied
- Returns:
- The reduced multi-scale structural similarity
- """
- return reduce(mcs_per_image, reduction)
- def multiscale_structural_similarity_index_measure(
- preds: Tensor,
- target: Tensor,
- gaussian_kernel: bool = True,
- sigma: Union[float, Sequence[float]] = 1.5,
- kernel_size: Union[int, Sequence[int]] = 11,
- reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
- data_range: Optional[Union[float, tuple[float, float]]] = None,
- k1: float = 0.01,
- k2: float = 0.03,
- betas: tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333),
- normalize: Optional[Literal["relu", "simple"]] = "relu",
- ) -> Tensor:
- """Compute `MultiScaleSSIM`_, Multi-scale Structural Similarity Index Measure.
- This metric is a generalization of Structural Similarity Index Measure by incorporating image details at different
- resolution scores.
- Args:
- preds: Predictions from model of shape ``[N, C, H, W]``
- target: Ground truth values of shape ``[N, C, H, W]``
- gaussian_kernel: If true, a gaussian kernel is used, if false a uniform kernel is used
- sigma: Standard deviation of the gaussian kernel
- kernel_size: size of the gaussian kernel
- reduction: a method to reduce metric score over labels.
- - ``'elementwise_mean'``: takes the mean
- - ``'sum'``: takes the sum
- - ``'none'`` or ``None``: no reduction will be applied
- data_range:
- the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then
- the range is calculated as the difference and input is clamped between the values.
- k1: Parameter of structural similarity index measure.
- k2: Parameter of structural similarity index measure.
- betas: Exponent parameters for individual similarities and contrastive sensitivities returned by different image
- resolutions.
- normalize: When MultiScaleSSIM loss is used for training, it is desirable to use normalizes to improve the
- training stability. This `normalize` argument is out of scope of the original implementation [1], and it is
- adapted from https://github.com/jorge-pessoa/pytorch-msssim instead.
- Return:
- Tensor with Multi-Scale SSIM score
- Raises:
- TypeError:
- If ``preds`` and ``target`` don't have the same data type.
- ValueError:
- If ``preds`` and ``target`` don't have ``BxCxHxW shape``.
- ValueError:
- If the length of ``kernel_size`` or ``sigma`` is not ``2``.
- ValueError:
- If one of the elements of ``kernel_size`` is not an ``odd positive number``.
- ValueError:
- If one of the elements of ``sigma`` is not a ``positive number``.
- Example:
- >>> from torch import rand
- >>> from torchmetrics.functional.image import multiscale_structural_similarity_index_measure
- >>> preds = rand([3, 3, 256, 256])
- >>> target = preds * 0.75
- >>> multiscale_structural_similarity_index_measure(preds, target, data_range=1.0)
- tensor(0.9628)
- References:
- [1] Multi-Scale Structural Similarity For Image Quality Assessment by Zhou Wang, Eero P. Simoncelli and Alan C.
- Bovik `MultiScaleSSIM`_
- """
- if not isinstance(betas, tuple):
- raise ValueError("Argument `betas` is expected to be of a type tuple.")
- if isinstance(betas, tuple) and not all(isinstance(beta, float) for beta in betas):
- raise ValueError("Argument `betas` is expected to be a tuple of floats.")
- if normalize and normalize not in ("relu", "simple"):
- raise ValueError("Argument `normalize` to be expected either `None` or one of 'relu' or 'simple'")
- preds, target = _ssim_check_inputs(preds, target)
- mcs_per_image = _multiscale_ssim_update(
- preds, target, gaussian_kernel, sigma, kernel_size, data_range, k1, k2, betas, normalize
- )
- return _multiscale_ssim_compute(mcs_per_image, reduction)
|