compiled_dag_ref.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. import asyncio
  2. from typing import Any, List, Optional
  3. import ray
  4. from ray.exceptions import (
  5. GetTimeoutError,
  6. RayChannelError,
  7. RayChannelTimeoutError,
  8. RayTaskError,
  9. )
  10. from ray.util.annotations import PublicAPI
  11. def _process_return_vals(return_vals: List[Any], return_single_output: bool):
  12. """
  13. Process return values for return to the DAG caller. Any exceptions found in
  14. return_vals will be raised. If return_single_output=True, it indicates that
  15. the original DAG did not have a MultiOutputNode, so the DAG caller expects
  16. a single return value instead of a list.
  17. """
  18. # Check for exceptions.
  19. if isinstance(return_vals, Exception):
  20. raise return_vals
  21. for val in return_vals:
  22. if isinstance(val, RayTaskError):
  23. raise val.as_instanceof_cause()
  24. if return_single_output:
  25. assert len(return_vals) == 1
  26. return return_vals[0]
  27. return return_vals
  28. @PublicAPI(stability="alpha")
  29. class CompiledDAGRef:
  30. """
  31. A reference to a compiled DAG execution result.
  32. This is a subclass of ObjectRef and resembles ObjectRef. For example,
  33. similar to ObjectRef, ray.get() can be called on it to retrieve the result.
  34. However, there are several major differences:
  35. 1. ray.get() can only be called once per CompiledDAGRef.
  36. 2. ray.wait() is not supported.
  37. 3. CompiledDAGRef cannot be copied, deep copied, or pickled.
  38. 4. CompiledDAGRef cannot be passed as an argument to another task.
  39. """
  40. def __init__(
  41. self,
  42. dag: "ray.experimental.CompiledDAG",
  43. execution_index: int,
  44. channel_index: Optional[int] = None,
  45. ):
  46. """
  47. Args:
  48. dag: The compiled DAG that generated this CompiledDAGRef.
  49. execution_index: The index of the execution for the DAG.
  50. A DAG can be executed multiple times, and execution index
  51. indicates which execution this CompiledDAGRef corresponds to.
  52. actor_execution_loop_refs: The actor execution loop refs that
  53. are used to execute the DAG. This can be used internally to
  54. check the task execution errors in case of exceptions.
  55. channel_index: The index of the DAG's output channel to fetch
  56. the result from. A DAG can have multiple output channels, and
  57. channel index indicates which channel this CompiledDAGRef
  58. corresponds to. If channel index is not provided, this CompiledDAGRef
  59. wraps the results from all output channels.
  60. """
  61. self._dag = dag
  62. self._execution_index = execution_index
  63. self._channel_index = channel_index
  64. # Whether ray.get() was called on this CompiledDAGRef.
  65. self._ray_get_called = False
  66. self._dag_output_channels = dag.dag_output_channels
  67. def __str__(self):
  68. return (
  69. f"CompiledDAGRef({self._dag.get_id()}, "
  70. f"execution_index={self._execution_index}, "
  71. f"channel_index={self._channel_index})"
  72. )
  73. def __copy__(self):
  74. raise ValueError("CompiledDAGRef cannot be copied.")
  75. def __deepcopy__(self, memo):
  76. raise ValueError("CompiledDAGRef cannot be deep copied.")
  77. def __reduce__(self):
  78. raise ValueError("CompiledDAGRef cannot be pickled.")
  79. def __del__(self):
  80. # If the dag is already teardown, it should do nothing.
  81. if self._dag.is_teardown:
  82. return
  83. if self._ray_get_called:
  84. # get() was already called, no further cleanup is needed.
  85. return
  86. self._dag._delete_execution_results(self._execution_index, self._channel_index)
  87. def get(self, timeout: Optional[float] = None):
  88. if self._ray_get_called:
  89. raise ValueError(
  90. "ray.get() can only be called once "
  91. "on a CompiledDAGRef, and it was already called."
  92. )
  93. self._ray_get_called = True
  94. try:
  95. self._dag._execute_until(
  96. self._execution_index, self._channel_index, timeout
  97. )
  98. return_vals = self._dag._get_execution_results(
  99. self._execution_index, self._channel_index
  100. )
  101. except RayChannelTimeoutError:
  102. raise
  103. except RayChannelError as channel_error:
  104. # If we get a channel error, we'd like to call ray.get() on
  105. # the actor execution loop refs to check if this is a result
  106. # of task execution error which could not be passed down
  107. # (e.g., when a pure NCCL channel is used, it is only
  108. # able to send tensors, but not the wrapped exceptions).
  109. # In this case, we'd like to raise the task execution error
  110. # (which is the actual cause of the channel error) instead
  111. # of the channel error itself.
  112. # TODO(rui): determine which error to raise if multiple
  113. # actor task refs have errors.
  114. actor_execution_loop_refs = list(self._dag.worker_task_refs.values())
  115. try:
  116. ray.get(actor_execution_loop_refs, timeout=10)
  117. except GetTimeoutError as timeout_error:
  118. raise Exception(
  119. "Timed out when getting the actor execution loop exception. "
  120. "This should not happen, please file a GitHub issue."
  121. ) from timeout_error
  122. except Exception as execution_error:
  123. # Use 'from None' to suppress the context of the original
  124. # channel error, which is not useful to the user.
  125. raise execution_error from None
  126. else:
  127. raise channel_error
  128. except Exception:
  129. raise
  130. return _process_return_vals(return_vals, True)
  131. @PublicAPI(stability="alpha")
  132. class CompiledDAGFuture:
  133. """
  134. A reference to a compiled DAG execution result, when executed with asyncio.
  135. This differs from CompiledDAGRef in that `await` must be called on the
  136. future to get the result, instead of `ray.get()`.
  137. This resembles async usage of ObjectRefs. For example, similar to
  138. ObjectRef, `await` can be called directly on the CompiledDAGFuture to
  139. retrieve the result. However, there are several major differences:
  140. 1. `await` can only be called once per CompiledDAGFuture.
  141. 2. ray.wait() is not supported.
  142. 3. CompiledDAGFuture cannot be copied, deep copied, or pickled.
  143. 4. CompiledDAGFuture cannot be passed as an argument to another task.
  144. """
  145. def __init__(
  146. self,
  147. dag: "ray.experimental.CompiledDAG",
  148. execution_index: int,
  149. fut: "asyncio.Future",
  150. channel_index: Optional[int] = None,
  151. ):
  152. self._dag = dag
  153. self._execution_index = execution_index
  154. self._fut = fut
  155. self._channel_index = channel_index
  156. def __str__(self):
  157. return (
  158. f"CompiledDAGFuture({self._dag.get_id()}, "
  159. f"execution_index={self._execution_index}, "
  160. f"channel_index={self._channel_index})"
  161. )
  162. def __copy__(self):
  163. raise ValueError("CompiledDAGFuture cannot be copied.")
  164. def __deepcopy__(self, memo):
  165. raise ValueError("CompiledDAGFuture cannot be deep copied.")
  166. def __reduce__(self):
  167. raise ValueError("CompiledDAGFuture cannot be pickled.")
  168. def __await__(self):
  169. if self._fut is None:
  170. raise ValueError(
  171. "CompiledDAGFuture can only be awaited upon once, and it has "
  172. "already been awaited upon."
  173. )
  174. # NOTE(swang): If the object is zero-copy deserialized, then it will
  175. # stay in scope as long as this future is in scope. Therefore, we
  176. # delete self._fut here before we return the result to the user.
  177. fut = self._fut
  178. self._fut = None
  179. if not self._dag._has_execution_results(self._execution_index):
  180. result = yield from fut.__await__()
  181. self._dag._max_finished_execution_index += 1
  182. self._dag._cache_execution_results(self._execution_index, result)
  183. return_vals = self._dag._get_execution_results(
  184. self._execution_index, self._channel_index
  185. )
  186. return _process_return_vals(return_vals, True)
  187. def __del__(self):
  188. if self._dag.is_teardown:
  189. return
  190. if self._fut is None:
  191. # await() was already called, no further cleanup is needed.
  192. return
  193. self._dag._delete_execution_results(self._execution_index, self._channel_index)