| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147 |
- # Copyright The Lightning team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import Optional, Union
- import torch
- from torch import Tensor, tensor
- from typing_extensions import Literal
- from torchmetrics.utilities import rank_zero_warn, reduce
- def _psnr_compute(
- sum_squared_error: Tensor,
- num_obs: Tensor,
- data_range: Tensor,
- base: float = 10.0,
- reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
- ) -> Tensor:
- """Compute peak signal-to-noise ratio.
- Args:
- sum_squared_error: Sum of square of errors over all observations
- num_obs: Number of predictions or observations
- data_range: the range of the data. If None, it is determined from the data (max - min).
- ``data_range`` must be given when ``dim`` is not None.
- base: a base of a logarithm to use
- reduction: a method to reduce metric score over labels.
- - ``'elementwise_mean'``: takes the mean (default)
- - ``'sum'``: takes the sum
- - ``'none'`` or ``None``: no reduction will be applied
- Example:
- >>> preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
- >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
- >>> data_range = target.max() - target.min()
- >>> sum_squared_error, num_obs = _psnr_update(preds, target)
- >>> _psnr_compute(sum_squared_error, num_obs, data_range)
- tensor(2.5527)
- """
- psnr_base_e = 2 * torch.log(data_range) - torch.log(sum_squared_error / num_obs)
- psnr_vals = psnr_base_e * (10 / torch.log(tensor(base)))
- return reduce(psnr_vals, reduction=reduction)
- def _psnr_update(
- preds: Tensor,
- target: Tensor,
- dim: Optional[Union[int, tuple[int, ...]]] = None,
- ) -> tuple[Tensor, Tensor]:
- """Update and return variables required to compute peak signal-to-noise ratio.
- Args:
- preds: Predicted tensor
- target: Ground truth tensor
- dim: Dimensions to reduce PSNR scores over provided as either an integer or a list of integers.
- Default is None meaning scores will be reduced across all dimensions.
- """
- if not preds.is_floating_point():
- preds = preds.to(torch.float32)
- if not target.is_floating_point():
- target = target.to(torch.float32)
- if dim is None:
- sum_squared_error = torch.sum(torch.pow(preds - target, 2))
- num_obs = tensor(target.numel(), device=target.device)
- return sum_squared_error, num_obs
- diff = preds - target
- sum_squared_error = torch.sum(diff * diff, dim=dim)
- dim_list = [dim] if isinstance(dim, int) else list(dim)
- if not dim_list:
- num_obs = tensor(target.numel(), device=target.device)
- else:
- num_obs = tensor(target.size(), device=target.device)[dim_list].prod()
- num_obs = num_obs.expand_as(sum_squared_error)
- return sum_squared_error, num_obs
- def peak_signal_noise_ratio(
- preds: Tensor,
- target: Tensor,
- data_range: Union[float, tuple[float, float]],
- base: float = 10.0,
- reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
- dim: Optional[Union[int, tuple[int, ...]]] = None,
- ) -> Tensor:
- """Compute the peak signal-to-noise ratio.
- Args:
- preds: estimated signal
- target: groun truth signal
- data_range:
- the range of the data. If a tuple is provided then the range is calculated as the difference and
- input is clamped between the values.
- base: a base of a logarithm to use
- reduction: a method to reduce metric score over labels.
- - ``'elementwise_mean'``: takes the mean (default)
- - ``'sum'``: takes the sum
- - ``'none'`` or None``: no reduction will be applied
- dim:
- Dimensions to reduce PSNR scores over provided as either an integer or a list of integers. Default is
- None meaning scores will be reduced across all dimensions.
- Return:
- Tensor with PSNR score
- Example:
- >>> from torchmetrics.functional.image import peak_signal_noise_ratio
- >>> pred = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
- >>> target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
- >>> peak_signal_noise_ratio(pred, target, data_range=3.0)
- tensor(2.5527)
- .. attention::
- Half precision is only support on GPU for this metric.
- """
- if dim is None and reduction != "elementwise_mean":
- rank_zero_warn(f"The `reduction={reduction}` will not have any effect when `dim` is None.")
- if isinstance(data_range, tuple):
- preds = torch.clamp(preds, min=data_range[0], max=data_range[1])
- target = torch.clamp(target, min=data_range[0], max=data_range[1])
- data_range_val = tensor(data_range[1] - data_range[0])
- else:
- data_range_val = tensor(float(data_range))
- sum_squared_error, num_obs = _psnr_update(preds, target, dim=dim)
- return _psnr_compute(sum_squared_error, num_obs, data_range_val, base=base, reduction=reduction)
|