from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Callable, List, Optional, Tuple import ray from ray.experimental.util.types import ReduceOp from ray.util.annotations import DeveloperAPI if TYPE_CHECKING: import torch # Signature for a torch.Tensor allocator is: # (shape: Tuple[int], dtype: torch.dtype) -> torch.Tensor. TorchTensorAllocator = Callable[[Tuple[int], "torch.dtype"], "torch.Tensor"] @DeveloperAPI class Communicator(ABC): """ Communicator for a group of Compiled Graph actors on NVIDIA GPU. The Compiled Graph execution leverages this internally to support communication between actors in the group. """ @abstractmethod def initialize(self, rank: int) -> None: """ Initialize the communicator from the actor. This is called once by Compiled Graph on each actor to initialize the communicator,before any other methods. Args: rank: The rank of this actor in the group. """ raise NotImplementedError @abstractmethod def get_actor_handles(self) -> List["ray.actor.ActorHandle"]: """ Get handles of all actors for this communicator group. """ raise NotImplementedError @abstractmethod def get_rank(self, actor: ray.actor.ActorHandle) -> int: """ Return the given actor's rank in the group. Args: actor: The actor handle to look up. """ raise NotImplementedError @abstractmethod def get_self_rank(self) -> Optional[int]: """ Return this actor's rank. """ raise NotImplementedError def get_world_size(self) -> int: """ Return the number of ranks in the group. """ raise NotImplementedError @abstractmethod def send(self, value: "torch.Tensor", peer_rank: int) -> None: """ Send a torch.Tensor to a peer. This returns when the send kernel has been queued, but the kernel may not have completed. Therefore, the caller should ensure that there are no concurrent writes to the sent `value` until the send has finished. Args: value: The torch.Tensor to send. It should already be on this actor's default device. peer_rank: The rank of the actor to send to. """ raise NotImplementedError @abstractmethod def recv( self, shape: Tuple[int], dtype: "torch.dtype", peer_rank: int, allocator: Optional[TorchTensorAllocator] = None, ) -> "torch.Tensor": """ Receive a torch.Tensor from a peer and synchronize. After this call returns, the receive buffer is safe to read from from any stream. An RayChannelError will be raised if an error occurred (e.g., remote actor died), and the buffer is not safe to read. Args: shape: The shape of the tensor to receive. dtype: The dtype of the tensor to receive. peer_rank: The rank of the actor to receive from. allocator: A function to allocate the tensor to receive into. """ raise NotImplementedError @property @abstractmethod def recv_stream(self): """ Return the torch stream context used for receiving tensors. """ raise NotImplementedError @property @abstractmethod def send_stream(self): """ Return the torch stream context used for sending tensors. """ raise NotImplementedError @abstractmethod def allgather( self, send_buf: "torch.Tensor", recv_buf: "torch.Tensor", ) -> None: """ Collectively allgather the tensor across the group. Args: send_buf: The input torch.tensor to allgather. It should already be on this actor's default device. recv_buf: The output torch.tensor to store the allgather result. """ raise NotImplementedError @abstractmethod def allreduce( self, send_buf: "torch.Tensor", recv_buf: "torch.Tensor", op: ReduceOp, ) -> None: """ Collectively allreduce the tensor across the group. Args: send_buf: The input torch.tensor to allreduce. It should already be on this actor's default device. recv_buf: The output torch.tensor to store the allreduce result. op: The reduce operation. """ raise NotImplementedError @abstractmethod def reducescatter( self, send_buf: "torch.Tensor", recv_buf: "torch.Tensor", op: ReduceOp, ) -> None: """ Collectively reducescatter the tensor across the group. Args: send_buf: The input torch.tensor to reducescatter. It should already be on this actor's default device. recv_buf: The output torch.tensor to store the reducescatter result. op: The reduce operation. """ raise NotImplementedError @abstractmethod def destroy(self) -> None: """ Destroy the GPU communicator. Any destruction and cleanup for the GPU communicator should be done here. Implement as a noop is nothing is needed. """ raise NotImplementedError @abstractmethod def get_transport_name(self) -> str: """ Return the type of the communicator (gpu or cpu). """ raise NotImplementedError @classmethod @abstractmethod def generate_communicator_id(cls) -> str: """ Return the unique id of the communicator. """ raise NotImplementedError