utils.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. from collections import deque
  2. from typing import Any, List, Union
  3. import numpy as np
  4. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  5. from ray.util.annotations import DeveloperAPI
  6. torch, _ = try_import_torch()
  7. _, tf, _ = try_import_tf()
  8. @DeveloperAPI
  9. def safe_isnan(value):
  10. """Check if a value is NaN.
  11. Args:
  12. value: The value to check.
  13. Returns:
  14. True if the value is NaN, False otherwise.
  15. """
  16. if torch and torch.is_tensor(value):
  17. return torch.isnan(value)
  18. if tf and tf.is_tensor(value):
  19. return tf.math.is_nan(value)
  20. return np.isnan(value)
  21. @DeveloperAPI
  22. def single_value_to_cpu(value):
  23. """Convert a single value to CPU if it's a tensor.
  24. TensorFlow tensors are always converted to numpy/python values.
  25. PyTorch tensors are converted to python scalars.
  26. """
  27. if torch and isinstance(value, torch.Tensor):
  28. return value.detach().cpu().item()
  29. elif tf and tf.is_tensor(value):
  30. return value.numpy()
  31. return value
  32. @DeveloperAPI
  33. def batch_values_to_cpu(values: Union[List[Any], deque]) -> List[Any]:
  34. """Convert a list or deque of GPU tensors to CPU scalars in a single operation.
  35. This function efficiently processes multiple PyTorch GPU tensors together by
  36. stacking them and performing a single .cpu() call. Assumes all values are either
  37. PyTorch tensors (on same device) or already CPU values.
  38. Args:
  39. values: A list or deque of values that may be GPU tensors.
  40. Returns:
  41. A list of CPU scalar values.
  42. """
  43. if not values:
  44. return []
  45. # Check if first value is a torch tensor - assume all are the same type
  46. if torch and isinstance(values[0], torch.Tensor):
  47. # Stack all tensors and move to CPU in one operation
  48. stacked = torch.stack(list(values))
  49. cpu_tensor = stacked.detach().cpu()
  50. return cpu_tensor.tolist()
  51. # Already CPU values
  52. return list(values)