max.py 988 B

1234567891011121314151617181920212223242526272829
  1. import numpy as np
  2. from ray.rllib.utils.framework import try_import_torch
  3. from ray.rllib.utils.metrics.stats.series import SeriesStats
  4. from ray.util.annotations import DeveloperAPI
  5. torch, _ = try_import_torch()
  6. @DeveloperAPI
  7. class MaxStats(SeriesStats):
  8. """A Stats object that tracks the max of a series of singular values (not vectors)."""
  9. stats_cls_identifier = "max"
  10. def _np_reduce_fn(self, values):
  11. return np.nanmax(values)
  12. def _torch_reduce_fn(self, values):
  13. """Reduce function for torch tensors (stays on GPU)."""
  14. # torch.nanmax not available, use workaround
  15. clean_values = values[~torch.isnan(values)]
  16. if len(clean_values) == 0:
  17. return torch.tensor(float("nan"), device=values.device)
  18. # Cast to float32 to avoid errors from Long tensors
  19. return torch.max(clean_values.float())
  20. def __repr__(self) -> str:
  21. return f"MaxStats({self.peek()}; window={self._window}; len={len(self)})"