| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382 |
- """
- The following is adapted from Dask release 2021.03.1:
- https://github.com/dask/dask/blob/2021.03.1/dask/local.py
- """
- import os
- import warnings
- from queue import Empty, Queue
- import dask
- from dask import config
- try:
- from dask._task_spec import DataNode, DependenciesMapping
- except ImportError:
- warnings.warn(
- "Dask on Ray is available only on dask>=2024.11.0, "
- f"you are on version {dask.__version__}."
- )
- from dask.callbacks import local_callbacks, unpack_callbacks
- from dask.core import flatten, get_dependencies, reverse_dict
- from dask.order import order
- if os.name == "nt":
- # Python 3 windows Queue.get doesn't handle interrupts properly. To
- # workaround this we poll at a sufficiently large interval that it
- # shouldn't affect performance, but small enough that users trying to kill
- # an application shouldn't care.
- def queue_get(q):
- while True:
- try:
- return q.get(block=True, timeout=0.1)
- except Empty:
- pass
- else:
- def queue_get(q):
- return q.get()
- def start_state_from_dask(dsk, cache=None, sortkey=None):
- """Start state from a dask
- Examples
- --------
- >>> dsk = {
- 'x': 1,
- 'y': 2,
- 'z': (inc, 'x'),
- 'w': (add, 'z', 'y')} # doctest: +SKIP
- >>> from pprint import pprint # doctest: +SKIP
- >>> pprint(start_state_from_dask(dsk)) # doctest: +SKIP
- {'cache': {'x': 1, 'y': 2},
- 'dependencies': {'w': {'z', 'y'}, 'x': set(), 'y': set(), 'z': {'x'}},
- 'dependents': {'w': set(), 'x': {'z'}, 'y': {'w'}, 'z': {'w'}},
- 'finished': set(),
- 'ready': ['z'],
- 'released': set(),
- 'running': set(),
- 'waiting': {'w': {'z'}},
- 'waiting_data': {'x': {'z'}, 'y': {'w'}, 'z': {'w'}}}
- """
- if sortkey is None:
- sortkey = order(dsk).get
- if cache is None:
- cache = config.get("cache", None)
- if cache is None:
- cache = dict()
- data_keys = set()
- for k, v in dsk.items():
- if isinstance(v, DataNode):
- cache[k] = v()
- data_keys.add(k)
- dsk2 = dsk.copy()
- dsk2.update(cache)
- dependencies = DependenciesMapping(dsk)
- waiting = {k: set(v) for k, v in dependencies.items() if k not in data_keys}
- dependents = reverse_dict(dependencies)
- for a in cache:
- for b in dependents.get(a, ()):
- waiting[b].remove(a)
- waiting_data = {k: v.copy() for k, v in dependents.items() if v}
- ready_set = {k for k, v in waiting.items() if not v}
- ready = sorted(ready_set, key=sortkey, reverse=True)
- waiting = {k: v for k, v in waiting.items() if v}
- state = {
- "dependencies": dependencies,
- "dependents": dependents,
- "waiting": waiting,
- "waiting_data": waiting_data,
- "cache": cache,
- "ready": ready,
- "running": set(),
- "finished": set(),
- "released": set(),
- }
- return state
- def execute_task(key, task_info, dumps, loads, get_id, pack_exception):
- """
- Compute task and handle all administration
- See Also
- --------
- _execute_task : actually execute task
- """
- try:
- task, data = loads(task_info)
- result = task(data)
- id = get_id()
- result = dumps((result, id))
- failed = False
- except BaseException as e:
- result = pack_exception(e, dumps)
- failed = True
- return key, result, failed
- def release_data(key, state, delete=True):
- """Remove data from temporary storage
- See Also
- --------
- finish_task
- """
- if key in state["waiting_data"]:
- assert not state["waiting_data"][key]
- del state["waiting_data"][key]
- state["released"].add(key)
- if delete:
- del state["cache"][key]
- DEBUG = False
- def finish_task(
- dsk, key, state, results, sortkey, delete=True, release_data=release_data
- ):
- """
- Update execution state after a task finishes
- Mutates. This should run atomically (with a lock).
- """
- for dep in sorted(state["dependents"][key], key=sortkey, reverse=True):
- s = state["waiting"][dep]
- s.remove(key)
- if not s:
- del state["waiting"][dep]
- state["ready"].append(dep)
- for dep in state["dependencies"][key]:
- if dep in state["waiting_data"]:
- s = state["waiting_data"][dep]
- s.remove(key)
- if not s and dep not in results:
- if DEBUG:
- from chest.core import nbytes
- print(
- "Key: %s\tDep: %s\t NBytes: %.2f\t Release"
- % (key, dep, sum(map(nbytes, state["cache"].values()) / 1e6))
- )
- release_data(dep, state, delete=delete)
- elif delete and dep not in results:
- release_data(dep, state, delete=delete)
- state["finished"].add(key)
- state["running"].remove(key)
- return state
- def nested_get(ind, coll):
- """Get nested index from collection
- Examples
- --------
- >>> nested_get(1, 'abc')
- 'b'
- >>> nested_get([1, 0], 'abc')
- ('b', 'a')
- >>> nested_get([[1, 0], [0, 1]], 'abc')
- (('b', 'a'), ('a', 'b'))
- """
- if isinstance(ind, list):
- return tuple(nested_get(i, coll) for i in ind)
- else:
- return coll[ind]
- def default_get_id():
- """Default get_id"""
- return None
- def default_pack_exception(e, dumps):
- raise
- def reraise(exc, tb=None):
- if exc.__traceback__ is not tb:
- raise exc.with_traceback(tb)
- raise exc
- def identity(x):
- """Identity function. Returns x.
- >>> identity(3)
- 3
- """
- return x
- def get_async(
- apply_async,
- num_workers,
- dsk,
- result,
- cache=None,
- get_id=default_get_id,
- rerun_exceptions_locally=None,
- pack_exception=default_pack_exception,
- raise_exception=reraise,
- callbacks=None,
- dumps=identity,
- loads=identity,
- **kwargs,
- ):
- """Asynchronous get function
- This is a general version of various asynchronous schedulers for dask. It
- takes a an apply_async function as found on Pool objects to form a more
- specific ``get`` method that walks through the dask array with parallel
- workers, avoiding repeat computation and minimizing memory use.
- Parameters
- ----------
- apply_async : function
- Asynchronous apply function as found on Pool or ThreadPool
- num_workers : int
- The number of active tasks we should have at any one time
- dsk : dict
- A dask dictionary specifying a workflow
- result : key or list of keys
- Keys corresponding to desired data
- cache : dict-like, optional
- Temporary storage of results
- get_id : callable, optional
- Function to return the worker id, takes no arguments. Examples are
- `threading.current_thread` and `multiprocessing.current_process`.
- rerun_exceptions_locally : bool, optional
- Whether to rerun failing tasks in local process to enable debugging
- (False by default)
- pack_exception : callable, optional
- Function to take an exception and ``dumps`` method, and return a
- serialized tuple of ``(exception, traceback)`` to send back to the
- scheduler. Default is to just raise the exception.
- raise_exception : callable, optional
- Function that takes an exception and a traceback, and raises an error.
- dumps: callable, optional
- Function to serialize task data and results to communicate between
- worker and parent. Defaults to identity.
- loads: callable, optional
- Inverse function of `dumps`. Defaults to identity.
- callbacks : tuple or list of tuples, optional
- Callbacks are passed in as tuples of length 5. Multiple sets of
- callbacks may be passed in as a list of tuples. For more information,
- see the dask.diagnostics documentation.
- See Also
- --------
- threaded.get
- """
- queue = Queue()
- if isinstance(result, list):
- result_flat = set(flatten(result))
- else:
- result_flat = {result}
- results = set(result_flat)
- dsk = dict(dsk)
- with local_callbacks(callbacks) as callbacks:
- _, _, pretask_cbs, posttask_cbs, _ = unpack_callbacks(callbacks)
- started_cbs = []
- succeeded = False
- # if start_state_from_dask fails, we will have something
- # to pass to the final block.
- state = {}
- try:
- for cb in callbacks:
- if cb[0]:
- cb[0](dsk)
- started_cbs.append(cb)
- keyorder = order(dsk)
- state = start_state_from_dask(dsk, cache=cache, sortkey=keyorder.get)
- for _, start_state, _, _, _ in callbacks:
- if start_state:
- start_state(dsk, state)
- if rerun_exceptions_locally is None:
- rerun_exceptions_locally = config.get("rerun_exceptions_locally", False)
- if state["waiting"] and not state["ready"]:
- raise ValueError("Found no accessible jobs in dask")
- def fire_task():
- """Fire off a task to the thread pool"""
- # Choose a good task to compute
- key = state["ready"].pop()
- state["running"].add(key)
- for f in pretask_cbs:
- f(key, dsk, state)
- # Prep data to send
- data = {dep: state["cache"][dep] for dep in get_dependencies(dsk, key)}
- # Submit
- apply_async(
- execute_task,
- args=(
- key,
- dumps((dsk[key], data)),
- dumps,
- loads,
- get_id,
- pack_exception,
- ),
- callback=queue.put,
- )
- # Seed initial tasks into the thread pool
- while state["ready"] and len(state["running"]) < num_workers:
- fire_task()
- # Main loop, wait on tasks to finish, insert new ones
- while state["waiting"] or state["ready"] or state["running"]:
- key, res_info, failed = queue_get(queue)
- if failed:
- exc, tb = loads(res_info)
- if rerun_exceptions_locally:
- data = {
- dep: state["cache"][dep]
- for dep in get_dependencies(dsk, key)
- }
- task = dsk[key]
- task(data) # Re-execute locally
- else:
- raise_exception(exc, tb)
- res, worker_id = loads(res_info)
- state["cache"][key] = res
- finish_task(dsk, key, state, results, keyorder.get)
- for f in posttask_cbs:
- f(key, res, dsk, state, worker_id)
- while state["ready"] and len(state["running"]) < num_workers:
- fire_task()
- succeeded = True
- finally:
- for _, _, _, _, finish in started_cbs:
- if finish:
- finish(dsk, state, not succeeded)
- return nested_get(result, state["cache"])
- def apply_sync(func, args=(), kwds=None, callback=None):
- """A naive synchronous version of apply_async"""
- if kwds is None:
- kwds = {}
- res = func(*args, **kwds)
- if callback is not None:
- callback(res)
|