output_node.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. from typing import Any, Dict, List, Tuple, Union
  2. import ray
  3. from ray.dag import DAGNode
  4. from ray.dag.format_utils import get_dag_node_str
  5. from ray.util.annotations import DeveloperAPI
  6. @DeveloperAPI
  7. class MultiOutputNode(DAGNode):
  8. """Ray dag node used in DAG building API to mark the endpoint of DAG"""
  9. def __init__(
  10. self,
  11. args: Union[List[DAGNode], Tuple[DAGNode]],
  12. other_args_to_resolve: Dict[str, Any] = None,
  13. ):
  14. if isinstance(args, tuple):
  15. args = list(args)
  16. if not isinstance(args, list):
  17. raise ValueError(f"Invalid input type for `args`, {type(args)}.")
  18. super().__init__(
  19. args,
  20. {},
  21. {},
  22. other_args_to_resolve=other_args_to_resolve or {},
  23. )
  24. def _execute_impl(
  25. self, *args, **kwargs
  26. ) -> Union[ray.ObjectRef, "ray.actor.ActorHandle"]:
  27. return self._bound_args
  28. def _copy_impl(
  29. self,
  30. new_args: List[Any],
  31. new_kwargs: Dict[str, Any],
  32. new_options: Dict[str, Any],
  33. new_other_args_to_resolve: Dict[str, Any],
  34. ) -> "DAGNode":
  35. """Return a copy of this node with the given new args."""
  36. return MultiOutputNode(new_args, new_other_args_to_resolve)
  37. def __str__(self) -> str:
  38. return get_dag_node_str(self, "__MultiOutputNode__")