sum.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. import time
  2. from typing import Any, Dict, Union
  3. import numpy as np
  4. from ray.rllib.utils.framework import try_import_torch
  5. from ray.rllib.utils.metrics.stats.series import SeriesStats
  6. from ray.util.annotations import DeveloperAPI
  7. torch, _ = try_import_torch()
  8. @DeveloperAPI
  9. class SumStats(SeriesStats):
  10. """A Stats object that tracks the sum of a series of singular values (not vectors)."""
  11. stats_cls_identifier = "sum"
  12. def _np_reduce_fn(self, values: np.ndarray) -> float:
  13. return np.nansum(values)
  14. def _torch_reduce_fn(self, values: "torch.Tensor"):
  15. """Reduce function for torch tensors (stays on GPU)."""
  16. # torch.nansum not available, use workaround
  17. clean_values = values[~torch.isnan(values)]
  18. if len(clean_values) == 0:
  19. return torch.tensor(0.0, device=values.device)
  20. return torch.sum(clean_values.float())
  21. def __init__(self, with_throughput: bool = False, **kwargs):
  22. """Initializes a SumStats instance.
  23. Args:
  24. throughput: If True, track a throughput estimate based on the time between consecutive calls to reduce().
  25. """
  26. super().__init__(**kwargs)
  27. self.track_throughput = with_throughput
  28. # We initialize this to the current time which may result in a low first throughput value
  29. # It seems reasonable that starting from a checkpoint or starting an experiment results in a low first throughput value
  30. self._last_throughput_measure_time = time.perf_counter()
  31. def initialize_throughput_reference_time(self, time: float) -> None:
  32. assert (
  33. self.is_root
  34. ), "initialize_throughput_reference_time can only be called on root stats"
  35. self._last_throughput_measure_time = time
  36. @property
  37. def has_throughputs(self) -> bool:
  38. return self.track_throughput
  39. @property
  40. def throughputs(self) -> float:
  41. """Returns the throughput since the last reduce."""
  42. assert (
  43. self.track_throughput
  44. ), "Throughput tracking is not enabled on this Stats object"
  45. return self.peek(compile=True) / (
  46. time.perf_counter() - self._last_throughput_measure_time
  47. )
  48. def reduce(self, compile: bool = True) -> Union[Any, "SumStats"]:
  49. reduce_value = super().reduce(compile=True)
  50. # Update the last throughput measure time for correct throughput calculations
  51. if self.track_throughput:
  52. self._last_throughput_measure_time = time.perf_counter()
  53. if compile:
  54. return reduce_value
  55. return_stats = self.clone()
  56. return_stats.values = [reduce_value]
  57. return return_stats
  58. @staticmethod
  59. def _get_init_args(stats_object=None, state=None) -> Dict[str, Any]:
  60. """Returns the initialization arguments for this Stats object."""
  61. super_args = SeriesStats._get_init_args(stats_object=stats_object, state=state)
  62. if state is not None:
  63. return {
  64. **super_args,
  65. "with_throughput": state["track_throughput"],
  66. }
  67. elif stats_object is not None:
  68. return {
  69. **super_args,
  70. "with_throughput": stats_object.track_throughput,
  71. }
  72. else:
  73. raise ValueError("Either stats_object or state must be provided")
  74. def get_state(self) -> Dict[str, Any]:
  75. """Returns the state of the stats object."""
  76. state = super().get_state()
  77. state["track_throughput"] = self.track_throughput
  78. return state
  79. def set_state(self, state: Dict[str, Any]) -> None:
  80. super().set_state(state)
  81. self.track_throughput = state["track_throughput"]
  82. def __repr__(self) -> str:
  83. return f"SumStats({self.peek()}; window={self._window}; len={len(self)})"