scheduler.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678
  1. import atexit
  2. import threading
  3. import time
  4. import warnings
  5. from collections import OrderedDict, defaultdict
  6. from collections.abc import Mapping
  7. from dataclasses import dataclass
  8. from multiprocessing.pool import ThreadPool
  9. from pprint import pprint
  10. from typing import Optional
  11. import dask
  12. from dask.core import ishashable, istask
  13. try:
  14. from dask._task_spec import Alias, DataNode, Task, TaskRef, convert_legacy_graph
  15. except ImportError:
  16. warnings.warn(
  17. "Dask on Ray is available only on dask>=2024.11.0, "
  18. f"you are on version {dask.__version__}."
  19. )
  20. from dask.system import CPU_COUNT
  21. from dask.threaded import _thread_get_id, pack_exception
  22. import ray
  23. from ray.util.dask.callbacks import local_ray_callbacks, unpack_ray_callbacks
  24. from ray.util.dask.common import unpack_object_refs
  25. from ray.util.dask.scheduler_utils import apply_sync, get_async
  26. main_thread = threading.current_thread()
  27. default_pool = None
  28. pools = defaultdict(dict)
  29. pools_lock = threading.Lock()
  30. TOP_LEVEL_RESOURCES_ERR_MSG = (
  31. 'Use ray_remote_args={"resources": {...}} instead of resources={...} to specify '
  32. "required Ray task resources; see "
  33. "https://docs.ray.io/en/master/ray-core/package-ref.html#ray-remote."
  34. )
  35. def enable_dask_on_ray(
  36. shuffle: Optional[str] = "tasks",
  37. use_shuffle_optimization: Optional[bool] = True,
  38. ) -> dask.config.set:
  39. """
  40. Enable Dask-on-Ray scheduler. This helper sets the Dask-on-Ray scheduler
  41. as the default Dask scheduler in the Dask config. By default, it will also
  42. cause the task-based shuffle to be used for any Dask shuffle operations
  43. (required for multi-node Ray clusters, not sharing a filesystem), and will
  44. enable a Ray-specific shuffle optimization.
  45. >>> enable_dask_on_ray()
  46. >>> ddf.compute() # <-- will use the Dask-on-Ray scheduler.
  47. If used as a context manager, the Dask-on-Ray scheduler will only be used
  48. within the context's scope.
  49. >>> with enable_dask_on_ray():
  50. ... ddf.compute() # <-- will use the Dask-on-Ray scheduler.
  51. >>> ddf.compute() # <-- won't use the Dask-on-Ray scheduler.
  52. Args:
  53. shuffle: The shuffle method used by Dask, either "tasks" or
  54. "disk". This should be "tasks" if using a multi-node Ray cluster.
  55. Defaults to "tasks".
  56. use_shuffle_optimization: Enable our custom Ray-specific shuffle
  57. optimization. Defaults to True.
  58. Returns:
  59. The Dask config object, which can be used as a context manager to limit
  60. the scope of the Dask-on-Ray scheduler to the corresponding context.
  61. """
  62. if use_shuffle_optimization:
  63. from ray.util.dask.optimizations import dataframe_optimize
  64. else:
  65. dataframe_optimize = None
  66. # Manually set the global Dask scheduler config.
  67. # We also force the task-based shuffle to be used since the disk-based
  68. # shuffle doesn't work for a multi-node Ray cluster that doesn't share
  69. # the filesystem.
  70. return dask.config.set(
  71. scheduler=ray_dask_get, shuffle=shuffle, dataframe_optimize=dataframe_optimize
  72. )
  73. def disable_dask_on_ray():
  74. """
  75. Unsets the scheduler, shuffle method, and DataFrame optimizer.
  76. """
  77. return dask.config.set(scheduler=None, shuffle=None, dataframe_optimize=None)
  78. def ray_dask_get(dsk, keys, **kwargs):
  79. """
  80. A Dask-Ray scheduler. This scheduler will send top-level (non-inlined) Dask
  81. tasks to a Ray cluster for execution. The scheduler will wait for the
  82. tasks to finish executing, fetch the results, and repackage them into the
  83. appropriate Dask collections. This particular scheduler uses a threadpool
  84. to submit Ray tasks.
  85. This can be passed directly to `dask.compute()`, as the scheduler:
  86. >>> dask.compute(obj, scheduler=ray_dask_get)
  87. You can override the currently active global Dask-Ray callbacks (e.g.
  88. supplied via a context manager), the number of threads to use when
  89. submitting the Ray tasks, or the threadpool used to submit Ray tasks:
  90. >>> dask.compute(
  91. obj,
  92. scheduler=ray_dask_get,
  93. ray_callbacks=some_ray_dask_callbacks,
  94. num_workers=8,
  95. pool=some_cool_pool,
  96. )
  97. Args:
  98. dsk: Dask graph, represented as a task DAG dictionary.
  99. keys (List[str]): List of Dask graph keys whose values we wish to
  100. compute and return.
  101. ray_callbacks (Optional[list[callable]]): Dask-Ray callbacks.
  102. num_workers (Optional[int]): The number of worker threads to use in
  103. the Ray task submission traversal of the Dask graph.
  104. pool (Optional[ThreadPool]): A multiprocessing threadpool to use to
  105. submit Ray tasks.
  106. Returns:
  107. Computed values corresponding to the provided keys.
  108. """
  109. num_workers = kwargs.pop("num_workers", None)
  110. pool = kwargs.pop("pool", None)
  111. # We attempt to reuse any other thread pools that have been created within
  112. # this thread and with the given number of workers. We reuse a global
  113. # thread pool if num_workers is not given and we're in the main thread.
  114. global default_pool
  115. thread = threading.current_thread()
  116. if pool is None:
  117. with pools_lock:
  118. if num_workers is None and thread is main_thread:
  119. if default_pool is None:
  120. default_pool = ThreadPool(CPU_COUNT)
  121. atexit.register(default_pool.close)
  122. pool = default_pool
  123. elif thread in pools and num_workers in pools[thread]:
  124. pool = pools[thread][num_workers]
  125. else:
  126. pool = ThreadPool(num_workers)
  127. atexit.register(pool.close)
  128. pools[thread][num_workers] = pool
  129. ray_callbacks = kwargs.pop("ray_callbacks", None)
  130. persist = kwargs.pop("ray_persist", False)
  131. enable_progress_bar = kwargs.pop("_ray_enable_progress_bar", None)
  132. # Handle Ray remote args and resource annotations.
  133. if "resources" in kwargs:
  134. raise ValueError(TOP_LEVEL_RESOURCES_ERR_MSG)
  135. ray_remote_args = kwargs.pop("ray_remote_args", {})
  136. annotations = dask.get_annotations()
  137. if "resources" in annotations:
  138. raise ValueError(TOP_LEVEL_RESOURCES_ERR_MSG)
  139. # Take out the dask graph if it is an Expr for dask>=2025.4.0.
  140. if not isinstance(dsk, Mapping):
  141. if hasattr(dsk, "_optimized_dsk"):
  142. # For Expr with this property
  143. dsk = dsk._optimized_dsk
  144. else:
  145. # For any other Expr
  146. dsk = dsk.__dask_graph__()
  147. scoped_ray_remote_args = _build_key_scoped_ray_remote_args(
  148. dsk, annotations, ray_remote_args
  149. )
  150. with local_ray_callbacks(ray_callbacks) as ray_callbacks:
  151. # Unpack the Ray-specific callbacks.
  152. (
  153. ray_presubmit_cbs,
  154. ray_postsubmit_cbs,
  155. ray_pretask_cbs,
  156. ray_posttask_cbs,
  157. ray_postsubmit_all_cbs,
  158. ray_finish_cbs,
  159. ) = unpack_ray_callbacks(ray_callbacks)
  160. # Make sure the graph is in the new format
  161. dsk = convert_legacy_graph(dsk)
  162. # NOTE: We hijack Dask's `get_async` function, injecting a different
  163. # task executor.
  164. object_refs = get_async(
  165. _apply_async_wrapper(
  166. pool.apply_async,
  167. _rayify_task_wrapper,
  168. ray_presubmit_cbs,
  169. ray_postsubmit_cbs,
  170. ray_pretask_cbs,
  171. ray_posttask_cbs,
  172. scoped_ray_remote_args,
  173. ),
  174. len(pool._pool),
  175. dsk,
  176. keys,
  177. get_id=_thread_get_id,
  178. pack_exception=pack_exception,
  179. **kwargs,
  180. )
  181. if ray_postsubmit_all_cbs is not None:
  182. for cb in ray_postsubmit_all_cbs:
  183. cb(object_refs, dsk)
  184. # NOTE: We explicitly delete the Dask graph here so object references
  185. # are garbage-collected before this function returns, i.e. before all
  186. # Ray tasks are done. Otherwise, no intermediate objects will be
  187. # cleaned up until all Ray tasks are done.
  188. del dsk
  189. if persist:
  190. result = object_refs
  191. else:
  192. pb_actor = None
  193. if enable_progress_bar:
  194. pb_actor = ray.get_actor("_dask_on_ray_pb")
  195. result = ray_get_unpack(object_refs, progress_bar_actor=pb_actor)
  196. if ray_finish_cbs is not None:
  197. for cb in ray_finish_cbs:
  198. cb(result)
  199. # cleanup pools associated with dead threads.
  200. with pools_lock:
  201. active_threads = set(threading.enumerate())
  202. if thread is not main_thread:
  203. for t in list(pools):
  204. if t not in active_threads:
  205. for p in pools.pop(t).values():
  206. p.close()
  207. return result
  208. def _apply_async_wrapper(apply_async, real_func, *extra_args, **extra_kwargs):
  209. """
  210. Wraps the given pool `apply_async` function, hotswapping `real_func` in as
  211. the function to be applied and adding `extra_args` and `extra_kwargs` to
  212. `real_func`'s call.
  213. Args:
  214. apply_async: The pool function to be wrapped.
  215. real_func: The real function that we wish the pool apply
  216. function to execute.
  217. *extra_args: Extra positional arguments to pass to the `real_func`.
  218. **extra_kwargs: Extra keyword arguments to pass to the `real_func`.
  219. Returns:
  220. A wrapper function that will ignore it's first `func` argument and
  221. pass `real_func` in its place. To be passed to `dask.local.get_async`.
  222. """
  223. def wrapper(func, args=(), kwds=None, callback=None): # noqa: M511
  224. if not kwds:
  225. kwds = {}
  226. return apply_async(
  227. real_func,
  228. args=args + extra_args,
  229. kwds=dict(kwds, **extra_kwargs),
  230. callback=callback,
  231. )
  232. return wrapper
  233. def _rayify_task_wrapper(
  234. key,
  235. task_info,
  236. dumps,
  237. loads,
  238. get_id,
  239. pack_exception,
  240. ray_presubmit_cbs,
  241. ray_postsubmit_cbs,
  242. ray_pretask_cbs,
  243. ray_posttask_cbs,
  244. scoped_ray_remote_args,
  245. ):
  246. """
  247. The core Ray-Dask task execution wrapper, to be given to the thread pool's
  248. `apply_async` function. Exactly the same as `execute_task`, except that it
  249. calls `_rayify_task` on the task instead of `_execute_task`.
  250. Args:
  251. key: The Dask graph key whose corresponding task we wish to
  252. execute.
  253. task_info: The task to execute and its dependencies.
  254. dumps: A result serializing function.
  255. loads: A task_info deserializing function.
  256. get_id: An ID generating function.
  257. pack_exception: An exception serializing function.
  258. ray_presubmit_cbs: Pre-task submission callbacks.
  259. ray_postsubmit_cbs: Post-task submission callbacks.
  260. ray_pretask_cbs: Pre-task execution callbacks.
  261. ray_posttask_cbs: Post-task execution callbacks.
  262. scoped_ray_remote_args: Ray task options for each key.
  263. Returns:
  264. A 3-tuple of the task's key, a literal or a Ray object reference for a
  265. Ray task's result, and whether the Ray task submission failed.
  266. """
  267. try:
  268. task, deps = loads(task_info)
  269. result = _rayify_task(
  270. task,
  271. key,
  272. deps,
  273. ray_presubmit_cbs,
  274. ray_postsubmit_cbs,
  275. ray_pretask_cbs,
  276. ray_posttask_cbs,
  277. scoped_ray_remote_args.get(key, {}),
  278. )
  279. id = get_id()
  280. result = dumps((result, id))
  281. failed = False
  282. except BaseException as e:
  283. result = pack_exception(e, dumps)
  284. failed = True
  285. return key, result, failed
  286. def _rayify_task(
  287. task,
  288. key,
  289. deps,
  290. ray_presubmit_cbs,
  291. ray_postsubmit_cbs,
  292. ray_pretask_cbs,
  293. ray_posttask_cbs,
  294. ray_remote_args,
  295. ):
  296. """
  297. Rayifies the given task, submitting it as a Ray task to the Ray cluster.
  298. Args:
  299. task: A Dask graph value, being either a literal, dependency
  300. key, Dask task, or a list thereof.
  301. key: The Dask graph key for the given task.
  302. deps: The dependencies of this task.
  303. ray_presubmit_cbs: Pre-task submission callbacks.
  304. ray_postsubmit_cbs: Post-task submission callbacks.
  305. ray_pretask_cbs: Pre-task execution callbacks.
  306. ray_posttask_cbs: Post-task execution callbacks.
  307. ray_remote_args: Ray task options. See :func:`ray.remote` for details.
  308. Returns:
  309. A literal, a Ray object reference representing a submitted task, or a
  310. list thereof.
  311. """
  312. if isinstance(task, list):
  313. # Recursively rayify this list. This will still bottom out at the first
  314. # actual task encountered, inlining any tasks in that task's arguments.
  315. return [
  316. _rayify_task(
  317. t,
  318. key,
  319. deps,
  320. ray_presubmit_cbs,
  321. ray_postsubmit_cbs,
  322. ray_pretask_cbs,
  323. ray_posttask_cbs,
  324. ray_remote_args,
  325. )
  326. for t in task
  327. ]
  328. elif istask(task):
  329. # Unpacks and repacks Ray object references and submits the task to the
  330. # Ray cluster for execution.
  331. if ray_presubmit_cbs is not None:
  332. alternate_returns = [cb(task, key, deps) for cb in ray_presubmit_cbs]
  333. for alternate_return in alternate_returns:
  334. # We don't submit a Ray task if a presubmit callback returns
  335. # a non-`None` value, instead we return said value.
  336. # NOTE: This returns the first non-None presubmit callback
  337. # return value.
  338. if alternate_return is not None:
  339. return alternate_return
  340. if isinstance(task, Alias):
  341. target = task.target
  342. if isinstance(target, TaskRef):
  343. # for 2024.12.0
  344. return deps[target.key]
  345. else:
  346. # for 2024.12.1+
  347. return deps[target]
  348. elif isinstance(task, Task):
  349. func = task.func
  350. else:
  351. raise ValueError("Invalid task type: %s" % type(task))
  352. # If the function's arguments contain nested object references, we must
  353. # unpack said object references into a flat set of arguments so that
  354. # Ray properly tracks the object dependencies between Ray tasks.
  355. arg_object_refs, repack = unpack_object_refs(deps)
  356. # Submit the task using a wrapper function.
  357. object_refs = dask_task_wrapper.options(
  358. name=f"dask:{key!s}",
  359. num_returns=(
  360. 1 if not isinstance(func, MultipleReturnFunc) else func.num_returns
  361. ),
  362. **ray_remote_args,
  363. ).remote(
  364. task,
  365. repack,
  366. key,
  367. ray_pretask_cbs,
  368. ray_posttask_cbs,
  369. *arg_object_refs,
  370. )
  371. if ray_postsubmit_cbs is not None:
  372. for cb in ray_postsubmit_cbs:
  373. cb(task, key, deps, object_refs)
  374. return object_refs
  375. elif not ishashable(task):
  376. return task
  377. elif task in deps:
  378. return deps[task]
  379. else:
  380. return task
  381. @ray.remote
  382. def dask_task_wrapper(
  383. task, repack, key, ray_pretask_cbs, ray_posttask_cbs, *arg_object_refs
  384. ):
  385. """
  386. A Ray remote function acting as a Dask task wrapper. This function will
  387. repackage the given `arg_object_refs` into its original `deps` using
  388. `repack`, and then pass it to the provided Dask Task object , `task`.
  389. Args:
  390. task: The Dask Task class object to execute.
  391. repack: A function that repackages the provided args into
  392. the original (possibly nested) Python objects.
  393. key: The Dask key for this task.
  394. ray_pretask_cbs: Pre-task execution callbacks.
  395. ray_posttask_cbs: Post-task execution callback.
  396. *arg_object_refs (ObjectRef): Ray object references representing the dependencies'
  397. results.
  398. Returns:
  399. The output of the Dask task. In the context of Ray, a
  400. dask_task_wrapper.remote() invocation will return a Ray object
  401. reference representing the Ray task's result.
  402. """
  403. if ray_pretask_cbs is not None:
  404. pre_states = [
  405. cb(key, arg_object_refs) if cb is not None else None
  406. for cb in ray_pretask_cbs
  407. ]
  408. (repacked_deps,) = repack(arg_object_refs)
  409. # De-reference the potentially nested arguments recursively.
  410. def _dereference_args(x):
  411. if isinstance(x, Task):
  412. x.args = _dereference_args(x.args)
  413. return x
  414. elif isinstance(x, Mapping):
  415. return {k: _dereference_args(v) for k, v in x.items()}
  416. elif isinstance(x, tuple):
  417. return tuple(_dereference_args(x) for x in x)
  418. elif isinstance(x, ray.ObjectRef):
  419. return ray.get(x)
  420. elif isinstance(x, DataNode):
  421. if isinstance(x.value, ray.ObjectRef):
  422. value = ray.get(x.value)
  423. return DataNode(key=x.key, value=value)
  424. return x
  425. else:
  426. return x
  427. task = _dereference_args(task)
  428. result = task(repacked_deps)
  429. if ray_posttask_cbs is not None:
  430. for cb, pre_state in zip(ray_posttask_cbs, pre_states):
  431. if cb is not None:
  432. cb(key, result, pre_state)
  433. return result
  434. def render_progress_bar(tracker, object_refs):
  435. from tqdm import tqdm
  436. # At this time, every task should be submitted.
  437. total, finished = ray.get(tracker.result.remote())
  438. reported_finished_so_far = 0
  439. pb_bar = tqdm(total=total, position=0)
  440. pb_bar.set_description("")
  441. ready_refs = []
  442. while finished < total:
  443. submitted, finished = ray.get(tracker.result.remote())
  444. pb_bar.update(finished - reported_finished_so_far)
  445. reported_finished_so_far = finished
  446. ready_refs, _ = ray.wait(
  447. object_refs, timeout=0, num_returns=len(object_refs), fetch_local=False
  448. )
  449. if len(ready_refs) == len(object_refs):
  450. break
  451. time.sleep(0.1)
  452. pb_bar.close()
  453. submitted, finished = ray.get(tracker.result.remote())
  454. if submitted != finished:
  455. print("Completed. There was state inconsistency.")
  456. pprint(ray.get(tracker.report.remote()))
  457. def ray_get_unpack(object_refs, progress_bar_actor=None):
  458. """
  459. Unpacks object references, gets the object references, and repacks.
  460. Traverses arbitrary data structures.
  461. Args:
  462. object_refs: A (potentially nested) Python object containing Ray object
  463. references.
  464. Returns:
  465. The input Python object with all contained Ray object references
  466. resolved with their concrete values.
  467. """
  468. def get_result(object_refs):
  469. if progress_bar_actor:
  470. render_progress_bar(progress_bar_actor, object_refs)
  471. return ray.get(object_refs)
  472. if isinstance(object_refs, tuple):
  473. object_refs = list(object_refs)
  474. if isinstance(object_refs, list) and any(
  475. not isinstance(x, ray.ObjectRef) for x in object_refs
  476. ):
  477. # We flatten the object references before calling ray.get(), since Dask
  478. # loves to nest collections in nested tuples and Ray expects a flat
  479. # list of object references. We repack the results after ray.get()
  480. # completes.
  481. object_refs, repack = unpack_object_refs(*object_refs)
  482. computed_result = get_result(object_refs)
  483. return repack(computed_result)
  484. else:
  485. return get_result(object_refs)
  486. def ray_dask_get_sync(dsk, keys, **kwargs):
  487. """
  488. A synchronous Dask-Ray scheduler. This scheduler will send top-level
  489. (non-inlined) Dask tasks to a Ray cluster for execution. The scheduler will
  490. wait for the tasks to finish executing, fetch the results, and repackage
  491. them into the appropriate Dask collections. This particular scheduler
  492. submits Ray tasks synchronously, which can be useful for debugging.
  493. This can be passed directly to `dask.compute()`, as the scheduler:
  494. >>> dask.compute(obj, scheduler=ray_dask_get_sync)
  495. You can override the currently active global Dask-Ray callbacks (e.g.
  496. supplied via a context manager):
  497. >>> dask.compute(
  498. obj,
  499. scheduler=ray_dask_get_sync,
  500. ray_callbacks=some_ray_dask_callbacks,
  501. )
  502. Args:
  503. dsk: Dask graph, represented as a task DAG dictionary.
  504. keys (List[str]): List of Dask graph keys whose values we wish to
  505. compute and return.
  506. Returns:
  507. Computed values corresponding to the provided keys.
  508. """
  509. ray_callbacks = kwargs.pop("ray_callbacks", None)
  510. persist = kwargs.pop("ray_persist", False)
  511. with local_ray_callbacks(ray_callbacks) as ray_callbacks:
  512. # Unpack the Ray-specific callbacks.
  513. (
  514. ray_presubmit_cbs,
  515. ray_postsubmit_cbs,
  516. ray_pretask_cbs,
  517. ray_posttask_cbs,
  518. ray_postsubmit_all_cbs,
  519. ray_finish_cbs,
  520. ) = unpack_ray_callbacks(ray_callbacks)
  521. # Make sure the graph is in the new format
  522. dsk = convert_legacy_graph(dsk)
  523. # NOTE: We hijack Dask's `get_async` function, injecting a different
  524. # task executor.
  525. object_refs = get_async(
  526. _apply_async_wrapper(
  527. apply_sync,
  528. _rayify_task_wrapper,
  529. ray_presubmit_cbs,
  530. ray_postsubmit_cbs,
  531. ray_pretask_cbs,
  532. ray_posttask_cbs,
  533. ),
  534. 1,
  535. dsk,
  536. keys,
  537. **kwargs,
  538. )
  539. if ray_postsubmit_all_cbs is not None:
  540. for cb in ray_postsubmit_all_cbs:
  541. cb(object_refs, dsk)
  542. # NOTE: We explicitly delete the Dask graph here so object references
  543. # are garbage-collected before this function returns, i.e. before all
  544. # Ray tasks are done. Otherwise, no intermediate objects will be
  545. # cleaned up until all Ray tasks are done.
  546. del dsk
  547. if persist:
  548. result = object_refs
  549. else:
  550. result = ray_get_unpack(object_refs)
  551. if ray_finish_cbs is not None:
  552. for cb in ray_finish_cbs:
  553. cb(result)
  554. return result
  555. @dataclass
  556. class MultipleReturnFunc:
  557. func: callable
  558. num_returns: int
  559. def __call__(self, *args, **kwargs):
  560. returns = self.func(*args, **kwargs)
  561. if isinstance(returns, dict) or isinstance(returns, OrderedDict):
  562. returns = [returns[k] for k in range(len(returns))]
  563. return returns
  564. def multiple_return_get(multiple_returns, idx):
  565. return multiple_returns[idx]
  566. def _build_key_scoped_ray_remote_args(dsk, annotations, ray_remote_args):
  567. # Handle per-layer annotations.
  568. if not isinstance(dsk, dask.highlevelgraph.HighLevelGraph):
  569. dsk = dask.highlevelgraph.HighLevelGraph.from_collections(
  570. id(dsk), dsk, dependencies=()
  571. )
  572. # Build key-scoped annotations.
  573. scoped_annotations = {}
  574. layers = [(name, dsk.layers[name]) for name in dsk._toposort_layers()]
  575. for id_, layer in layers:
  576. layer_annotations = layer.annotations
  577. if layer_annotations is None:
  578. layer_annotations = annotations
  579. elif "resources" in layer_annotations:
  580. raise ValueError(TOP_LEVEL_RESOURCES_ERR_MSG)
  581. for key in layer.get_output_keys():
  582. layer_annotations_for_key = annotations.copy()
  583. # Layer annotations override global annotations.
  584. layer_annotations_for_key.update(layer_annotations)
  585. # Let same-key annotations earlier in the topological sort take precedence.
  586. layer_annotations_for_key.update(scoped_annotations.get(key, {}))
  587. scoped_annotations[key] = layer_annotations_for_key
  588. # Build key-scoped Ray remote args.
  589. scoped_ray_remote_args = {}
  590. for key, annotations in scoped_annotations.items():
  591. layer_ray_remote_args = ray_remote_args.copy()
  592. # Layer Ray remote args override global Ray remote args given in the compute
  593. # call.
  594. layer_ray_remote_args.update(annotations.get("ray_remote_args", {}))
  595. scoped_ray_remote_args[key] = layer_ray_remote_args
  596. return scoped_ray_remote_args