| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- import tempfile
- from contextlib import contextmanager
- from pathlib import Path
- from typing import Dict, Optional
- from lightgbm import Booster
- import ray.tune
- from ray.train.lightgbm._lightgbm_utils import RayReportCallback
- from ray.tune import Checkpoint
- from ray.util.annotations import Deprecated, PublicAPI
- @PublicAPI(stability="beta")
- class TuneReportCheckpointCallback(RayReportCallback):
- """Creates a callback that reports metrics and checkpoints model.
- Args:
- metrics: Metrics to report. If this is a list,
- each item should be a metric key reported by LightGBM,
- and it will be reported to Ray Train/Tune under the same name.
- This can also be a dict of {<key-to-report>: <lightgbm-metric-key>},
- which can be used to rename LightGBM default metrics.
- filename: Customize the saved checkpoint file type by passing
- a filename. Defaults to "model.txt".
- frequency: How often to save checkpoints, in terms of iterations.
- Defaults to 0 (no checkpoints are saved during training).
- checkpoint_at_end: Whether or not to save a checkpoint at the end of training.
- results_postprocessing_fn: An optional Callable that takes in
- the metrics dict that will be reported (after it has been flattened)
- and returns a modified dict.
- Examples
- --------
- Reporting checkpoints and metrics to Ray Tune when running many
- independent LightGBM trials (without data parallelism within a trial).
- .. testcode::
- :skipif: True
- import lightgbm
- from ray.tune.integration.lightgbm import TuneReportCheckpointCallback
- config = {
- # ...
- "metric": ["binary_logloss", "binary_error"],
- }
- # Report only log loss to Tune after each validation epoch.
- bst = lightgbm.train(
- ...,
- callbacks=[
- TuneReportCheckpointCallback(
- metrics={"loss": "eval-binary_logloss"}, frequency=1
- )
- ],
- )
- """
- @contextmanager
- def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]:
- with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
- model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
- yield Checkpoint.from_directory(temp_checkpoint_dir)
- def _save_and_report_checkpoint(self, report_dict: Dict, model: Booster):
- with self._get_checkpoint(model=model) as checkpoint:
- ray.tune.report(report_dict, checkpoint=checkpoint)
- def _report_metrics(self, report_dict: Dict):
- ray.tune.report(report_dict)
- @Deprecated
- class TuneReportCallback:
- def __new__(cls: type, *args, **kwargs):
- # TODO(justinvyu): [code_removal] Remove in 2.11.
- raise DeprecationWarning(
- "`TuneReportCallback` is deprecated. "
- "Use `ray.tune.integration.lightgbm.TuneReportCheckpointCallback` instead."
- )
|