series.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. from abc import ABCMeta
  2. from collections import deque
  3. from itertools import chain
  4. from typing import Any, Dict, List, Optional, Union
  5. import numpy as np
  6. from ray.rllib.utils.annotations import (
  7. OverrideToImplementCustomLogic_CallToSuperRecommended,
  8. )
  9. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  10. from ray.rllib.utils.metrics.stats.base import StatsBase
  11. from ray.rllib.utils.metrics.stats.utils import batch_values_to_cpu, single_value_to_cpu
  12. from ray.util.annotations import DeveloperAPI
  13. torch, _ = try_import_torch()
  14. _, tf, _ = try_import_tf()
  15. @DeveloperAPI
  16. class SeriesStats(StatsBase, metaclass=ABCMeta):
  17. """A base class for Stats that represent a series of singular values (not vectors)."""
  18. # Set by subclasses
  19. _np_reduce_fn = None
  20. # Set by subclasses
  21. _torch_reduce_fn = None
  22. def __init__(
  23. self,
  24. window: Optional[Union[int, float]] = None,
  25. *args,
  26. **kwargs,
  27. ):
  28. """Initializes a SeriesStats instance.
  29. Args:
  30. window: The window size to reduce over.
  31. """
  32. super().__init__(*args, **kwargs)
  33. self._window = window
  34. self.values: Union[List[Any], deque[Any]] = []
  35. self._set_values([])
  36. def get_state(self) -> Dict[str, Any]:
  37. state = super().get_state()
  38. state = {
  39. **state,
  40. "values": batch_values_to_cpu(self.values),
  41. "window": self._window,
  42. }
  43. return state
  44. def set_state(self, state: Dict[str, Any]) -> None:
  45. super().set_state(state)
  46. self._set_values(state["values"])
  47. self._window = state["window"]
  48. @OverrideToImplementCustomLogic_CallToSuperRecommended
  49. @staticmethod
  50. def _get_init_args(stats_object=None, state=None) -> Dict[str, Any]:
  51. super_args = StatsBase._get_init_args(stats_object=stats_object, state=state)
  52. if state is not None:
  53. return {
  54. **super_args,
  55. "window": state["window"],
  56. }
  57. elif stats_object is not None:
  58. return {
  59. **super_args,
  60. "window": stats_object._window,
  61. }
  62. else:
  63. raise ValueError("Either stats_object or state must be provided")
  64. def reduce(self, compile: bool = True) -> Union[Any, "SeriesStats"]:
  65. """Reduces the internal values list according to the constructor settings."""
  66. if self._window is None:
  67. if len(self.values) <= 1 or not compile:
  68. reduced_values = batch_values_to_cpu(self.values)
  69. else:
  70. reduced_values = self.window_reduce()
  71. else:
  72. reduced_values = self.window_reduce()
  73. self._set_values([])
  74. if compile:
  75. if len(reduced_values) == 0:
  76. return np.nan
  77. else:
  78. return reduced_values[0]
  79. return_stats = self.clone()
  80. return_stats.values = reduced_values
  81. return return_stats
  82. def __len__(self) -> int:
  83. """Returns the length of the internal values list."""
  84. return len(self.values)
  85. def _set_values(self, new_values):
  86. # For stats with window, use a deque with maxlen=window.
  87. # This way, we never store more values than absolutely necessary.
  88. if self._window and self.is_leaf:
  89. # Window always counts at leafs only (or non-root stats)
  90. self.values = deque(new_values, maxlen=self._window)
  91. # For infinite windows, use `new_values` as-is (a list).
  92. else:
  93. self.values = new_values
  94. def push(self, value: Any) -> None:
  95. """Pushes a value into this Stats object.
  96. Args:
  97. value: The value to be pushed. Can be of any type.
  98. PyTorch GPU tensors are kept on GPU until reduce() or peek().
  99. TensorFlow tensors are moved to CPU immediately.
  100. """
  101. # Convert TensorFlow tensors to CPU immediately, keep PyTorch tensors as-is
  102. if tf and tf.is_tensor(value):
  103. value = value.numpy()
  104. if torch and isinstance(value, torch.Tensor):
  105. value = value.detach()
  106. if self._window is None:
  107. if not self.values:
  108. self._set_values([value])
  109. else:
  110. self._set_values(self.running_reduce(self.values[0], value))
  111. else:
  112. # For windowed operations, append to values and trim if needed
  113. self.values.append(value)
  114. def merge(self, incoming_stats: List["SeriesStats"]) -> None:
  115. """Merges SeriesStats objects.
  116. Args:
  117. incoming_stats: The list of SeriesStats objects to merge.
  118. Returns:
  119. None. The merge operation modifies self in place.
  120. """
  121. assert (
  122. not self.is_leaf
  123. ), "SeriesStats should only be merged at aggregation stages (root or intermediate)"
  124. if len(incoming_stats) == 0:
  125. return
  126. all_items = [s.values for s in incoming_stats]
  127. all_items = list(chain.from_iterable(all_items))
  128. # Implicitly may convert internal to list.
  129. # That's ok because we don't want to evict items from the deque if we merge in this object's values.
  130. all_items = list(self.values) + list(all_items)
  131. self.values = all_items
  132. # Track merged values for latest_merged_only peek functionality
  133. if not self.is_leaf:
  134. # Store the values that were merged in this operation (from incoming_stats only)
  135. merged_values = list(
  136. chain.from_iterable([s.values for s in incoming_stats])
  137. )
  138. self.latest_merged = merged_values
  139. def peek(
  140. self, compile: bool = True, latest_merged_only: bool = False
  141. ) -> Union[Any, List[Any]]:
  142. """Returns the result of reducing the internal values list.
  143. Note that this method does NOT alter the internal values list.
  144. Args:
  145. compile: If True, the result is compiled into a single value if possible.
  146. latest_merged_only: If True, only considers the latest merged values.
  147. This parameter only works on aggregation stats (root or intermediate nodes, is_leaf=False).
  148. When enabled, peek() will only use the values from the most recent merge operation.
  149. Returns:
  150. The result of reducing the internal values list.
  151. """
  152. # If latest_merged_only is True, use look at the latest merged values
  153. if latest_merged_only:
  154. if self.is_leaf:
  155. raise ValueError(
  156. "latest_merged_only can only be used on aggregation stats objects "
  157. "(is_leaf=False)"
  158. )
  159. if self.latest_merged is None:
  160. # No merged values yet, return NaN or empty list
  161. if compile:
  162. return np.nan
  163. else:
  164. return []
  165. # Use only the latest merged values
  166. latest_merged = self.latest_merged
  167. if len(latest_merged) == 0:
  168. reduced_values = [np.nan]
  169. else:
  170. reduced_values = self.window_reduce(latest_merged)
  171. else:
  172. # Normal peek behavior
  173. if len(self.values) == 1:
  174. # Note that we can not check for window=None here because merged SeriesStats may have multiple values.
  175. reduced_values = self.values
  176. else:
  177. reduced_values = self.window_reduce()
  178. if compile:
  179. if len(reduced_values) == 0:
  180. return np.nan
  181. else:
  182. return reduced_values[0]
  183. else:
  184. return reduced_values
  185. def running_reduce(self, value_1, value_2) -> List[Any]:
  186. """Reduces two values through a reduce function.
  187. If values are PyTorch tensors, reduction happens on GPU.
  188. Result stays on GPU (or CPU if values were CPU).
  189. Args:
  190. value_1: The first value to reduce.
  191. value_2: The second value to reduce.
  192. Returns:
  193. The reduced value (may be GPU tensor).
  194. """
  195. # If values are torch tensors, reduce on GPU
  196. if (
  197. torch
  198. and isinstance(value_1, torch.Tensor)
  199. and hasattr(self, "_torch_reduce_fn")
  200. ):
  201. return [self._torch_reduce_fn(torch.stack([value_1, value_2]))]
  202. # Otherwise use numpy reduction
  203. return [self._np_reduce_fn([value_1, value_2])]
  204. def window_reduce(self, values=None) -> List[Any]:
  205. """Reduces the internal values list according to the constructor settings.
  206. If values are PyTorch GPU tensors, reduction happens on GPU and result
  207. is moved to CPU. Otherwise returns CPU value.
  208. Args:
  209. values: The list of values to reduce. If not None, use `self.values`
  210. Returns:
  211. The reduced value on CPU.
  212. """
  213. values = values if values is not None else self.values
  214. # Special case: Internal values list is empty -> return NaN
  215. if len(values) == 0:
  216. return [np.nan]
  217. # If values are torch tensors, reduce on GPU then move to CPU
  218. if (
  219. torch
  220. and isinstance(values[0], torch.Tensor)
  221. and hasattr(self, "_torch_reduce_fn")
  222. ):
  223. stacked = torch.stack(list(values))
  224. # Check for all NaN
  225. if torch.all(torch.isnan(stacked)):
  226. return [np.nan]
  227. result = self._torch_reduce_fn(stacked)
  228. return [single_value_to_cpu(result)]
  229. # Otherwise use numpy reduction on CPU values
  230. if np.all(np.isnan(values)):
  231. return [np.nan]
  232. else:
  233. return [self._np_reduce_fn(values)]