cpu_communicator.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. import asyncio
  2. from collections import defaultdict
  3. from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
  4. import ray
  5. from ray.experimental.channel.communicator import (
  6. Communicator,
  7. ReduceOp,
  8. TorchTensorAllocator,
  9. )
  10. if TYPE_CHECKING:
  11. import torch
  12. @ray.remote(num_cpus=0)
  13. class CPUCommBarrier:
  14. """
  15. Barrier actor that blocks the given number of actors until all actors have
  16. reached the Barrier.
  17. p2p operations are not done here (completed via shared memory channel).
  18. """
  19. def __init__(self, num_actors: int):
  20. self.num_actors = num_actors
  21. self.condition = asyncio.Condition()
  22. # Stores the data for each collective operation
  23. self.collective_data: Dict[int, List["torch.Tensor"]] = defaultdict(list)
  24. # Stores the shape of data for each collective operation
  25. self.collective_data_shape: Dict[int, "torch.Tensor.type"] = {}
  26. # Buffer for the number of actors seen
  27. self.num_actors_seen = defaultdict(int)
  28. # Number of actors who have read the result, and are about to exit the function.
  29. # State is kept so we only garbage collect after the last actor has read the
  30. # relevant data.
  31. self.num_actors_read = defaultdict(int)
  32. async def wait_collective(self, op_id: int, data: "torch.Tensor", op: ReduceOp):
  33. """
  34. Wait at the communicator until all actors have sent `op_id` and `data`.
  35. Once data from all actors is received, execute the collective `op`
  36. on the communicator actor and return the result.
  37. """
  38. async with self.condition:
  39. self.collective_data[op_id].append(data)
  40. self.num_actors_seen[op_id] += 1
  41. if self.num_actors_seen[op_id] == self.num_actors:
  42. # Apply the collective operation across all gathered tensors
  43. data = self._apply_op(op, self.collective_data[op_id])
  44. self.collective_data[op_id] = data
  45. self.condition.notify_all()
  46. else:
  47. await self.condition.wait_for(
  48. lambda: self.num_actors_seen[op_id] == self.num_actors
  49. )
  50. data = self.collective_data[op_id]
  51. self.num_actors_read[op_id] += 1
  52. if self.num_actors_read[op_id] == self.num_actors:
  53. del self.collective_data[op_id]
  54. del self.num_actors_seen[op_id]
  55. del self.num_actors_read[op_id]
  56. return data
  57. def _apply_op(self, op: ReduceOp, tensors: List["torch.Tensor"]) -> "torch.Tensor":
  58. """Apply the specified reduction operation across a list of tensors."""
  59. result = tensors[0].clone()
  60. if op == ReduceOp.SUM:
  61. for tensor in tensors[1:]:
  62. result += tensor
  63. elif op == ReduceOp.PRODUCT:
  64. for tensor in tensors[1:]:
  65. result *= tensor
  66. elif op == ReduceOp.MAX:
  67. for tensor in tensors[1:]:
  68. result = torch.max(result, tensor)
  69. elif op == ReduceOp.MIN:
  70. for tensor in tensors[1:]:
  71. result = torch.min(result, tensor)
  72. elif op == ReduceOp.AVG:
  73. result = sum(tensors) / len(tensors)
  74. else:
  75. raise ValueError(f"Operation {op} not supported")
  76. return result
  77. class CPUCommunicator(Communicator):
  78. """
  79. Uses a CPU-based communicator actor instead of an accelerator group like NCCL.
  80. """
  81. def __init__(self, world_size: int, actor_handles: List["ray.actor.ActorHandle"]):
  82. """We use the op index to synchronize the sender and receiver at the
  83. communicator actor."""
  84. self._world_size = world_size
  85. self._actor_handles = actor_handles
  86. self.num_ops = defaultdict(int)
  87. # For collective communication, one barrier will be created for
  88. # each unique group of participants.
  89. self.barriers = set()
  90. self._rank = None
  91. def send(self, tensor: "torch.Tensor", peer_rank: int):
  92. # p2p operations are done via a shared memory channel, initialized in
  93. # `create_channel` of `TorchTensorType`
  94. pass
  95. def recv(
  96. self,
  97. shape: Tuple[int],
  98. dtype: "torch.dtype",
  99. peer_rank: int,
  100. allocator: Optional[TorchTensorAllocator] = None,
  101. ):
  102. # See the comment on `send`
  103. pass
  104. def allgather(
  105. self,
  106. send_buf: "torch.Tensor",
  107. recv_buf: "torch.Tensor",
  108. ):
  109. raise NotImplementedError
  110. def allreduce(
  111. self,
  112. send_buf: "torch.Tensor",
  113. recv_buf: "torch.Tensor",
  114. op: ReduceOp = ReduceOp.SUM,
  115. ):
  116. all_ranks = [
  117. self.get_rank(actor_handle) for actor_handle in self.get_actor_handles()
  118. ]
  119. barrier_key = "barrier-collective-" + "-".join(map(str, sorted(all_ranks)))
  120. barrier = CPUCommBarrier.options(name=barrier_key, get_if_exists=True).remote(
  121. self._world_size
  122. )
  123. self.barriers.add(barrier)
  124. result = ray.get(
  125. barrier.wait_collective.remote(self.num_ops[barrier_key], send_buf, op)
  126. )
  127. assert recv_buf is not None, "Receiving buffer required for CPUCommunicator"
  128. recv_buf[:] = result[:]
  129. self.num_ops[barrier_key] += 1
  130. def reducescatter(
  131. self,
  132. send_buf: "torch.Tensor",
  133. recv_buf: "torch.Tensor",
  134. op: ReduceOp = ReduceOp.SUM,
  135. ):
  136. raise NotImplementedError
  137. def destroy(self) -> None:
  138. for barrier in self.barriers:
  139. ray.kill(barrier)
  140. def initialize(self, rank: int) -> None:
  141. self._rank = rank
  142. def get_actor_handles(self) -> List["ray.actor.ActorHandle"]:
  143. return self._actor_handles
  144. def get_rank(self, actor: ray.actor.ActorHandle) -> int:
  145. """
  146. Return the given actor's rank in the CPU communicator.
  147. Args:
  148. actor: The actor handle to look up.
  149. """
  150. actor_ids = [a._ray_actor_id for a in self._actor_handles]
  151. try:
  152. rank = actor_ids.index(actor._ray_actor_id)
  153. except ValueError:
  154. raise ValueError("Actor is not in the CPUCommunicator group.")
  155. return rank
  156. def get_self_rank(self) -> Optional[int]:
  157. return self._rank
  158. def get_world_size(self) -> int:
  159. """
  160. Return the number of ranks in the CPU communicator.
  161. """
  162. return self._world_size
  163. def get_transport_name(self) -> str:
  164. return "cpu"
  165. def recv_stream(self):
  166. raise NotImplementedError
  167. def send_stream(self):
  168. raise NotImplementedError
  169. @classmethod
  170. def generate_communicator_id(cls) -> str:
  171. import uuid
  172. return str(uuid.uuid4())