_lightgbm_utils.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. import tempfile
  2. from abc import abstractmethod
  3. from contextlib import contextmanager
  4. from pathlib import Path
  5. from typing import Callable, Dict, List, Optional, Union
  6. from lightgbm.basic import Booster
  7. from lightgbm.callback import CallbackEnv
  8. import ray.train
  9. from ray.train import Checkpoint
  10. from ray.tune.utils import flatten_dict
  11. from ray.util.annotations import PublicAPI
  12. class RayReportCallback:
  13. CHECKPOINT_NAME = "model.txt"
  14. def __init__(
  15. self,
  16. metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
  17. filename: str = CHECKPOINT_NAME,
  18. frequency: int = 0,
  19. checkpoint_at_end: bool = True,
  20. results_postprocessing_fn: Optional[
  21. Callable[[Dict[str, Union[float, List[float]]]], Dict[str, float]]
  22. ] = None,
  23. ):
  24. if isinstance(metrics, str):
  25. metrics = [metrics]
  26. self._metrics = metrics
  27. self._filename = filename
  28. self._frequency = frequency
  29. self._checkpoint_at_end = checkpoint_at_end
  30. self._results_postprocessing_fn = results_postprocessing_fn
  31. @classmethod
  32. def get_model(
  33. cls, checkpoint: Checkpoint, filename: str = CHECKPOINT_NAME
  34. ) -> Booster:
  35. """Retrieve the model stored in a checkpoint reported by this callback.
  36. Args:
  37. checkpoint: The checkpoint object returned by a training run.
  38. The checkpoint should be saved by an instance of this callback.
  39. filename: The filename to load the model from, which should match
  40. the filename used when creating the callback.
  41. Returns:
  42. The model loaded from the checkpoint.
  43. """
  44. with checkpoint.as_directory() as checkpoint_path:
  45. return Booster(model_file=Path(checkpoint_path, filename).as_posix())
  46. def _get_report_dict(self, evals_log: Dict[str, Dict[str, list]]) -> dict:
  47. result_dict = flatten_dict(evals_log, delimiter="-")
  48. if not self._metrics:
  49. report_dict = result_dict
  50. else:
  51. report_dict = {}
  52. for key in self._metrics:
  53. if isinstance(self._metrics, dict):
  54. metric = self._metrics[key]
  55. else:
  56. metric = key
  57. report_dict[key] = result_dict[metric]
  58. if self._results_postprocessing_fn:
  59. report_dict = self._results_postprocessing_fn(report_dict)
  60. return report_dict
  61. def _get_eval_result(self, env: CallbackEnv) -> dict:
  62. eval_result = {}
  63. for entry in env.evaluation_result_list:
  64. data_name, eval_name, result = entry[0:3]
  65. if len(entry) > 4:
  66. stdv = entry[4]
  67. suffix = "-mean"
  68. else:
  69. stdv = None
  70. suffix = ""
  71. if data_name not in eval_result:
  72. eval_result[data_name] = {}
  73. eval_result[data_name][eval_name + suffix] = result
  74. if stdv is not None:
  75. eval_result[data_name][eval_name + "-stdv"] = stdv
  76. return eval_result
  77. @abstractmethod
  78. def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]:
  79. """Get checkpoint from model.
  80. This method needs to be implemented by subclasses.
  81. """
  82. raise NotImplementedError
  83. @abstractmethod
  84. def _save_and_report_checkpoint(self, report_dict: Dict, model: Booster):
  85. """Save checkpoint and report metrics corresonding to this checkpoint.
  86. This method needs to be implemented by subclasses.
  87. """
  88. raise NotImplementedError
  89. @abstractmethod
  90. def _report_metrics(self, report_dict: Dict):
  91. """Report Metrics.
  92. This method needs to be implemented by subclasses.
  93. """
  94. raise NotImplementedError
  95. def __call__(self, env: CallbackEnv) -> None:
  96. eval_result = self._get_eval_result(env)
  97. report_dict = self._get_report_dict(eval_result)
  98. # Ex: if frequency=2, checkpoint_at_end=True and num_boost_rounds=11,
  99. # you will checkpoint at iterations 1, 3, 5, ..., 9, and 10 (checkpoint_at_end)
  100. # (iterations count from 0)
  101. on_last_iter = env.iteration == env.end_iteration - 1
  102. should_checkpoint_at_end = on_last_iter and self._checkpoint_at_end
  103. should_checkpoint_with_frequency = (
  104. self._frequency != 0 and (env.iteration + 1) % self._frequency == 0
  105. )
  106. should_checkpoint = should_checkpoint_at_end or should_checkpoint_with_frequency
  107. if should_checkpoint:
  108. self._save_and_report_checkpoint(report_dict, env.model)
  109. else:
  110. self._report_metrics(report_dict)
  111. @PublicAPI(stability="beta")
  112. class RayTrainReportCallback(RayReportCallback):
  113. """Creates a callback that reports metrics and checkpoints model.
  114. Args:
  115. metrics: Metrics to report. If this is a list,
  116. each item should be a metric key reported by LightGBM,
  117. and it will be reported to Ray Train/Tune under the same name.
  118. This can also be a dict of {<key-to-report>: <lightgbm-metric-key>},
  119. which can be used to rename LightGBM default metrics.
  120. filename: Customize the saved checkpoint file type by passing
  121. a filename. Defaults to "model.txt".
  122. frequency: How often to save checkpoints, in terms of iterations.
  123. Defaults to 0 (no checkpoints are saved during training).
  124. checkpoint_at_end: Whether or not to save a checkpoint at the end of training.
  125. results_postprocessing_fn: An optional Callable that takes in
  126. the metrics dict that will be reported (after it has been flattened)
  127. and returns a modified dict.
  128. Examples
  129. --------
  130. Reporting checkpoints and metrics to Ray Tune when running many
  131. independent LightGBM trials (without data parallelism within a trial).
  132. .. testcode::
  133. :skipif: True
  134. import lightgbm
  135. from ray.train.lightgbm import RayTrainReportCallback
  136. config = {
  137. # ...
  138. "metric": ["binary_logloss", "binary_error"],
  139. }
  140. # Report only log loss to Tune after each validation epoch.
  141. bst = lightgbm.train(
  142. ...,
  143. callbacks=[
  144. RayTrainReportCallback(
  145. metrics={"loss": "eval-binary_logloss"}, frequency=1
  146. )
  147. ],
  148. )
  149. Loading a model from a checkpoint reported by this callback.
  150. .. testcode::
  151. :skipif: True
  152. from ray.train.lightgbm import RayTrainReportCallback
  153. # Get a `Checkpoint` object that is saved by the callback during training.
  154. result = trainer.fit()
  155. booster = RayTrainReportCallback.get_model(result.checkpoint)
  156. """
  157. @contextmanager
  158. def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]:
  159. if ray.train.get_context().get_world_rank() in (0, None):
  160. with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
  161. model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
  162. yield Checkpoint.from_directory(temp_checkpoint_dir)
  163. else:
  164. yield None
  165. def _save_and_report_checkpoint(self, report_dict: Dict, model: Booster):
  166. with self._get_checkpoint(model=model) as checkpoint:
  167. ray.train.report(report_dict, checkpoint=checkpoint)
  168. def _report_metrics(self, report_dict: Dict):
  169. ray.train.report(report_dict)