xgboost_trainer.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. import logging
  2. from functools import partial
  3. from typing import Any, Callable, Dict, Optional, Union
  4. import xgboost
  5. from packaging.version import Version
  6. import ray.train
  7. from ray.train import Checkpoint
  8. from ray.train.constants import TRAIN_DATASET_KEY
  9. from ray.train.trainer import GenDataset
  10. from ray.train.utils import _log_deprecation_warning
  11. from ray.train.xgboost import RayTrainReportCallback, XGBoostConfig
  12. from ray.train.xgboost.v2 import XGBoostTrainer as SimpleXGBoostTrainer
  13. from ray.util.annotations import PublicAPI
  14. logger = logging.getLogger(__name__)
  15. LEGACY_XGBOOST_TRAINER_DEPRECATION_MESSAGE = (
  16. "Passing in `xgboost.train` kwargs such as `params`, `num_boost_round`, "
  17. "`label_column`, etc. to `XGBoostTrainer` is deprecated "
  18. "in favor of the new API which accepts a training function, "
  19. "similar to the other DataParallelTrainer APIs (ex: TorchTrainer). "
  20. "See this issue for more context: "
  21. "https://github.com/ray-project/ray/issues/50042"
  22. )
  23. def _xgboost_train_fn_per_worker(
  24. config: dict,
  25. label_column: str,
  26. num_boost_round: int,
  27. dataset_keys: set,
  28. xgboost_train_kwargs: dict,
  29. ):
  30. checkpoint = ray.train.get_checkpoint()
  31. starting_model = None
  32. remaining_iters = num_boost_round
  33. if checkpoint:
  34. starting_model = RayTrainReportCallback.get_model(checkpoint)
  35. starting_iter = starting_model.num_boosted_rounds()
  36. remaining_iters = num_boost_round - starting_iter
  37. logger.info(
  38. f"Model loaded from checkpoint will train for "
  39. f"additional {remaining_iters} iterations (trees) in order "
  40. "to achieve the target number of iterations "
  41. f"({num_boost_round=})."
  42. )
  43. train_ds_iter = ray.train.get_dataset_shard(TRAIN_DATASET_KEY)
  44. train_df = train_ds_iter.materialize().to_pandas()
  45. eval_ds_iters = {
  46. k: ray.train.get_dataset_shard(k)
  47. for k in dataset_keys
  48. if k != TRAIN_DATASET_KEY
  49. }
  50. eval_dfs = {k: d.materialize().to_pandas() for k, d in eval_ds_iters.items()}
  51. train_X, train_y = train_df.drop(label_column, axis=1), train_df[label_column]
  52. dtrain = xgboost.DMatrix(train_X, label=train_y)
  53. # NOTE: Include the training dataset in the evaluation datasets.
  54. # This allows `train-*` metrics to be calculated and reported.
  55. evals = [(dtrain, TRAIN_DATASET_KEY)]
  56. for eval_name, eval_df in eval_dfs.items():
  57. eval_X, eval_y = eval_df.drop(label_column, axis=1), eval_df[label_column]
  58. evals.append((xgboost.DMatrix(eval_X, label=eval_y), eval_name))
  59. evals_result = {}
  60. xgboost.train(
  61. config,
  62. dtrain=dtrain,
  63. evals=evals,
  64. evals_result=evals_result,
  65. num_boost_round=remaining_iters,
  66. xgb_model=starting_model,
  67. **xgboost_train_kwargs,
  68. )
  69. @PublicAPI(stability="beta")
  70. class XGBoostTrainer(SimpleXGBoostTrainer):
  71. """A Trainer for distributed data-parallel XGBoost training.
  72. Example
  73. -------
  74. .. testcode::
  75. :skipif: True
  76. import xgboost
  77. import ray.data
  78. import ray.train
  79. from ray.train.xgboost import RayTrainReportCallback, XGBoostTrainer
  80. def train_fn_per_worker(config: dict):
  81. # (Optional) Add logic to resume training state from a checkpoint.
  82. # ray.train.get_checkpoint()
  83. # 1. Get the dataset shard for the worker and convert to a `xgboost.DMatrix`
  84. train_ds_iter, eval_ds_iter = (
  85. ray.train.get_dataset_shard("train"),
  86. ray.train.get_dataset_shard("validation"),
  87. )
  88. train_ds, eval_ds = train_ds_iter.materialize(), eval_ds_iter.materialize()
  89. train_df, eval_df = train_ds.to_pandas(), eval_ds.to_pandas()
  90. train_X, train_y = train_df.drop("y", axis=1), train_df["y"]
  91. eval_X, eval_y = eval_df.drop("y", axis=1), eval_df["y"]
  92. dtrain = xgboost.DMatrix(train_X, label=train_y)
  93. deval = xgboost.DMatrix(eval_X, label=eval_y)
  94. params = {
  95. "tree_method": "approx",
  96. "objective": "reg:squarederror",
  97. "eta": 1e-4,
  98. "subsample": 0.5,
  99. "max_depth": 2,
  100. }
  101. # 2. Do distributed data-parallel training.
  102. # Ray Train sets up the necessary coordinator processes and
  103. # environment variables for your workers to communicate with each other.
  104. bst = xgboost.train(
  105. params,
  106. dtrain=dtrain,
  107. evals=[(deval, "validation")],
  108. num_boost_round=10,
  109. callbacks=[RayTrainReportCallback()],
  110. )
  111. train_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
  112. eval_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(16)])
  113. trainer = XGBoostTrainer(
  114. train_fn_per_worker,
  115. datasets={"train": train_ds, "validation": eval_ds},
  116. scaling_config=ray.train.ScalingConfig(num_workers=4),
  117. )
  118. result = trainer.fit()
  119. booster = RayTrainReportCallback.get_model(result.checkpoint)
  120. Args:
  121. train_loop_per_worker: The training function to execute on each worker.
  122. This function can either take in zero arguments or a single ``Dict``
  123. argument which is set by defining ``train_loop_config``.
  124. Within this function you can use any of the
  125. :ref:`Ray Train Loop utilities <train-loop-api>`.
  126. train_loop_config: A configuration ``Dict`` to pass in as an argument to
  127. ``train_loop_per_worker``.
  128. This is typically used for specifying hyperparameters.
  129. xgboost_config: The configuration for setting up the distributed xgboost
  130. backend. Defaults to using the "rabit" backend.
  131. See :class:`~ray.train.xgboost.XGBoostConfig` for more info.
  132. datasets: The Ray Datasets to use for training and validation.
  133. dataset_config: The configuration for ingesting the input ``datasets``.
  134. By default, all the Ray Datasets are split equally across workers.
  135. See :class:`~ray.train.DataConfig` for more details.
  136. scaling_config: The configuration for how to scale data parallel training.
  137. ``num_workers`` determines how many Python processes are used for training,
  138. and ``use_gpu`` determines whether or not each process should use GPUs.
  139. See :class:`~ray.train.ScalingConfig` for more info.
  140. run_config: The configuration for the execution of the training run.
  141. See :class:`~ray.train.RunConfig` for more info.
  142. resume_from_checkpoint: A checkpoint to resume training from.
  143. This checkpoint can be accessed from within ``train_loop_per_worker``
  144. by calling ``ray.train.get_checkpoint()``.
  145. metadata: Dict that should be made available via
  146. `ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
  147. for checkpoints saved from this Trainer. Must be JSON-serializable.
  148. label_column: [Deprecated] Name of the label column. A column with this name
  149. must be present in the training dataset.
  150. params: [Deprecated] XGBoost training parameters.
  151. Refer to `XGBoost documentation <https://xgboost.readthedocs.io/>`_
  152. for a list of possible parameters.
  153. num_boost_round: [Deprecated] Target number of boosting iterations (trees in the model).
  154. Note that unlike in ``xgboost.train``, this is the target number
  155. of trees, meaning that if you set ``num_boost_round=10`` and pass a model
  156. that has already been trained for 5 iterations, it will be trained for 5
  157. iterations more, instead of 10 more.
  158. **train_kwargs: [Deprecated] Additional kwargs passed to ``xgboost.train()`` function.
  159. """
  160. _handles_checkpoint_freq = True
  161. _handles_checkpoint_at_end = True
  162. def __init__(
  163. self,
  164. train_loop_per_worker: Optional[
  165. Union[Callable[[], None], Callable[[Dict], None]]
  166. ] = None,
  167. *,
  168. train_loop_config: Optional[Dict] = None,
  169. xgboost_config: Optional[XGBoostConfig] = None,
  170. scaling_config: Optional[ray.train.ScalingConfig] = None,
  171. run_config: Optional[ray.train.RunConfig] = None,
  172. datasets: Optional[Dict[str, GenDataset]] = None,
  173. dataset_config: Optional[ray.train.DataConfig] = None,
  174. resume_from_checkpoint: Optional[Checkpoint] = None,
  175. metadata: Optional[Dict[str, Any]] = None,
  176. # TODO(justinvyu): [Deprecated] Legacy XGBoostTrainer API
  177. label_column: Optional[str] = None,
  178. params: Optional[Dict[str, Any]] = None,
  179. num_boost_round: Optional[int] = None,
  180. **train_kwargs,
  181. ):
  182. if Version(xgboost.__version__) < Version("1.7.0"):
  183. raise ImportError(
  184. "`XGBoostTrainer` requires the `xgboost` version to be >= 1.7.0. "
  185. 'Upgrade with: `pip install -U "xgboost>=1.7"`'
  186. )
  187. # TODO(justinvyu): [Deprecated] Legacy XGBoostTrainer API
  188. legacy_api = train_loop_per_worker is None
  189. if legacy_api:
  190. train_loop_per_worker = self._get_legacy_train_fn_per_worker(
  191. xgboost_train_kwargs=train_kwargs,
  192. run_config=run_config,
  193. label_column=label_column,
  194. num_boost_round=num_boost_round,
  195. datasets=datasets,
  196. )
  197. train_loop_config = params or {}
  198. elif train_kwargs:
  199. _log_deprecation_warning(
  200. "Passing `xgboost.train` kwargs to `XGBoostTrainer` is deprecated. "
  201. "In your training function, you can call `xgboost.train(**kwargs)` "
  202. "with arbitrary arguments. "
  203. f"{LEGACY_XGBOOST_TRAINER_DEPRECATION_MESSAGE}"
  204. )
  205. super(XGBoostTrainer, self).__init__(
  206. train_loop_per_worker=train_loop_per_worker,
  207. train_loop_config=train_loop_config,
  208. xgboost_config=xgboost_config,
  209. scaling_config=scaling_config,
  210. run_config=run_config,
  211. datasets=datasets,
  212. dataset_config=dataset_config,
  213. resume_from_checkpoint=resume_from_checkpoint,
  214. metadata=metadata,
  215. )
  216. def _get_legacy_train_fn_per_worker(
  217. self,
  218. xgboost_train_kwargs: Dict,
  219. run_config: Optional[ray.train.RunConfig],
  220. datasets: Optional[Dict[str, GenDataset]],
  221. label_column: Optional[str],
  222. num_boost_round: Optional[int],
  223. ) -> Callable[[Dict], None]:
  224. """Get the training function for the legacy XGBoostTrainer API."""
  225. datasets = datasets or {}
  226. if not datasets.get(TRAIN_DATASET_KEY):
  227. raise ValueError(
  228. "`datasets` must be provided for the XGBoostTrainer API "
  229. "if `train_loop_per_worker` is not provided. "
  230. "This dict must contain the training dataset under the "
  231. f"key: '{TRAIN_DATASET_KEY}'. "
  232. f"Got keys: {list(datasets.keys())}"
  233. )
  234. if not label_column:
  235. raise ValueError(
  236. "`label_column` must be provided for the XGBoostTrainer API "
  237. "if `train_loop_per_worker` is not provided. "
  238. "This is the column name of the label in the dataset."
  239. )
  240. num_boost_round = num_boost_round or 10
  241. _log_deprecation_warning(LEGACY_XGBOOST_TRAINER_DEPRECATION_MESSAGE)
  242. # Initialize a default Ray Train metrics/checkpoint reporting callback if needed
  243. callbacks = xgboost_train_kwargs.get("callbacks", [])
  244. user_supplied_callback = any(
  245. isinstance(callback, RayTrainReportCallback) for callback in callbacks
  246. )
  247. callback_kwargs = {}
  248. if run_config:
  249. checkpoint_frequency = run_config.checkpoint_config.checkpoint_frequency
  250. checkpoint_at_end = run_config.checkpoint_config.checkpoint_at_end
  251. callback_kwargs["frequency"] = checkpoint_frequency
  252. # Default `checkpoint_at_end=True` unless the user explicitly sets it.
  253. callback_kwargs["checkpoint_at_end"] = (
  254. checkpoint_at_end if checkpoint_at_end is not None else True
  255. )
  256. if not user_supplied_callback:
  257. callbacks.append(RayTrainReportCallback(**callback_kwargs))
  258. xgboost_train_kwargs["callbacks"] = callbacks
  259. train_fn_per_worker = partial(
  260. _xgboost_train_fn_per_worker,
  261. label_column=label_column,
  262. num_boost_round=num_boost_round,
  263. dataset_keys=set(datasets),
  264. xgboost_train_kwargs=xgboost_train_kwargs,
  265. )
  266. return train_fn_per_worker
  267. @classmethod
  268. def get_model(
  269. cls,
  270. checkpoint: Checkpoint,
  271. ) -> xgboost.Booster:
  272. """Retrieve the XGBoost model stored in this checkpoint."""
  273. return RayTrainReportCallback.get_model(checkpoint)