running.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, Optional, Union
  16. from torch import Tensor
  17. from torchmetrics.metric import Metric
  18. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  19. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  20. from torchmetrics.wrappers.abstract import WrapperMetric
  21. if not _MATPLOTLIB_AVAILABLE:
  22. __doctest_skip__ = ["Running.plot"]
  23. class Running(WrapperMetric):
  24. """Running wrapper for metrics.
  25. Using this wrapper allows for calculating metrics over a running window of values, instead of the whole history of
  26. values. This is beneficial when you want to get a better estimate of the metric during training and don't want to
  27. wait for the whole training to finish to get epoch level estimates.
  28. The running window is defined by the `window` argument. The window is a fixed size and this wrapper will store a
  29. duplicate of the underlying metric state for each value in the window. Thus memory usage will increase linearly
  30. with window size. Use accordingly. Also note that the running only works with metrics that have the
  31. `full_state_update` set to `False`.
  32. Importantly, the wrapper does not alter the value of the `forward` method of the underlying metric. Thus, forward
  33. will still return the value on the current batch. To get the running value call `compute` instead.
  34. Args:
  35. base_metric: The metric to wrap.
  36. window: The size of the running window.
  37. Example (single metric):
  38. >>> from torch import tensor
  39. >>> from torchmetrics.wrappers import Running
  40. >>> from torchmetrics.aggregation import SumMetric
  41. >>> metric = Running(SumMetric(), window=3)
  42. >>> for i in range(6):
  43. ... current_val = metric(tensor([i]))
  44. ... running_val = metric.compute()
  45. ... total_val = tensor(sum(list(range(i+1)))) # value we would get from `compute` without running
  46. ... print(f"{current_val=}, {running_val=}, {total_val=}")
  47. current_val=tensor(0.), running_val=tensor(0.), total_val=tensor(0)
  48. current_val=tensor(1.), running_val=tensor(1.), total_val=tensor(1)
  49. current_val=tensor(2.), running_val=tensor(3.), total_val=tensor(3)
  50. current_val=tensor(3.), running_val=tensor(6.), total_val=tensor(6)
  51. current_val=tensor(4.), running_val=tensor(9.), total_val=tensor(10)
  52. current_val=tensor(5.), running_val=tensor(12.), total_val=tensor(15)
  53. Example (metric collection):
  54. >>> from torch import tensor
  55. >>> from torchmetrics.wrappers import Running
  56. >>> from torchmetrics import MetricCollection
  57. >>> from torchmetrics.aggregation import SumMetric, MeanMetric
  58. >>> # note that running is input to collection, not the other way
  59. >>> metric = MetricCollection({"sum": Running(SumMetric(), 3), "mean": Running(MeanMetric(), 3)})
  60. >>> for i in range(6):
  61. ... current_val = metric(tensor([i]))
  62. ... running_val = metric.compute()
  63. ... print(f"{current_val=}, {running_val=}")
  64. current_val={'mean': tensor(0.), 'sum': tensor(0.)}, running_val={'mean': tensor(0.), 'sum': tensor(0.)}
  65. current_val={'mean': tensor(1.), 'sum': tensor(1.)}, running_val={'mean': tensor(0.5000), 'sum': tensor(1.)}
  66. current_val={'mean': tensor(2.), 'sum': tensor(2.)}, running_val={'mean': tensor(1.), 'sum': tensor(3.)}
  67. current_val={'mean': tensor(3.), 'sum': tensor(3.)}, running_val={'mean': tensor(2.), 'sum': tensor(6.)}
  68. current_val={'mean': tensor(4.), 'sum': tensor(4.)}, running_val={'mean': tensor(3.), 'sum': tensor(9.)}
  69. current_val={'mean': tensor(5.), 'sum': tensor(5.)}, running_val={'mean': tensor(4.), 'sum': tensor(12.)}
  70. """
  71. def __init__(self, base_metric: Metric, window: int = 5) -> None:
  72. super().__init__()
  73. if not isinstance(base_metric, Metric):
  74. raise ValueError(
  75. f"Expected argument `metric` to be an instance of `torchmetrics.Metric` but got {base_metric}"
  76. )
  77. if not (isinstance(window, int) and window > 0):
  78. raise ValueError(f"Expected argument `window` to be a positive integer but got {window}")
  79. self.base_metric = base_metric
  80. self.window = window
  81. if base_metric.full_state_update is not False:
  82. raise ValueError(
  83. f"Expected attribute `full_state_update` set to `False` but got {base_metric.full_state_update}"
  84. )
  85. self._num_vals_seen = 0
  86. for key in base_metric._defaults:
  87. for i in range(window):
  88. self.add_state(
  89. name=key + f"_{i}", default=base_metric._defaults[key], dist_reduce_fx=base_metric._reductions[key]
  90. )
  91. def update(self, *args: Any, **kwargs: Any) -> None:
  92. """Update the underlying metric and save state afterwards."""
  93. val = self._num_vals_seen % self.window
  94. self.base_metric.update(*args, **kwargs)
  95. for key in self.base_metric._defaults:
  96. setattr(self, key + f"_{val}", getattr(self.base_metric, key))
  97. self.base_metric.reset()
  98. self._num_vals_seen += 1
  99. def forward(self, *args: Any, **kwargs: Any) -> Any:
  100. """Forward input to the underlying metric and save state afterwards."""
  101. val = self._num_vals_seen % self.window
  102. res = self.base_metric.forward(*args, **kwargs)
  103. for key in self.base_metric._defaults:
  104. setattr(self, key + f"_{val}", getattr(self.base_metric, key))
  105. self.base_metric.reset()
  106. self._num_vals_seen += 1
  107. self._computed = None
  108. return res
  109. def compute(self) -> Any:
  110. """Compute the metric over the running window."""
  111. for i in range(self.window):
  112. self.base_metric._reduce_states({key: getattr(self, key + f"_{i}") for key in self.base_metric._defaults})
  113. self.base_metric._update_count = self._num_vals_seen
  114. val = self.base_metric.compute()
  115. self.base_metric.reset()
  116. return val
  117. def reset(self) -> None:
  118. """Reset metric."""
  119. super().reset()
  120. self._num_vals_seen = 0
  121. def plot(
  122. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  123. ) -> _PLOT_OUT_TYPE:
  124. """Plot a single or multiple values from the metric.
  125. Args:
  126. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  127. If no value is provided, will automatically call `metric.compute` and plot that result.
  128. ax: An matplotlib axis object. If provided will add plot to that axis
  129. Returns:
  130. Figure and Axes object
  131. Raises:
  132. ModuleNotFoundError:
  133. If `matplotlib` is not installed
  134. .. plot::
  135. :scale: 75
  136. >>> # Example plotting a single value
  137. >>> import torch
  138. >>> from torchmetrics.wrappers import Running
  139. >>> from torchmetrics.aggregation import SumMetric
  140. >>> metric = Running(SumMetric(), 2)
  141. >>> metric.update(torch.randn(20, 2))
  142. >>> fig_, ax_ = metric.plot()
  143. .. plot::
  144. :scale: 75
  145. >>> # Example plotting multiple values
  146. >>> import torch
  147. >>> from torchmetrics.wrappers import Running
  148. >>> from torchmetrics.aggregation import SumMetric
  149. >>> metric = Running(SumMetric(), 2)
  150. >>> values = [ ]
  151. >>> for _ in range(3):
  152. ... values.append(metric(torch.randn(20, 2)))
  153. >>> fig_, ax_ = metric.plot(values)
  154. """
  155. return self._plot(val, ax)