common.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import uuid
  2. from collections import OrderedDict
  3. from collections.abc import Iterator
  4. from operator import getitem
  5. from dask.core import get as get_sync, quote
  6. from dask.utils import apply
  7. import ray
  8. try:
  9. from dataclasses import fields as dataclass_fields, is_dataclass
  10. except ImportError:
  11. # Python < 3.7
  12. def is_dataclass(x):
  13. return False
  14. def dataclass_fields(x):
  15. return []
  16. def unpack_object_refs(*args):
  17. """
  18. Extract Ray object refs from a set of potentially arbitrarily nested
  19. Python objects.
  20. Intended use is to find all Ray object references in a set of (possibly
  21. nested) Python objects, do something to them (get(), wait(), etc.), then
  22. repackage them into equivalent Python objects.
  23. Args:
  24. *args: One or more (potentially nested) Python objects that contain
  25. Ray object references.
  26. Returns:
  27. A 2-tuple of a flat list of all contained Ray object references, and a
  28. function that, when given the corresponding flat list of concrete
  29. values, will return a set of Python objects equivalent to that which
  30. was given in *args, but with all Ray object references replaced with
  31. their corresponding concrete values.
  32. """
  33. object_refs = []
  34. repack_dsk = {}
  35. object_refs_token = uuid.uuid4().hex
  36. def _unpack(expr):
  37. if isinstance(expr, ray.ObjectRef):
  38. token = expr.hex()
  39. repack_dsk[token] = (getitem, object_refs_token, len(object_refs))
  40. object_refs.append(expr)
  41. return token
  42. token = uuid.uuid4().hex
  43. # Treat iterators like lists
  44. typ = list if isinstance(expr, Iterator) else type(expr)
  45. if typ in (list, tuple, set):
  46. repack_task = (typ, [_unpack(i) for i in expr])
  47. elif typ in (dict, OrderedDict):
  48. repack_task = (typ, [[_unpack(k), _unpack(v)] for k, v in expr.items()])
  49. elif is_dataclass(expr):
  50. repack_task = (
  51. apply,
  52. typ,
  53. (),
  54. (
  55. dict,
  56. [
  57. [f.name, _unpack(getattr(expr, f.name))]
  58. for f in dataclass_fields(expr)
  59. ],
  60. ),
  61. )
  62. else:
  63. return expr
  64. repack_dsk[token] = repack_task
  65. return token
  66. out = uuid.uuid4().hex
  67. repack_dsk[out] = (tuple, [_unpack(i) for i in args])
  68. def repack(results):
  69. dsk = repack_dsk.copy()
  70. dsk[object_refs_token] = quote(results)
  71. return get_sync(dsk, out)
  72. return object_refs, repack