lightgbm.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import tempfile
  2. from contextlib import contextmanager
  3. from pathlib import Path
  4. from typing import Dict, Optional
  5. from lightgbm import Booster
  6. import ray.tune
  7. from ray.train.lightgbm._lightgbm_utils import RayReportCallback
  8. from ray.tune import Checkpoint
  9. from ray.util.annotations import Deprecated, PublicAPI
  10. @PublicAPI(stability="beta")
  11. class TuneReportCheckpointCallback(RayReportCallback):
  12. """Creates a callback that reports metrics and checkpoints model.
  13. Args:
  14. metrics: Metrics to report. If this is a list,
  15. each item should be a metric key reported by LightGBM,
  16. and it will be reported to Ray Train/Tune under the same name.
  17. This can also be a dict of {<key-to-report>: <lightgbm-metric-key>},
  18. which can be used to rename LightGBM default metrics.
  19. filename: Customize the saved checkpoint file type by passing
  20. a filename. Defaults to "model.txt".
  21. frequency: How often to save checkpoints, in terms of iterations.
  22. Defaults to 0 (no checkpoints are saved during training).
  23. checkpoint_at_end: Whether or not to save a checkpoint at the end of training.
  24. results_postprocessing_fn: An optional Callable that takes in
  25. the metrics dict that will be reported (after it has been flattened)
  26. and returns a modified dict.
  27. Examples
  28. --------
  29. Reporting checkpoints and metrics to Ray Tune when running many
  30. independent LightGBM trials (without data parallelism within a trial).
  31. .. testcode::
  32. :skipif: True
  33. import lightgbm
  34. from ray.tune.integration.lightgbm import TuneReportCheckpointCallback
  35. config = {
  36. # ...
  37. "metric": ["binary_logloss", "binary_error"],
  38. }
  39. # Report only log loss to Tune after each validation epoch.
  40. bst = lightgbm.train(
  41. ...,
  42. callbacks=[
  43. TuneReportCheckpointCallback(
  44. metrics={"loss": "eval-binary_logloss"}, frequency=1
  45. )
  46. ],
  47. )
  48. """
  49. @contextmanager
  50. def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]:
  51. with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
  52. model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
  53. yield Checkpoint.from_directory(temp_checkpoint_dir)
  54. def _save_and_report_checkpoint(self, report_dict: Dict, model: Booster):
  55. with self._get_checkpoint(model=model) as checkpoint:
  56. ray.tune.report(report_dict, checkpoint=checkpoint)
  57. def _report_metrics(self, report_dict: Dict):
  58. ray.tune.report(report_dict)
  59. @Deprecated
  60. class TuneReportCallback:
  61. def __new__(cls: type, *args, **kwargs):
  62. # TODO(justinvyu): [code_removal] Remove in 2.11.
  63. raise DeprecationWarning(
  64. "`TuneReportCallback` is deprecated. "
  65. "Use `ray.tune.integration.lightgbm.TuneReportCheckpointCallback` instead."
  66. )