ema.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. import logging
  2. import warnings
  3. from typing import Any, Dict, List, Union
  4. import numpy as np
  5. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  6. from ray.rllib.utils.metrics.stats.base import StatsBase
  7. from ray.rllib.utils.metrics.stats.utils import safe_isnan, single_value_to_cpu
  8. from ray.util import log_once
  9. from ray.util.annotations import DeveloperAPI
  10. logger = logging.getLogger(__name__)
  11. torch, _ = try_import_torch()
  12. _, tf, _ = try_import_tf()
  13. @DeveloperAPI
  14. class EmaStats(StatsBase):
  15. """A Stats object that tracks the exponential average of a series of singular values (not vectors)."""
  16. stats_cls_identifier = "ema"
  17. def __init__(
  18. self,
  19. ema_coeff: float = 0.01,
  20. *args,
  21. **kwargs,
  22. ):
  23. """Initializes a EmaStats instance.
  24. We calculate the EMA in parallel components.
  25. Also, we potentially aggregate them multiple times per reduction cycle.
  26. We therefore aggregate by taking the mean of all collected EMAs.
  27. We do this for simplicity and accept this limitation because EMAs
  28. inherently only approximate.
  29. Example to illustrate this limitation:
  30. Using an ema coefficient of 0.01:
  31. First incoming ema: [1, 2, 3, 4, 5] -> 1.1
  32. Second incoming ema: [15] -> 15
  33. Mean of both merged ema values: [1.1, 15] -> 8.05
  34. True mean of all values: [1, 2, 3, 4, 5, 15] -> 5
  35. Args:
  36. ema_coeff: The EMA coefficient to use. Defaults to 0.01.
  37. """
  38. super().__init__(*args, **kwargs)
  39. self._value = np.nan
  40. if not self.is_leaf:
  41. self._values_to_merge = []
  42. self._ema_coeff = ema_coeff
  43. def _quiet_nanmean(self, values: List[Any]) -> float:
  44. """Compute the nanmean while ignoring warnings if all values are NaN.
  45. Args:
  46. values: The list of values to compute the nanmean of.
  47. Returns:
  48. The nanmean of the values.
  49. """
  50. if torch and isinstance(values[0], torch.Tensor):
  51. stacked = torch.stack(list(values))
  52. return torch.nanmean(stacked)
  53. with warnings.catch_warnings():
  54. warnings.filterwarnings("ignore", "Mean of empty slice", RuntimeWarning)
  55. return np.nanmean(values)
  56. def __len__(self) -> int:
  57. """Returns the length of the internal values list."""
  58. return 1
  59. def merge(self, incoming_stats: List["EmaStats"]) -> None:
  60. """Merges EmaStats objects.
  61. Args:
  62. incoming_stats: The list of EmaStats objects to merge.
  63. Returns:
  64. None. The merge operation modifies self in place.
  65. """
  66. assert (
  67. not self.is_leaf
  68. ), "EmaStats should only be merged at aggregation stages (root or intermediate)"
  69. all_values = [stat._value for stat in incoming_stats]
  70. if len(all_values) == 0:
  71. return
  72. self._values_to_merge.extend(all_values)
  73. # Track merged values for latest_merged_only peek functionality
  74. if not self.is_leaf:
  75. # Store the values that were merged in this operation
  76. self.latest_merged = all_values
  77. def push(self, value: Any) -> None:
  78. """Pushes a value into this Stats object.
  79. Args:
  80. value: The value to be pushed. Can be of any type.
  81. PyTorch GPU tensors are kept on GPU until reduce() or peek().
  82. TensorFlow tensors are moved to CPU immediately.
  83. """
  84. # Convert TensorFlow tensors to CPU immediately
  85. if tf and tf.is_tensor(value):
  86. value = value.numpy()
  87. # If incoming value is NaN, do nothing
  88. if safe_isnan(value):
  89. return
  90. if torch and isinstance(value, torch.Tensor):
  91. # Detach the value from the graph to avoid unnecessary computation
  92. value = value.detach()
  93. # If internal value is NaN, replace it with the incoming value
  94. if safe_isnan(self._value):
  95. self._value = value
  96. else:
  97. # Otherwise, update the internal value using the EMA formula
  98. self._value = (
  99. self._ema_coeff * value + (1.0 - self._ema_coeff) * self._value
  100. )
  101. def _reduce_values_to_merge(self) -> float:
  102. """Reduces the internal values to merge."""
  103. if not np.isnan(self._value) and log_once("ema_stats_merge_push"):
  104. logger.warning(
  105. f"Merging values in {self} but self._value is not NaN. This leads to an inaccurate metric. Not erroring out to avoid breaking older checkpoints."
  106. )
  107. if len(self._values_to_merge) == 0:
  108. return np.nan
  109. return self._quiet_nanmean(self._values_to_merge)
  110. def peek(
  111. self, compile: bool = True, latest_merged_only: bool = False
  112. ) -> Union[Any, List[Any]]:
  113. """Returns the current EMA value.
  114. If value is a GPU tensor, it's converted to CPU.
  115. Args:
  116. compile: If True, the result is compiled into a single value if possible.
  117. latest_merged_only: If True, only considers the latest merged values.
  118. This parameter only works on aggregation stats (root or intermediate nodes).
  119. When enabled, peek() will only use the values from the most recent merge operation.
  120. """
  121. # Check latest_merged_only validity
  122. if latest_merged_only and self.is_leaf:
  123. raise ValueError(
  124. "latest_merged_only can only be used on aggregation stats objects (is_leaf=False)."
  125. )
  126. # If latest_merged_only is True, use only the latest merged values
  127. if latest_merged_only:
  128. if self.latest_merged is None:
  129. # No merged values yet, return NaN
  130. if compile:
  131. return np.nan
  132. else:
  133. return [np.nan]
  134. # Use only the latest merged values
  135. latest_merged = self.latest_merged
  136. if len(latest_merged) == 0:
  137. value = np.nan
  138. else:
  139. # Reduce latest merged values
  140. value = self._quiet_nanmean(latest_merged)
  141. else:
  142. # Normal peek behavior
  143. if hasattr(self, "_values_to_merge"):
  144. # If _values_to_merge is empty, use _value instead
  145. # This can happen after reduce(compile=False) returns a new stats object
  146. if len(self._values_to_merge) == 0:
  147. value = self._value
  148. else:
  149. value = self._reduce_values_to_merge()
  150. else:
  151. value = self._value
  152. value = single_value_to_cpu(value)
  153. return value if compile else [value]
  154. def reduce(self, compile: bool = True) -> Union[Any, "EmaStats"]:
  155. """Reduces the internal value.
  156. If value is a GPU tensor, it's converted to CPU.
  157. Args:
  158. compile: If True, the result is compiled into a single value if possible.
  159. Returns:
  160. The reduced value.
  161. """
  162. if hasattr(self, "_values_to_merge"):
  163. # If _values_to_merge is empty, use _value instead
  164. # This can happen when a non-leaf stats object logs values directly
  165. if len(self._values_to_merge) == 0:
  166. value = self._value
  167. else:
  168. value = self._reduce_values_to_merge()
  169. self._values_to_merge = []
  170. else:
  171. value = self._value
  172. # Convert GPU tensor to CPU
  173. if torch and isinstance(value, torch.Tensor):
  174. value = single_value_to_cpu(value)
  175. self._value = np.nan
  176. if compile:
  177. return value
  178. return_stats = self.clone()
  179. return_stats._value = value
  180. return return_stats
  181. def __repr__(self) -> str:
  182. values_to_merge_len = (
  183. len(self._values_to_merge) if hasattr(self, "_values_to_merge") else 0
  184. )
  185. return (
  186. f"EmaStats({self.peek()}; number_of_values_to_merge=({values_to_merge_len}); "
  187. f"ema_coeff={self._ema_coeff}, value={self._value})"
  188. )
  189. def get_state(self) -> Dict[str, Any]:
  190. state = super().get_state()
  191. state["ema_coeff"] = self._ema_coeff
  192. state["value"] = self._value
  193. if not self.is_leaf:
  194. state["values_to_merge"] = self._values_to_merge
  195. return state
  196. def set_state(self, state: Dict[str, Any]) -> None:
  197. super().set_state(state)
  198. self._ema_coeff = state["ema_coeff"]
  199. self._value = state["value"]
  200. # Handle legacy state that doesn't have values_to_merge
  201. if not self.is_leaf:
  202. self._values_to_merge = state.get("values_to_merge", [])
  203. @staticmethod
  204. def _get_init_args(stats_object=None, state=None) -> Dict[str, Any]:
  205. """Returns the initialization arguments for this Stats object."""
  206. super_args = StatsBase._get_init_args(stats_object=stats_object, state=state)
  207. if state is not None:
  208. return {
  209. **super_args,
  210. "ema_coeff": state["ema_coeff"],
  211. }
  212. if stats_object is not None:
  213. return {
  214. **super_args,
  215. "ema_coeff": stats_object._ema_coeff,
  216. }
  217. else:
  218. raise ValueError("Either stats_object or state must be provided")