| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859 |
- from typing import Any, Dict, List
- import ray
- from ray.dag.dag_node import DAGNode
- from ray.dag.format_utils import get_dag_node_str
- from ray.util.annotations import DeveloperAPI
- @DeveloperAPI
- class FunctionNode(DAGNode):
- """Represents a bound task node in a Ray task DAG."""
- def __init__(
- self,
- func_body,
- func_args,
- func_kwargs,
- func_options,
- other_args_to_resolve=None,
- ):
- self._body = func_body
- super().__init__(
- func_args,
- func_kwargs,
- func_options,
- other_args_to_resolve=other_args_to_resolve,
- )
- def _copy_impl(
- self,
- new_args: List[Any],
- new_kwargs: Dict[str, Any],
- new_options: Dict[str, Any],
- new_other_args_to_resolve: Dict[str, Any],
- ):
- return FunctionNode(
- self._body,
- new_args,
- new_kwargs,
- new_options,
- other_args_to_resolve=new_other_args_to_resolve,
- )
- def _execute_impl(self, *args, **kwargs):
- """Executor of FunctionNode by ray.remote().
- Args and kwargs are to match base class signature, but not in the
- implementation. All args and kwargs should be resolved and replaced
- with value in bound_args and bound_kwargs via bottom-up recursion when
- current node is executed.
- """
- return (
- ray.remote(self._body)
- .options(**self._bound_options)
- .remote(*self._bound_args, **self._bound_kwargs)
- )
- def __str__(self) -> str:
- return get_dag_node_str(self, str(self._body))
|