| 123456789101112131415161718192021222324252627282930313233343536373839404142434445 |
- from typing import Any, Dict, List, Tuple, Union
- import ray
- from ray.dag import DAGNode
- from ray.dag.format_utils import get_dag_node_str
- from ray.util.annotations import DeveloperAPI
- @DeveloperAPI
- class MultiOutputNode(DAGNode):
- """Ray dag node used in DAG building API to mark the endpoint of DAG"""
- def __init__(
- self,
- args: Union[List[DAGNode], Tuple[DAGNode]],
- other_args_to_resolve: Dict[str, Any] = None,
- ):
- if isinstance(args, tuple):
- args = list(args)
- if not isinstance(args, list):
- raise ValueError(f"Invalid input type for `args`, {type(args)}.")
- super().__init__(
- args,
- {},
- {},
- other_args_to_resolve=other_args_to_resolve or {},
- )
- def _execute_impl(
- self, *args, **kwargs
- ) -> Union[ray.ObjectRef, "ray.actor.ActorHandle"]:
- return self._bound_args
- def _copy_impl(
- self,
- new_args: List[Any],
- new_kwargs: Dict[str, Any],
- new_options: Dict[str, Any],
- new_other_args_to_resolve: Dict[str, Any],
- ) -> "DAGNode":
- """Return a copy of this node with the given new args."""
- return MultiOutputNode(new_args, new_other_args_to_resolve)
- def __str__(self) -> str:
- return get_dag_node_str(self, "__MultiOutputNode__")
|