ssim.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import List
  18. import torch
  19. from torch import nn
  20. from kornia.filters import filter2d_separable, get_gaussian_kernel1d
  21. from kornia.filters.filter import _compute_padding
  22. def _crop(img: torch.Tensor, cropping_shape: List[int]) -> torch.Tensor:
  23. """Crop out the part of "valid" convolution area."""
  24. return torch.nn.functional.pad(
  25. img, (-cropping_shape[2], -cropping_shape[3], -cropping_shape[0], -cropping_shape[1])
  26. )
  27. def ssim(
  28. img1: torch.Tensor,
  29. img2: torch.Tensor,
  30. window_size: int,
  31. max_val: float = 1.0,
  32. eps: float = 1e-12,
  33. padding: str = "same",
  34. ) -> torch.Tensor:
  35. r"""Compute the Structural Similarity (SSIM) index map between two images.
  36. Measures the (SSIM) index between each element in the input `x` and target `y`.
  37. The index can be described as:
  38. .. math::
  39. \text{SSIM}(x, y) = \frac{(2\mu_x\mu_y+c_1)(2\sigma_{xy}+c_2)}
  40. {(\mu_x^2+\mu_y^2+c_1)(\sigma_x^2+\sigma_y^2+c_2)}
  41. where:
  42. - :math:`c_1=(k_1 L)^2` and :math:`c_2=(k_2 L)^2` are two variables to
  43. stabilize the division with weak denominator.
  44. - :math:`L` is the dynamic range of the pixel-values (typically this is
  45. :math:`2^{\#\text{bits per pixel}}-1`).
  46. Args:
  47. img1: the first input image with shape :math:`(B, C, H, W)`.
  48. img2: the second input image with shape :math:`(B, C, H, W)`.
  49. window_size: the size of the gaussian kernel to smooth the images.
  50. max_val: the dynamic range of the images.
  51. eps: Small value for numerically stability when dividing.
  52. padding: ``'same'`` | ``'valid'``. Whether to only use the "valid" convolution
  53. area to compute SSIM to match the MATLAB implementation of original SSIM paper.
  54. Returns:
  55. The ssim index map with shape :math:`(B, C, H, W)`.
  56. Examples:
  57. >>> input1 = torch.rand(1, 4, 5, 5)
  58. >>> input2 = torch.rand(1, 4, 5, 5)
  59. >>> ssim_map = ssim(input1, input2, 5) # 1x4x5x5
  60. """
  61. if not isinstance(img1, torch.Tensor):
  62. raise TypeError(f"Input img1 type is not a torch.Tensor. Got {type(img1)}")
  63. if not isinstance(img2, torch.Tensor):
  64. raise TypeError(f"Input img2 type is not a torch.Tensor. Got {type(img2)}")
  65. if not isinstance(max_val, float):
  66. raise TypeError(f"Input max_val type is not a float. Got {type(max_val)}")
  67. if not len(img1.shape) == 4:
  68. raise ValueError(f"Invalid img1 shape, we expect BxCxHxW. Got: {img1.shape}")
  69. if not len(img2.shape) == 4:
  70. raise ValueError(f"Invalid img2 shape, we expect BxCxHxW. Got: {img2.shape}")
  71. if not img1.shape == img2.shape:
  72. raise ValueError(f"img1 and img2 shapes must be the same. Got: {img1.shape} and {img2.shape}")
  73. # prepare kernel
  74. kernel: torch.Tensor = get_gaussian_kernel1d(window_size, 1.5, device=img1.device, dtype=img1.dtype)
  75. # compute coefficients
  76. C1: float = (0.01 * max_val) ** 2
  77. C2: float = (0.03 * max_val) ** 2
  78. # compute local mean per channel
  79. mu1: torch.Tensor = filter2d_separable(img1, kernel, kernel)
  80. mu2: torch.Tensor = filter2d_separable(img2, kernel, kernel)
  81. cropping_shape: List[int] = []
  82. if padding == "valid":
  83. height = width = kernel.shape[-1]
  84. cropping_shape = _compute_padding([height, width])
  85. mu1 = _crop(mu1, cropping_shape)
  86. mu2 = _crop(mu2, cropping_shape)
  87. elif padding == "same":
  88. pass
  89. mu1_sq = mu1**2
  90. mu2_sq = mu2**2
  91. mu1_mu2 = mu1 * mu2
  92. mu_img1_sq = filter2d_separable(img1**2, kernel, kernel)
  93. mu_img2_sq = filter2d_separable(img2**2, kernel, kernel)
  94. mu_img1_img2 = filter2d_separable(img1 * img2, kernel, kernel)
  95. if padding == "valid":
  96. mu_img1_sq = _crop(mu_img1_sq, cropping_shape)
  97. mu_img2_sq = _crop(mu_img2_sq, cropping_shape)
  98. mu_img1_img2 = _crop(mu_img1_img2, cropping_shape)
  99. elif padding == "same":
  100. pass
  101. # compute local sigma per channel
  102. sigma1_sq = mu_img1_sq - mu1_sq
  103. sigma2_sq = mu_img2_sq - mu2_sq
  104. sigma12 = mu_img1_img2 - mu1_mu2
  105. # compute the similarity index map
  106. num: torch.Tensor = (2.0 * mu1_mu2 + C1) * (2.0 * sigma12 + C2)
  107. den: torch.Tensor = (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
  108. return num / (den + eps)
  109. class SSIM(nn.Module):
  110. r"""Create a module that computes the Structural Similarity (SSIM) index between two images.
  111. Measures the (SSIM) index between each element in the input `x` and target `y`.
  112. The index can be described as:
  113. .. math::
  114. \text{SSIM}(x, y) = \frac{(2\mu_x\mu_y+c_1)(2\sigma_{xy}+c_2)}
  115. {(\mu_x^2+\mu_y^2+c_1)(\sigma_x^2+\sigma_y^2+c_2)}
  116. where:
  117. - :math:`c_1=(k_1 L)^2` and :math:`c_2=(k_2 L)^2` are two variables to
  118. stabilize the division with weak denominator.
  119. - :math:`L` is the dynamic range of the pixel-values (typically this is
  120. :math:`2^{\#\text{bits per pixel}}-1`).
  121. Args:
  122. window_size: the size of the gaussian kernel to smooth the images.
  123. max_val: the dynamic range of the images.
  124. eps: Small value for numerically stability when dividing.
  125. padding: ``'same'`` | ``'valid'``. Whether to only use the "valid" convolution
  126. area to compute SSIM to match the MATLAB implementation of original SSIM paper.
  127. Shape:
  128. - Input: :math:`(B, C, H, W)`.
  129. - Target :math:`(B, C, H, W)`.
  130. - Output: :math:`(B, C, H, W)`.
  131. Examples:
  132. >>> input1 = torch.rand(1, 4, 5, 5)
  133. >>> input2 = torch.rand(1, 4, 5, 5)
  134. >>> ssim = SSIM(5)
  135. >>> ssim_map = ssim(input1, input2) # 1x4x5x5
  136. """
  137. def __init__(self, window_size: int, max_val: float = 1.0, eps: float = 1e-12, padding: str = "same") -> None:
  138. super().__init__()
  139. self.window_size: int = window_size
  140. self.max_val: float = max_val
  141. self.eps = eps
  142. self.padding = padding
  143. def forward(self, img1: torch.Tensor, img2: torch.Tensor) -> torch.Tensor:
  144. return ssim(img1, img2, self.window_size, self.max_val, self.eps, self.padding)