scc.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Optional
  15. import torch
  16. from torch import Tensor, tensor
  17. from torchmetrics.functional.image.scc import _scc_per_channel_compute as _scc_compute
  18. from torchmetrics.functional.image.scc import _scc_update
  19. from torchmetrics.metric import Metric
  20. class SpatialCorrelationCoefficient(Metric):
  21. """Compute Spatial Correlation Coefficient (SCC_).
  22. As input to ``forward`` and ``update`` the metric accepts the following input
  23. - ``preds`` (:class:`~torch.Tensor`): Predictions from model of shape ``(N,C,H,W)`` or ``(N,H,W)``.
  24. - ``target`` (:class:`~torch.Tensor`): Ground truth values of shape ``(N,C,H,W)`` or ``(N,H,W)``.
  25. As output of `forward` and `compute` the metric returns the following output
  26. - ``scc`` (:class:`~torch.Tensor`): Tensor with scc score
  27. Args:
  28. hp_filter: High-pass filter tensor. default: tensor([[-1,-1,-1],[-1,8,-1],[-1,-1,-1]]).
  29. window_size: Local window size integer. default: 8.
  30. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  31. Example:
  32. >>> from torch import randn
  33. >>> from torchmetrics.image import SpatialCorrelationCoefficient as SCC
  34. >>> preds = randn([32, 3, 64, 64])
  35. >>> target = randn([32, 3, 64, 64])
  36. >>> scc = SCC()
  37. >>> scc(preds, target)
  38. tensor(0.0023)
  39. """
  40. is_differentiable = True
  41. higher_is_better = True
  42. full_state_update = False
  43. scc_score: Tensor
  44. total: Tensor
  45. def __init__(self, high_pass_filter: Optional[Tensor] = None, window_size: int = 8, **kwargs: Any) -> None:
  46. super().__init__(**kwargs)
  47. if high_pass_filter is None:
  48. high_pass_filter = tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]])
  49. self.hp_filter = high_pass_filter
  50. self.ws = window_size
  51. self.add_state("scc_score", default=tensor(0.0), dist_reduce_fx="sum")
  52. self.add_state("total", default=tensor(0.0), dist_reduce_fx="sum")
  53. def update(self, preds: Tensor, target: Tensor) -> None:
  54. """Update state with predictions and targets."""
  55. preds, target, hp_filter = _scc_update(preds, target, self.hp_filter, self.ws)
  56. scc_per_channel = [
  57. _scc_compute(preds[:, i, :, :].unsqueeze(1), target[:, i, :, :].unsqueeze(1), hp_filter, self.ws)
  58. for i in range(preds.size(1))
  59. ]
  60. self.scc_score += torch.sum(torch.mean(torch.cat(scc_per_channel, dim=1), dim=[1, 2, 3]))
  61. self.total += preds.size(0)
  62. def compute(self) -> Tensor:
  63. """Compute the VIF score based on inputs passed in to ``update`` previously."""
  64. return self.scc_score / self.total