uqi.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Optional
  16. import torch
  17. from torch import Tensor, nn
  18. from typing_extensions import Literal
  19. from torchmetrics.functional.image.utils import _gaussian_kernel_2d
  20. from torchmetrics.utilities.checks import _check_same_shape
  21. from torchmetrics.utilities.distributed import reduce
  22. def _uqi_update(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]:
  23. """Update and returns variables required to compute Universal Image Quality Index.
  24. Args:
  25. preds: Predicted tensor
  26. target: Ground truth tensor
  27. """
  28. if preds.dtype != target.dtype:
  29. raise TypeError(
  30. "Expected `preds` and `target` to have the same data type."
  31. f" Got preds: {preds.dtype} and target: {target.dtype}."
  32. )
  33. _check_same_shape(preds, target)
  34. if len(preds.shape) != 4:
  35. raise ValueError(
  36. f"Expected `preds` and `target` to have BxCxHxW shape. Got preds: {preds.shape} and target: {target.shape}."
  37. )
  38. return preds, target
  39. def _uqi_compute(
  40. preds: Tensor,
  41. target: Tensor,
  42. kernel_size: Sequence[int] = (11, 11),
  43. sigma: Sequence[float] = (1.5, 1.5),
  44. reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
  45. ) -> Tensor:
  46. """Compute Universal Image Quality Index.
  47. Args:
  48. preds: estimated image
  49. target: ground truth image
  50. kernel_size: size of the gaussian kernel
  51. sigma: Standard deviation of the gaussian kernel
  52. reduction: a method to reduce metric score over labels.
  53. - ``'elementwise_mean'``: takes the mean (default)
  54. - ``'sum'``: takes the sum
  55. - ``'none'`` or ``None``: no reduction will be applied
  56. Example:
  57. >>> preds = torch.rand([16, 1, 16, 16])
  58. >>> target = preds * 0.75
  59. >>> preds, target = _uqi_update(preds, target)
  60. >>> _uqi_compute(preds, target)
  61. tensor(0.9216)
  62. """
  63. if len(kernel_size) != 2 or len(sigma) != 2:
  64. raise ValueError(
  65. "Expected `kernel_size` and `sigma` to have the length of two."
  66. f" Got kernel_size: {len(kernel_size)} and sigma: {len(sigma)}."
  67. )
  68. if any(x % 2 == 0 or x <= 0 for x in kernel_size):
  69. raise ValueError(f"Expected `kernel_size` to have odd positive number. Got {kernel_size}.")
  70. if any(y <= 0 for y in sigma):
  71. raise ValueError(f"Expected `sigma` to have positive number. Got {sigma}.")
  72. device = preds.device
  73. channel = preds.size(1)
  74. dtype = preds.dtype
  75. kernel = _gaussian_kernel_2d(channel, kernel_size, sigma, dtype, device)
  76. pad_h = (kernel_size[0] - 1) // 2
  77. pad_w = (kernel_size[1] - 1) // 2
  78. preds = nn.functional.pad(preds, (pad_h, pad_h, pad_w, pad_w), mode="reflect")
  79. target = nn.functional.pad(target, (pad_h, pad_h, pad_w, pad_w), mode="reflect")
  80. input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W)
  81. outputs = nn.functional.conv2d(input_list, kernel, groups=channel)
  82. output_list = outputs.split(preds.shape[0])
  83. mu_pred_sq = output_list[0].pow(2)
  84. mu_target_sq = output_list[1].pow(2)
  85. mu_pred_target = output_list[0] * output_list[1]
  86. # Calculate the variance of the predicted and target images, should be non-negative
  87. sigma_pred_sq = torch.clamp(output_list[2] - mu_pred_sq, min=0.0)
  88. sigma_target_sq = torch.clamp(output_list[3] - mu_target_sq, min=0.0)
  89. sigma_pred_target = output_list[4] - mu_pred_target
  90. upper = 2 * sigma_pred_target
  91. lower = sigma_pred_sq + sigma_target_sq
  92. eps = torch.finfo(sigma_pred_sq.dtype).eps
  93. uqi_idx = ((2 * mu_pred_target) * upper) / ((mu_pred_sq + mu_target_sq) * lower + eps)
  94. uqi_idx = uqi_idx[..., pad_h:-pad_h, pad_w:-pad_w]
  95. return reduce(uqi_idx, reduction)
  96. def universal_image_quality_index(
  97. preds: Tensor,
  98. target: Tensor,
  99. kernel_size: Sequence[int] = (11, 11),
  100. sigma: Sequence[float] = (1.5, 1.5),
  101. reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
  102. ) -> Tensor:
  103. """Universal Image Quality Index.
  104. Args:
  105. preds: estimated image
  106. target: ground truth image
  107. kernel_size: size of the gaussian kernel
  108. sigma: Standard deviation of the gaussian kernel
  109. reduction: a method to reduce metric score over labels.
  110. - ``'elementwise_mean'``: takes the mean (default)
  111. - ``'sum'``: takes the sum
  112. - ``'none'`` or ``None``: no reduction will be applied
  113. Return:
  114. Tensor with UniversalImageQualityIndex score
  115. Raises:
  116. TypeError:
  117. If ``preds`` and ``target`` don't have the same data type.
  118. ValueError:
  119. If ``preds`` and ``target`` don't have ``BxCxHxW shape``.
  120. ValueError:
  121. If the length of ``kernel_size`` or ``sigma`` is not ``2``.
  122. ValueError:
  123. If one of the elements of ``kernel_size`` is not an ``odd positive number``.
  124. ValueError:
  125. If one of the elements of ``sigma`` is not a ``positive number``.
  126. Example:
  127. >>> from torchmetrics.functional.image import universal_image_quality_index
  128. >>> preds = torch.rand([16, 1, 16, 16])
  129. >>> target = preds * 0.75
  130. >>> universal_image_quality_index(preds, target)
  131. tensor(0.9216)
  132. References:
  133. [1] Zhou Wang and A. C. Bovik, "A universal image quality index," in IEEE Signal Processing Letters, vol. 9,
  134. no. 3, pp. 81-84, March 2002, doi: 10.1109/97.995823.
  135. [2] Zhou Wang, A. C. Bovik, H. R. Sheikh and E. P. Simoncelli, "Image quality assessment: from error visibility
  136. to structural similarity," in IEEE Transactions on Image Processing, vol. 13, no. 4, pp. 600-612, April 2004,
  137. doi: 10.1109/TIP.2003.819861.
  138. """
  139. preds, target = _uqi_update(preds, target)
  140. return _uqi_compute(preds, target, kernel_size, sigma, reduction)