| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- # Copyright The Lightning team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import Any, Optional
- import torch
- from torch import Tensor, tensor
- from torchmetrics.functional.image.scc import _scc_per_channel_compute as _scc_compute
- from torchmetrics.functional.image.scc import _scc_update
- from torchmetrics.metric import Metric
- class SpatialCorrelationCoefficient(Metric):
- """Compute Spatial Correlation Coefficient (SCC_).
- As input to ``forward`` and ``update`` the metric accepts the following input
- - ``preds`` (:class:`~torch.Tensor`): Predictions from model of shape ``(N,C,H,W)`` or ``(N,H,W)``.
- - ``target`` (:class:`~torch.Tensor`): Ground truth values of shape ``(N,C,H,W)`` or ``(N,H,W)``.
- As output of `forward` and `compute` the metric returns the following output
- - ``scc`` (:class:`~torch.Tensor`): Tensor with scc score
- Args:
- hp_filter: High-pass filter tensor. default: tensor([[-1,-1,-1],[-1,8,-1],[-1,-1,-1]]).
- window_size: Local window size integer. default: 8.
- kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
- Example:
- >>> from torch import randn
- >>> from torchmetrics.image import SpatialCorrelationCoefficient as SCC
- >>> preds = randn([32, 3, 64, 64])
- >>> target = randn([32, 3, 64, 64])
- >>> scc = SCC()
- >>> scc(preds, target)
- tensor(0.0023)
- """
- is_differentiable = True
- higher_is_better = True
- full_state_update = False
- scc_score: Tensor
- total: Tensor
- def __init__(self, high_pass_filter: Optional[Tensor] = None, window_size: int = 8, **kwargs: Any) -> None:
- super().__init__(**kwargs)
- if high_pass_filter is None:
- high_pass_filter = tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]])
- self.hp_filter = high_pass_filter
- self.ws = window_size
- self.add_state("scc_score", default=tensor(0.0), dist_reduce_fx="sum")
- self.add_state("total", default=tensor(0.0), dist_reduce_fx="sum")
- def update(self, preds: Tensor, target: Tensor) -> None:
- """Update state with predictions and targets."""
- preds, target, hp_filter = _scc_update(preds, target, self.hp_filter, self.ws)
- scc_per_channel = [
- _scc_compute(preds[:, i, :, :].unsqueeze(1), target[:, i, :, :].unsqueeze(1), hp_filter, self.ws)
- for i in range(preds.size(1))
- ]
- self.scc_score += torch.sum(torch.mean(torch.cat(scc_per_channel, dim=1), dim=[1, 2, 3]))
- self.total += preds.size(0)
- def compute(self) -> Tensor:
- """Compute the VIF score based on inputs passed in to ``update`` previously."""
- return self.scc_score / self.total
|