| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202 |
- import json
- import logging
- import os
- import threading
- from contextlib import contextmanager
- from dataclasses import dataclass
- from typing import Optional
- import xgboost
- from packaging.version import Version
- from xgboost import RabitTracker
- from xgboost.collective import CommunicatorContext
- import ray
- from ray.train._internal.base_worker_group import BaseWorkerGroup
- from ray.train.backend import Backend, BackendConfig
- logger = logging.getLogger(__name__)
- @dataclass
- class XGBoostConfig(BackendConfig):
- """Configuration for xgboost collective communication setup.
- Ray Train will set up the necessary coordinator processes and environment
- variables for your workers to communicate with each other.
- Additional configuration options can be passed into the
- `xgboost.collective.CommunicatorContext` that wraps your own `xgboost.train` code.
- See the `xgboost.collective` module for more information:
- https://github.com/dmlc/xgboost/blob/master/python-package/xgboost/collective.py
- Args:
- xgboost_communicator: The backend to use for collective communication for
- distributed xgboost training. For now, only "rabit" is supported.
- """
- xgboost_communicator: str = "rabit"
- @property
- def train_func_context(self):
- @contextmanager
- def collective_communication_context():
- with CommunicatorContext(**_get_xgboost_args()):
- yield
- return collective_communication_context
- @property
- def backend_cls(self):
- if self.xgboost_communicator == "rabit":
- return (
- _XGBoostRabitBackend
- if Version(xgboost.__version__) >= Version("2.1.0")
- else _XGBoostRabitBackend_pre_xgb210
- )
- raise NotImplementedError(f"Unsupported backend: {self.xgboost_communicator}")
- class _XGBoostRabitBackend(Backend):
- def __init__(self):
- self._tracker: Optional[RabitTracker] = None
- self._wait_thread: Optional[threading.Thread] = None
- def _setup_xgboost_distributed_backend(self, worker_group: BaseWorkerGroup):
- # Set up the rabit tracker on the Train driver.
- num_workers = len(worker_group)
- rabit_args = {"n_workers": num_workers}
- train_driver_ip = ray.util.get_node_ip_address()
- # NOTE: sortby="task" is needed to ensure that the xgboost worker ranks
- # align with Ray Train worker ranks.
- # The worker ranks will be sorted by `dmlc_task_id`,
- # which is defined below.
- self._tracker = RabitTracker(
- n_workers=num_workers, host_ip=train_driver_ip, sortby="task"
- )
- self._tracker.start()
- # The RabitTracker is started in a separate thread, and the
- # `wait_for` method must be called for `worker_args` to return.
- self._wait_thread = threading.Thread(target=self._tracker.wait_for, daemon=True)
- self._wait_thread.start()
- rabit_args.update(self._tracker.worker_args())
- start_log = (
- "RabitTracker coordinator started with parameters:\n"
- f"{json.dumps(rabit_args, indent=2)}"
- )
- logger.debug(start_log)
- def set_xgboost_communicator_args(args):
- import ray.train
- args["dmlc_task_id"] = (
- f"[xgboost.ray-rank={ray.train.get_context().get_world_rank():08}]:"
- f"{ray.get_runtime_context().get_actor_id()}"
- )
- _set_xgboost_args(args)
- worker_group.execute(set_xgboost_communicator_args, rabit_args)
- def on_training_start(
- self, worker_group: BaseWorkerGroup, backend_config: XGBoostConfig
- ):
- assert backend_config.xgboost_communicator == "rabit"
- self._setup_xgboost_distributed_backend(worker_group)
- def on_shutdown(self, worker_group: BaseWorkerGroup, backend_config: XGBoostConfig):
- timeout = 5
- if self._wait_thread is not None:
- self._wait_thread.join(timeout=timeout)
- if self._wait_thread.is_alive():
- logger.warning(
- "During shutdown, the RabitTracker thread failed to join "
- f"within {timeout} seconds. "
- "The process will still be terminated as part of Ray actor cleanup."
- )
- class _XGBoostRabitBackend_pre_xgb210(Backend):
- def __init__(self):
- self._tracker: Optional[RabitTracker] = None
- def _setup_xgboost_distributed_backend(self, worker_group: BaseWorkerGroup):
- # Set up the rabit tracker on the Train driver.
- num_workers = len(worker_group)
- rabit_args = {"DMLC_NUM_WORKER": num_workers}
- train_driver_ip = ray.util.get_node_ip_address()
- # NOTE: sortby="task" is needed to ensure that the xgboost worker ranks
- # align with Ray Train worker ranks.
- # The worker ranks will be sorted by `DMLC_TASK_ID`,
- # which is defined below.
- self._tracker = RabitTracker(
- n_workers=num_workers, host_ip=train_driver_ip, sortby="task"
- )
- self._tracker.start(n_workers=num_workers)
- worker_args = self._tracker.worker_envs()
- rabit_args.update(worker_args)
- start_log = (
- "RabitTracker coordinator started with parameters:\n"
- f"{json.dumps(rabit_args, indent=2)}"
- )
- logger.debug(start_log)
- def set_xgboost_env_vars():
- import ray.train
- for k, v in rabit_args.items():
- os.environ[k] = str(v)
- # Ranks are assigned in increasing order of the worker's task id.
- # This task id will be sorted by increasing world rank.
- os.environ["DMLC_TASK_ID"] = (
- f"[xgboost.ray-rank={ray.train.get_context().get_world_rank():08}]:"
- f"{ray.get_runtime_context().get_actor_id()}"
- )
- worker_group.execute(set_xgboost_env_vars)
- def on_training_start(
- self, worker_group: BaseWorkerGroup, backend_config: XGBoostConfig
- ):
- assert backend_config.xgboost_communicator == "rabit"
- self._setup_xgboost_distributed_backend(worker_group)
- def on_shutdown(self, worker_group: BaseWorkerGroup, backend_config: XGBoostConfig):
- if not self._tracker:
- return
- timeout = 5
- self._tracker.thread.join(timeout=timeout)
- if self._tracker.thread.is_alive():
- logger.warning(
- "During shutdown, the RabitTracker thread failed to join "
- f"within {timeout} seconds. "
- "The process will still be terminated as part of Ray actor cleanup."
- )
- _xgboost_args: dict = {}
- _xgboost_args_lock = threading.Lock()
- def _set_xgboost_args(args):
- with _xgboost_args_lock:
- global _xgboost_args
- _xgboost_args = args
- def _get_xgboost_args() -> dict:
- with _xgboost_args_lock:
- return _xgboost_args
|