| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220 |
- # Copyright The Lightning team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- from typing import Optional, Union
- import torch
- from torch import Tensor, tensor
- from torch.nn.functional import conv2d, pad
- from typing_extensions import Literal
- from torchmetrics.utilities.checks import _check_same_shape
- from torchmetrics.utilities.distributed import reduce
- def _scc_update(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: int) -> tuple[Tensor, Tensor, Tensor]:
- """Update and returns variables required to compute Spatial Correlation Coefficient.
- Args:
- preds: Predicted tensor
- target: Ground truth tensor
- hp_filter: High-pass filter tensor
- window_size: Local window size integer
- Return:
- Tuple of (preds, target, hp_filter) tensors
- Raises:
- ValueError:
- If ``preds`` and ``target`` have different number of channels
- If ``preds`` and ``target`` have different shapes
- If ``preds`` and ``target`` have invalid shapes
- If ``window_size`` is not a positive integer
- If ``window_size`` is greater than the size of the image
- """
- if preds.dtype != target.dtype:
- target = target.to(preds.dtype)
- _check_same_shape(preds, target)
- if preds.ndim not in (3, 4):
- raise ValueError(
- "Expected `preds` and `target` to have batch of colored images with BxCxHxW shape"
- " or batch of grayscale images of BxHxW shape."
- f" Got preds: {preds.shape} and target: {target.shape}."
- )
- if len(preds.shape) == 3:
- preds = preds.unsqueeze(1)
- target = target.unsqueeze(1)
- if not window_size > 0:
- raise ValueError(f"Expected `window_size` to be a positive integer. Got {window_size}.")
- if window_size > preds.size(2) or window_size > preds.size(3):
- raise ValueError(
- f"Expected `window_size` to be less than or equal to the size of the image."
- f" Got window_size: {window_size} and image size: {preds.size(2)}x{preds.size(3)}."
- )
- preds = preds.to(torch.float32)
- target = target.to(torch.float32)
- hp_filter = hp_filter[None, None, :].to(dtype=preds.dtype, device=preds.device)
- return preds, target, hp_filter
- def _symmetric_reflect_pad_2d(input_img: Tensor, pad: Union[int, tuple[int, ...]]) -> Tensor:
- """Applies symmetric padding to the 2D image tensor input using ``reflect`` mode (d c b a | a b c d | d c b a)."""
- if isinstance(pad, int):
- pad = (pad, pad, pad, pad)
- if len(pad) != 4:
- raise ValueError(f"Expected padding to have length 4, but got {len(pad)}")
- left_pad = input_img[:, :, :, 0 : pad[0]].flip(dims=[3])
- right_pad = input_img[:, :, :, -pad[1] :].flip(dims=[3])
- padded = torch.cat([left_pad, input_img, right_pad], dim=3)
- top_pad = padded[:, :, 0 : pad[2], :].flip(dims=[2])
- bottom_pad = padded[:, :, -pad[3] :, :].flip(dims=[2])
- return torch.cat([top_pad, padded, bottom_pad], dim=2)
- def _signal_convolve_2d(input_img: Tensor, kernel: Tensor) -> Tensor:
- """Applies 2D signal convolution to the input tensor with the given kernel."""
- left_padding = math.floor((kernel.size(3) - 1) / 2)
- right_padding = math.ceil((kernel.size(3) - 1) / 2)
- top_padding = math.floor((kernel.size(2) - 1) / 2)
- bottom_padding = math.ceil((kernel.size(2) - 1) / 2)
- padded = _symmetric_reflect_pad_2d(input_img, pad=(left_padding, right_padding, top_padding, bottom_padding))
- kernel = kernel.flip([2, 3])
- return conv2d(padded, kernel, stride=1, padding=0)
- def _hp_2d_laplacian(input_img: Tensor, kernel: Tensor) -> Tensor:
- """Applies 2-D Laplace filter to the input tensor with the given high pass filter."""
- return _signal_convolve_2d(input_img, kernel) * 2.0
- def _local_variance_covariance(preds: Tensor, target: Tensor, window: Tensor) -> tuple[Tensor, Tensor, Tensor]:
- """Computes local variance and covariance of the input tensors."""
- # This code is inspired by
- # https://github.com/andrewekhalel/sewar/blob/master/sewar/full_ref.py#L187.
- left_padding = math.ceil((window.size(3) - 1) / 2)
- right_padding = math.floor((window.size(3) - 1) / 2)
- preds = pad(preds, (left_padding, right_padding, left_padding, right_padding))
- target = pad(target, (left_padding, right_padding, left_padding, right_padding))
- preds_mean = conv2d(preds, window, stride=1, padding=0)
- target_mean = conv2d(target, window, stride=1, padding=0)
- preds_var = conv2d(preds**2, window, stride=1, padding=0) - preds_mean**2
- target_var = conv2d(target**2, window, stride=1, padding=0) - target_mean**2
- target_preds_cov = conv2d(target * preds, window, stride=1, padding=0) - target_mean * preds_mean
- return preds_var, target_var, target_preds_cov
- def _scc_per_channel_compute(preds: Tensor, target: Tensor, hp_filter: Tensor, window_size: int) -> Tensor:
- """Computes per channel Spatial Correlation Coefficient.
- Args:
- preds: estimated image of Bx1xHxW shape.
- target: ground truth image of Bx1xHxW shape.
- hp_filter: 2D high-pass filter.
- window_size: size of window for local mean calculation.
- Return:
- Tensor with Spatial Correlation Coefficient score
- """
- dtype = preds.dtype
- device = preds.device
- # This code is inspired by
- # https://github.com/andrewekhalel/sewar/blob/master/sewar/full_ref.py#L187.
- window = torch.ones(size=(1, 1, window_size, window_size), dtype=dtype, device=device) / (window_size**2)
- preds_hp = _hp_2d_laplacian(preds, hp_filter)
- target_hp = _hp_2d_laplacian(target, hp_filter)
- preds_var, target_var, target_preds_cov = _local_variance_covariance(preds_hp, target_hp, window)
- preds_var[preds_var < 0] = 0
- target_var[target_var < 0] = 0
- den = torch.sqrt(target_var) * torch.sqrt(preds_var)
- idx = den == 0
- den[den == 0] = 1
- scc = target_preds_cov / den
- scc[idx] = 0
- return scc
- def spatial_correlation_coefficient(
- preds: Tensor,
- target: Tensor,
- hp_filter: Optional[Tensor] = None,
- window_size: int = 8,
- reduction: Optional[Literal["mean", "none", None]] = "mean",
- ) -> Tensor:
- """Compute Spatial Correlation Coefficient (SCC_).
- Args:
- preds: predicted images of shape ``(N,C,H,W)`` or ``(N,H,W)``.
- target: ground truth images of shape ``(N,C,H,W)`` or ``(N,H,W)``.
- hp_filter: High-pass filter tensor. default: tensor([[-1,-1,-1],[-1,8,-1],[-1,-1,-1]])
- window_size: Local window size integer. default: 8,
- reduction: Reduction method for output tensor. If ``None`` or ``"none"``,
- returns a tensor with the per sample results. default: ``"mean"``.
- Return:
- Tensor with scc score
- Example:
- >>> from torch import randn
- >>> from torchmetrics.functional.image import spatial_correlation_coefficient as scc
- >>> x = randn(5, 3, 16, 16)
- >>> scc(x, x)
- tensor(1.)
- >>> x = randn(5, 16, 16)
- >>> scc(x, x)
- tensor(1.)
- >>> x = randn(5, 3, 16, 16)
- >>> y = randn(5, 3, 16, 16)
- >>> scc(x, y, reduction="none")
- tensor([0.0223, 0.0256, 0.0616, 0.0159, 0.0170])
- """
- if hp_filter is None:
- hp_filter = tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]])
- if reduction is None:
- reduction = "none"
- if reduction not in ("mean", "none"):
- raise ValueError(f"Expected reduction to be 'mean' or 'none', but got {reduction}")
- preds, target, hp_filter = _scc_update(preds, target, hp_filter, window_size)
- per_channel = [
- _scc_per_channel_compute(
- preds[:, i, :, :].unsqueeze(1), target[:, i, :, :].unsqueeze(1), hp_filter, window_size
- )
- for i in range(preds.size(1))
- ]
- if reduction == "none":
- return torch.mean(torch.cat(per_channel, dim=1), dim=[1, 2, 3])
- if reduction == "mean":
- return reduce(torch.cat(per_channel, dim=1), reduction="elementwise_mean")
- return None
|