| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687 |
- import uuid
- from collections import OrderedDict
- from collections.abc import Iterator
- from operator import getitem
- from dask.core import get as get_sync, quote
- from dask.utils import apply
- import ray
- try:
- from dataclasses import fields as dataclass_fields, is_dataclass
- except ImportError:
- # Python < 3.7
- def is_dataclass(x):
- return False
- def dataclass_fields(x):
- return []
- def unpack_object_refs(*args):
- """
- Extract Ray object refs from a set of potentially arbitrarily nested
- Python objects.
- Intended use is to find all Ray object references in a set of (possibly
- nested) Python objects, do something to them (get(), wait(), etc.), then
- repackage them into equivalent Python objects.
- Args:
- *args: One or more (potentially nested) Python objects that contain
- Ray object references.
- Returns:
- A 2-tuple of a flat list of all contained Ray object references, and a
- function that, when given the corresponding flat list of concrete
- values, will return a set of Python objects equivalent to that which
- was given in *args, but with all Ray object references replaced with
- their corresponding concrete values.
- """
- object_refs = []
- repack_dsk = {}
- object_refs_token = uuid.uuid4().hex
- def _unpack(expr):
- if isinstance(expr, ray.ObjectRef):
- token = expr.hex()
- repack_dsk[token] = (getitem, object_refs_token, len(object_refs))
- object_refs.append(expr)
- return token
- token = uuid.uuid4().hex
- # Treat iterators like lists
- typ = list if isinstance(expr, Iterator) else type(expr)
- if typ in (list, tuple, set):
- repack_task = (typ, [_unpack(i) for i in expr])
- elif typ in (dict, OrderedDict):
- repack_task = (typ, [[_unpack(k), _unpack(v)] for k, v in expr.items()])
- elif is_dataclass(expr):
- repack_task = (
- apply,
- typ,
- (),
- (
- dict,
- [
- [f.name, _unpack(getattr(expr, f.name))]
- for f in dataclass_fields(expr)
- ],
- ),
- )
- else:
- return expr
- repack_dsk[token] = repack_task
- return token
- out = uuid.uuid4().hex
- repack_dsk[out] = (tuple, [_unpack(i) for i in args])
- def repack(results):
- dsk = repack_dsk.copy()
- dsk[object_refs_token] = quote(results)
- return get_sync(dsk, out)
- return object_refs, repack
|