| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307 |
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
- if TYPE_CHECKING:
- import torch
- import ray
- from ray.dag import (
- ClassMethodNode,
- DAGNode,
- )
- from ray.dag.constants import COLLECTIVE_OPERATION_KEY, IS_CLASS_METHOD_OUTPUT_KEY
- from ray.experimental.channel import ChannelContext
- from ray.experimental.channel.torch_tensor_type import Communicator, TorchTensorType
- from ray.experimental.util.types import (
- AllGatherOp,
- AllReduceOp,
- ReduceScatterOp,
- _CollectiveOp,
- )
- from ray.util.annotations import DeveloperAPI
- class _CollectiveOperation:
- """
- Represent metadata for a collective communicator collective operation.
- Args:
- inputs: A list of lists of DAGNode. Each nested list inside
- of inputs should contain exactly one object per actor.
- If multiple nested lists are provided, then the order of
- actors should be the same for each nested list.
- op: The collective operation to perform.
- transport: The transport to use for the collective operation.
- Requirements:
- 1. Input nodes are unique.
- 2. Actor handles are unique.
- 3. Actor handles match the custom communicator group if specified.
- """
- def __init__(
- self,
- inputs: List[List[DAGNode]],
- op: _CollectiveOp,
- transport: Optional[Union[str, Communicator]] = None,
- ):
- self._actor_handles: List["ray.actor.ActorHandle"] = []
- for i, input_nodes in enumerate(inputs):
- # Check non-empty input list
- if len(input_nodes) == 0:
- nested_list_error_msg = f" at index {i}" if len(inputs) > 1 else ""
- raise ValueError(
- f"Expected non-empty input list{nested_list_error_msg}."
- )
- # Check input nodes are DAGNode
- if not all(isinstance(node, DAGNode) for node in input_nodes):
- nested_list_error_msg = (
- f" at list at index {i}" if len(inputs) > 1 else ""
- )
- raise ValueError(
- f"Expected all input nodes to be DAGNode{nested_list_error_msg}, "
- f"but got {input_nodes}."
- )
- # Check unique input nodes
- if len(set(input_nodes)) != len(input_nodes):
- duplicates = [
- input_node
- for input_node in input_nodes
- if input_nodes.count(input_node) > 1
- ]
- nested_list_error_msg = (
- f" at list at index {i}" if len(inputs) > 1 else ""
- )
- raise ValueError(
- f"Expected unique input nodes{nested_list_error_msg}, but found duplicates: "
- f"{duplicates}"
- )
- current_actor_handles = []
- for input_node in input_nodes:
- actor_handle = input_node._get_actor_handle()
- if actor_handle is None:
- nested_list_error_msg = (
- f" at list at index {i}" if len(inputs) > 1 else ""
- )
- raise ValueError(
- f"Expected an actor handle from the input node{nested_list_error_msg}"
- )
- current_actor_handles.append(actor_handle)
- # Check unique actor handles
- if len(set(current_actor_handles)) != len(current_actor_handles):
- invalid_input_nodes = [
- input_node
- for input_node in input_nodes
- if current_actor_handles.count(input_node._get_actor_handle()) > 1
- ]
- nested_list_error_msg = (
- f" at list at index {i}" if len(inputs) > 1 else ""
- )
- raise ValueError(
- f"Expected unique actor handles{nested_list_error_msg}, "
- "but found duplicate actor handles from input nodes: "
- f"{invalid_input_nodes}"
- )
- if i == 0:
- first_actor_handles = current_actor_handles
- # Check all lists of DAGNode have the same number of nodes
- if len(inputs[0]) != len(inputs[i]):
- raise ValueError(
- f"Expected all input lists to have the same number of nodes. "
- f"List at index 0 has length {len(inputs[0])}, but list at "
- f"index {i} has length {len(inputs[i])}."
- )
- # Check all lists of DAGNode have same set of actor handles
- if set(first_actor_handles) != set(current_actor_handles):
- raise ValueError(
- f"Expected all input lists to have the same set of actor handles. "
- f"List at index 0 has actors {set(first_actor_handles)}, but list at "
- f"index {i} has actors {set(current_actor_handles)}."
- )
- # Check all lists of DAGNode have same order of actor handles
- for j, (first, current) in enumerate(
- zip(first_actor_handles, current_actor_handles)
- ):
- if first != current:
- raise ValueError(
- f"Expected all input lists to have the same order of actor handles. "
- f"List at index 0 has actor {first} at position {j}, but list at "
- f"index {i} has actor {current} at position {j}."
- )
- self._actor_handles = current_actor_handles
- self._op = op
- if transport is None:
- transport = TorchTensorType.ACCELERATOR
- self._type_hint = TorchTensorType(transport=transport, _direct_return=True)
- if isinstance(transport, Communicator):
- if set(transport.get_actor_handles()) != set(self._actor_handles):
- raise ValueError(
- "Expected actor handles to match the custom communicator group"
- )
- def __str__(self) -> str:
- return (
- f"CollectiveOperation("
- f"_actor_handles={self._actor_handles}, "
- f"_op={self._op}, "
- f"_type_hint={self._type_hint})"
- )
- @property
- def actor_handles(self) -> List["ray.actor.ActorHandle"]:
- return self._actor_handles
- @property
- def type_hint(self) -> TorchTensorType:
- return self._type_hint
- def get_communicator(self) -> Communicator:
- if self._type_hint.communicator_id is not None:
- ctx = ChannelContext.get_current()
- communicator = ctx.communicators[self._type_hint.communicator_id]
- elif self._type_hint.get_custom_communicator() is not None:
- communicator = self._type_hint.get_custom_communicator()
- else:
- raise ValueError("Expected a communicator group")
- return communicator
- def execute(
- self, *send_buf: "torch.Tensor"
- ) -> Union["torch.Tensor", Tuple["torch.Tensor", ...]]:
- """
- Call the collective operation on the input tensor(s). Output tensor(s) are
- allocated and returned.
- Args:
- *send_buf: A variable number of torch tensors to send to the collective
- operation. The tensors have the same order as the input nodes.
- Returns:
- A torch tensor or a tuple of torch tensors containing the results of the
- collective operation. The output tensors have the same length and order
- as the input node list of the actor of this operation.
- """
- import torch
- if not all(isinstance(t, torch.Tensor) for t in send_buf):
- raise ValueError("Expected a torch tensor for each input node")
- communicator = self.get_communicator()
- if isinstance(self._op, AllGatherOp):
- assert len(send_buf) == 1
- t = send_buf[0]
- world_size = len(self._actor_handles)
- recv_buf = torch.empty(
- (t.shape[0] * world_size, *t.shape[1:]),
- dtype=t.dtype,
- device=t.device,
- )
- communicator.allgather(t, recv_buf)
- elif isinstance(self._op, AllReduceOp):
- if len(send_buf) == 1:
- t = send_buf[0]
- recv_buf = torch.empty_like(t)
- communicator.allreduce(t, recv_buf, self._op.reduceOp)
- else:
- if not all(t.dtype == send_buf[0].dtype for t in send_buf):
- raise ValueError(
- "Expected all input tensors to have the same dtype, "
- f"but got {[t.dtype for t in send_buf]}"
- )
- def unflatten_from(flat_buf, bufs):
- views = []
- offset = 0
- for t in bufs:
- numel = t.numel()
- t = flat_buf[offset : offset + numel].view(t.shape)
- views.append(t)
- offset += numel
- return tuple(views)
- flat_buf = torch.nn.utils.parameters_to_vector(send_buf)
- communicator.allreduce(flat_buf, flat_buf, self._op.reduceOp)
- recv_buf = unflatten_from(flat_buf, send_buf)
- elif isinstance(self._op, ReduceScatterOp):
- assert len(send_buf) == 1
- t = send_buf[0]
- world_size = len(self._actor_handles)
- if t.shape[0] % world_size != 0:
- raise ValueError(
- "Expected the first dimension of the input tensor to be divisible "
- f"by the world size {world_size}"
- )
- recv_buf = torch.empty(
- (t.shape[0] // world_size, *t.shape[1:]),
- dtype=t.dtype,
- device=t.device,
- )
- communicator.reducescatter(t, recv_buf, self._op.reduceOp)
- return recv_buf
- @DeveloperAPI
- class CollectiveOutputNode(ClassMethodNode):
- """Represent an output node from a communicator collective operation in a Ray DAG."""
- def __init__(
- self,
- method_name: str,
- method_args: Tuple[
- DAGNode,
- ],
- method_kwargs: Dict[str, Any],
- method_options: Dict[str, Any],
- other_args_to_resolve: Dict[str, Any],
- ):
- # Parse the input node(s).
- self._inputs = method_args
- # Parse the collective operation.
- self._collective_op: _CollectiveOperation = other_args_to_resolve.get(
- COLLECTIVE_OPERATION_KEY, None
- )
- self._is_class_method_output: bool = other_args_to_resolve.get(
- IS_CLASS_METHOD_OUTPUT_KEY, False
- )
- if self._collective_op is None and not self._is_class_method_output:
- raise ValueError("Expected a collective operation")
- super().__init__(
- method_name,
- method_args,
- method_kwargs,
- method_options,
- other_args_to_resolve,
- )
- def _copy_impl(
- self,
- new_args: List[Any],
- new_kwargs: Dict[str, Any],
- new_options: Dict[str, Any],
- new_other_args_to_resolve: Dict[str, Any],
- ):
- return CollectiveOutputNode(
- self._method_name,
- new_args,
- new_kwargs,
- new_options,
- other_args_to_resolve=new_other_args_to_resolve,
- )
- def _execute_impl(self, *args, **kwargs):
- raise NotImplementedError(
- "CollectiveOutputNode is only supported with dag.experimental_compile()"
- )
- @property
- def collective_op(self) -> _CollectiveOperation:
- return self._collective_op
|