minmax.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, Optional, Union
  16. import torch
  17. from torch import Tensor
  18. from torchmetrics.metric import Metric
  19. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  20. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  21. from torchmetrics.wrappers.abstract import WrapperMetric
  22. if not _MATPLOTLIB_AVAILABLE:
  23. __doctest_skip__ = ["MinMaxMetric.plot"]
  24. class MinMaxMetric(WrapperMetric):
  25. """Wrapper metric that tracks both the minimum and maximum of a scalar/tensor across an experiment.
  26. The min/max value will be updated each time ``.compute`` is called.
  27. Args:
  28. base_metric:
  29. The metric of which you want to keep track of its maximum and minimum values.
  30. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  31. Raises:
  32. ValueError
  33. If ``base_metric` argument is not a subclasses instance of ``torchmetrics.Metric``
  34. Example::
  35. >>> import torch
  36. >>> from torchmetrics.wrappers import MinMaxMetric
  37. >>> from torchmetrics.classification import BinaryAccuracy
  38. >>> from pprint import pprint
  39. >>> base_metric = BinaryAccuracy()
  40. >>> minmax_metric = MinMaxMetric(base_metric)
  41. >>> preds_1 = torch.Tensor([[0.1, 0.9], [0.2, 0.8]])
  42. >>> preds_2 = torch.Tensor([[0.9, 0.1], [0.2, 0.8]])
  43. >>> labels = torch.Tensor([[0, 1], [0, 1]]).long()
  44. >>> pprint(minmax_metric(preds_1, labels))
  45. {'max': tensor(1.), 'min': tensor(1.), 'raw': tensor(1.)}
  46. >>> pprint(minmax_metric.compute())
  47. {'max': tensor(1.), 'min': tensor(1.), 'raw': tensor(1.)}
  48. >>> minmax_metric.update(preds_2, labels)
  49. >>> pprint(minmax_metric.compute())
  50. {'max': tensor(1.), 'min': tensor(0.7500), 'raw': tensor(0.7500)}
  51. """
  52. full_state_update: Optional[bool] = True
  53. min_val: Tensor
  54. max_val: Tensor
  55. def __init__(
  56. self,
  57. base_metric: Metric,
  58. **kwargs: Any,
  59. ) -> None:
  60. super().__init__(**kwargs)
  61. if not isinstance(base_metric, Metric):
  62. raise ValueError(
  63. f"Expected base metric to be an instance of `torchmetrics.Metric` but received {base_metric}"
  64. )
  65. self._base_metric = base_metric
  66. self.min_val = torch.tensor(float("inf"))
  67. self.max_val = torch.tensor(float("-inf"))
  68. def update(self, *args: Any, **kwargs: Any) -> None:
  69. """Update the underlying metric."""
  70. self._base_metric.update(*args, **kwargs)
  71. def compute(self) -> dict[str, Tensor]:
  72. """Compute the underlying metric as well as max and min values for this metric.
  73. Returns a dictionary that consists of the computed value (``raw``), as well as the minimum (``min``) and maximum
  74. (``max``) values.
  75. """
  76. val = self._base_metric.compute()
  77. if not self._is_suitable_val(val):
  78. raise RuntimeError(f"Returned value from base metric should be a float or scalar tensor, but got {val}.")
  79. self.max_val = val if self.max_val.to(val.device) < val else self.max_val.to(val.device)
  80. self.min_val = val if self.min_val.to(val.device) > val else self.min_val.to(val.device)
  81. return {"raw": val, "max": self.max_val, "min": self.min_val}
  82. def forward(self, *args: Any, **kwargs: Any) -> Any:
  83. """Use the original forward method of the base metric class."""
  84. return super(WrapperMetric, self).forward(*args, **kwargs)
  85. def reset(self) -> None:
  86. """Set ``max_val`` and ``min_val`` to the initialization bounds and resets the base metric."""
  87. super().reset()
  88. self._base_metric.reset()
  89. @staticmethod
  90. def _is_suitable_val(val: Union[float, Tensor]) -> bool:
  91. """Check whether min/max is a scalar value."""
  92. if isinstance(val, (int, float)):
  93. return True
  94. if isinstance(val, Tensor):
  95. return val.numel() == 1
  96. return False
  97. def plot(
  98. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  99. ) -> _PLOT_OUT_TYPE:
  100. """Plot a single or multiple values from the metric.
  101. Args:
  102. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  103. If no value is provided, will automatically call `metric.compute` and plot that result.
  104. ax: An matplotlib axis object. If provided will add plot to that axis
  105. Returns:
  106. Figure and Axes object
  107. Raises:
  108. ModuleNotFoundError:
  109. If `matplotlib` is not installed
  110. .. plot::
  111. :scale: 75
  112. >>> # Example plotting a single value
  113. >>> import torch
  114. >>> from torchmetrics.wrappers import MinMaxMetric
  115. >>> from torchmetrics.classification import BinaryAccuracy
  116. >>> metric = MinMaxMetric(BinaryAccuracy())
  117. >>> metric.update(torch.randint(2, (20,)), torch.randint(2, (20,)))
  118. >>> fig_, ax_ = metric.plot()
  119. .. plot::
  120. :scale: 75
  121. >>> # Example plotting multiple values
  122. >>> import torch
  123. >>> from torchmetrics.wrappers import MinMaxMetric
  124. >>> from torchmetrics.classification import BinaryAccuracy
  125. >>> metric = MinMaxMetric(BinaryAccuracy())
  126. >>> values = [ ]
  127. >>> for _ in range(3):
  128. ... values.append(metric(torch.randint(2, (20,)), torch.randint(2, (20,))))
  129. >>> fig_, ax_ = metric.plot(values)
  130. """
  131. return self._plot(val, ax)