ray_backend.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import logging
  2. from typing import Any, Dict, Optional
  3. from joblib import Parallel
  4. from joblib._parallel_backends import MultiprocessingBackend
  5. from joblib.pool import PicklingPool
  6. import ray
  7. from ray._common.usage import usage_lib
  8. from ray.util.multiprocessing.pool import Pool
  9. logger = logging.getLogger(__name__)
  10. class RayBackend(MultiprocessingBackend):
  11. """Ray backend uses ray, a system for scalable distributed computing.
  12. More info about Ray is available here: https://docs.ray.io.
  13. """
  14. def __init__(
  15. self,
  16. nesting_level: Optional[int] = None,
  17. inner_max_num_threads: Optional[int] = None,
  18. ray_remote_args: Optional[Dict[str, Any]] = None,
  19. **kwargs
  20. ):
  21. """``ray_remote_args`` will be used to configure Ray Actors
  22. making up the pool."""
  23. usage_lib.record_library_usage("util.joblib")
  24. self.ray_remote_args = ray_remote_args
  25. super().__init__(
  26. nesting_level=nesting_level,
  27. inner_max_num_threads=inner_max_num_threads,
  28. **kwargs
  29. )
  30. # ray_remote_args is used both in __init__ and configure to allow for it to be
  31. # set in both `parallel_backend` and `Parallel` respectively
  32. def configure(
  33. self,
  34. n_jobs: int = 1,
  35. parallel: Optional[Parallel] = None,
  36. prefer: Optional[str] = None,
  37. require: Optional[str] = None,
  38. ray_remote_args: Optional[Dict[str, Any]] = None,
  39. **memmappingpool_args
  40. ):
  41. """Make Ray Pool the father class of PicklingPool. PicklingPool is a
  42. father class that inherits Pool from multiprocessing.pool. The next
  43. line is a patch, which changes the inheritance of Pool to be from
  44. ray.util.multiprocessing.pool.
  45. ``ray_remote_args`` will be used to configure Ray Actors making up the pool.
  46. This will override ``ray_remote_args`` set during initialization.
  47. """
  48. PicklingPool.__bases__ = (Pool,)
  49. """Use all available resources when n_jobs == -1. Must set RAY_ADDRESS
  50. variable in the environment or run ray.init(address=..) to run on
  51. multiple nodes.
  52. """
  53. if n_jobs == -1:
  54. if not ray.is_initialized():
  55. import os
  56. if "RAY_ADDRESS" in os.environ:
  57. logger.info(
  58. "Connecting to ray cluster at address='{}'".format(
  59. os.environ["RAY_ADDRESS"]
  60. )
  61. )
  62. else:
  63. logger.info("Starting local ray cluster")
  64. ray.init()
  65. ray_cpus = int(ray._private.state.cluster_resources()["CPU"])
  66. n_jobs = ray_cpus
  67. eff_n_jobs = super(RayBackend, self).configure(
  68. n_jobs,
  69. parallel,
  70. prefer,
  71. require,
  72. ray_remote_args=ray_remote_args
  73. if ray_remote_args is not None
  74. else self.ray_remote_args,
  75. **memmappingpool_args
  76. )
  77. return eff_n_jobs
  78. def effective_n_jobs(self, n_jobs):
  79. eff_n_jobs = super(RayBackend, self).effective_n_jobs(n_jobs)
  80. if n_jobs == -1:
  81. ray_cpus = int(ray._private.state.cluster_resources()["CPU"])
  82. eff_n_jobs = ray_cpus
  83. return eff_n_jobs