nccl_group.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. import logging
  2. from types import ModuleType
  3. from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
  4. import ray
  5. from ray.exceptions import RayChannelError
  6. from ray.experimental.channel.accelerator_context import AcceleratorContext
  7. from ray.experimental.channel.communicator import Communicator, TorchTensorAllocator
  8. from ray.experimental.util.types import ReduceOp
  9. if TYPE_CHECKING:
  10. import torch
  11. # Logger for this module. It should be configured at the entry point
  12. # into the program using Ray. Ray provides a default configuration at
  13. # entry/init points.
  14. logger = logging.getLogger(__name__)
  15. class _NcclGroup(Communicator):
  16. """
  17. Represents an actor's NCCL communicator. This is the default NCCL communicator
  18. to be used in Compiled Graph if a custom communicator is not provided.
  19. This class is not thread-safe.
  20. """
  21. def __init__(
  22. self,
  23. world_size: int,
  24. comm_id: tuple,
  25. rank: Optional[int],
  26. actor_handles: List["ray.actor.ActorHandle"],
  27. cuda_stream: Optional["torch.cuda.Stream"],
  28. use_communication_streams: bool = False,
  29. ):
  30. """
  31. Initialize a NCCL communicator that can be used to communicate p2p with
  32. other GPU actors.
  33. This method blocks until the same call has been made on all other
  34. actors in the group, with the same arguments for world_size and
  35. comm_id.
  36. NOTE: A concurrent NCCL group can coexist with this one but using the
  37. two groups concurrently on different CUDA streams may cause deadlock.
  38. See
  39. https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html
  40. #using-multiple-nccl-communicators-concurrently.
  41. If the user can guarantee that all involved actors execute the same ops
  42. in the same order, then the other NCCL group should use the given
  43. `cuda_stream`, and there will not be a concurrency issue. Otherwise,
  44. the other stream needs to synchronize with the given `cuda_stream`
  45. before and after it launches NCCL ops, e.g., at the beginning and end
  46. of a DAG task.
  47. Args:
  48. world_size: The number of participating actors/devices.
  49. comm_id: A unique communicator ID returned by
  50. cupy.cuda.nccl.get_unique_id().
  51. rank: The rank of this actor. If None, then the caller is not a
  52. participant of the NCCL group.
  53. actor_handles: A list of actor handles, in rank order.
  54. cuda_stream: A raw CUDA stream to dispatch NCCL ops to. If rank is
  55. specified, then this must be specified too.
  56. use_communication_streams: Whether to use dedicated send and recv
  57. streams for communication. If True, communication and computation
  58. can be overlapped to improve performance.
  59. """
  60. self._world_size = world_size
  61. self._rank: Optional[int] = rank
  62. self.nccl_util: Optional[ModuleType] = None
  63. self._actor_handles = actor_handles
  64. self._use_communication_streams = use_communication_streams
  65. if rank is not None:
  66. assert ray.get_gpu_ids(), "NCCL actor has no GPUs assigned"
  67. assert cuda_stream is not None, "NCCL actor must specify cuda_stream"
  68. expected_rank = self.get_rank(ray.get_runtime_context().current_actor)
  69. assert (
  70. rank == expected_rank
  71. ), f"NCCL actor's rank {rank} does not match expected rank {expected_rank}"
  72. from ray.util.collective.collective_group import nccl_util
  73. self.nccl_util = nccl_util
  74. self._comm = self.nccl_util.NcclCommunicator(world_size, comm_id, rank)
  75. else:
  76. # Driver does not have a rank.
  77. self._comm = None
  78. self._cuda_stream: Optional["torch.cuda.Stream"] = None
  79. self._send_stream: Optional["torch.cuda.Stream"] = None
  80. self._recv_stream: Optional["torch.cuda.Stream"] = None
  81. if cuda_stream is not None:
  82. assert rank is not None, "NCCL actor has no rank assigned"
  83. self._cuda_stream = cuda_stream
  84. if use_communication_streams:
  85. import torch
  86. # TODO(swang): Allow default device to be overridden.
  87. device = AcceleratorContext.get().get_accelerator_devices()[0]
  88. self._send_stream = torch.cuda.Stream(device=device)
  89. self._recv_stream = torch.cuda.Stream(device=device)
  90. else:
  91. self._send_stream = self._cuda_stream
  92. self._recv_stream = self._cuda_stream
  93. self._closed = False
  94. def initialize(self, rank: int) -> None:
  95. # No additional initialization is needed.
  96. pass
  97. def get_actor_handles(self) -> List["ray.actor.ActorHandle"]:
  98. return self._actor_handles
  99. def get_rank(self, actor: ray.actor.ActorHandle) -> int:
  100. """
  101. Return the given actor's rank in the NCCL communicator.
  102. Args:
  103. actor: The actor handle to look up.
  104. """
  105. actor_ids = [a._ray_actor_id for a in self._actor_handles]
  106. try:
  107. rank = actor_ids.index(actor._ray_actor_id)
  108. except ValueError:
  109. raise ValueError("Actor is not in the NCCL group.")
  110. return rank
  111. def get_self_rank(self) -> Optional[int]:
  112. """
  113. Return this actor's rank.
  114. """
  115. return self._rank
  116. def get_world_size(self) -> int:
  117. """
  118. Return the number of ranks in the NCCL communicator.
  119. """
  120. return self._world_size
  121. def send(self, buf: "torch.Tensor", peer_rank: int) -> None:
  122. """
  123. Send a torch.Tensor to a peer.
  124. This returns when the send kernel has been queued, but the kernel may
  125. not have completed. Therefore, the caller should ensure that there are
  126. no concurrent writes to the sent `buf` until the send has finished.
  127. That is, either all writes should be submitted on the current stream
  128. (self._cuda_stream) or, if on a different stream, that stream should
  129. synchronize with the current stream.
  130. Args:
  131. buf: The torch.Tensor to send. It should already be on this
  132. actor's default device.
  133. peer_rank: The rank of the actor to send to.
  134. """
  135. if self._closed:
  136. raise RayChannelError("NCCL group has been destroyed.")
  137. if self._use_communication_streams:
  138. # We observed that if all recv/compute/send operations run on GPU,
  139. # since there is no synchronization, the CPU execution loop may be
  140. # far ahead of the GPU operations and lead to runtime failures.
  141. # To avoid that, we synchronize on the send stream.
  142. # TODO(rui): find a better approach
  143. self._send_stream.synchronize()
  144. # TODO(swang): Handle send/recv async NCCL errors such as network
  145. # failures.
  146. self._comm.send(
  147. self.nccl_util.get_tensor_ptr(buf),
  148. buf.numel(),
  149. self.nccl_util.get_nccl_tensor_dtype(buf),
  150. peer_rank,
  151. self._send_stream.cuda_stream,
  152. )
  153. def recv(
  154. self,
  155. shape: Tuple[int],
  156. dtype: "torch.dtype",
  157. peer_rank: int,
  158. allocator=Optional[TorchTensorAllocator],
  159. ) -> "torch.Tensor":
  160. """
  161. Receive a torch.Tensor from a peer and synchronize the current stream.
  162. After this call returns, the receive buffer is safe to read from from
  163. any stream. An RayChannelError will be raised if an error occurred (e.g.,
  164. remote actor died), and the buffer is not safe to read.
  165. Args:
  166. buf: The torch.Tensor to receive into. This buffer is safe to read
  167. peer_rank: The rank of the actor to receive from.
  168. """
  169. if self._closed:
  170. raise RayChannelError("NCCL group has been destroyed.")
  171. assert allocator is not None, "NCCL group requires a tensor allocator"
  172. buf = allocator(shape, dtype)
  173. if self._use_communication_streams:
  174. # We observed that if all recv/compute/send operations run on GPU,
  175. # since there is no synchronization, the CPU execution loop may be
  176. # far ahead of the GPU operations and lead to runtime failures.
  177. # To avoid that, we synchronize on the recv stream.
  178. # TODO(rui): find a better approach
  179. self._recv_stream.synchronize()
  180. self._comm.recv(
  181. self.nccl_util.get_tensor_ptr(buf),
  182. buf.numel(),
  183. self.nccl_util.get_nccl_tensor_dtype(buf),
  184. peer_rank,
  185. self._recv_stream.cuda_stream,
  186. )
  187. else:
  188. self._comm.recv(
  189. self.nccl_util.get_tensor_ptr(buf),
  190. buf.numel(),
  191. self.nccl_util.get_nccl_tensor_dtype(buf),
  192. peer_rank,
  193. self._recv_stream.cuda_stream,
  194. )
  195. # Buffer values are undefined if NCCL ops are aborted. Therefore, we
  196. # need to synchronize here and check that the channel is still open to
  197. # ensure that the receive buffer is valid.
  198. # TODO(swang): Avoid CUDA synchronization.
  199. self._cuda_stream.synchronize()
  200. if self._closed:
  201. raise RayChannelError("NCCL group has been destroyed.")
  202. return buf
  203. def _exec_collective(
  204. self,
  205. send_buf: "torch.Tensor",
  206. recv_buf: "torch.Tensor",
  207. operation: "Callable[..., None]",
  208. *operation_args,
  209. ):
  210. if self._closed:
  211. raise RayChannelError("NCCL group has been destroyed.")
  212. assert send_buf.dtype == recv_buf.dtype, (
  213. "Ray Compiled Graph derived the dtype of recv_buf from send_buf, "
  214. "so send_buf and recv_buf must have the same dtype. "
  215. "If you see this error, please file an issue at Ray repository."
  216. )
  217. operation(*operation_args)
  218. # Buffer values are undefined if NCCL ops are aborted. Therefore, we
  219. # need to synchronize here and check that the channel is still open to
  220. # ensure that the receive buffer is valid.
  221. # TODO(swang): Avoid CUDA synchronization.
  222. # TODO(wxdeng): This synchronize will be optional after merging the unify PR.
  223. self._cuda_stream.synchronize()
  224. if self._closed:
  225. raise RayChannelError(
  226. "NCCL group has been destroyed during allreduce operation. "
  227. "There may be a dtype mismatch between input tensors from "
  228. "different ranks."
  229. )
  230. def allgather(
  231. self,
  232. send_buf: "torch.Tensor",
  233. recv_buf: "torch.Tensor",
  234. ):
  235. operation_args = [
  236. self.nccl_util.get_tensor_ptr(send_buf),
  237. self.nccl_util.get_tensor_ptr(recv_buf),
  238. send_buf.numel(),
  239. self.nccl_util.get_nccl_tensor_dtype(send_buf),
  240. self._cuda_stream.cuda_stream,
  241. ]
  242. self._exec_collective(
  243. send_buf,
  244. recv_buf,
  245. self._comm.allGather,
  246. *operation_args,
  247. )
  248. def allreduce(
  249. self,
  250. send_buf: "torch.Tensor",
  251. recv_buf: "torch.Tensor",
  252. op: ReduceOp = ReduceOp.SUM,
  253. ):
  254. operation_args = [
  255. self.nccl_util.get_tensor_ptr(send_buf),
  256. self.nccl_util.get_tensor_ptr(recv_buf),
  257. send_buf.numel(),
  258. self.nccl_util.get_nccl_tensor_dtype(send_buf),
  259. op.value,
  260. self._cuda_stream.cuda_stream,
  261. ]
  262. self._exec_collective(
  263. send_buf,
  264. recv_buf,
  265. self._comm.allReduce,
  266. *operation_args,
  267. )
  268. def reducescatter(
  269. self,
  270. send_buf: "torch.Tensor",
  271. recv_buf: "torch.Tensor",
  272. op: ReduceOp = ReduceOp.SUM,
  273. ):
  274. operation_args = [
  275. self.nccl_util.get_tensor_ptr(send_buf),
  276. self.nccl_util.get_tensor_ptr(recv_buf),
  277. recv_buf.numel(),
  278. self.nccl_util.get_nccl_tensor_dtype(send_buf),
  279. op.value,
  280. self._cuda_stream.cuda_stream,
  281. ]
  282. self._exec_collective(
  283. send_buf,
  284. recv_buf,
  285. self._comm.reduceScatter,
  286. *operation_args,
  287. )
  288. @property
  289. def recv_stream(self):
  290. import torch
  291. return torch.cuda.StreamContext(self._recv_stream)
  292. @property
  293. def send_stream(self):
  294. import torch
  295. return torch.cuda.StreamContext(self._send_stream)
  296. def destroy(self) -> None:
  297. """
  298. Destroy the NCCL group.
  299. """
  300. if self._closed:
  301. return
  302. self._closed = True
  303. if self._comm is not None:
  304. logger.info(
  305. "Destructing NCCL group on actor: "
  306. f"{ray.get_runtime_context().current_actor}"
  307. )
  308. # Abort *after* setting the _closed flag. This ensures that NCCL
  309. # ops that were blocked on a remote peer will see that the _closed
  310. # flag is True when they exit from the abort.
  311. self._comm.abort()
  312. self._comm.destroy()
  313. def get_transport_name(self) -> str:
  314. return "accelerator"
  315. @classmethod
  316. def generate_communicator_id(cls) -> str:
  317. from cupy.cuda import nccl
  318. return nccl.get_unique_id()