mean.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. from typing import Any, Union
  2. import numpy as np
  3. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  4. from ray.rllib.utils.metrics.stats.series import SeriesStats
  5. from ray.util.annotations import DeveloperAPI
  6. torch, _ = try_import_torch()
  7. _, tf, _ = try_import_tf()
  8. @DeveloperAPI
  9. class MeanStats(SeriesStats):
  10. """A Stats object that tracks the mean of a series of singular values (not vectors).
  11. Note the following limitation: When merging multiple MeanStats objects, the resulting mean is not the true mean of all values.
  12. Instead, it is the mean of the means of the incoming MeanStats objects.
  13. This is because we calculate the mean in parallel components and potentially merge them multiple times in one reduce cycle.
  14. The resulting mean of means may differ significantly from the true mean, especially if some incoming means are the result of few outliers.
  15. Example to illustrate this limitation:
  16. First incoming mean: [1, 2, 3, 4, 5] -> 3
  17. Second incoming mean: [15] -> 15
  18. Mean of both merged means: [3, 15] -> 9
  19. True mean of all values: [1, 2, 3, 4, 5, 15] -> 5
  20. """
  21. stats_cls_identifier = "mean"
  22. def _np_reduce_fn(self, values: np.ndarray) -> float:
  23. return np.nanmean(values)
  24. def _torch_reduce_fn(self, values: "torch.Tensor"):
  25. """Reduce function for torch tensors (stays on GPU)."""
  26. return torch.nanmean(values.float())
  27. def push(self, value: Any) -> None:
  28. """Pushes a value into this Stats object.
  29. Args:
  30. value: The value to be pushed. Can be of any type.
  31. PyTorch GPU tensors are kept on GPU until reduce() or peek().
  32. TensorFlow tensors are moved to CPU immediately.
  33. """
  34. # Convert TensorFlow tensors to CPU immediately, keep PyTorch tensors as-is
  35. if tf and tf.is_tensor(value):
  36. value = value.numpy()
  37. self.values.append(value)
  38. def reduce(self, compile: bool = True) -> Union[Any, "MeanStats"]:
  39. reduced_values = self.window_reduce() # Values are on CPU already after this
  40. self._set_values([])
  41. if compile:
  42. return reduced_values[0]
  43. return_stats = self.clone()
  44. return_stats.values = reduced_values
  45. return return_stats
  46. def __repr__(self) -> str:
  47. return f"MeanStats({self.peek()}; window={self._window}; len={len(self)})"