psnr.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from functools import partial
  16. from typing import Any, Optional, Union
  17. import torch
  18. from torch import Tensor, tensor
  19. from typing_extensions import Literal
  20. from torchmetrics.functional.image.psnr import _psnr_compute, _psnr_update
  21. from torchmetrics.metric import Metric
  22. from torchmetrics.utilities import rank_zero_warn
  23. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  24. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  25. if not _MATPLOTLIB_AVAILABLE:
  26. __doctest_skip__ = ["PeakSignalNoiseRatio.plot"]
  27. class PeakSignalNoiseRatio(Metric):
  28. r"""`Compute Peak Signal-to-Noise Ratio`_ (PSNR).
  29. .. math:: \text{PSNR}(I, J) = 10 * \log_{10} \left(\frac{\max(I)^2}{\text{MSE}(I, J)}\right)
  30. Where :math:`\text{MSE}` denotes the `mean-squared-error`_ function.
  31. As input to ``forward`` and ``update`` the metric accepts the following input
  32. - ``preds`` (:class:`~torch.Tensor`): Predictions from model of shape ``(N,C,H,W)``
  33. - ``target`` (:class:`~torch.Tensor`): Ground truth values of shape ``(N,C,H,W)``
  34. As output of `forward` and `compute` the metric returns the following output
  35. - ``psnr`` (:class:`~torch.Tensor`): if ``reduction!='none'`` returns float scalar tensor with average PSNR value
  36. over sample else returns tensor of shape ``(N,)`` with PSNR values per sample
  37. Args:
  38. data_range:
  39. the range of the data. If a tuple is provided, then the range is calculated as the difference and
  40. input is clamped between the values.
  41. base: a base of a logarithm to use.
  42. reduction: a method to reduce metric score over labels.
  43. - ``'elementwise_mean'``: takes the mean (default)
  44. - ``'sum'``: takes the sum
  45. - ``'none'`` or ``None``: no reduction will be applied
  46. dim:
  47. Dimensions to reduce PSNR scores over, provided as either an integer or a list of integers. Default is
  48. None meaning scores will be reduced across all dimensions and all batches.
  49. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  50. Example:
  51. >>> from torchmetrics.image import PeakSignalNoiseRatio
  52. >>> psnr = PeakSignalNoiseRatio(data_range=3.0)
  53. >>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
  54. >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
  55. >>> psnr(preds, target)
  56. tensor(2.5527)
  57. """
  58. is_differentiable: bool = True
  59. higher_is_better: bool = True
  60. full_state_update: bool = False
  61. plot_lower_bound: float = 0.0
  62. data_range: Tensor
  63. def __init__(
  64. self,
  65. data_range: Union[float, tuple[float, float]],
  66. base: float = 10.0,
  67. reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
  68. dim: Optional[Union[int, tuple[int, ...]]] = None,
  69. **kwargs: Any,
  70. ) -> None:
  71. super().__init__(**kwargs)
  72. if dim is None and reduction != "elementwise_mean":
  73. rank_zero_warn(f"The `reduction={reduction}` will not have any effect when `dim` is None.")
  74. if dim is None:
  75. self.add_state("sum_squared_error", default=tensor(0.0), dist_reduce_fx="sum")
  76. self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
  77. else:
  78. self.add_state("sum_squared_error", default=[], dist_reduce_fx="cat")
  79. self.add_state("total", default=[], dist_reduce_fx="cat")
  80. self.clamping_fn = None
  81. if isinstance(data_range, tuple):
  82. self.add_state("data_range", default=tensor(data_range[1] - data_range[0]), dist_reduce_fx="mean")
  83. self.clamping_fn = partial(torch.clamp, min=data_range[0], max=data_range[1])
  84. else:
  85. self.add_state("data_range", default=tensor(float(data_range)), dist_reduce_fx="mean")
  86. self.base = base
  87. self.reduction = reduction
  88. self.dim = tuple(dim) if isinstance(dim, Sequence) else dim
  89. def update(self, preds: Tensor, target: Tensor) -> None:
  90. """Update state with predictions and targets."""
  91. if self.clamping_fn is not None:
  92. preds = self.clamping_fn(preds)
  93. target = self.clamping_fn(target)
  94. sum_squared_error, num_obs = _psnr_update(preds, target, dim=self.dim)
  95. if self.dim is None:
  96. if not isinstance(self.sum_squared_error, Tensor):
  97. raise TypeError(
  98. f"Expected `self.sum_squared_error` to be a Tensor, but got {type(self.sum_squared_error)}"
  99. )
  100. if not isinstance(self.total, Tensor):
  101. raise TypeError(f"Expected `self.total` to be a Tensor, but got {type(self.total)}")
  102. self.sum_squared_error += sum_squared_error
  103. self.total += num_obs
  104. else:
  105. if not isinstance(self.sum_squared_error, list):
  106. raise TypeError(
  107. f"Expected `self.sum_squared_error` to be a list, but got {type(self.sum_squared_error)}"
  108. )
  109. if not isinstance(self.total, list):
  110. raise TypeError(f"Expected `self.total` to be a list, but got {type(self.total)}")
  111. self.sum_squared_error.append(sum_squared_error)
  112. self.total.append(num_obs)
  113. def compute(self) -> Tensor:
  114. """Compute peak signal-to-noise ratio over state."""
  115. if isinstance(self.sum_squared_error, torch.Tensor):
  116. sum_squared_error = self.sum_squared_error
  117. elif isinstance(self.sum_squared_error, list):
  118. sum_squared_error = torch.cat([value.flatten() for value in self.sum_squared_error])
  119. else:
  120. raise TypeError("Expected sum_squared_error to be a Tensor or a list of Tensors")
  121. if isinstance(self.total, torch.Tensor):
  122. total = self.total
  123. elif isinstance(self.total, list):
  124. total = torch.cat([value.flatten() for value in self.total])
  125. else:
  126. raise TypeError("Expected total to be a Tensor or a list of Tensors")
  127. return _psnr_compute(sum_squared_error, total, self.data_range, base=self.base, reduction=self.reduction)
  128. def plot(
  129. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  130. ) -> _PLOT_OUT_TYPE:
  131. """Plot a single or multiple values from the metric.
  132. Args:
  133. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  134. If no value is provided, will automatically call `metric.compute` and plot that result.
  135. ax: An matplotlib axis object. If provided will add plot to that axis
  136. Returns:
  137. Figure and Axes object
  138. Raises:
  139. ModuleNotFoundError:
  140. If `matplotlib` is not installed
  141. .. plot::
  142. :scale: 75
  143. >>> # Example plotting a single value
  144. >>> import torch
  145. >>> from torchmetrics.image import PeakSignalNoiseRatio
  146. >>> metric = PeakSignalNoiseRatio(data_range=1.0)
  147. >>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
  148. >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
  149. >>> metric.update(preds, target)
  150. >>> fig_, ax_ = metric.plot()
  151. .. plot::
  152. :scale: 75
  153. >>> # Example plotting multiple values
  154. >>> import torch
  155. >>> from torchmetrics.image import PeakSignalNoiseRatio
  156. >>> metric = PeakSignalNoiseRatio(data_range=1.0)
  157. >>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
  158. >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
  159. >>> values = [ ]
  160. >>> for _ in range(10):
  161. ... values.append(metric(preds, target))
  162. >>> fig_, ax_ = metric.plot(values)
  163. """
  164. return self._plot(val, ax)