config.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import json
  2. import logging
  3. import os
  4. import threading
  5. from contextlib import contextmanager
  6. from dataclasses import dataclass
  7. from typing import Optional
  8. import xgboost
  9. from packaging.version import Version
  10. from xgboost import RabitTracker
  11. from xgboost.collective import CommunicatorContext
  12. import ray
  13. from ray.train._internal.base_worker_group import BaseWorkerGroup
  14. from ray.train.backend import Backend, BackendConfig
  15. logger = logging.getLogger(__name__)
  16. @dataclass
  17. class XGBoostConfig(BackendConfig):
  18. """Configuration for xgboost collective communication setup.
  19. Ray Train will set up the necessary coordinator processes and environment
  20. variables for your workers to communicate with each other.
  21. Additional configuration options can be passed into the
  22. `xgboost.collective.CommunicatorContext` that wraps your own `xgboost.train` code.
  23. See the `xgboost.collective` module for more information:
  24. https://github.com/dmlc/xgboost/blob/master/python-package/xgboost/collective.py
  25. Args:
  26. xgboost_communicator: The backend to use for collective communication for
  27. distributed xgboost training. For now, only "rabit" is supported.
  28. """
  29. xgboost_communicator: str = "rabit"
  30. @property
  31. def train_func_context(self):
  32. @contextmanager
  33. def collective_communication_context():
  34. with CommunicatorContext(**_get_xgboost_args()):
  35. yield
  36. return collective_communication_context
  37. @property
  38. def backend_cls(self):
  39. if self.xgboost_communicator == "rabit":
  40. return (
  41. _XGBoostRabitBackend
  42. if Version(xgboost.__version__) >= Version("2.1.0")
  43. else _XGBoostRabitBackend_pre_xgb210
  44. )
  45. raise NotImplementedError(f"Unsupported backend: {self.xgboost_communicator}")
  46. class _XGBoostRabitBackend(Backend):
  47. def __init__(self):
  48. self._tracker: Optional[RabitTracker] = None
  49. self._wait_thread: Optional[threading.Thread] = None
  50. def _setup_xgboost_distributed_backend(self, worker_group: BaseWorkerGroup):
  51. # Set up the rabit tracker on the Train driver.
  52. num_workers = len(worker_group)
  53. rabit_args = {"n_workers": num_workers}
  54. train_driver_ip = ray.util.get_node_ip_address()
  55. # NOTE: sortby="task" is needed to ensure that the xgboost worker ranks
  56. # align with Ray Train worker ranks.
  57. # The worker ranks will be sorted by `dmlc_task_id`,
  58. # which is defined below.
  59. self._tracker = RabitTracker(
  60. n_workers=num_workers, host_ip=train_driver_ip, sortby="task"
  61. )
  62. self._tracker.start()
  63. # The RabitTracker is started in a separate thread, and the
  64. # `wait_for` method must be called for `worker_args` to return.
  65. self._wait_thread = threading.Thread(target=self._tracker.wait_for, daemon=True)
  66. self._wait_thread.start()
  67. rabit_args.update(self._tracker.worker_args())
  68. start_log = (
  69. "RabitTracker coordinator started with parameters:\n"
  70. f"{json.dumps(rabit_args, indent=2)}"
  71. )
  72. logger.debug(start_log)
  73. def set_xgboost_communicator_args(args):
  74. import ray.train
  75. args["dmlc_task_id"] = (
  76. f"[xgboost.ray-rank={ray.train.get_context().get_world_rank():08}]:"
  77. f"{ray.get_runtime_context().get_actor_id()}"
  78. )
  79. _set_xgboost_args(args)
  80. worker_group.execute(set_xgboost_communicator_args, rabit_args)
  81. def on_training_start(
  82. self, worker_group: BaseWorkerGroup, backend_config: XGBoostConfig
  83. ):
  84. assert backend_config.xgboost_communicator == "rabit"
  85. self._setup_xgboost_distributed_backend(worker_group)
  86. def on_shutdown(self, worker_group: BaseWorkerGroup, backend_config: XGBoostConfig):
  87. timeout = 5
  88. if self._wait_thread is not None:
  89. self._wait_thread.join(timeout=timeout)
  90. if self._wait_thread.is_alive():
  91. logger.warning(
  92. "During shutdown, the RabitTracker thread failed to join "
  93. f"within {timeout} seconds. "
  94. "The process will still be terminated as part of Ray actor cleanup."
  95. )
  96. class _XGBoostRabitBackend_pre_xgb210(Backend):
  97. def __init__(self):
  98. self._tracker: Optional[RabitTracker] = None
  99. def _setup_xgboost_distributed_backend(self, worker_group: BaseWorkerGroup):
  100. # Set up the rabit tracker on the Train driver.
  101. num_workers = len(worker_group)
  102. rabit_args = {"DMLC_NUM_WORKER": num_workers}
  103. train_driver_ip = ray.util.get_node_ip_address()
  104. # NOTE: sortby="task" is needed to ensure that the xgboost worker ranks
  105. # align with Ray Train worker ranks.
  106. # The worker ranks will be sorted by `DMLC_TASK_ID`,
  107. # which is defined below.
  108. self._tracker = RabitTracker(
  109. n_workers=num_workers, host_ip=train_driver_ip, sortby="task"
  110. )
  111. self._tracker.start(n_workers=num_workers)
  112. worker_args = self._tracker.worker_envs()
  113. rabit_args.update(worker_args)
  114. start_log = (
  115. "RabitTracker coordinator started with parameters:\n"
  116. f"{json.dumps(rabit_args, indent=2)}"
  117. )
  118. logger.debug(start_log)
  119. def set_xgboost_env_vars():
  120. import ray.train
  121. for k, v in rabit_args.items():
  122. os.environ[k] = str(v)
  123. # Ranks are assigned in increasing order of the worker's task id.
  124. # This task id will be sorted by increasing world rank.
  125. os.environ["DMLC_TASK_ID"] = (
  126. f"[xgboost.ray-rank={ray.train.get_context().get_world_rank():08}]:"
  127. f"{ray.get_runtime_context().get_actor_id()}"
  128. )
  129. worker_group.execute(set_xgboost_env_vars)
  130. def on_training_start(
  131. self, worker_group: BaseWorkerGroup, backend_config: XGBoostConfig
  132. ):
  133. assert backend_config.xgboost_communicator == "rabit"
  134. self._setup_xgboost_distributed_backend(worker_group)
  135. def on_shutdown(self, worker_group: BaseWorkerGroup, backend_config: XGBoostConfig):
  136. if not self._tracker:
  137. return
  138. timeout = 5
  139. self._tracker.thread.join(timeout=timeout)
  140. if self._tracker.thread.is_alive():
  141. logger.warning(
  142. "During shutdown, the RabitTracker thread failed to join "
  143. f"within {timeout} seconds. "
  144. "The process will still be terminated as part of Ray actor cleanup."
  145. )
  146. _xgboost_args: dict = {}
  147. _xgboost_args_lock = threading.Lock()
  148. def _set_xgboost_args(args):
  149. with _xgboost_args_lock:
  150. global _xgboost_args
  151. _xgboost_args = args
  152. def _get_xgboost_args() -> dict:
  153. with _xgboost_args_lock:
  154. return _xgboost_args