operations.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. import logging
  2. from typing import List, Optional, Union
  3. import ray
  4. from ray.dag.collective_node import CollectiveOutputNode, _CollectiveOperation
  5. from ray.dag.constants import (
  6. BIND_INDEX_KEY,
  7. COLLECTIVE_OPERATION_KEY,
  8. IS_CLASS_METHOD_OUTPUT_KEY,
  9. PARENT_CLASS_NODE_KEY,
  10. )
  11. from ray.experimental.channel.torch_tensor_type import Communicator, TorchTensorType
  12. from ray.experimental.util.types import (
  13. AllGatherOp,
  14. AllReduceOp,
  15. ReduceOp,
  16. ReduceScatterOp,
  17. _CollectiveOp,
  18. )
  19. from ray.util.collective.types import ReduceOp as RayReduceOp
  20. logger = logging.getLogger(__name__)
  21. def _bind(
  22. inputs: Union[List["ray.dag.DAGNode"], List[List["ray.dag.DAGNode"]]],
  23. op: _CollectiveOp,
  24. transport: Optional[Union[str, Communicator]] = None,
  25. ):
  26. """
  27. Bind inputs (input nodes or lists of input nodes) with a collective operation.
  28. The collective operation is applied to each list of input nodes. The output nodes
  29. will have the same shape as the input nodes.
  30. Example of binding a list of input node:
  31. with InputNode() as inp:
  32. res_comp1 = [actor.comp1.bind(inp) for actor in actors]
  33. res_comp2 = [actor.comp2.bind(inp) for actor in actors]
  34. res_ar = allreduce.bind([res_comp1, res_comp2])
  35. Requirements:
  36. 1. Each input node returns a torch tensor.
  37. 2. Each input node within a list is from a different actor.
  38. 3. If lists of input nodes are provided, the order of actors should
  39. be the same for each nested list.
  40. 4. If a custom transport is specified, its actor set matches the actor
  41. set of the input nodes.
  42. 5. If input nodes are provided, then all tensors have the same shape.
  43. If lists of input nodes are provided, then all tensors in each
  44. list have the same shape.
  45. Requirements 1-3 are checked in the `CollectiveGroup` constructor.
  46. Requirement 4 is not checked yet.
  47. Args:
  48. inputs: A list of DAG nodes or a list of lists of DAG nodes. Each leaf list
  49. should contain one object per actor.
  50. op: The collective operation.
  51. transport: GPU communicator for the collective operation. If not
  52. specified, the default ACCELERATOR is used.
  53. Returns:
  54. A list of collective output nodes or a list of lists of collective output nodes,
  55. with the same shape as the input nodes. Each output node has the same order and
  56. belongs to the same actor as the corresponding input node.
  57. """
  58. if isinstance(inputs[0], list) and not isinstance(op, AllReduceOp):
  59. raise ValueError(
  60. "Currently binding a nested list of dag nodes is only supported for allreduce"
  61. )
  62. # Convert list of DAGNode into nested list for type checking
  63. if not isinstance(inputs[0], list):
  64. inputs = [inputs]
  65. if transport is None:
  66. transport = TorchTensorType.ACCELERATOR
  67. collective_op = _CollectiveOperation(inputs, op, transport)
  68. collective_output_nodes: List[CollectiveOutputNode] = []
  69. if isinstance(op, AllGatherOp):
  70. method_name = "allgather"
  71. elif isinstance(op, AllReduceOp):
  72. method_name = f"allreduce.{op.reduceOp}"
  73. elif isinstance(op, ReduceScatterOp):
  74. method_name = f"reducescatter.{op.reduceOp}"
  75. else:
  76. raise ValueError(f"Expected a collective operation, but got {op}")
  77. for i in range(len(inputs[0])):
  78. input_node_list = [l[i] for l in inputs if l]
  79. actor_handle: Optional["ray.actor.ActorHandle"] = input_node_list[
  80. 0
  81. ]._get_actor_handle()
  82. assert actor_handle is not None
  83. collective_output_node = CollectiveOutputNode(
  84. method_name=method_name,
  85. method_args=tuple(input_node_list),
  86. method_kwargs=dict(),
  87. method_options=dict(),
  88. other_args_to_resolve={
  89. PARENT_CLASS_NODE_KEY: actor_handle,
  90. BIND_INDEX_KEY: actor_handle._ray_dag_bind_index,
  91. COLLECTIVE_OPERATION_KEY: collective_op,
  92. },
  93. )
  94. actor_handle._ray_dag_bind_index += 1
  95. if len(input_node_list) > 1:
  96. output_nodes: List[CollectiveOutputNode] = []
  97. for i in range(len(input_node_list)):
  98. output_node = CollectiveOutputNode(
  99. f"return_idx_{i}",
  100. (collective_output_node, i),
  101. dict(),
  102. dict(),
  103. {
  104. BIND_INDEX_KEY: collective_output_node._get_bind_index(),
  105. IS_CLASS_METHOD_OUTPUT_KEY: True,
  106. PARENT_CLASS_NODE_KEY: actor_handle,
  107. },
  108. )
  109. output_nodes.append(output_node)
  110. collective_output_nodes.append(output_nodes)
  111. else:
  112. collective_output_nodes.append(collective_output_node)
  113. return collective_output_nodes
  114. class AllGatherWrapper:
  115. """Wrapper for NCCL all-gather."""
  116. def bind(
  117. self,
  118. input_nodes: List["ray.dag.DAGNode"],
  119. transport: Optional[Union[str, Communicator]] = None,
  120. ) -> List[CollectiveOutputNode]:
  121. return _bind(input_nodes, AllGatherOp(), transport)
  122. def __call__(
  123. self,
  124. tensor_list,
  125. tensor,
  126. group_name: str = "default",
  127. ):
  128. from ray.util.collective.collective import allgather
  129. return allgather(tensor_list, tensor, group_name)
  130. class AllReduceWrapper:
  131. """Wrapper for NCCL all-reduce."""
  132. def bind(
  133. self,
  134. input_nodes: List["ray.dag.DAGNode"],
  135. op: ReduceOp = ReduceOp.SUM,
  136. transport: Optional[Union[str, Communicator]] = None,
  137. ) -> List[CollectiveOutputNode]:
  138. if not isinstance(op, ReduceOp):
  139. raise ValueError(f"Unexpected operation: {op}")
  140. return _bind(input_nodes, AllReduceOp(reduceOp=op), transport)
  141. def __call__(
  142. self,
  143. tensor,
  144. group_name: str = "default",
  145. op: RayReduceOp = RayReduceOp.SUM,
  146. ):
  147. from ray.util.collective.collective import allreduce
  148. return allreduce(tensor, group_name, op)
  149. class ReduceScatterWrapper:
  150. """Wrapper for NCCL reduce-scatter."""
  151. def bind(
  152. self,
  153. input_nodes: List["ray.dag.DAGNode"],
  154. op: ReduceOp = ReduceOp.SUM,
  155. transport: Optional[Union[str, Communicator]] = None,
  156. ) -> List[CollectiveOutputNode]:
  157. if not isinstance(op, ReduceOp):
  158. raise ValueError(f"Unexpected operation: {op}")
  159. return _bind(input_nodes, ReduceScatterOp(reduceOp=op), transport)
  160. def __call__(
  161. self,
  162. tensor,
  163. group_name: str = "default",
  164. op: RayReduceOp = RayReduceOp.SUM,
  165. ):
  166. from ray.util.collective.collective import reducescatter
  167. return reducescatter(tensor, group_name, op)
  168. allgather = AllGatherWrapper()
  169. allreduce = AllReduceWrapper()
  170. reducescatter = ReduceScatterWrapper()