| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203 |
- import logging
- from typing import List, Optional, Union
- import ray
- from ray.dag.collective_node import CollectiveOutputNode, _CollectiveOperation
- from ray.dag.constants import (
- BIND_INDEX_KEY,
- COLLECTIVE_OPERATION_KEY,
- IS_CLASS_METHOD_OUTPUT_KEY,
- PARENT_CLASS_NODE_KEY,
- )
- from ray.experimental.channel.torch_tensor_type import Communicator, TorchTensorType
- from ray.experimental.util.types import (
- AllGatherOp,
- AllReduceOp,
- ReduceOp,
- ReduceScatterOp,
- _CollectiveOp,
- )
- from ray.util.collective.types import ReduceOp as RayReduceOp
- logger = logging.getLogger(__name__)
- def _bind(
- inputs: Union[List["ray.dag.DAGNode"], List[List["ray.dag.DAGNode"]]],
- op: _CollectiveOp,
- transport: Optional[Union[str, Communicator]] = None,
- ):
- """
- Bind inputs (input nodes or lists of input nodes) with a collective operation.
- The collective operation is applied to each list of input nodes. The output nodes
- will have the same shape as the input nodes.
- Example of binding a list of input node:
- with InputNode() as inp:
- res_comp1 = [actor.comp1.bind(inp) for actor in actors]
- res_comp2 = [actor.comp2.bind(inp) for actor in actors]
- res_ar = allreduce.bind([res_comp1, res_comp2])
- Requirements:
- 1. Each input node returns a torch tensor.
- 2. Each input node within a list is from a different actor.
- 3. If lists of input nodes are provided, the order of actors should
- be the same for each nested list.
- 4. If a custom transport is specified, its actor set matches the actor
- set of the input nodes.
- 5. If input nodes are provided, then all tensors have the same shape.
- If lists of input nodes are provided, then all tensors in each
- list have the same shape.
- Requirements 1-3 are checked in the `CollectiveGroup` constructor.
- Requirement 4 is not checked yet.
- Args:
- inputs: A list of DAG nodes or a list of lists of DAG nodes. Each leaf list
- should contain one object per actor.
- op: The collective operation.
- transport: GPU communicator for the collective operation. If not
- specified, the default ACCELERATOR is used.
- Returns:
- A list of collective output nodes or a list of lists of collective output nodes,
- with the same shape as the input nodes. Each output node has the same order and
- belongs to the same actor as the corresponding input node.
- """
- if isinstance(inputs[0], list) and not isinstance(op, AllReduceOp):
- raise ValueError(
- "Currently binding a nested list of dag nodes is only supported for allreduce"
- )
- # Convert list of DAGNode into nested list for type checking
- if not isinstance(inputs[0], list):
- inputs = [inputs]
- if transport is None:
- transport = TorchTensorType.ACCELERATOR
- collective_op = _CollectiveOperation(inputs, op, transport)
- collective_output_nodes: List[CollectiveOutputNode] = []
- if isinstance(op, AllGatherOp):
- method_name = "allgather"
- elif isinstance(op, AllReduceOp):
- method_name = f"allreduce.{op.reduceOp}"
- elif isinstance(op, ReduceScatterOp):
- method_name = f"reducescatter.{op.reduceOp}"
- else:
- raise ValueError(f"Expected a collective operation, but got {op}")
- for i in range(len(inputs[0])):
- input_node_list = [l[i] for l in inputs if l]
- actor_handle: Optional["ray.actor.ActorHandle"] = input_node_list[
- 0
- ]._get_actor_handle()
- assert actor_handle is not None
- collective_output_node = CollectiveOutputNode(
- method_name=method_name,
- method_args=tuple(input_node_list),
- method_kwargs=dict(),
- method_options=dict(),
- other_args_to_resolve={
- PARENT_CLASS_NODE_KEY: actor_handle,
- BIND_INDEX_KEY: actor_handle._ray_dag_bind_index,
- COLLECTIVE_OPERATION_KEY: collective_op,
- },
- )
- actor_handle._ray_dag_bind_index += 1
- if len(input_node_list) > 1:
- output_nodes: List[CollectiveOutputNode] = []
- for i in range(len(input_node_list)):
- output_node = CollectiveOutputNode(
- f"return_idx_{i}",
- (collective_output_node, i),
- dict(),
- dict(),
- {
- BIND_INDEX_KEY: collective_output_node._get_bind_index(),
- IS_CLASS_METHOD_OUTPUT_KEY: True,
- PARENT_CLASS_NODE_KEY: actor_handle,
- },
- )
- output_nodes.append(output_node)
- collective_output_nodes.append(output_nodes)
- else:
- collective_output_nodes.append(collective_output_node)
- return collective_output_nodes
- class AllGatherWrapper:
- """Wrapper for NCCL all-gather."""
- def bind(
- self,
- input_nodes: List["ray.dag.DAGNode"],
- transport: Optional[Union[str, Communicator]] = None,
- ) -> List[CollectiveOutputNode]:
- return _bind(input_nodes, AllGatherOp(), transport)
- def __call__(
- self,
- tensor_list,
- tensor,
- group_name: str = "default",
- ):
- from ray.util.collective.collective import allgather
- return allgather(tensor_list, tensor, group_name)
- class AllReduceWrapper:
- """Wrapper for NCCL all-reduce."""
- def bind(
- self,
- input_nodes: List["ray.dag.DAGNode"],
- op: ReduceOp = ReduceOp.SUM,
- transport: Optional[Union[str, Communicator]] = None,
- ) -> List[CollectiveOutputNode]:
- if not isinstance(op, ReduceOp):
- raise ValueError(f"Unexpected operation: {op}")
- return _bind(input_nodes, AllReduceOp(reduceOp=op), transport)
- def __call__(
- self,
- tensor,
- group_name: str = "default",
- op: RayReduceOp = RayReduceOp.SUM,
- ):
- from ray.util.collective.collective import allreduce
- return allreduce(tensor, group_name, op)
- class ReduceScatterWrapper:
- """Wrapper for NCCL reduce-scatter."""
- def bind(
- self,
- input_nodes: List["ray.dag.DAGNode"],
- op: ReduceOp = ReduceOp.SUM,
- transport: Optional[Union[str, Communicator]] = None,
- ) -> List[CollectiveOutputNode]:
- if not isinstance(op, ReduceOp):
- raise ValueError(f"Unexpected operation: {op}")
- return _bind(input_nodes, ReduceScatterOp(reduceOp=op), transport)
- def __call__(
- self,
- tensor,
- group_name: str = "default",
- op: RayReduceOp = RayReduceOp.SUM,
- ):
- from ray.util.collective.collective import reducescatter
- return reducescatter(tensor, group_name, op)
- allgather = AllGatherWrapper()
- allreduce = AllReduceWrapper()
- reducescatter = ReduceScatterWrapper()
|