scc.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. from typing import Optional, Union
  16. import torch
  17. from torch import Tensor, tensor
  18. from torch.nn.functional import conv2d, pad
  19. from typing_extensions import Literal
  20. from torchmetrics.utilities.checks import _check_same_shape
  21. from torchmetrics.utilities.distributed import reduce
  22. def _scc_update(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: int) -> tuple[Tensor, Tensor, Tensor]:
  23. """Update and returns variables required to compute Spatial Correlation Coefficient.
  24. Args:
  25. preds: Predicted tensor
  26. target: Ground truth tensor
  27. hp_filter: High-pass filter tensor
  28. window_size: Local window size integer
  29. Return:
  30. Tuple of (preds, target, hp_filter) tensors
  31. Raises:
  32. ValueError:
  33. If ``preds`` and ``target`` have different number of channels
  34. If ``preds`` and ``target`` have different shapes
  35. If ``preds`` and ``target`` have invalid shapes
  36. If ``window_size`` is not a positive integer
  37. If ``window_size`` is greater than the size of the image
  38. """
  39. if preds.dtype != target.dtype:
  40. target = target.to(preds.dtype)
  41. _check_same_shape(preds, target)
  42. if preds.ndim not in (3, 4):
  43. raise ValueError(
  44. "Expected `preds` and `target` to have batch of colored images with BxCxHxW shape"
  45. " or batch of grayscale images of BxHxW shape."
  46. f" Got preds: {preds.shape} and target: {target.shape}."
  47. )
  48. if len(preds.shape) == 3:
  49. preds = preds.unsqueeze(1)
  50. target = target.unsqueeze(1)
  51. if not window_size > 0:
  52. raise ValueError(f"Expected `window_size` to be a positive integer. Got {window_size}.")
  53. if window_size > preds.size(2) or window_size > preds.size(3):
  54. raise ValueError(
  55. f"Expected `window_size` to be less than or equal to the size of the image."
  56. f" Got window_size: {window_size} and image size: {preds.size(2)}x{preds.size(3)}."
  57. )
  58. preds = preds.to(torch.float32)
  59. target = target.to(torch.float32)
  60. hp_filter = hp_filter[None, None, :].to(dtype=preds.dtype, device=preds.device)
  61. return preds, target, hp_filter
  62. def _symmetric_reflect_pad_2d(input_img: Tensor, pad: Union[int, tuple[int, ...]]) -> Tensor:
  63. """Applies symmetric padding to the 2D image tensor input using ``reflect`` mode (d c b a | a b c d | d c b a)."""
  64. if isinstance(pad, int):
  65. pad = (pad, pad, pad, pad)
  66. if len(pad) != 4:
  67. raise ValueError(f"Expected padding to have length 4, but got {len(pad)}")
  68. left_pad = input_img[:, :, :, 0 : pad[0]].flip(dims=[3])
  69. right_pad = input_img[:, :, :, -pad[1] :].flip(dims=[3])
  70. padded = torch.cat([left_pad, input_img, right_pad], dim=3)
  71. top_pad = padded[:, :, 0 : pad[2], :].flip(dims=[2])
  72. bottom_pad = padded[:, :, -pad[3] :, :].flip(dims=[2])
  73. return torch.cat([top_pad, padded, bottom_pad], dim=2)
  74. def _signal_convolve_2d(input_img: Tensor, kernel: Tensor) -> Tensor:
  75. """Applies 2D signal convolution to the input tensor with the given kernel."""
  76. left_padding = math.floor((kernel.size(3) - 1) / 2)
  77. right_padding = math.ceil((kernel.size(3) - 1) / 2)
  78. top_padding = math.floor((kernel.size(2) - 1) / 2)
  79. bottom_padding = math.ceil((kernel.size(2) - 1) / 2)
  80. padded = _symmetric_reflect_pad_2d(input_img, pad=(left_padding, right_padding, top_padding, bottom_padding))
  81. kernel = kernel.flip([2, 3])
  82. return conv2d(padded, kernel, stride=1, padding=0)
  83. def _hp_2d_laplacian(input_img: Tensor, kernel: Tensor) -> Tensor:
  84. """Applies 2-D Laplace filter to the input tensor with the given high pass filter."""
  85. return _signal_convolve_2d(input_img, kernel) * 2.0
  86. def _local_variance_covariance(preds: Tensor, target: Tensor, window: Tensor) -> tuple[Tensor, Tensor, Tensor]:
  87. """Computes local variance and covariance of the input tensors."""
  88. # This code is inspired by
  89. # https://github.com/andrewekhalel/sewar/blob/master/sewar/full_ref.py#L187.
  90. left_padding = math.ceil((window.size(3) - 1) / 2)
  91. right_padding = math.floor((window.size(3) - 1) / 2)
  92. preds = pad(preds, (left_padding, right_padding, left_padding, right_padding))
  93. target = pad(target, (left_padding, right_padding, left_padding, right_padding))
  94. preds_mean = conv2d(preds, window, stride=1, padding=0)
  95. target_mean = conv2d(target, window, stride=1, padding=0)
  96. preds_var = conv2d(preds**2, window, stride=1, padding=0) - preds_mean**2
  97. target_var = conv2d(target**2, window, stride=1, padding=0) - target_mean**2
  98. target_preds_cov = conv2d(target * preds, window, stride=1, padding=0) - target_mean * preds_mean
  99. return preds_var, target_var, target_preds_cov
  100. def _scc_per_channel_compute(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: int) -> Tensor:
  101. """Computes per channel Spatial Correlation Coefficient.
  102. Args:
  103. preds: estimated image of Bx1xHxW shape.
  104. target: ground truth image of Bx1xHxW shape.
  105. hp_filter: 2D high-pass filter.
  106. window_size: size of window for local mean calculation.
  107. Return:
  108. Tensor with Spatial Correlation Coefficient score
  109. """
  110. dtype = preds.dtype
  111. device = preds.device
  112. # This code is inspired by
  113. # https://github.com/andrewekhalel/sewar/blob/master/sewar/full_ref.py#L187.
  114. window = torch.ones(size=(1, 1, window_size, window_size), dtype=dtype, device=device) / (window_size**2)
  115. preds_hp = _hp_2d_laplacian(preds, hp_filter)
  116. target_hp = _hp_2d_laplacian(target, hp_filter)
  117. preds_var, target_var, target_preds_cov = _local_variance_covariance(preds_hp, target_hp, window)
  118. preds_var[preds_var < 0] = 0
  119. target_var[target_var < 0] = 0
  120. den = torch.sqrt(target_var) * torch.sqrt(preds_var)
  121. idx = den == 0
  122. den[den == 0] = 1
  123. scc = target_preds_cov / den
  124. scc[idx] = 0
  125. return scc
  126. def spatial_correlation_coefficient(
  127. preds: Tensor,
  128. target: Tensor,
  129. hp_filter: Optional[Tensor] = None,
  130. window_size: int = 8,
  131. reduction: Optional[Literal["mean", "none", None]] = "mean",
  132. ) -> Tensor:
  133. """Compute Spatial Correlation Coefficient (SCC_).
  134. Args:
  135. preds: predicted images of shape ``(N,C,H,W)`` or ``(N,H,W)``.
  136. target: ground truth images of shape ``(N,C,H,W)`` or ``(N,H,W)``.
  137. hp_filter: High-pass filter tensor. default: tensor([[-1,-1,-1],[-1,8,-1],[-1,-1,-1]])
  138. window_size: Local window size integer. default: 8,
  139. reduction: Reduction method for output tensor. If ``None`` or ``"none"``,
  140. returns a tensor with the per sample results. default: ``"mean"``.
  141. Return:
  142. Tensor with scc score
  143. Example:
  144. >>> from torch import randn
  145. >>> from torchmetrics.functional.image import spatial_correlation_coefficient as scc
  146. >>> x = randn(5, 3, 16, 16)
  147. >>> scc(x, x)
  148. tensor(1.)
  149. >>> x = randn(5, 16, 16)
  150. >>> scc(x, x)
  151. tensor(1.)
  152. >>> x = randn(5, 3, 16, 16)
  153. >>> y = randn(5, 3, 16, 16)
  154. >>> scc(x, y, reduction="none")
  155. tensor([0.0223, 0.0256, 0.0616, 0.0159, 0.0170])
  156. """
  157. if hp_filter is None:
  158. hp_filter = tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]])
  159. if reduction is None:
  160. reduction = "none"
  161. if reduction not in ("mean", "none"):
  162. raise ValueError(f"Expected reduction to be 'mean' or 'none', but got {reduction}")
  163. preds, target, hp_filter = _scc_update(preds, target, hp_filter, window_size)
  164. per_channel = [
  165. _scc_per_channel_compute(
  166. preds[:, i, :, :].unsqueeze(1), target[:, i, :, :].unsqueeze(1), hp_filter, window_size
  167. )
  168. for i in range(preds.size(1))
  169. ]
  170. if reduction == "none":
  171. return torch.mean(torch.cat(per_channel, dim=1), dim=[1, 2, 3])
  172. if reduction == "mean":
  173. return reduce(torch.cat(per_channel, dim=1), reduction="elementwise_mean")
  174. return None