base.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. import threading
  2. import time
  3. from abc import ABCMeta, abstractmethod
  4. from collections import deque
  5. from typing import Any, Dict, List, Optional, Union
  6. from ray.rllib.utils.annotations import (
  7. OverrideToImplementCustomLogic,
  8. OverrideToImplementCustomLogic_CallToSuperRecommended,
  9. )
  10. from ray.util.annotations import DeveloperAPI
  11. @DeveloperAPI
  12. class StatsBase(metaclass=ABCMeta):
  13. """A base class for Stats.
  14. Stats are meant to be used to log values to and then aggregate them in a tree.
  15. Therefore, we log to stats in two different ways:
  16. - On a leaf component, we log values directly by pushing.
  17. - On a non-leaf component, we only aggregate incoming values.
  18. Additionally, we pay special respect to Stats that live at the root of the tree.
  19. These may have a different behaviour (example: a lifetime sum).
  20. Note the tight coupling between StatsBase and MetricsLogger.
  21. """
  22. # In order to restore from a checkpoint, we need to know the class of the Stats object.
  23. # This is set in the subclass.
  24. stats_cls_identifier: str = None
  25. def __init__(
  26. self,
  27. is_root: bool = False,
  28. is_leaf: bool = True,
  29. ):
  30. """Initializes a StatsBase object.
  31. Args:
  32. is_root: If True, the Stats object is a root stats object.
  33. is_leaf: If True, the Stats object is a leaf stats object.
  34. Note: A stats object can be both root and leaf at the same time.
  35. Note: A stats object can also be neither root nor leaf ("intermediate" stats that only aggregate from other stats but are not at the root).
  36. """
  37. self.is_root = is_root
  38. self.is_leaf = is_leaf
  39. # Used to keep track of start times when using the `with` context manager.
  40. # This helps us measure times with threads in parallel.
  41. self._start_times = {}
  42. # For non-leaf stats (root or intermediate), track the latest merged values
  43. # This is overwritten on each merge operation
  44. if not self.is_leaf:
  45. self.latest_merged: Union[List[Any], Any] = None
  46. assert (
  47. self.stats_cls_identifier is not None
  48. ), "stats_cls_identifier must be set in the subclass"
  49. @property
  50. def has_throughputs(self) -> bool:
  51. """Returns True if the Stats object has throughput tracking enabled.
  52. Some Stats classes may have throughput tracking enabled, such as SumStats.
  53. """
  54. return False
  55. @OverrideToImplementCustomLogic
  56. def initialize_throughput_reference_time(self, time: float) -> None:
  57. """Sets the reference time for this Stats object.
  58. This is important because the component that tracks the time
  59. between reduce cycles is not Stats, but MetricsLogger.
  60. Args:
  61. time: The time to establish as the reference time for this Stats object.
  62. """
  63. if self.has_throughputs:
  64. raise ValueError(
  65. "initialize_throughput_reference_time must be overridden for stats objects that have throughputs."
  66. )
  67. @abstractmethod
  68. def __len__(self) -> int:
  69. """Returns the length of the internal values list."""
  70. ...
  71. def __float__(self):
  72. value = self.peek(compile=True)
  73. if isinstance(value, (list, tuple, deque)):
  74. raise ValueError(f"Can not convert {self} to float.")
  75. return float(value)
  76. def __int__(self):
  77. value = self.peek(compile=True)
  78. if isinstance(value, (list, tuple, deque)):
  79. raise ValueError(f"Can not convert {self} to int.")
  80. return int(value)
  81. def __eq__(self, other):
  82. return float(self) == float(other)
  83. def __le__(self, other):
  84. return float(self) <= float(other)
  85. def __ge__(self, other):
  86. return float(self) >= float(other)
  87. def __lt__(self, other):
  88. return float(self) < float(other)
  89. def __gt__(self, other):
  90. return float(self) > float(other)
  91. def __add__(self, other):
  92. return float(self) + float(other)
  93. def __sub__(self, other):
  94. return float(self) - float(other)
  95. def __mul__(self, other):
  96. return float(self) * float(other)
  97. def __format__(self, fmt):
  98. return f"{float(self):{fmt}}"
  99. def __enter__(self) -> "StatsBase":
  100. """Called when entering a context (with which users can measure a time delta).
  101. Returns:
  102. This stats instance.
  103. """
  104. thread_id = threading.get_ident()
  105. self._start_times[thread_id] = time.perf_counter()
  106. return self
  107. def __exit__(self, exc_type, exc_value, tb) -> None:
  108. """Called when exiting a context (with which users can measure a time delta).
  109. This pushes the time delta since __enter__ to this Stats object.
  110. """
  111. thread_id = threading.get_ident()
  112. assert self._start_times[thread_id] is not None
  113. time_delta_s = time.perf_counter() - self._start_times[thread_id]
  114. self.push(time_delta_s)
  115. del self._start_times[thread_id]
  116. @classmethod
  117. def from_state(cls, state: Dict[str, Any]) -> "StatsBase":
  118. """Creates a stats object from a state dictionary.
  119. Any implementation of this should call this base classe's
  120. `stats_object.set_state()` to set the state of the stats object.
  121. Args:
  122. state: The state to set after instantiation.
  123. """
  124. init_args = cls._get_init_args(state=state)
  125. stats = cls(**init_args)
  126. stats.set_state(state)
  127. return stats
  128. @OverrideToImplementCustomLogic_CallToSuperRecommended
  129. def clone(
  130. self,
  131. init_overrides: Optional[Dict[str, Any]] = None,
  132. ) -> "StatsBase":
  133. """Returns a new stats object with the same settings as `self`.
  134. Args:
  135. init_overrides: Optional dict of initialization arguments to override. Can be used to change is_root, is_leaf, etc.
  136. Returns:
  137. A new stats object similar to `self` but missing internal values.
  138. """
  139. init_args = self.__class__._get_init_args(stats_object=self)
  140. if init_overrides:
  141. init_args.update(init_overrides)
  142. new_stats = self.__class__(**init_args)
  143. return new_stats
  144. @OverrideToImplementCustomLogic_CallToSuperRecommended
  145. def get_state(self) -> Dict[str, Any]:
  146. """Returns the state of the stats object."""
  147. state = {
  148. "stats_cls_identifier": self.stats_cls_identifier,
  149. "is_root": self.is_root,
  150. "is_leaf": self.is_leaf,
  151. }
  152. if not self.is_leaf:
  153. state["latest_merged"] = self.latest_merged
  154. return state
  155. @OverrideToImplementCustomLogic_CallToSuperRecommended
  156. def set_state(self, state: Dict[str, Any]) -> None:
  157. """Sets the state of the stats object.
  158. Args:
  159. state: The state to set on this StatsBase object.
  160. """
  161. # Handle legacy state that uses old attribute names
  162. self.is_root = state["is_root"]
  163. self.is_leaf = state["is_leaf"]
  164. # Prevent setting a state with a different stats class identifier
  165. assert self.stats_cls_identifier == state["stats_cls_identifier"]
  166. if not self.is_leaf:
  167. # Handle legacy state that doesn't have latest_merged
  168. self.latest_merged = state["latest_merged"]
  169. @OverrideToImplementCustomLogic_CallToSuperRecommended
  170. @staticmethod
  171. def _get_init_args(stats_object=None, state=None) -> Dict[str, Any]:
  172. """Returns the initialization arguments for this Stats object."""
  173. if state is not None:
  174. # Handle legacy state that uses old attribute names
  175. is_root = state["is_root"]
  176. is_leaf = state["is_leaf"]
  177. return {
  178. "is_root": is_root,
  179. "is_leaf": is_leaf,
  180. }
  181. elif stats_object is not None:
  182. return {
  183. "is_root": stats_object.is_root,
  184. "is_leaf": stats_object.is_leaf,
  185. }
  186. else:
  187. raise ValueError("Either stats_object or state must be provided")
  188. @abstractmethod
  189. def __repr__(self) -> str:
  190. ...
  191. @abstractmethod
  192. def merge(self, incoming_stats: List["StatsBase"]) -> None:
  193. """Merges StatsBase objects.
  194. Args:
  195. incoming_stats: The list of StatsBase objects to merge.
  196. """
  197. @abstractmethod
  198. def push(self, value: Any) -> None:
  199. """Pushes a value into this Stats object.
  200. Args:
  201. value: The value to push. Can be of any type.
  202. GPU tensors are moved to CPU memory.
  203. """
  204. assert (
  205. self.is_leaf
  206. ), "Cannot push values to non-leaf Stats. Non-leaf Stats can only receive values via merge()."
  207. @abstractmethod
  208. def peek(
  209. self, compile: bool = True, latest_merged_only: bool = False
  210. ) -> Union[Any, List[Any]]:
  211. """Returns the result of reducing the internal values list.
  212. Note that this method does NOT alter the internal values list in this process.
  213. Thus, users can call this method to get an accurate look at the reduced value(s)
  214. given the current internal values list.
  215. Args:
  216. compile: If True, the result is compiled into a single value if possible.
  217. latest_merged_only: If True, only considers the latest merged values.
  218. This parameter only works on aggregation stats objects (is_leaf=False).
  219. When enabled, peek() will only use the values from the most recent merge operation.
  220. Returns:
  221. The result of reducing the internal values list on CPU memory.
  222. """
  223. @abstractmethod
  224. def reduce(self, compile: bool = True) -> Union[Any, "StatsBase"]:
  225. """Reduces the internal values.
  226. This method should NOT be called directly by users.
  227. It can be used as a hook to prepare the stats object for sending it to the root metrics logger and starting a new 'reduce cycle'.
  228. The reduction logic depends on the implementation of the subclass.
  229. Meaning that some classes may reduce to a single value, while others do not or don't even contain values.
  230. Args:
  231. compile: If True, the result is compiled into a single value if possible.
  232. If False, the result is a Stats object similar to itself, but with the internal values reduced.
  233. Returns:
  234. The reduced value or a Stats object similar to itself, but with the internal values reduced.
  235. """