import logging from typing import List, Optional, Union import ray from ray.dag.collective_node import CollectiveOutputNode, _CollectiveOperation from ray.dag.constants import ( BIND_INDEX_KEY, COLLECTIVE_OPERATION_KEY, IS_CLASS_METHOD_OUTPUT_KEY, PARENT_CLASS_NODE_KEY, ) from ray.experimental.channel.torch_tensor_type import Communicator, TorchTensorType from ray.experimental.util.types import ( AllGatherOp, AllReduceOp, ReduceOp, ReduceScatterOp, _CollectiveOp, ) from ray.util.collective.types import ReduceOp as RayReduceOp logger = logging.getLogger(__name__) def _bind( inputs: Union[List["ray.dag.DAGNode"], List[List["ray.dag.DAGNode"]]], op: _CollectiveOp, transport: Optional[Union[str, Communicator]] = None, ): """ Bind inputs (input nodes or lists of input nodes) with a collective operation. The collective operation is applied to each list of input nodes. The output nodes will have the same shape as the input nodes. Example of binding a list of input node: with InputNode() as inp: res_comp1 = [actor.comp1.bind(inp) for actor in actors] res_comp2 = [actor.comp2.bind(inp) for actor in actors] res_ar = allreduce.bind([res_comp1, res_comp2]) Requirements: 1. Each input node returns a torch tensor. 2. Each input node within a list is from a different actor. 3. If lists of input nodes are provided, the order of actors should be the same for each nested list. 4. If a custom transport is specified, its actor set matches the actor set of the input nodes. 5. If input nodes are provided, then all tensors have the same shape. If lists of input nodes are provided, then all tensors in each list have the same shape. Requirements 1-3 are checked in the `CollectiveGroup` constructor. Requirement 4 is not checked yet. Args: inputs: A list of DAG nodes or a list of lists of DAG nodes. Each leaf list should contain one object per actor. op: The collective operation. transport: GPU communicator for the collective operation. If not specified, the default ACCELERATOR is used. Returns: A list of collective output nodes or a list of lists of collective output nodes, with the same shape as the input nodes. Each output node has the same order and belongs to the same actor as the corresponding input node. """ if isinstance(inputs[0], list) and not isinstance(op, AllReduceOp): raise ValueError( "Currently binding a nested list of dag nodes is only supported for allreduce" ) # Convert list of DAGNode into nested list for type checking if not isinstance(inputs[0], list): inputs = [inputs] if transport is None: transport = TorchTensorType.ACCELERATOR collective_op = _CollectiveOperation(inputs, op, transport) collective_output_nodes: List[CollectiveOutputNode] = [] if isinstance(op, AllGatherOp): method_name = "allgather" elif isinstance(op, AllReduceOp): method_name = f"allreduce.{op.reduceOp}" elif isinstance(op, ReduceScatterOp): method_name = f"reducescatter.{op.reduceOp}" else: raise ValueError(f"Expected a collective operation, but got {op}") for i in range(len(inputs[0])): input_node_list = [l[i] for l in inputs if l] actor_handle: Optional["ray.actor.ActorHandle"] = input_node_list[ 0 ]._get_actor_handle() assert actor_handle is not None collective_output_node = CollectiveOutputNode( method_name=method_name, method_args=tuple(input_node_list), method_kwargs=dict(), method_options=dict(), other_args_to_resolve={ PARENT_CLASS_NODE_KEY: actor_handle, BIND_INDEX_KEY: actor_handle._ray_dag_bind_index, COLLECTIVE_OPERATION_KEY: collective_op, }, ) actor_handle._ray_dag_bind_index += 1 if len(input_node_list) > 1: output_nodes: List[CollectiveOutputNode] = [] for i in range(len(input_node_list)): output_node = CollectiveOutputNode( f"return_idx_{i}", (collective_output_node, i), dict(), dict(), { BIND_INDEX_KEY: collective_output_node._get_bind_index(), IS_CLASS_METHOD_OUTPUT_KEY: True, PARENT_CLASS_NODE_KEY: actor_handle, }, ) output_nodes.append(output_node) collective_output_nodes.append(output_nodes) else: collective_output_nodes.append(collective_output_node) return collective_output_nodes class AllGatherWrapper: """Wrapper for NCCL all-gather.""" def bind( self, input_nodes: List["ray.dag.DAGNode"], transport: Optional[Union[str, Communicator]] = None, ) -> List[CollectiveOutputNode]: return _bind(input_nodes, AllGatherOp(), transport) def __call__( self, tensor_list, tensor, group_name: str = "default", ): from ray.util.collective.collective import allgather return allgather(tensor_list, tensor, group_name) class AllReduceWrapper: """Wrapper for NCCL all-reduce.""" def bind( self, input_nodes: List["ray.dag.DAGNode"], op: ReduceOp = ReduceOp.SUM, transport: Optional[Union[str, Communicator]] = None, ) -> List[CollectiveOutputNode]: if not isinstance(op, ReduceOp): raise ValueError(f"Unexpected operation: {op}") return _bind(input_nodes, AllReduceOp(reduceOp=op), transport) def __call__( self, tensor, group_name: str = "default", op: RayReduceOp = RayReduceOp.SUM, ): from ray.util.collective.collective import allreduce return allreduce(tensor, group_name, op) class ReduceScatterWrapper: """Wrapper for NCCL reduce-scatter.""" def bind( self, input_nodes: List["ray.dag.DAGNode"], op: ReduceOp = ReduceOp.SUM, transport: Optional[Union[str, Communicator]] = None, ) -> List[CollectiveOutputNode]: if not isinstance(op, ReduceOp): raise ValueError(f"Unexpected operation: {op}") return _bind(input_nodes, ReduceScatterOp(reduceOp=op), transport) def __call__( self, tensor, group_name: str = "default", op: RayReduceOp = RayReduceOp.SUM, ): from ray.util.collective.collective import reducescatter return reducescatter(tensor, group_name, op) allgather = AllGatherWrapper() allreduce = AllReduceWrapper() reducescatter = ReduceScatterWrapper()