tv.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, List, Optional, Union
  16. import torch
  17. from torch import Tensor, tensor
  18. from typing_extensions import Literal
  19. from torchmetrics.functional.image.tv import _total_variation_compute, _total_variation_update
  20. from torchmetrics.metric import Metric
  21. from torchmetrics.utilities.data import dim_zero_cat
  22. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  23. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  24. if not _MATPLOTLIB_AVAILABLE:
  25. __doctest_skip__ = ["TotalVariation.plot"]
  26. class TotalVariation(Metric):
  27. """Compute Total Variation loss (`TV`_).
  28. As input to ``forward`` and ``update`` the metric accepts the following input
  29. - ``img`` (:class:`~torch.Tensor`): A tensor of shape ``(N, C, H, W)`` consisting of images
  30. As output of `forward` and `compute` the metric returns the following output
  31. - ``sdi`` (:class:`~torch.Tensor`): if ``reduction!='none'`` returns float scalar tensor with average TV value
  32. over sample else returns tensor of shape ``(N,)`` with TV values per sample
  33. Args:
  34. reduction: a method to reduce metric score over samples
  35. - ``'mean'``: takes the mean over samples
  36. - ``'sum'``: takes the sum over samples
  37. - ``None`` or ``'none'``: return the score per sample
  38. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  39. Raises:
  40. ValueError:
  41. If ``reduction`` is not one of ``'sum'``, ``'mean'``, ``'none'`` or ``None``
  42. Example:
  43. >>> from torch import rand
  44. >>> from torchmetrics.image import TotalVariation
  45. >>> tv = TotalVariation()
  46. >>> img = torch.rand(5, 3, 28, 28)
  47. >>> tv(img)
  48. tensor(7546.8018)
  49. """
  50. full_state_update: bool = False
  51. is_differentiable: bool = True
  52. higher_is_better: bool = False
  53. plot_lower_bound: float = 0.0
  54. num_elements: Tensor
  55. score_list: List[Tensor]
  56. score: Tensor
  57. def __init__(self, reduction: Optional[Literal["mean", "sum", "none"]] = "sum", **kwargs: Any) -> None:
  58. super().__init__(**kwargs)
  59. if reduction is not None and reduction not in ("sum", "mean", "none"):
  60. raise ValueError("Expected argument `reduction` to either be 'sum', 'mean', 'none' or None")
  61. self.reduction = reduction
  62. self.add_state("score_list", default=[], dist_reduce_fx="cat")
  63. self.add_state("score", default=tensor(0, dtype=torch.float), dist_reduce_fx="sum")
  64. self.add_state("num_elements", default=tensor(0, dtype=torch.int), dist_reduce_fx="sum")
  65. def update(self, img: Tensor) -> None:
  66. """Update current score with batch of input images."""
  67. score, num_elements = _total_variation_update(img)
  68. if self.reduction is None or self.reduction == "none":
  69. self.score_list.append(score)
  70. else:
  71. self.score += score.sum()
  72. self.num_elements += num_elements
  73. def compute(self) -> Tensor:
  74. """Compute final total variation."""
  75. score = dim_zero_cat(self.score_list) if self.reduction is None or self.reduction == "none" else self.score
  76. return _total_variation_compute(score, self.num_elements, self.reduction)
  77. def plot(
  78. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  79. ) -> _PLOT_OUT_TYPE:
  80. """Plot a single or multiple values from the metric.
  81. Args:
  82. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  83. If no value is provided, will automatically call `metric.compute` and plot that result.
  84. ax: An matplotlib axis object. If provided will add plot to that axis
  85. Returns:
  86. Figure and Axes object
  87. Raises:
  88. ModuleNotFoundError:
  89. If `matplotlib` is not installed
  90. .. plot::
  91. :scale: 75
  92. >>> # Example plotting a single value
  93. >>> import torch
  94. >>> from torchmetrics.image import TotalVariation
  95. >>> metric = TotalVariation()
  96. >>> metric.update(torch.rand(5, 3, 28, 28))
  97. >>> fig_, ax_ = metric.plot()
  98. .. plot::
  99. :scale: 75
  100. >>> # Example plotting multiple values
  101. >>> import torch
  102. >>> from torchmetrics.image import TotalVariation
  103. >>> metric = TotalVariation()
  104. >>> values = [ ]
  105. >>> for _ in range(10):
  106. ... values.append(metric(torch.rand(5, 3, 28, 28)))
  107. >>> fig_, ax_ = metric.plot(values)
  108. """
  109. return self._plot(val, ax)