v2.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import logging
  2. from typing import Any, Callable, Dict, Optional, Union
  3. import ray.train
  4. from ray.train import Checkpoint
  5. from ray.train.data_parallel_trainer import DataParallelTrainer
  6. from ray.train.lightgbm.config import LightGBMConfig, get_network_params # noqa: F401
  7. from ray.train.trainer import GenDataset
  8. logger = logging.getLogger(__name__)
  9. class LightGBMTrainer(DataParallelTrainer):
  10. """A Trainer for distributed data-parallel LightGBM training.
  11. Example
  12. -------
  13. .. testcode::
  14. :skipif: True
  15. import lightgbm as lgb
  16. import ray.data
  17. import ray.train
  18. from ray.train.lightgbm import RayTrainReportCallback, LightGBMTrainer
  19. def train_fn_per_worker(config: dict):
  20. # (Optional) Add logic to resume training state from a checkpoint.
  21. # ray.train.get_checkpoint()
  22. # 1. Get the dataset shard for the worker and convert to a `lgb.Dataset`
  23. train_ds_iter, eval_ds_iter = (
  24. ray.train.get_dataset_shard("train"),
  25. ray.train.get_dataset_shard("validation"),
  26. )
  27. train_ds, eval_ds = train_ds_iter.materialize(), eval_ds_iter.materialize()
  28. train_df, eval_df = train_ds.to_pandas(), eval_ds.to_pandas()
  29. train_X, train_y = train_df.drop("y", axis=1), train_df["y"]
  30. eval_X, eval_y = eval_df.drop("y", axis=1), eval_df["y"]
  31. train_set = lgb.Dataset(train_X, label=train_y)
  32. eval_set = lgb.Dataset(eval_X, label=eval_y)
  33. # 2. Run distributed data-parallel training.
  34. # `get_network_params` sets up the necessary configurations for LightGBM
  35. # to set up the data parallel training worker group on your Ray cluster.
  36. params = {
  37. "objective": "regression",
  38. # Adding the line below is the only change needed
  39. # for your `lgb.train` call!
  40. **ray.train.lightgbm.get_network_params(),
  41. }
  42. lgb.train(
  43. params,
  44. train_set,
  45. valid_sets=[eval_set],
  46. valid_names=["eval"],
  47. callbacks=[RayTrainReportCallback()],
  48. )
  49. train_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
  50. eval_ds = ray.data.from_items(
  51. [{"x": x, "y": x + 1} for x in range(32, 32 + 16)]
  52. )
  53. trainer = LightGBMTrainer(
  54. train_fn_per_worker,
  55. datasets={"train": train_ds, "validation": eval_ds},
  56. scaling_config=ray.train.ScalingConfig(num_workers=4),
  57. )
  58. result = trainer.fit()
  59. booster = RayTrainReportCallback.get_model(result.checkpoint)
  60. Args:
  61. train_loop_per_worker: The training function to execute on each worker.
  62. This function can either take in zero arguments or a single ``Dict``
  63. argument which is set by defining ``train_loop_config``.
  64. Within this function you can use any of the
  65. :ref:`Ray Train Loop utilities <train-loop-api>`.
  66. train_loop_config: A configuration ``Dict`` to pass in as an argument to
  67. ``train_loop_per_worker``.
  68. This is typically used for specifying hyperparameters.
  69. lightgbm_config: The configuration for setting up the distributed lightgbm
  70. backend. See :class:`~ray.train.lightgbm.LightGBMConfig` for more info.
  71. datasets: The Ray Datasets to use for training and validation.
  72. dataset_config: The configuration for ingesting the input ``datasets``.
  73. By default, all the Ray Dataset are split equally across workers.
  74. See :class:`~ray.train.DataConfig` for more details.
  75. scaling_config: The configuration for how to scale data parallel training.
  76. ``num_workers`` determines how many Python processes are used for training,
  77. and ``use_gpu`` determines whether or not each process should use GPUs.
  78. See :class:`~ray.train.ScalingConfig` for more info.
  79. run_config: The configuration for the execution of the training run.
  80. See :class:`~ray.train.RunConfig` for more info.
  81. resume_from_checkpoint: A checkpoint to resume training from.
  82. This checkpoint can be accessed from within ``train_loop_per_worker``
  83. by calling ``ray.train.get_checkpoint()``.
  84. metadata: Dict that should be made available via
  85. `ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
  86. for checkpoints saved from this Trainer. Must be JSON-serializable.
  87. """
  88. def __init__(
  89. self,
  90. train_loop_per_worker: Union[Callable[[], None], Callable[[Dict], None]],
  91. *,
  92. train_loop_config: Optional[Dict] = None,
  93. lightgbm_config: Optional[LightGBMConfig] = None,
  94. scaling_config: Optional[ray.train.ScalingConfig] = None,
  95. run_config: Optional[ray.train.RunConfig] = None,
  96. datasets: Optional[Dict[str, GenDataset]] = None,
  97. dataset_config: Optional[ray.train.DataConfig] = None,
  98. metadata: Optional[Dict[str, Any]] = None,
  99. resume_from_checkpoint: Optional[Checkpoint] = None,
  100. ):
  101. super(LightGBMTrainer, self).__init__(
  102. train_loop_per_worker=train_loop_per_worker,
  103. train_loop_config=train_loop_config,
  104. backend_config=lightgbm_config or LightGBMConfig(),
  105. scaling_config=scaling_config,
  106. dataset_config=dataset_config,
  107. run_config=run_config,
  108. datasets=datasets,
  109. resume_from_checkpoint=resume_from_checkpoint,
  110. metadata=metadata,
  111. )