xgboost.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import tempfile
  2. from contextlib import contextmanager
  3. from pathlib import Path
  4. from typing import Callable, Dict, List, Optional, Union
  5. from xgboost.core import Booster
  6. import ray.tune
  7. from ray.train.xgboost._xgboost_utils import RayReportCallback
  8. from ray.tune import Checkpoint
  9. from ray.util.annotations import Deprecated, PublicAPI
  10. @PublicAPI(stability="beta")
  11. class TuneReportCheckpointCallback(RayReportCallback):
  12. """XGBoost callback to save checkpoints and report metrics for Ray Tune.
  13. Args:
  14. metrics: Metrics to report. If this is a list,
  15. each item describes the metric key reported to XGBoost,
  16. and it will be reported under the same name.
  17. This can also be a dict of {<key-to-report>: <xgboost-metric-key>},
  18. which can be used to rename xgboost default metrics.
  19. filename: Customize the saved checkpoint file type by passing
  20. a filename. Defaults to "model.ubj".
  21. frequency: How often to save checkpoints, in terms of iterations.
  22. Defaults to 0 (no checkpoints are saved during training).
  23. checkpoint_at_end: Whether or not to save a checkpoint at the end of training.
  24. results_postprocessing_fn: An optional Callable that takes in
  25. the metrics dict that will be reported (after it has been flattened)
  26. and returns a modified dict. For example, this can be used to
  27. average results across CV fold when using ``xgboost.cv``.
  28. Examples
  29. --------
  30. Reporting checkpoints and metrics to Ray Tune when running many
  31. independent xgboost trials (without data parallelism within a trial).
  32. .. testcode::
  33. :skipif: True
  34. import xgboost
  35. from ray.tune import Tuner
  36. from ray.tune.integration.xgboost import TuneReportCheckpointCallback
  37. def train_fn(config):
  38. # Report log loss to Ray Tune after each validation epoch.
  39. bst = xgboost.train(
  40. ...,
  41. callbacks=[
  42. TuneReportCheckpointCallback(
  43. metrics={"loss": "eval-logloss"}, frequency=1
  44. )
  45. ],
  46. )
  47. tuner = Tuner(train_fn)
  48. results = tuner.fit()
  49. """
  50. def __init__(
  51. self,
  52. metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
  53. filename: str = RayReportCallback.CHECKPOINT_NAME,
  54. frequency: int = 0,
  55. checkpoint_at_end: bool = True,
  56. results_postprocessing_fn: Optional[
  57. Callable[[Dict[str, Union[float, List[float]]]], Dict[str, float]]
  58. ] = None,
  59. ):
  60. super().__init__(
  61. metrics=metrics,
  62. filename=filename,
  63. frequency=frequency,
  64. checkpoint_at_end=checkpoint_at_end,
  65. results_postprocessing_fn=results_postprocessing_fn,
  66. )
  67. @contextmanager
  68. def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]:
  69. with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
  70. model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
  71. yield Checkpoint(temp_checkpoint_dir)
  72. def _save_and_report_checkpoint(self, report_dict: Dict, model: Booster):
  73. with self._get_checkpoint(model=model) as checkpoint:
  74. ray.tune.report(report_dict, checkpoint=checkpoint)
  75. def _report_metrics(self, report_dict: Dict):
  76. ray.tune.report(report_dict)
  77. @Deprecated
  78. class TuneReportCallback:
  79. def __new__(cls: type, *args, **kwargs):
  80. # TODO(justinvyu): [code_removal] Remove in 2.11.
  81. raise DeprecationWarning(
  82. "`TuneReportCallback` is deprecated. "
  83. "Use `ray.tune.integration.xgboost.TuneReportCheckpointCallback` instead."
  84. )