util.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import threading
  2. from typing import TYPE_CHECKING, Dict, List, NamedTuple, Optional
  3. import ray
  4. from ray._raylet import ObjectRef
  5. from ray.experimental.gpu_object_manager.collective_tensor_transport import (
  6. GLOOTensorTransport,
  7. NCCLTensorTransport,
  8. )
  9. from ray.experimental.gpu_object_manager.cuda_ipc_transport import CudaIpcTransport
  10. from ray.experimental.gpu_object_manager.nixl_tensor_transport import (
  11. NixlTensorTransport,
  12. )
  13. from ray.experimental.gpu_object_manager.tensor_transport_manager import (
  14. TensorTransportManager,
  15. TensorTransportMetadata,
  16. )
  17. from ray.util.annotations import PublicAPI
  18. if TYPE_CHECKING:
  19. import torch
  20. class TransportManagerInfo(NamedTuple):
  21. transport_manager_class: type[TensorTransportManager]
  22. # list of support device types for the transport
  23. devices: List[str]
  24. transport_manager_info: Dict[str, TransportManagerInfo] = {}
  25. # Singleton instances of transport managers
  26. transport_managers: Dict[str, TensorTransportManager] = {}
  27. # To protect the singleton instances of transport managers
  28. transport_managers_lock = threading.Lock()
  29. # Flipped to True when the first custom transport is registered.
  30. has_custom_transports = False
  31. @PublicAPI(stability="alpha")
  32. def register_tensor_transport(
  33. transport_name: str,
  34. devices: List[str],
  35. transport_manager_class: type[TensorTransportManager],
  36. ):
  37. """
  38. Register a new tensor transport for use in Ray. Note that this needs to be called
  39. before you create the actors that will use the transport. The actors also
  40. need to be created in the same process from which you call this function.
  41. Args:
  42. transport_name: The name of the transport protocol.
  43. devices: List of PyTorch device types supported by this transport (e.g., ["cuda", "cpu"]).
  44. transport_manager_class: A class that implements TensorTransportManager.
  45. Raises:
  46. ValueError: If transport_manager_class is not a subclass of TensorTransportManager.
  47. """
  48. global transport_manager_info
  49. global has_custom_transports
  50. transport_name = transport_name.upper()
  51. if transport_name in transport_manager_info:
  52. raise ValueError(f"Transport {transport_name} already registered.")
  53. if not issubclass(transport_manager_class, TensorTransportManager):
  54. raise ValueError(
  55. f"transport_manager_class {transport_manager_class.__name__} must be a subclass of TensorTransportManager."
  56. )
  57. transport_manager_info[transport_name] = TransportManagerInfo(
  58. transport_manager_class, devices
  59. )
  60. if transport_name not in DEFAULT_TRANSPORTS:
  61. has_custom_transports = True
  62. DEFAULT_TRANSPORTS = ["NIXL", "GLOO", "NCCL", "CUDA_IPC"]
  63. register_tensor_transport("NIXL", ["cuda", "cpu"], NixlTensorTransport)
  64. register_tensor_transport("GLOO", ["cpu"], GLOOTensorTransport)
  65. register_tensor_transport("NCCL", ["cuda"], NCCLTensorTransport)
  66. register_tensor_transport("CUDA_IPC", ["cuda"], CudaIpcTransport)
  67. def get_tensor_transport_manager(
  68. transport_name: str,
  69. ) -> "TensorTransportManager":
  70. """Get the tensor transport manager for the given tensor transport protocol.
  71. Args:
  72. transport_name: The tensor transport protocol to use for the GPU object.
  73. Returns:
  74. TensorTransportManager: The tensor transport manager for the given tensor transport protocol.
  75. """
  76. global transport_manager_info
  77. global transport_managers
  78. global transport_managers_lock
  79. with transport_managers_lock:
  80. if transport_name in transport_managers:
  81. return transport_managers[transport_name]
  82. if transport_name not in transport_manager_info:
  83. raise ValueError(f"Unsupported tensor transport protocol: {transport_name}")
  84. transport_managers[transport_name] = transport_manager_info[
  85. transport_name
  86. ].transport_manager_class()
  87. return transport_managers[transport_name]
  88. def register_custom_tensor_transports_on_actor(
  89. actor: "ray.actor.ActorHandle",
  90. ) -> Optional[ObjectRef]:
  91. """
  92. If there's no custom transports to register, returns None.
  93. Otherwise returns an object ref for a task on the actor that will register the custom transports.
  94. """
  95. global transport_manager_info
  96. global has_custom_transports
  97. if not has_custom_transports:
  98. return None
  99. def register_transport_on_actor(
  100. self, owner_transport_manager_info: Dict[str, TransportManagerInfo]
  101. ):
  102. from ray.experimental.gpu_object_manager.util import (
  103. register_tensor_transport,
  104. transport_manager_info,
  105. )
  106. for transport_name, transport_info in owner_transport_manager_info.items():
  107. if transport_name not in transport_manager_info:
  108. register_tensor_transport(
  109. transport_name,
  110. transport_info.devices,
  111. transport_info.transport_manager_class,
  112. )
  113. return actor.__ray_call__.options(concurrency_group="_ray_system").remote(
  114. register_transport_on_actor, transport_manager_info
  115. )
  116. def device_match_transport(device: "torch.device", tensor_transport: str) -> bool:
  117. """Check if the device matches the transport."""
  118. if tensor_transport not in transport_manager_info:
  119. raise ValueError(f"Unsupported tensor transport protocol: {tensor_transport}")
  120. return device.type in transport_manager_info[tensor_transport].devices
  121. def normalize_and_validate_tensor_transport(tensor_transport: str) -> str:
  122. tensor_transport = tensor_transport.upper()
  123. if tensor_transport not in transport_manager_info:
  124. raise ValueError(f"Invalid tensor transport: {tensor_transport}")
  125. return tensor_transport
  126. def validate_one_sided(tensor_transport: str, ray_usage_func: str):
  127. if not transport_manager_info[
  128. tensor_transport
  129. ].transport_manager_class.is_one_sided():
  130. raise ValueError(
  131. f"Trying to use two-sided tensor transport: {tensor_transport} for {ray_usage_func}. "
  132. "This is only supported for one-sided transports such as NIXL or the OBJECT_STORE."
  133. )
  134. def create_empty_tensors_from_metadata(
  135. tensor_transport_meta: TensorTransportMetadata,
  136. ) -> List["torch.Tensor"]:
  137. import torch
  138. tensors = []
  139. device = tensor_transport_meta.tensor_device
  140. for meta in tensor_transport_meta.tensor_meta:
  141. shape, dtype = meta
  142. tensor = torch.empty(shape, dtype=dtype, device=device)
  143. tensors.append(tensor)
  144. return tensors