| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262 |
- from collections.abc import Sequence
- from typing import Any, Optional, Union
- from typing_extensions import Literal
- from torchmetrics.image.d_lambda import SpectralDistortionIndex
- from torchmetrics.image.ergas import ErrorRelativeGlobalDimensionlessSynthesis
- from torchmetrics.image.psnr import PeakSignalNoiseRatio
- from torchmetrics.image.rase import RelativeAverageSpectralError
- from torchmetrics.image.rmse_sw import RootMeanSquaredErrorUsingSlidingWindow
- from torchmetrics.image.sam import SpectralAngleMapper
- from torchmetrics.image.ssim import MultiScaleStructuralSimilarityIndexMeasure, StructuralSimilarityIndexMeasure
- from torchmetrics.image.tv import TotalVariation
- from torchmetrics.image.uqi import UniversalImageQualityIndex
- from torchmetrics.utilities.prints import _deprecated_root_import_class
- class _ErrorRelativeGlobalDimensionlessSynthesis(ErrorRelativeGlobalDimensionlessSynthesis):
- """Wrapper for deprecated import.
- >>> from torch import rand
- >>> preds = rand([16, 1, 16, 16])
- >>> target = preds * 0.75
- >>> ergas = _ErrorRelativeGlobalDimensionlessSynthesis()
- >>> ergas(preds, target).round()
- tensor(10.)
- """
- def __init__(
- self,
- ratio: float = 4,
- reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
- **kwargs: Any,
- ) -> None:
- _deprecated_root_import_class("ErrorRelativeGlobalDimensionlessSynthesis", "image")
- super().__init__(ratio=ratio, reduction=reduction, **kwargs)
- class _MultiScaleStructuralSimilarityIndexMeasure(MultiScaleStructuralSimilarityIndexMeasure):
- """Wrapper for deprecated import.
- >>> from torch import rand
- >>> preds = rand([3, 3, 256, 256])
- >>> target = preds * 0.75
- >>> ms_ssim = _MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0)
- >>> ms_ssim(preds, target)
- tensor(0.9628)
- """
- def __init__(
- self,
- gaussian_kernel: bool = True,
- kernel_size: Union[int, Sequence[int]] = 11,
- sigma: Union[float, Sequence[float]] = 1.5,
- reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
- data_range: Optional[Union[float, tuple[float, float]]] = None,
- k1: float = 0.01,
- k2: float = 0.03,
- betas: tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333),
- normalize: Literal["relu", "simple", None] = "relu",
- **kwargs: Any,
- ) -> None:
- _deprecated_root_import_class("MultiScaleStructuralSimilarityIndexMeasure", "image")
- super().__init__(
- gaussian_kernel=gaussian_kernel,
- kernel_size=kernel_size,
- sigma=sigma,
- reduction=reduction,
- data_range=data_range,
- k1=k1,
- k2=k2,
- betas=betas,
- normalize=normalize,
- **kwargs,
- )
- class _PeakSignalNoiseRatio(PeakSignalNoiseRatio):
- """Wrapper for deprecated import.
- >>> from torch import tensor
- >>> psnr = _PeakSignalNoiseRatio()
- >>> preds = tensor([[0.0, 1.0], [2.0, 3.0]])
- >>> target = tensor([[3.0, 2.0], [1.0, 0.0]])
- >>> psnr(preds, target)
- tensor(2.5527)
- """
- def __init__(
- self,
- data_range: Union[float, tuple[float, float]] = 3.0,
- base: float = 10.0,
- reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
- dim: Optional[Union[int, tuple[int, ...]]] = None,
- **kwargs: Any,
- ) -> None:
- _deprecated_root_import_class("PeakSignalNoiseRatio", "image")
- super().__init__(data_range=data_range, base=base, reduction=reduction, dim=dim, **kwargs)
- class _RelativeAverageSpectralError(RelativeAverageSpectralError):
- """Wrapper for deprecated import.
- >>> from torch import rand
- >>> preds = rand(4, 3, 16, 16)
- >>> target = rand(4, 3, 16, 16)
- >>> rase = _RelativeAverageSpectralError()
- >>> rase(preds, target)
- tensor(5326.40...)
- """
- def __init__(
- self,
- window_size: int = 8,
- **kwargs: dict[str, Any],
- ) -> None:
- _deprecated_root_import_class("RelativeAverageSpectralError", "image")
- super().__init__(window_size=window_size, **kwargs)
- class _RootMeanSquaredErrorUsingSlidingWindow(RootMeanSquaredErrorUsingSlidingWindow):
- """Wrapper for deprecated import.
- >>> from torch import rand
- >>> preds = rand(4, 3, 16, 16)
- >>> target = rand(4, 3, 16, 16)
- >>> rmse_sw = RootMeanSquaredErrorUsingSlidingWindow()
- >>> rmse_sw(preds, target)
- tensor(0.4158)
- """
- def __init__(
- self,
- window_size: int = 8,
- **kwargs: dict[str, Any],
- ) -> None:
- _deprecated_root_import_class("RootMeanSquaredErrorUsingSlidingWindow", "image")
- super().__init__(window_size=window_size, **kwargs)
- class _SpectralAngleMapper(SpectralAngleMapper):
- """Wrapper for deprecated import.
- >>> from torch import rand
- >>> preds = rand([16, 3, 16, 16])
- >>> target = rand([16, 3, 16, 16])
- >>> sam = _SpectralAngleMapper()
- >>> sam(preds, target)
- tensor(0.5914)
- """
- def __init__(
- self,
- reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean",
- **kwargs: Any,
- ) -> None:
- _deprecated_root_import_class("SpectralAngleMapper", "image")
- super().__init__(reduction=reduction, **kwargs)
- class _SpectralDistortionIndex(SpectralDistortionIndex):
- """Wrapper for deprecated import.
- >>> from torch import rand
- >>> preds = rand([16, 3, 16, 16])
- >>> target = rand([16, 3, 16, 16])
- >>> sdi = _SpectralDistortionIndex()
- >>> sdi(preds, target)
- tensor(0.0234)
- """
- def __init__(
- self, p: int = 1, reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean", **kwargs: Any
- ) -> None:
- _deprecated_root_import_class("SpectralDistortionIndex", "image")
- super().__init__(p=p, reduction=reduction, **kwargs)
- class _StructuralSimilarityIndexMeasure(StructuralSimilarityIndexMeasure):
- """Wrapper for deprecated import.
- >>> import torch
- >>> preds = torch.rand([3, 3, 256, 256])
- >>> target = preds * 0.75
- >>> ssim = _StructuralSimilarityIndexMeasure(data_range=1.0)
- >>> ssim(preds, target)
- tensor(0.9219)
- """
- def __init__(
- self,
- gaussian_kernel: bool = True,
- sigma: Union[float, Sequence[float]] = 1.5,
- kernel_size: Union[int, Sequence[int]] = 11,
- reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
- data_range: Optional[Union[float, tuple[float, float]]] = None,
- k1: float = 0.01,
- k2: float = 0.03,
- return_full_image: bool = False,
- return_contrast_sensitivity: bool = False,
- **kwargs: Any,
- ) -> None:
- _deprecated_root_import_class("StructuralSimilarityIndexMeasure", "image")
- super().__init__(
- gaussian_kernel=gaussian_kernel,
- sigma=sigma,
- kernel_size=kernel_size,
- reduction=reduction,
- data_range=data_range,
- k1=k1,
- k2=k2,
- return_full_image=return_full_image,
- return_contrast_sensitivity=return_contrast_sensitivity,
- **kwargs,
- )
- class _TotalVariation(TotalVariation):
- """Wrapper for deprecated import.
- >>> from torch import rand
- >>> tv = _TotalVariation()
- >>> img = rand(5, 3, 28, 28)
- >>> tv(img)
- tensor(7546.8018)
- """
- def __init__(self, reduction: Literal["mean", "sum", "none", None] = "sum", **kwargs: Any) -> None:
- _deprecated_root_import_class("TotalVariation", "image")
- super().__init__(reduction=reduction, **kwargs)
- class _UniversalImageQualityIndex(UniversalImageQualityIndex):
- """Wrapper for deprecated import.
- >>> import torch
- >>> preds = torch.rand([16, 1, 16, 16])
- >>> target = preds * 0.75
- >>> uqi = _UniversalImageQualityIndex()
- >>> uqi(preds, target)
- tensor(0.9216)
- """
- def __init__(
- self,
- kernel_size: Sequence[int] = (11, 11),
- sigma: Sequence[float] = (1.5, 1.5),
- reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
- **kwargs: Any,
- ) -> None:
- _deprecated_root_import_class("UniversalImageQualityIndex", "image")
- super().__init__(kernel_size=kernel_size, sigma=sigma, reduction=reduction, **kwargs)
|