_deprecated.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. from collections.abc import Sequence
  2. from typing import Optional, Union
  3. from torch import Tensor
  4. from typing_extensions import Literal
  5. from torchmetrics.functional.image.d_lambda import spectral_distortion_index
  6. from torchmetrics.functional.image.ergas import error_relative_global_dimensionless_synthesis
  7. from torchmetrics.functional.image.gradients import image_gradients
  8. from torchmetrics.functional.image.psnr import peak_signal_noise_ratio
  9. from torchmetrics.functional.image.rase import relative_average_spectral_error
  10. from torchmetrics.functional.image.rmse_sw import root_mean_squared_error_using_sliding_window
  11. from torchmetrics.functional.image.sam import spectral_angle_mapper
  12. from torchmetrics.functional.image.ssim import (
  13. multiscale_structural_similarity_index_measure,
  14. structural_similarity_index_measure,
  15. )
  16. from torchmetrics.functional.image.tv import total_variation
  17. from torchmetrics.functional.image.uqi import universal_image_quality_index
  18. from torchmetrics.utilities.prints import _deprecated_root_import_func
  19. def _spectral_distortion_index(
  20. preds: Tensor,
  21. target: Tensor,
  22. p: int = 1,
  23. reduction: Literal["elementwise_mean", "sum", "none"] = "elementwise_mean",
  24. ) -> Tensor:
  25. """Wrapper for deprecated import.
  26. >>> from torch import rand
  27. >>> preds = rand([16, 3, 16, 16])
  28. >>> target = rand([16, 3, 16, 16])
  29. >>> _spectral_distortion_index(preds, target)
  30. tensor(0.0234)
  31. """
  32. _deprecated_root_import_func("spectral_distortion_index", "image")
  33. return spectral_distortion_index(preds=preds, target=target, p=p, reduction=reduction)
  34. def _error_relative_global_dimensionless_synthesis(
  35. preds: Tensor,
  36. target: Tensor,
  37. ratio: float = 4,
  38. reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
  39. ) -> Tensor:
  40. """Wrapper for deprecated import.
  41. >>> from torch import rand
  42. >>> preds = rand([16, 1, 16, 16])
  43. >>> target = preds * 0.75
  44. >>> _error_relative_global_dimensionless_synthesis(preds, target).round()
  45. tensor(10.)
  46. """
  47. _deprecated_root_import_func("error_relative_global_dimensionless_synthesis", "image")
  48. return error_relative_global_dimensionless_synthesis(preds=preds, target=target, ratio=ratio, reduction=reduction)
  49. def _image_gradients(img: Tensor) -> tuple[Tensor, Tensor]:
  50. """Wrapper for deprecated import.
  51. >>> import torch
  52. >>> image = torch.arange(0, 1*1*5*5, dtype=torch.float32)
  53. >>> image = torch.reshape(image, (1, 1, 5, 5))
  54. >>> dy, dx = _image_gradients(image)
  55. >>> dy[0, 0, :, :]
  56. tensor([[5., 5., 5., 5., 5.],
  57. [5., 5., 5., 5., 5.],
  58. [5., 5., 5., 5., 5.],
  59. [5., 5., 5., 5., 5.],
  60. [0., 0., 0., 0., 0.]])
  61. """
  62. _deprecated_root_import_func("image_gradients", "image")
  63. return image_gradients(img=img)
  64. def _peak_signal_noise_ratio(
  65. preds: Tensor,
  66. target: Tensor,
  67. data_range: Union[float, tuple[float, float]] = 3.0,
  68. base: float = 10.0,
  69. reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
  70. dim: Optional[Union[int, tuple[int, ...]]] = None,
  71. ) -> Tensor:
  72. """Wrapper for deprecated import.
  73. >>> from torch import tensor
  74. >>> pred = tensor([[0.0, 1.0], [2.0, 3.0]])
  75. >>> target = tensor([[3.0, 2.0], [1.0, 0.0]])
  76. >>> _peak_signal_noise_ratio(pred, target)
  77. tensor(2.5527)
  78. """
  79. _deprecated_root_import_func("peak_signal_noise_ratio", "image")
  80. return peak_signal_noise_ratio(
  81. preds=preds, target=target, data_range=data_range, base=base, reduction=reduction, dim=dim
  82. )
  83. def _relative_average_spectral_error(preds: Tensor, target: Tensor, window_size: int = 8) -> Tensor:
  84. """Wrapper for deprecated import.
  85. >>> from torch import rand
  86. >>> preds = rand(4, 3, 16, 16)
  87. >>> target = rand(4, 3, 16, 16)
  88. >>> _relative_average_spectral_error(preds, target)
  89. tensor(5326.40...)
  90. """
  91. _deprecated_root_import_func("relative_average_spectral_error", "image")
  92. return relative_average_spectral_error(preds=preds, target=target, window_size=window_size)
  93. def _root_mean_squared_error_using_sliding_window(
  94. preds: Tensor, target: Tensor, window_size: int = 8, return_rmse_map: bool = False
  95. ) -> Union[Optional[Tensor], tuple[Optional[Tensor], Tensor]]:
  96. """Wrapper for deprecated import.
  97. >>> from torch import rand
  98. >>> preds = rand(4, 3, 16, 16)
  99. >>> target = rand(4, 3, 16, 16)
  100. >>> _root_mean_squared_error_using_sliding_window(preds, target)
  101. tensor(0.4158)
  102. """
  103. _deprecated_root_import_func("root_mean_squared_error_using_sliding_window", "image")
  104. return root_mean_squared_error_using_sliding_window(
  105. preds=preds, target=target, window_size=window_size, return_rmse_map=return_rmse_map
  106. )
  107. def _spectral_angle_mapper(
  108. preds: Tensor,
  109. target: Tensor,
  110. reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
  111. ) -> Tensor:
  112. """Wrapper for deprecated import.
  113. >>> from torch import rand
  114. >>> preds = rand([16, 3, 16, 16])
  115. >>> target = rand([16, 3, 16, 16])
  116. >>> _spectral_angle_mapper(preds, target)
  117. tensor(0.5914)
  118. """
  119. _deprecated_root_import_func("spectral_angle_mapper", "image")
  120. return spectral_angle_mapper(preds=preds, target=target, reduction=reduction)
  121. def _multiscale_structural_similarity_index_measure(
  122. preds: Tensor,
  123. target: Tensor,
  124. gaussian_kernel: bool = True,
  125. sigma: Union[float, Sequence[float]] = 1.5,
  126. kernel_size: Union[int, Sequence[int]] = 11,
  127. reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
  128. data_range: Optional[Union[float, tuple[float, float]]] = None,
  129. k1: float = 0.01,
  130. k2: float = 0.03,
  131. betas: tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333),
  132. normalize: Optional[Literal["relu", "simple"]] = "relu",
  133. ) -> Tensor:
  134. """Wrapper for deprecated import.
  135. >>> from torch import rand
  136. >>> preds = rand([3, 3, 256, 256])
  137. >>> target = preds * 0.75
  138. >>> _multiscale_structural_similarity_index_measure(preds, target, data_range=1.0)
  139. tensor(0.9628)
  140. """
  141. _deprecated_root_import_func("multiscale_structural_similarity_index_measure", "image")
  142. return multiscale_structural_similarity_index_measure(
  143. preds=preds,
  144. target=target,
  145. gaussian_kernel=gaussian_kernel,
  146. sigma=sigma,
  147. kernel_size=kernel_size,
  148. reduction=reduction,
  149. data_range=data_range,
  150. k1=k1,
  151. k2=k2,
  152. betas=betas,
  153. normalize=normalize,
  154. )
  155. def _structural_similarity_index_measure(
  156. preds: Tensor,
  157. target: Tensor,
  158. gaussian_kernel: bool = True,
  159. sigma: Union[float, Sequence[float]] = 1.5,
  160. kernel_size: Union[int, Sequence[int]] = 11,
  161. reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
  162. data_range: Optional[Union[float, tuple[float, float]]] = None,
  163. k1: float = 0.01,
  164. k2: float = 0.03,
  165. return_full_image: bool = False,
  166. return_contrast_sensitivity: bool = False,
  167. ) -> Union[Tensor, tuple[Tensor, Tensor]]:
  168. """Wrapper for deprecated import.
  169. >>> import torch
  170. >>> preds = torch.rand([3, 3, 256, 256])
  171. >>> target = preds * 0.75
  172. >>> _structural_similarity_index_measure(preds, target)
  173. tensor(0.9219)
  174. """
  175. _deprecated_root_import_func("spectral_angle_mapper", "image")
  176. return structural_similarity_index_measure(
  177. preds=preds,
  178. target=target,
  179. gaussian_kernel=gaussian_kernel,
  180. sigma=sigma,
  181. kernel_size=kernel_size,
  182. reduction=reduction,
  183. data_range=data_range,
  184. k1=k1,
  185. k2=k2,
  186. return_full_image=return_full_image,
  187. return_contrast_sensitivity=return_contrast_sensitivity,
  188. )
  189. def _total_variation(img: Tensor, reduction: Literal["mean", "sum", "none", None] = "sum") -> Tensor:
  190. """Wrapper for deprecated import.
  191. >>> from torch import rand
  192. >>> img = rand(5, 3, 28, 28)
  193. >>> _total_variation(img)
  194. tensor(7546.8018)
  195. """
  196. _deprecated_root_import_func("total_variation", "image")
  197. return total_variation(img=img, reduction=reduction)
  198. def _universal_image_quality_index(
  199. preds: Tensor,
  200. target: Tensor,
  201. kernel_size: Sequence[int] = (11, 11),
  202. sigma: Sequence[float] = (1.5, 1.5),
  203. reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
  204. ) -> Tensor:
  205. """Wrapper for deprecated import.
  206. >>> import torch
  207. >>> preds = torch.rand([16, 1, 16, 16])
  208. >>> target = preds * 0.75
  209. >>> _universal_image_quality_index(preds, target)
  210. tensor(0.9216)
  211. """
  212. _deprecated_root_import_func("universal_image_quality_index", "image")
  213. return universal_image_quality_index(
  214. preds=preds,
  215. target=target,
  216. kernel_size=kernel_size,
  217. sigma=sigma,
  218. reduction=reduction,
  219. )