d_s.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Optional
  15. import torch
  16. from torch import Tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.functional.image.uqi import universal_image_quality_index
  19. from torchmetrics.utilities.distributed import reduce
  20. from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE
  21. if not _TORCHVISION_AVAILABLE:
  22. __doctest_skip__ = ["_spatial_distortion_index_compute", "spatial_distortion_index"]
  23. def _spatial_distortion_index_update(
  24. preds: Tensor, ms: Tensor, pan: Tensor, pan_lr: Optional[Tensor] = None
  25. ) -> tuple[Tensor, Tensor, Tensor, Optional[Tensor]]:
  26. """Update and returns variables required to compute Spatial Distortion Index.
  27. Args:
  28. preds: High resolution multispectral image.
  29. ms: Low resolution multispectral image.
  30. pan: High resolution panchromatic image.
  31. pan_lr: Low resolution panchromatic image.
  32. Return:
  33. A tuple of Tensors containing ``preds``, ``ms``, ``pan`` and ``pan_lr``.
  34. Raises:
  35. TypeError:
  36. If ``preds``, ``ms``, ``pan`` and ``pan_lr`` don't have the same data type.
  37. ValueError:
  38. If ``preds``, ``ms``, ``pan`` and ``pan_lr`` don't have ``BxCxHxW shape``.
  39. ValueError:
  40. If ``preds``, ``ms``, ``pan`` and ``pan_lr`` don't have the same batch and channel sizes.
  41. ValueError:
  42. If ``preds`` and ``pan`` don't have the same dimension.
  43. ValueError:
  44. If ``ms`` and ``pan_lr`` don't have the same dimension.
  45. ValueError:
  46. If ``preds`` and ``pan`` don't have dimension which is multiple of that of ``ms``.
  47. """
  48. if len(preds.shape) != 4:
  49. raise ValueError(f"Expected `preds` to have BxCxHxW shape. Got preds: {preds.shape}.")
  50. if preds.dtype != ms.dtype:
  51. raise TypeError(
  52. f"Expected `preds` and `ms` to have the same data type. Got preds: {preds.dtype} and ms: {ms.dtype}."
  53. )
  54. if preds.dtype != pan.dtype:
  55. raise TypeError(
  56. f"Expected `preds` and `pan` to have the same data type. Got preds: {preds.dtype} and pan: {pan.dtype}."
  57. )
  58. if pan_lr is not None and preds.dtype != pan_lr.dtype:
  59. raise TypeError(
  60. f"Expected `preds` and `pan_lr` to have the same data type."
  61. f" Got preds: {preds.dtype} and pan_lr: {pan_lr.dtype}."
  62. )
  63. if len(ms.shape) != 4:
  64. raise ValueError(f"Expected `ms` to have BxCxHxW shape. Got ms: {ms.shape}.")
  65. if len(pan.shape) != 4:
  66. raise ValueError(f"Expected `pan` to have BxCxHxW shape. Got pan: {pan.shape}.")
  67. if pan_lr is not None and len(pan_lr.shape) != 4:
  68. raise ValueError(f"Expected `pan_lr` to have BxCxHxW shape. Got pan_lr: {pan_lr.shape}.")
  69. if preds.shape[:2] != ms.shape[:2]:
  70. raise ValueError(
  71. f"Expected `preds` and `ms` to have the same batch and channel sizes."
  72. f" Got preds: {preds.shape} and ms: {ms.shape}."
  73. )
  74. if preds.shape[:2] != pan.shape[:2]:
  75. raise ValueError(
  76. f"Expected `preds` and `pan` to have the same batch and channel sizes."
  77. f" Got preds: {preds.shape} and pan: {pan.shape}."
  78. )
  79. if pan_lr is not None and preds.shape[:2] != pan_lr.shape[:2]:
  80. raise ValueError(
  81. f"Expected `preds` and `pan_lr` to have the same batch and channel sizes."
  82. f" Got preds: {preds.shape} and pan_lr: {pan_lr.shape}."
  83. )
  84. preds_h, preds_w = preds.shape[-2:]
  85. ms_h, ms_w = ms.shape[-2:]
  86. pan_h, pan_w = pan.shape[-2:]
  87. if preds_h != pan_h:
  88. raise ValueError(f"Expected `preds` and `pan` to have the same height. Got preds: {preds_h} and pan: {pan_h}")
  89. if preds_w != pan_w:
  90. raise ValueError(f"Expected `preds` and `pan` to have the same width. Got preds: {preds_w} and pan: {pan_w}")
  91. if preds_h % ms_h != 0:
  92. raise ValueError(
  93. f"Expected height of `preds` to be multiple of height of `ms`. Got preds: {preds_h} and ms: {ms_h}."
  94. )
  95. if preds_w % ms_w != 0:
  96. raise ValueError(
  97. f"Expected width of `preds` to be multiple of width of `ms`. Got preds: {preds_w} and ms: {ms_w}."
  98. )
  99. if pan_h % ms_h != 0:
  100. raise ValueError(
  101. f"Expected height of `pan` to be multiple of height of `ms`. Got preds: {pan_h} and ms: {ms_h}."
  102. )
  103. if pan_w % ms_w != 0:
  104. raise ValueError(f"Expected width of `pan` to be multiple of width of `ms`. Got preds: {pan_w} and ms: {ms_w}.")
  105. if pan_lr is not None:
  106. pan_lr_h, pan_lr_w = pan_lr.shape[-2:]
  107. if pan_lr_h != ms_h:
  108. raise ValueError(
  109. f"Expected `ms` and `pan_lr` to have the same height. Got ms: {ms_h} and pan_lr: {pan_lr_h}."
  110. )
  111. if pan_lr_w != ms_w:
  112. raise ValueError(
  113. f"Expected `ms` and `pan_lr` to have the same width. Got ms: {ms_w} and pan_lr: {pan_lr_w}."
  114. )
  115. return preds, ms, pan, pan_lr
  116. def _spatial_distortion_index_compute(
  117. preds: Tensor,
  118. ms: Tensor,
  119. pan: Tensor,
  120. pan_lr: Optional[Tensor] = None,
  121. norm_order: int = 1,
  122. window_size: int = 7,
  123. reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean",
  124. ) -> Tensor:
  125. """Compute Spatial Distortion Index (SpatialDistortionIndex_).
  126. Args:
  127. preds: High resolution multispectral image.
  128. ms: Low resolution multispectral image.
  129. pan: High resolution panchromatic image.
  130. pan_lr: Low resolution panchromatic image.
  131. norm_order: Order of the norm applied on the difference.
  132. window_size: Window size of the filter applied to degrade the high resolution panchromatic image.
  133. reduction: A method to reduce metric score over labels.
  134. - ``'elementwise_mean'``: takes the mean (default)
  135. - ``'sum'``: takes the sum
  136. - ``'none'``: no reduction will be applied
  137. Return:
  138. Tensor with SpatialDistortionIndex score
  139. Raises:
  140. ValueError
  141. If ``window_size`` is smaller than dimension of ``ms``.
  142. Example:
  143. >>> from torch import rand
  144. >>> preds = rand([16, 3, 32, 32])
  145. >>> ms = rand([16, 3, 16, 16])
  146. >>> pan = rand([16, 3, 32, 32])
  147. >>> preds, ms, pan, pan_lr = _spatial_distortion_index_update(preds, ms, pan)
  148. >>> _spatial_distortion_index_compute(preds, ms, pan, pan_lr)
  149. tensor(0.0090)
  150. """
  151. length = preds.shape[1]
  152. ms_h, ms_w = ms.shape[-2:]
  153. if window_size >= ms_h or window_size >= ms_w:
  154. raise ValueError(
  155. f"Expected `window_size` to be smaller than dimension of `ms`. Got window_size: {window_size}."
  156. )
  157. if pan_lr is None:
  158. if not _TORCHVISION_AVAILABLE:
  159. raise ValueError(
  160. "When `pan_lr` is not provided as input to metric Spatial distortion index, torchvision should be "
  161. "installed. Please install with `pip install torchvision` or `pip install torchmetrics[image]`."
  162. )
  163. from torchvision.transforms.functional import resize
  164. from torchmetrics.functional.image.utils import _uniform_filter
  165. pan_degraded = _uniform_filter(pan, window_size=window_size)
  166. pan_degraded = resize(pan_degraded, size=ms.shape[-2:], antialias=False)
  167. else:
  168. pan_degraded = pan_lr
  169. m1 = torch.zeros(length, device=preds.device)
  170. m2 = torch.zeros(length, device=preds.device)
  171. for i in range(length):
  172. m1[i] = universal_image_quality_index(ms[:, i : i + 1], pan_degraded[:, i : i + 1])
  173. m2[i] = universal_image_quality_index(preds[:, i : i + 1], pan[:, i : i + 1])
  174. diff = (m1 - m2).abs() ** norm_order
  175. return reduce(diff, reduction) ** (1 / norm_order)
  176. def spatial_distortion_index(
  177. preds: Tensor,
  178. ms: Tensor,
  179. pan: Tensor,
  180. pan_lr: Optional[Tensor] = None,
  181. norm_order: int = 1,
  182. window_size: int = 7,
  183. reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean",
  184. ) -> Tensor:
  185. """Calculate `Spatial Distortion Index`_ (SpatialDistortionIndex_) also known as D_s.
  186. Metric is used to compare the spatial distortion between two images.
  187. Args:
  188. preds: High resolution multispectral image.
  189. ms: Low resolution multispectral image.
  190. pan: High resolution panchromatic image.
  191. pan_lr: Low resolution panchromatic image.
  192. norm_order: Order of the norm applied on the difference.
  193. window_size: Window size of the filter applied to degrade the high resolution panchromatic image.
  194. reduction: A method to reduce metric score over labels.
  195. - ``'elementwise_mean'``: takes the mean (default)
  196. - ``'sum'``: takes the sum
  197. - ``'none'``: no reduction will be applied
  198. Return:
  199. Tensor with SpatialDistortionIndex score
  200. Raises:
  201. TypeError:
  202. If ``preds``, ``ms``, ``pan`` and ``pan_lr`` don't have the same data type.
  203. ValueError:
  204. If ``preds``, ``ms``, ``pan`` and ``pan_lr`` don't have ``BxCxHxW shape``.
  205. ValueError:
  206. If ``preds``, ``ms``, ``pan`` and ``pan_lr`` don't have the same batch and channel sizes.
  207. ValueError:
  208. If ``preds`` and ``pan`` don't have the same dimension.
  209. ValueError:
  210. If ``ms`` and ``pan_lr`` don't have the same dimension.
  211. ValueError:
  212. If ``preds`` and ``pan`` don't have dimension which is multiple of that of ``ms``.
  213. ValueError:
  214. If ``norm_order`` is not a positive integer.
  215. ValueError:
  216. If ``window_size`` is not a positive integer.
  217. Example:
  218. >>> from torch import rand
  219. >>> from torchmetrics.functional.image import spatial_distortion_index
  220. >>> preds = rand([16, 3, 32, 32])
  221. >>> ms = rand([16, 3, 16, 16])
  222. >>> pan = rand([16, 3, 32, 32])
  223. >>> spatial_distortion_index(preds, ms, pan)
  224. tensor(0.0090)
  225. """
  226. if not isinstance(norm_order, int) or norm_order <= 0:
  227. raise ValueError(f"Expected `norm_order` to be a positive integer. Got norm_order: {norm_order}.")
  228. if not isinstance(window_size, int) or window_size <= 0:
  229. raise ValueError(f"Expected `window_size` to be a positive integer. Got window_size: {window_size}.")
  230. preds, ms, pan, pan_lr = _spatial_distortion_index_update(preds, ms, pan, pan_lr)
  231. return _spatial_distortion_index_compute(preds, ms, pan, pan_lr, norm_order, window_size, reduction)