percentiles.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. from collections import deque
  2. from itertools import chain
  3. from typing import Any, Dict, List, Optional, Union
  4. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  5. from ray.rllib.utils.metrics.stats.base import StatsBase
  6. from ray.rllib.utils.metrics.stats.utils import batch_values_to_cpu, safe_isnan
  7. from ray.util.annotations import DeveloperAPI
  8. torch, _ = try_import_torch()
  9. _, tf, _ = try_import_tf()
  10. @DeveloperAPI
  11. class PercentilesStats(StatsBase):
  12. """A Stats object that tracks percentiles of a series of singular values (not vectors)."""
  13. stats_cls_identifier = "percentiles"
  14. def __init__(
  15. self,
  16. percentiles: Union[List[int], bool] = None,
  17. window: Optional[Union[int, float]] = None,
  18. *args,
  19. **kwargs,
  20. ):
  21. """Initializes a PercentilesStats instance.
  22. Percentiles are computed over the last `window` values across all parallel components.
  23. Example: If we have 10 parallel components, and each component tracks 1,000 values, we will track the last 10,000 values across all components.
  24. Be careful to not track too many values because computing percentiles is O(n*log(n)) where n is the window size.
  25. See https://github.com/ray-project/ray/pull/52963 for more details.
  26. Args:
  27. percentiles: The percentiles to track.
  28. If None, track the default percentiles [0, 50, 75, 90, 95, 99, 100].
  29. If a list, track the given percentiles.
  30. """
  31. super().__init__(*args, **kwargs)
  32. self._window = window
  33. self.values: Union[List[Any], deque[Any]] = []
  34. self._set_values([])
  35. if percentiles is None:
  36. # We compute a bunch of default percentiles because computing one is just as expensive as computing all of them.
  37. percentiles = [0, 50, 75, 90, 95, 99, 100]
  38. elif isinstance(percentiles, list):
  39. percentiles = percentiles
  40. else:
  41. raise ValueError("`percentiles` must be a list or None")
  42. self._percentiles = percentiles
  43. def get_state(self) -> Dict[str, Any]:
  44. state = super().get_state()
  45. state["values"] = self.values
  46. state["window"] = self._window
  47. state["percentiles"] = self._percentiles
  48. return state
  49. def set_state(self, state: Dict[str, Any]) -> None:
  50. super().set_state(state)
  51. self._set_values(state["values"])
  52. self._window = state["window"]
  53. self._percentiles = state["percentiles"]
  54. def _set_values(self, new_values):
  55. # For stats with window, use a deque with maxlen=window.
  56. # This way, we never store more values than absolutely necessary.
  57. if self._window and self.is_leaf:
  58. # Window always counts at leafs only (or non-root stats)
  59. self.values = deque(new_values, maxlen=self._window)
  60. # For infinite windows, use `new_values` as-is (a list).
  61. else:
  62. self.values = new_values
  63. def __len__(self) -> int:
  64. """Returns the length of the internal values list."""
  65. return len(self.values)
  66. def __float__(self):
  67. raise ValueError(
  68. "Cannot convert to float because percentiles are not reduced to a single value."
  69. )
  70. def __eq__(self, other):
  71. self._comp_error("__eq__")
  72. def __le__(self, other):
  73. self._comp_error("__le__")
  74. def __ge__(self, other):
  75. self._comp_error("__ge__")
  76. def __lt__(self, other):
  77. self._comp_error("__lt__")
  78. def __gt__(self, other):
  79. self._comp_error("__gt__")
  80. def __add__(self, other):
  81. self._comp_error("__add__")
  82. def __sub__(self, other):
  83. self._comp_error("__sub__")
  84. def __mul__(self, other):
  85. self._comp_error("__mul__")
  86. def _comp_error(self, comp):
  87. raise NotImplementedError()
  88. def __format__(self, fmt):
  89. raise ValueError(
  90. "Cannot format percentiles object because percentiles are not reduced to a single value."
  91. )
  92. def push(self, value: Any) -> None:
  93. """Pushes a value into this Stats object.
  94. Args:
  95. value: The value to be pushed. Can be of any type.
  96. PyTorch GPU tensors are kept on GPU until reduce() or peek().
  97. TensorFlow tensors are moved to CPU immediately.
  98. """
  99. # Convert TensorFlow tensors to CPU immediately, keep PyTorch tensors as-is
  100. if tf and tf.is_tensor(value):
  101. value = value.numpy()
  102. if safe_isnan(value):
  103. raise ValueError("NaN values are not allowed in PercentilesStats")
  104. if torch and isinstance(value, torch.Tensor):
  105. value = value.detach()
  106. self.values.append(value)
  107. def merge(self, incoming_stats: List["PercentilesStats"]) -> None:
  108. """Merges PercentilesStats objects.
  109. This method assumes that the incoming stats have the same percentiles and window size.
  110. It will append the incoming values to the existing values.
  111. Args:
  112. incoming_stats: The list of PercentilesStats objects to merge.
  113. Returns:
  114. None. The merge operation modifies self in place.
  115. """
  116. assert (
  117. not self.is_leaf
  118. ), "PercentilesStats should only be merged at aggregation stages (root or intermediate)"
  119. assert all(
  120. s._percentiles == self._percentiles for s in incoming_stats
  121. ), "All incoming PercentilesStats objects must have the same percentiles"
  122. assert all(
  123. s._window == self._window for s in incoming_stats
  124. ), "All incoming PercentilesStats objects must have the same window size"
  125. new_values = [s.values for s in incoming_stats]
  126. new_values = list(chain.from_iterable(new_values))
  127. all_values = list(self.values) + new_values
  128. self.values = all_values
  129. # Track merged values for latest_merged_only peek functionality
  130. if not self.is_leaf:
  131. # Store the values that were merged in this operation (from incoming_stats only)
  132. self.latest_merged = new_values
  133. def peek(
  134. self, compile: bool = True, latest_merged_only: bool = False
  135. ) -> Union[Any, List[Any]]:
  136. """Returns the result of reducing the internal values list.
  137. Note that this method does NOT alter the internal values list in this process.
  138. Thus, users can call this method to get an accurate look at the reduced value(s)
  139. given the current internal values list.
  140. Args:
  141. compile: If True, the result is compiled into the percentiles list.
  142. latest_merged_only: If True, only considers the latest merged values.
  143. This parameter only works on aggregation stats (root or intermediate nodes).
  144. When enabled, peek() will only use the values from the most recent merge operation.
  145. Returns:
  146. The result of reducing the internal values list on CPU.
  147. """
  148. # Check latest_merged_only validity
  149. if latest_merged_only and self.is_leaf:
  150. raise ValueError(
  151. "latest_merged_only can only be used on aggregation stats objects (is_leaf=False)."
  152. )
  153. # If latest_merged_only is True, use only the latest merged values
  154. if latest_merged_only:
  155. if self.latest_merged is None:
  156. # No merged values yet, return dict with None values
  157. if compile:
  158. return {p: None for p in self._percentiles}
  159. else:
  160. return []
  161. # Use only the latest merged values
  162. latest_merged = self.latest_merged
  163. values = batch_values_to_cpu(latest_merged)
  164. else:
  165. # Normal peek behavior
  166. values = batch_values_to_cpu(self.values)
  167. values.sort()
  168. if compile:
  169. return compute_percentiles(values, self._percentiles)
  170. return values
  171. def reduce(self, compile: bool = True) -> Union[Any, "PercentilesStats"]:
  172. """Reduces the internal values list.
  173. Args:
  174. compile: If True, the result is compiled into a single value if possible.
  175. Returns:
  176. The reduced value on CPU.
  177. """
  178. values = batch_values_to_cpu(self.values)
  179. values.sort()
  180. self._set_values([])
  181. if compile:
  182. return compute_percentiles(values, self._percentiles)
  183. return_stats = self.clone()
  184. return_stats.values = values
  185. return return_stats
  186. def __repr__(self) -> str:
  187. return (
  188. f"PercentilesStats({self.peek()}; window={self._window}; len={len(self)})"
  189. )
  190. @staticmethod
  191. def _get_init_args(stats_object=None, state=None) -> Dict[str, Any]:
  192. """Returns the initialization arguments for this Stats object."""
  193. super_args = StatsBase._get_init_args(stats_object=stats_object, state=state)
  194. if state is not None:
  195. return {
  196. **super_args,
  197. "percentiles": state["percentiles"],
  198. "window": state["window"],
  199. }
  200. elif stats_object is not None:
  201. return {
  202. **super_args,
  203. "percentiles": stats_object._percentiles,
  204. "window": stats_object._window,
  205. }
  206. else:
  207. raise ValueError("Either stats_object or state must be provided")
  208. @DeveloperAPI
  209. def compute_percentiles(sorted_list, percentiles):
  210. """Compute percentiles from an already sorted list.
  211. Note that this will not raise an error if the list is not sorted to avoid overhead.
  212. Args:
  213. sorted_list: A list of numbers sorted in ascending order
  214. percentiles: A list of percentile values (0-100)
  215. Returns:
  216. A dictionary mapping percentile values to their corresponding data values
  217. """
  218. n = len(sorted_list)
  219. if n == 0:
  220. return {p: None for p in percentiles}
  221. results = {}
  222. for p in percentiles:
  223. index = (p / 100) * (n - 1)
  224. if index.is_integer():
  225. results[p] = sorted_list[int(index)]
  226. else:
  227. lower_index = int(index)
  228. upper_index = lower_index + 1
  229. weight = index - lower_index
  230. results[p] = (
  231. sorted_list[lower_index] * (1 - weight)
  232. + sorted_list[upper_index] * weight
  233. )
  234. return results