min.py 1.0 KB

1234567891011121314151617181920212223242526272829
  1. import numpy as np
  2. from ray.rllib.utils.framework import try_import_torch
  3. from ray.rllib.utils.metrics.stats.series import SeriesStats
  4. from ray.util.annotations import DeveloperAPI
  5. torch, _ = try_import_torch()
  6. @DeveloperAPI
  7. class MinStats(SeriesStats):
  8. """A Stats object that tracks the min of a series of singular values (not vectors)."""
  9. stats_cls_identifier = "min"
  10. def _np_reduce_fn(self, values: np.ndarray) -> float:
  11. return np.nanmin(values)
  12. def _torch_reduce_fn(self, values: "torch.Tensor"):
  13. """Reduce function for torch tensors (stays on GPU)."""
  14. # torch.nanmin not available, use workaround
  15. clean_values = values[~torch.isnan(values)]
  16. if len(clean_values) == 0:
  17. return torch.tensor(float("nan"), device=values.device)
  18. # Cast to float32 to avoid errors from Long tensors
  19. return torch.min(clean_values.float())
  20. def __repr__(self) -> str:
  21. return f"MinStats({self.peek()}; window={self._window}; len={len(self)})"