callbacks.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. import contextlib
  2. from collections import defaultdict, namedtuple
  3. from datetime import datetime
  4. from typing import Any, List, Optional
  5. from dask.callbacks import Callback
  6. import ray
  7. # The names of the Ray-specific callbacks. These are the kwarg names that
  8. # RayDaskCallback will accept on construction, and is considered the
  9. # source-of-truth for what Ray-specific callbacks exist.
  10. CBS = (
  11. "ray_presubmit",
  12. "ray_postsubmit",
  13. "ray_pretask",
  14. "ray_posttask",
  15. "ray_postsubmit_all",
  16. "ray_finish",
  17. )
  18. # The Ray-specific callback method names for RayDaskCallback.
  19. CB_FIELDS = tuple("_" + field for field in CBS)
  20. # The Ray-specific callbacks that we do _not_ wish to drop from RayCallbacks
  21. # if not given on a RayDaskCallback instance (will be filled with None
  22. # instead).
  23. CBS_DONT_DROP = {"ray_pretask", "ray_posttask"}
  24. # The Ray-specific callbacks for a single RayDaskCallback.
  25. RayCallback = namedtuple("RayCallback", " ".join(CBS))
  26. # The Ray-specific callbacks for one or more RayDaskCallbacks.
  27. RayCallbacks = namedtuple("RayCallbacks", " ".join([field + "_cbs" for field in CBS]))
  28. class RayDaskCallback(Callback):
  29. """
  30. Extends Dask's `Callback` class with Ray-specific hooks. When instantiating
  31. or subclassing this class, both the normal Dask hooks (e.g. pretask,
  32. posttask, etc.) and the Ray-specific hooks can be provided.
  33. See `dask.callbacks.Callback` for usage.
  34. Caveats: Any Dask-Ray scheduler must bring the Ray-specific callbacks into
  35. context using the `local_ray_callbacks` context manager, since the built-in
  36. `local_callbacks` context manager provided by Dask isn't aware of this
  37. class.
  38. """
  39. # Set of active Ray-specific callbacks.
  40. ray_active = set()
  41. def __init__(self, **kwargs):
  42. for cb in CBS:
  43. cb_func = kwargs.pop(cb, None)
  44. if cb_func is not None:
  45. setattr(self, "_" + cb, cb_func)
  46. super().__init__(**kwargs)
  47. @property
  48. def _ray_callback(self):
  49. return RayCallback(*[getattr(self, field, None) for field in CB_FIELDS])
  50. def __enter__(self):
  51. self._ray_cm = add_ray_callbacks(self)
  52. self._ray_cm.__enter__()
  53. super().__enter__()
  54. return self
  55. def __exit__(self, *args):
  56. super().__exit__(*args)
  57. self._ray_cm.__exit__(*args)
  58. def register(self):
  59. type(self).ray_active.add(self._ray_callback)
  60. super().register()
  61. def unregister(self):
  62. type(self).ray_active.remove(self._ray_callback)
  63. super().unregister()
  64. def _ray_presubmit(self, task, key, deps) -> Optional[Any]:
  65. """Run before submitting a Ray task.
  66. If this callback returns a non-`None` value, Ray does _not_ create
  67. a task and uses this value as the would-be task's result value.
  68. Args:
  69. task: A Dask task, where the first tuple item is
  70. the task function, and the remaining tuple items are
  71. the task arguments, which are either the actual argument values,
  72. or Dask keys into the deps dictionary whose
  73. corresponding values are the argument values.
  74. key: The Dask graph key for the given task.
  75. deps: The dependencies of this task.
  76. Returns:
  77. Either None, in which case Ray submits a task, or
  78. a non-None value, in which case Ray task doesn't submit
  79. a task and uses this return value as the
  80. would-be task result value.
  81. """
  82. pass
  83. def _ray_postsubmit(self, task, key, deps, object_ref: ray.ObjectRef):
  84. """Run after submitting a Ray task.
  85. Args:
  86. task: A Dask task, where the first tuple item is
  87. the task function, and the remaining tuple items are
  88. the task arguments, which are either the actual argument values,
  89. or Dask keys into the deps dictionary whose
  90. corresponding values are the argument values.
  91. key: The Dask graph key for the given task.
  92. deps: The dependencies of this task.
  93. object_ref: The object reference for the
  94. return value of the Ray task.
  95. """
  96. pass
  97. def _ray_pretask(self, key, object_refs: List[ray.ObjectRef]):
  98. """Run before executing a Dask task within a Ray task.
  99. This method executes after Ray submits the task within a Ray
  100. worker. Ray passes the return value of this task to the
  101. _ray_posttask callback, if provided.
  102. Args:
  103. key: The Dask graph key for the Dask task.
  104. object_refs: The object references
  105. for the arguments of the Ray task.
  106. Returns:
  107. A value that Ray passes to the corresponding
  108. _ray_posttask callback, if the callback is defined.
  109. """
  110. pass
  111. def _ray_posttask(self, key, result, pre_state):
  112. """Run after executing a Dask task within a Ray task.
  113. This method executes within a Ray worker. This callback receives the
  114. return value of the _ray_pretask callback, if provided.
  115. Args:
  116. key: The Dask graph key for the Dask task.
  117. result: The task result value.
  118. pre_state: The return value of the corresponding
  119. _ray_pretask callback, if said callback is defined.
  120. """
  121. pass
  122. def _ray_postsubmit_all(self, object_refs: List[ray.ObjectRef], dsk):
  123. """Run after Ray submits all tasks.
  124. Args:
  125. object_refs: The object references
  126. for the output (leaf) Ray tasks of the task graph.
  127. dsk: The Dask graph.
  128. """
  129. pass
  130. def _ray_finish(self, result):
  131. """Run after Ray finishes executing all Ray tasks and returns the final
  132. result.
  133. Args:
  134. result: The final result (output) of the Dask
  135. computation, before any repackaging is done by
  136. Dask collection-specific post-compute callbacks.
  137. """
  138. pass
  139. class add_ray_callbacks:
  140. def __init__(self, *callbacks):
  141. self.callbacks = [normalize_ray_callback(c) for c in callbacks]
  142. RayDaskCallback.ray_active.update(self.callbacks)
  143. def __enter__(self):
  144. return self
  145. def __exit__(self, *args):
  146. for c in self.callbacks:
  147. RayDaskCallback.ray_active.discard(c)
  148. def normalize_ray_callback(cb):
  149. if isinstance(cb, RayDaskCallback):
  150. return cb._ray_callback
  151. elif isinstance(cb, RayCallback):
  152. return cb
  153. else:
  154. raise TypeError(
  155. "Callbacks must be either 'RayDaskCallback' or 'RayCallback' namedtuple"
  156. )
  157. def unpack_ray_callbacks(cbs):
  158. """Take an iterable of callbacks, return a list of each callback."""
  159. if cbs:
  160. # Only drop callback methods that aren't in CBS_DONT_DROP.
  161. return RayCallbacks(
  162. *(
  163. [cb for cb in cbs_ if cb or CBS[idx] in CBS_DONT_DROP] or None
  164. for idx, cbs_ in enumerate(zip(*cbs))
  165. )
  166. )
  167. else:
  168. return RayCallbacks(*([()] * len(CBS)))
  169. @contextlib.contextmanager
  170. def local_ray_callbacks(callbacks=None):
  171. """
  172. Allows Dask-Ray callbacks to work with nested schedulers.
  173. Callbacks will only be used by the first started scheduler they encounter.
  174. This means that only the outermost scheduler will use global callbacks.
  175. """
  176. global_callbacks = callbacks is None
  177. if global_callbacks:
  178. callbacks, RayDaskCallback.ray_active = (RayDaskCallback.ray_active, set())
  179. try:
  180. yield callbacks or ()
  181. finally:
  182. if global_callbacks:
  183. RayDaskCallback.ray_active = callbacks
  184. class ProgressBarCallback(RayDaskCallback):
  185. def __init__(self):
  186. @ray.remote
  187. class ProgressBarActor:
  188. def __init__(self):
  189. self._init()
  190. def submit(self, key, deps, now):
  191. for dep in deps.keys():
  192. self.deps[key].add(dep)
  193. self.submitted[key] = now
  194. self.submission_queue.append((key, now))
  195. def task_scheduled(self, key, now):
  196. self.scheduled[key] = now
  197. def finish(self, key, now):
  198. self.finished[key] = now
  199. def result(self):
  200. return len(self.submitted), len(self.finished)
  201. def report(self):
  202. result = defaultdict(dict)
  203. for key, finished in self.finished.items():
  204. submitted = self.submitted[key]
  205. scheduled = self.scheduled[key]
  206. # deps = self.deps[key]
  207. result[key]["execution_time"] = (
  208. finished - scheduled
  209. ).total_seconds()
  210. # Calculate the scheduling time.
  211. # This is inaccurate.
  212. # We should subtract scheduled - (last dep completed).
  213. # But currently it is not easy because
  214. # of how getitem is implemented in dask on ray sort.
  215. result[key]["scheduling_time"] = (
  216. scheduled - submitted
  217. ).total_seconds()
  218. result["submission_order"] = self.submission_queue
  219. return result
  220. def ready(self):
  221. pass
  222. def reset(self):
  223. self._init()
  224. def _init(self):
  225. self.submission_queue = []
  226. self.submitted = defaultdict(None)
  227. self.scheduled = defaultdict(None)
  228. self.finished = defaultdict(None)
  229. self.deps = defaultdict(set)
  230. try:
  231. self.pb = ray.get_actor("_dask_on_ray_pb")
  232. ray.get(self.pb.reset.remote())
  233. except ValueError:
  234. self.pb = ProgressBarActor.options(name="_dask_on_ray_pb").remote()
  235. ray.get(self.pb.ready.remote())
  236. def _ray_postsubmit(self, task, key, deps, object_ref):
  237. # Indicate the dask task is submitted.
  238. self.pb.submit.remote(key, deps, datetime.now())
  239. def _ray_pretask(self, key, object_refs):
  240. self.pb.task_scheduled.remote(key, datetime.now())
  241. def _ray_posttask(self, key, result, pre_state):
  242. # Indicate the dask task is finished.
  243. self.pb.finish.remote(key, datetime.now())
  244. def _ray_finish(self, result):
  245. print("All tasks are completed.")