dag_node.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705
  1. import asyncio
  2. import copy
  3. import uuid
  4. from itertools import chain
  5. from typing import (
  6. Any,
  7. Callable,
  8. Dict,
  9. List,
  10. Literal,
  11. Optional,
  12. Tuple,
  13. TypeVar,
  14. Union,
  15. )
  16. import ray
  17. from ray.dag.base import DAGNodeBase
  18. from ray.dag.compiled_dag_node import build_compiled_dag_from_ray_dag
  19. from ray.dag.py_obj_scanner import _PyObjScanner
  20. from ray.experimental.channel import ChannelOutputType
  21. from ray.experimental.channel.auto_transport_type import AutoTransportType
  22. from ray.experimental.channel.communicator import Communicator
  23. from ray.experimental.channel.torch_tensor_type import TorchTensorType
  24. from ray.experimental.util.types import Device
  25. from ray.util.annotations import DeveloperAPI
  26. T = TypeVar("T")
  27. @DeveloperAPI
  28. class DAGNode(DAGNodeBase):
  29. """Abstract class for a node in a Ray task graph.
  30. A node has a type (e.g., FunctionNode), data (e.g., function options and
  31. body), arguments (Python values, DAGNodes, and DAGNodes nested within Python
  32. argument values) and options (Ray API .options() used for function, class
  33. or class method)
  34. """
  35. def __init__(
  36. self,
  37. args: Tuple[Any],
  38. kwargs: Dict[str, Any],
  39. options: Dict[str, Any],
  40. other_args_to_resolve: Dict[str, Any],
  41. ):
  42. """
  43. args:
  44. args (Tuple[Any]): Bound node arguments.
  45. ex: func_or_class.bind(1)
  46. kwargs (Dict[str, Any]): Bound node keyword arguments.
  47. ex: func_or_class.bind(a=1)
  48. options (Dict[str, Any]): Bound node options arguments.
  49. ex: func_or_class.options(num_cpus=2)
  50. other_args_to_resolve (Dict[str, Any]): Bound kwargs to resolve
  51. that's specific to subclass implementation without exposing
  52. as args in base class, example: ClassMethodNode
  53. """
  54. self._bound_args: Tuple[Any] = args or []
  55. self._bound_kwargs: Dict[str, Any] = kwargs or {}
  56. self._bound_options: Dict[str, Any] = options or {}
  57. self._bound_other_args_to_resolve: Optional[Dict[str, Any]] = (
  58. other_args_to_resolve or {}
  59. )
  60. # The list of nodes that use this DAG node as an argument.
  61. self._downstream_nodes: List["DAGNode"] = []
  62. # UUID that is not changed over copies of this node.
  63. self._stable_uuid = uuid.uuid4().hex
  64. # Indicates whether this DAG node contains nested DAG nodes.
  65. # Nested DAG nodes are allowed in traditional DAGs but not
  66. # in Ray Compiled Graphs, except for MultiOutputNode.
  67. self._args_contain_nested_dag_node = False
  68. # The list of nodes that this DAG node uses as an argument.
  69. self._upstream_nodes: List["DAGNode"] = self._collect_upstream_nodes()
  70. # Cached values from last call to execute()
  71. self.cache_from_last_execute = {}
  72. self._type_hint: ChannelOutputType = ChannelOutputType()
  73. # If the original type hint is an AutoTransportType, we make a copy
  74. # here when it is resolved to the actual type, as additional debugging
  75. # information. Otherwise, it is None.
  76. self._original_type_hint: Optional[ChannelOutputType] = None
  77. # Whether this node calls `experimental_compile`.
  78. self.is_cgraph_output_node = False
  79. def _collect_upstream_nodes(self) -> List["DAGNode"]:
  80. """
  81. Retrieve upstream nodes and update their downstream dependencies.
  82. Currently, the DAG assumes that all DAGNodes in `args`, `kwargs`, and
  83. `other_args_to_resolve` are upstream nodes. However, Ray Compiled Graphs
  84. builds the upstream/downstream relationship based only on args. Be cautious
  85. when persisting DAGNodes in `other_args_to_resolve` and kwargs in the future.
  86. TODO (kevin85421): Currently, the upstream nodes and downstream nodes have
  87. circular references. Therefore, it relies on the garbage collector to clean
  88. them up instead of reference counting. We should consider using weak references
  89. to avoid circular references.
  90. """
  91. upstream_nodes: List["DAGNode"] = []
  92. # Ray Compiled Graphs do not allow nested DAG nodes in arguments.
  93. # Specifically, a DAGNode should not be placed inside any type of
  94. # container. However, we only know if this is a compiled graph
  95. # when calling `experimental_compile`. Therefore, we need to check
  96. # in advance if the arguments contain nested DAG nodes and raise
  97. # an error after compilation.
  98. assert hasattr(self._bound_args, "__iter__")
  99. for arg in self._bound_args:
  100. if isinstance(arg, DAGNode):
  101. upstream_nodes.append(arg)
  102. else:
  103. scanner = _PyObjScanner()
  104. dag_nodes = scanner.find_nodes(arg)
  105. upstream_nodes.extend(dag_nodes)
  106. scanner.clear()
  107. self._args_contain_nested_dag_node = len(dag_nodes) > 0
  108. scanner = _PyObjScanner()
  109. other_upstream_nodes: List["DAGNode"] = scanner.find_nodes(
  110. [
  111. self._bound_kwargs,
  112. self._bound_other_args_to_resolve,
  113. ]
  114. )
  115. upstream_nodes.extend(other_upstream_nodes)
  116. scanner.clear()
  117. # Update dependencies.
  118. for upstream_node in upstream_nodes:
  119. upstream_node._downstream_nodes.append(self)
  120. return upstream_nodes
  121. def with_tensor_transport(
  122. self,
  123. transport: Optional[Union[str, Communicator]] = "auto",
  124. device: Literal["default", "cpu", "gpu", "cuda"] = "default",
  125. _static_shape: bool = False,
  126. _direct_return: bool = False,
  127. ):
  128. """
  129. Configure the torch tensor transport for this node.
  130. Args:
  131. transport: Specifies the tensor transport mechanism.
  132. - "accelerator": Tensors are communicated using accelerator-specific backends
  133. (e.g., NCCL, XLA, or vendor-provided transport). This is the recommended option
  134. for most use cases, as it supports extensibility and future hardware backends.
  135. - "nccl": Tensors are passed explicitly via NCCL. This option is kept for
  136. backwards compatibility and may be removed in the future. Use "accelerator"
  137. instead unless you have legacy requirements.
  138. - "shm": Tensors are passed via host shared memory and gRPC. Typically used
  139. when accelerator-based transport is unavailable or not suitable.
  140. - "auto" (default): The system automatically selects the appropriate transport
  141. mechanism based on the sender and receiver, usually preferring accelerator-based
  142. transport when available.
  143. device: The target device to use for the tensor transport.
  144. "default": The tensor will maintain its original device placement from the sender
  145. "cpu": The tensor will be explicitly moved to CPU device in the receiver
  146. "gpu" or "cuda": The tensor will be explicitly moved to GPU device in the receiver
  147. _static_shape: A hint indicating whether the shape(s) and dtype(s)
  148. of tensor(s) contained in this value always remain the same
  149. across different executions of the DAG. If this is True, the
  150. transport will be more efficient.
  151. _direct_return: Whether the tensor is sent directly or inside of
  152. other data. If a "nccl" transport is used, this allows the
  153. sender and receiver to eliminate performance overhead from
  154. an additional data transfer.
  155. """
  156. try:
  157. device = Device(device)
  158. except ValueError:
  159. valid_devices = ", ".join(f"'{d.value}'" for d in Device)
  160. raise ValueError(
  161. f"Invalid device '{device}'. Valid options are: {valid_devices}."
  162. )
  163. if transport == "auto":
  164. self._type_hint = AutoTransportType(
  165. device=device,
  166. _static_shape=_static_shape,
  167. _direct_return=_direct_return,
  168. )
  169. elif transport == "nccl":
  170. self._type_hint = TorchTensorType(
  171. transport="accelerator",
  172. device=device,
  173. _static_shape=_static_shape,
  174. _direct_return=_direct_return,
  175. )
  176. elif transport == "accelerator":
  177. self._type_hint = TorchTensorType(
  178. transport="accelerator",
  179. device=device,
  180. _static_shape=_static_shape,
  181. _direct_return=_direct_return,
  182. )
  183. elif transport == "shm":
  184. self._type_hint = TorchTensorType(
  185. device=device,
  186. _static_shape=_static_shape,
  187. _direct_return=_direct_return,
  188. )
  189. else:
  190. if not isinstance(transport, Communicator):
  191. raise ValueError(
  192. f"Invalid transport type: {transport}. "
  193. "Transport must be one of 'auto', 'nccl', 'shm', 'accelerator' or "
  194. "an instance of Communicator type."
  195. )
  196. self._type_hint = TorchTensorType(
  197. transport=transport,
  198. device=device,
  199. _static_shape=_static_shape,
  200. _direct_return=_direct_return,
  201. )
  202. return self
  203. @property
  204. def type_hint(self) -> ChannelOutputType:
  205. return self._type_hint
  206. @type_hint.setter
  207. def type_hint(self, type_hint: ChannelOutputType) -> None:
  208. if isinstance(self._type_hint, AutoTransportType):
  209. self._original_type_hint = self._type_hint
  210. self._type_hint = type_hint
  211. def get_args(self) -> Tuple[Any]:
  212. """Return the tuple of arguments for this node."""
  213. return self._bound_args
  214. def get_kwargs(self) -> Dict[str, Any]:
  215. """Return the dict of keyword arguments for this node."""
  216. return self._bound_kwargs.copy()
  217. def get_options(self) -> Dict[str, Any]:
  218. """Return the dict of options arguments for this node."""
  219. return self._bound_options.copy()
  220. def get_other_args_to_resolve(self) -> Dict[str, Any]:
  221. """Return the dict of other args to resolve arguments for this node."""
  222. return self._bound_other_args_to_resolve.copy()
  223. def get_stable_uuid(self) -> str:
  224. """Return stable uuid for this node.
  225. 1) Generated only once at first instance creation
  226. 2) Stable across pickling, replacement and JSON serialization.
  227. """
  228. return self._stable_uuid
  229. async def get_object_refs_from_last_execute(self) -> Dict[str, Any]:
  230. """Gets cached object refs from the last call to execute().
  231. After this DAG is executed through execute(), retrieves a map between node
  232. UUID to a reference to the return value of the default executor on that node.
  233. """
  234. cache = {}
  235. for node_uuid, value in self.cache_from_last_execute.items():
  236. if isinstance(value, asyncio.Task):
  237. cache[node_uuid] = await value
  238. else:
  239. cache[node_uuid] = value
  240. return cache
  241. def clear_cache(self):
  242. self.cache_from_last_execute = {}
  243. def experimental_compile(
  244. self,
  245. _submit_timeout: Optional[float] = None,
  246. _buffer_size_bytes: Optional[int] = None,
  247. enable_asyncio: bool = False,
  248. _max_inflight_executions: Optional[int] = None,
  249. _max_buffered_results: Optional[int] = None,
  250. _overlap_gpu_communication: Optional[bool] = None,
  251. _default_communicator: Optional[Union[Communicator, str]] = "create",
  252. ) -> "ray.dag.CompiledDAG":
  253. """Compile an accelerated execution path for this DAG.
  254. Args:
  255. _submit_timeout: The maximum time in seconds to wait for execute() calls.
  256. None means using default timeout, 0 means immediate timeout
  257. (immediate success or timeout without blocking), -1 means
  258. infinite timeout (block indefinitely).
  259. _buffer_size_bytes: The initial buffer size in bytes for messages
  260. that can be passed between tasks in the DAG. The buffers will
  261. be automatically resized if larger messages are written to the
  262. channel.
  263. enable_asyncio: Whether to enable asyncio for this DAG.
  264. _max_inflight_executions: The maximum number of in-flight executions that
  265. can be submitted via `execute` or `execute_async` before consuming
  266. the output using `ray.get()`. If the caller submits more executions,
  267. `RayCgraphCapacityExceeded` is raised.
  268. _max_buffered_results: The maximum number of results that can be
  269. buffered at the driver. If more than this number of results
  270. are buffered, `RayCgraphCapacityExceeded` is raised. Note that
  271. when result corresponding to an execution is retrieved
  272. (by calling `ray.get()` on a `CompiledDAGRef` or
  273. `CompiledDAGRef` or await on a `CompiledDAGFuture`), results
  274. corresponding to earlier executions that have not been retrieved
  275. yet are buffered.
  276. _overlap_gpu_communication: (experimental) Whether to overlap GPU
  277. communication with computation during DAG execution. If True, the
  278. communication and computation can be overlapped, which can improve
  279. the performance of the DAG execution. If None, the default value
  280. will be used.
  281. _default_communicator: The default communicator to use to transfer
  282. tensors. Three types of values are valid. (1) Communicator:
  283. For p2p operations, this is the default communicator
  284. to use for nodes annotated with `with_tensor_transport()` and when
  285. shared memory is not the desired option (e.g., when transport="nccl",
  286. or when transport="auto" for communication between two different GPUs).
  287. For collective operations, this is the default communicator to use
  288. when a custom communicator is not specified.
  289. (2) "create": for each collective operation without a custom communicator
  290. specified, a communicator is created and initialized on its involved actors,
  291. or an already created communicator is reused if the set of actors is the same.
  292. For all p2p operations without a custom communicator specified, it reuses
  293. an already created collective communicator if the p2p actors are a subset.
  294. Otherwise, a new communicator is created.
  295. (3) None: a ValueError will be thrown if a custom communicator is not specified.
  296. Returns:
  297. A compiled DAG.
  298. """
  299. from ray.dag import DAGContext
  300. ctx = DAGContext.get_current()
  301. if _buffer_size_bytes is None:
  302. _buffer_size_bytes = ctx.buffer_size_bytes
  303. # Validate whether this DAG node has already been compiled.
  304. if self.is_cgraph_output_node:
  305. raise ValueError(
  306. "It is not allowed to call `experimental_compile` on the same DAG "
  307. "object multiple times no matter whether `teardown` is called or not. "
  308. "Please reuse the existing compiled DAG or create a new one."
  309. )
  310. # Whether this node is an output node in the DAG. We cannot determine
  311. # this in the constructor because the output node is determined when
  312. # `experimental_compile` is called.
  313. self.is_cgraph_output_node = True
  314. return build_compiled_dag_from_ray_dag(
  315. self,
  316. _submit_timeout,
  317. _buffer_size_bytes,
  318. enable_asyncio,
  319. _max_inflight_executions,
  320. _max_buffered_results,
  321. _overlap_gpu_communication,
  322. _default_communicator,
  323. )
  324. def execute(
  325. self, *args, _ray_cache_refs: bool = False, **kwargs
  326. ) -> Union[ray.ObjectRef, "ray.actor.ActorHandle"]:
  327. """Execute this DAG using the Ray default executor _execute_impl().
  328. Args:
  329. _ray_cache_refs: If true, stores the default executor's return values
  330. on each node in this DAG in a cache. These should be a mix of:
  331. - ray.ObjectRefs pointing to the outputs of method and function nodes
  332. - Serve handles for class nodes
  333. - resolved values representing user input at runtime
  334. """
  335. def executor(node):
  336. return node._execute_impl(*args, **kwargs)
  337. result = self.apply_recursive(executor)
  338. if _ray_cache_refs:
  339. self.cache_from_last_execute = executor.cache
  340. return result
  341. def _get_toplevel_child_nodes(self) -> List["DAGNode"]:
  342. """Return the list of nodes specified as top-level args.
  343. For example, in `f.remote(a, [b])`, only `a` is a top-level arg.
  344. This list of nodes are those that are typically resolved prior to
  345. task execution in Ray. This does not include nodes nested within args.
  346. For that, use ``_get_all_child_nodes()``.
  347. """
  348. # we use List instead of Set here because the hash key of the node
  349. # object changes each time we create it. So if using Set here, the
  350. # order of returned children can be different if we create the same
  351. # nodes and dag one more time.
  352. children = []
  353. for a in self.get_args():
  354. if isinstance(a, DAGNode):
  355. if a not in children:
  356. children.append(a)
  357. for a in self.get_kwargs().values():
  358. if isinstance(a, DAGNode):
  359. if a not in children:
  360. children.append(a)
  361. for a in self.get_other_args_to_resolve().values():
  362. if isinstance(a, DAGNode):
  363. if a not in children:
  364. children.append(a)
  365. return children
  366. def _get_all_child_nodes(self) -> List["DAGNode"]:
  367. """Return the list of nodes referenced by the args, kwargs, and
  368. args_to_resolve in current node, even they're deeply nested.
  369. Examples:
  370. f.remote(a, [b]) -> [a, b]
  371. f.remote(a, [b], key={"nested": [c]}) -> [a, b, c]
  372. """
  373. scanner = _PyObjScanner()
  374. # we use List instead of Set here, reason explained
  375. # in `_get_toplevel_child_nodes`.
  376. children = []
  377. for n in scanner.find_nodes(
  378. [
  379. self._bound_args,
  380. self._bound_kwargs,
  381. self._bound_other_args_to_resolve,
  382. ]
  383. ):
  384. if n not in children:
  385. children.append(n)
  386. scanner.clear()
  387. return children
  388. def _apply_and_replace_all_child_nodes(
  389. self, fn: "Callable[[DAGNode], T]"
  390. ) -> "DAGNode":
  391. """Apply and replace all immediate child nodes using a given function.
  392. This is a shallow replacement only. To recursively transform nodes in
  393. the DAG, use ``apply_recursive()``.
  394. Args:
  395. fn: Callable that will be applied once to each child of this node.
  396. Returns:
  397. New DAGNode after replacing all child nodes.
  398. """
  399. replace_table = {}
  400. # CloudPickler scanner object for current layer of DAGNode. Same
  401. # scanner should be use for a full find & replace cycle.
  402. scanner = _PyObjScanner()
  403. # Find all first-level nested DAGNode children in args.
  404. # Update replacement table and execute the replace.
  405. for node in scanner.find_nodes(
  406. [
  407. self._bound_args,
  408. self._bound_kwargs,
  409. self._bound_other_args_to_resolve,
  410. ]
  411. ):
  412. if node not in replace_table:
  413. replace_table[node] = fn(node)
  414. new_args, new_kwargs, new_other_args_to_resolve = scanner.replace_nodes(
  415. replace_table
  416. )
  417. scanner.clear()
  418. # Return updated copy of self.
  419. return self._copy(
  420. new_args, new_kwargs, self.get_options(), new_other_args_to_resolve
  421. )
  422. def apply_recursive(self, fn: "Callable[[DAGNode], T]") -> T:
  423. """Apply callable on each node in this DAG in a bottom-up tree walk.
  424. Args:
  425. fn: Callable that will be applied once to each node in the
  426. DAG. It will be applied recursively bottom-up, so nodes can
  427. assume the fn has been applied to their args already.
  428. Returns:
  429. Return type of the fn after application to the tree.
  430. """
  431. if not type(fn).__name__ == "_CachingFn":
  432. class _CachingFn:
  433. def __init__(self, fn):
  434. self.cache = {}
  435. self.fn = fn
  436. self.fn.cache = self.cache
  437. self.input_node_uuid = None
  438. def __call__(self, node: "DAGNode"):
  439. from ray.dag.input_node import InputNode
  440. if node._stable_uuid not in self.cache:
  441. self.cache[node._stable_uuid] = self.fn(node)
  442. if isinstance(node, InputNode):
  443. if not self.input_node_uuid:
  444. self.input_node_uuid = node._stable_uuid
  445. elif self.input_node_uuid != node._stable_uuid:
  446. raise AssertionError(
  447. "Each DAG should only have one unique InputNode."
  448. )
  449. return self.cache[node._stable_uuid]
  450. fn = _CachingFn(fn)
  451. else:
  452. if self._stable_uuid in fn.cache:
  453. return fn.cache[self._stable_uuid]
  454. return fn(
  455. self._apply_and_replace_all_child_nodes(
  456. lambda node: node.apply_recursive(fn)
  457. )
  458. )
  459. def traverse_and_apply(self, fn: "Callable[[DAGNode], T]"):
  460. """
  461. Traverse all nodes in the connected component of the DAG that contains
  462. the `self` node, and apply the given function to each node.
  463. """
  464. visited = set()
  465. queue = [self]
  466. cgraph_output_node: Optional[DAGNode] = None
  467. while queue:
  468. node = queue.pop(0)
  469. if node._args_contain_nested_dag_node:
  470. self._raise_nested_dag_node_error(node._bound_args)
  471. if node not in visited:
  472. if node.is_cgraph_output_node:
  473. # Validate whether there are multiple nodes that call
  474. # `experimental_compile`.
  475. if cgraph_output_node is not None:
  476. raise ValueError(
  477. "The DAG was compiled more than once. The following two "
  478. "nodes call `experimental_compile`: "
  479. f"(1) {cgraph_output_node}, (2) {node}"
  480. )
  481. cgraph_output_node = node
  482. fn(node)
  483. visited.add(node)
  484. """
  485. Add all unseen downstream and upstream nodes to the queue.
  486. This function should be called by the root of the DAG. However,
  487. in some invalid cases, some nodes may not be descendants of the
  488. root. Therefore, we also add upstream nodes to the queue so that
  489. a meaningful error message can be raised when the DAG is compiled.
  490. ```
  491. with InputNode() as inp:
  492. dag = MultiOutputNode([a1.inc.bind(inp), a2.inc.bind(1)])
  493. ```
  494. In the above example, `a2.inc` is not a descendant of inp. If we only
  495. add downstream nodes to the queue, the `a2.inc` node will not be visited
  496. , and the error message will be hard to understand, such as a key error
  497. in the compiled DAG.
  498. """
  499. for neighbor in chain.from_iterable(
  500. [node._downstream_nodes, node._upstream_nodes]
  501. ):
  502. if neighbor not in visited:
  503. queue.append(neighbor)
  504. def _raise_nested_dag_node_error(self, args):
  505. """
  506. Raise an error for nested DAGNodes in Ray Compiled Graphs.
  507. Args:
  508. args: The arguments of the DAGNode.
  509. """
  510. for arg in args:
  511. if isinstance(arg, DAGNode):
  512. continue
  513. else:
  514. scanner = _PyObjScanner()
  515. dag_nodes = scanner.find_nodes([arg])
  516. scanner.clear()
  517. if len(dag_nodes) > 0:
  518. raise ValueError(
  519. f"Found {len(dag_nodes)} DAGNodes from the arg {arg} "
  520. f"in {self}. Please ensure that the argument is a "
  521. "single DAGNode and that a DAGNode is not allowed to "
  522. "be placed inside any type of container."
  523. )
  524. raise AssertionError(
  525. "A DAGNode's args should contain nested DAGNodes as args, "
  526. "but none were found during the compilation process. This is a "
  527. "Ray internal error. Please report this issue to the Ray team."
  528. )
  529. def _find_root(self) -> "DAGNode":
  530. """
  531. Return the root node of the DAG. The root node must be an InputNode.
  532. """
  533. from ray.dag.input_node import InputNode
  534. node = self
  535. while not isinstance(node, InputNode):
  536. if len(node._upstream_nodes) == 0:
  537. raise ValueError(
  538. "No InputNode found in the DAG: when traversing upwards, "
  539. f"no upstream node was found for {node}."
  540. )
  541. node = node._upstream_nodes[0]
  542. return node
  543. def apply_functional(
  544. self,
  545. source_input_list: Any,
  546. predicate_fn: Callable,
  547. apply_fn: Callable,
  548. ):
  549. """
  550. Apply a given function to DAGNodes in source_input_list, and return
  551. the replaced inputs without mutating or coping any DAGNode.
  552. Args:
  553. source_input_list: Source inputs to extract and apply function on
  554. all children DAGNode instances.
  555. predicate_fn: Applied on each DAGNode instance found and determine
  556. if we should apply function to it. Can be used to filter node
  557. types.
  558. apply_fn: Function to apply on the node on bound attributes. Example::
  559. apply_fn = lambda node: node._get_serve_deployment_handle(
  560. node._deployment, node._bound_other_args_to_resolve
  561. )
  562. Returns:
  563. replaced_inputs: Outputs of apply_fn on DAGNodes in
  564. source_input_list that passes predicate_fn.
  565. """
  566. replace_table = {}
  567. scanner = _PyObjScanner()
  568. for node in scanner.find_nodes(source_input_list):
  569. if predicate_fn(node) and node not in replace_table:
  570. replace_table[node] = apply_fn(node)
  571. replaced_inputs = scanner.replace_nodes(replace_table)
  572. scanner.clear()
  573. return replaced_inputs
  574. def _execute_impl(
  575. self, *args, **kwargs
  576. ) -> Union[ray.ObjectRef, "ray.actor.ActorHandle"]:
  577. """Execute this node, assuming args have been transformed already."""
  578. raise NotImplementedError
  579. def _copy_impl(
  580. self,
  581. new_args: List[Any],
  582. new_kwargs: Dict[str, Any],
  583. new_options: Dict[str, Any],
  584. new_other_args_to_resolve: Dict[str, Any],
  585. ) -> "DAGNode":
  586. """Return a copy of this node with the given new args."""
  587. raise NotImplementedError
  588. def _copy(
  589. self,
  590. new_args: List[Any],
  591. new_kwargs: Dict[str, Any],
  592. new_options: Dict[str, Any],
  593. new_other_args_to_resolve: Dict[str, Any],
  594. ) -> "DAGNode":
  595. """Return a copy of this node with the given new args."""
  596. instance = self._copy_impl(
  597. new_args, new_kwargs, new_options, new_other_args_to_resolve
  598. )
  599. instance._stable_uuid = self._stable_uuid
  600. instance._type_hint = copy.deepcopy(self._type_hint)
  601. instance._original_type_hint = copy.deepcopy(self._original_type_hint)
  602. return instance
  603. def __getstate__(self):
  604. """Required due to overriding `__getattr__` else pickling fails."""
  605. return self.__dict__
  606. def __setstate__(self, d: Dict[str, Any]):
  607. """Required due to overriding `__getattr__` else pickling fails."""
  608. self.__dict__.update(d)
  609. def __getattr__(self, attr: str):
  610. if attr == "bind":
  611. raise AttributeError(f".bind() cannot be used again on {type(self)} ")
  612. elif attr == "remote":
  613. raise AttributeError(
  614. f".remote() cannot be used on {type(self)}. To execute the task "
  615. "graph for this node, use .execute()."
  616. )
  617. else:
  618. return self.__getattribute__(attr)