lightgbm_trainer.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. import logging
  2. from functools import partial
  3. from typing import Any, Callable, Dict, Optional, Union
  4. import lightgbm
  5. import ray
  6. from ray.train import Checkpoint
  7. from ray.train.constants import TRAIN_DATASET_KEY
  8. from ray.train.lightgbm._lightgbm_utils import RayTrainReportCallback
  9. from ray.train.lightgbm.config import LightGBMConfig
  10. from ray.train.lightgbm.v2 import LightGBMTrainer as SimpleLightGBMTrainer
  11. from ray.train.trainer import GenDataset
  12. from ray.train.utils import _log_deprecation_warning
  13. from ray.util.annotations import PublicAPI
  14. logger = logging.getLogger(__name__)
  15. LEGACY_LIGHTGBM_TRAINER_DEPRECATION_MESSAGE = (
  16. "Passing in `lightgbm.train` kwargs such as `params`, `num_boost_round`, "
  17. "`label_column`, etc. to `LightGBMTrainer` is deprecated "
  18. "in favor of the new API which accepts a `train_loop_per_worker` argument, "
  19. "similar to the other DataParallelTrainer APIs (ex: TorchTrainer). "
  20. "See this issue for more context: "
  21. "https://github.com/ray-project/ray/issues/50042"
  22. )
  23. def _lightgbm_train_fn_per_worker(
  24. config: dict,
  25. label_column: str,
  26. num_boost_round: int,
  27. dataset_keys: set,
  28. lightgbm_train_kwargs: dict,
  29. ):
  30. checkpoint = ray.train.get_checkpoint()
  31. starting_model = None
  32. remaining_iters = num_boost_round
  33. if checkpoint:
  34. starting_model = RayTrainReportCallback.get_model(checkpoint)
  35. starting_iter = starting_model.current_iteration()
  36. remaining_iters = num_boost_round - starting_iter
  37. logger.info(
  38. f"Model loaded from checkpoint will train for "
  39. f"additional {remaining_iters} iterations (trees) in order "
  40. "to achieve the target number of iterations "
  41. f"({num_boost_round=})."
  42. )
  43. train_ds_iter = ray.train.get_dataset_shard(TRAIN_DATASET_KEY)
  44. train_df = train_ds_iter.materialize().to_pandas()
  45. eval_ds_iters = {
  46. k: ray.train.get_dataset_shard(k)
  47. for k in dataset_keys
  48. if k != TRAIN_DATASET_KEY
  49. }
  50. eval_dfs = {k: d.materialize().to_pandas() for k, d in eval_ds_iters.items()}
  51. train_X, train_y = train_df.drop(label_column, axis=1), train_df[label_column]
  52. train_set = lightgbm.Dataset(train_X, label=train_y)
  53. # NOTE: Include the training dataset in the evaluation datasets.
  54. # This allows `train-*` metrics to be calculated and reported.
  55. valid_sets = [train_set]
  56. valid_names = [TRAIN_DATASET_KEY]
  57. for eval_name, eval_df in eval_dfs.items():
  58. eval_X, eval_y = eval_df.drop(label_column, axis=1), eval_df[label_column]
  59. valid_sets.append(lightgbm.Dataset(eval_X, label=eval_y))
  60. valid_names.append(eval_name)
  61. # Add network params of the worker group to enable distributed training.
  62. config.update(ray.train.lightgbm.get_network_params())
  63. config.setdefault("tree_learner", "data_parallel")
  64. config.setdefault("pre_partition", True)
  65. lightgbm.train(
  66. params=config,
  67. train_set=train_set,
  68. num_boost_round=remaining_iters,
  69. valid_sets=valid_sets,
  70. valid_names=valid_names,
  71. init_model=starting_model,
  72. **lightgbm_train_kwargs,
  73. )
  74. @PublicAPI(stability="beta")
  75. class LightGBMTrainer(SimpleLightGBMTrainer):
  76. """A Trainer for distributed data-parallel LightGBM training.
  77. Example
  78. -------
  79. .. testcode::
  80. :skipif: True
  81. import lightgbm
  82. import ray.data
  83. import ray.train
  84. from ray.train.lightgbm import RayTrainReportCallback, LightGBMTrainer
  85. def train_fn_per_worker(config: dict):
  86. # (Optional) Add logic to resume training state from a checkpoint.
  87. # ray.train.get_checkpoint()
  88. # 1. Get the dataset shard for the worker and convert to a `lightgbm.Dataset`
  89. train_ds_iter, eval_ds_iter = (
  90. ray.train.get_dataset_shard("train"),
  91. ray.train.get_dataset_shard("validation"),
  92. )
  93. train_ds, eval_ds = train_ds_iter.materialize(), eval_ds_iter.materialize()
  94. train_df, eval_df = train_ds.to_pandas(), eval_ds.to_pandas()
  95. train_X, train_y = train_df.drop("y", axis=1), train_df["y"]
  96. eval_X, eval_y = eval_df.drop("y", axis=1), eval_df["y"]
  97. dtrain = lightgbm.Dataset(train_X, label=train_y)
  98. deval = lightgbm.Dataset(eval_X, label=eval_y)
  99. params = {
  100. "objective": "regression",
  101. "metric": "l2",
  102. "learning_rate": 1e-4,
  103. "subsample": 0.5,
  104. "max_depth": 2,
  105. # Adding the line below is the only change needed
  106. # for your `lgb.train` call!
  107. **ray.train.lightgbm.get_network_params(),
  108. }
  109. # 2. Do distributed data-parallel training.
  110. # Ray Train sets up the necessary coordinator processes and
  111. # environment variables for your workers to communicate with each other.
  112. bst = lightgbm.train(
  113. params,
  114. train_set=dtrain,
  115. valid_sets=[deval],
  116. valid_names=["validation"],
  117. num_boost_round=10,
  118. callbacks=[RayTrainReportCallback()],
  119. )
  120. train_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(32)])
  121. eval_ds = ray.data.from_items([{"x": x, "y": x + 1} for x in range(16)])
  122. trainer = LightGBMTrainer(
  123. train_fn_per_worker,
  124. datasets={"train": train_ds, "validation": eval_ds},
  125. scaling_config=ray.train.ScalingConfig(num_workers=4),
  126. )
  127. result = trainer.fit()
  128. booster = RayTrainReportCallback.get_model(result.checkpoint)
  129. Args:
  130. train_loop_per_worker: The training function to execute on each worker.
  131. This function can either take in zero arguments or a single ``Dict``
  132. argument which is set by defining ``train_loop_config``.
  133. Within this function you can use any of the
  134. :ref:`Ray Train Loop utilities <train-loop-api>`.
  135. train_loop_config: A configuration ``Dict`` to pass in as an argument to
  136. ``train_loop_per_worker``.
  137. This is typically used for specifying hyperparameters.
  138. lightgbm_config: The configuration for setting up the distributed lightgbm
  139. backend. Defaults to using the "rabit" backend.
  140. See :class:`~ray.train.lightgbm.LightGBMConfig` for more info.
  141. datasets: The Ray Datasets to use for training and validation.
  142. dataset_config: The configuration for ingesting the input ``datasets``.
  143. By default, all the Ray Datasets are split equally across workers.
  144. See :class:`~ray.train.DataConfig` for more details.
  145. scaling_config: The configuration for how to scale data parallel training.
  146. ``num_workers`` determines how many Python processes are used for training,
  147. and ``use_gpu`` determines whether or not each process should use GPUs.
  148. See :class:`~ray.train.ScalingConfig` for more info.
  149. run_config: The configuration for the execution of the training run.
  150. See :class:`~ray.train.RunConfig` for more info.
  151. resume_from_checkpoint: A checkpoint to resume training from.
  152. This checkpoint can be accessed from within ``train_loop_per_worker``
  153. by calling ``ray.train.get_checkpoint()``.
  154. metadata: Dict that should be made available via
  155. `ray.train.get_context().get_metadata()` and in `checkpoint.get_metadata()`
  156. for checkpoints saved from this Trainer. Must be JSON-serializable.
  157. label_column: [Deprecated] Name of the label column. A column with this name
  158. must be present in the training dataset.
  159. params: [Deprecated] LightGBM training parameters.
  160. Refer to `LightGBM documentation <https://lightgbm.readthedocs.io/>`_
  161. for a list of possible parameters.
  162. num_boost_round: [Deprecated] Target number of boosting iterations (trees in the model).
  163. Note that unlike in ``lightgbm.train``, this is the target number
  164. of trees, meaning that if you set ``num_boost_round=10`` and pass a model
  165. that has already been trained for 5 iterations, it will be trained for 5
  166. iterations more, instead of 10 more.
  167. **train_kwargs: [Deprecated] Additional kwargs passed to ``lightgbm.train()`` function.
  168. """
  169. _handles_checkpoint_freq = True
  170. _handles_checkpoint_at_end = True
  171. def __init__(
  172. self,
  173. train_loop_per_worker: Optional[
  174. Union[Callable[[], None], Callable[[Dict], None]]
  175. ] = None,
  176. *,
  177. train_loop_config: Optional[Dict] = None,
  178. lightgbm_config: Optional[LightGBMConfig] = None,
  179. scaling_config: Optional[ray.train.ScalingConfig] = None,
  180. run_config: Optional[ray.train.RunConfig] = None,
  181. datasets: Optional[Dict[str, GenDataset]] = None,
  182. dataset_config: Optional[ray.train.DataConfig] = None,
  183. resume_from_checkpoint: Optional[Checkpoint] = None,
  184. metadata: Optional[Dict[str, Any]] = None,
  185. # TODO: [Deprecated] Legacy LightGBMTrainer API
  186. label_column: Optional[str] = None,
  187. params: Optional[Dict[str, Any]] = None,
  188. num_boost_round: Optional[int] = None,
  189. **train_kwargs,
  190. ):
  191. # TODO: [Deprecated] Legacy LightGBMTrainer API
  192. legacy_api = train_loop_per_worker is None
  193. if legacy_api:
  194. train_loop_per_worker = self._get_legacy_train_fn_per_worker(
  195. lightgbm_train_kwargs=train_kwargs,
  196. run_config=run_config,
  197. label_column=label_column,
  198. num_boost_round=num_boost_round,
  199. datasets=datasets,
  200. )
  201. train_loop_config = params or {}
  202. elif train_kwargs:
  203. _log_deprecation_warning(
  204. "Passing `lightgbm.train` kwargs to `LightGBMTrainer` is deprecated. "
  205. f"Got kwargs: {train_kwargs.keys()}\n"
  206. "In your training function, you can call `lightgbm.train(**kwargs)` "
  207. "with arbitrary arguments. "
  208. f"{LEGACY_LIGHTGBM_TRAINER_DEPRECATION_MESSAGE}"
  209. )
  210. super(LightGBMTrainer, self).__init__(
  211. train_loop_per_worker=train_loop_per_worker,
  212. train_loop_config=train_loop_config,
  213. lightgbm_config=lightgbm_config,
  214. scaling_config=scaling_config,
  215. run_config=run_config,
  216. datasets=datasets,
  217. dataset_config=dataset_config,
  218. resume_from_checkpoint=resume_from_checkpoint,
  219. metadata=metadata,
  220. )
  221. def _get_legacy_train_fn_per_worker(
  222. self,
  223. lightgbm_train_kwargs: Dict,
  224. run_config: Optional[ray.train.RunConfig],
  225. datasets: Optional[Dict[str, GenDataset]],
  226. label_column: Optional[str],
  227. num_boost_round: Optional[int],
  228. ) -> Callable[[Dict], None]:
  229. """Get the training function for the legacy LightGBMTrainer API."""
  230. datasets = datasets or {}
  231. if not datasets.get(TRAIN_DATASET_KEY):
  232. raise ValueError(
  233. "`datasets` must be provided for the LightGBMTrainer API "
  234. "if `train_loop_per_worker` is not provided. "
  235. "This dict must contain the training dataset under the "
  236. f"key: '{TRAIN_DATASET_KEY}'. "
  237. f"Got keys: {list(datasets.keys())}"
  238. )
  239. if not label_column:
  240. raise ValueError(
  241. "`label_column` must be provided for the LightGBMTrainer API "
  242. "if `train_loop_per_worker` is not provided. "
  243. "This is the column name of the label in the dataset."
  244. )
  245. num_boost_round = num_boost_round or 10
  246. _log_deprecation_warning(LEGACY_LIGHTGBM_TRAINER_DEPRECATION_MESSAGE)
  247. # Initialize a default Ray Train metrics/checkpoint reporting callback if needed
  248. callbacks = lightgbm_train_kwargs.get("callbacks", [])
  249. user_supplied_callback = any(
  250. isinstance(callback, RayTrainReportCallback) for callback in callbacks
  251. )
  252. callback_kwargs = {}
  253. if run_config:
  254. checkpoint_frequency = run_config.checkpoint_config.checkpoint_frequency
  255. checkpoint_at_end = run_config.checkpoint_config.checkpoint_at_end
  256. callback_kwargs["frequency"] = checkpoint_frequency
  257. # Default `checkpoint_at_end=True` unless the user explicitly sets it.
  258. callback_kwargs["checkpoint_at_end"] = (
  259. checkpoint_at_end if checkpoint_at_end is not None else True
  260. )
  261. if not user_supplied_callback:
  262. callbacks.append(RayTrainReportCallback(**callback_kwargs))
  263. lightgbm_train_kwargs["callbacks"] = callbacks
  264. train_fn_per_worker = partial(
  265. _lightgbm_train_fn_per_worker,
  266. label_column=label_column,
  267. num_boost_round=num_boost_round,
  268. dataset_keys=set(datasets),
  269. lightgbm_train_kwargs=lightgbm_train_kwargs,
  270. )
  271. return train_fn_per_worker
  272. @classmethod
  273. def get_model(
  274. cls,
  275. checkpoint: Checkpoint,
  276. ) -> lightgbm.Booster:
  277. """Retrieve the LightGBM model stored in this checkpoint."""
  278. return RayTrainReportCallback.get_model(checkpoint)