d_lambda.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import torch
  15. from torch import Tensor
  16. from typing_extensions import Literal
  17. from torchmetrics.functional.image.uqi import universal_image_quality_index
  18. from torchmetrics.utilities.distributed import reduce
  19. def _spectral_distortion_index_update(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]:
  20. """Update and returns variables required to compute Spectral Distortion Index.
  21. Args:
  22. preds: Low resolution multispectral image
  23. target: High resolution fused image
  24. """
  25. if preds.dtype != target.dtype:
  26. raise TypeError(
  27. f"Expected `ms` and `fused` to have the same data type. Got ms: {preds.dtype} and fused: {target.dtype}."
  28. )
  29. if len(preds.shape) != 4:
  30. raise ValueError(
  31. f"Expected `preds` and `target` to have BxCxHxW shape. Got preds: {preds.shape} and target: {target.shape}."
  32. )
  33. if preds.shape[:2] != target.shape[:2]:
  34. raise ValueError(
  35. "Expected `preds` and `target` to have same batch and channel sizes."
  36. f"Got preds: {preds.shape} and target: {target.shape}."
  37. )
  38. return preds, target
  39. def _spectral_distortion_index_compute(
  40. preds: Tensor,
  41. target: Tensor,
  42. p: int = 1,
  43. reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean",
  44. ) -> Tensor:
  45. """Compute Spectral Distortion Index (SpectralDistortionIndex_).
  46. Args:
  47. preds: Low resolution multispectral image
  48. target: High resolution fused image
  49. p: a parameter to emphasize large spectral difference
  50. reduction: a method to reduce metric score over labels.
  51. - ``'elementwise_mean'``: takes the mean (default)
  52. - ``'sum'``: takes the sum
  53. - ``'none'``: no reduction will be applied
  54. Example:
  55. >>> from torch import rand
  56. >>> preds = rand([16, 3, 16, 16])
  57. >>> target = rand([16, 3, 16, 16])
  58. >>> preds, target = _spectral_distortion_index_update(preds, target)
  59. >>> _spectral_distortion_index_compute(preds, target)
  60. tensor(0.0234)
  61. """
  62. length = preds.shape[1]
  63. m1 = torch.zeros((length, length), device=preds.device)
  64. m2 = torch.zeros((length, length), device=preds.device)
  65. for k in range(length):
  66. num = length - (k + 1)
  67. if num == 0:
  68. continue
  69. stack1 = target[:, k : k + 1, :, :].repeat(num, 1, 1, 1)
  70. stack2 = torch.cat([target[:, r : r + 1, :, :] for r in range(k + 1, length)], dim=0)
  71. score = [
  72. s.mean() for s in universal_image_quality_index(stack1, stack2, reduction="none").split(preds.shape[0])
  73. ]
  74. m1[k, k + 1 :] = torch.stack(score, 0)
  75. stack1 = preds[:, k : k + 1, :, :].repeat(num, 1, 1, 1)
  76. stack2 = torch.cat([preds[:, r : r + 1, :, :] for r in range(k + 1, length)], dim=0)
  77. score = [
  78. s.mean() for s in universal_image_quality_index(stack1, stack2, reduction="none").split(preds.shape[0])
  79. ]
  80. m2[k, k + 1 :] = torch.stack(score, 0)
  81. m1 = m1 + m1.T
  82. m2 = m2 + m2.T
  83. diff = torch.pow(torch.abs(m1 - m2), p)
  84. # Special case: when number of channels (L) is 1, there will be only one element in M1 and M2. Hence no need to sum.
  85. if length == 1:
  86. output = torch.pow(diff, (1.0 / p))
  87. else:
  88. output = torch.pow(1.0 / (length * (length - 1)) * torch.sum(diff), (1.0 / p))
  89. return reduce(output, reduction)
  90. def spectral_distortion_index(
  91. preds: Tensor,
  92. target: Tensor,
  93. p: int = 1,
  94. reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean",
  95. ) -> Tensor:
  96. """Calculate `Spectral Distortion Index`_ (SpectralDistortionIndex_) also known as D_lambda.
  97. Metric is used to compare the spectral distortion between two images.
  98. Args:
  99. preds: Low resolution multispectral image
  100. target: High resolution fused image
  101. p: Large spectral differences
  102. reduction: a method to reduce metric score over labels.
  103. - ``'elementwise_mean'``: takes the mean (default)
  104. - ``'sum'``: takes the sum
  105. - ``'none'``: no reduction will be applied
  106. Return:
  107. Tensor with SpectralDistortionIndex score
  108. Raises:
  109. TypeError:
  110. If ``preds`` and ``target`` don't have the same data type.
  111. ValueError:
  112. If ``preds`` and ``target`` don't have ``BxCxHxW shape``.
  113. ValueError:
  114. If ``p`` is not a positive integer.
  115. Example:
  116. >>> from torch import rand
  117. >>> from torchmetrics.functional.image import spectral_distortion_index
  118. >>> preds = rand([16, 3, 16, 16])
  119. >>> target = rand([16, 3, 16, 16])
  120. >>> spectral_distortion_index(preds, target)
  121. tensor(0.0234)
  122. """
  123. if not isinstance(p, int) or p <= 0:
  124. raise ValueError(f"Expected `p` to be a positive integer. Got p: {p}.")
  125. preds, target = _spectral_distortion_index_update(preds, target)
  126. return _spectral_distortion_index_compute(preds, target, p, reduction)