__init__.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import dask
  2. from packaging.version import Version
  3. # Version(dask.__version__) becomes "0" during doc builds.
  4. if Version(dask.__version__) != Version("0") and Version(dask.__version__) < Version(
  5. "2024.11.0"
  6. ):
  7. # Dask on Ray doesn't work if Dask version is less than 2024.11.0.
  8. raise ImportError(
  9. "Dask on Ray requires Dask version 2024.11.0 or later. "
  10. "Please upgrade your Dask installation."
  11. )
  12. from .callbacks import (
  13. ProgressBarCallback,
  14. RayDaskCallback,
  15. local_ray_callbacks,
  16. unpack_ray_callbacks,
  17. )
  18. from .optimizations import dataframe_optimize
  19. from .scheduler import (
  20. disable_dask_on_ray,
  21. enable_dask_on_ray,
  22. ray_dask_get,
  23. ray_dask_get_sync,
  24. )
  25. dask_persist = dask.persist
  26. def ray_dask_persist(*args, **kwargs):
  27. kwargs["ray_persist"] = True
  28. return dask_persist(*args, **kwargs)
  29. ray_dask_persist.__doc__ = dask_persist.__doc__
  30. dask_persist_mixin = dask.base.DaskMethodsMixin.persist
  31. def ray_dask_persist_mixin(self, **kwargs):
  32. kwargs["ray_persist"] = True
  33. return dask_persist_mixin(self, **kwargs)
  34. ray_dask_persist_mixin.__doc__ = dask_persist_mixin.__doc__
  35. # We patch dask in order to inject a kwarg into its `dask.persist()` calls,
  36. # which the Dask-on-Ray scheduler needs.
  37. # FIXME(Clark): Monkey patching is bad and we should try to avoid this.
  38. def patch_dask(ray_dask_persist, ray_dask_persist_mixin):
  39. dask.persist = ray_dask_persist
  40. dask.base.DaskMethodsMixin.persist = ray_dask_persist_mixin
  41. patch_dask(ray_dask_persist, ray_dask_persist_mixin)
  42. __all__ = [
  43. # Config
  44. "enable_dask_on_ray",
  45. "disable_dask_on_ray",
  46. # Schedulers
  47. "ray_dask_get",
  48. "ray_dask_get_sync",
  49. # Helpers
  50. "ray_dask_persist",
  51. # Callbacks
  52. "RayDaskCallback",
  53. "local_ray_callbacks",
  54. "unpack_ray_callbacks",
  55. # Optimizations
  56. "dataframe_optimize",
  57. "ProgressBarCallback",
  58. ]