import asyncio from collections import defaultdict from typing import TYPE_CHECKING, Dict, List, Optional, Tuple import ray from ray.experimental.channel.communicator import ( Communicator, ReduceOp, TorchTensorAllocator, ) if TYPE_CHECKING: import torch @ray.remote(num_cpus=0) class CPUCommBarrier: """ Barrier actor that blocks the given number of actors until all actors have reached the Barrier. p2p operations are not done here (completed via shared memory channel). """ def __init__(self, num_actors: int): self.num_actors = num_actors self.condition = asyncio.Condition() # Stores the data for each collective operation self.collective_data: Dict[int, List["torch.Tensor"]] = defaultdict(list) # Stores the shape of data for each collective operation self.collective_data_shape: Dict[int, "torch.Tensor.type"] = {} # Buffer for the number of actors seen self.num_actors_seen = defaultdict(int) # Number of actors who have read the result, and are about to exit the function. # State is kept so we only garbage collect after the last actor has read the # relevant data. self.num_actors_read = defaultdict(int) async def wait_collective(self, op_id: int, data: "torch.Tensor", op: ReduceOp): """ Wait at the communicator until all actors have sent `op_id` and `data`. Once data from all actors is received, execute the collective `op` on the communicator actor and return the result. """ async with self.condition: self.collective_data[op_id].append(data) self.num_actors_seen[op_id] += 1 if self.num_actors_seen[op_id] == self.num_actors: # Apply the collective operation across all gathered tensors data = self._apply_op(op, self.collective_data[op_id]) self.collective_data[op_id] = data self.condition.notify_all() else: await self.condition.wait_for( lambda: self.num_actors_seen[op_id] == self.num_actors ) data = self.collective_data[op_id] self.num_actors_read[op_id] += 1 if self.num_actors_read[op_id] == self.num_actors: del self.collective_data[op_id] del self.num_actors_seen[op_id] del self.num_actors_read[op_id] return data def _apply_op(self, op: ReduceOp, tensors: List["torch.Tensor"]) -> "torch.Tensor": """Apply the specified reduction operation across a list of tensors.""" result = tensors[0].clone() if op == ReduceOp.SUM: for tensor in tensors[1:]: result += tensor elif op == ReduceOp.PRODUCT: for tensor in tensors[1:]: result *= tensor elif op == ReduceOp.MAX: for tensor in tensors[1:]: result = torch.max(result, tensor) elif op == ReduceOp.MIN: for tensor in tensors[1:]: result = torch.min(result, tensor) elif op == ReduceOp.AVG: result = sum(tensors) / len(tensors) else: raise ValueError(f"Operation {op} not supported") return result class CPUCommunicator(Communicator): """ Uses a CPU-based communicator actor instead of an accelerator group like NCCL. """ def __init__(self, world_size: int, actor_handles: List["ray.actor.ActorHandle"]): """We use the op index to synchronize the sender and receiver at the communicator actor.""" self._world_size = world_size self._actor_handles = actor_handles self.num_ops = defaultdict(int) # For collective communication, one barrier will be created for # each unique group of participants. self.barriers = set() self._rank = None def send(self, tensor: "torch.Tensor", peer_rank: int): # p2p operations are done via a shared memory channel, initialized in # `create_channel` of `TorchTensorType` pass def recv( self, shape: Tuple[int], dtype: "torch.dtype", peer_rank: int, allocator: Optional[TorchTensorAllocator] = None, ): # See the comment on `send` pass def allgather( self, send_buf: "torch.Tensor", recv_buf: "torch.Tensor", ): raise NotImplementedError def allreduce( self, send_buf: "torch.Tensor", recv_buf: "torch.Tensor", op: ReduceOp = ReduceOp.SUM, ): all_ranks = [ self.get_rank(actor_handle) for actor_handle in self.get_actor_handles() ] barrier_key = "barrier-collective-" + "-".join(map(str, sorted(all_ranks))) barrier = CPUCommBarrier.options(name=barrier_key, get_if_exists=True).remote( self._world_size ) self.barriers.add(barrier) result = ray.get( barrier.wait_collective.remote(self.num_ops[barrier_key], send_buf, op) ) assert recv_buf is not None, "Receiving buffer required for CPUCommunicator" recv_buf[:] = result[:] self.num_ops[barrier_key] += 1 def reducescatter( self, send_buf: "torch.Tensor", recv_buf: "torch.Tensor", op: ReduceOp = ReduceOp.SUM, ): raise NotImplementedError def destroy(self) -> None: for barrier in self.barriers: ray.kill(barrier) def initialize(self, rank: int) -> None: self._rank = rank def get_actor_handles(self) -> List["ray.actor.ActorHandle"]: return self._actor_handles def get_rank(self, actor: ray.actor.ActorHandle) -> int: """ Return the given actor's rank in the CPU communicator. Args: actor: The actor handle to look up. """ actor_ids = [a._ray_actor_id for a in self._actor_handles] try: rank = actor_ids.index(actor._ray_actor_id) except ValueError: raise ValueError("Actor is not in the CPUCommunicator group.") return rank def get_self_rank(self) -> Optional[int]: return self._rank def get_world_size(self) -> int: """ Return the number of ranks in the CPU communicator. """ return self._world_size def get_transport_name(self) -> str: return "cpu" def recv_stream(self): raise NotImplementedError def send_stream(self): raise NotImplementedError @classmethod def generate_communicator_id(cls) -> str: import uuid return str(uuid.uuid4())