metrics_logger.py 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955
  1. import logging
  2. import time
  3. from typing import Any, Dict, List, Optional, Tuple, Type, Union
  4. import numpy as np
  5. import tree # pip install dm_tree
  6. from ray._common.deprecation import DEPRECATED_VALUE, Deprecated, deprecation_warning
  7. from ray.rllib.utils import deep_update, force_tuple
  8. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  9. from ray.rllib.utils.metrics.stats import (
  10. EmaStats,
  11. ItemSeriesStats,
  12. ItemStats,
  13. LifetimeSumStats,
  14. MaxStats,
  15. MeanStats,
  16. MinStats,
  17. PercentilesStats,
  18. StatsBase,
  19. SumStats,
  20. )
  21. from ray.util.annotations import DeveloperAPI, PublicAPI
  22. _, tf, _ = try_import_tf()
  23. torch, _ = try_import_torch()
  24. logger = logging.getLogger("ray.rllib")
  25. # This is used by default to look up classes to use for logging stats.
  26. # You can override it and add new classes by passing a different lookup to the MetricsLogger constructor.
  27. # These new classes can then be used to log stats by passing the corresponding identifier to the MetricsLogger.log method.
  28. DEFAULT_STATS_CLS_LOOKUP = {
  29. "mean": MeanStats,
  30. "ema": EmaStats,
  31. "min": MinStats,
  32. "max": MaxStats,
  33. "sum": SumStats,
  34. "lifetime_sum": LifetimeSumStats,
  35. "percentiles": PercentilesStats,
  36. "item": ItemStats,
  37. "item_series": ItemSeriesStats,
  38. }
  39. # Note(Artur): Delete this in a couple of Ray releases.
  40. @DeveloperAPI
  41. def stats_from_legacy_state(state: Dict[str, Any], is_root: bool = False) -> StatsBase:
  42. """Creates a Stats object from a legacy state."""
  43. cls_identifier = state["reduce"]
  44. new_state = {
  45. # Always set is_leaf to True for legacy stats for compatibility
  46. "is_leaf": not is_root, # We assume that legacy stats have been logged correctly (to leaf stats only) because we have no way of checking otherwise.
  47. "is_root": is_root,
  48. "latest_merged": [], # Always include a latest_merged field for compatibility.
  49. }
  50. if state.get("clear_on_reduce", True) is False:
  51. if cls_identifier == "sum":
  52. new_state["stats_cls_identifier"] = "lifetime_sum"
  53. # lifetime sum
  54. if is_root:
  55. # With the new stats, only the root logger tracks values for lifetime sum.
  56. new_state["lifetime_sum"] = np.nansum(state["values"])
  57. else:
  58. new_state["lifetime_sum"] = 0.0
  59. # old lifetime sum checkpoints always track a througput
  60. if state.get("throughput_stats") is not None:
  61. new_state["track_throughputs"] = True
  62. else:
  63. new_state["track_throughputs"] = False
  64. _cls = DEFAULT_STATS_CLS_LOOKUP["lifetime_sum"]
  65. stats = _cls.from_state(state=new_state)
  66. return stats
  67. else:
  68. deprecation_warning(
  69. "Legacy Stats class tracking throughput detected. This is not supported anymore.",
  70. error=False,
  71. )
  72. if cls_identifier == "mean":
  73. if state["ema_coeff"] is not None:
  74. cls_identifier = "ema"
  75. new_state["ema_coeff"] = state["ema_coeff"]
  76. new_state["value"] = np.nanmean(state["values"])
  77. new_state["stats_cls_identifier"] = "ema"
  78. else:
  79. cls_identifier = "mean"
  80. new_state["values"] = state["values"]
  81. new_state["window"] = state["window"]
  82. elif cls_identifier in ["min", "max", "sum"]:
  83. new_state["values"] = state["values"]
  84. new_state["window"] = state["window"]
  85. if cls_identifier == "sum" and state.get("throughput_stats") is not None:
  86. new_state["track_throughput"] = True
  87. else:
  88. new_state["track_throughput"] = False
  89. elif cls_identifier is None and state.get("percentiles", False) is not False:
  90. # This is a percentiles stats (reduce=None with percentiles specified)
  91. cls_identifier = "percentiles"
  92. new_state["values"] = state["values"]
  93. new_state["window"] = state["window"]
  94. new_state["percentiles"] = state["percentiles"]
  95. new_state["stats_cls_identifier"] = "percentiles"
  96. elif cls_identifier == "percentiles":
  97. new_state["values"] = state["values"]
  98. new_state["window"] = state["window"]
  99. new_state["percentiles"] = state["percentiles"]
  100. _cls = DEFAULT_STATS_CLS_LOOKUP[cls_identifier]
  101. new_state["stats_cls_identifier"] = cls_identifier
  102. stats = _cls.from_state(state=new_state)
  103. return stats
  104. @PublicAPI(stability="alpha")
  105. class MetricsLogger:
  106. """A generic class collecting and reducing metrics.
  107. Use this API to log and merge metrics.
  108. Metrics should be logged in parallel components with MetricsLogger.log_value().
  109. RLlib will then aggregate metrics, reduce them and report them.
  110. The MetricsLogger supports logging anything that has a corresponding reduction method.
  111. These are defined natively in the Stats classes, which are used to log the metrics.
  112. Please take a look ray.rllib.utils.metrics.metrics_logger.DEFAULT_STATS_CLS_LOOKUP for the available reduction methods.
  113. You can provide your own reduce methods by extending ray.rllib.utils.metrics.metrics_logger.DEFAULT_STATS_CLS_LOOKUP and passing it to AlgorithmConfig.logging().
  114. Notes on architecture:
  115. In our docstrings we make heavy use of the phrase 'parallel components'.
  116. This pertains to the architecture of the logging system, where we have one 'root' MetricsLogger
  117. that is used to aggregate all metrics of n parallel ('non-root') MetricsLoggers that are used to log metrics for each parallel component.
  118. A parallel component is typically a single Learner, an EnvRunner, or a ConnectorV2 or any other component of which more than one instance is running in parallel.
  119. We also allow intermediate MetricsLoggers that are no root MetricsLogger but are used to aggregate metrics. They are therefore neither root nor leaf.
  120. """
  121. def __init__(
  122. self,
  123. root=False,
  124. stats_cls_lookup: Optional[
  125. Dict[str, Type[StatsBase]]
  126. ] = DEFAULT_STATS_CLS_LOOKUP,
  127. ):
  128. """Initializes a MetricsLogger instance.
  129. Args:
  130. root: Whether this logger is a root logger. If True, lifetime sums (reduce="lifetime_sum") will not be cleared on reduce().
  131. stats_cls_lookup: A dictionary mapping reduction method names to Stats classes.
  132. If not provided, the default lookup (ray.rllib.utils.metrics.metrics_logger.DEFAULT_STATS_CLS_LOOKUP) will be used.
  133. You can provide your own reduce methods by extending ray.rllib.utils.metrics.metrics_logger.DEFAULT_STATS_CLS_LOOKUP and passing it to AlgorithmConfig.logging().
  134. """
  135. self.stats = {}
  136. # TODO (sven): We use a dummy RLock here for most RLlib algos, however, APPO
  137. # and IMPALA require this to be an actual RLock (b/c of thread safety reasons).
  138. # An actual RLock, however, breaks our current OfflineData and
  139. # OfflinePreLearner logic, in which the Learner (which contains a
  140. # MetricsLogger) is serialized and deserialized. We will have to fix this
  141. # offline RL logic first, then can remove this hack here and return to always
  142. # using the RLock.
  143. self._threading_lock = _DummyRLock()
  144. self._is_root_logger = root
  145. self._time_when_initialized = time.perf_counter()
  146. self.stats_cls_lookup = stats_cls_lookup
  147. def __contains__(self, key: Union[str, Tuple[str, ...]]) -> bool:
  148. """Returns True, if `key` can be found in self.stats.
  149. Args:
  150. key: The key to find in self.stats. This must be either a str (single,
  151. top-level key) or a tuple of str (nested key).
  152. Returns:
  153. Whether `key` could be found in self.stats.
  154. """
  155. return self._key_in_stats(key)
  156. def peek(
  157. self,
  158. key: Union[str, Tuple[str, ...], None] = None,
  159. default=None,
  160. compile: bool = True,
  161. throughput: bool = False,
  162. latest_merged_only: bool = False,
  163. ) -> Any:
  164. """Returns the reduced values found in this MetricsLogger.
  165. Note that calling this method does NOT cause an actual underlying value list
  166. reduction, even though reduced values are being returned. It'll keep all
  167. internal structures as-is. By default, this returns a single reduced value or, if
  168. the Stats object has no reduce method, a list of values. When when compile is False,
  169. the result is a list of one or more values.
  170. Args:
  171. key: The key/key sequence of the sub-structure of `self`, whose (reduced)
  172. values to return.
  173. default: An optional default value in case `key` cannot be found in `self`.
  174. If default is not provided and `key` cannot be found, throws a KeyError.
  175. compile: If True, the result is compiled into a single value if possible.
  176. throughput: If True, the throughput is returned instead of the
  177. actual (reduced) value.
  178. latest_merged_only: If True, only considers the latest merged values.
  179. This parameter only works on aggregation loggers (root or intermediate).
  180. Returns:
  181. The (reduced) values of the (possibly nested) sub-structure found under
  182. the given key or key sequence.
  183. """
  184. if throughput:
  185. assert (
  186. self._is_root_logger
  187. ), "Throughput can only be peeked from a root logger"
  188. return self._get_throughputs(key=key, default=default)
  189. # Create a reduced view of the entire stats structure.
  190. def _nested_peek(stats: Dict[str, Any]):
  191. def _peek_with_path(path: str, stats: StatsBase):
  192. try:
  193. return stats.peek(
  194. compile=compile, latest_merged_only=latest_merged_only
  195. )
  196. except Exception as e:
  197. raise ValueError(
  198. f"Error peeking stats {stats} with compile={compile} at path {path}."
  199. ) from e
  200. return tree.map_structure_with_path(_peek_with_path, stats.copy())
  201. with self._threading_lock:
  202. if key is None:
  203. return _nested_peek(self.stats)
  204. else:
  205. if default is None:
  206. stats = self._get_key(key, key_error=True)
  207. else:
  208. stats = self._get_key(key, key_error=False)
  209. if isinstance(stats, StatsBase):
  210. # If the Stats object has a reduce method, we need to convert the list to a single value
  211. return stats.peek(
  212. compile=compile, latest_merged_only=latest_merged_only
  213. )
  214. elif isinstance(stats, dict) and stats:
  215. return _nested_peek(stats)
  216. else:
  217. return default
  218. @staticmethod
  219. def peek_results(results: Any, compile: bool = True) -> Any:
  220. """Performs `peek()` on any leaf element of an arbitrarily nested Stats struct.
  221. Args:
  222. results: The nested structure of Stats-leafs to be peek'd and returned.
  223. compile: If True, the result is compiled into a single value if possible.
  224. Returns:
  225. A corresponding structure of the peek'd `results` (reduced float/int values;
  226. no Stats objects).
  227. """
  228. return tree.map_structure(
  229. lambda s: s.peek(compile=compile) if isinstance(s, StatsBase) else s,
  230. results,
  231. )
  232. def _maybe_create_stats_object(
  233. self,
  234. key: Union[str, Tuple[str, ...]],
  235. *,
  236. reduce: str = "ema",
  237. window: Optional[Union[int, float]] = None,
  238. ema_coeff: Optional[float] = None,
  239. percentiles: Optional[Union[List[int], bool]] = None,
  240. clear_on_reduce: Optional[bool] = DEPRECATED_VALUE,
  241. with_throughput: Optional[bool] = None,
  242. throughput_ema_coeff: Optional[float] = DEPRECATED_VALUE,
  243. reduce_per_index_on_aggregate: Optional[bool] = DEPRECATED_VALUE,
  244. **kwargs: Dict[str, Any],
  245. ) -> None:
  246. """Prepare the kwargs and create the stats object if it doesn't exist."""
  247. with self._threading_lock:
  248. # `key` doesn't exist -> Automatically create it.
  249. if not self._key_in_stats(key):
  250. if reduce == "ema" and ema_coeff is None:
  251. ema_coeff = 0.01
  252. if percentiles and not reduce == "percentiles":
  253. raise ValueError(
  254. "percentiles is only supported for reduce=percentiles"
  255. )
  256. if reduce == "ema" and window is not None:
  257. deprecation_warning(
  258. "window is not supported for ema reduction. If you want to use a window, use mean reduction instead.",
  259. error=True,
  260. )
  261. window = None
  262. if reduce_per_index_on_aggregate is not DEPRECATED_VALUE:
  263. deprecation_warning(
  264. "reduce_per_index_on_aggregate is deprecated. Aggregation now happens over all values"
  265. "of incoming stats objects, treating each incoming value with equal weight.",
  266. error=False,
  267. )
  268. if throughput_ema_coeff is not DEPRECATED_VALUE:
  269. deprecation_warning(
  270. "throughput_ema_coeff is deprecated. Throughput is not smoothed with ema anymore"
  271. "but calculate once per MetricsLogger.reduce() call.",
  272. error=True,
  273. )
  274. if reduce == "mean":
  275. if ema_coeff is not None:
  276. deprecation_warning(
  277. "ema_coeff is not supported for mean reduction. Use `reduce='ema'` instead.",
  278. error=True,
  279. )
  280. if with_throughput and reduce not in ["sum", "lifetime_sum"]:
  281. deprecation_warning(
  282. "with_throughput=True is only supported for reduce='sum' or reduce='lifetime_sum'. Use reduce='sum' or reduce='lifetime_sum' instead.",
  283. error=False,
  284. )
  285. try:
  286. stats_cls = self.stats_cls_lookup[reduce]
  287. except KeyError:
  288. raise ValueError(
  289. f"Invalid reduce method '{reduce}' could not be found in stats_cls_lookup"
  290. )
  291. if window is not None:
  292. kwargs["window"] = window
  293. if ema_coeff is not None:
  294. kwargs["ema_coeff"] = ema_coeff
  295. if percentiles is not None:
  296. kwargs["percentiles"] = percentiles
  297. if with_throughput is not None:
  298. kwargs["with_throughput"] = with_throughput
  299. # Only stats at the root logger can be root stats
  300. kwargs["is_root"] = self._is_root_logger
  301. # Any Stats that are created in a logger are leaf stats by definition.
  302. # If they are aggregated from another logger, they are not leaf stats.
  303. kwargs["is_leaf"] = True
  304. stats_object = stats_cls(**kwargs)
  305. self._set_key(key, stats_object)
  306. def log_value(
  307. self,
  308. key: Union[str, Tuple[str, ...]],
  309. value: Any,
  310. *,
  311. reduce: Optional[str] = None,
  312. window: Optional[Union[int, float]] = None,
  313. ema_coeff: Optional[float] = None,
  314. percentiles: Optional[Union[List[int], bool]] = None,
  315. clear_on_reduce: Optional[bool] = DEPRECATED_VALUE,
  316. with_throughput: Optional[bool] = None,
  317. throughput_ema_coeff: Optional[float] = DEPRECATED_VALUE,
  318. reduce_per_index_on_aggregate: Optional[bool] = DEPRECATED_VALUE,
  319. **kwargs: Dict[str, Any],
  320. ) -> None:
  321. """Logs a new value or item under a (possibly nested) key to the logger.
  322. Args:
  323. key: The key (or nested key-tuple) to log the `value` under.
  324. value: A numeric value, an item to log or a StatsObject containing multiple values to log.
  325. reduce: The reduction method to apply when compiling metrics at the root logger.
  326. By default, the reduction methods to choose from here are the keys
  327. of rllib.utils.metrics.metrics_logger.DEFAULT_STATS_CLS_LOOKUP.
  328. You can provide your own reduce methods by extending
  329. rllib.utils.metrics.metrics_logger.DEFAULT_STATS_CLS_LOOKUP and passing it to AlgorithmConfig.logging()).
  330. window: An optional window size to reduce over.
  331. If not None, then the reduction operation is only applied to the most
  332. recent `window` items, and - after reduction - the internal values list
  333. under `key` is shortened to hold at most `window` items (the most
  334. recent ones). Must be None if `ema_coeff` is provided.
  335. If None (and `ema_coeff` is None), reduction must not be "mean".
  336. ema_coeff: An optional EMA coefficient to use if `reduce` is "mean"
  337. and no `window` is provided. Note that if both `window` and `ema_coeff`
  338. are provided, an error is thrown. Also, if `ema_coeff` is provided,
  339. `reduce` must be "mean".
  340. The reduction formula for EMA is:
  341. EMA(t1) = (1.0 - ema_coeff) * EMA(t0) + ema_coeff * new_value
  342. Defaults to 0.01.
  343. percentiles: If reduce is `None`, we can compute the percentiles of the
  344. values list given by `percentiles`. Defaults to [0, 0.5, 0.75, 0.9, 0.95,
  345. 0.99, 1] if set to True. When using percentiles, a window must be provided.
  346. This window should be chosen carefully. RLlib computes exact percentiles and
  347. the computational complexity is O(m*n*log(n/m)) where n is the window size
  348. and m is the number of parallel metrics loggers involved (for example,
  349. m EnvRunners).
  350. clear_on_reduce: Deprecated. Use reduce="lifetime_sum" instead.
  351. If True, all values under `key` will be cleared after
  352. `self.reduce()` is called. Setting this to True is useful for cases,
  353. in which the internal values list would otherwise grow indefinitely,
  354. for example if reduce is None and there is no `window` provided.
  355. with_throughput: Whether to track a throughput estimate together with this
  356. metric. This is supported by default only for `reduce=sum` and `reduce=lifetime_sum`.
  357. throughput_ema_coeff: Deprecated argument. Throughput is not smoothed with ema anymore
  358. but calculate once per MetricsLogger.reduce() call.
  359. reduce_per_index_on_aggregate: Deprecated argument. Aggregation now happens over all values
  360. of incoming stats objects once per MetricsLogger.reduce() call, treating each incoming value with equal weight.
  361. """
  362. # Some compatibility logic to support the legacy usage of MetricsLogger:
  363. # 1. If no reduce method is provided and a window is provided, use mean reduction.
  364. if reduce is None and window is not None:
  365. reduce = "mean"
  366. if reduce is None:
  367. reduce = "ema"
  368. # 2. If clear_on_reduce is provided, warn about deprecation.
  369. if clear_on_reduce is not DEPRECATED_VALUE:
  370. deprecation_warning(
  371. "clear_on_reduce is deprecated. Use reduce='lifetime_sum' for sums. Provide a custom reduce method for other cases.",
  372. error=False,
  373. )
  374. # 3. If reduce is sum and clear_on_reduce is False, use lifetime_sum instead
  375. if reduce == "sum" and clear_on_reduce is False:
  376. reduce = "lifetime_sum"
  377. clear_on_reduce = None
  378. # Prepare the kwargs for the stats object and create it if it doesn't exist
  379. self._maybe_create_stats_object(
  380. key,
  381. reduce=reduce,
  382. window=window,
  383. ema_coeff=ema_coeff,
  384. percentiles=percentiles,
  385. clear_on_reduce=clear_on_reduce,
  386. with_throughput=with_throughput,
  387. throughput_ema_coeff=throughput_ema_coeff,
  388. reduce_per_index_on_aggregate=reduce_per_index_on_aggregate,
  389. )
  390. stats = self._get_key(key)
  391. stats.push(value)
  392. def log_dict(
  393. self,
  394. value_dict,
  395. *,
  396. key: Optional[Union[str, Tuple[str, ...]]] = None,
  397. reduce: Optional[str] = "mean",
  398. window: Optional[Union[int, float]] = None,
  399. ema_coeff: Optional[float] = None,
  400. percentiles: Optional[Union[List[int], bool]] = None,
  401. clear_on_reduce: Optional[bool] = DEPRECATED_VALUE,
  402. with_throughput: Optional[bool] = None,
  403. throughput_ema_coeff: Optional[float] = DEPRECATED_VALUE,
  404. reduce_per_index_on_aggregate: Optional[bool] = DEPRECATED_VALUE,
  405. ) -> None:
  406. """Logs all leafs of a possibly nested dict of values or Stats objects to this logger.
  407. Traverses through all leafs of `stats_dict` and - if a path cannot be found in
  408. this logger yet, will add the `Stats` found at the leaf under that new key.
  409. If a path already exists, will merge the found leaf (`Stats`) with the ones
  410. already logged before. This way, `stats_dict` does NOT have to have
  411. the same structure as what has already been logged to `self`, but can be used to
  412. log values under new keys or nested key paths.
  413. Passing a dict of stats objects allows you to merge dictionaries of stats objects that
  414. have been reduced by other, parallel components.
  415. See MetricsLogger.log_value for more details on the arguments.
  416. """
  417. assert isinstance(
  418. value_dict, dict
  419. ), f"`stats_dict` ({value_dict}) must be dict!"
  420. prefix_key = force_tuple(key)
  421. def _map(path, stat_or_value):
  422. extended_key = prefix_key + force_tuple(tree.flatten(path))
  423. self.log_value(
  424. extended_key,
  425. value=stat_or_value,
  426. reduce=reduce,
  427. window=window,
  428. ema_coeff=ema_coeff,
  429. percentiles=percentiles,
  430. clear_on_reduce=clear_on_reduce,
  431. with_throughput=with_throughput,
  432. throughput_ema_coeff=throughput_ema_coeff,
  433. reduce_per_index_on_aggregate=reduce_per_index_on_aggregate,
  434. )
  435. with self._threading_lock:
  436. tree.map_structure_with_path(_map, value_dict)
  437. @Deprecated(new="aggregate", error=False)
  438. def merge_and_log_n_dicts(self, *args, **kwargs):
  439. return self.aggregate(*args, **kwargs)
  440. def aggregate(
  441. self,
  442. stats_dicts: List[Dict[str, Any]],
  443. *,
  444. key: Optional[Union[str, Tuple[str, ...]]] = None,
  445. ) -> None:
  446. """Merges n stats_dicts and logs result by merging on the time axis with existing stats.
  447. The n stats_dicts should be generated by n parallel components such that merging their
  448. respective stats in parallel is meaningful. Stats can be aggregated at root or intermediate loggers.
  449. This will replace most internal values with the result of the merge.
  450. For exceptions, see the documentation of the individual stats classes `merge` methods.
  451. Args:
  452. stats_dicts: List of n stats dicts to be merged and then logged.
  453. key: Optional top-level key under which to log all keys/key sequences
  454. found in the n `stats_dicts`.
  455. """
  456. all_keys = set()
  457. def traverse_and_add_paths(d, path=()):
  458. if isinstance(d, dict):
  459. new_dict = {}
  460. for key, value in d.items():
  461. new_dict[key] = traverse_and_add_paths(value, path + (key,))
  462. return new_dict
  463. elif isinstance(d, list):
  464. all_keys.add(path)
  465. if len(d) == 1:
  466. return d[0]
  467. return d
  468. else:
  469. # For lists and values, we add the path to the set of all keys
  470. all_keys.add(path)
  471. return d
  472. def build_nested_dict(stats_dict, key):
  473. if isinstance(key, str):
  474. return {key: stats_dict}
  475. elif len(key) > 1:
  476. # Key is tuple of keys so we build a nested dict recursively
  477. return {key[0]: build_nested_dict(stats_dict, key[1:])}
  478. else:
  479. return {key[0]: stats_dict}
  480. # We do one pass over all the stats_dicts_or_loggers to 1. prepend the key if provided and 2. collect all the keys that lead to leaves (which may be lists or values).
  481. incoming_stats_dicts_with_key = []
  482. for stats_dict in stats_dicts:
  483. if key is not None:
  484. stats_dict = build_nested_dict(stats_dict, key)
  485. stats_dict = traverse_and_add_paths(stats_dict)
  486. incoming_stats_dicts_with_key.append(stats_dict)
  487. for key in all_keys:
  488. # Get all incoming Stats objects for this key
  489. incoming_stats = [
  490. self._get_key(key, stats=s)
  491. for s in incoming_stats_dicts_with_key
  492. if self._key_in_stats(key, stats=s)
  493. ]
  494. structure_under_key = self._get_key(key, stats=self.stats, key_error=False)
  495. # self._get_key returns {} if the key is not found
  496. own_stats = (
  497. None if isinstance(structure_under_key, dict) else structure_under_key
  498. )
  499. if own_stats is None:
  500. # This should happen the first time we reduce this stat to the root logger.
  501. # Clone without internal values to create a fresh aggregator
  502. own_stats = incoming_stats[0].clone(
  503. init_overrides={"is_root": self._is_root_logger, "is_leaf": False},
  504. )
  505. if own_stats.has_throughputs:
  506. own_stats.initialize_throughput_reference_time(
  507. self._time_when_initialized
  508. )
  509. else:
  510. # If own_stats exists, it must be a non-leaf stats (created by previous aggregation)
  511. # We cannot aggregate into a leaf stats (created by direct logging)
  512. assert (
  513. not own_stats.is_leaf
  514. ), f"Cannot aggregate into key '{key}' because it was created by direct logging. Aggregation keys must be separate from direct logging keys."
  515. own_stats.merge(incoming_stats=incoming_stats)
  516. self._set_key(key, own_stats)
  517. def log_time(
  518. self,
  519. key: Union[str, Tuple[str, ...]],
  520. *,
  521. reduce: str = "ema",
  522. window: Optional[Union[int, float]] = None,
  523. ema_coeff: Optional[float] = None,
  524. percentiles: Optional[Union[List[int], bool]] = None,
  525. clear_on_reduce: Optional[bool] = DEPRECATED_VALUE,
  526. with_throughput: Optional[bool] = None,
  527. throughput_ema_coeff: Optional[float] = DEPRECATED_VALUE,
  528. reduce_per_index_on_aggregate: Optional[bool] = DEPRECATED_VALUE,
  529. ) -> StatsBase:
  530. """Measures and logs a time delta value under `key` when used with a with-block.
  531. Args:
  532. key: The key (or tuple of keys) to log the measured time delta under.
  533. reduce: The reduction method to apply when compiling metrics at the root logger.
  534. By default, the reduction methods to choose from here are the keys
  535. of rllib.utils.metrics.metrics_logger.DEFAULT_STATS_CLS_LOOKUP.
  536. You can provide your own reduce methods by extending rllib.utils.metrics.metrics_logger.DEFAULT_STATS_CLS_LOOKUP and passing it to AlgorithmConfig.logging()).
  537. window: An optional window size to reduce over.
  538. If not None, then the reduction operation is only applied to the most
  539. recent `window` items, and - after reduction - the internal values list
  540. under `key` is shortened to hold at most `window` items (the most
  541. recent ones).
  542. Must be None if `ema_coeff` is provided.
  543. If None (and `ema_coeff` is None), reduction must not be "mean".
  544. ema_coeff: An optional EMA coefficient to use if `reduce` is "mean"
  545. and no `window` is provided. Note that if both `window` and `ema_coeff`
  546. are provided, an error is thrown. Also, if `ema_coeff` is provided,
  547. `reduce` must be "mean".
  548. The reduction formula for EMA is:
  549. EMA(t1) = (1.0 - ema_coeff) * EMA(t0) + ema_coeff * new_value
  550. percentiles: If reduce is `None`, we can compute the percentiles of the
  551. values list given by `percentiles`. Defaults to [0, 0.5, 0.75, 0.9, 0.95,
  552. 0.99, 1] if set to True. When using percentiles, a window must be provided.
  553. This window should be chosen carefully. RLlib computes exact percentiles and
  554. the computational complexity is O(m*n*log(n/m)) where n is the window size
  555. and m is the number of parallel metrics loggers involved (for example,
  556. m EnvRunners).
  557. clear_on_reduce: Deprecated. Use reduce="lifetime_sum" instead.
  558. If True, all values under `key` will be cleared after
  559. `MetricsLogger.reduce()` is called. Setting this to True is useful for cases,
  560. in which the internal values list would otherwise grow indefinitely,
  561. for example if reduce is None and there is no `window` provided.
  562. with_throughput: Whether to track a throughput estimate together with this
  563. metric. This is only supported for `reduce=sum` and `reduce=lifetime_sum`.
  564. The current throughput estimate of a key can be obtained
  565. through: `MetricsLogger.peek(key, throughput=True)`.
  566. throughput_ema_coeff: Deprecated argument. Throughput is not smoothed with ema anymore
  567. but calculate once per MetricsLogger.reduce() call.
  568. reduce_per_index_on_aggregate: Deprecated argument. Aggregation now happens over all values
  569. of incoming stats objects once per MetricsLogger.reduce() call, treating each incoming value with equal weight.
  570. """
  571. # Prepare the kwargs for the stats object and create it if it doesn't exist
  572. self._maybe_create_stats_object(
  573. key,
  574. reduce=reduce,
  575. window=window,
  576. ema_coeff=ema_coeff,
  577. percentiles=percentiles,
  578. clear_on_reduce=clear_on_reduce,
  579. with_throughput=with_throughput,
  580. throughput_ema_coeff=throughput_ema_coeff,
  581. reduce_per_index_on_aggregate=reduce_per_index_on_aggregate,
  582. )
  583. # Return the Stats object, so a `with` clause can enter and exit it.
  584. return self._get_key(key)
  585. def reduce(self, compile: bool = False) -> Dict:
  586. """Reduces all logged values based on their settings and returns a result dict.
  587. Note to user: Do not call this method directly! This should be called only by RLlib when aggregating stats.
  588. Args:
  589. compile: If True, the result is compiled into a single value if possible.
  590. If it is not possible, the result is a list of values.
  591. If False, the result is a list of one or more values.
  592. Returns:
  593. A dict containing all ever logged nested keys to this MetricsLogger with the leafs being the reduced stats.
  594. """
  595. def _reduce(path: str, stats: StatsBase):
  596. try:
  597. return stats.reduce(compile=compile)
  598. except Exception as e:
  599. raise ValueError(
  600. f"Error reducing stats {stats} with compile={compile} at path {path}."
  601. ) from e
  602. with self._threading_lock:
  603. return tree.map_structure_with_path(_reduce, self.stats)
  604. @Deprecated(
  605. new="log_value",
  606. help="Use log_value with reduce='item' or another reduction method with a window of the appropriate size.",
  607. error=True,
  608. )
  609. def set_value(self, *args, **kwargs) -> None:
  610. ...
  611. def reset(self) -> None:
  612. """Resets all data stored in this MetricsLogger."""
  613. with self._threading_lock:
  614. self.stats = {}
  615. def delete(self, *key: Tuple[str, ...], key_error: bool = True) -> None:
  616. """Deletes the given `key` from this metrics logger's stats.
  617. Args:
  618. key: The key or key sequence (for nested location within self.stats),
  619. to delete from this MetricsLogger's stats.
  620. key_error: Whether to throw a KeyError if `key` cannot be found in `self`.
  621. Raises:
  622. KeyError: If `key` cannot be found in `self` AND `key_error` is True.
  623. """
  624. self._del_key(key, key_error)
  625. def get_state(self) -> Dict[str, Any]:
  626. """Returns the current state of `self` as a dict.
  627. Note that the state is merely the combination of all states of the individual
  628. `Stats` objects stored under `self.stats`.
  629. """
  630. stats_dict = {}
  631. def _map(path, stats):
  632. # Convert keys to strings for msgpack-friendliness.
  633. stats_dict["--".join(path)] = stats.get_state()
  634. with self._threading_lock:
  635. tree.map_structure_with_path(_map, self.stats)
  636. return {"stats": stats_dict}
  637. def set_state(self, state: Dict[str, Any]) -> None:
  638. """Sets the state of `self` to the given `state`.
  639. Args:
  640. state: The state to set `self` to.
  641. """
  642. with self._threading_lock:
  643. # Reset all existing stats to ensure a clean state transition
  644. self.stats = {}
  645. for flat_key, stats_state in state["stats"].items():
  646. if "stats_cls_identifier" in stats_state:
  647. # Having a stats cls identifier means we are using the new stats classes.
  648. cls_identifier = stats_state["stats_cls_identifier"]
  649. assert (
  650. cls_identifier in self.stats_cls_lookup
  651. ), f"Stats class identifier {cls_identifier} not found in stats_cls_lookup. This can happen if you are loading a stats from a checkpoint that was created with a different stats class lookup."
  652. _cls = self.stats_cls_lookup[cls_identifier]
  653. stats = _cls.from_state(state=stats_state)
  654. else:
  655. # We want to preserve compatibility with old checkpoints
  656. # as much as possible.
  657. stats = stats_from_legacy_state(
  658. state=stats_state, is_root=self._is_root_logger
  659. )
  660. self._set_key(flat_key.split("--"), stats)
  661. def _key_in_stats(self, flat_key, *, stats=None):
  662. flat_key = force_tuple(tree.flatten(flat_key))
  663. _dict = stats if stats is not None else self.stats
  664. for key in flat_key:
  665. if key not in _dict:
  666. return False
  667. _dict = _dict[key]
  668. return True
  669. def _get_key(self, flat_key, *, stats=None, key_error=True):
  670. flat_key = force_tuple(tree.flatten(flat_key))
  671. _dict = stats if stats is not None else self.stats
  672. for key in flat_key:
  673. try:
  674. _dict = _dict[key]
  675. except KeyError as e:
  676. if key_error:
  677. raise e
  678. else:
  679. return {}
  680. return _dict
  681. def _set_key(self, flat_key, stats):
  682. flat_key = force_tuple(tree.flatten(flat_key))
  683. with self._threading_lock:
  684. _dict = self.stats
  685. for i, key in enumerate(flat_key):
  686. # If we are at the end of the key sequence, set
  687. # the key, no matter, whether it already exists or not.
  688. if i == len(flat_key) - 1:
  689. _dict[key] = stats
  690. return
  691. # If an intermediary key in the sequence is missing,
  692. # add a sub-dict under this key.
  693. if key not in _dict:
  694. _dict[key] = {}
  695. _dict = _dict[key]
  696. def _del_key(self, flat_key, key_error=False):
  697. flat_key = force_tuple(tree.flatten(flat_key))
  698. with self._threading_lock:
  699. # Erase the key from the (nested) `self.stats` dict.
  700. _dict = self.stats
  701. try:
  702. for i, key in enumerate(flat_key):
  703. if i == len(flat_key) - 1:
  704. del _dict[key]
  705. return
  706. _dict = _dict[key]
  707. except KeyError as e:
  708. if key_error:
  709. raise e
  710. def _get_throughputs(
  711. self, key: Optional[Union[str, Tuple[str, ...]]] = None, default=None
  712. ) -> Union[Dict, float]:
  713. """Returns throughput values for Stats that have throughput tracking enabled.
  714. If no key is provided, returns a nested dict containing throughput values for all Stats
  715. that have throughput tracking enabled. If a key is provided, returns the throughput value
  716. for that specific key or nested structure.
  717. The throughput values represent the rate of change of the corresponding metrics per second.
  718. For example, if a metric represents the number of steps taken, its throughput value would
  719. represent steps per second.
  720. Args:
  721. key: Optional key or nested key path to get throughput for. If provided, returns just
  722. the throughput value for that key or nested structure. If None, returns a nested dict
  723. with throughputs for all metrics.
  724. default: Default value to return if no throughput values are found.
  725. Returns:
  726. If key is None: A nested dict with the same structure as self.stats but with "_throughput"
  727. appended to leaf keys and throughput values as leaf values. Only includes entries for
  728. Stats objects that have throughput tracking enabled.
  729. If key is provided: The throughput value for that specific key or nested structure.
  730. """
  731. def _nested_throughputs(stats):
  732. """Helper function to calculate throughputs for a nested structure."""
  733. def _transform(path, value):
  734. if isinstance(value, StatsBase) and value.has_throughputs:
  735. # Convert path to tuple for consistent key handling
  736. key = force_tuple(path)
  737. # Add "_throughput" to the last key in the path
  738. return key[:-1] + (key[-1] + "_throughputs",), value.throughputs
  739. return path, value
  740. result = {}
  741. for path, value in tree.flatten_with_path(stats):
  742. new_path, new_value = _transform(path, value)
  743. if isinstance(new_value, float): # Only include throughput values
  744. _dict = result
  745. for k in new_path[:-1]:
  746. if k not in _dict:
  747. _dict[k] = {}
  748. _dict = _dict[k]
  749. _dict[new_path[-1]] = new_value
  750. return result
  751. with self._threading_lock:
  752. if key is not None:
  753. # Get the Stats object or nested structure for the key
  754. stats = self._get_key(key, key_error=False)
  755. if isinstance(stats, StatsBase):
  756. if not stats.has_throughputs:
  757. raise ValueError(
  758. f"Key '{key}' does not have throughput tracking enabled"
  759. )
  760. return stats.throughputs
  761. elif stats == {}:
  762. # If the key is not found, return the default value
  763. return default
  764. else:
  765. # stats is a non-empty dictionary
  766. return _nested_throughputs(stats)
  767. throughputs = {}
  768. def _map(path, stats):
  769. if isinstance(stats, StatsBase) and stats.has_throughputs:
  770. # Convert path to tuple for consistent key handling
  771. key = force_tuple(path)
  772. # Add "_throughput" to the last key in the path
  773. key = key[:-1] + (key[-1] + "_throughput",)
  774. # Set the throughput value in the nested structure
  775. _dict = throughputs
  776. for k in key[:-1]:
  777. if k not in _dict:
  778. _dict[k] = {}
  779. _dict = _dict[k]
  780. _dict[key[-1]] = stats.throughputs
  781. tree.map_structure_with_path(_map, self.stats)
  782. return throughputs if throughputs else default
  783. def compile(self) -> Dict:
  784. """Compiles all current values and throughputs into a single dictionary.
  785. This method combines the results of all stats and throughputs into a single
  786. dictionary, with throughput values having a "_throughput" suffix. This is useful
  787. for getting a complete snapshot of all metrics and their throughputs in one call.
  788. Returns:
  789. A nested dictionary containing both the current values and throughputs for all
  790. metrics. The structure matches self.stats, with throughput values having
  791. "_throughput" suffix in their keys.
  792. """
  793. # Get all throughputs
  794. throughputs = self._get_throughputs()
  795. # Get all current values
  796. values = self.reduce(compile=True)
  797. deep_update(values, throughputs or {}, new_keys_allowed=True)
  798. def traverse_dict(d):
  799. if isinstance(d, dict):
  800. new_dict = {}
  801. for key, value in d.items():
  802. new_dict[key] = traverse_dict(value)
  803. return new_dict
  804. elif isinstance(d, list):
  805. if len(d) == 1:
  806. return d[0]
  807. # If value is a longer list, we should just return the list because there is no reduction method applied
  808. return d
  809. else:
  810. # If the value is not a list, it is a single value and we can yield it
  811. return d
  812. return traverse_dict(values)
  813. @Deprecated(
  814. new="",
  815. help="Tensor mode is not required anymore.",
  816. error=False,
  817. )
  818. def activate_tensor_mode(self):
  819. pass
  820. @Deprecated(
  821. new="",
  822. help="Tensor mode is not required anymore.",
  823. error=False,
  824. )
  825. def deactivate_tensor_mode(self):
  826. pass
  827. class _DummyRLock:
  828. def acquire(self, blocking=True, timeout=-1):
  829. return True
  830. def release(self):
  831. pass
  832. def __enter__(self):
  833. return self
  834. def __exit__(self, exc_type, exc_value, traceback):
  835. pass