collective_node.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
  2. if TYPE_CHECKING:
  3. import torch
  4. import ray
  5. from ray.dag import (
  6. ClassMethodNode,
  7. DAGNode,
  8. )
  9. from ray.dag.constants import COLLECTIVE_OPERATION_KEY, IS_CLASS_METHOD_OUTPUT_KEY
  10. from ray.experimental.channel import ChannelContext
  11. from ray.experimental.channel.torch_tensor_type import Communicator, TorchTensorType
  12. from ray.experimental.util.types import (
  13. AllGatherOp,
  14. AllReduceOp,
  15. ReduceScatterOp,
  16. _CollectiveOp,
  17. )
  18. from ray.util.annotations import DeveloperAPI
  19. class _CollectiveOperation:
  20. """
  21. Represent metadata for a collective communicator collective operation.
  22. Args:
  23. inputs: A list of lists of DAGNode. Each nested list inside
  24. of inputs should contain exactly one object per actor.
  25. If multiple nested lists are provided, then the order of
  26. actors should be the same for each nested list.
  27. op: The collective operation to perform.
  28. transport: The transport to use for the collective operation.
  29. Requirements:
  30. 1. Input nodes are unique.
  31. 2. Actor handles are unique.
  32. 3. Actor handles match the custom communicator group if specified.
  33. """
  34. def __init__(
  35. self,
  36. inputs: List[List[DAGNode]],
  37. op: _CollectiveOp,
  38. transport: Optional[Union[str, Communicator]] = None,
  39. ):
  40. self._actor_handles: List["ray.actor.ActorHandle"] = []
  41. for i, input_nodes in enumerate(inputs):
  42. # Check non-empty input list
  43. if len(input_nodes) == 0:
  44. nested_list_error_msg = f" at index {i}" if len(inputs) > 1 else ""
  45. raise ValueError(
  46. f"Expected non-empty input list{nested_list_error_msg}."
  47. )
  48. # Check input nodes are DAGNode
  49. if not all(isinstance(node, DAGNode) for node in input_nodes):
  50. nested_list_error_msg = (
  51. f" at list at index {i}" if len(inputs) > 1 else ""
  52. )
  53. raise ValueError(
  54. f"Expected all input nodes to be DAGNode{nested_list_error_msg}, "
  55. f"but got {input_nodes}."
  56. )
  57. # Check unique input nodes
  58. if len(set(input_nodes)) != len(input_nodes):
  59. duplicates = [
  60. input_node
  61. for input_node in input_nodes
  62. if input_nodes.count(input_node) > 1
  63. ]
  64. nested_list_error_msg = (
  65. f" at list at index {i}" if len(inputs) > 1 else ""
  66. )
  67. raise ValueError(
  68. f"Expected unique input nodes{nested_list_error_msg}, but found duplicates: "
  69. f"{duplicates}"
  70. )
  71. current_actor_handles = []
  72. for input_node in input_nodes:
  73. actor_handle = input_node._get_actor_handle()
  74. if actor_handle is None:
  75. nested_list_error_msg = (
  76. f" at list at index {i}" if len(inputs) > 1 else ""
  77. )
  78. raise ValueError(
  79. f"Expected an actor handle from the input node{nested_list_error_msg}"
  80. )
  81. current_actor_handles.append(actor_handle)
  82. # Check unique actor handles
  83. if len(set(current_actor_handles)) != len(current_actor_handles):
  84. invalid_input_nodes = [
  85. input_node
  86. for input_node in input_nodes
  87. if current_actor_handles.count(input_node._get_actor_handle()) > 1
  88. ]
  89. nested_list_error_msg = (
  90. f" at list at index {i}" if len(inputs) > 1 else ""
  91. )
  92. raise ValueError(
  93. f"Expected unique actor handles{nested_list_error_msg}, "
  94. "but found duplicate actor handles from input nodes: "
  95. f"{invalid_input_nodes}"
  96. )
  97. if i == 0:
  98. first_actor_handles = current_actor_handles
  99. # Check all lists of DAGNode have the same number of nodes
  100. if len(inputs[0]) != len(inputs[i]):
  101. raise ValueError(
  102. f"Expected all input lists to have the same number of nodes. "
  103. f"List at index 0 has length {len(inputs[0])}, but list at "
  104. f"index {i} has length {len(inputs[i])}."
  105. )
  106. # Check all lists of DAGNode have same set of actor handles
  107. if set(first_actor_handles) != set(current_actor_handles):
  108. raise ValueError(
  109. f"Expected all input lists to have the same set of actor handles. "
  110. f"List at index 0 has actors {set(first_actor_handles)}, but list at "
  111. f"index {i} has actors {set(current_actor_handles)}."
  112. )
  113. # Check all lists of DAGNode have same order of actor handles
  114. for j, (first, current) in enumerate(
  115. zip(first_actor_handles, current_actor_handles)
  116. ):
  117. if first != current:
  118. raise ValueError(
  119. f"Expected all input lists to have the same order of actor handles. "
  120. f"List at index 0 has actor {first} at position {j}, but list at "
  121. f"index {i} has actor {current} at position {j}."
  122. )
  123. self._actor_handles = current_actor_handles
  124. self._op = op
  125. if transport is None:
  126. transport = TorchTensorType.ACCELERATOR
  127. self._type_hint = TorchTensorType(transport=transport, _direct_return=True)
  128. if isinstance(transport, Communicator):
  129. if set(transport.get_actor_handles()) != set(self._actor_handles):
  130. raise ValueError(
  131. "Expected actor handles to match the custom communicator group"
  132. )
  133. def __str__(self) -> str:
  134. return (
  135. f"CollectiveOperation("
  136. f"_actor_handles={self._actor_handles}, "
  137. f"_op={self._op}, "
  138. f"_type_hint={self._type_hint})"
  139. )
  140. @property
  141. def actor_handles(self) -> List["ray.actor.ActorHandle"]:
  142. return self._actor_handles
  143. @property
  144. def type_hint(self) -> TorchTensorType:
  145. return self._type_hint
  146. def get_communicator(self) -> Communicator:
  147. if self._type_hint.communicator_id is not None:
  148. ctx = ChannelContext.get_current()
  149. communicator = ctx.communicators[self._type_hint.communicator_id]
  150. elif self._type_hint.get_custom_communicator() is not None:
  151. communicator = self._type_hint.get_custom_communicator()
  152. else:
  153. raise ValueError("Expected a communicator group")
  154. return communicator
  155. def execute(
  156. self, *send_buf: "torch.Tensor"
  157. ) -> Union["torch.Tensor", Tuple["torch.Tensor", ...]]:
  158. """
  159. Call the collective operation on the input tensor(s). Output tensor(s) are
  160. allocated and returned.
  161. Args:
  162. *send_buf: A variable number of torch tensors to send to the collective
  163. operation. The tensors have the same order as the input nodes.
  164. Returns:
  165. A torch tensor or a tuple of torch tensors containing the results of the
  166. collective operation. The output tensors have the same length and order
  167. as the input node list of the actor of this operation.
  168. """
  169. import torch
  170. if not all(isinstance(t, torch.Tensor) for t in send_buf):
  171. raise ValueError("Expected a torch tensor for each input node")
  172. communicator = self.get_communicator()
  173. if isinstance(self._op, AllGatherOp):
  174. assert len(send_buf) == 1
  175. t = send_buf[0]
  176. world_size = len(self._actor_handles)
  177. recv_buf = torch.empty(
  178. (t.shape[0] * world_size, *t.shape[1:]),
  179. dtype=t.dtype,
  180. device=t.device,
  181. )
  182. communicator.allgather(t, recv_buf)
  183. elif isinstance(self._op, AllReduceOp):
  184. if len(send_buf) == 1:
  185. t = send_buf[0]
  186. recv_buf = torch.empty_like(t)
  187. communicator.allreduce(t, recv_buf, self._op.reduceOp)
  188. else:
  189. if not all(t.dtype == send_buf[0].dtype for t in send_buf):
  190. raise ValueError(
  191. "Expected all input tensors to have the same dtype, "
  192. f"but got {[t.dtype for t in send_buf]}"
  193. )
  194. def unflatten_from(flat_buf, bufs):
  195. views = []
  196. offset = 0
  197. for t in bufs:
  198. numel = t.numel()
  199. t = flat_buf[offset : offset + numel].view(t.shape)
  200. views.append(t)
  201. offset += numel
  202. return tuple(views)
  203. flat_buf = torch.nn.utils.parameters_to_vector(send_buf)
  204. communicator.allreduce(flat_buf, flat_buf, self._op.reduceOp)
  205. recv_buf = unflatten_from(flat_buf, send_buf)
  206. elif isinstance(self._op, ReduceScatterOp):
  207. assert len(send_buf) == 1
  208. t = send_buf[0]
  209. world_size = len(self._actor_handles)
  210. if t.shape[0] % world_size != 0:
  211. raise ValueError(
  212. "Expected the first dimension of the input tensor to be divisible "
  213. f"by the world size {world_size}"
  214. )
  215. recv_buf = torch.empty(
  216. (t.shape[0] // world_size, *t.shape[1:]),
  217. dtype=t.dtype,
  218. device=t.device,
  219. )
  220. communicator.reducescatter(t, recv_buf, self._op.reduceOp)
  221. return recv_buf
  222. @DeveloperAPI
  223. class CollectiveOutputNode(ClassMethodNode):
  224. """Represent an output node from a communicator collective operation in a Ray DAG."""
  225. def __init__(
  226. self,
  227. method_name: str,
  228. method_args: Tuple[
  229. DAGNode,
  230. ],
  231. method_kwargs: Dict[str, Any],
  232. method_options: Dict[str, Any],
  233. other_args_to_resolve: Dict[str, Any],
  234. ):
  235. # Parse the input node(s).
  236. self._inputs = method_args
  237. # Parse the collective operation.
  238. self._collective_op: _CollectiveOperation = other_args_to_resolve.get(
  239. COLLECTIVE_OPERATION_KEY, None
  240. )
  241. self._is_class_method_output: bool = other_args_to_resolve.get(
  242. IS_CLASS_METHOD_OUTPUT_KEY, False
  243. )
  244. if self._collective_op is None and not self._is_class_method_output:
  245. raise ValueError("Expected a collective operation")
  246. super().__init__(
  247. method_name,
  248. method_args,
  249. method_kwargs,
  250. method_options,
  251. other_args_to_resolve,
  252. )
  253. def _copy_impl(
  254. self,
  255. new_args: List[Any],
  256. new_kwargs: Dict[str, Any],
  257. new_options: Dict[str, Any],
  258. new_other_args_to_resolve: Dict[str, Any],
  259. ):
  260. return CollectiveOutputNode(
  261. self._method_name,
  262. new_args,
  263. new_kwargs,
  264. new_options,
  265. other_args_to_resolve=new_other_args_to_resolve,
  266. )
  267. def _execute_impl(self, *args, **kwargs):
  268. raise NotImplementedError(
  269. "CollectiveOutputNode is only supported with dag.experimental_compile()"
  270. )
  271. @property
  272. def collective_op(self) -> _CollectiveOperation:
  273. return self._collective_op