communicator.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. from abc import ABC, abstractmethod
  2. from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
  3. import ray
  4. from ray.experimental.util.types import ReduceOp
  5. from ray.util.annotations import DeveloperAPI
  6. if TYPE_CHECKING:
  7. import torch
  8. # Signature for a torch.Tensor allocator is:
  9. # (shape: Tuple[int], dtype: torch.dtype) -> torch.Tensor.
  10. TorchTensorAllocator = Callable[[Tuple[int], "torch.dtype"], "torch.Tensor"]
  11. @DeveloperAPI
  12. class Communicator(ABC):
  13. """
  14. Communicator for a group of Compiled Graph actors on NVIDIA GPU.
  15. The Compiled Graph execution leverages this internally to support communication
  16. between actors in the group.
  17. """
  18. @abstractmethod
  19. def initialize(self, rank: int) -> None:
  20. """
  21. Initialize the communicator from the actor.
  22. This is called once by Compiled Graph on each actor to initialize the
  23. communicator,before any other methods.
  24. Args:
  25. rank: The rank of this actor in the group.
  26. """
  27. raise NotImplementedError
  28. @abstractmethod
  29. def get_actor_handles(self) -> List["ray.actor.ActorHandle"]:
  30. """
  31. Get handles of all actors for this communicator group.
  32. """
  33. raise NotImplementedError
  34. @abstractmethod
  35. def get_rank(self, actor: ray.actor.ActorHandle) -> int:
  36. """
  37. Return the given actor's rank in the group.
  38. Args:
  39. actor: The actor handle to look up.
  40. """
  41. raise NotImplementedError
  42. @abstractmethod
  43. def get_self_rank(self) -> Optional[int]:
  44. """
  45. Return this actor's rank.
  46. """
  47. raise NotImplementedError
  48. def get_world_size(self) -> int:
  49. """
  50. Return the number of ranks in the group.
  51. """
  52. raise NotImplementedError
  53. @abstractmethod
  54. def send(self, value: "torch.Tensor", peer_rank: int) -> None:
  55. """
  56. Send a torch.Tensor to a peer.
  57. This returns when the send kernel has been queued, but the kernel may
  58. not have completed. Therefore, the caller should ensure that there are
  59. no concurrent writes to the sent `value` until the send has finished.
  60. Args:
  61. value: The torch.Tensor to send. It should already be on this
  62. actor's default device.
  63. peer_rank: The rank of the actor to send to.
  64. """
  65. raise NotImplementedError
  66. @abstractmethod
  67. def recv(
  68. self,
  69. shape: Tuple[int],
  70. dtype: "torch.dtype",
  71. peer_rank: int,
  72. allocator: Optional[TorchTensorAllocator] = None,
  73. ) -> "torch.Tensor":
  74. """
  75. Receive a torch.Tensor from a peer and synchronize.
  76. After this call returns, the receive buffer is safe to read from from
  77. any stream. An RayChannelError will be raised if an error occurred (e.g.,
  78. remote actor died), and the buffer is not safe to read.
  79. Args:
  80. shape: The shape of the tensor to receive.
  81. dtype: The dtype of the tensor to receive.
  82. peer_rank: The rank of the actor to receive from.
  83. allocator: A function to allocate the tensor to receive into.
  84. """
  85. raise NotImplementedError
  86. @property
  87. @abstractmethod
  88. def recv_stream(self):
  89. """
  90. Return the torch stream context used for receiving tensors.
  91. """
  92. raise NotImplementedError
  93. @property
  94. @abstractmethod
  95. def send_stream(self):
  96. """
  97. Return the torch stream context used for sending tensors.
  98. """
  99. raise NotImplementedError
  100. @abstractmethod
  101. def allgather(
  102. self,
  103. send_buf: "torch.Tensor",
  104. recv_buf: "torch.Tensor",
  105. ) -> None:
  106. """
  107. Collectively allgather the tensor across the group.
  108. Args:
  109. send_buf: The input torch.tensor to allgather. It should already be
  110. on this actor's default device.
  111. recv_buf: The output torch.tensor to store the allgather result.
  112. """
  113. raise NotImplementedError
  114. @abstractmethod
  115. def allreduce(
  116. self,
  117. send_buf: "torch.Tensor",
  118. recv_buf: "torch.Tensor",
  119. op: ReduceOp,
  120. ) -> None:
  121. """
  122. Collectively allreduce the tensor across the group.
  123. Args:
  124. send_buf: The input torch.tensor to allreduce. It should already be
  125. on this actor's default device.
  126. recv_buf: The output torch.tensor to store the allreduce result.
  127. op: The reduce operation.
  128. """
  129. raise NotImplementedError
  130. @abstractmethod
  131. def reducescatter(
  132. self,
  133. send_buf: "torch.Tensor",
  134. recv_buf: "torch.Tensor",
  135. op: ReduceOp,
  136. ) -> None:
  137. """
  138. Collectively reducescatter the tensor across the group.
  139. Args:
  140. send_buf: The input torch.tensor to reducescatter. It should already be
  141. on this actor's default device.
  142. recv_buf: The output torch.tensor to store the reducescatter result.
  143. op: The reduce operation.
  144. """
  145. raise NotImplementedError
  146. @abstractmethod
  147. def destroy(self) -> None:
  148. """
  149. Destroy the GPU communicator.
  150. Any destruction and cleanup for the GPU communicator should be
  151. done here. Implement as a noop is nothing is needed.
  152. """
  153. raise NotImplementedError
  154. @abstractmethod
  155. def get_transport_name(self) -> str:
  156. """
  157. Return the type of the communicator (gpu or cpu).
  158. """
  159. raise NotImplementedError
  160. @classmethod
  161. @abstractmethod
  162. def generate_communicator_id(cls) -> str:
  163. """
  164. Return the unique id of the communicator.
  165. """
  166. raise NotImplementedError