| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from typing import List
- from kornia.core import Module, Tensor, pad
- from kornia.core.check import KORNIA_CHECK, KORNIA_CHECK_IS_TENSOR, KORNIA_CHECK_SHAPE
- from kornia.filters import filter3d, get_gaussian_kernel3d
- from kornia.filters.filter import _compute_padding
- def _crop(img: Tensor, cropping_shape: List[int]) -> Tensor:
- """Crop out the part of "valid" convolution area."""
- return pad(
- img,
- (
- -cropping_shape[4],
- -cropping_shape[5],
- -cropping_shape[2],
- -cropping_shape[3],
- -cropping_shape[0],
- -cropping_shape[1],
- ),
- )
- def ssim3d(
- img1: Tensor, img2: Tensor, window_size: int, max_val: float = 1.0, eps: float = 1e-12, padding: str = "same"
- ) -> Tensor:
- r"""Compute the Structural Similarity (SSIM) index map between two images.
- Measures the (SSIM) index between each element in the input `x` and target `y`.
- The index can be described as:
- .. math::
- \text{SSIM}(x, y) = \frac{(2\mu_x\mu_y+c_1)(2\sigma_{xy}+c_2)}
- {(\mu_x^2+\mu_y^2+c_1)(\sigma_x^2+\sigma_y^2+c_2)}
- where:
- - :math:`c_1=(k_1 L)^2` and :math:`c_2=(k_2 L)^2` are two variables to
- stabilize the division with weak denominator.
- - :math:`L` is the dynamic range of the pixel-values (typically this is
- :math:`2^{\#\text{bits per pixel}}-1`).
- Args:
- img1: the first input image with shape :math:`(B, C, D, H, W)`.
- img2: the second input image with shape :math:`(B, C, D, H, W)`.
- window_size: the size of the gaussian kernel to smooth the images.
- max_val: the dynamic range of the images.
- eps: Small value for numerically stability when dividing.
- padding: ``'same'`` | ``'valid'``. Whether to only use the "valid" convolution
- area to compute SSIM to match the MATLAB implementation of original SSIM paper.
- Returns:
- The ssim index map with shape :math:`(B, C, D, H, W)`.
- Examples:
- >>> input1 = torch.rand(1, 4, 5, 5, 5)
- >>> input2 = torch.rand(1, 4, 5, 5, 5)
- >>> ssim_map = ssim3d(input1, input2, 5) # 1x4x5x5x5
- """
- KORNIA_CHECK_IS_TENSOR(img1)
- KORNIA_CHECK_IS_TENSOR(img2)
- KORNIA_CHECK_SHAPE(img1, ["B", "C", "D", "H", "W"])
- KORNIA_CHECK_SHAPE(img2, ["B", "C", "D", "H", "W"])
- KORNIA_CHECK(img1.shape == img2.shape, f"img1 and img2 shapes must be the same. Got: {img1.shape} and {img2.shape}")
- if not isinstance(max_val, float):
- raise TypeError(f"Input max_val type is not a float. Got {type(max_val)}")
- # prepare kernel
- kernel: Tensor = get_gaussian_kernel3d((window_size, window_size, window_size), (1.5, 1.5, 1.5))
- # compute coefficients
- C1: float = (0.01 * max_val) ** 2
- C2: float = (0.03 * max_val) ** 2
- # compute local mean per channel
- mu1: Tensor = filter3d(img1, kernel)
- mu2: Tensor = filter3d(img2, kernel)
- cropping_shape: List[int] = []
- if padding == "valid":
- depth, height, width = kernel.shape[-3:]
- cropping_shape = _compute_padding([depth, height, width])
- mu1 = _crop(mu1, cropping_shape)
- mu2 = _crop(mu2, cropping_shape)
- elif padding == "same":
- pass
- mu1_sq = mu1**2
- mu2_sq = mu2**2
- mu1_mu2 = mu1 * mu2
- mu_img1_sq = filter3d(img1**2, kernel)
- mu_img2_sq = filter3d(img2**2, kernel)
- mu_img1_img2 = filter3d(img1 * img2, kernel)
- if padding == "valid":
- mu_img1_sq = _crop(mu_img1_sq, cropping_shape)
- mu_img2_sq = _crop(mu_img2_sq, cropping_shape)
- mu_img1_img2 = _crop(mu_img1_img2, cropping_shape)
- elif padding == "same":
- pass
- # compute local sigma per channel
- sigma1_sq = mu_img1_sq - mu1_sq
- sigma2_sq = mu_img2_sq - mu2_sq
- sigma12 = mu_img1_img2 - mu1_mu2
- # compute the similarity index map
- num: Tensor = (2.0 * mu1_mu2 + C1) * (2.0 * sigma12 + C2)
- den: Tensor = (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
- return num / (den + eps)
- class SSIM3D(Module):
- r"""Create a module that computes the Structural Similarity (SSIM) index between two 3D images.
- Measures the (SSIM) index between each element in the input `x` and target `y`.
- The index can be described as:
- .. math::
- \text{SSIM}(x, y) = \frac{(2\mu_x\mu_y+c_1)(2\sigma_{xy}+c_2)}
- {(\mu_x^2+\mu_y^2+c_1)(\sigma_x^2+\sigma_y^2+c_2)}
- where:
- - :math:`c_1=(k_1 L)^2` and :math:`c_2=(k_2 L)^2` are two variables to
- stabilize the division with weak denominator.
- - :math:`L` is the dynamic range of the pixel-values (typically this is
- :math:`2^{\#\text{bits per pixel}}-1`).
- Args:
- window_size: the size of the gaussian kernel to smooth the images.
- max_val: the dynamic range of the images.
- eps: Small value for numerically stability when dividing.
- padding: ``'same'`` | ``'valid'``. Whether to only use the "valid" convolution
- area to compute SSIM to match the MATLAB implementation of original SSIM paper.
- Shape:
- - Input: :math:`(B, C, D, H, W)`.
- - Target :math:`(B, C, D, H, W)`.
- - Output: :math:`(B, C, D, H, W)`.
- Examples:
- >>> input1 = torch.rand(1, 4, 5, 5, 5)
- >>> input2 = torch.rand(1, 4, 5, 5, 5)
- >>> ssim = SSIM3D(5)
- >>> ssim_map = ssim(input1, input2) # 1x4x5x5x5
- """
- def __init__(self, window_size: int, max_val: float = 1.0, eps: float = 1e-12, padding: str = "same") -> None:
- super().__init__()
- self.window_size: int = window_size
- self.max_val: float = max_val
- self.eps = eps
- self.padding = padding
- def forward(self, img1: Tensor, img2: Tensor) -> Tensor:
- return ssim3d(img1, img2, self.window_size, self.max_val, self.eps, self.padding)
|