ssim.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, List, Optional, Union
  16. import torch
  17. from torch import Tensor
  18. from typing_extensions import Literal
  19. from torchmetrics.functional.image.ssim import _multiscale_ssim_update, _ssim_check_inputs, _ssim_update
  20. from torchmetrics.metric import Metric
  21. from torchmetrics.utilities.data import dim_zero_cat
  22. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  23. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  24. if not _MATPLOTLIB_AVAILABLE:
  25. __doctest_skip__ = ["StructuralSimilarityIndexMeasure.plot", "MultiScaleStructuralSimilarityIndexMeasure.plot"]
  26. class StructuralSimilarityIndexMeasure(Metric):
  27. """Compute Structural Similarity Index Measure (SSIM_).
  28. As input to ``forward`` and ``update`` the metric accepts the following input
  29. - ``preds`` (:class:`~torch.Tensor`): Predictions from model
  30. - ``target`` (:class:`~torch.Tensor`): Ground truth values
  31. As output of `forward` and `compute` the metric returns the following output
  32. - ``ssim`` (:class:`~torch.Tensor`): if ``reduction!='none'`` returns float scalar tensor with average SSIM value
  33. over sample else returns tensor of shape ``(N,)`` with SSIM values per sample
  34. Args:
  35. preds: estimated image
  36. target: ground truth image
  37. gaussian_kernel: If ``True`` (default), a gaussian kernel is used, if ``False`` a uniform kernel is used
  38. sigma: Standard deviation of the gaussian kernel, anisotropic kernels are possible.
  39. Ignored if a uniform kernel is used
  40. kernel_size: the size of the uniform kernel, anisotropic kernels are possible.
  41. Ignored if a Gaussian kernel is used
  42. reduction: a method to reduce metric score over individual batch scores
  43. - ``'elementwise_mean'``: takes the mean
  44. - ``'sum'``: takes the sum
  45. - ``'none'`` or ``None``: no reduction will be applied
  46. data_range:
  47. the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then
  48. the range is calculated as the difference and input is clamped between the values.
  49. k1: Parameter of SSIM.
  50. k2: Parameter of SSIM.
  51. return_full_image: If true, the full ``ssim`` image is returned as a second argument.
  52. Mutually exclusive with ``return_contrast_sensitivity``
  53. return_contrast_sensitivity: If true, the constant term is returned as a second argument.
  54. The luminance term can be obtained with luminance=ssim/contrast
  55. Mutually exclusive with ``return_full_image``
  56. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  57. Example:
  58. >>> import torch
  59. >>> from torchmetrics.image import StructuralSimilarityIndexMeasure
  60. >>> preds = torch.rand([3, 3, 256, 256])
  61. >>> target = preds * 0.75
  62. >>> ssim = StructuralSimilarityIndexMeasure(data_range=1.0)
  63. >>> ssim(preds, target)
  64. tensor(0.9219)
  65. """
  66. higher_is_better: bool = True
  67. is_differentiable: bool = True
  68. full_state_update: bool = False
  69. plot_lower_bound: float = 0.0
  70. plot_upper_bound: float = 1.0
  71. preds: List[Tensor]
  72. target: List[Tensor]
  73. def __init__(
  74. self,
  75. gaussian_kernel: bool = True,
  76. sigma: Union[float, Sequence[float]] = 1.5,
  77. kernel_size: Union[int, Sequence[int]] = 11,
  78. reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
  79. data_range: Optional[Union[float, tuple[float, float]]] = None,
  80. k1: float = 0.01,
  81. k2: float = 0.03,
  82. return_full_image: bool = False,
  83. return_contrast_sensitivity: bool = False,
  84. **kwargs: Any,
  85. ) -> None:
  86. super().__init__(**kwargs)
  87. valid_reduction = ("elementwise_mean", "sum", "none", None)
  88. if reduction not in valid_reduction:
  89. raise ValueError(f"Argument `reduction` must be one of {valid_reduction}, but got {reduction}")
  90. if reduction in ("elementwise_mean", "sum"):
  91. self.add_state("similarity", default=torch.tensor(0.0), dist_reduce_fx="sum")
  92. else:
  93. self.add_state("similarity", default=[], dist_reduce_fx=None)
  94. self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
  95. if return_contrast_sensitivity or return_full_image:
  96. self.add_state("image_return", default=[], dist_reduce_fx="cat")
  97. self.gaussian_kernel = gaussian_kernel
  98. self.sigma = sigma
  99. self.kernel_size = kernel_size
  100. self.reduction = reduction
  101. self.data_range = data_range
  102. self.k1 = k1
  103. self.k2 = k2
  104. self.return_full_image = return_full_image
  105. self.return_contrast_sensitivity = return_contrast_sensitivity
  106. def update(self, preds: Tensor, target: Tensor) -> None:
  107. """Update state with predictions and targets."""
  108. preds, target = _ssim_check_inputs(preds, target)
  109. similarity_pack = _ssim_update(
  110. preds,
  111. target,
  112. self.gaussian_kernel,
  113. self.sigma,
  114. self.kernel_size,
  115. self.data_range,
  116. self.k1,
  117. self.k2,
  118. self.return_full_image,
  119. self.return_contrast_sensitivity,
  120. )
  121. if isinstance(similarity_pack, tuple):
  122. similarity, image = similarity_pack
  123. else:
  124. similarity = similarity_pack
  125. if self.return_contrast_sensitivity or self.return_full_image:
  126. if not isinstance(self.image_return, list):
  127. raise TypeError("Expected `self.image_return` to be a list when returning images.")
  128. self.image_return.append(image)
  129. if self.reduction in ("elementwise_mean", "sum"):
  130. if not isinstance(self.similarity, torch.Tensor): # Ensure it's a Tensor
  131. raise TypeError("Expected `self.similarity` to be a Tensor for reductions.")
  132. self.similarity += similarity.sum()
  133. if not isinstance(self.total, torch.Tensor):
  134. raise TypeError("Expected `self.total` to be a Tensor.")
  135. self.total += preds.shape[0]
  136. else:
  137. if not isinstance(self.similarity, list):
  138. raise TypeError("Expected `self.similarity` to be a list when reduction='none'.")
  139. self.similarity.append(similarity)
  140. def compute(self) -> Union[Tensor, tuple[Tensor, Tensor]]:
  141. """Compute SSIM over state."""
  142. if self.reduction == "elementwise_mean":
  143. if isinstance(self.similarity, Tensor) and isinstance(self.total, Tensor):
  144. similarity = self.similarity / self.total
  145. else:
  146. raise TypeError(
  147. "Expected `self.similarity`and `self.total` to be of type Tensor for elementwise_mean reduction."
  148. )
  149. elif self.reduction == "sum":
  150. if not isinstance(self.similarity, Tensor):
  151. raise TypeError("Expected `self.similarity` to be a Tensor for sum reduction.")
  152. similarity = self.similarity
  153. else:
  154. if isinstance(self.similarity, list):
  155. similarity = dim_zero_cat(self.similarity) # Concatenate list of Tensors
  156. else:
  157. raise TypeError("Expected `self.similarity` to be a list for reduction='none'.")
  158. if self.return_contrast_sensitivity or self.return_full_image:
  159. if isinstance(self.image_return, list):
  160. image_return = dim_zero_cat(self.image_return) # Concatenate list of Tensors
  161. else:
  162. raise TypeError("Expected `self.image_return` to be a list when returning images.")
  163. return similarity, image_return
  164. return similarity
  165. def plot(
  166. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  167. ) -> _PLOT_OUT_TYPE:
  168. """Plot a single or multiple values from the metric.
  169. Args:
  170. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  171. If no value is provided, will automatically call `metric.compute` and plot that result.
  172. ax: An matplotlib axis object. If provided will add plot to that axis
  173. Returns:
  174. Figure and Axes object
  175. Raises:
  176. ModuleNotFoundError:
  177. If `matplotlib` is not installed
  178. .. plot::
  179. :scale: 75
  180. >>> # Example plotting a single value
  181. >>> import torch
  182. >>> from torchmetrics.image import StructuralSimilarityIndexMeasure
  183. >>> preds = torch.rand([3, 3, 256, 256])
  184. >>> target = preds * 0.75
  185. >>> metric = StructuralSimilarityIndexMeasure(data_range=1.0)
  186. >>> metric.update(preds, target)
  187. >>> fig_, ax_ = metric.plot()
  188. .. plot::
  189. :scale: 75
  190. >>> # Example plotting multiple values
  191. >>> import torch
  192. >>> from torchmetrics.image import StructuralSimilarityIndexMeasure
  193. >>> preds = torch.rand([3, 3, 256, 256])
  194. >>> target = preds * 0.75
  195. >>> metric = StructuralSimilarityIndexMeasure(data_range=1.0)
  196. >>> values = [ ]
  197. >>> for _ in range(10):
  198. ... values.append(metric(preds, target))
  199. >>> fig_, ax_ = metric.plot(values)
  200. """
  201. return self._plot(val, ax)
  202. class MultiScaleStructuralSimilarityIndexMeasure(Metric):
  203. """Compute `MultiScaleSSIM`_, Multi-scale Structural Similarity Index Measure.
  204. This metric is is a generalization of Structural Similarity Index Measure by incorporating image details at
  205. different resolution scores.
  206. As input to ``forward`` and ``update`` the metric accepts the following input
  207. - ``preds`` (:class:`~torch.Tensor`): Predictions from model
  208. - ``target`` (:class:`~torch.Tensor`): Ground truth values
  209. As output of `forward` and `compute` the metric returns the following output
  210. - ``msssim`` (:class:`~torch.Tensor`): if ``reduction!='none'`` returns float scalar tensor with average MSSSIM
  211. value over sample else returns tensor of shape ``(N,)`` with SSIM values per sample
  212. Args:
  213. gaussian_kernel: If ``True`` (default), a gaussian kernel is used, if false a uniform kernel is used
  214. kernel_size: size of the gaussian kernel
  215. sigma: Standard deviation of the gaussian kernel
  216. reduction: a method to reduce metric score over labels.
  217. - ``'elementwise_mean'``: takes the mean
  218. - ``'sum'``: takes the sum
  219. - ``'none'`` or ``None``: no reduction will be applied
  220. data_range:
  221. the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then
  222. the range is calculated as the difference and input is clamped between the values.
  223. The ``data_range`` must be given when ``dim`` is not None.
  224. k1: Parameter of structural similarity index measure.
  225. k2: Parameter of structural similarity index measure.
  226. betas: Exponent parameters for individual similarities and contrastive sensitivities returned by different image
  227. resolutions.
  228. normalize: When MultiScaleStructuralSimilarityIndexMeasure loss is used for training, it is desirable to use
  229. normalizes to improve the training stability. This `normalize` argument is out of scope of the original
  230. implementation [1], and it is adapted from https://github.com/jorge-pessoa/pytorch-msssim instead.
  231. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  232. Return:
  233. Tensor with Multi-Scale SSIM score
  234. Raises:
  235. ValueError:
  236. If ``kernel_size`` is not an int or a Sequence of ints with size 2 or 3.
  237. ValueError:
  238. If ``betas`` is not a tuple of floats with length 2.
  239. ValueError:
  240. If ``normalize`` is neither `None`, `ReLU` nor `simple`.
  241. Example:
  242. >>> from torch import rand
  243. >>> from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
  244. >>> preds = torch.rand([3, 3, 256, 256])
  245. >>> target = preds * 0.75
  246. >>> ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0)
  247. >>> ms_ssim(preds, target)
  248. tensor(0.9628)
  249. """
  250. higher_is_better: bool = True
  251. is_differentiable: bool = True
  252. full_state_update: bool = False
  253. plot_lower_bound: float = 0.0
  254. plot_upper_bound: float = 1.0
  255. preds: List[Tensor]
  256. target: List[Tensor]
  257. def __init__(
  258. self,
  259. gaussian_kernel: bool = True,
  260. kernel_size: Union[int, Sequence[int]] = 11,
  261. sigma: Union[float, Sequence[float]] = 1.5,
  262. reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
  263. data_range: Optional[Union[float, tuple[float, float]]] = None,
  264. k1: float = 0.01,
  265. k2: float = 0.03,
  266. betas: tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333),
  267. normalize: Literal["relu", "simple", None] = "relu",
  268. **kwargs: Any,
  269. ) -> None:
  270. super().__init__(**kwargs)
  271. valid_reduction = ("elementwise_mean", "sum", "none", None)
  272. if reduction not in valid_reduction:
  273. raise ValueError(f"Argument `reduction` must be one of {valid_reduction}, but got {reduction}")
  274. if reduction in ("elementwise_mean", "sum"):
  275. self.add_state("similarity", default=torch.tensor(0.0), dist_reduce_fx="sum")
  276. else:
  277. self.add_state("similarity", default=[], dist_reduce_fx=None)
  278. self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
  279. if not (isinstance(kernel_size, (Sequence, int))):
  280. raise ValueError(
  281. f"Argument `kernel_size` expected to be an sequence or an int, or a single int. Got {kernel_size}"
  282. )
  283. if isinstance(kernel_size, Sequence) and (
  284. len(kernel_size) not in (2, 3) or not all(isinstance(ks, int) for ks in kernel_size)
  285. ):
  286. raise ValueError(
  287. "Argument `kernel_size` expected to be an sequence of size 2 or 3 where each element is an int, "
  288. f"or a single int. Got {kernel_size}"
  289. )
  290. self.gaussian_kernel = gaussian_kernel
  291. self.sigma = sigma
  292. self.kernel_size = kernel_size
  293. self.reduction = reduction
  294. self.data_range = data_range
  295. self.k1 = k1
  296. self.k2 = k2
  297. if not isinstance(betas, tuple):
  298. raise ValueError("Argument `betas` is expected to be of a type tuple.")
  299. if isinstance(betas, tuple) and not all(isinstance(beta, float) for beta in betas):
  300. raise ValueError("Argument `betas` is expected to be a tuple of floats.")
  301. self.betas = betas
  302. if normalize and normalize not in ("relu", "simple"):
  303. raise ValueError("Argument `normalize` to be expected either `None` or one of 'relu' or 'simple'")
  304. self.normalize = normalize
  305. def update(self, preds: Tensor, target: Tensor) -> None:
  306. """Update state with predictions and targets."""
  307. preds, target = _ssim_check_inputs(preds, target)
  308. similarity = _multiscale_ssim_update(
  309. preds,
  310. target,
  311. self.gaussian_kernel,
  312. self.sigma,
  313. self.kernel_size,
  314. self.data_range,
  315. self.k1,
  316. self.k2,
  317. self.betas,
  318. self.normalize,
  319. )
  320. if self.reduction in ("none", None):
  321. if not isinstance(self.similarity, list):
  322. raise TypeError("Expected `self.similarity` to be a list for reduction='none'.")
  323. self.similarity.append(similarity)
  324. else:
  325. if not isinstance(self.similarity, Tensor):
  326. raise TypeError("Expected `self.similarity` to be a Tensor for elementwise_mean or sum reduction.")
  327. self.similarity += similarity.sum()
  328. if not isinstance(self.total, Tensor):
  329. raise TypeError("Expected `self.total` to be a Tensor.")
  330. self.total += torch.tensor(preds.shape[0], dtype=self.total.dtype, device=self.total.device)
  331. def compute(self) -> Tensor:
  332. """Compute MS-SSIM over state."""
  333. if self.reduction in ("none", None):
  334. if isinstance(self.similarity, list):
  335. return dim_zero_cat(self.similarity)
  336. raise TypeError("Expected `self.similarity` to be a list for reduction='none'.")
  337. if self.reduction == "sum":
  338. if isinstance(self.similarity, Tensor):
  339. return self.similarity
  340. raise TypeError("Expected `self.similarity` to be a Tensor for sum reduction.")
  341. if isinstance(self.similarity, Tensor) and isinstance(self.total, Tensor):
  342. return self.similarity / self.total
  343. raise TypeError("Expected `self.similarity` and `self.total` to be Tensors for elementwise_mean reduction.")
  344. def plot(
  345. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  346. ) -> _PLOT_OUT_TYPE:
  347. """Plot a single or multiple values from the metric.
  348. Args:
  349. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  350. If no value is provided, will automatically call `metric.compute` and plot that result.
  351. ax: An matplotlib axis object. If provided will add plot to that axis
  352. Returns:
  353. Figure and Axes object
  354. Raises:
  355. ModuleNotFoundError:
  356. If `matplotlib` is not installed
  357. .. plot::
  358. :scale: 75
  359. >>> # Example plotting a single value
  360. >>> from torch import rand
  361. >>> from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
  362. >>> preds = rand([3, 3, 256, 256])
  363. >>> target = preds * 0.75
  364. >>> metric = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0)
  365. >>> metric.update(preds, target)
  366. >>> fig_, ax_ = metric.plot()
  367. .. plot::
  368. :scale: 75
  369. >>> # Example plotting multiple values
  370. >>> from torch import rand
  371. >>> from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
  372. >>> preds = rand([3, 3, 256, 256])
  373. >>> target = preds * 0.75
  374. >>> metric = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0)
  375. >>> values = [ ]
  376. >>> for _ in range(10):
  377. ... values.append(metric(preds, target))
  378. >>> fig_, ax_ = metric.plot(values)
  379. """
  380. return self._plot(val, ax)