collections.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730
  1. # Copyright The Lightning team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # this is just a bypass for this module name collision with built-in one
  15. from collections import OrderedDict
  16. from collections.abc import Hashable, Iterable, Iterator, Mapping, Sequence
  17. from copy import deepcopy
  18. from typing import Any, ClassVar, Dict, List, Optional, Union
  19. import torch
  20. from torch import Tensor
  21. from torch.nn import ModuleDict
  22. from typing_extensions import Literal
  23. from torchmetrics.metric import Metric
  24. from torchmetrics.utilities import rank_zero_warn
  25. from torchmetrics.utilities.data import _flatten, _flatten_dict, allclose
  26. from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
  27. from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE, plot_single_or_multi_val
  28. if not _MATPLOTLIB_AVAILABLE:
  29. __doctest_skip__ = ["MetricCollection.plot", "MetricCollection.plot_all"]
  30. def _remove_prefix(string: str, prefix: str) -> str:
  31. """Patch for older version with missing method `removeprefix`.
  32. >>> _remove_prefix("prefix_string", "prefix_")
  33. 'string'
  34. >>> _remove_prefix("not_prefix_string", "prefix_")
  35. 'not_prefix_string'
  36. """
  37. return string[len(prefix) :] if string.startswith(prefix) else string
  38. def _remove_suffix(string: str, suffix: str) -> str:
  39. """Patch for older version with missing method `removesuffix`.
  40. >>> _remove_suffix("string_suffix", "_suffix")
  41. 'string'
  42. >>> _remove_suffix("string_suffix_missing", "_suffix")
  43. 'string_suffix_missing'
  44. """
  45. return string[: -len(suffix)] if string.endswith(suffix) else string
  46. class MetricCollection(ModuleDict):
  47. """MetricCollection class can be used to chain metrics that have the same call pattern into one single class.
  48. Args:
  49. metrics: One of the following
  50. * list or tuple (sequence): if metrics are passed in as a list or tuple, will use the metrics class name
  51. as key for output dict. Therefore, two metrics of the same class cannot be chained this way.
  52. * arguments: similar to passing in as a list, metrics passed in as arguments will use their metric
  53. class name as key for the output dict.
  54. * dict: if metrics are passed in as a dict, will use each key in the dict as key for output dict.
  55. Use this format if you want to chain together multiple of the same metric with different parameters.
  56. Note that the keys in the output dict will be sorted alphabetically.
  57. prefix: a string to append in front of the keys of the output dict
  58. postfix: a string to append after the keys of the output dict
  59. compute_groups:
  60. By default the MetricCollection will try to reduce the computations needed for the metrics in the collection
  61. by checking if they belong to the same **compute group**. All metrics in a compute group share the same
  62. metric state and are therefore only different in their compute step e.g. accuracy, precision and recall
  63. can all be computed from the true positives/negatives and false positives/negatives. By default,
  64. this argument is ``True`` which enables this feature. Set this argument to `False` for disabling
  65. this behaviour. Can also be set to a list of lists of metrics for setting the compute groups yourself.
  66. .. tip::
  67. The compute groups feature can significantly speedup the calculation of metrics under the right conditions.
  68. First, the feature is only available when calling the ``update`` method and not when calling ``forward`` method
  69. due to the internal logic of ``forward`` preventing this. Secondly, since we compute groups share metric
  70. states by reference, calling ``.items()``, ``.values()`` etc. on the metric collection will break this
  71. reference and a copy of states are instead returned in this case (reference will be reestablished on the next
  72. call to ``update``). Do note that for the time being that if you are manually specifying compute groups in
  73. nested collections, these are not compatible with the compute groups of the parent collection and will be
  74. overridden.
  75. .. important::
  76. Metric collections can be nested at initialization (see last example) but the output of the collection will
  77. still be a single flatten dictionary combining the prefix and postfix arguments from the nested collection.
  78. Raises:
  79. ValueError:
  80. If one of the elements of ``metrics`` is not an instance of ``pl.metrics.Metric``.
  81. ValueError:
  82. If two elements in ``metrics`` have the same ``name``.
  83. ValueError:
  84. If ``metrics`` is not a ``list``, ``tuple`` or a ``dict``.
  85. ValueError:
  86. If ``metrics`` is ``dict`` and additional_metrics are passed in.
  87. ValueError:
  88. If ``prefix`` is set and it is not a string.
  89. ValueError:
  90. If ``postfix`` is set and it is not a string.
  91. Example::
  92. In the most basic case, the metrics can be passed in as a list or tuple. The keys of the output dict will be
  93. the same as the class name of the metric:
  94. >>> from torch import tensor
  95. >>> from pprint import pprint
  96. >>> from torchmetrics import MetricCollection
  97. >>> from torchmetrics.regression import MeanSquaredError
  98. >>> from torchmetrics.classification import MulticlassAccuracy, MulticlassPrecision, MulticlassRecall
  99. >>> target = tensor([0, 2, 0, 2, 0, 1, 0, 2])
  100. >>> preds = tensor([2, 1, 2, 0, 1, 2, 2, 2])
  101. >>> metrics = MetricCollection([MulticlassAccuracy(num_classes=3, average='micro'),
  102. ... MulticlassPrecision(num_classes=3, average='macro'),
  103. ... MulticlassRecall(num_classes=3, average='macro')])
  104. >>> metrics(preds, target) # doctest: +NORMALIZE_WHITESPACE
  105. {'MulticlassAccuracy': tensor(0.1250),
  106. 'MulticlassPrecision': tensor(0.0667),
  107. 'MulticlassRecall': tensor(0.1111)}
  108. Example::
  109. Alternatively, metrics can be passed in as arguments. The keys of the output dict will be the same as the
  110. class name of the metric:
  111. >>> metrics = MetricCollection(MulticlassAccuracy(num_classes=3, average='micro'),
  112. ... MulticlassPrecision(num_classes=3, average='macro'),
  113. ... MulticlassRecall(num_classes=3, average='macro'))
  114. >>> metrics(preds, target) # doctest: +NORMALIZE_WHITESPACE
  115. {'MulticlassAccuracy': tensor(0.1250),
  116. 'MulticlassPrecision': tensor(0.0667),
  117. 'MulticlassRecall': tensor(0.1111)}
  118. Example::
  119. If multiple of the same metric class (with different parameters) should be chained together, metrics can be
  120. passed in as a dict and the output dict will have the same keys as the input dict:
  121. >>> metrics = MetricCollection({'micro_recall': MulticlassRecall(num_classes=3, average='micro'),
  122. ... 'macro_recall': MulticlassRecall(num_classes=3, average='macro')})
  123. >>> same_metric = metrics.clone()
  124. >>> pprint(metrics(preds, target))
  125. {'macro_recall': tensor(0.1111), 'micro_recall': tensor(0.1250)}
  126. >>> pprint(same_metric(preds, target))
  127. {'macro_recall': tensor(0.1111), 'micro_recall': tensor(0.1250)}
  128. Example::
  129. Metric collections can also be nested up to a single time. The output of the collection will still be a single
  130. dict with the prefix and postfix arguments from the nested collection:
  131. >>> metrics = MetricCollection([
  132. ... MetricCollection([
  133. ... MulticlassAccuracy(num_classes=3, average='macro'),
  134. ... MulticlassPrecision(num_classes=3, average='macro')
  135. ... ], postfix='_macro'),
  136. ... MetricCollection([
  137. ... MulticlassAccuracy(num_classes=3, average='micro'),
  138. ... MulticlassPrecision(num_classes=3, average='micro')
  139. ... ], postfix='_micro'),
  140. ... ], prefix='valmetrics/')
  141. >>> pprint(metrics(preds, target)) # doctest: +NORMALIZE_WHITESPACE
  142. {'valmetrics/MulticlassAccuracy_macro': tensor(0.1111),
  143. 'valmetrics/MulticlassAccuracy_micro': tensor(0.1250),
  144. 'valmetrics/MulticlassPrecision_macro': tensor(0.0667),
  145. 'valmetrics/MulticlassPrecision_micro': tensor(0.1250)}
  146. Example::
  147. The `compute_groups` argument allow you to specify which metrics should share metric state. By default, this
  148. will automatically be derived but can also be set manually.
  149. >>> metrics = MetricCollection(
  150. ... MulticlassRecall(num_classes=3, average='macro'),
  151. ... MulticlassPrecision(num_classes=3, average='macro'),
  152. ... MeanSquaredError(),
  153. ... compute_groups=[['MulticlassRecall', 'MulticlassPrecision'], ['MeanSquaredError']]
  154. ... )
  155. >>> metrics.update(preds, target)
  156. >>> pprint(metrics.compute())
  157. {'MeanSquaredError': tensor(2.3750), 'MulticlassPrecision': tensor(0.0667), 'MulticlassRecall': tensor(0.1111)}
  158. >>> pprint(metrics.compute_groups)
  159. {0: ['MulticlassRecall', 'MulticlassPrecision'], 1: ['MeanSquaredError']}
  160. """
  161. _modules: dict[str, Metric] # type: ignore[assignment]
  162. __jit_unused_properties__: ClassVar[list[str]] = ["metric_state"]
  163. def __init__(
  164. self,
  165. metrics: Union[
  166. Metric,
  167. "MetricCollection",
  168. Sequence[Union[Metric, "MetricCollection"]],
  169. dict[str, Union[Metric, "MetricCollection"]],
  170. ],
  171. *additional_metrics: Metric,
  172. prefix: Optional[str] = None,
  173. postfix: Optional[str] = None,
  174. compute_groups: Union[bool, list[list[str]]] = True,
  175. ) -> None:
  176. super().__init__()
  177. self.prefix = self._check_arg(prefix, "prefix")
  178. self.postfix = self._check_arg(postfix, "postfix")
  179. self._enable_compute_groups = compute_groups
  180. self._groups_checked: bool = False
  181. self._state_is_copy: bool = False
  182. self._groups: Dict[int, list[str]] = {}
  183. self.add_metrics(metrics, *additional_metrics)
  184. @property
  185. def metric_state(self) -> dict[str, dict[str, Any]]:
  186. """Get the current state of the metric."""
  187. return {k: m.metric_state for k, m in self.items(keep_base=False, copy_state=False)}
  188. @torch.jit.unused
  189. def forward(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
  190. """Call forward for each metric sequentially.
  191. Positional arguments (args) will be passed to every metric in the collection, while keyword arguments (kwargs)
  192. will be filtered based on the signature of the individual metric.
  193. """
  194. return self._compute_and_reduce("forward", *args, **kwargs)
  195. def update(self, *args: Any, **kwargs: Any) -> None:
  196. """Call update for each metric sequentially.
  197. Positional arguments (args) will be passed to every metric in the collection, while keyword arguments (kwargs)
  198. will be filtered based on the signature of the individual metric.
  199. """
  200. # Use compute groups if already initialized and checked
  201. if self._groups_checked:
  202. # Delete the cache of all metrics to invalidate the cache and therefore recent compute calls, forcing new
  203. # compute calls to recompute
  204. for k in self.keys(keep_base=True):
  205. mi = getattr(self, str(k))
  206. mi._computed = None
  207. for cg in self._groups.values():
  208. # only update the first member
  209. m0 = getattr(self, cg[0])
  210. m0.update(*args, **m0._filter_kwargs(**kwargs))
  211. self._state_is_copy = False
  212. self._compute_groups_create_state_ref()
  213. else: # the first update always do per metric to form compute groups
  214. for m in self.values(copy_state=False):
  215. m_kwargs = m._filter_kwargs(**kwargs)
  216. m.update(*args, **m_kwargs)
  217. if self._enable_compute_groups:
  218. self._merge_compute_groups()
  219. # create reference between states
  220. self._state_is_copy = False
  221. self._compute_groups_create_state_ref()
  222. self._groups_checked = True
  223. def _merge_compute_groups(self) -> None:
  224. """Iterate over the collection of metrics, checking if the state of each metric matches another.
  225. If so, their compute groups will be merged into one. The complexity of the method is approximately
  226. ``O(number_of_metrics_in_collection ** 2)``, as all metrics need to be compared to all other metrics.
  227. """
  228. num_groups = len(self._groups)
  229. while True:
  230. for cg_idx1, cg_members1 in deepcopy(self._groups).items():
  231. for cg_idx2, cg_members2 in deepcopy(self._groups).items():
  232. if cg_idx1 == cg_idx2:
  233. continue
  234. metric1 = getattr(self, cg_members1[0])
  235. metric2 = getattr(self, cg_members2[0])
  236. if self._equal_metric_states(metric1, metric2):
  237. self._groups[cg_idx1].extend(self._groups.pop(cg_idx2))
  238. break
  239. # Start over if we merged groups
  240. if len(self._groups) != num_groups:
  241. break
  242. # Stop when we iterate over everything and do not merge any groups
  243. if len(self._groups) == num_groups:
  244. break
  245. num_groups = len(self._groups)
  246. # Re-index groups
  247. temp = deepcopy(self._groups)
  248. self._groups = {}
  249. for idx, values in enumerate(temp.values()):
  250. self._groups[idx] = values
  251. @staticmethod
  252. def _equal_metric_states(metric1: Metric, metric2: Metric) -> bool:
  253. """Check if the metric state of two metrics are the same."""
  254. # empty state
  255. if len(metric1._defaults) == 0 or len(metric2._defaults) == 0:
  256. return False
  257. if metric1._defaults.keys() != metric2._defaults.keys():
  258. return False
  259. for key in metric1._defaults:
  260. state1 = getattr(metric1, key)
  261. state2 = getattr(metric2, key)
  262. if type(state1) != type(state2): # noqa: E721
  263. return False
  264. if (
  265. isinstance(state1, Tensor)
  266. and isinstance(state2, Tensor)
  267. and not (state1.shape == state2.shape and allclose(state1, state2))
  268. ):
  269. return False
  270. if (
  271. isinstance(state1, list)
  272. and isinstance(state2, list)
  273. and not (all(s1.shape == s2.shape and allclose(s1, s2) for s1, s2 in zip(state1, state2)))
  274. ):
  275. return False
  276. return True
  277. def _compute_groups_create_state_ref(self, copy: bool = False) -> None:
  278. """Create reference between metrics in the same compute group.
  279. Args:
  280. copy: If `True` the metric state will between members will be copied instead
  281. of just passed by reference
  282. """
  283. if not self._state_is_copy: # only create reference if not already copied
  284. for cg in self._groups.values():
  285. m0 = getattr(self, cg[0])
  286. for i in range(1, len(cg)):
  287. mi = getattr(self, cg[i])
  288. for state in m0._defaults:
  289. m0_state = getattr(m0, state)
  290. # Determine if we just should set a reference or a full copy
  291. setattr(mi, state, deepcopy(m0_state) if copy else m0_state)
  292. mi._update_count = deepcopy(m0._update_count) if copy else m0._update_count
  293. self._state_is_copy = copy
  294. def compute(self) -> dict[str, Any]:
  295. """Compute the result for each metric in the collection."""
  296. return self._compute_and_reduce("compute")
  297. def _compute_and_reduce(
  298. self, method_name: Literal["compute", "forward"], *args: Any, **kwargs: Any
  299. ) -> dict[str, Any]:
  300. """Compute result from collection and reduce into a single dictionary.
  301. Args:
  302. method_name: The method to call on each metric in the collection.
  303. Should be either `compute` or `forward`.
  304. args: Positional arguments to pass to each metric (if method_name is `forward`)
  305. kwargs: Keyword arguments to pass to each metric (if method_name is `forward`)
  306. Raises:
  307. ValueError:
  308. If method_name is not `compute` or `forward`.
  309. """
  310. result = {}
  311. for k, m in self.items(keep_base=True, copy_state=False):
  312. if method_name == "compute":
  313. res = m.compute()
  314. elif method_name == "forward":
  315. res = m(*args, **m._filter_kwargs(**kwargs))
  316. else:
  317. raise ValueError(f"method_name should be either 'compute' or 'forward', but got {method_name}")
  318. result[k] = res
  319. _, duplicates = _flatten_dict(result)
  320. flattened_results = {}
  321. for k, m in self.items(keep_base=True, copy_state=False):
  322. res = result[k]
  323. if isinstance(res, dict):
  324. for key, v in res.items():
  325. # if duplicates of keys we need to add unique prefix to each key
  326. if duplicates:
  327. stripped_k = k.replace(getattr(m, "prefix", ""), "")
  328. stripped_k = stripped_k.replace(getattr(m, "postfix", ""), "")
  329. key = f"{stripped_k}_{key}"
  330. if getattr(m, "_from_collection", None) and m.prefix is not None:
  331. key = f"{m.prefix}{key}"
  332. if getattr(m, "_from_collection", None) and m.postfix is not None:
  333. key = f"{key}{m.postfix}"
  334. flattened_results[key] = v
  335. else:
  336. flattened_results[k] = res
  337. return {self._set_name(k): v for k, v in flattened_results.items()}
  338. def reset(self) -> None:
  339. """Call reset for each metric sequentially."""
  340. for m in self.values(copy_state=False):
  341. m.reset()
  342. if self._enable_compute_groups and self._groups_checked:
  343. # reset state reference
  344. self._compute_groups_create_state_ref()
  345. def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> "MetricCollection":
  346. """Make a copy of the metric collection.
  347. Args:
  348. prefix: a string to append in front of the metric keys
  349. postfix: a string to append after the keys of the output dict.
  350. """
  351. mc = deepcopy(self)
  352. if prefix:
  353. mc.prefix = self._check_arg(prefix, "prefix")
  354. if postfix:
  355. mc.postfix = self._check_arg(postfix, "postfix")
  356. return mc
  357. def persistent(self, mode: bool = True) -> None:
  358. """Change if metric states should be saved to its state_dict after initialization."""
  359. for m in self.values(copy_state=False):
  360. m.persistent(mode)
  361. def add_metrics(
  362. self,
  363. metrics: Union[
  364. Metric,
  365. "MetricCollection",
  366. Sequence[Union[Metric, "MetricCollection"]],
  367. dict[str, Union[Metric, "MetricCollection"]],
  368. ],
  369. *additional_metrics: Metric,
  370. ) -> None:
  371. """Add new metrics to Metric Collection."""
  372. if isinstance(metrics, Metric):
  373. # set compatible with original type expectations
  374. metrics = [metrics]
  375. if isinstance(metrics, Sequence):
  376. # prepare for optional additions
  377. metrics = list(metrics)
  378. remain: list = []
  379. for m in additional_metrics:
  380. sel = metrics if isinstance(m, Metric) else remain
  381. sel.append(m)
  382. if remain:
  383. rank_zero_warn(
  384. f"You have passes extra arguments {remain} which are not `Metric` so they will be ignored."
  385. )
  386. elif additional_metrics:
  387. raise ValueError(
  388. f"You have passes extra arguments {additional_metrics} which are not compatible"
  389. f" with first passed dictionary {metrics} so they will be ignored."
  390. )
  391. if isinstance(metrics, dict):
  392. # Check all values are metrics
  393. # Make sure that metrics are added in deterministic order
  394. for name in sorted(metrics.keys()):
  395. metric = metrics[name]
  396. if not isinstance(metric, (Metric, MetricCollection)):
  397. raise ValueError(
  398. f"Value {metric} belonging to key {name} is not an instance of"
  399. " `torchmetrics.Metric` or `torchmetrics.MetricCollection`"
  400. )
  401. if isinstance(metric, Metric):
  402. self[name] = metric
  403. else:
  404. for k, v in metric.items(keep_base=False):
  405. v.postfix = metric.postfix
  406. v.prefix = metric.prefix
  407. v._from_collection = True
  408. self[f"{name}_{k}"] = v
  409. elif isinstance(metrics, Sequence):
  410. for metric in metrics:
  411. if not isinstance(metric, (Metric, MetricCollection)):
  412. raise ValueError(
  413. f"Input {metric} to `MetricCollection` is not a instance of"
  414. " `torchmetrics.Metric` or `torchmetrics.MetricCollection`"
  415. )
  416. if isinstance(metric, Metric):
  417. name = metric.__class__.__name__
  418. if name in self:
  419. raise ValueError(f"Encountered two metrics both named {name}")
  420. self[name] = metric
  421. else:
  422. for k, v in metric.items(keep_base=False):
  423. v.postfix = metric.postfix
  424. v.prefix = metric.prefix
  425. v._from_collection = True
  426. self[k] = v
  427. elif isinstance(metrics, MetricCollection):
  428. for name, metric in metrics.items(keep_base=False):
  429. if name in self:
  430. raise ValueError(f"Metric with name '{name}' already exists in the collection.")
  431. self[name] = metric
  432. else:
  433. raise ValueError(
  434. "Unknown input to MetricCollection. Expected, `Metric`, `MetricCollection` or `dict`/`sequence` of the"
  435. f" previous, but got {metrics}"
  436. )
  437. self._groups_checked = False
  438. if self._enable_compute_groups:
  439. self._init_compute_groups()
  440. else:
  441. self._groups = {}
  442. def _init_compute_groups(self) -> None:
  443. """Initialize compute groups.
  444. If user provided a list, we check that all metrics in the list are also in the collection. If set to `True` we
  445. simply initialize each metric in the collection as its own group
  446. """
  447. if isinstance(self._enable_compute_groups, list):
  448. self._groups = dict(enumerate(self._enable_compute_groups))
  449. for v in self._groups.values():
  450. for metric in v:
  451. if metric not in self:
  452. raise ValueError(
  453. f"Input {metric} in `compute_groups` argument does not match a metric in the collection."
  454. f" Please make sure that {self._enable_compute_groups} matches {self.keys(keep_base=True)}"
  455. )
  456. # add metrics not specified in compute groups as their own group
  457. already_in_group = _flatten(self._groups.values()) # type: ignore
  458. counter = len(self._groups)
  459. for k in self.keys(keep_base=True):
  460. if k not in already_in_group:
  461. self._groups[counter] = [k] # type: ignore
  462. counter += 1
  463. self._groups_checked = True
  464. else:
  465. self._groups = {i: [str(k)] for i, k in enumerate(self.keys(keep_base=True))}
  466. @property
  467. def compute_groups(self) -> Dict[int, List[str]]:
  468. """Return a dict with the current compute groups in the collection."""
  469. return self._groups
  470. def _set_name(self, base: str) -> str:
  471. """Adjust name of metric with both prefix and postfix."""
  472. name = base if self.prefix is None else self.prefix + base
  473. return name if self.postfix is None else name + self.postfix
  474. def _to_renamed_dict(self) -> Mapping[str, Metric]:
  475. # self._modules changed from OrderedDict to dict as of PyTorch 2.5.0
  476. dict_modules = OrderedDict() if isinstance(self._modules, OrderedDict) else {}
  477. for k, v in self._modules.items():
  478. dict_modules[self._set_name(k)] = v
  479. return dict_modules
  480. def __iter__(self) -> Iterator[Hashable]: # type: ignore[override]
  481. """Return an iterator over the keys of the MetricDict."""
  482. return iter(self.keys())
  483. # TODO: redefine this as native python dict
  484. def keys(self, keep_base: bool = False) -> Iterable[Hashable]: # type: ignore[override]
  485. r"""Return an iterable of the ModuleDict key.
  486. Args:
  487. keep_base: Whether to add prefix/postfix on the items collection.
  488. """
  489. if keep_base:
  490. return self._modules.keys()
  491. return self._to_renamed_dict().keys()
  492. def items(self, keep_base: bool = False, copy_state: bool = True) -> Iterable[tuple[str, Metric]]: # type: ignore[override]
  493. r"""Return an iterable of the ModuleDict key/value pairs.
  494. Args:
  495. keep_base: Whether to add prefix/postfix on the collection.
  496. copy_state:
  497. If metric states should be copied between metrics in the same compute group or just passed by reference
  498. """
  499. self._compute_groups_create_state_ref(copy_state)
  500. if keep_base:
  501. return self._modules.items()
  502. return self._to_renamed_dict().items()
  503. def values(self, copy_state: bool = True) -> Iterable[Metric]: # type: ignore[override]
  504. """Return an iterable of the ModuleDict values.
  505. Args:
  506. copy_state:
  507. If metric states should be copied between metrics in the same compute group or just passed by reference
  508. """
  509. self._compute_groups_create_state_ref(copy_state)
  510. return self._modules.values()
  511. def __getitem__(self, key: str, copy_state: bool = True) -> Metric:
  512. """Retrieve a single metric from the collection.
  513. Args:
  514. key: name of metric to retrieve
  515. copy_state:
  516. If metric states should be copied between metrics in the same compute group or just passed by reference
  517. """
  518. self._compute_groups_create_state_ref(copy_state)
  519. if self.prefix:
  520. key = _remove_prefix(key, self.prefix)
  521. if self.postfix:
  522. key = _remove_suffix(key, self.postfix)
  523. return self._modules[key]
  524. @staticmethod
  525. def _check_arg(arg: Optional[str], name: str) -> Optional[str]:
  526. if arg is None or isinstance(arg, str):
  527. return arg
  528. raise ValueError(f"Expected input `{name}` to be a string, but got {type(arg)}")
  529. def __repr__(self) -> str:
  530. """Return the representation of the metric collection including all metrics in the collection."""
  531. repr_str = super().__repr__()[:-2]
  532. if self.prefix:
  533. repr_str += f",\n prefix={self.prefix}{',' if self.postfix else ''}"
  534. if self.postfix:
  535. repr_str += f"{',' if not self.prefix else ''}\n postfix={self.postfix}"
  536. return repr_str + "\n)"
  537. def set_dtype(self, dst_type: Union[str, torch.dtype]) -> "MetricCollection":
  538. """Transfer all metric state to specific dtype. Special version of standard `type` method.
  539. Arguments:
  540. dst_type: the desired type as ``torch.dtype`` or string.
  541. """
  542. for m in self.values(copy_state=False):
  543. m.set_dtype(dst_type)
  544. return self
  545. def plot(
  546. self,
  547. val: Optional[Union[dict, Sequence[dict]]] = None,
  548. ax: Optional[Union[_AX_TYPE, Sequence[_AX_TYPE]]] = None,
  549. together: bool = False,
  550. ) -> Sequence[_PLOT_OUT_TYPE]:
  551. """Plot a single or multiple values from the metric.
  552. The plot method has two modes of operation. If argument `together` is set to `False` (default), the `.plot`
  553. method of each metric will be called individually and the result will be list of figures. If `together` is set
  554. to `True`, the values of all metrics will instead be plotted in the same figure.
  555. Args:
  556. val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results.
  557. If no value is provided, will automatically call `metric.compute` and plot that result.
  558. ax: Either a single instance of matplotlib axis object or an sequence of matplotlib axis objects. If
  559. provided, will add the plots to the provided axis objects. If not provided, will create a new. If
  560. argument `together` is set to `True`, a single object is expected. If `together` is set to `False`,
  561. the number of axis objects needs to be the same length as the number of metrics in the collection.
  562. together: If `True`, will plot all metrics in the same axis. If `False`, will plot each metric in a separate
  563. Returns:
  564. Either install tuple of Figure and Axes object or an sequence of tuples with Figure and Axes object for each
  565. metric in the collection.
  566. Raises:
  567. ModuleNotFoundError:
  568. If `matplotlib` is not installed
  569. ValueError:
  570. If `together` is not an bool
  571. ValueError:
  572. If `ax` is not an instance of matplotlib axis object or a sequence of matplotlib axis objects
  573. .. plot::
  574. :scale: 75
  575. >>> # Example plotting a single value
  576. >>> import torch
  577. >>> from torchmetrics import MetricCollection
  578. >>> from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall
  579. >>> metrics = MetricCollection([BinaryAccuracy(), BinaryPrecision(), BinaryRecall()])
  580. >>> metrics.update(torch.rand(10), torch.randint(2, (10,)))
  581. >>> fig_ax_ = metrics.plot()
  582. .. plot::
  583. :scale: 75
  584. >>> # Example plotting multiple values
  585. >>> import torch
  586. >>> from torchmetrics import MetricCollection
  587. >>> from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall
  588. >>> metrics = MetricCollection([BinaryAccuracy(), BinaryPrecision(), BinaryRecall()])
  589. >>> values = []
  590. >>> for _ in range(10):
  591. ... values.append(metrics(torch.rand(10), torch.randint(2, (10,))))
  592. >>> fig_, ax_ = metrics.plot(values, together=True)
  593. """
  594. if not isinstance(together, bool):
  595. raise ValueError(f"Expected argument `together` to be a boolean, but got {type(together)}")
  596. if ax is not None:
  597. if together and not isinstance(ax, _AX_TYPE):
  598. raise ValueError(
  599. f"Expected argument `ax` to be a matplotlib axis object, but got {type(ax)} when `together=True`"
  600. )
  601. if not together and not (
  602. isinstance(ax, Sequence) and all(isinstance(a, _AX_TYPE) for a in ax) and len(ax) == len(self)
  603. ):
  604. raise ValueError(
  605. f"Expected argument `ax` to be a sequence of matplotlib axis objects with the same length as the "
  606. f"number of metrics in the collection, but got {type(ax)} with len {len(ax)} when `together=False`"
  607. )
  608. val = val or self.compute()
  609. if together:
  610. return plot_single_or_multi_val(val, ax=ax)
  611. fig_axs = []
  612. for i, (k, m) in enumerate(self.items(keep_base=False, copy_state=False)):
  613. if isinstance(val, dict):
  614. f, a = m.plot(val[k], ax=ax[i] if ax is not None else ax)
  615. elif isinstance(val, Sequence):
  616. f, a = m.plot([v[k] for v in val], ax=ax[i] if ax is not None else ax)
  617. fig_axs.append((f, a))
  618. return fig_axs