import atexit import threading import time import warnings from collections import OrderedDict, defaultdict from collections.abc import Mapping from dataclasses import dataclass from multiprocessing.pool import ThreadPool from pprint import pprint from typing import Optional import dask from dask.core import ishashable, istask try: from dask._task_spec import Alias, DataNode, Task, TaskRef, convert_legacy_graph except ImportError: warnings.warn( "Dask on Ray is available only on dask>=2024.11.0, " f"you are on version {dask.__version__}." ) from dask.system import CPU_COUNT from dask.threaded import _thread_get_id, pack_exception import ray from ray.util.dask.callbacks import local_ray_callbacks, unpack_ray_callbacks from ray.util.dask.common import unpack_object_refs from ray.util.dask.scheduler_utils import apply_sync, get_async main_thread = threading.current_thread() default_pool = None pools = defaultdict(dict) pools_lock = threading.Lock() TOP_LEVEL_RESOURCES_ERR_MSG = ( 'Use ray_remote_args={"resources": {...}} instead of resources={...} to specify ' "required Ray task resources; see " "https://docs.ray.io/en/master/ray-core/package-ref.html#ray-remote." ) def enable_dask_on_ray( shuffle: Optional[str] = "tasks", use_shuffle_optimization: Optional[bool] = True, ) -> dask.config.set: """ Enable Dask-on-Ray scheduler. This helper sets the Dask-on-Ray scheduler as the default Dask scheduler in the Dask config. By default, it will also cause the task-based shuffle to be used for any Dask shuffle operations (required for multi-node Ray clusters, not sharing a filesystem), and will enable a Ray-specific shuffle optimization. >>> enable_dask_on_ray() >>> ddf.compute() # <-- will use the Dask-on-Ray scheduler. If used as a context manager, the Dask-on-Ray scheduler will only be used within the context's scope. >>> with enable_dask_on_ray(): ... ddf.compute() # <-- will use the Dask-on-Ray scheduler. >>> ddf.compute() # <-- won't use the Dask-on-Ray scheduler. Args: shuffle: The shuffle method used by Dask, either "tasks" or "disk". This should be "tasks" if using a multi-node Ray cluster. Defaults to "tasks". use_shuffle_optimization: Enable our custom Ray-specific shuffle optimization. Defaults to True. Returns: The Dask config object, which can be used as a context manager to limit the scope of the Dask-on-Ray scheduler to the corresponding context. """ if use_shuffle_optimization: from ray.util.dask.optimizations import dataframe_optimize else: dataframe_optimize = None # Manually set the global Dask scheduler config. # We also force the task-based shuffle to be used since the disk-based # shuffle doesn't work for a multi-node Ray cluster that doesn't share # the filesystem. return dask.config.set( scheduler=ray_dask_get, shuffle=shuffle, dataframe_optimize=dataframe_optimize ) def disable_dask_on_ray(): """ Unsets the scheduler, shuffle method, and DataFrame optimizer. """ return dask.config.set(scheduler=None, shuffle=None, dataframe_optimize=None) def ray_dask_get(dsk, keys, **kwargs): """ A Dask-Ray scheduler. This scheduler will send top-level (non-inlined) Dask tasks to a Ray cluster for execution. The scheduler will wait for the tasks to finish executing, fetch the results, and repackage them into the appropriate Dask collections. This particular scheduler uses a threadpool to submit Ray tasks. This can be passed directly to `dask.compute()`, as the scheduler: >>> dask.compute(obj, scheduler=ray_dask_get) You can override the currently active global Dask-Ray callbacks (e.g. supplied via a context manager), the number of threads to use when submitting the Ray tasks, or the threadpool used to submit Ray tasks: >>> dask.compute( obj, scheduler=ray_dask_get, ray_callbacks=some_ray_dask_callbacks, num_workers=8, pool=some_cool_pool, ) Args: dsk: Dask graph, represented as a task DAG dictionary. keys (List[str]): List of Dask graph keys whose values we wish to compute and return. ray_callbacks (Optional[list[callable]]): Dask-Ray callbacks. num_workers (Optional[int]): The number of worker threads to use in the Ray task submission traversal of the Dask graph. pool (Optional[ThreadPool]): A multiprocessing threadpool to use to submit Ray tasks. Returns: Computed values corresponding to the provided keys. """ num_workers = kwargs.pop("num_workers", None) pool = kwargs.pop("pool", None) # We attempt to reuse any other thread pools that have been created within # this thread and with the given number of workers. We reuse a global # thread pool if num_workers is not given and we're in the main thread. global default_pool thread = threading.current_thread() if pool is None: with pools_lock: if num_workers is None and thread is main_thread: if default_pool is None: default_pool = ThreadPool(CPU_COUNT) atexit.register(default_pool.close) pool = default_pool elif thread in pools and num_workers in pools[thread]: pool = pools[thread][num_workers] else: pool = ThreadPool(num_workers) atexit.register(pool.close) pools[thread][num_workers] = pool ray_callbacks = kwargs.pop("ray_callbacks", None) persist = kwargs.pop("ray_persist", False) enable_progress_bar = kwargs.pop("_ray_enable_progress_bar", None) # Handle Ray remote args and resource annotations. if "resources" in kwargs: raise ValueError(TOP_LEVEL_RESOURCES_ERR_MSG) ray_remote_args = kwargs.pop("ray_remote_args", {}) annotations = dask.get_annotations() if "resources" in annotations: raise ValueError(TOP_LEVEL_RESOURCES_ERR_MSG) # Take out the dask graph if it is an Expr for dask>=2025.4.0. if not isinstance(dsk, Mapping): if hasattr(dsk, "_optimized_dsk"): # For Expr with this property dsk = dsk._optimized_dsk else: # For any other Expr dsk = dsk.__dask_graph__() scoped_ray_remote_args = _build_key_scoped_ray_remote_args( dsk, annotations, ray_remote_args ) with local_ray_callbacks(ray_callbacks) as ray_callbacks: # Unpack the Ray-specific callbacks. ( ray_presubmit_cbs, ray_postsubmit_cbs, ray_pretask_cbs, ray_posttask_cbs, ray_postsubmit_all_cbs, ray_finish_cbs, ) = unpack_ray_callbacks(ray_callbacks) # Make sure the graph is in the new format dsk = convert_legacy_graph(dsk) # NOTE: We hijack Dask's `get_async` function, injecting a different # task executor. object_refs = get_async( _apply_async_wrapper( pool.apply_async, _rayify_task_wrapper, ray_presubmit_cbs, ray_postsubmit_cbs, ray_pretask_cbs, ray_posttask_cbs, scoped_ray_remote_args, ), len(pool._pool), dsk, keys, get_id=_thread_get_id, pack_exception=pack_exception, **kwargs, ) if ray_postsubmit_all_cbs is not None: for cb in ray_postsubmit_all_cbs: cb(object_refs, dsk) # NOTE: We explicitly delete the Dask graph here so object references # are garbage-collected before this function returns, i.e. before all # Ray tasks are done. Otherwise, no intermediate objects will be # cleaned up until all Ray tasks are done. del dsk if persist: result = object_refs else: pb_actor = None if enable_progress_bar: pb_actor = ray.get_actor("_dask_on_ray_pb") result = ray_get_unpack(object_refs, progress_bar_actor=pb_actor) if ray_finish_cbs is not None: for cb in ray_finish_cbs: cb(result) # cleanup pools associated with dead threads. with pools_lock: active_threads = set(threading.enumerate()) if thread is not main_thread: for t in list(pools): if t not in active_threads: for p in pools.pop(t).values(): p.close() return result def _apply_async_wrapper(apply_async, real_func, *extra_args, **extra_kwargs): """ Wraps the given pool `apply_async` function, hotswapping `real_func` in as the function to be applied and adding `extra_args` and `extra_kwargs` to `real_func`'s call. Args: apply_async: The pool function to be wrapped. real_func: The real function that we wish the pool apply function to execute. *extra_args: Extra positional arguments to pass to the `real_func`. **extra_kwargs: Extra keyword arguments to pass to the `real_func`. Returns: A wrapper function that will ignore it's first `func` argument and pass `real_func` in its place. To be passed to `dask.local.get_async`. """ def wrapper(func, args=(), kwds=None, callback=None): # noqa: M511 if not kwds: kwds = {} return apply_async( real_func, args=args + extra_args, kwds=dict(kwds, **extra_kwargs), callback=callback, ) return wrapper def _rayify_task_wrapper( key, task_info, dumps, loads, get_id, pack_exception, ray_presubmit_cbs, ray_postsubmit_cbs, ray_pretask_cbs, ray_posttask_cbs, scoped_ray_remote_args, ): """ The core Ray-Dask task execution wrapper, to be given to the thread pool's `apply_async` function. Exactly the same as `execute_task`, except that it calls `_rayify_task` on the task instead of `_execute_task`. Args: key: The Dask graph key whose corresponding task we wish to execute. task_info: The task to execute and its dependencies. dumps: A result serializing function. loads: A task_info deserializing function. get_id: An ID generating function. pack_exception: An exception serializing function. ray_presubmit_cbs: Pre-task submission callbacks. ray_postsubmit_cbs: Post-task submission callbacks. ray_pretask_cbs: Pre-task execution callbacks. ray_posttask_cbs: Post-task execution callbacks. scoped_ray_remote_args: Ray task options for each key. Returns: A 3-tuple of the task's key, a literal or a Ray object reference for a Ray task's result, and whether the Ray task submission failed. """ try: task, deps = loads(task_info) result = _rayify_task( task, key, deps, ray_presubmit_cbs, ray_postsubmit_cbs, ray_pretask_cbs, ray_posttask_cbs, scoped_ray_remote_args.get(key, {}), ) id = get_id() result = dumps((result, id)) failed = False except BaseException as e: result = pack_exception(e, dumps) failed = True return key, result, failed def _rayify_task( task, key, deps, ray_presubmit_cbs, ray_postsubmit_cbs, ray_pretask_cbs, ray_posttask_cbs, ray_remote_args, ): """ Rayifies the given task, submitting it as a Ray task to the Ray cluster. Args: task: A Dask graph value, being either a literal, dependency key, Dask task, or a list thereof. key: The Dask graph key for the given task. deps: The dependencies of this task. ray_presubmit_cbs: Pre-task submission callbacks. ray_postsubmit_cbs: Post-task submission callbacks. ray_pretask_cbs: Pre-task execution callbacks. ray_posttask_cbs: Post-task execution callbacks. ray_remote_args: Ray task options. See :func:`ray.remote` for details. Returns: A literal, a Ray object reference representing a submitted task, or a list thereof. """ if isinstance(task, list): # Recursively rayify this list. This will still bottom out at the first # actual task encountered, inlining any tasks in that task's arguments. return [ _rayify_task( t, key, deps, ray_presubmit_cbs, ray_postsubmit_cbs, ray_pretask_cbs, ray_posttask_cbs, ray_remote_args, ) for t in task ] elif istask(task): # Unpacks and repacks Ray object references and submits the task to the # Ray cluster for execution. if ray_presubmit_cbs is not None: alternate_returns = [cb(task, key, deps) for cb in ray_presubmit_cbs] for alternate_return in alternate_returns: # We don't submit a Ray task if a presubmit callback returns # a non-`None` value, instead we return said value. # NOTE: This returns the first non-None presubmit callback # return value. if alternate_return is not None: return alternate_return if isinstance(task, Alias): target = task.target if isinstance(target, TaskRef): # for 2024.12.0 return deps[target.key] else: # for 2024.12.1+ return deps[target] elif isinstance(task, Task): func = task.func else: raise ValueError("Invalid task type: %s" % type(task)) # If the function's arguments contain nested object references, we must # unpack said object references into a flat set of arguments so that # Ray properly tracks the object dependencies between Ray tasks. arg_object_refs, repack = unpack_object_refs(deps) # Submit the task using a wrapper function. object_refs = dask_task_wrapper.options( name=f"dask:{key!s}", num_returns=( 1 if not isinstance(func, MultipleReturnFunc) else func.num_returns ), **ray_remote_args, ).remote( task, repack, key, ray_pretask_cbs, ray_posttask_cbs, *arg_object_refs, ) if ray_postsubmit_cbs is not None: for cb in ray_postsubmit_cbs: cb(task, key, deps, object_refs) return object_refs elif not ishashable(task): return task elif task in deps: return deps[task] else: return task @ray.remote def dask_task_wrapper( task, repack, key, ray_pretask_cbs, ray_posttask_cbs, *arg_object_refs ): """ A Ray remote function acting as a Dask task wrapper. This function will repackage the given `arg_object_refs` into its original `deps` using `repack`, and then pass it to the provided Dask Task object , `task`. Args: task: The Dask Task class object to execute. repack: A function that repackages the provided args into the original (possibly nested) Python objects. key: The Dask key for this task. ray_pretask_cbs: Pre-task execution callbacks. ray_posttask_cbs: Post-task execution callback. *arg_object_refs (ObjectRef): Ray object references representing the dependencies' results. Returns: The output of the Dask task. In the context of Ray, a dask_task_wrapper.remote() invocation will return a Ray object reference representing the Ray task's result. """ if ray_pretask_cbs is not None: pre_states = [ cb(key, arg_object_refs) if cb is not None else None for cb in ray_pretask_cbs ] (repacked_deps,) = repack(arg_object_refs) # De-reference the potentially nested arguments recursively. def _dereference_args(x): if isinstance(x, Task): x.args = _dereference_args(x.args) return x elif isinstance(x, Mapping): return {k: _dereference_args(v) for k, v in x.items()} elif isinstance(x, tuple): return tuple(_dereference_args(x) for x in x) elif isinstance(x, ray.ObjectRef): return ray.get(x) elif isinstance(x, DataNode): if isinstance(x.value, ray.ObjectRef): value = ray.get(x.value) return DataNode(key=x.key, value=value) return x else: return x task = _dereference_args(task) result = task(repacked_deps) if ray_posttask_cbs is not None: for cb, pre_state in zip(ray_posttask_cbs, pre_states): if cb is not None: cb(key, result, pre_state) return result def render_progress_bar(tracker, object_refs): from tqdm import tqdm # At this time, every task should be submitted. total, finished = ray.get(tracker.result.remote()) reported_finished_so_far = 0 pb_bar = tqdm(total=total, position=0) pb_bar.set_description("") ready_refs = [] while finished < total: submitted, finished = ray.get(tracker.result.remote()) pb_bar.update(finished - reported_finished_so_far) reported_finished_so_far = finished ready_refs, _ = ray.wait( object_refs, timeout=0, num_returns=len(object_refs), fetch_local=False ) if len(ready_refs) == len(object_refs): break time.sleep(0.1) pb_bar.close() submitted, finished = ray.get(tracker.result.remote()) if submitted != finished: print("Completed. There was state inconsistency.") pprint(ray.get(tracker.report.remote())) def ray_get_unpack(object_refs, progress_bar_actor=None): """ Unpacks object references, gets the object references, and repacks. Traverses arbitrary data structures. Args: object_refs: A (potentially nested) Python object containing Ray object references. Returns: The input Python object with all contained Ray object references resolved with their concrete values. """ def get_result(object_refs): if progress_bar_actor: render_progress_bar(progress_bar_actor, object_refs) return ray.get(object_refs) if isinstance(object_refs, tuple): object_refs = list(object_refs) if isinstance(object_refs, list) and any( not isinstance(x, ray.ObjectRef) for x in object_refs ): # We flatten the object references before calling ray.get(), since Dask # loves to nest collections in nested tuples and Ray expects a flat # list of object references. We repack the results after ray.get() # completes. object_refs, repack = unpack_object_refs(*object_refs) computed_result = get_result(object_refs) return repack(computed_result) else: return get_result(object_refs) def ray_dask_get_sync(dsk, keys, **kwargs): """ A synchronous Dask-Ray scheduler. This scheduler will send top-level (non-inlined) Dask tasks to a Ray cluster for execution. The scheduler will wait for the tasks to finish executing, fetch the results, and repackage them into the appropriate Dask collections. This particular scheduler submits Ray tasks synchronously, which can be useful for debugging. This can be passed directly to `dask.compute()`, as the scheduler: >>> dask.compute(obj, scheduler=ray_dask_get_sync) You can override the currently active global Dask-Ray callbacks (e.g. supplied via a context manager): >>> dask.compute( obj, scheduler=ray_dask_get_sync, ray_callbacks=some_ray_dask_callbacks, ) Args: dsk: Dask graph, represented as a task DAG dictionary. keys (List[str]): List of Dask graph keys whose values we wish to compute and return. Returns: Computed values corresponding to the provided keys. """ ray_callbacks = kwargs.pop("ray_callbacks", None) persist = kwargs.pop("ray_persist", False) with local_ray_callbacks(ray_callbacks) as ray_callbacks: # Unpack the Ray-specific callbacks. ( ray_presubmit_cbs, ray_postsubmit_cbs, ray_pretask_cbs, ray_posttask_cbs, ray_postsubmit_all_cbs, ray_finish_cbs, ) = unpack_ray_callbacks(ray_callbacks) # Make sure the graph is in the new format dsk = convert_legacy_graph(dsk) # NOTE: We hijack Dask's `get_async` function, injecting a different # task executor. object_refs = get_async( _apply_async_wrapper( apply_sync, _rayify_task_wrapper, ray_presubmit_cbs, ray_postsubmit_cbs, ray_pretask_cbs, ray_posttask_cbs, ), 1, dsk, keys, **kwargs, ) if ray_postsubmit_all_cbs is not None: for cb in ray_postsubmit_all_cbs: cb(object_refs, dsk) # NOTE: We explicitly delete the Dask graph here so object references # are garbage-collected before this function returns, i.e. before all # Ray tasks are done. Otherwise, no intermediate objects will be # cleaned up until all Ray tasks are done. del dsk if persist: result = object_refs else: result = ray_get_unpack(object_refs) if ray_finish_cbs is not None: for cb in ray_finish_cbs: cb(result) return result @dataclass class MultipleReturnFunc: func: callable num_returns: int def __call__(self, *args, **kwargs): returns = self.func(*args, **kwargs) if isinstance(returns, dict) or isinstance(returns, OrderedDict): returns = [returns[k] for k in range(len(returns))] return returns def multiple_return_get(multiple_returns, idx): return multiple_returns[idx] def _build_key_scoped_ray_remote_args(dsk, annotations, ray_remote_args): # Handle per-layer annotations. if not isinstance(dsk, dask.highlevelgraph.HighLevelGraph): dsk = dask.highlevelgraph.HighLevelGraph.from_collections( id(dsk), dsk, dependencies=() ) # Build key-scoped annotations. scoped_annotations = {} layers = [(name, dsk.layers[name]) for name in dsk._toposort_layers()] for id_, layer in layers: layer_annotations = layer.annotations if layer_annotations is None: layer_annotations = annotations elif "resources" in layer_annotations: raise ValueError(TOP_LEVEL_RESOURCES_ERR_MSG) for key in layer.get_output_keys(): layer_annotations_for_key = annotations.copy() # Layer annotations override global annotations. layer_annotations_for_key.update(layer_annotations) # Let same-key annotations earlier in the topological sort take precedence. layer_annotations_for_key.update(scoped_annotations.get(key, {})) scoped_annotations[key] = layer_annotations_for_key # Build key-scoped Ray remote args. scoped_ray_remote_args = {} for key, annotations in scoped_annotations.items(): layer_ray_remote_args = ray_remote_args.copy() # Layer Ray remote args override global Ray remote args given in the compute # call. layer_ray_remote_args.update(annotations.get("ray_remote_args", {})) scoped_ray_remote_args[key] = layer_ray_remote_args return scoped_ray_remote_args