remote_fn.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. from typing import Any, Dict, Hashable, List
  2. import ray
  3. CACHED_FUNCTIONS = {}
  4. def cached_remote_fn(fn: Any, **ray_remote_args) -> Any:
  5. """Lazily defines a ray.remote function.
  6. This is used in Datasets to avoid circular import issues with ray.remote.
  7. (ray imports ray.data in order to allow ``ray.data.read_foo()`` to work,
  8. which means ray.remote cannot be used top-level in ray.data).
  9. NOTE: Dynamic arguments should not be passed in directly,
  10. and should be set with ``options`` instead:
  11. ``cached_remote_fn(fn, **static_args).options(**dynamic_args)``.
  12. """
  13. # NOTE: Hash of the passed in arguments guarantees that we're caching
  14. # complete instantiation of the Ray's remote method
  15. #
  16. # To compute the hash of passed in arguments and make sure it's deterministic
  17. # - Sort all KV-pairs by the keys
  18. # - Convert sorted list into tuple
  19. # - Compute hash of the resulting tuple
  20. hashable_args = _make_hashable(ray_remote_args)
  21. args_hash = hash(hashable_args)
  22. if (fn, args_hash) not in CACHED_FUNCTIONS:
  23. default_ray_remote_args = {
  24. # Use the default scheduling strategy for all tasks so that we will
  25. # not inherit a placement group from the caller, if there is one.
  26. # The caller of this function may override the scheduling strategy
  27. # as needed.
  28. "scheduling_strategy": "DEFAULT",
  29. "max_retries": -1,
  30. }
  31. ray_remote_args = {**default_ray_remote_args, **ray_remote_args}
  32. _add_system_error_to_retry_exceptions(ray_remote_args)
  33. CACHED_FUNCTIONS[(fn, args_hash)] = ray.remote(**ray_remote_args)(fn)
  34. return CACHED_FUNCTIONS[(fn, args_hash)]
  35. def _make_hashable(obj):
  36. if isinstance(obj, (List, tuple)):
  37. return tuple([_make_hashable(o) for o in obj])
  38. elif isinstance(obj, Dict):
  39. converted = [(_make_hashable(k), _make_hashable(v)) for k, v in obj.items()]
  40. return tuple(sorted(converted, key=lambda t: t[0]))
  41. elif isinstance(obj, Hashable):
  42. return obj
  43. else:
  44. raise ValueError(f"Type {type(obj)} is not hashable")
  45. def _add_system_error_to_retry_exceptions(ray_remote_args) -> None:
  46. """Modify the remote args so that Ray retries `RaySystemError`s.
  47. Ray typically automatically retries system errors. However, in some cases, Ray won't
  48. retry system errors if they're raised from task code. To ensure that Ray Data is
  49. fault tolerant to those errors, we need to add `RaySystemError` to the
  50. `retry_exceptions` list.
  51. TODO: Fix this in Ray Core. See https://github.com/ray-project/ray/pull/45079.
  52. """
  53. retry_exceptions = ray_remote_args.get("retry_exceptions", False)
  54. assert isinstance(retry_exceptions, (list, bool))
  55. if (
  56. isinstance(retry_exceptions, list)
  57. and ray.exceptions.RaySystemError not in retry_exceptions
  58. ):
  59. retry_exceptions.append(ray.exceptions.RaySystemError)
  60. elif not retry_exceptions:
  61. retry_exceptions = [ray.exceptions.RaySystemError]
  62. ray_remote_args["retry_exceptions"] = retry_exceptions