lightgbm_trainer.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. import logging
  2. from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
  3. import ray.train
  4. from ray.train import Checkpoint
  5. from ray.train.trainer import GenDataset
  6. from ray.train.v2.api.config import RunConfig, ScalingConfig
  7. from ray.train.v2.api.data_parallel_trainer import DataParallelTrainer
  8. from ray.train.v2.api.validation_config import ValidationConfig
  9. from ray.util.annotations import Deprecated
  10. if TYPE_CHECKING:
  11. from ray.train.lightgbm import LightGBMConfig
  12. logger = logging.getLogger(__name__)
  13. class LightGBMTrainer(DataParallelTrainer):
  14. """A Trainer for distributed data-parallel LightGBM training.
  15. Example
  16. -------
  17. .. testcode::
  18. :skipif: True
  19. import lightgbm as lgb
  20. import ray.data
  21. import ray.train
  22. from ray.train.lightgbm import RayTrainReportCallback
  23. from ray.train.lightgbm import LightGBMTrainer
  24. def train_fn_per_worker(config: dict):
  25. # (Optional) Add logic to resume training state from a checkpoint.
  26. # ray.train.get_checkpoint()
  27. # 1. Get the dataset shard for the worker and convert to a `lgb.Dataset`
  28. train_ds_iter, eval_ds_iter = (
  29. ray.train.get_dataset_shard("train"),
  30. ray.train.get_dataset_shard("validation"),
  31. )
  32. train_ds, eval_ds = train_ds_iter.materialize(), eval_ds_iter.materialize()
  33. train_df, eval_df = train_ds.to_pandas(), eval_ds.to_pandas()
  34. train_X, train_y = train_df.drop("y", axis=1), train_df["y"]
  35. eval_X, eval_y = eval_df.drop("y", axis=1), eval_df["y"]
  36. train_set = lgb.Dataset(train_X, label=train_y)
  37. eval_set = lgb.Dataset(eval_X, label=eval_y)
  38. # 2. Run distributed data-parallel training.
  39. # `get_network_params` sets up the necessary configurations for LightGBM
  40. # to set up the data parallel training worker group on your Ray cluster.
  41. params = {
  42. "objective": "regression",
  43. # Adding the lines below are the only changes needed
  44. # for your `lgb.train` call!
  45. "tree_learner": "data_parallel",
  46. "pre_partition": True,
  47. **ray.train.lightgbm.get_network_params(),
  48. }
  49. lgb.train(
  50. params,
  51. train_set,
  52. valid_sets=[eval_set],
  53. valid_names=["eval"],
  54. num_boost_round=1,
  55. # To access the checkpoint from trainer, you need this callback.
  56. callbacks=[RayTrainReportCallback()],
  57. )
  58. train_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
  59. eval_ds = ray.data.from_items(
  60. [{"x": x, "y": x + 1} for x in range(32, 32 + 16)]
  61. )
  62. trainer = LightGBMTrainer(
  63. train_fn_per_worker,
  64. datasets={"train": train_ds, "validation": eval_ds},
  65. scaling_config=ray.train.ScalingConfig(num_workers=2),
  66. )
  67. result = trainer.fit()
  68. booster = RayTrainReportCallback.get_model(result.checkpoint)
  69. Args:
  70. train_loop_per_worker: The training function to execute on each worker.
  71. This function can either take in zero arguments or a single ``Dict``
  72. argument which is set by defining ``train_loop_config``.
  73. Within this function you can use any of the
  74. :ref:`Ray Train Loop utilities <train-loop-api>`.
  75. train_loop_config: A configuration ``Dict`` to pass in as an argument to
  76. ``train_loop_per_worker``.
  77. This is typically used for specifying hyperparameters.
  78. lightgbm_config: The configuration for setting up the distributed lightgbm
  79. backend. See :class:`~ray.train.lightgbm.LightGBMConfig` for more info.
  80. scaling_config: The configuration for how to scale data parallel training.
  81. ``num_workers`` determines how many Python processes are used for training,
  82. and ``use_gpu`` determines whether or not each process should use GPUs.
  83. See :class:`~ray.train.ScalingConfig` for more info.
  84. run_config: The configuration for the execution of the training run.
  85. See :class:`~ray.train.RunConfig` for more info.
  86. datasets: The Ray Datasets to ingest for training.
  87. Datasets are keyed by name (``{name: dataset}``).
  88. Each dataset can be accessed from within the ``train_loop_per_worker``
  89. by calling ``ray.train.get_dataset_shard(name)``.
  90. Sharding and additional configuration can be done by
  91. passing in a ``dataset_config``.
  92. dataset_config: The configuration for ingesting the input ``datasets``.
  93. By default, all the Ray Dataset are split equally across workers.
  94. See :class:`~ray.train.DataConfig` for more details.
  95. validation_config: [Alpha] Configuration for checkpoint validation.
  96. If provided and ``ray.train.report`` is called with the ``validation``
  97. argument, Ray Train will validate the reported checkpoint using
  98. the validation function specified in this config.
  99. resume_from_checkpoint: [Deprecated]
  100. metadata: [Deprecated]
  101. """
  102. def __init__(
  103. self,
  104. train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
  105. *,
  106. train_loop_config: Optional[Dict] = None,
  107. lightgbm_config: Optional["LightGBMConfig"] = None,
  108. scaling_config: Optional[ScalingConfig] = None,
  109. run_config: Optional[RunConfig] = None,
  110. datasets: Optional[Dict[str, GenDataset]] = None,
  111. dataset_config: Optional[ray.train.DataConfig] = None,
  112. validation_config: Optional[ValidationConfig] = None,
  113. # TODO: [Deprecated]
  114. metadata: Optional[Dict[str, Any]] = None,
  115. resume_from_checkpoint: Optional[Checkpoint] = None,
  116. # TODO: [Deprecated] Legacy LightGBMTrainer API
  117. label_column: Optional[str] = None,
  118. params: Optional[Dict[str, Any]] = None,
  119. num_boost_round: Optional[int] = None,
  120. ):
  121. if (
  122. label_column is not None
  123. or params is not None
  124. or num_boost_round is not None
  125. ):
  126. raise DeprecationWarning(
  127. "The legacy LightGBMTrainer API is deprecated. "
  128. "Please switch to passing in a custom `train_loop_per_worker` "
  129. "function instead. "
  130. "See this issue for more context: "
  131. "https://github.com/ray-project/ray/issues/50042"
  132. )
  133. from ray.train.lightgbm import LightGBMConfig
  134. super(LightGBMTrainer, self).__init__(
  135. train_loop_per_worker=train_loop_per_worker,
  136. train_loop_config=train_loop_config,
  137. backend_config=lightgbm_config or LightGBMConfig(),
  138. scaling_config=scaling_config,
  139. dataset_config=dataset_config,
  140. run_config=run_config,
  141. datasets=datasets,
  142. resume_from_checkpoint=resume_from_checkpoint,
  143. metadata=metadata,
  144. validation_config=validation_config,
  145. )
  146. @classmethod
  147. @Deprecated
  148. def get_model(
  149. cls,
  150. checkpoint: Checkpoint,
  151. ):
  152. """Retrieve the LightGBM model stored in this checkpoint.
  153. This API is deprecated. Use `RayTrainReportCallback.get_model` instead.
  154. """
  155. raise DeprecationWarning(
  156. "`LightGBMTrainer.get_model` is deprecated. "
  157. "Use `RayTrainReportCallback.get_model` instead."
  158. )