sam.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. from torch import Tensor
  16. from typing_extensions import Literal
  17. from torchmetrics.utilities.checks import _check_same_shape
  18. from torchmetrics.utilities.distributed import reduce
  19. def _sam_update(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]:
  20. """Update and returns variables required to compute Spectral Angle Mapper.
  21. Args:
  22. preds: Predicted tensor
  23. target: Ground truth tensor
  24. """
  25. if preds.dtype != target.dtype:
  26. raise TypeError(
  27. "Expected `preds` and `target` to have the same data type."
  28. f" Got preds: {preds.dtype} and target: {target.dtype}."
  29. )
  30. _check_same_shape(preds, target)
  31. if len(preds.shape) != 4:
  32. raise ValueError(
  33. f"Expected `preds` and `target` to have BxCxHxW shape. Got preds: {preds.shape} and target: {target.shape}."
  34. )
  35. if (preds.shape[1] <= 1) or (target.shape[1] <= 1):
  36. raise ValueError(
  37. "Expected channel dimension of `preds` and `target` to be larger than 1."
  38. f" Got preds: {preds.shape[1]} and target: {target.shape[1]}."
  39. )
  40. return preds, target
  41. def _sam_compute(
  42. preds: Tensor,
  43. target: Tensor,
  44. reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
  45. ) -> Tensor:
  46. """Compute Spectral Angle Mapper.
  47. Args:
  48. preds: estimated image
  49. target: ground truth image
  50. reduction: a method to reduce metric score over labels.
  51. - ``'elementwise_mean'``: takes the mean (default)
  52. - ``'sum'``: takes the sum
  53. - ``'none'`` or ``None``: no reduction will be applied
  54. Example:
  55. >>> from torch import rand
  56. >>> preds = rand([16, 3, 16, 16])
  57. >>> target = rand([16, 3, 16, 16])
  58. >>> preds, target = _sam_update(preds, target)
  59. >>> _sam_compute(preds, target)
  60. tensor(0.5914)
  61. """
  62. dot_product = (preds * target).sum(dim=1)
  63. preds_norm = preds.norm(dim=1)
  64. target_norm = target.norm(dim=1)
  65. sam_score = torch.clamp(dot_product / (preds_norm * target_norm), -1, 1).acos()
  66. return reduce(sam_score, reduction)
  67. def spectral_angle_mapper(
  68. preds: Tensor,
  69. target: Tensor,
  70. reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
  71. ) -> Tensor:
  72. """Universal Spectral Angle Mapper.
  73. Args:
  74. preds: estimated image
  75. target: ground truth image
  76. reduction: a method to reduce metric score over labels.
  77. - ``'elementwise_mean'``: takes the mean (default)
  78. - ``'sum'``: takes the sum
  79. - ``'none'`` or ``None``: no reduction will be applied
  80. Return:
  81. Tensor with Spectral Angle Mapper score
  82. Raises:
  83. TypeError:
  84. If ``preds`` and ``target`` don't have the same data type.
  85. ValueError:
  86. If ``preds`` and ``target`` don't have ``BxCxHxW shape``.
  87. Example:
  88. >>> from torch import rand
  89. >>> from torchmetrics.functional.image import spectral_angle_mapper
  90. >>> preds = rand([16, 3, 16, 16],)
  91. >>> target = rand([16, 3, 16, 16])
  92. >>> spectral_angle_mapper(preds, target)
  93. tensor(0.5914)
  94. References:
  95. [1] Roberta H. Yuhas, Alexander F. H. Goetz and Joe W. Boardman, "Discrimination among semi-arid
  96. landscape endmembers using the Spectral Angle Mapper (SAM) algorithm" in PL, Summaries of the Third Annual JPL
  97. Airborne Geoscience Workshop, vol. 1, June 1, 1992.
  98. """
  99. preds, target = _sam_update(preds, target)
  100. return _sam_compute(preds, target, reduction)