| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265 |
- from collections.abc import Sequence
- from typing import Optional, Union
- from torch import Tensor
- from typing_extensions import Literal
- from torchmetrics.functional.image.d_lambda import spectral_distortion_index
- from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis
- from torchmetrics.functional.image.gradients import image_gradients
- from torchmetrics.functional.image.psnr import peak_signal_noise_ratio
- from torchmetrics.functional.image.rase import relative_average_spectral_error
- from torchmetrics.functional.image.rmse_sw import root_mean_squared_error_using_sliding_window
- from torchmetrics.functional.image.sam import spectral_angle_mapper
- from torchmetrics.functional.image.ssim import (
- multiscale_structural_similarity_index_measure,
- structural_similarity_index_measure,
- )
- from torchmetrics.functional.image.tv import total_variation
- from torchmetrics.functional.image.uqi import universal_image_quality_index
- from torchmetrics.utilities.prints import _deprecated_root_import_func
- def _spectral_distortion_index(
- preds: Tensor,
- target: Tensor,
- p: int = 1,
- reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean",
- ) -> Tensor:
- """Wrapper for deprecated import.
- >>> from torch import rand
- >>> preds = rand([16, 3, 16, 16])
- >>> target = rand([16, 3, 16, 16])
- >>> _spectral_distortion_index(preds, target)
- tensor(0.0234)
- """
- _deprecated_root_import_func("spectral_distortion_index", "image")
- return spectral_distortion_index(preds=preds, target=target, p=p, reduction=reduction)
- def _error_relative_global_dimensionless_synthesis(
- preds: Tensor,
- target: Tensor,
- ratio: float = 4,
- reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
- ) -> Tensor:
- """Wrapper for deprecated import.
- >>> from torch import rand
- >>> preds = rand([16, 1, 16, 16])
- >>> target = preds * 0.75
- >>> _error_relative_global_dimensionless_synthesis(preds, target).round()
- tensor(10.)
- """
- _deprecated_root_import_func("error_relative_global_dimensionless_synthesis", "image")
- return error_relative_global_dimensionless_synthesis(preds=preds, target=target, ratio=ratio, reduction=reduction)
- def _image_gradients(img: Tensor) -> tuple[Tensor, Tensor]:
- """Wrapper for deprecated import.
- >>> import torch
- >>> image = torch.arange(0, 1*1*5*5, dtype=torch.float32)
- >>> image = torch.reshape(image, (1, 1, 5, 5))
- >>> dy, dx = _image_gradients(image)
- >>> dy[0, 0, :, :]
- tensor([[5., 5., 5., 5., 5.],
- [5., 5., 5., 5., 5.],
- [5., 5., 5., 5., 5.],
- [5., 5., 5., 5., 5.],
- [0., 0., 0., 0., 0.]])
- """
- _deprecated_root_import_func("image_gradients", "image")
- return image_gradients(img=img)
- def _peak_signal_noise_ratio(
- preds: Tensor,
- target: Tensor,
- data_range: Union[float, tuple[float, float]] = 3.0,
- base: float = 10.0,
- reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
- dim: Optional[Union[int, tuple[int, ...]]] = None,
- ) -> Tensor:
- """Wrapper for deprecated import.
- >>> from torch import tensor
- >>> pred = tensor([[0.0, 1.0], [2.0, 3.0]])
- >>> target = tensor([[3.0, 2.0], [1.0, 0.0]])
- >>> _peak_signal_noise_ratio(pred, target)
- tensor(2.5527)
- """
- _deprecated_root_import_func("peak_signal_noise_ratio", "image")
- return peak_signal_noise_ratio(
- preds=preds, target=target, data_range=data_range, base=base, reduction=reduction, dim=dim
- )
- def _relative_average_spectral_error(preds: Tensor, target: Tensor, window_size: int = 8) -> Tensor:
- """Wrapper for deprecated import.
- >>> from torch import rand
- >>> preds = rand(4, 3, 16, 16)
- >>> target = rand(4, 3, 16, 16)
- >>> _relative_average_spectral_error(preds, target)
- tensor(5326.40...)
- """
- _deprecated_root_import_func("relative_average_spectral_error", "image")
- return relative_average_spectral_error(preds=preds, target=target, window_size=window_size)
- def _root_mean_squared_error_using_sliding_window(
- preds: Tensor, target: Tensor, window_size: int = 8, return_rmse_map: bool = False
- ) -> Union[Optional[Tensor], tuple[Optional[Tensor], Tensor]]:
- """Wrapper for deprecated import.
- >>> from torch import rand
- >>> preds = rand(4, 3, 16, 16)
- >>> target = rand(4, 3, 16, 16)
- >>> _root_mean_squared_error_using_sliding_window(preds, target)
- tensor(0.4158)
- """
- _deprecated_root_import_func("root_mean_squared_error_using_sliding_window", "image")
- return root_mean_squared_error_using_sliding_window(
- preds=preds, target=target, window_size=window_size, return_rmse_map=return_rmse_map
- )
- def _spectral_angle_mapper(
- preds: Tensor,
- target: Tensor,
- reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
- ) -> Tensor:
- """Wrapper for deprecated import.
- >>> from torch import rand
- >>> preds = rand([16, 3, 16, 16])
- >>> target = rand([16, 3, 16, 16])
- >>> _spectral_angle_mapper(preds, target)
- tensor(0.5914)
- """
- _deprecated_root_import_func("spectral_angle_mapper", "image")
- return spectral_angle_mapper(preds=preds, target=target, reduction=reduction)
- def _multiscale_structural_similarity_index_measure(
- preds: Tensor,
- target: Tensor,
- gaussian_kernel: bool = True,
- sigma: Union[float, Sequence[float]] = 1.5,
- kernel_size: Union[int, Sequence[int]] = 11,
- reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
- data_range: Optional[Union[float, tuple[float, float]]] = None,
- k1: float = 0.01,
- k2: float = 0.03,
- betas: tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333),
- normalize: Optional[Literal["relu", "simple"]] = "relu",
- ) -> Tensor:
- """Wrapper for deprecated import.
- >>> from torch import rand
- >>> preds = rand([3, 3, 256, 256])
- >>> target = preds * 0.75
- >>> _multiscale_structural_similarity_index_measure(preds, target, data_range=1.0)
- tensor(0.9628)
- """
- _deprecated_root_import_func("multiscale_structural_similarity_index_measure", "image")
- return multiscale_structural_similarity_index_measure(
- preds=preds,
- target=target,
- gaussian_kernel=gaussian_kernel,
- sigma=sigma,
- kernel_size=kernel_size,
- reduction=reduction,
- data_range=data_range,
- k1=k1,
- k2=k2,
- betas=betas,
- normalize=normalize,
- )
- def _structural_similarity_index_measure(
- preds: Tensor,
- target: Tensor,
- gaussian_kernel: bool = True,
- sigma: Union[float, Sequence[float]] = 1.5,
- kernel_size: Union[int, Sequence[int]] = 11,
- reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
- data_range: Optional[Union[float, tuple[float, float]]] = None,
- k1: float = 0.01,
- k2: float = 0.03,
- return_full_image: bool = False,
- return_contrast_sensitivity: bool = False,
- ) -> Union[Tensor, tuple[Tensor, Tensor]]:
- """Wrapper for deprecated import.
- >>> import torch
- >>> preds = torch.rand([3, 3, 256, 256])
- >>> target = preds * 0.75
- >>> _structural_similarity_index_measure(preds, target)
- tensor(0.9219)
- """
- _deprecated_root_import_func("spectral_angle_mapper", "image")
- return structural_similarity_index_measure(
- preds=preds,
- target=target,
- gaussian_kernel=gaussian_kernel,
- sigma=sigma,
- kernel_size=kernel_size,
- reduction=reduction,
- data_range=data_range,
- k1=k1,
- k2=k2,
- return_full_image=return_full_image,
- return_contrast_sensitivity=return_contrast_sensitivity,
- )
- def _total_variation(img: Tensor, reduction: Literal["mean", "sum", "none", None] = "sum") -> Tensor:
- """Wrapper for deprecated import.
- >>> from torch import rand
- >>> img = rand(5, 3, 28, 28)
- >>> _total_variation(img)
- tensor(7546.8018)
- """
- _deprecated_root_import_func("total_variation", "image")
- return total_variation(img=img, reduction=reduction)
- def _universal_image_quality_index(
- preds: Tensor,
- target: Tensor,
- kernel_size: Sequence[int] = (11, 11),
- sigma: Sequence[float] = (1.5, 1.5),
- reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
- ) -> Tensor:
- """Wrapper for deprecated import.
- >>> import torch
- >>> preds = torch.rand([16, 1, 16, 16])
- >>> target = preds * 0.75
- >>> _universal_image_quality_index(preds, target)
- tensor(0.9216)
- """
- _deprecated_root_import_func("universal_image_quality_index", "image")
- return universal_image_quality_index(
- preds=preds,
- target=target,
- kernel_size=kernel_size,
- sigma=sigma,
- reduction=reduction,
- )
|