psnr.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Optional, Union
  15. import torch
  16. from torch import Tensor, tensor
  17. from typing_extensions import Literal
  18. from torchmetrics.utilities import rank_zero_warn, reduce
  19. def _psnr_compute(
  20. sum_squared_error: Tensor,
  21. num_obs: Tensor,
  22. data_range: Tensor,
  23. base: float = 10.0,
  24. reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
  25. ) -> Tensor:
  26. """Compute peak signal-to-noise ratio.
  27. Args:
  28. sum_squared_error: Sum of square of errors over all observations
  29. num_obs: Number of predictions or observations
  30. data_range: the range of the data. If None, it is determined from the data (max - min).
  31. ``data_range`` must be given when ``dim`` is not None.
  32. base: a base of a logarithm to use
  33. reduction: a method to reduce metric score over labels.
  34. - ``'elementwise_mean'``: takes the mean (default)
  35. - ``'sum'``: takes the sum
  36. - ``'none'`` or ``None``: no reduction will be applied
  37. Example:
  38. >>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
  39. >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
  40. >>> data_range = target.max() - target.min()
  41. >>> sum_squared_error, num_obs = _psnr_update(preds, target)
  42. >>> _psnr_compute(sum_squared_error, num_obs, data_range)
  43. tensor(2.5527)
  44. """
  45. psnr_base_e = 2 * torch.log(data_range) - torch.log(sum_squared_error / num_obs)
  46. psnr_vals = psnr_base_e * (10 / torch.log(tensor(base)))
  47. return reduce(psnr_vals, reduction=reduction)
  48. def _psnr_update(
  49. preds: Tensor,
  50. target: Tensor,
  51. dim: Optional[Union[int, tuple[int, ...]]] = None,
  52. ) -> tuple[Tensor, Tensor]:
  53. """Update and return variables required to compute peak signal-to-noise ratio.
  54. Args:
  55. preds: Predicted tensor
  56. target: Ground truth tensor
  57. dim: Dimensions to reduce PSNR scores over provided as either an integer or a list of integers.
  58. Default is None meaning scores will be reduced across all dimensions.
  59. """
  60. if not preds.is_floating_point():
  61. preds = preds.to(torch.float32)
  62. if not target.is_floating_point():
  63. target = target.to(torch.float32)
  64. if dim is None:
  65. sum_squared_error = torch.sum(torch.pow(preds - target, 2))
  66. num_obs = tensor(target.numel(), device=target.device)
  67. return sum_squared_error, num_obs
  68. diff = preds - target
  69. sum_squared_error = torch.sum(diff * diff, dim=dim)
  70. dim_list = [dim] if isinstance(dim, int) else list(dim)
  71. if not dim_list:
  72. num_obs = tensor(target.numel(), device=target.device)
  73. else:
  74. num_obs = tensor(target.size(), device=target.device)[dim_list].prod()
  75. num_obs = num_obs.expand_as(sum_squared_error)
  76. return sum_squared_error, num_obs
  77. def peak_signal_noise_ratio(
  78. preds: Tensor,
  79. target: Tensor,
  80. data_range: Union[float, tuple[float, float]],
  81. base: float = 10.0,
  82. reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
  83. dim: Optional[Union[int, tuple[int, ...]]] = None,
  84. ) -> Tensor:
  85. """Compute the peak signal-to-noise ratio.
  86. Args:
  87. preds: estimated signal
  88. target: groun truth signal
  89. data_range:
  90. the range of the data. If a tuple is provided then the range is calculated as the difference and
  91. input is clamped between the values.
  92. base: a base of a logarithm to use
  93. reduction: a method to reduce metric score over labels.
  94. - ``'elementwise_mean'``: takes the mean (default)
  95. - ``'sum'``: takes the sum
  96. - ``'none'`` or None``: no reduction will be applied
  97. dim:
  98. Dimensions to reduce PSNR scores over provided as either an integer or a list of integers. Default is
  99. None meaning scores will be reduced across all dimensions.
  100. Return:
  101. Tensor with PSNR score
  102. Example:
  103. >>> from torchmetrics.functional.image import peak_signal_noise_ratio
  104. >>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
  105. >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
  106. >>> peak_signal_noise_ratio(pred, target, data_range=3.0)
  107. tensor(2.5527)
  108. .. attention::
  109. Half precision is only support on GPU for this metric.
  110. """
  111. if dim is None and reduction != "elementwise_mean":
  112. rank_zero_warn(f"The `reduction={reduction}` will not have any effect when `dim` is None.")
  113. if isinstance(data_range, tuple):
  114. preds = torch.clamp(preds, min=data_range[0], max=data_range[1])
  115. target = torch.clamp(target, min=data_range[0], max=data_range[1])
  116. data_range_val = tensor(data_range[1] - data_range[0])
  117. else:
  118. data_range_val = tensor(float(data_range))
  119. sum_squared_error, num_obs = _psnr_update(preds, target, dim=dim)
  120. return _psnr_compute(sum_squared_error, num_obs, data_range_val, base=base, reduction=reduction)