qnr.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Optional
  15. from torch import Tensor
  16. from typing_extensions import Literal
  17. from torchmetrics.functional.image.d_lambda import spectral_distortion_index
  18. from torchmetrics.functional.image.d_s import spatial_distortion_index
  19. from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE
  20. if not _TORCHVISION_AVAILABLE:
  21. __doctest_skip__ = ["quality_with_no_reference"]
  22. def quality_with_no_reference(
  23. preds: Tensor,
  24. ms: Tensor,
  25. pan: Tensor,
  26. pan_lr: Optional[Tensor] = None,
  27. alpha: float = 1,
  28. beta: float = 1,
  29. norm_order: int = 1,
  30. window_size: int = 7,
  31. reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean",
  32. ) -> Tensor:
  33. """Calculate `Quality with No Reference`_ (QualityWithNoReference_) also known as QNR.
  34. Metric is used to compare the joint spectral and spatial distortion between two images.
  35. Args:
  36. preds: High resolution multispectral image.
  37. ms: Low resolution multispectral image.
  38. pan: High resolution panchromatic image.
  39. pan_lr: Low resolution panchromatic image.
  40. alpha: Relevance of spectral distortion.
  41. beta: Relevance of spatial distortion.
  42. norm_order: Order of the norm applied on the difference.
  43. window_size: Window size of the filter applied to degrade the high resolution panchromatic image.
  44. reduction: A method to reduce metric score over labels.
  45. - ``'elementwise_mean'``: takes the mean (default)
  46. - ``'sum'``: takes the sum
  47. - ``'none'``: no reduction will be applied
  48. Return:
  49. Tensor with QualityWithNoReference score
  50. Raises:
  51. ValueError:
  52. If ``alpha`` or ``beta`` is not a non-negative real number.
  53. Example:
  54. >>> from torch import rand
  55. >>> from torchmetrics.functional.image import quality_with_no_reference
  56. >>> preds = rand([16, 3, 32, 32])
  57. >>> ms = rand([16, 3, 16, 16])
  58. >>> pan = rand([16, 3, 32, 32])
  59. >>> quality_with_no_reference(preds, ms, pan)
  60. tensor(0.9694)
  61. """
  62. if not isinstance(alpha, (int, float)) or alpha < 0:
  63. raise ValueError(f"Expected `alpha` to be a non-negative real number. Got alpha: {alpha}.")
  64. if not isinstance(beta, (int, float)) or beta < 0:
  65. raise ValueError(f"Expected `beta` to be a non-negative real number. Got beta: {beta}.")
  66. d_lambda = spectral_distortion_index(preds, ms, norm_order, reduction)
  67. d_s = spatial_distortion_index(preds, ms, pan, pan_lr, norm_order, window_size, reduction)
  68. return (1 - d_lambda) ** alpha * (1 - d_s) ** beta