| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678 |
- import atexit
- import threading
- import time
- import warnings
- from collections import OrderedDict, defaultdict
- from collections.abc import Mapping
- from dataclasses import dataclass
- from multiprocessing.pool import ThreadPool
- from pprint import pprint
- from typing import Optional
- import dask
- from dask.core import ishashable, istask
- try:
- from dask._task_spec import Alias, DataNode, Task, TaskRef, convert_legacy_graph
- except ImportError:
- warnings.warn(
- "Dask on Ray is available only on dask>=2024.11.0, "
- f"you are on version {dask.__version__}."
- )
- from dask.system import CPU_COUNT
- from dask.threaded import _thread_get_id, pack_exception
- import ray
- from ray.util.dask.callbacks import local_ray_callbacks, unpack_ray_callbacks
- from ray.util.dask.common import unpack_object_refs
- from ray.util.dask.scheduler_utils import apply_sync, get_async
- main_thread = threading.current_thread()
- default_pool = None
- pools = defaultdict(dict)
- pools_lock = threading.Lock()
- TOP_LEVEL_RESOURCES_ERR_MSG = (
- 'Use ray_remote_args={"resources": {...}} instead of resources={...} to specify '
- "required Ray task resources; see "
- "https://docs.ray.io/en/master/ray-core/package-ref.html#ray-remote."
- )
- def enable_dask_on_ray(
- shuffle: Optional[str] = "tasks",
- use_shuffle_optimization: Optional[bool] = True,
- ) -> dask.config.set:
- """
- Enable Dask-on-Ray scheduler. This helper sets the Dask-on-Ray scheduler
- as the default Dask scheduler in the Dask config. By default, it will also
- cause the task-based shuffle to be used for any Dask shuffle operations
- (required for multi-node Ray clusters, not sharing a filesystem), and will
- enable a Ray-specific shuffle optimization.
- >>> enable_dask_on_ray()
- >>> ddf.compute() # <-- will use the Dask-on-Ray scheduler.
- If used as a context manager, the Dask-on-Ray scheduler will only be used
- within the context's scope.
- >>> with enable_dask_on_ray():
- ... ddf.compute() # <-- will use the Dask-on-Ray scheduler.
- >>> ddf.compute() # <-- won't use the Dask-on-Ray scheduler.
- Args:
- shuffle: The shuffle method used by Dask, either "tasks" or
- "disk". This should be "tasks" if using a multi-node Ray cluster.
- Defaults to "tasks".
- use_shuffle_optimization: Enable our custom Ray-specific shuffle
- optimization. Defaults to True.
- Returns:
- The Dask config object, which can be used as a context manager to limit
- the scope of the Dask-on-Ray scheduler to the corresponding context.
- """
- if use_shuffle_optimization:
- from ray.util.dask.optimizations import dataframe_optimize
- else:
- dataframe_optimize = None
- # Manually set the global Dask scheduler config.
- # We also force the task-based shuffle to be used since the disk-based
- # shuffle doesn't work for a multi-node Ray cluster that doesn't share
- # the filesystem.
- return dask.config.set(
- scheduler=ray_dask_get, shuffle=shuffle, dataframe_optimize=dataframe_optimize
- )
- def disable_dask_on_ray():
- """
- Unsets the scheduler, shuffle method, and DataFrame optimizer.
- """
- return dask.config.set(scheduler=None, shuffle=None, dataframe_optimize=None)
- def ray_dask_get(dsk, keys, **kwargs):
- """
- A Dask-Ray scheduler. This scheduler will send top-level (non-inlined) Dask
- tasks to a Ray cluster for execution. The scheduler will wait for the
- tasks to finish executing, fetch the results, and repackage them into the
- appropriate Dask collections. This particular scheduler uses a threadpool
- to submit Ray tasks.
- This can be passed directly to `dask.compute()`, as the scheduler:
- >>> dask.compute(obj, scheduler=ray_dask_get)
- You can override the currently active global Dask-Ray callbacks (e.g.
- supplied via a context manager), the number of threads to use when
- submitting the Ray tasks, or the threadpool used to submit Ray tasks:
- >>> dask.compute(
- obj,
- scheduler=ray_dask_get,
- ray_callbacks=some_ray_dask_callbacks,
- num_workers=8,
- pool=some_cool_pool,
- )
- Args:
- dsk: Dask graph, represented as a task DAG dictionary.
- keys (List[str]): List of Dask graph keys whose values we wish to
- compute and return.
- ray_callbacks (Optional[list[callable]]): Dask-Ray callbacks.
- num_workers (Optional[int]): The number of worker threads to use in
- the Ray task submission traversal of the Dask graph.
- pool (Optional[ThreadPool]): A multiprocessing threadpool to use to
- submit Ray tasks.
- Returns:
- Computed values corresponding to the provided keys.
- """
- num_workers = kwargs.pop("num_workers", None)
- pool = kwargs.pop("pool", None)
- # We attempt to reuse any other thread pools that have been created within
- # this thread and with the given number of workers. We reuse a global
- # thread pool if num_workers is not given and we're in the main thread.
- global default_pool
- thread = threading.current_thread()
- if pool is None:
- with pools_lock:
- if num_workers is None and thread is main_thread:
- if default_pool is None:
- default_pool = ThreadPool(CPU_COUNT)
- atexit.register(default_pool.close)
- pool = default_pool
- elif thread in pools and num_workers in pools[thread]:
- pool = pools[thread][num_workers]
- else:
- pool = ThreadPool(num_workers)
- atexit.register(pool.close)
- pools[thread][num_workers] = pool
- ray_callbacks = kwargs.pop("ray_callbacks", None)
- persist = kwargs.pop("ray_persist", False)
- enable_progress_bar = kwargs.pop("_ray_enable_progress_bar", None)
- # Handle Ray remote args and resource annotations.
- if "resources" in kwargs:
- raise ValueError(TOP_LEVEL_RESOURCES_ERR_MSG)
- ray_remote_args = kwargs.pop("ray_remote_args", {})
- annotations = dask.get_annotations()
- if "resources" in annotations:
- raise ValueError(TOP_LEVEL_RESOURCES_ERR_MSG)
- # Take out the dask graph if it is an Expr for dask>=2025.4.0.
- if not isinstance(dsk, Mapping):
- if hasattr(dsk, "_optimized_dsk"):
- # For Expr with this property
- dsk = dsk._optimized_dsk
- else:
- # For any other Expr
- dsk = dsk.__dask_graph__()
- scoped_ray_remote_args = _build_key_scoped_ray_remote_args(
- dsk, annotations, ray_remote_args
- )
- with local_ray_callbacks(ray_callbacks) as ray_callbacks:
- # Unpack the Ray-specific callbacks.
- (
- ray_presubmit_cbs,
- ray_postsubmit_cbs,
- ray_pretask_cbs,
- ray_posttask_cbs,
- ray_postsubmit_all_cbs,
- ray_finish_cbs,
- ) = unpack_ray_callbacks(ray_callbacks)
- # Make sure the graph is in the new format
- dsk = convert_legacy_graph(dsk)
- # NOTE: We hijack Dask's `get_async` function, injecting a different
- # task executor.
- object_refs = get_async(
- _apply_async_wrapper(
- pool.apply_async,
- _rayify_task_wrapper,
- ray_presubmit_cbs,
- ray_postsubmit_cbs,
- ray_pretask_cbs,
- ray_posttask_cbs,
- scoped_ray_remote_args,
- ),
- len(pool._pool),
- dsk,
- keys,
- get_id=_thread_get_id,
- pack_exception=pack_exception,
- **kwargs,
- )
- if ray_postsubmit_all_cbs is not None:
- for cb in ray_postsubmit_all_cbs:
- cb(object_refs, dsk)
- # NOTE: We explicitly delete the Dask graph here so object references
- # are garbage-collected before this function returns, i.e. before all
- # Ray tasks are done. Otherwise, no intermediate objects will be
- # cleaned up until all Ray tasks are done.
- del dsk
- if persist:
- result = object_refs
- else:
- pb_actor = None
- if enable_progress_bar:
- pb_actor = ray.get_actor("_dask_on_ray_pb")
- result = ray_get_unpack(object_refs, progress_bar_actor=pb_actor)
- if ray_finish_cbs is not None:
- for cb in ray_finish_cbs:
- cb(result)
- # cleanup pools associated with dead threads.
- with pools_lock:
- active_threads = set(threading.enumerate())
- if thread is not main_thread:
- for t in list(pools):
- if t not in active_threads:
- for p in pools.pop(t).values():
- p.close()
- return result
- def _apply_async_wrapper(apply_async, real_func, *extra_args, **extra_kwargs):
- """
- Wraps the given pool `apply_async` function, hotswapping `real_func` in as
- the function to be applied and adding `extra_args` and `extra_kwargs` to
- `real_func`'s call.
- Args:
- apply_async: The pool function to be wrapped.
- real_func: The real function that we wish the pool apply
- function to execute.
- *extra_args: Extra positional arguments to pass to the `real_func`.
- **extra_kwargs: Extra keyword arguments to pass to the `real_func`.
- Returns:
- A wrapper function that will ignore it's first `func` argument and
- pass `real_func` in its place. To be passed to `dask.local.get_async`.
- """
- def wrapper(func, args=(), kwds=None, callback=None): # noqa: M511
- if not kwds:
- kwds = {}
- return apply_async(
- real_func,
- args=args + extra_args,
- kwds=dict(kwds, **extra_kwargs),
- callback=callback,
- )
- return wrapper
- def _rayify_task_wrapper(
- key,
- task_info,
- dumps,
- loads,
- get_id,
- pack_exception,
- ray_presubmit_cbs,
- ray_postsubmit_cbs,
- ray_pretask_cbs,
- ray_posttask_cbs,
- scoped_ray_remote_args,
- ):
- """
- The core Ray-Dask task execution wrapper, to be given to the thread pool's
- `apply_async` function. Exactly the same as `execute_task`, except that it
- calls `_rayify_task` on the task instead of `_execute_task`.
- Args:
- key: The Dask graph key whose corresponding task we wish to
- execute.
- task_info: The task to execute and its dependencies.
- dumps: A result serializing function.
- loads: A task_info deserializing function.
- get_id: An ID generating function.
- pack_exception: An exception serializing function.
- ray_presubmit_cbs: Pre-task submission callbacks.
- ray_postsubmit_cbs: Post-task submission callbacks.
- ray_pretask_cbs: Pre-task execution callbacks.
- ray_posttask_cbs: Post-task execution callbacks.
- scoped_ray_remote_args: Ray task options for each key.
- Returns:
- A 3-tuple of the task's key, a literal or a Ray object reference for a
- Ray task's result, and whether the Ray task submission failed.
- """
- try:
- task, deps = loads(task_info)
- result = _rayify_task(
- task,
- key,
- deps,
- ray_presubmit_cbs,
- ray_postsubmit_cbs,
- ray_pretask_cbs,
- ray_posttask_cbs,
- scoped_ray_remote_args.get(key, {}),
- )
- id = get_id()
- result = dumps((result, id))
- failed = False
- except BaseException as e:
- result = pack_exception(e, dumps)
- failed = True
- return key, result, failed
- def _rayify_task(
- task,
- key,
- deps,
- ray_presubmit_cbs,
- ray_postsubmit_cbs,
- ray_pretask_cbs,
- ray_posttask_cbs,
- ray_remote_args,
- ):
- """
- Rayifies the given task, submitting it as a Ray task to the Ray cluster.
- Args:
- task: A Dask graph value, being either a literal, dependency
- key, Dask task, or a list thereof.
- key: The Dask graph key for the given task.
- deps: The dependencies of this task.
- ray_presubmit_cbs: Pre-task submission callbacks.
- ray_postsubmit_cbs: Post-task submission callbacks.
- ray_pretask_cbs: Pre-task execution callbacks.
- ray_posttask_cbs: Post-task execution callbacks.
- ray_remote_args: Ray task options. See :func:`ray.remote` for details.
- Returns:
- A literal, a Ray object reference representing a submitted task, or a
- list thereof.
- """
- if isinstance(task, list):
- # Recursively rayify this list. This will still bottom out at the first
- # actual task encountered, inlining any tasks in that task's arguments.
- return [
- _rayify_task(
- t,
- key,
- deps,
- ray_presubmit_cbs,
- ray_postsubmit_cbs,
- ray_pretask_cbs,
- ray_posttask_cbs,
- ray_remote_args,
- )
- for t in task
- ]
- elif istask(task):
- # Unpacks and repacks Ray object references and submits the task to the
- # Ray cluster for execution.
- if ray_presubmit_cbs is not None:
- alternate_returns = [cb(task, key, deps) for cb in ray_presubmit_cbs]
- for alternate_return in alternate_returns:
- # We don't submit a Ray task if a presubmit callback returns
- # a non-`None` value, instead we return said value.
- # NOTE: This returns the first non-None presubmit callback
- # return value.
- if alternate_return is not None:
- return alternate_return
- if isinstance(task, Alias):
- target = task.target
- if isinstance(target, TaskRef):
- # for 2024.12.0
- return deps[target.key]
- else:
- # for 2024.12.1+
- return deps[target]
- elif isinstance(task, Task):
- func = task.func
- else:
- raise ValueError("Invalid task type: %s" % type(task))
- # If the function's arguments contain nested object references, we must
- # unpack said object references into a flat set of arguments so that
- # Ray properly tracks the object dependencies between Ray tasks.
- arg_object_refs, repack = unpack_object_refs(deps)
- # Submit the task using a wrapper function.
- object_refs = dask_task_wrapper.options(
- name=f"dask:{key!s}",
- num_returns=(
- 1 if not isinstance(func, MultipleReturnFunc) else func.num_returns
- ),
- **ray_remote_args,
- ).remote(
- task,
- repack,
- key,
- ray_pretask_cbs,
- ray_posttask_cbs,
- *arg_object_refs,
- )
- if ray_postsubmit_cbs is not None:
- for cb in ray_postsubmit_cbs:
- cb(task, key, deps, object_refs)
- return object_refs
- elif not ishashable(task):
- return task
- elif task in deps:
- return deps[task]
- else:
- return task
- @ray.remote
- def dask_task_wrapper(
- task, repack, key, ray_pretask_cbs, ray_posttask_cbs, *arg_object_refs
- ):
- """
- A Ray remote function acting as a Dask task wrapper. This function will
- repackage the given `arg_object_refs` into its original `deps` using
- `repack`, and then pass it to the provided Dask Task object , `task`.
- Args:
- task: The Dask Task class object to execute.
- repack: A function that repackages the provided args into
- the original (possibly nested) Python objects.
- key: The Dask key for this task.
- ray_pretask_cbs: Pre-task execution callbacks.
- ray_posttask_cbs: Post-task execution callback.
- *arg_object_refs (ObjectRef): Ray object references representing the dependencies'
- results.
- Returns:
- The output of the Dask task. In the context of Ray, a
- dask_task_wrapper.remote() invocation will return a Ray object
- reference representing the Ray task's result.
- """
- if ray_pretask_cbs is not None:
- pre_states = [
- cb(key, arg_object_refs) if cb is not None else None
- for cb in ray_pretask_cbs
- ]
- (repacked_deps,) = repack(arg_object_refs)
- # De-reference the potentially nested arguments recursively.
- def _dereference_args(x):
- if isinstance(x, Task):
- x.args = _dereference_args(x.args)
- return x
- elif isinstance(x, Mapping):
- return {k: _dereference_args(v) for k, v in x.items()}
- elif isinstance(x, tuple):
- return tuple(_dereference_args(x) for x in x)
- elif isinstance(x, ray.ObjectRef):
- return ray.get(x)
- elif isinstance(x, DataNode):
- if isinstance(x.value, ray.ObjectRef):
- value = ray.get(x.value)
- return DataNode(key=x.key, value=value)
- return x
- else:
- return x
- task = _dereference_args(task)
- result = task(repacked_deps)
- if ray_posttask_cbs is not None:
- for cb, pre_state in zip(ray_posttask_cbs, pre_states):
- if cb is not None:
- cb(key, result, pre_state)
- return result
- def render_progress_bar(tracker, object_refs):
- from tqdm import tqdm
- # At this time, every task should be submitted.
- total, finished = ray.get(tracker.result.remote())
- reported_finished_so_far = 0
- pb_bar = tqdm(total=total, position=0)
- pb_bar.set_description("")
- ready_refs = []
- while finished < total:
- submitted, finished = ray.get(tracker.result.remote())
- pb_bar.update(finished - reported_finished_so_far)
- reported_finished_so_far = finished
- ready_refs, _ = ray.wait(
- object_refs, timeout=0, num_returns=len(object_refs), fetch_local=False
- )
- if len(ready_refs) == len(object_refs):
- break
- time.sleep(0.1)
- pb_bar.close()
- submitted, finished = ray.get(tracker.result.remote())
- if submitted != finished:
- print("Completed. There was state inconsistency.")
- pprint(ray.get(tracker.report.remote()))
- def ray_get_unpack(object_refs, progress_bar_actor=None):
- """
- Unpacks object references, gets the object references, and repacks.
- Traverses arbitrary data structures.
- Args:
- object_refs: A (potentially nested) Python object containing Ray object
- references.
- Returns:
- The input Python object with all contained Ray object references
- resolved with their concrete values.
- """
- def get_result(object_refs):
- if progress_bar_actor:
- render_progress_bar(progress_bar_actor, object_refs)
- return ray.get(object_refs)
- if isinstance(object_refs, tuple):
- object_refs = list(object_refs)
- if isinstance(object_refs, list) and any(
- not isinstance(x, ray.ObjectRef) for x in object_refs
- ):
- # We flatten the object references before calling ray.get(), since Dask
- # loves to nest collections in nested tuples and Ray expects a flat
- # list of object references. We repack the results after ray.get()
- # completes.
- object_refs, repack = unpack_object_refs(*object_refs)
- computed_result = get_result(object_refs)
- return repack(computed_result)
- else:
- return get_result(object_refs)
- def ray_dask_get_sync(dsk, keys, **kwargs):
- """
- A synchronous Dask-Ray scheduler. This scheduler will send top-level
- (non-inlined) Dask tasks to a Ray cluster for execution. The scheduler will
- wait for the tasks to finish executing, fetch the results, and repackage
- them into the appropriate Dask collections. This particular scheduler
- submits Ray tasks synchronously, which can be useful for debugging.
- This can be passed directly to `dask.compute()`, as the scheduler:
- >>> dask.compute(obj, scheduler=ray_dask_get_sync)
- You can override the currently active global Dask-Ray callbacks (e.g.
- supplied via a context manager):
- >>> dask.compute(
- obj,
- scheduler=ray_dask_get_sync,
- ray_callbacks=some_ray_dask_callbacks,
- )
- Args:
- dsk: Dask graph, represented as a task DAG dictionary.
- keys (List[str]): List of Dask graph keys whose values we wish to
- compute and return.
- Returns:
- Computed values corresponding to the provided keys.
- """
- ray_callbacks = kwargs.pop("ray_callbacks", None)
- persist = kwargs.pop("ray_persist", False)
- with local_ray_callbacks(ray_callbacks) as ray_callbacks:
- # Unpack the Ray-specific callbacks.
- (
- ray_presubmit_cbs,
- ray_postsubmit_cbs,
- ray_pretask_cbs,
- ray_posttask_cbs,
- ray_postsubmit_all_cbs,
- ray_finish_cbs,
- ) = unpack_ray_callbacks(ray_callbacks)
- # Make sure the graph is in the new format
- dsk = convert_legacy_graph(dsk)
- # NOTE: We hijack Dask's `get_async` function, injecting a different
- # task executor.
- object_refs = get_async(
- _apply_async_wrapper(
- apply_sync,
- _rayify_task_wrapper,
- ray_presubmit_cbs,
- ray_postsubmit_cbs,
- ray_pretask_cbs,
- ray_posttask_cbs,
- ),
- 1,
- dsk,
- keys,
- **kwargs,
- )
- if ray_postsubmit_all_cbs is not None:
- for cb in ray_postsubmit_all_cbs:
- cb(object_refs, dsk)
- # NOTE: We explicitly delete the Dask graph here so object references
- # are garbage-collected before this function returns, i.e. before all
- # Ray tasks are done. Otherwise, no intermediate objects will be
- # cleaned up until all Ray tasks are done.
- del dsk
- if persist:
- result = object_refs
- else:
- result = ray_get_unpack(object_refs)
- if ray_finish_cbs is not None:
- for cb in ray_finish_cbs:
- cb(result)
- return result
- @dataclass
- class MultipleReturnFunc:
- func: callable
- num_returns: int
- def __call__(self, *args, **kwargs):
- returns = self.func(*args, **kwargs)
- if isinstance(returns, dict) or isinstance(returns, OrderedDict):
- returns = [returns[k] for k in range(len(returns))]
- return returns
- def multiple_return_get(multiple_returns, idx):
- return multiple_returns[idx]
- def _build_key_scoped_ray_remote_args(dsk, annotations, ray_remote_args):
- # Handle per-layer annotations.
- if not isinstance(dsk, dask.highlevelgraph.HighLevelGraph):
- dsk = dask.highlevelgraph.HighLevelGraph.from_collections(
- id(dsk), dsk, dependencies=()
- )
- # Build key-scoped annotations.
- scoped_annotations = {}
- layers = [(name, dsk.layers[name]) for name in dsk._toposort_layers()]
- for id_, layer in layers:
- layer_annotations = layer.annotations
- if layer_annotations is None:
- layer_annotations = annotations
- elif "resources" in layer_annotations:
- raise ValueError(TOP_LEVEL_RESOURCES_ERR_MSG)
- for key in layer.get_output_keys():
- layer_annotations_for_key = annotations.copy()
- # Layer annotations override global annotations.
- layer_annotations_for_key.update(layer_annotations)
- # Let same-key annotations earlier in the topological sort take precedence.
- layer_annotations_for_key.update(scoped_annotations.get(key, {}))
- scoped_annotations[key] = layer_annotations_for_key
- # Build key-scoped Ray remote args.
- scoped_ray_remote_args = {}
- for key, annotations in scoped_annotations.items():
- layer_ray_remote_args = ray_remote_args.copy()
- # Layer Ray remote args override global Ray remote args given in the compute
- # call.
- layer_ray_remote_args.update(annotations.get("ray_remote_args", {}))
- scoped_ray_remote_args[key] = layer_ray_remote_args
- return scoped_ray_remote_args
|