xgboost_trainer.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import logging
  2. from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
  3. import ray.train
  4. from ray.train import Checkpoint
  5. from ray.train.trainer import GenDataset
  6. from ray.train.v2.api.config import RunConfig, ScalingConfig
  7. from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer
  8. from ray.train.v2.api.validation_config import ValidationConfig
  9. from ray.util.annotations import Deprecated
  10. if TYPE_CHECKING:
  11. from ray.train.xgboost import XGBoostConfig
  12. logger = logging.getLogger(__name__)
  13. class XGBoostTrainer(DataParallelTrainer):
  14. """A Trainer for distributed data-parallel XGBoost training.
  15. Example
  16. -------
  17. .. testcode::
  18. import xgboost
  19. import ray.data
  20. import ray.train
  21. from ray.train.xgboost import RayTrainReportCallback
  22. from ray.train.xgboost import XGBoostTrainer
  23. def train_fn_per_worker(config: dict):
  24. # (Optional) Add logic to resume training state from a checkpoint.
  25. # ray.train.get_checkpoint()
  26. # 1. Get the dataset shard for the worker and convert to a `xgboost.DMatrix`
  27. train_ds_iter, eval_ds_iter = (
  28. ray.train.get_dataset_shard("train"),
  29. ray.train.get_dataset_shard("validation"),
  30. )
  31. train_ds, eval_ds = train_ds_iter.materialize(), eval_ds_iter.materialize()
  32. train_df, eval_df = train_ds.to_pandas(), eval_ds.to_pandas()
  33. train_X, train_y = train_df.drop("y", axis=1), train_df["y"]
  34. eval_X, eval_y = eval_df.drop("y", axis=1), eval_df["y"]
  35. dtrain = xgboost.DMatrix(train_X, label=train_y)
  36. deval = xgboost.DMatrix(eval_X, label=eval_y)
  37. params = {
  38. "tree_method": "approx",
  39. "objective": "reg:squarederror",
  40. "eta": 1e-4,
  41. "subsample": 0.5,
  42. "max_depth": 2,
  43. }
  44. # 2. Do distributed data-parallel training.
  45. # Ray Train sets up the necessary coordinator processes and
  46. # environment variables for your workers to communicate with each other.
  47. bst = xgboost.train(
  48. params,
  49. dtrain=dtrain,
  50. evals=[(deval, "validation")],
  51. num_boost_round=1,
  52. callbacks=[RayTrainReportCallback()],
  53. )
  54. train_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
  55. eval_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(16)])
  56. trainer = XGBoostTrainer(
  57. train_fn_per_worker,
  58. datasets={"train": train_ds, "validation": eval_ds},
  59. scaling_config=ray.train.ScalingConfig(num_workers=2),
  60. )
  61. result = trainer.fit()
  62. booster = RayTrainReportCallback.get_model(result.checkpoint)
  63. Args:
  64. train_loop_per_worker: The training function to execute on each worker.
  65. This function can either take in zero arguments or a single ``Dict``
  66. argument which is set by defining ``train_loop_config``.
  67. Within this function you can use any of the
  68. :ref:`Ray Train Loop utilities <train-loop-api>`.
  69. train_loop_config: A configuration ``Dict`` to pass in as an argument to
  70. ``train_loop_per_worker``.
  71. This is typically used for specifying hyperparameters.
  72. xgboost_config: The configuration for setting up the distributed xgboost
  73. backend. Defaults to using the "rabit" backend.
  74. See :class:`~ray.train.xgboost.XGBoostConfig` for more info.
  75. scaling_config: The configuration for how to scale data parallel training.
  76. ``num_workers`` determines how many Python processes are used for training,
  77. and ``use_gpu`` determines whether or not each process should use GPUs.
  78. See :class:`~ray.train.ScalingConfig` for more info.
  79. run_config: The configuration for the execution of the training run.
  80. See :class:`~ray.train.RunConfig` for more info.
  81. datasets: The Ray Datasets to ingest for training.
  82. Datasets are keyed by name (``{name: dataset}``).
  83. Each dataset can be accessed from within the ``train_loop_per_worker``
  84. by calling ``ray.train.get_dataset_shard(name)``.
  85. Sharding and additional configuration can be done by
  86. passing in a ``dataset_config``.
  87. dataset_config: The configuration for ingesting the input ``datasets``.
  88. By default, all the Ray Dataset are split equally across workers.
  89. See :class:`~ray.train.DataConfig` for more details.
  90. validation_config: [Alpha] Configuration for checkpoint validation.
  91. If provided and ``ray.train.report`` is called with the ``validation``
  92. argument, Ray Train will validate the reported checkpoint using
  93. the validation function specified in this config.
  94. resume_from_checkpoint: [Deprecated]
  95. metadata: [Deprecated]
  96. """
  97. def __init__(
  98. self,
  99. train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
  100. *,
  101. train_loop_config: Optional[Dict] = None,
  102. xgboost_config: Optional["XGBoostConfig"] = None,
  103. scaling_config: Optional[ScalingConfig] = None,
  104. run_config: Optional[RunConfig] = None,
  105. datasets: Optional[Dict[str, GenDataset]] = None,
  106. dataset_config: Optional[ray.train.DataConfig] = None,
  107. validation_config: Optional[ValidationConfig] = None,
  108. # TODO: [Deprecated]
  109. metadata: Optional[Dict[str, Any]] = None,
  110. resume_from_checkpoint: Optional[Checkpoint] = None,
  111. # TODO(justinvyu): [Deprecated] Legacy XGBoostTrainer API
  112. label_column: Optional[str] = None,
  113. params: Optional[Dict[str, Any]] = None,
  114. num_boost_round: Optional[int] = None,
  115. ):
  116. if (
  117. label_column is not None
  118. or params is not None
  119. or num_boost_round is not None
  120. ):
  121. raise DeprecationWarning(
  122. "The legacy XGBoostTrainer API is deprecated. "
  123. "Please switch to passing in a custom `train_loop_per_worker` "
  124. "function instead. "
  125. "See this issue for more context: "
  126. "https://github.com/ray-project/ray/issues/50042"
  127. )
  128. from ray.train.xgboost import XGBoostConfig
  129. super(XGBoostTrainer, self).__init__(
  130. train_loop_per_worker=train_loop_per_worker,
  131. train_loop_config=train_loop_config,
  132. backend_config=xgboost_config or XGBoostConfig(),
  133. scaling_config=scaling_config,
  134. dataset_config=dataset_config,
  135. run_config=run_config,
  136. datasets=datasets,
  137. resume_from_checkpoint=resume_from_checkpoint,
  138. metadata=metadata,
  139. validation_config=validation_config,
  140. )
  141. @classmethod
  142. @Deprecated
  143. def get_model(cls, checkpoint: Checkpoint):
  144. """[Deprecated] Retrieve the XGBoost model stored in this checkpoint."""
  145. raise DeprecationWarning(
  146. "`XGBoostTrainer.get_model` is deprecated. "
  147. "Use `RayTrainReportCallback.get_model` instead."
  148. )