qnr.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, List, Optional, Union
  16. from torch import Tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.functional.image.d_lambda import _spectral_distortion_index_compute, _spectral_distortion_index_update
  19. from torchmetrics.functional.image.d_s import _spatial_distortion_index_compute, _spatial_distortion_index_update
  20. from torchmetrics.metric import Metric
  21. from torchmetrics.utilities import rank_zero_warn
  22. from torchmetrics.utilities.data import dim_zero_cat
  23. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE, _TORCHVISION_AVAILABLE
  24. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  25. if not _MATPLOTLIB_AVAILABLE:
  26. __doctest_skip__ = ["QualityWithNoReference.plot"]
  27. if not _TORCHVISION_AVAILABLE:
  28. __doctest_skip__ = ["QualityWithNoReference", "QualityWithNoReference.plot"]
  29. class QualityWithNoReference(Metric):
  30. """Compute Quality with No Reference (QualityWithNoReference_) also now as QNR.
  31. The metric is used to compare the joint spectral and spatial distortion between two images.
  32. As input to ``forward`` and ``update`` the metric accepts the following input
  33. - ``preds`` (:class:`~torch.Tensor`): High resolution multispectral image of shape ``(N,C,H,W)``.
  34. - ``target`` (:class:`~Dict`): A dictionary containing the following keys:
  35. - ``ms`` (:class:`~torch.Tensor`): Low resolution multispectral image of shape ``(N,C,H',W')``.
  36. - ``pan`` (:class:`~torch.Tensor`): High resolution panchromatic image of shape ``(N,C,H,W)``.
  37. - ``pan_lr`` (:class:`~torch.Tensor`): (optional) Low resolution panchromatic image of shape ``(N,C,H',W')``.
  38. where H and W must be multiple of H' and W'.
  39. When ``pan_lr`` is ``None``, a uniform filter will be applied on ``pan`` to produce a degraded image. The degraded
  40. image is then resized to match the size of ``ms`` and served as ``pan_lr`` in the calculation.
  41. As output of `forward` and `compute` the metric returns the following output
  42. - ``qnr`` (:class:`~torch.Tensor`): if ``reduction!='none'`` returns float scalar tensor with average QNR value
  43. over sample else returns tensor of shape ``(N,)`` with QNR values per sample
  44. Args:
  45. alpha: Relevance of spectral distortion.
  46. beta: Relevance of spatial distortion.
  47. norm_order: Order of the norm applied on the difference.
  48. window_size: Window size of the filter applied to degrade the high resolution panchromatic image.
  49. reduction: a method to reduce metric score over labels.
  50. - ``'elementwise_mean'``: takes the mean (default)
  51. - ``'sum'``: takes the sum
  52. - ``'none'``: no reduction will be applied
  53. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  54. Example:
  55. >>> from torch import rand
  56. >>> from torchmetrics.image import QualityWithNoReference
  57. >>> preds = rand([16, 3, 32, 32])
  58. >>> target = {
  59. ... 'ms': rand([16, 3, 16, 16]),
  60. ... 'pan': rand([16, 3, 32, 32]),
  61. ... }
  62. >>> qnr = QualityWithNoReference()
  63. >>> qnr(preds, target)
  64. tensor(0.9694)
  65. """
  66. higher_is_better: bool = True
  67. is_differentiable: bool = True
  68. full_state_update: bool = False
  69. plot_lower_bound: float = 0.0
  70. plot_upper_bound: float = 1.0
  71. preds: List[Tensor]
  72. ms: List[Tensor]
  73. pan: List[Tensor]
  74. pan_lr: List[Tensor]
  75. def __init__(
  76. self,
  77. alpha: float = 1,
  78. beta: float = 1,
  79. norm_order: int = 1,
  80. window_size: int = 7,
  81. reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean",
  82. **kwargs: Any,
  83. ) -> None:
  84. super().__init__(**kwargs)
  85. rank_zero_warn(
  86. "Metric `QualityWithNoReference` will save all targets and predictions in buffer."
  87. " For large datasets this may lead to large memory footprint."
  88. )
  89. if not isinstance(alpha, (int, float)) or alpha < 0:
  90. raise ValueError(f"Expected `alpha` to be a non-negative real number. Got alpha: {alpha}.")
  91. self.alpha = alpha
  92. if not isinstance(beta, (int, float)) or beta < 0:
  93. raise ValueError(f"Expected `beta` to be a non-negative real number. Got beta: {beta}.")
  94. self.beta = beta
  95. if not isinstance(norm_order, int) or norm_order <= 0:
  96. raise ValueError(f"Expected `norm_order` to be a positive integer. Got norm_order: {norm_order}.")
  97. self.norm_order = norm_order
  98. if not isinstance(window_size, int) or window_size <= 0:
  99. raise ValueError(f"Expected `window_size` to be a positive integer. Got window_size: {window_size}.")
  100. self.window_size = window_size
  101. allowed_reductions = ("elementwise_mean", "sum", "none")
  102. if reduction not in allowed_reductions:
  103. raise ValueError(f"Expected argument `reduction` be one of {allowed_reductions} but got {reduction}")
  104. self.reduction = reduction
  105. self.add_state("preds", default=[], dist_reduce_fx="cat")
  106. self.add_state("ms", default=[], dist_reduce_fx="cat")
  107. self.add_state("pan", default=[], dist_reduce_fx="cat")
  108. self.add_state("pan_lr", default=[], dist_reduce_fx="cat")
  109. def update(self, preds: Tensor, target: dict[str, Tensor]) -> None:
  110. """Update state with preds and target.
  111. Args:
  112. preds: High resolution multispectral image.
  113. target: A dictionary containing the following keys:
  114. - ``'ms'``: low resolution multispectral image.
  115. - ``'pan'``: high resolution panchromatic image.
  116. - ``'pan_lr'``: (optional) low resolution panchromatic image.
  117. Raises:
  118. ValueError:
  119. If ``target`` doesn't have ``ms`` and ``pan``.
  120. """
  121. if "ms" not in target:
  122. raise ValueError(f"Expected `target` to have key `ms`. Got target: {target.keys()}.")
  123. if "pan" not in target:
  124. raise ValueError(f"Expected `target` to have key `pan`. Got target: {target.keys()}.")
  125. ms = target["ms"]
  126. pan = target["pan"]
  127. pan_lr = target.get("pan_lr")
  128. preds, ms = _spectral_distortion_index_update(preds, ms)
  129. preds, ms, pan, pan_lr = _spatial_distortion_index_update(preds, ms, pan, pan_lr)
  130. self.preds.append(preds)
  131. self.ms.append(target["ms"])
  132. self.pan.append(target["pan"])
  133. if "pan_lr" in target:
  134. self.pan_lr.append(target["pan_lr"])
  135. def compute(self) -> Tensor:
  136. """Compute and returns quality with no reference."""
  137. preds = dim_zero_cat(self.preds)
  138. ms = dim_zero_cat(self.ms)
  139. pan = dim_zero_cat(self.pan)
  140. pan_lr = dim_zero_cat(self.pan_lr) if len(self.pan_lr) > 0 else None
  141. d_lambda = _spectral_distortion_index_compute(preds, ms, self.norm_order, self.reduction)
  142. d_s = _spatial_distortion_index_compute(
  143. preds, ms, pan, pan_lr, self.norm_order, self.window_size, self.reduction
  144. )
  145. return (1 - d_lambda) ** self.alpha * (1 - d_s) ** self.beta
  146. def plot(
  147. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  148. ) -> _PLOT_OUT_TYPE:
  149. """Plot a single or multiple values from the metric.
  150. Args:
  151. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  152. If no value is provided, will automatically call `metric.compute` and plot that result.
  153. ax: An matplotlib axis object. If provided will add plot to that axis
  154. Returns:
  155. Figure and Axes object
  156. Raises:
  157. ModuleNotFoundError:
  158. If `matplotlib` is not installed
  159. .. plot::
  160. :scale: 75
  161. >>> # Example plotting a single value
  162. >>> from torch import rand
  163. >>> from torchmetrics.image import QualityWithNoReference
  164. >>> preds = rand([16, 3, 32, 32])
  165. >>> target = {
  166. ... 'ms': rand([16, 3, 16, 16]),
  167. ... 'pan': rand([16, 3, 32, 32]),
  168. ... }
  169. >>> metric = QualityWithNoReference()
  170. >>> metric.update(preds, target)
  171. >>> fig_, ax_ = metric.plot()
  172. .. plot::
  173. :scale: 75
  174. >>> # Example plotting multiple values
  175. >>> from torch import rand
  176. >>> from torchmetrics.image import QualityWithNoReference
  177. >>> preds = rand([16, 3, 32, 32])
  178. >>> target = {
  179. ... 'ms': rand([16, 3, 16, 16]),
  180. ... 'pan': rand([16, 3, 32, 32]),
  181. ... }
  182. >>> metric = QualityWithNoReference()
  183. >>> values = [ ]
  184. >>> for _ in range(10):
  185. ... values.append(metric(preds, target))
  186. >>> fig_, ax_ = metric.plot(values)
  187. """
  188. return self._plot(val, ax)