legacy_stats.py 44 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041
  1. import copy
  2. import heapq
  3. import threading
  4. import time
  5. import uuid
  6. from collections import defaultdict, deque
  7. from typing import Any, Dict, List, Optional, Tuple, Union
  8. import numpy as np
  9. from ray._common.deprecation import Deprecated
  10. from ray.rllib.utils import force_list
  11. from ray.rllib.utils.framework import try_import_torch
  12. from ray.rllib.utils.numpy import convert_to_numpy
  13. from ray.util.annotations import DeveloperAPI
  14. torch, _ = try_import_torch()
  15. @Deprecated(new="rllib.utils.metrics.stats", error=False)
  16. @DeveloperAPI
  17. class Stats:
  18. """A container class holding a number of values and executing reductions over them.
  19. The individual values in a Stats object may be of any type, for example python int
  20. or float, numpy arrays, or more complex structured (tuple, dict) and are stored in
  21. a list under `self.values`. This class is not meant to be interfaced with directly
  22. from application code. Instead, use `MetricsLogger` to log and manipulate Stats.
  23. Stats can be used to store metrics of the same type over time, for example a loss
  24. or a learning rate, and to reduce all stored values applying a certain reduction
  25. mechanism (for example "mean" or "sum").
  26. Available reduction mechanisms are:
  27. - "mean" using EMA with a configurable EMA coefficient.
  28. - "mean" using a sliding window (over the last n stored values).
  29. - "max/min" with an optional sliding window (over the last n stored values).
  30. - "sum" with an optional sliding window (over the last n stored values).
  31. - None: Simply store all logged values to an ever-growing list.
  32. Through the `reduce()` API, one of the above-mentioned reduction mechanisms will
  33. be executed on `self.values`.
  34. """
  35. def __init__(
  36. self,
  37. init_values: Optional[Any] = None,
  38. reduce: Optional[str] = "mean",
  39. percentiles: Union[List[int], bool] = False,
  40. reduce_per_index_on_aggregate: bool = False,
  41. window: Optional[Union[int, float]] = None,
  42. ema_coeff: Optional[float] = None,
  43. clear_on_reduce: bool = False,
  44. throughput: Union[bool, float] = False,
  45. throughput_ema_coeff: Optional[float] = None,
  46. ):
  47. """Initializes a Stats instance.
  48. Args:
  49. init_values: Optional initial values to be placed into `self.values`. If None,
  50. `self.values` will start empty. If percentiles is True, values must be ordered
  51. if provided.
  52. reduce: The name of the reduce method to be used. Allowed are "mean", "min",
  53. "max", and "sum". Use None to apply no reduction method (leave
  54. `self.values` as-is when reducing, except for shortening it to
  55. `window`). Note that if both `reduce` and `window` are None, the user of
  56. this Stats object needs to apply some caution over the values list not
  57. growing infinitely.
  58. percentiles: If reduce is `None`, we can compute the percentiles of the
  59. values list given by `percentiles`. Defaults to [0, 50, 75, 90, 95,
  60. 99, 100] if set to True. When using percentiles, a window must be provided.
  61. This window should be chosen carfully. RLlib computes exact percentiles and
  62. the computational complexity is O(m*n*log(n/m)) where n is the window size
  63. and m is the number of parallel metrics loggers invovled (for example,
  64. m EnvRunners). To be safe, choose a window < 1M and less than 1000 Stats
  65. objects to aggregate. See #52963 for more details.
  66. window: An optional window size to reduce over.
  67. If `window` is not None, then the reduction operation is only applied to
  68. the most recent `windows` items, and - after reduction - the values list
  69. is shortened to hold at most `window` items (the most recent ones).
  70. Must be None if `ema_coeff` is not None.
  71. If `window` is None (and `ema_coeff` is None), reduction must not be
  72. "mean".
  73. reduce_per_index_on_aggregate: If True, when merging Stats objects, we reduce
  74. incoming values per index such that the new value at index `n` will be
  75. the reduced value of all incoming values at index `n`.
  76. If False, when reducing `n` Stats, the first `n` merged values will be
  77. the reduced value of all incoming values at index `0`, the next `n` merged
  78. values will be the reduced values of all incoming values at index `1`, etc.
  79. ema_coeff: An optional EMA coefficient to use if reduce is "mean"
  80. and no `window` is provided. Note that if both `window` and `ema_coeff`
  81. are provided, an error is thrown. Also, if `ema_coeff` is provided,
  82. `reduce` must be "mean".
  83. The reduction formula for EMA performed by Stats is:
  84. EMA(t1) = (1.0 - ema_coeff) * EMA(t0) + ema_coeff * new_value
  85. clear_on_reduce: If True, the Stats object will reset its entire values list
  86. to an empty one after `self.reduce()` is called. However, it will then
  87. return from the `self.reduce()` call a new Stats object with the
  88. properly reduced (not completely emptied) new values. Setting this
  89. to True is useful for cases, in which the internal values list would
  90. otherwise grow indefinitely, for example if reduce is None and there
  91. is no `window` provided.
  92. throughput: If True, track a throughput estimate together with this
  93. Stats. This is only supported for `reduce=sum` and
  94. `clear_on_reduce=False` metrics (aka. "lifetime counts"). The `Stats`
  95. then keeps track of the time passed between two consecutive calls to
  96. `reduce()` and update its throughput estimate. The current throughput
  97. estimate can be obtained through:
  98. `throughput_per_sec = Stats.peek(throughput=True)`.
  99. If a float, track throughput and also set current throughput estimate
  100. to the given value.
  101. throughput_ema_coeff: An optional EMA coefficient to use for throughput tracking.
  102. Only used if throughput=True.
  103. """
  104. # Thus far, we only support mean, max, min, and sum.
  105. if reduce not in [None, "mean", "min", "max", "sum", "percentiles"]:
  106. raise ValueError(
  107. "`reduce` must be one of `mean|min|max|sum|percentiles` or None!"
  108. )
  109. # One or both window and ema_coeff must be None.
  110. if window is not None and ema_coeff is not None:
  111. raise ValueError("Only one of `window` or `ema_coeff` can be specified!")
  112. # If `ema_coeff` is provided, `reduce` must be "mean".
  113. if ema_coeff is not None and reduce != "mean":
  114. raise ValueError(
  115. "`ema_coeff` arg only allowed (not None) when `reduce=mean`!"
  116. )
  117. if percentiles is not False:
  118. if reduce is not None:
  119. raise ValueError(
  120. "`reduce` must be `None` when `percentiles` is not `False`!"
  121. )
  122. if window in (None, float("inf")):
  123. raise ValueError(
  124. "A window must be specified when reduce is 'percentiles'!"
  125. )
  126. if reduce_per_index_on_aggregate is not False:
  127. raise ValueError(
  128. f"`reduce_per_index_on_aggregate` ({reduce_per_index_on_aggregate})"
  129. f" must be `False` when `percentiles` is not `False`!"
  130. )
  131. if percentiles is True:
  132. percentiles = [0, 50, 75, 90, 95, 99, 100]
  133. else:
  134. if type(percentiles) not in (bool, list):
  135. raise ValueError("`percentiles` must be a list or bool!")
  136. if isinstance(percentiles, list):
  137. if not all(isinstance(p, (int, float)) for p in percentiles):
  138. raise ValueError(
  139. "`percentiles` must contain only ints or floats!"
  140. )
  141. if not all(0 <= p <= 100 for p in percentiles):
  142. raise ValueError(
  143. "`percentiles` must contain only values between 0 and 100!"
  144. )
  145. self._percentiles = percentiles
  146. # If `window` is explicitly set to inf, `clear_on_reduce` must be True.
  147. self._inf_window = window in [None, float("inf")]
  148. # If `window` is set to inf, `clear_on_reduce` must be True.
  149. # Otherwise, we risk a memory leak.
  150. if self._inf_window and not clear_on_reduce and reduce is None:
  151. raise ValueError(
  152. "When using an infinite window without reduction, `clear_on_reduce` must "
  153. "be set to True!"
  154. )
  155. # If reduce=mean AND window=ema_coeff=None, we use EMA by default with a coeff
  156. # of 0.01 (we do NOT support infinite window sizes for mean as that would mean
  157. # to keep data in the cache forever).
  158. if reduce == "mean" and self._inf_window and ema_coeff is None:
  159. ema_coeff = 0.01
  160. self._reduce_method = reduce
  161. self._window = window
  162. self._ema_coeff = ema_coeff
  163. if (
  164. self._reduce_method not in ["mean", "sum", "min", "max"]
  165. and reduce_per_index_on_aggregate
  166. ):
  167. raise ValueError(
  168. "reduce_per_index_on_aggregate is only supported for mean, sum, min, and max reduction!"
  169. )
  170. self._reduce_per_index_on_aggregate = reduce_per_index_on_aggregate
  171. # Timing functionality (keep start times per thread).
  172. self._start_times = defaultdict(lambda: None)
  173. # Simply store ths flag for the user of this class.
  174. self._clear_on_reduce = clear_on_reduce
  175. self._has_returned_zero = False
  176. # On each `.reduce()` call, we store the result of this call in
  177. # self._last_reduce.
  178. self._last_reduced = [np.nan]
  179. # The ID of this Stats instance.
  180. self.id_ = str(uuid.uuid4())
  181. self._prev_merge_values = defaultdict(int)
  182. self._throughput_ema_coeff = throughput_ema_coeff
  183. self._throughput_stats = None
  184. if throughput is not False:
  185. self._throughput_stats = Stats(
  186. # We have to check for bool here because in Python, bool is a subclass
  187. # of int.
  188. init_values=[throughput]
  189. if (
  190. isinstance(throughput, (int, float))
  191. and not isinstance(throughput, bool)
  192. )
  193. else None,
  194. reduce="mean",
  195. ema_coeff=throughput_ema_coeff,
  196. window=None,
  197. clear_on_reduce=False,
  198. throughput=False,
  199. throughput_ema_coeff=None,
  200. )
  201. if init_values is not None:
  202. self._last_throughput_measure_time = time.perf_counter()
  203. else:
  204. self._last_throughput_measure_time = (
  205. -1
  206. ) # Track last push time for throughput calculation
  207. # The actual, underlying data in this Stats object.
  208. self.values: Union[List, deque.Deque] = None
  209. self._set_values(force_list(init_values))
  210. self._is_tensor = False
  211. # Track if new values were pushed since last reduce
  212. if init_values is not None:
  213. self._has_new_values = True
  214. else:
  215. self._has_new_values = False
  216. def check_value(self, value: Any) -> None:
  217. # If we have a reduce method, value should always be a scalar
  218. # If we don't reduce, we can keep track of value as it is
  219. if self._reduce_method is not None:
  220. if isinstance(value, np.ndarray) and value.shape == ():
  221. return
  222. elif torch and torch.is_tensor(value):
  223. self._is_tensor = True
  224. if tuple(value.shape) == ():
  225. return
  226. elif type(value) not in (list, tuple, deque):
  227. return
  228. raise ValueError(
  229. f"Value ({value}) is required to be a scalar when using a reduce "
  230. "method!"
  231. )
  232. def push(self, value: Any) -> None:
  233. """Pushes a value into this Stats object.
  234. Args:
  235. value: The value to be pushed. Can be of any type.
  236. """
  237. self.check_value(value)
  238. # If throughput tracking is enabled, calculate it based on time between pushes
  239. if self.has_throughput:
  240. self._recompute_throughput(value)
  241. # Handle different reduction methods
  242. if self._window is not None:
  243. # For windowed operations, append to values and trim if needed
  244. self.values.append(value)
  245. if len(self.values) > self._window:
  246. self.values.popleft()
  247. else:
  248. # For non-windowed operations, use _reduced_values
  249. if len(self.values) == 0:
  250. self._set_values([value])
  251. else:
  252. self.values.append(value)
  253. _, values = self._reduced_values()
  254. self._set_values(values)
  255. # Mark that we have new values
  256. self._has_new_values = True
  257. def __enter__(self) -> "Stats":
  258. """Called when entering a context (with which users can measure a time delta).
  259. Returns:
  260. This Stats instance (self), unless another thread has already entered (and
  261. not exited yet), in which case a copy of `self` is returned. This way, the
  262. second thread(s) cannot mess with the original Stat's (self) time-measuring.
  263. This also means that only the first thread to __enter__ actually logs into
  264. `self` and the following threads' measurements are discarded (logged into
  265. a non-referenced shim-Stats object, which will simply be garbage collected).
  266. """
  267. # In case another thread already is measuring this Stats (timing), simply ignore
  268. # the "enter request" and return a clone of `self`.
  269. thread_id = threading.get_ident()
  270. self._start_times[thread_id] = time.perf_counter()
  271. return self
  272. def __exit__(self, exc_type, exc_value, tb) -> None:
  273. """Called when exiting a context (with which users can measure a time delta)."""
  274. thread_id = threading.get_ident()
  275. assert self._start_times[thread_id] is not None
  276. time_delta_s = time.perf_counter() - self._start_times[thread_id]
  277. self.push(time_delta_s)
  278. del self._start_times[thread_id]
  279. def peek(self, compile: bool = True) -> Union[Any, List[Any]]:
  280. """Returns the result of reducing the internal values list.
  281. Note that this method does NOT alter the internal values list in this process.
  282. Thus, users can call this method to get an accurate look at the reduced value(s)
  283. given the current internal values list.
  284. Args:
  285. compile: If True, the result is compiled into a single value if possible.
  286. Returns:
  287. The result of reducing the internal values list.
  288. """
  289. if self._has_new_values or (not compile and not self._inf_window):
  290. reduced_value, reduced_values = self._reduced_values()
  291. if not compile and not self._inf_window:
  292. return reduced_values
  293. if compile and self._reduce_method:
  294. return reduced_value[0]
  295. if compile and self._percentiles is not False:
  296. return compute_percentiles(reduced_values, self._percentiles)
  297. return reduced_value
  298. else:
  299. return_value = self._last_reduced
  300. if compile:
  301. # We don't need to check for self._reduce_method or percentiles here
  302. # because we only store the reduced value if there is a reduce method.
  303. return_value = return_value[0]
  304. return return_value
  305. @property
  306. def throughput(self) -> float:
  307. """Returns the current throughput estimate per second.
  308. Raises:
  309. ValueError: If throughput tracking is not enabled for this Stats object.
  310. Returns:
  311. The current throughput estimate per second.
  312. """
  313. if not self.has_throughput:
  314. raise ValueError("Throughput tracking is not enabled for this Stats object")
  315. # We can always return the first value here because throughput is a single value
  316. return self._throughput_stats.peek()
  317. @property
  318. def has_throughput(self) -> bool:
  319. """Returns whether this Stats object tracks throughput.
  320. Returns:
  321. True if this Stats object has throughput tracking enabled, False otherwise.
  322. """
  323. return self._throughput_stats is not None
  324. def reduce(self, compile: bool = True) -> Union[Any, List[Any]]:
  325. """Reduces the internal values list according to the constructor settings.
  326. Thereby, the internal values list is changed (note that this is different from
  327. `peek()`, where the internal list is NOT changed). See the docstring of this
  328. class for details on the reduction logic applied to the values list, based on
  329. the constructor settings, such as `window`, `reduce`, etc..
  330. Args:
  331. compile: If True, the result is compiled into a single value if possible.
  332. If it is not possible, the result is a list of values.
  333. If False, the result is a list of one or more values.
  334. Returns:
  335. The reduced value (can be of any type, depending on the input values and
  336. reduction method).
  337. """
  338. len_before_reduce = len(self)
  339. if self._has_new_values:
  340. # Only calculate and update history if there were new values pushed since
  341. # last reduce
  342. reduced, reduced_internal_values_list = self._reduced_values()
  343. # `clear_on_reduce` -> Clear the values list.
  344. if self._clear_on_reduce:
  345. self._set_values([])
  346. else:
  347. self._set_values(reduced_internal_values_list)
  348. else:
  349. reduced_internal_values_list = None
  350. reduced = self._last_reduced
  351. reduced = self._numpy_if_necessary(reduced)
  352. # Shift historic reduced valued by one in our reduce_history.
  353. if self._reduce_method is not None:
  354. # It only makes sense to extend the history if we are reducing to a single
  355. # value. We need to make a copy here because the new_values_list is a
  356. # reference to the internal values list
  357. self._last_reduced = force_list(reduced.copy())
  358. else:
  359. # If there is a window and no reduce method, we don't want to use the reduce
  360. # history to return reduced values in other methods
  361. self._has_new_values = True
  362. if compile and self._reduce_method is not None:
  363. assert (
  364. len(reduced) == 1
  365. ), f"Reduced values list must contain exactly one value, found {reduced}"
  366. reduced = reduced[0]
  367. if not compile and not self._inf_window:
  368. if reduced_internal_values_list is None:
  369. _, reduced_internal_values_list = self._reduced_values()
  370. return_values = self._numpy_if_necessary(
  371. reduced_internal_values_list
  372. ).copy()
  373. elif compile and self._percentiles is not False:
  374. if reduced_internal_values_list is None:
  375. _, reduced_internal_values_list = self._reduced_values()
  376. return_values = compute_percentiles(
  377. reduced_internal_values_list, self._percentiles
  378. )
  379. else:
  380. return_values = reduced
  381. if compile:
  382. return return_values
  383. else:
  384. if len_before_reduce == 0:
  385. # return_values will be be 0 if we reduce a sum over zero elements
  386. # But we don't want to create such a zero out of nothing for our new
  387. # Stats object that we return here
  388. return Stats.similar_to(self)
  389. return Stats.similar_to(self, init_values=return_values)
  390. def merge_on_time_axis(self, other: "Stats") -> None:
  391. """Merges another Stats object's values into this one along the time axis.
  392. Args:
  393. other: The other Stats object to merge values from.
  394. """
  395. self.values.extend(other.values)
  396. # Mark that we have new values since we modified the values list
  397. self._has_new_values = True
  398. def merge_in_parallel(self, *others: "Stats") -> None:
  399. """Merges all internal values of `others` into `self`'s internal values list.
  400. Thereby, the newly incoming values of `others` are treated equally with respect
  401. to each other as well as with respect to the internal values of self.
  402. Use this method to merge other `Stats` objects, which resulted from some
  403. parallelly executed components, into this one. For example: n Learner workers
  404. all returning a loss value in the form of `{"total_loss": [some value]}`.
  405. The following examples demonstrate the parallel merging logic for different
  406. reduce- and window settings:
  407. Args:
  408. others: One or more other Stats objects that need to be parallely merged
  409. into `self, meaning with equal weighting as the existing values in
  410. `self`.
  411. """
  412. win = self._window or float("inf")
  413. # If any of the value lists have a length of 0 or if there is only one value and
  414. # it is nan, we skip
  415. stats_to_merge = [
  416. s
  417. for s in [self, *others]
  418. if not (
  419. len(s) == 0
  420. or (
  421. len(s) == 1 and np.all(np.isnan(self._numpy_if_necessary(s.values)))
  422. )
  423. )
  424. ]
  425. # If there is only one stat to merge, and it is the same as self, return.
  426. if len(stats_to_merge) == 0:
  427. # If none of the stats have values, return.
  428. return
  429. elif len(stats_to_merge) == 1:
  430. if stats_to_merge[0] == self:
  431. # If no incoming stats have values, return.
  432. return
  433. else:
  434. # If there is only one stat with values, and it's incoming, copy its
  435. # values.
  436. self.values = stats_to_merge[0].values
  437. return
  438. # Take turns stepping through `self` and `*others` values, thereby moving
  439. # backwards from last index to beginning and will up the resulting values list.
  440. # Stop as soon as we reach the window size.
  441. new_values = []
  442. tmp_values = []
  443. if self._percentiles is not False:
  444. # Use heapq to sort values (assumes that the values are already sorted)
  445. # and then pick the correct percentiles
  446. lists_to_merge = [list(self.values), *[list(o.values) for o in others]]
  447. merged = list(heapq.merge(*lists_to_merge))
  448. self._set_values(merged)
  449. else:
  450. # Loop from index=-1 backward to index=start until our new_values list has
  451. # at least a len of `win`.
  452. for i in range(1, max(map(len, stats_to_merge)) + 1):
  453. # Per index, loop through all involved stats, including `self` and add
  454. # to `tmp_values`.
  455. for stats in stats_to_merge:
  456. if len(stats) < i:
  457. continue
  458. tmp_values.append(stats.values[-i])
  459. # Now reduce across `tmp_values` based on the reduce-settings of this
  460. # Stats.
  461. if self._reduce_per_index_on_aggregate:
  462. n_values = 1
  463. else:
  464. n_values = len(tmp_values)
  465. if self._ema_coeff is not None:
  466. new_values.extend([np.nanmean(tmp_values)] * n_values)
  467. elif self._reduce_method is None:
  468. new_values.extend(tmp_values)
  469. elif self._reduce_method == "sum":
  470. # We add [sum(tmp_values) / n_values] * n_values to the new values
  471. # list instead of tmp_values, because every incoming element should
  472. # have the same weight.
  473. added_sum = self._reduced_values(values=tmp_values)[0][0]
  474. new_values.extend([added_sum / n_values] * n_values)
  475. if self.has_throughput:
  476. self._recompute_throughput(added_sum)
  477. else:
  478. new_values.extend(
  479. self._reduced_values(values=tmp_values)[0] * n_values
  480. )
  481. tmp_values.clear()
  482. if len(new_values) >= win:
  483. new_values = new_values[:win]
  484. break
  485. self._set_values(list(reversed(new_values)))
  486. # Mark that we have new values since we modified the values list
  487. self._has_new_values = True
  488. def clear_throughput(self) -> None:
  489. """Clears the throughput Stats, if applicable and `self` has throughput.
  490. Also resets `self._last_throughput_measure_time` to -1 such that the Stats
  491. object has to create a new timestamp first, before measuring any new throughput
  492. values.
  493. """
  494. if self.has_throughput:
  495. self._throughput_stats._set_values([])
  496. self._last_throughput_measure_time = -1
  497. def _recompute_throughput(self, value) -> None:
  498. """Recomputes the current throughput value of this Stats instance."""
  499. # Make sure this Stats object does measure throughput.
  500. assert self.has_throughput
  501. # Take the current time stamp.
  502. current_time = time.perf_counter()
  503. # Check, whether we have a previous timestamp (non -1).
  504. if self._last_throughput_measure_time >= 0:
  505. # Compute the time delta.
  506. time_diff = current_time - self._last_throughput_measure_time
  507. # Avoid divisions by zero.
  508. if time_diff > 0:
  509. # Push new throughput value into our throughput stats object.
  510. self._throughput_stats.push(value / time_diff)
  511. # Update the time stamp of the most recent throughput computation (this one).
  512. self._last_throughput_measure_time = current_time
  513. @staticmethod
  514. def _numpy_if_necessary(values):
  515. # Torch tensor handling. Convert to CPU/numpy first.
  516. if torch and len(values) > 0 and torch.is_tensor(values[0]):
  517. # Convert all tensors to numpy values.
  518. values = [v.cpu().numpy() for v in values]
  519. return values
  520. def __len__(self) -> int:
  521. """Returns the length of the internal values list."""
  522. return len(self.values)
  523. def __repr__(self) -> str:
  524. win_or_ema = (
  525. f"; win={self._window}"
  526. if self._window
  527. else f"; ema={self._ema_coeff}"
  528. if self._ema_coeff
  529. else ""
  530. )
  531. return (
  532. f"Stats({self.peek()}; len={len(self)}; "
  533. f"reduce={self._reduce_method}{win_or_ema})"
  534. )
  535. def __int__(self):
  536. if self._reduce_method is None:
  537. raise ValueError(
  538. "Cannot convert Stats object with reduce method `None` to int because "
  539. "it can not be reduced to a single value."
  540. )
  541. else:
  542. return int(self.peek())
  543. def __float__(self):
  544. if self._reduce_method is None:
  545. raise ValueError(
  546. "Cannot convert Stats object with reduce method `None` to float "
  547. "because it can not be reduced to a single value."
  548. )
  549. else:
  550. return float(self.peek())
  551. def __eq__(self, other):
  552. if self._reduce_method is None:
  553. self._comp_error("__eq__")
  554. else:
  555. return float(self) == float(other)
  556. def __le__(self, other):
  557. if self._reduce_method is None:
  558. self._comp_error("__le__")
  559. else:
  560. return float(self) <= float(other)
  561. def __ge__(self, other):
  562. if self._reduce_method is None:
  563. self._comp_error("__ge__")
  564. else:
  565. return float(self) >= float(other)
  566. def __lt__(self, other):
  567. if self._reduce_method is None:
  568. self._comp_error("__lt__")
  569. else:
  570. return float(self) < float(other)
  571. def __gt__(self, other):
  572. if self._reduce_method is None:
  573. self._comp_error("__gt__")
  574. else:
  575. return float(self) > float(other)
  576. def __add__(self, other):
  577. if self._reduce_method is None:
  578. self._comp_error("__add__")
  579. else:
  580. return float(self) + float(other)
  581. def __sub__(self, other):
  582. if self._reduce_method is None:
  583. self._comp_error("__sub__")
  584. else:
  585. return float(self) - float(other)
  586. def __mul__(self, other):
  587. if self._reduce_method is None:
  588. self._comp_error("__mul__")
  589. else:
  590. return float(self) * float(other)
  591. def __format__(self, fmt):
  592. if self._reduce_method is None:
  593. raise ValueError(
  594. "Cannot format Stats object with reduce method `None` because it can "
  595. "not be reduced to a single value."
  596. )
  597. else:
  598. return f"{float(self):{fmt}}"
  599. def _comp_error(self, comp):
  600. raise ValueError(
  601. f"Cannot {comp} Stats object with reduce method `None` to other "
  602. "because it can not be reduced to a single value."
  603. )
  604. def get_state(self) -> Dict[str, Any]:
  605. state = {
  606. # Make sure we don't return any tensors here.
  607. "values": convert_to_numpy(self.values),
  608. "reduce": self._reduce_method,
  609. "percentiles": self._percentiles,
  610. "reduce_per_index_on_aggregate": self._reduce_per_index_on_aggregate,
  611. "window": self._window,
  612. "ema_coeff": self._ema_coeff,
  613. "clear_on_reduce": self._clear_on_reduce,
  614. "_last_reduced": self._last_reduced,
  615. "_is_tensor": self._is_tensor,
  616. }
  617. if self._throughput_stats is not None:
  618. state["throughput_stats"] = self._throughput_stats.get_state()
  619. return state
  620. @staticmethod
  621. def from_state(state: Dict[str, Any]) -> "Stats":
  622. # If `values` could contain tensors, don't reinstate them (b/c we don't know
  623. # whether we are on a supported device).
  624. values = state["values"]
  625. if "_is_tensor" in state and state["_is_tensor"]:
  626. values = []
  627. if "throughput_stats" in state:
  628. throughput_stats = Stats.from_state(state["throughput_stats"])
  629. stats = Stats(
  630. values,
  631. reduce=state["reduce"],
  632. percentiles=state.get("percentiles", False),
  633. reduce_per_index_on_aggregate=state.get(
  634. "reduce_per_index_on_aggregate", False
  635. ),
  636. window=state["window"],
  637. ema_coeff=state["ema_coeff"],
  638. clear_on_reduce=state["clear_on_reduce"],
  639. throughput=throughput_stats.peek(),
  640. throughput_ema_coeff=throughput_stats._ema_coeff,
  641. )
  642. elif state.get("_throughput", False):
  643. # Older checkpoints have a _throughput key that is boolean or
  644. # a float (throughput value). They don't have a throughput_ema_coeff
  645. # so we use a default of 0.05.
  646. # TODO(Artur): Remove this after a few Ray releases.
  647. stats = Stats(
  648. values,
  649. reduce=state["reduce"],
  650. percentiles=state.get("percentiles", False),
  651. window=state["window"],
  652. ema_coeff=state["ema_coeff"],
  653. clear_on_reduce=state["clear_on_reduce"],
  654. throughput=state["_throughput"],
  655. throughput_ema_coeff=0.05,
  656. )
  657. else:
  658. stats = Stats(
  659. values,
  660. reduce=state["reduce"],
  661. percentiles=state.get("percentiles", False),
  662. window=state["window"],
  663. ema_coeff=state["ema_coeff"],
  664. clear_on_reduce=state["clear_on_reduce"],
  665. throughput=False,
  666. throughput_ema_coeff=None,
  667. )
  668. # Compatibility to old checkpoints where a reduce sometimes resulted in a single
  669. # values instead of a list such that the history would be a list of integers
  670. # instead of a list of lists.
  671. if "_hist" in state:
  672. # TODO(Artur): Remove this after a few Ray releases.
  673. if not isinstance(state["_hist"][0], list):
  674. state["_hist"] = list(map(lambda x: [x], state["_hist"]))
  675. stats._last_reduced = state["_hist"][-1]
  676. else:
  677. stats._last_reduced = state.get("_last_reduced", [np.nan])
  678. return stats
  679. @staticmethod
  680. def similar_to(
  681. other: "Stats",
  682. init_values: Optional[Any] = None,
  683. ) -> "Stats":
  684. """Returns a new Stats object that's similar to `other`.
  685. "Similar" here means it has the exact same settings (reduce, window, ema_coeff,
  686. etc..). The initial values of the returned `Stats` are empty by default, but
  687. can be set as well.
  688. Args:
  689. other: The other Stats object to return a similar new Stats equivalent for.
  690. init_value: The initial value to already push into the returned Stats.
  691. Returns:
  692. A new Stats object similar to `other`, with the exact same settings and
  693. maybe a custom initial value (if provided; otherwise empty).
  694. """
  695. stats = Stats(
  696. init_values=init_values,
  697. reduce=other._reduce_method,
  698. percentiles=other._percentiles,
  699. reduce_per_index_on_aggregate=other._reduce_per_index_on_aggregate,
  700. window=other._window,
  701. ema_coeff=other._ema_coeff,
  702. clear_on_reduce=other._clear_on_reduce,
  703. throughput=other._throughput_stats.peek()
  704. if other.has_throughput
  705. else False,
  706. throughput_ema_coeff=other._throughput_ema_coeff,
  707. )
  708. stats.id_ = other.id_
  709. stats._last_reduced = other._last_reduced
  710. return stats
  711. def _set_values(self, new_values):
  712. # For stats with window, use a deque with maxlen=window.
  713. # This way, we never store more values than absolutely necessary.
  714. if not self._inf_window:
  715. self.values = deque(new_values, maxlen=self._window)
  716. # For infinite windows, use `new_values` as-is (a list).
  717. else:
  718. self.values = new_values
  719. self._has_new_values = True
  720. def _reduced_values(self, values=None) -> Tuple[Any, Any]:
  721. """Runs a non-committed reduction procedure on given values (or `self.values`).
  722. Note that this method does NOT alter any state of `self` or the possibly
  723. provided list of `values`. It only returns new values as they should be
  724. adopted after a possible, actual reduction step.
  725. Args:
  726. values: The list of values to reduce. If not None, use `self.values`
  727. Returns:
  728. A tuple containing 1) the reduced values and 2) the new internal values list
  729. to be used. If there is no reduciton method, the reduced values will be the same as the values.
  730. """
  731. values = values if values is not None else self.values
  732. # No reduction method. Return list as-is OR reduce list to len=window.
  733. if self._reduce_method is None:
  734. if self._percentiles is not False:
  735. # Sort values
  736. values = list(values)
  737. # (Artur): Numpy can sort faster than Python's built-in sort for large lists. Howoever, if we convert to an array here
  738. # and then sort, this only slightly (<2x) improved the runtime of this method, even for an internal values list of 1M values.
  739. values.sort()
  740. return values, values
  741. # Special case: Internal values list is empty -> return NaN or 0.0 for sum.
  742. elif len(values) == 0:
  743. if self._reduce_method in ["min", "max", "mean"] or self._has_returned_zero:
  744. # We also return np.nan if we have returned zero before.
  745. # This helps with cases where stats are cleared on reduce, but we don't want to log 0's, except for the first time.
  746. return [np.nan], []
  747. else:
  748. return [0], []
  749. # Do EMA (always a "mean" reduction; possibly using a window).
  750. elif self._ema_coeff is not None:
  751. # Perform EMA reduction over all values in internal values list.
  752. mean_value = values[0]
  753. for v in values[1:]:
  754. mean_value = self._ema_coeff * v + (1.0 - self._ema_coeff) * mean_value
  755. if self._inf_window:
  756. return [mean_value], [mean_value]
  757. else:
  758. return [mean_value], values
  759. # Non-EMA reduction (possibly using a window).
  760. else:
  761. # Use the numpy/torch "nan"-prefix to ignore NaN's in our value lists.
  762. if torch and torch.is_tensor(values[0]):
  763. self._is_tensor = True
  764. # Only one item in the
  765. if len(values[0].shape) == 0:
  766. reduced = values[0]
  767. else:
  768. reduce_meth = getattr(torch, "nan" + self._reduce_method)
  769. reduce_in = torch.stack(list(values))
  770. if self._reduce_method == "mean":
  771. reduce_in = reduce_in.float()
  772. reduced = reduce_meth(reduce_in)
  773. else:
  774. reduce_meth = getattr(np, "nan" + self._reduce_method)
  775. if np.all(np.isnan(values)):
  776. # This avoids warnings for taking a mean of an empty array.
  777. reduced = np.nan
  778. else:
  779. reduced = reduce_meth(values)
  780. def safe_isnan(value):
  781. if torch and isinstance(value, torch.Tensor):
  782. return torch.isnan(value)
  783. return np.isnan(value)
  784. # Convert from numpy to primitive python types, if original `values` are
  785. # python types.
  786. if (
  787. not safe_isnan(reduced)
  788. and reduced.shape == ()
  789. and isinstance(values[0], (int, float))
  790. ):
  791. if reduced.dtype in [np.int32, np.int64, np.int8, np.int16]:
  792. reduced = int(reduced)
  793. else:
  794. reduced = float(reduced)
  795. # For window=None|inf (infinite window) and reduce != mean, we don't have to
  796. # keep any values, except the last (reduced) one.
  797. if self._inf_window and self._reduce_method != "mean":
  798. # TODO (sven): What if values are torch tensors? In this case, we
  799. # would have to do reduction using `torch` above (not numpy) and only
  800. # then return the python primitive AND put the reduced new torch
  801. # tensor in the new `self.values`.
  802. return [reduced], [reduced]
  803. else:
  804. # In all other cases, keep the values that were also used for the reduce
  805. # operation.
  806. return [reduced], values
  807. @DeveloperAPI
  808. def compute_percentiles(sorted_list, percentiles):
  809. """Compute percentiles from an already sorted list.
  810. Note that this will not raise an error if the list is not sorted to avoid overhead.
  811. Args:
  812. sorted_list: A list of numbers sorted in ascending order
  813. percentiles: A list of percentile values (0-100)
  814. Returns:
  815. A dictionary mapping percentile values to their corresponding data values
  816. """
  817. n = len(sorted_list)
  818. if n == 0:
  819. return {p: None for p in percentiles}
  820. results = {}
  821. for p in percentiles:
  822. index = (p / 100) * (n - 1)
  823. if index.is_integer():
  824. results[p] = sorted_list[int(index)]
  825. else:
  826. lower_index = int(index)
  827. upper_index = lower_index + 1
  828. weight = index - lower_index
  829. results[p] = (
  830. sorted_list[lower_index] * (1 - weight)
  831. + sorted_list[upper_index] * weight
  832. )
  833. return results
  834. @DeveloperAPI
  835. def merge_stats(base_stats: Optional[Stats], incoming_stats: List[Stats]) -> Stats:
  836. """Merges Stats objects.
  837. If `base_stats` is None, we use the first incoming Stats object as the new base Stats object.
  838. If `base_stats` is not None, we merge all incoming Stats objects into the base Stats object.
  839. Args:
  840. base_stats: The base Stats object to merge into.
  841. incoming_stats: The list of Stats objects to merge.
  842. Returns:
  843. The merged Stats object.
  844. """
  845. if base_stats is None:
  846. new_root_stats = True
  847. else:
  848. new_root_stats = False
  849. # Nothing to be merged
  850. if len(incoming_stats) == 0:
  851. return base_stats
  852. if new_root_stats:
  853. # We need to deepcopy here first because stats from incoming_stats may be altered in the future
  854. base_stats = copy.deepcopy(incoming_stats[0])
  855. base_stats.clear_throughput()
  856. # Note that we may take a mean of means here, which is not the same as a
  857. # mean of all values. In the future, we could implement a weighted mean
  858. # of means here by introducing a new Stats object that counts samples
  859. # for each mean Stats object.
  860. if len(incoming_stats) > 1:
  861. base_stats.merge_in_parallel(*incoming_stats[1:])
  862. if (
  863. base_stats._reduce_method == "sum"
  864. and base_stats._inf_window
  865. and base_stats._clear_on_reduce is False
  866. ):
  867. for stat in incoming_stats:
  868. base_stats._prev_merge_values[stat.id_] = stat.peek()
  869. elif len(incoming_stats) > 0:
  870. # Special case: `base_stats` is a lifetime sum (reduce=sum,
  871. # clear_on_reduce=False) -> We subtract the previous value (from 2
  872. # `reduce()` calls ago) from all to-be-merged stats, so we don't count
  873. # twice the older sum from before.
  874. # Also, for the merged, new throughput value, we need to find out what the
  875. # actual value-delta is between before the last reduce and the current one.
  876. added_sum = 0.0 # Used in `base_stats._recompute_throughput` if applicable.
  877. if (
  878. base_stats._reduce_method == "sum"
  879. and base_stats._inf_window
  880. and base_stats._clear_on_reduce is False
  881. ):
  882. for stat in incoming_stats:
  883. # Subtract "lifetime counts" from the Stat's values to not count
  884. # older "lifetime counts" more than once.
  885. prev_reduction = base_stats._prev_merge_values[stat.id_]
  886. new_reduction = stat.peek(compile=True)
  887. base_stats.values[-1] -= prev_reduction
  888. # Keep track of how many counts we actually gained (for throughput
  889. # recomputation).
  890. added_sum += new_reduction - prev_reduction
  891. base_stats._prev_merge_values[stat.id_] = new_reduction
  892. parallel_merged_stat = copy.deepcopy(incoming_stats[0])
  893. if len(incoming_stats) > 1:
  894. # There are more than one incoming parallel others -> Merge all of
  895. # them in parallel (equal importance).
  896. parallel_merged_stat.merge_in_parallel(*incoming_stats[1:])
  897. # Merge incoming Stats object into base Stats object on time axis
  898. # (giving incoming ones priority).
  899. if base_stats._reduce_method == "mean" and not base_stats._clear_on_reduce:
  900. # If we don't clear values, values that are not cleared would contribute
  901. # to the mean multiple times.
  902. base_stats._set_values(parallel_merged_stat.values.copy())
  903. else:
  904. base_stats.merge_on_time_axis(parallel_merged_stat)
  905. # Keep track of throughput through the sum of added counts.
  906. if base_stats.has_throughput:
  907. base_stats._recompute_throughput(added_sum)
  908. return base_stats