aggregation.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, Callable, Optional, Union
  16. import torch
  17. from torch import Tensor
  18. from typing_extensions import Literal
  19. from torchmetrics.metric import Metric
  20. from torchmetrics.utilities import rank_zero_warn
  21. from torchmetrics.utilities.data import dim_zero_cat
  22. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  23. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
  24. from torchmetrics.wrappers.running import Running
  25. if not _MATPLOTLIB_AVAILABLE:
  26. __doctest_skip__ = ["SumMetric.plot", "MeanMetric.plot", "MaxMetric.plot", "MinMetric.plot"]
  27. class BaseAggregator(Metric):
  28. """Base class for aggregation metrics.
  29. Args:
  30. fn: string specifying the reduction function
  31. default_value: default tensor value to use for the metric state
  32. nan_strategy: options:
  33. - ``'error'``: if any `nan` values are encountered will give a RuntimeError
  34. - ``'warn'``: if any `nan` values are encountered will give a warning and continue
  35. - ``'ignore'``: all `nan` values are silently removed
  36. - ``'disable'``: disable all `nan` checks
  37. - a float: if a float is provided will impute any `nan` values with this value
  38. state_name: name of the metric state
  39. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  40. Raises:
  41. ValueError:
  42. If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
  43. """
  44. is_differentiable = None
  45. higher_is_better = None
  46. full_state_update: bool = False
  47. def __init__(
  48. self,
  49. fn: Union[Callable, str],
  50. default_value: Union[Tensor, list],
  51. nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "error",
  52. state_name: str = "value",
  53. **kwargs: Any,
  54. ) -> None:
  55. super().__init__(**kwargs)
  56. allowed_nan_strategy = ("error", "warn", "ignore", "disable")
  57. if nan_strategy not in allowed_nan_strategy and not isinstance(nan_strategy, float):
  58. raise ValueError(
  59. f"Arg `nan_strategy` should either be a float or one of {allowed_nan_strategy} but got {nan_strategy}."
  60. )
  61. self.nan_strategy = nan_strategy
  62. self.add_state(state_name, default=default_value, dist_reduce_fx=fn)
  63. self.state_name = state_name
  64. def _cast_and_nan_check_input(
  65. self, x: Union[float, Tensor], weight: Optional[Union[float, Tensor]] = None
  66. ) -> tuple[Tensor, Tensor]:
  67. """Convert input ``x`` to a tensor and check for Nans."""
  68. if not isinstance(x, Tensor):
  69. x = torch.as_tensor(x, dtype=self.dtype, device=self.device)
  70. if weight is not None and not isinstance(weight, Tensor):
  71. weight = torch.as_tensor(weight, dtype=self.dtype, device=self.device)
  72. if self.nan_strategy != "disable":
  73. nans = torch.isnan(x)
  74. if weight is not None:
  75. nans_weight = torch.isnan(weight)
  76. else:
  77. nans_weight = torch.zeros_like(nans).bool()
  78. weight = torch.ones_like(x)
  79. if nans.any() or nans_weight.any():
  80. if self.nan_strategy == "error":
  81. raise RuntimeError("Encountered `nan` values in tensor")
  82. if self.nan_strategy in ("ignore", "warn"):
  83. if self.nan_strategy == "warn":
  84. rank_zero_warn("Encountered `nan` values in tensor. Will be removed.", UserWarning)
  85. x = x[~(nans | nans_weight)]
  86. weight = weight[~(nans | nans_weight)]
  87. else:
  88. if not isinstance(self.nan_strategy, float):
  89. raise ValueError(f"`nan_strategy` shall be float but you pass {self.nan_strategy}")
  90. x[nans | nans_weight] = self.nan_strategy
  91. weight[nans | nans_weight] = 1
  92. else:
  93. weight = torch.ones_like(x)
  94. return x.to(self.dtype), weight.to(self.dtype)
  95. def update(self, value: Union[float, Tensor]) -> None:
  96. """Overwrite in child class."""
  97. def compute(self) -> Tensor:
  98. """Compute the aggregated value."""
  99. return getattr(self, self.state_name)
  100. class MaxMetric(BaseAggregator):
  101. """Aggregate a stream of value into their maximum value.
  102. As input to ``forward`` and ``update`` the metric accepts the following input
  103. - ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with
  104. arbitrary shape ``(...,)``.
  105. As output of `forward` and `compute` the metric returns the following output
  106. - ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated maximum value over all inputs received
  107. Args:
  108. nan_strategy: options:
  109. - ``'error'``: if any `nan` values are encountered will give a RuntimeError
  110. - ``'warn'``: if any `nan` values are encountered will give a warning and continue
  111. - ``'ignore'``: all `nan` values are silently removed
  112. - ``'disable'``: disable all `nan` checks
  113. - a float: if a float is provided will impute any `nan` values with this value
  114. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  115. Raises:
  116. ValueError:
  117. If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
  118. Example:
  119. >>> from torch import tensor
  120. >>> from torchmetrics.aggregation import MaxMetric
  121. >>> metric = MaxMetric()
  122. >>> metric.update(1)
  123. >>> metric.update(tensor([2, 3]))
  124. >>> metric.compute()
  125. tensor(3.)
  126. """
  127. full_state_update: bool = True
  128. max_value: Tensor
  129. def __init__(
  130. self,
  131. nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn",
  132. **kwargs: Any,
  133. ) -> None:
  134. super().__init__(
  135. "max",
  136. -torch.tensor(float("inf"), dtype=torch.get_default_dtype()),
  137. nan_strategy,
  138. state_name="max_value",
  139. **kwargs,
  140. )
  141. def update(self, value: Union[float, Tensor]) -> None:
  142. """Update state with data.
  143. Args:
  144. value: Either a float or tensor containing data. Additional tensor
  145. dimensions will be flattened
  146. """
  147. value, _ = self._cast_and_nan_check_input(value)
  148. if value.numel(): # make sure tensor not empty
  149. self.max_value = torch.max(self.max_value, torch.max(value))
  150. def plot(
  151. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  152. ) -> _PLOT_OUT_TYPE:
  153. """Plot a single or multiple values from the metric.
  154. Args:
  155. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  156. If no value is provided, will automatically call `metric.compute` and plot that result.
  157. ax: An matplotlib axis object. If provided will add plot to that axis
  158. Returns:
  159. Figure and Axes object
  160. Raises:
  161. ModuleNotFoundError:
  162. If `matplotlib` is not installed
  163. .. plot::
  164. :scale: 75
  165. >>> # Example plotting a single value
  166. >>> from torchmetrics.aggregation import MaxMetric
  167. >>> metric = MaxMetric()
  168. >>> metric.update([1, 2, 3])
  169. >>> fig_, ax_ = metric.plot()
  170. .. plot::
  171. :scale: 75
  172. >>> # Example plotting multiple values
  173. >>> from torchmetrics.aggregation import MaxMetric
  174. >>> metric = MaxMetric()
  175. >>> values = [ ]
  176. >>> for i in range(10):
  177. ... values.append(metric(i))
  178. >>> fig_, ax_ = metric.plot(values)
  179. """
  180. return self._plot(val, ax)
  181. class MinMetric(BaseAggregator):
  182. """Aggregate a stream of value into their minimum value.
  183. As input to ``forward`` and ``update`` the metric accepts the following input
  184. - ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with
  185. arbitrary shape ``(...,)``.
  186. As output of `forward` and `compute` the metric returns the following output
  187. - ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated minimum value over all inputs received
  188. Args:
  189. nan_strategy: options:
  190. - ``'error'``: if any `nan` values are encountered will give a RuntimeError
  191. - ``'warn'``: if any `nan` values are encountered will give a warning and continue
  192. - ``'ignore'``: all `nan` values are silently removed
  193. - ``'disable'``: disable all `nan` checks
  194. - a float: if a float is provided will impute any `nan` values with this value
  195. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  196. Raises:
  197. ValueError:
  198. If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
  199. Example:
  200. >>> from torch import tensor
  201. >>> from torchmetrics.aggregation import MinMetric
  202. >>> metric = MinMetric()
  203. >>> metric.update(1)
  204. >>> metric.update(tensor([2, 3]))
  205. >>> metric.compute()
  206. tensor(1.)
  207. """
  208. full_state_update: bool = True
  209. min_value: Tensor
  210. def __init__(
  211. self,
  212. nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn",
  213. **kwargs: Any,
  214. ) -> None:
  215. super().__init__(
  216. "min",
  217. torch.tensor(float("inf"), dtype=torch.get_default_dtype()),
  218. nan_strategy,
  219. state_name="min_value",
  220. **kwargs,
  221. )
  222. def update(self, value: Union[float, Tensor]) -> None:
  223. """Update state with data.
  224. Args:
  225. value: Either a float or tensor containing data. Additional tensor
  226. dimensions will be flattened
  227. """
  228. value, _ = self._cast_and_nan_check_input(value)
  229. if value.numel(): # make sure tensor not empty
  230. self.min_value = torch.min(self.min_value, torch.min(value))
  231. def plot(
  232. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  233. ) -> _PLOT_OUT_TYPE:
  234. """Plot a single or multiple values from the metric.
  235. Args:
  236. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  237. If no value is provided, will automatically call `metric.compute` and plot that result.
  238. ax: An matplotlib axis object. If provided will add plot to that axis
  239. Returns:
  240. Figure and Axes object
  241. Raises:
  242. ModuleNotFoundError:
  243. If `matplotlib` is not installed
  244. .. plot::
  245. :scale: 75
  246. >>> # Example plotting a single value
  247. >>> from torchmetrics.aggregation import MinMetric
  248. >>> metric = MinMetric()
  249. >>> metric.update([1, 2, 3])
  250. >>> fig_, ax_ = metric.plot()
  251. .. plot::
  252. :scale: 75
  253. >>> # Example plotting multiple values
  254. >>> from torchmetrics.aggregation import MinMetric
  255. >>> metric = MinMetric()
  256. >>> values = [ ]
  257. >>> for i in range(10):
  258. ... values.append(metric(i))
  259. >>> fig_, ax_ = metric.plot(values)
  260. """
  261. return self._plot(val, ax)
  262. class SumMetric(BaseAggregator):
  263. """Aggregate a stream of value into their sum.
  264. As input to ``forward`` and ``update`` the metric accepts the following input
  265. - ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with
  266. arbitrary shape ``(...,)``.
  267. As output of `forward` and `compute` the metric returns the following output
  268. - ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated sum over all inputs received
  269. Args:
  270. nan_strategy: options:
  271. - ``'error'``: if any `nan` values are encountered will give a RuntimeError
  272. - ``'warn'``: if any `nan` values are encountered will give a warning and continue
  273. - ``'ignore'``: all `nan` values are silently removed
  274. - ``'disable'``: disable all `nan` checks
  275. - a float: if a float is provided will impute any `nan` values with this value
  276. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  277. Raises:
  278. ValueError:
  279. If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
  280. Example:
  281. >>> from torch import tensor
  282. >>> from torchmetrics.aggregation import SumMetric
  283. >>> metric = SumMetric()
  284. >>> metric.update(1)
  285. >>> metric.update(tensor([2, 3]))
  286. >>> metric.compute()
  287. tensor(6.)
  288. """
  289. sum_value: Tensor
  290. def __init__(
  291. self,
  292. nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn",
  293. **kwargs: Any,
  294. ) -> None:
  295. super().__init__(
  296. "sum",
  297. torch.tensor(0.0, dtype=torch.get_default_dtype()),
  298. nan_strategy,
  299. state_name="sum_value",
  300. **kwargs,
  301. )
  302. def update(self, value: Union[float, Tensor]) -> None:
  303. """Update state with data.
  304. Args:
  305. value: Either a float or tensor containing data. Additional tensor
  306. dimensions will be flattened
  307. """
  308. value, _ = self._cast_and_nan_check_input(value)
  309. if value.numel():
  310. self.sum_value += value.sum()
  311. def plot(
  312. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  313. ) -> _PLOT_OUT_TYPE:
  314. """Plot a single or multiple values from the metric.
  315. Args:
  316. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  317. If no value is provided, will automatically call `metric.compute` and plot that result.
  318. ax: An matplotlib axis object. If provided will add plot to that axis
  319. Returns:
  320. Figure and Axes object
  321. Raises:
  322. ModuleNotFoundError:
  323. If `matplotlib` is not installed
  324. .. plot::
  325. :scale: 75
  326. >>> # Example plotting a single value
  327. >>> from torchmetrics.aggregation import SumMetric
  328. >>> metric = SumMetric()
  329. >>> metric.update([1, 2, 3])
  330. >>> fig_, ax_ = metric.plot()
  331. .. plot::
  332. :scale: 75
  333. >>> # Example plotting multiple values
  334. >>> from torch import rand, randint
  335. >>> from torchmetrics.aggregation import SumMetric
  336. >>> metric = SumMetric()
  337. >>> values = [ ]
  338. >>> for i in range(10):
  339. ... values.append(metric([i, i+1]))
  340. >>> fig_, ax_ = metric.plot(values)
  341. """
  342. return self._plot(val, ax)
  343. class CatMetric(BaseAggregator):
  344. """Concatenate a stream of values.
  345. As input to ``forward`` and ``update`` the metric accepts the following input
  346. - ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with
  347. arbitrary shape ``(...,)``.
  348. As output of `forward` and `compute` the metric returns the following output
  349. - ``agg`` (:class:`~torch.Tensor`): scalar float tensor with concatenated values over all input received
  350. Args:
  351. nan_strategy: options:
  352. - ``'error'``: if any `nan` values are encountered will give a RuntimeError
  353. - ``'warn'``: if any `nan` values are encountered will give a warning and continue
  354. - ``'ignore'``: all `nan` values are silently removed
  355. - ``'disable'``: disable all `nan` checks
  356. - a float: if a float is provided will impute any `nan` values with this value
  357. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  358. Raises:
  359. ValueError:
  360. If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
  361. Example:
  362. >>> from torch import tensor
  363. >>> from torchmetrics.aggregation import CatMetric
  364. >>> metric = CatMetric()
  365. >>> metric.update(1)
  366. >>> metric.update(tensor([2, 3]))
  367. >>> metric.compute()
  368. tensor([1., 2., 3.])
  369. """
  370. value: Tensor
  371. def __init__(
  372. self,
  373. nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn",
  374. **kwargs: Any,
  375. ) -> None:
  376. super().__init__("cat", [], nan_strategy, **kwargs)
  377. def update(self, value: Union[float, Tensor]) -> None:
  378. """Update state with data.
  379. Args:
  380. value: Either a float or tensor containing data. Additional tensor
  381. dimensions will be flattened
  382. """
  383. value, _ = self._cast_and_nan_check_input(value)
  384. if value.numel():
  385. self.value.append(value)
  386. def compute(self) -> Tensor:
  387. """Compute the aggregated value."""
  388. if isinstance(self.value, list) and self.value:
  389. return dim_zero_cat(self.value)
  390. return self.value
  391. class MeanMetric(BaseAggregator):
  392. """Aggregate a stream of value into their mean value.
  393. As input to ``forward`` and ``update`` the metric accepts the following input
  394. - ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with
  395. arbitrary shape ``(...,)``.
  396. - ``weight`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float value with
  397. arbitrary shape ``(...,)``. Needs to be broadcastable with the shape of ``value`` tensor.
  398. As output of `forward` and `compute` the metric returns the following output
  399. - ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated (weighted) mean over all inputs received
  400. Args:
  401. nan_strategy: options:
  402. - ``'error'``: if any `nan` values are encountered will give a RuntimeError
  403. - ``'warn'``: if any `nan` values are encountered will give a warning and continue
  404. - ``'ignore'``: all `nan` values are silently removed
  405. - ``'disable'``: disable all `nan` checks
  406. - a float: if a float is provided will impute any `nan` values with this value
  407. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  408. Raises:
  409. ValueError:
  410. If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
  411. Example:
  412. >>> from torchmetrics.aggregation import MeanMetric
  413. >>> metric = MeanMetric()
  414. >>> metric.update(1)
  415. >>> metric.update(torch.tensor([2, 3]))
  416. >>> metric.compute()
  417. tensor(2.)
  418. """
  419. mean_value: Tensor
  420. weight: Tensor
  421. def __init__(
  422. self,
  423. nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn",
  424. **kwargs: Any,
  425. ) -> None:
  426. super().__init__(
  427. "sum",
  428. torch.tensor(0.0, dtype=torch.get_default_dtype()),
  429. nan_strategy,
  430. state_name="mean_value",
  431. **kwargs,
  432. )
  433. self.add_state("weight", default=torch.tensor(0.0, dtype=torch.get_default_dtype()), dist_reduce_fx="sum")
  434. def update(self, value: Union[float, Tensor], weight: Union[float, Tensor, None] = None) -> None:
  435. """Update state with data.
  436. Args:
  437. value: Either a float or tensor containing data. Additional tensor
  438. dimensions will be flattened
  439. weight: Either a float or tensor containing weights for calculating
  440. the average. Shape of weight should be able to broadcast with
  441. the shape of `value`. Default to None corresponding to simple
  442. harmonic average.
  443. """
  444. # broadcast weight to value shape
  445. if not isinstance(value, Tensor):
  446. value = torch.as_tensor(value, dtype=self.dtype, device=self.device)
  447. if weight is None:
  448. weight = torch.ones_like(value)
  449. elif not isinstance(weight, Tensor):
  450. weight = torch.as_tensor(weight, dtype=self.dtype, device=self.device)
  451. weight = torch.broadcast_to(weight, value.shape)
  452. value, weight = self._cast_and_nan_check_input(value, weight)
  453. if value.numel() == 0:
  454. return
  455. self.mean_value += (value * weight).sum()
  456. self.weight += weight.sum()
  457. def compute(self) -> Tensor:
  458. """Compute the aggregated value."""
  459. return self.mean_value / self.weight
  460. def plot(
  461. self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
  462. ) -> _PLOT_OUT_TYPE:
  463. """Plot a single or multiple values from the metric.
  464. Args:
  465. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  466. If no value is provided, will automatically call `metric.compute` and plot that result.
  467. ax: An matplotlib axis object. If provided will add plot to that axis
  468. Returns:
  469. Figure and Axes object
  470. Raises:
  471. ModuleNotFoundError:
  472. If `matplotlib` is not installed
  473. .. plot::
  474. :scale: 75
  475. >>> # Example plotting a single value
  476. >>> from torchmetrics.aggregation import MeanMetric
  477. >>> metric = MeanMetric()
  478. >>> metric.update([1, 2, 3])
  479. >>> fig_, ax_ = metric.plot()
  480. .. plot::
  481. :scale: 75
  482. >>> # Example plotting multiple values
  483. >>> from torchmetrics.aggregation import MeanMetric
  484. >>> metric = MeanMetric()
  485. >>> values = [ ]
  486. >>> for i in range(10):
  487. ... values.append(metric([i, i+1]))
  488. >>> fig_, ax_ = metric.plot(values)
  489. """
  490. return self._plot(val, ax)
  491. class RunningMean(Running):
  492. """Aggregate a stream of value into their mean over a running window.
  493. Using this metric compared to `MeanMetric` allows for calculating metrics over a running window of values, instead
  494. of the whole history of values. This is beneficial when you want to get a better estimate of the metric during
  495. training and don't want to wait for the whole training to finish to get epoch level estimates.
  496. As input to ``forward`` and ``update`` the metric accepts the following input
  497. - ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with
  498. arbitrary shape ``(...,)``.
  499. As output of `forward` and `compute` the metric returns the following output
  500. - ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated sum over all inputs received
  501. Args:
  502. nan_strategy: options:
  503. - ``'error'``: if any `nan` values are encountered will give a RuntimeError
  504. - ``'warn'``: if any `nan` values are encountered will give a warning and continue
  505. - ``'ignore'``: all `nan` values are silently removed
  506. - ``'disable'``: disable all `nan` checks
  507. - a float: if a float is provided will impute any `nan` values with this value
  508. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  509. Raises:
  510. ValueError:
  511. If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
  512. Example:
  513. >>> from torch import tensor
  514. >>> from torchmetrics.aggregation import RunningMean
  515. >>> metric = RunningMean(window=3)
  516. >>> for i in range(6):
  517. ... current_val = metric(tensor([i]))
  518. ... running_val = metric.compute()
  519. ... total_val = tensor(sum(list(range(i+1)))) / (i+1) # total mean over all samples
  520. ... print(f"{current_val=}, {running_val=}, {total_val=}")
  521. current_val=tensor(0.), running_val=tensor(0.), total_val=tensor(0.)
  522. current_val=tensor(1.), running_val=tensor(0.5000), total_val=tensor(0.5000)
  523. current_val=tensor(2.), running_val=tensor(1.), total_val=tensor(1.)
  524. current_val=tensor(3.), running_val=tensor(2.), total_val=tensor(1.5000)
  525. current_val=tensor(4.), running_val=tensor(3.), total_val=tensor(2.)
  526. current_val=tensor(5.), running_val=tensor(4.), total_val=tensor(2.5000)
  527. """
  528. def __init__(
  529. self,
  530. window: int = 5,
  531. nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn",
  532. **kwargs: Any,
  533. ) -> None:
  534. super().__init__(base_metric=MeanMetric(nan_strategy=nan_strategy, **kwargs), window=window)
  535. class RunningSum(Running):
  536. """Aggregate a stream of value into their sum over a running window.
  537. Using this metric compared to `SumMetric` allows for calculating metrics over a running window of values, instead
  538. of the whole history of values. This is beneficial when you want to get a better estimate of the metric during
  539. training and don't want to wait for the whole training to finish to get epoch level estimates.
  540. As input to ``forward`` and ``update`` the metric accepts the following input
  541. - ``value`` (:class:`~float` or :class:`~torch.Tensor`): a single float or an tensor of float values with
  542. arbitrary shape ``(...,)``.
  543. As output of `forward` and `compute` the metric returns the following output
  544. - ``agg`` (:class:`~torch.Tensor`): scalar float tensor with aggregated sum over all inputs received
  545. Args:
  546. window: The size of the running window.
  547. nan_strategy: options:
  548. - ``'error'``: if any `nan` values are encountered will give a RuntimeError
  549. - ``'warn'``: if any `nan` values are encountered will give a warning and continue
  550. - ``'ignore'``: all `nan` values are silently removed
  551. - ``'disable'``: disable all `nan` checks
  552. - a float: if a float is provided will impute any `nan` values with this value
  553. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info.
  554. Raises:
  555. ValueError:
  556. If ``nan_strategy`` is not one of ``error``, ``warn``, ``ignore``, ``disable`` or a float
  557. Example:
  558. >>> from torch import tensor
  559. >>> from torchmetrics.aggregation import RunningSum
  560. >>> metric = RunningSum(window=3)
  561. >>> for i in range(6):
  562. ... current_val = metric(tensor([i]))
  563. ... running_val = metric.compute()
  564. ... total_val = tensor(sum(list(range(i+1)))) # total sum over all samples
  565. ... print(f"{current_val=}, {running_val=}, {total_val=}")
  566. current_val=tensor(0.), running_val=tensor(0.), total_val=tensor(0)
  567. current_val=tensor(1.), running_val=tensor(1.), total_val=tensor(1)
  568. current_val=tensor(2.), running_val=tensor(3.), total_val=tensor(3)
  569. current_val=tensor(3.), running_val=tensor(6.), total_val=tensor(6)
  570. current_val=tensor(4.), running_val=tensor(9.), total_val=tensor(10)
  571. current_val=tensor(5.), running_val=tensor(12.), total_val=tensor(15)
  572. """
  573. def __init__(
  574. self,
  575. window: int = 5,
  576. nan_strategy: Union[Literal["error", "warn", "ignore", "disable"], float] = "warn",
  577. **kwargs: Any,
  578. ) -> None:
  579. super().__init__(base_metric=SumMetric(nan_strategy=nan_strategy, **kwargs), window=window)