function_node.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. from typing import Any, Dict, List
  2. import ray
  3. from ray.dag.dag_node import DAGNode
  4. from ray.dag.format_utils import get_dag_node_str
  5. from ray.util.annotations import DeveloperAPI
  6. @DeveloperAPI
  7. class FunctionNode(DAGNode):
  8. """Represents a bound task node in a Ray task DAG."""
  9. def __init__(
  10. self,
  11. func_body,
  12. func_args,
  13. func_kwargs,
  14. func_options,
  15. other_args_to_resolve=None,
  16. ):
  17. self._body = func_body
  18. super().__init__(
  19. func_args,
  20. func_kwargs,
  21. func_options,
  22. other_args_to_resolve=other_args_to_resolve,
  23. )
  24. def _copy_impl(
  25. self,
  26. new_args: List[Any],
  27. new_kwargs: Dict[str, Any],
  28. new_options: Dict[str, Any],
  29. new_other_args_to_resolve: Dict[str, Any],
  30. ):
  31. return FunctionNode(
  32. self._body,
  33. new_args,
  34. new_kwargs,
  35. new_options,
  36. other_args_to_resolve=new_other_args_to_resolve,
  37. )
  38. def _execute_impl(self, *args, **kwargs):
  39. """Executor of FunctionNode by ray.remote().
  40. Args and kwargs are to match base class signature, but not in the
  41. implementation. All args and kwargs should be resolved and replaced
  42. with value in bound_args and bound_kwargs via bottom-up recursion when
  43. current node is executed.
  44. """
  45. return (
  46. ray.remote(self._body)
  47. .options(**self._bound_options)
  48. .remote(*self._bound_args, **self._bound_kwargs)
  49. )
  50. def __str__(self) -> str:
  51. return get_dag_node_str(self, str(self._body))