input_node.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. from typing import Any, Dict, List, Optional, Union
  2. from ray.dag import DAGNode
  3. from ray.dag.format_utils import get_dag_node_str
  4. from ray.experimental.gradio_utils import type_to_string
  5. from ray.util.annotations import DeveloperAPI
  6. IN_CONTEXT_MANAGER = "__in_context_manager__"
  7. @DeveloperAPI
  8. class InputNode(DAGNode):
  9. r"""Ray dag node used in DAG building API to mark entrypoints of a DAG.
  10. Should only be function or class method. A DAG can have multiple
  11. entrypoints, but only one instance of InputNode exists per DAG, shared
  12. among all DAGNodes.
  13. Example:
  14. .. code-block::
  15. m1.forward
  16. / \
  17. dag_input ensemble -> dag_output
  18. \ /
  19. m2.forward
  20. In this pipeline, each user input is broadcasted to both m1.forward and
  21. m2.forward as first stop of the DAG, and authored like
  22. .. code-block:: python
  23. import ray
  24. @ray.remote
  25. class Model:
  26. def __init__(self, val):
  27. self.val = val
  28. def forward(self, input):
  29. return self.val * input
  30. @ray.remote
  31. def combine(a, b):
  32. return a + b
  33. with InputNode() as dag_input:
  34. m1 = Model.bind(1)
  35. m2 = Model.bind(2)
  36. m1_output = m1.forward.bind(dag_input[0])
  37. m2_output = m2.forward.bind(dag_input.x)
  38. ray_dag = combine.bind(m1_output, m2_output)
  39. # Pass mix of args and kwargs as input.
  40. ray_dag.execute(1, x=2) # 1 sent to m1, 2 sent to m2
  41. # Alternatively user can also pass single data object, list or dict
  42. # and access them via list index, object attribute or dict key str.
  43. ray_dag.execute(UserDataObject(m1=1, m2=2))
  44. # dag_input.m1, dag_input.m2
  45. ray_dag.execute([1, 2])
  46. # dag_input[0], dag_input[1]
  47. ray_dag.execute({"m1": 1, "m2": 2})
  48. # dag_input["m1"], dag_input["m2"]
  49. """
  50. def __init__(
  51. self,
  52. *args,
  53. input_type: Optional[Union[type, Dict[Union[int, str], type]]] = None,
  54. _other_args_to_resolve=None,
  55. **kwargs,
  56. ):
  57. """InputNode should only take attributes of validating and converting
  58. input data rather than the input data itself. User input should be
  59. provided via `ray_dag.execute(user_input)`.
  60. Args:
  61. input_type: Describes the data type of inputs user will be giving.
  62. - if given through singular InputNode: type of InputNode
  63. - if given through InputAttributeNodes: map of key -> type
  64. Used when deciding what Gradio block to represent the input nodes with.
  65. _other_args_to_resolve: Internal only to keep InputNode's execution
  66. context throughput pickling, replacement and serialization.
  67. User should not use or pass this field.
  68. """
  69. if len(args) != 0 or len(kwargs) != 0:
  70. raise ValueError("InputNode should not take any args or kwargs.")
  71. self.input_attribute_nodes = {}
  72. self.input_type = input_type
  73. if input_type is not None and isinstance(input_type, type):
  74. if _other_args_to_resolve is None:
  75. _other_args_to_resolve = {}
  76. _other_args_to_resolve["result_type_string"] = type_to_string(input_type)
  77. super().__init__([], {}, {}, other_args_to_resolve=_other_args_to_resolve)
  78. def _copy_impl(
  79. self,
  80. new_args: List[Any],
  81. new_kwargs: Dict[str, Any],
  82. new_options: Dict[str, Any],
  83. new_other_args_to_resolve: Dict[str, Any],
  84. ):
  85. return InputNode(_other_args_to_resolve=new_other_args_to_resolve)
  86. def _execute_impl(self, *args, **kwargs):
  87. """Executor of InputNode."""
  88. # Catch and assert singleton context at dag execution time.
  89. assert self._in_context_manager(), (
  90. "InputNode is a singleton instance that should be only used in "
  91. "context manager for dag building and execution. See the docstring "
  92. "of class InputNode for examples."
  93. )
  94. # If user only passed in one value, for simplicity we just return it.
  95. if len(args) == 1 and len(kwargs) == 0:
  96. return args[0]
  97. return DAGInputData(*args, **kwargs)
  98. def _in_context_manager(self) -> bool:
  99. """Return if InputNode is created in context manager."""
  100. if (
  101. not self._bound_other_args_to_resolve
  102. or IN_CONTEXT_MANAGER not in self._bound_other_args_to_resolve
  103. ):
  104. return False
  105. else:
  106. return self._bound_other_args_to_resolve[IN_CONTEXT_MANAGER]
  107. def set_context(self, key: str, val: Any):
  108. """Set field in parent DAGNode attribute that can be resolved in both
  109. pickle and JSON serialization
  110. """
  111. self._bound_other_args_to_resolve[key] = val
  112. def __str__(self) -> str:
  113. return get_dag_node_str(self, "__InputNode__")
  114. def __getattr__(self, key: str):
  115. assert isinstance(
  116. key, str
  117. ), "Please only access dag input attributes with str key."
  118. if key not in self.input_attribute_nodes:
  119. self.input_attribute_nodes[key] = InputAttributeNode(
  120. self, key, "__getattr__"
  121. )
  122. return self.input_attribute_nodes[key]
  123. def __getitem__(self, key: Union[int, str]) -> Any:
  124. assert isinstance(key, (str, int)), (
  125. "Please only use int index or str as first-level key to "
  126. "access fields of dag input."
  127. )
  128. input_type = None
  129. if self.input_type is not None and key in self.input_type:
  130. input_type = type_to_string(self.input_type[key])
  131. if key not in self.input_attribute_nodes:
  132. self.input_attribute_nodes[key] = InputAttributeNode(
  133. self, key, "__getitem__", input_type
  134. )
  135. return self.input_attribute_nodes[key]
  136. def __enter__(self):
  137. self.set_context(IN_CONTEXT_MANAGER, True)
  138. return self
  139. def __exit__(self, *args):
  140. pass
  141. def get_result_type(self) -> str:
  142. """Get type of the output of this DAGNode.
  143. Generated by ray.experimental.gradio_utils.type_to_string().
  144. """
  145. if "result_type_string" in self._bound_other_args_to_resolve:
  146. return self._bound_other_args_to_resolve["result_type_string"]
  147. @DeveloperAPI
  148. class InputAttributeNode(DAGNode):
  149. """Represents partial access of user input based on an index (int),
  150. object attribute or dict key (str).
  151. Examples:
  152. .. code-block:: python
  153. with InputNode() as dag_input:
  154. a = dag_input[0]
  155. b = dag_input.x
  156. ray_dag = add.bind(a, b)
  157. # This makes a = 1 and b = 2
  158. ray_dag.execute(1, x=2)
  159. with InputNode() as dag_input:
  160. a = dag_input[0]
  161. b = dag_input[1]
  162. ray_dag = add.bind(a, b)
  163. # This makes a = 2 and b = 3
  164. ray_dag.execute(2, 3)
  165. # Alternatively, you can input a single object
  166. # and the inputs are automatically indexed from the object:
  167. # This makes a = 2 and b = 3
  168. ray_dag.execute([2, 3])
  169. """
  170. def __init__(
  171. self,
  172. dag_input_node: InputNode,
  173. key: Union[int, str],
  174. accessor_method: str,
  175. input_type: str = None,
  176. ):
  177. self._dag_input_node = dag_input_node
  178. self._key = key
  179. self._accessor_method = accessor_method
  180. super().__init__(
  181. [],
  182. {},
  183. {},
  184. {
  185. "dag_input_node": dag_input_node,
  186. "key": key,
  187. "accessor_method": accessor_method,
  188. # Type of the input tied to this node. Used by
  189. # gradio_visualize_graph.GraphVisualizer to determine which Gradio
  190. # component should be used for this node.
  191. "result_type_string": input_type,
  192. },
  193. )
  194. def _copy_impl(
  195. self,
  196. new_args: List[Any],
  197. new_kwargs: Dict[str, Any],
  198. new_options: Dict[str, Any],
  199. new_other_args_to_resolve: Dict[str, Any],
  200. ):
  201. return InputAttributeNode(
  202. new_other_args_to_resolve["dag_input_node"],
  203. new_other_args_to_resolve["key"],
  204. new_other_args_to_resolve["accessor_method"],
  205. new_other_args_to_resolve["result_type_string"],
  206. )
  207. def _execute_impl(self, *args, **kwargs):
  208. """Executor of InputAttributeNode.
  209. Args and kwargs are to match base class signature, but not in the
  210. implementation. All args and kwargs should be resolved and replaced
  211. with value in bound_args and bound_kwargs via bottom-up recursion when
  212. current node is executed.
  213. """
  214. if isinstance(self._dag_input_node, DAGInputData):
  215. return self._dag_input_node[self._key]
  216. else:
  217. # dag.execute() is called with only one arg, thus when an
  218. # InputAttributeNode is executed, its dependent InputNode is
  219. # resolved with original user input python object.
  220. user_input_python_object = self._dag_input_node
  221. if isinstance(self._key, str):
  222. if self._accessor_method == "__getitem__":
  223. return user_input_python_object[self._key]
  224. elif self._accessor_method == "__getattr__":
  225. return getattr(user_input_python_object, self._key)
  226. elif isinstance(self._key, int):
  227. return user_input_python_object[self._key]
  228. else:
  229. raise ValueError(
  230. "Please only use int index or str as first-level key to "
  231. "access fields of dag input."
  232. )
  233. def __str__(self) -> str:
  234. return get_dag_node_str(self, f'["{self._key}"]')
  235. def get_result_type(self) -> str:
  236. """Get type of the output of this DAGNode.
  237. Generated by ray.experimental.gradio_utils.type_to_string().
  238. """
  239. if "result_type_string" in self._bound_other_args_to_resolve:
  240. return self._bound_other_args_to_resolve["result_type_string"]
  241. @property
  242. def key(self) -> Union[int, str]:
  243. return self._key
  244. @DeveloperAPI
  245. class DAGInputData:
  246. """If user passed multiple args and kwargs directly to dag.execute(), we
  247. generate this wrapper for all user inputs as one object, accessible via
  248. list index or object attribute key.
  249. """
  250. def __init__(self, *args, **kwargs):
  251. self._args = list(args)
  252. self._kwargs = kwargs
  253. def __getitem__(self, key: Union[int, str]) -> Any:
  254. if isinstance(key, int):
  255. # Access list args by index.
  256. return self._args[key]
  257. elif isinstance(key, str):
  258. # Access kwarg by key.
  259. return self._kwargs[key]
  260. else:
  261. raise ValueError(
  262. "Please only use int index or str as first-level key to "
  263. "access fields of dag input."
  264. )