scheduler_utils.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. """
  2. The following is adapted from Dask release 2021.03.1:
  3. https://github.com/dask/dask/blob/2021.03.1/dask/local.py
  4. """
  5. import os
  6. import warnings
  7. from queue import Empty, Queue
  8. import dask
  9. from dask import config
  10. try:
  11. from dask._task_spec import DataNode, DependenciesMapping
  12. except ImportError:
  13. warnings.warn(
  14. "Dask on Ray is available only on dask>=2024.11.0, "
  15. f"you are on version {dask.__version__}."
  16. )
  17. from dask.callbacks import local_callbacks, unpack_callbacks
  18. from dask.core import flatten, get_dependencies, reverse_dict
  19. from dask.order import order
  20. if os.name == "nt":
  21. # Python 3 windows Queue.get doesn't handle interrupts properly. To
  22. # workaround this we poll at a sufficiently large interval that it
  23. # shouldn't affect performance, but small enough that users trying to kill
  24. # an application shouldn't care.
  25. def queue_get(q):
  26. while True:
  27. try:
  28. return q.get(block=True, timeout=0.1)
  29. except Empty:
  30. pass
  31. else:
  32. def queue_get(q):
  33. return q.get()
  34. def start_state_from_dask(dsk, cache=None, sortkey=None):
  35. """Start state from a dask
  36. Examples
  37. --------
  38. >>> dsk = {
  39. 'x': 1,
  40. 'y': 2,
  41. 'z': (inc, 'x'),
  42. 'w': (add, 'z', 'y')} # doctest: +SKIP
  43. >>> from pprint import pprint # doctest: +SKIP
  44. >>> pprint(start_state_from_dask(dsk)) # doctest: +SKIP
  45. {'cache': {'x': 1, 'y': 2},
  46. 'dependencies': {'w': {'z', 'y'}, 'x': set(), 'y': set(), 'z': {'x'}},
  47. 'dependents': {'w': set(), 'x': {'z'}, 'y': {'w'}, 'z': {'w'}},
  48. 'finished': set(),
  49. 'ready': ['z'],
  50. 'released': set(),
  51. 'running': set(),
  52. 'waiting': {'w': {'z'}},
  53. 'waiting_data': {'x': {'z'}, 'y': {'w'}, 'z': {'w'}}}
  54. """
  55. if sortkey is None:
  56. sortkey = order(dsk).get
  57. if cache is None:
  58. cache = config.get("cache", None)
  59. if cache is None:
  60. cache = dict()
  61. data_keys = set()
  62. for k, v in dsk.items():
  63. if isinstance(v, DataNode):
  64. cache[k] = v()
  65. data_keys.add(k)
  66. dsk2 = dsk.copy()
  67. dsk2.update(cache)
  68. dependencies = DependenciesMapping(dsk)
  69. waiting = {k: set(v) for k, v in dependencies.items() if k not in data_keys}
  70. dependents = reverse_dict(dependencies)
  71. for a in cache:
  72. for b in dependents.get(a, ()):
  73. waiting[b].remove(a)
  74. waiting_data = {k: v.copy() for k, v in dependents.items() if v}
  75. ready_set = {k for k, v in waiting.items() if not v}
  76. ready = sorted(ready_set, key=sortkey, reverse=True)
  77. waiting = {k: v for k, v in waiting.items() if v}
  78. state = {
  79. "dependencies": dependencies,
  80. "dependents": dependents,
  81. "waiting": waiting,
  82. "waiting_data": waiting_data,
  83. "cache": cache,
  84. "ready": ready,
  85. "running": set(),
  86. "finished": set(),
  87. "released": set(),
  88. }
  89. return state
  90. def execute_task(key, task_info, dumps, loads, get_id, pack_exception):
  91. """
  92. Compute task and handle all administration
  93. See Also
  94. --------
  95. _execute_task : actually execute task
  96. """
  97. try:
  98. task, data = loads(task_info)
  99. result = task(data)
  100. id = get_id()
  101. result = dumps((result, id))
  102. failed = False
  103. except BaseException as e:
  104. result = pack_exception(e, dumps)
  105. failed = True
  106. return key, result, failed
  107. def release_data(key, state, delete=True):
  108. """Remove data from temporary storage
  109. See Also
  110. --------
  111. finish_task
  112. """
  113. if key in state["waiting_data"]:
  114. assert not state["waiting_data"][key]
  115. del state["waiting_data"][key]
  116. state["released"].add(key)
  117. if delete:
  118. del state["cache"][key]
  119. DEBUG = False
  120. def finish_task(
  121. dsk, key, state, results, sortkey, delete=True, release_data=release_data
  122. ):
  123. """
  124. Update execution state after a task finishes
  125. Mutates. This should run atomically (with a lock).
  126. """
  127. for dep in sorted(state["dependents"][key], key=sortkey, reverse=True):
  128. s = state["waiting"][dep]
  129. s.remove(key)
  130. if not s:
  131. del state["waiting"][dep]
  132. state["ready"].append(dep)
  133. for dep in state["dependencies"][key]:
  134. if dep in state["waiting_data"]:
  135. s = state["waiting_data"][dep]
  136. s.remove(key)
  137. if not s and dep not in results:
  138. if DEBUG:
  139. from chest.core import nbytes
  140. print(
  141. "Key: %s\tDep: %s\t NBytes: %.2f\t Release"
  142. % (key, dep, sum(map(nbytes, state["cache"].values()) / 1e6))
  143. )
  144. release_data(dep, state, delete=delete)
  145. elif delete and dep not in results:
  146. release_data(dep, state, delete=delete)
  147. state["finished"].add(key)
  148. state["running"].remove(key)
  149. return state
  150. def nested_get(ind, coll):
  151. """Get nested index from collection
  152. Examples
  153. --------
  154. >>> nested_get(1, 'abc')
  155. 'b'
  156. >>> nested_get([1, 0], 'abc')
  157. ('b', 'a')
  158. >>> nested_get([[1, 0], [0, 1]], 'abc')
  159. (('b', 'a'), ('a', 'b'))
  160. """
  161. if isinstance(ind, list):
  162. return tuple(nested_get(i, coll) for i in ind)
  163. else:
  164. return coll[ind]
  165. def default_get_id():
  166. """Default get_id"""
  167. return None
  168. def default_pack_exception(e, dumps):
  169. raise
  170. def reraise(exc, tb=None):
  171. if exc.__traceback__ is not tb:
  172. raise exc.with_traceback(tb)
  173. raise exc
  174. def identity(x):
  175. """Identity function. Returns x.
  176. >>> identity(3)
  177. 3
  178. """
  179. return x
  180. def get_async(
  181. apply_async,
  182. num_workers,
  183. dsk,
  184. result,
  185. cache=None,
  186. get_id=default_get_id,
  187. rerun_exceptions_locally=None,
  188. pack_exception=default_pack_exception,
  189. raise_exception=reraise,
  190. callbacks=None,
  191. dumps=identity,
  192. loads=identity,
  193. **kwargs,
  194. ):
  195. """Asynchronous get function
  196. This is a general version of various asynchronous schedulers for dask. It
  197. takes a an apply_async function as found on Pool objects to form a more
  198. specific ``get`` method that walks through the dask array with parallel
  199. workers, avoiding repeat computation and minimizing memory use.
  200. Parameters
  201. ----------
  202. apply_async : function
  203. Asynchronous apply function as found on Pool or ThreadPool
  204. num_workers : int
  205. The number of active tasks we should have at any one time
  206. dsk : dict
  207. A dask dictionary specifying a workflow
  208. result : key or list of keys
  209. Keys corresponding to desired data
  210. cache : dict-like, optional
  211. Temporary storage of results
  212. get_id : callable, optional
  213. Function to return the worker id, takes no arguments. Examples are
  214. `threading.current_thread` and `multiprocessing.current_process`.
  215. rerun_exceptions_locally : bool, optional
  216. Whether to rerun failing tasks in local process to enable debugging
  217. (False by default)
  218. pack_exception : callable, optional
  219. Function to take an exception and ``dumps`` method, and return a
  220. serialized tuple of ``(exception, traceback)`` to send back to the
  221. scheduler. Default is to just raise the exception.
  222. raise_exception : callable, optional
  223. Function that takes an exception and a traceback, and raises an error.
  224. dumps: callable, optional
  225. Function to serialize task data and results to communicate between
  226. worker and parent. Defaults to identity.
  227. loads: callable, optional
  228. Inverse function of `dumps`. Defaults to identity.
  229. callbacks : tuple or list of tuples, optional
  230. Callbacks are passed in as tuples of length 5. Multiple sets of
  231. callbacks may be passed in as a list of tuples. For more information,
  232. see the dask.diagnostics documentation.
  233. See Also
  234. --------
  235. threaded.get
  236. """
  237. queue = Queue()
  238. if isinstance(result, list):
  239. result_flat = set(flatten(result))
  240. else:
  241. result_flat = {result}
  242. results = set(result_flat)
  243. dsk = dict(dsk)
  244. with local_callbacks(callbacks) as callbacks:
  245. _, _, pretask_cbs, posttask_cbs, _ = unpack_callbacks(callbacks)
  246. started_cbs = []
  247. succeeded = False
  248. # if start_state_from_dask fails, we will have something
  249. # to pass to the final block.
  250. state = {}
  251. try:
  252. for cb in callbacks:
  253. if cb[0]:
  254. cb[0](dsk)
  255. started_cbs.append(cb)
  256. keyorder = order(dsk)
  257. state = start_state_from_dask(dsk, cache=cache, sortkey=keyorder.get)
  258. for _, start_state, _, _, _ in callbacks:
  259. if start_state:
  260. start_state(dsk, state)
  261. if rerun_exceptions_locally is None:
  262. rerun_exceptions_locally = config.get("rerun_exceptions_locally", False)
  263. if state["waiting"] and not state["ready"]:
  264. raise ValueError("Found no accessible jobs in dask")
  265. def fire_task():
  266. """Fire off a task to the thread pool"""
  267. # Choose a good task to compute
  268. key = state["ready"].pop()
  269. state["running"].add(key)
  270. for f in pretask_cbs:
  271. f(key, dsk, state)
  272. # Prep data to send
  273. data = {dep: state["cache"][dep] for dep in get_dependencies(dsk, key)}
  274. # Submit
  275. apply_async(
  276. execute_task,
  277. args=(
  278. key,
  279. dumps((dsk[key], data)),
  280. dumps,
  281. loads,
  282. get_id,
  283. pack_exception,
  284. ),
  285. callback=queue.put,
  286. )
  287. # Seed initial tasks into the thread pool
  288. while state["ready"] and len(state["running"]) < num_workers:
  289. fire_task()
  290. # Main loop, wait on tasks to finish, insert new ones
  291. while state["waiting"] or state["ready"] or state["running"]:
  292. key, res_info, failed = queue_get(queue)
  293. if failed:
  294. exc, tb = loads(res_info)
  295. if rerun_exceptions_locally:
  296. data = {
  297. dep: state["cache"][dep]
  298. for dep in get_dependencies(dsk, key)
  299. }
  300. task = dsk[key]
  301. task(data) # Re-execute locally
  302. else:
  303. raise_exception(exc, tb)
  304. res, worker_id = loads(res_info)
  305. state["cache"][key] = res
  306. finish_task(dsk, key, state, results, keyorder.get)
  307. for f in posttask_cbs:
  308. f(key, res, dsk, state, worker_id)
  309. while state["ready"] and len(state["running"]) < num_workers:
  310. fire_task()
  311. succeeded = True
  312. finally:
  313. for _, _, _, _, finish in started_cbs:
  314. if finish:
  315. finish(dsk, state, not succeeded)
  316. return nested_get(result, state["cache"])
  317. def apply_sync(func, args=(), kwds=None, callback=None):
  318. """A naive synchronous version of apply_async"""
  319. if kwds is None:
  320. kwds = {}
  321. res = func(*args, **kwds)
  322. if callback is not None:
  323. callback(res)