_xgboost_utils.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. import tempfile
  2. from abc import abstractmethod
  3. from collections import OrderedDict
  4. from contextlib import contextmanager
  5. from pathlib import Path
  6. from typing import Callable, Dict, List, Optional, Union
  7. from xgboost.core import Booster
  8. import ray.train
  9. from ray.train import Checkpoint
  10. from ray.tune.utils import flatten_dict
  11. from ray.util.annotations import PublicAPI
  12. try:
  13. from xgboost.callback import TrainingCallback
  14. except ImportError:
  15. class TrainingCallback:
  16. pass
  17. class RayReportCallback(TrainingCallback):
  18. CHECKPOINT_NAME = "model.ubj"
  19. def __init__(
  20. self,
  21. metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
  22. filename: str = CHECKPOINT_NAME,
  23. frequency: int = 0,
  24. checkpoint_at_end: bool = True,
  25. results_postprocessing_fn: Optional[
  26. Callable[[Dict[str, Union[float, List[float]]]], Dict[str, float]]
  27. ] = None,
  28. ):
  29. if isinstance(metrics, str):
  30. metrics = [metrics]
  31. self._metrics = metrics
  32. self._filename = filename
  33. self._frequency = frequency
  34. self._checkpoint_at_end = checkpoint_at_end
  35. self._results_postprocessing_fn = results_postprocessing_fn
  36. # Keeps track of the eval metrics from the last iteration,
  37. # so that the latest metrics can be reported with the checkpoint
  38. # at the end of training.
  39. self._evals_log = None
  40. # Keep track of the last checkpoint iteration to avoid double-checkpointing
  41. # when using `checkpoint_at_end=True`.
  42. self._last_checkpoint_iteration = None
  43. @classmethod
  44. def get_model(
  45. cls,
  46. checkpoint: Checkpoint,
  47. filename: str = CHECKPOINT_NAME,
  48. ) -> Booster:
  49. """Retrieve the model stored in a checkpoint reported by this callback.
  50. Args:
  51. checkpoint: The checkpoint object returned by a training run.
  52. The checkpoint should be saved by an instance of this callback.
  53. filename: The filename to load the model from, which should match
  54. the filename used when creating the callback.
  55. Returns:
  56. The model loaded from the checkpoint.
  57. """
  58. with checkpoint.as_directory() as checkpoint_path:
  59. booster = Booster()
  60. booster.load_model(Path(checkpoint_path, filename).as_posix())
  61. return booster
  62. def _get_report_dict(self, evals_log):
  63. if isinstance(evals_log, OrderedDict):
  64. # xgboost>=1.3
  65. result_dict = flatten_dict(evals_log, delimiter="-")
  66. for k in list(result_dict):
  67. result_dict[k] = result_dict[k][-1]
  68. else:
  69. # xgboost<1.3
  70. result_dict = dict(evals_log)
  71. if not self._metrics:
  72. report_dict = result_dict
  73. else:
  74. report_dict = {}
  75. for key in self._metrics:
  76. if isinstance(self._metrics, dict):
  77. metric = self._metrics[key]
  78. else:
  79. metric = key
  80. report_dict[key] = result_dict[metric]
  81. if self._results_postprocessing_fn:
  82. report_dict = self._results_postprocessing_fn(report_dict)
  83. return report_dict
  84. @abstractmethod
  85. def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]:
  86. """Get checkpoint from model.
  87. This method needs to be implemented by subclasses.
  88. """
  89. raise NotImplementedError
  90. @abstractmethod
  91. def _save_and_report_checkpoint(self, report_dict: Dict, model: Booster):
  92. """Save checkpoint and report metrics corresonding to this checkpoint.
  93. This method needs to be implemented by subclasses.
  94. """
  95. raise NotImplementedError
  96. @abstractmethod
  97. def _report_metrics(self, report_dict: Dict):
  98. """Report Metrics.
  99. This method needs to be implemented by subclasses.
  100. """
  101. raise NotImplementedError
  102. def after_iteration(self, model: Booster, epoch: int, evals_log: Dict):
  103. self._evals_log = evals_log
  104. checkpointing_disabled = self._frequency == 0
  105. # Ex: if frequency=2, checkpoint at epoch 1, 3, 5, ... (counting from 0)
  106. should_checkpoint = (
  107. not checkpointing_disabled and (epoch + 1) % self._frequency == 0
  108. )
  109. report_dict = self._get_report_dict(evals_log)
  110. if should_checkpoint:
  111. self._last_checkpoint_iteration = epoch
  112. self._save_and_report_checkpoint(report_dict, model)
  113. else:
  114. self._report_metrics(report_dict)
  115. def after_training(self, model: Booster) -> Booster:
  116. if not self._checkpoint_at_end:
  117. return model
  118. if (
  119. self._last_checkpoint_iteration is not None
  120. and model.num_boosted_rounds() - 1 == self._last_checkpoint_iteration
  121. ):
  122. # Avoids a duplicate checkpoint if the checkpoint frequency happens
  123. # to align with the last iteration.
  124. return model
  125. report_dict = self._get_report_dict(self._evals_log) if self._evals_log else {}
  126. self._save_and_report_checkpoint(report_dict, model)
  127. return model
  128. @PublicAPI(stability="beta")
  129. class RayTrainReportCallback(RayReportCallback):
  130. """XGBoost callback to save checkpoints and report metrics.
  131. Args:
  132. metrics: Metrics to report. If this is a list,
  133. each item describes the metric key reported to XGBoost,
  134. and it will be reported under the same name.
  135. This can also be a dict of {<key-to-report>: <xgboost-metric-key>},
  136. which can be used to rename xgboost default metrics.
  137. filename: Customize the saved checkpoint file type by passing
  138. a filename. Defaults to "model.ubj".
  139. frequency: How often to save checkpoints, in terms of iterations.
  140. Defaults to 0 (no checkpoints are saved during training).
  141. checkpoint_at_end: Whether or not to save a checkpoint at the end of training.
  142. results_postprocessing_fn: An optional Callable that takes in
  143. the metrics dict that will be reported (after it has been flattened)
  144. and returns a modified dict. For example, this can be used to
  145. average results across CV fold when using ``xgboost.cv``.
  146. Examples
  147. --------
  148. Reporting checkpoints and metrics to Ray Tune when running many
  149. independent xgboost trials (without data parallelism within a trial).
  150. .. testcode::
  151. :skipif: True
  152. import xgboost
  153. from ray.tune import Tuner
  154. from ray.train.xgboost import RayTrainReportCallback
  155. def train_fn(config):
  156. # Report log loss to Ray Tune after each validation epoch.
  157. bst = xgboost.train(
  158. ...,
  159. callbacks=[
  160. RayTrainReportCallback(
  161. metrics={"loss": "eval-logloss"}, frequency=1
  162. )
  163. ],
  164. )
  165. tuner = Tuner(train_fn)
  166. results = tuner.fit()
  167. Loading a model from a checkpoint reported by this callback.
  168. .. testcode::
  169. :skipif: True
  170. from ray.train.xgboost import RayTrainReportCallback
  171. # Get a `Checkpoint` object that is saved by the callback during training.
  172. result = trainer.fit()
  173. booster = RayTrainReportCallback.get_model(result.checkpoint)
  174. """
  175. def __init__(
  176. self,
  177. metrics: Optional[Union[str, List[str], Dict[str, str]]] = None,
  178. filename: str = RayReportCallback.CHECKPOINT_NAME,
  179. frequency: int = 0,
  180. checkpoint_at_end: bool = True,
  181. results_postprocessing_fn: Optional[
  182. Callable[[Dict[str, Union[float, List[float]]]], Dict[str, float]]
  183. ] = None,
  184. ):
  185. super().__init__(
  186. metrics=metrics,
  187. filename=filename,
  188. frequency=frequency,
  189. checkpoint_at_end=checkpoint_at_end,
  190. results_postprocessing_fn=results_postprocessing_fn,
  191. )
  192. @contextmanager
  193. def _get_checkpoint(self, model: Booster) -> Optional[Checkpoint]:
  194. # NOTE: The world rank returns None for Tune usage without Train.
  195. if ray.train.get_context().get_world_rank() in (0, None):
  196. with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
  197. model.save_model(Path(temp_checkpoint_dir, self._filename).as_posix())
  198. yield Checkpoint(temp_checkpoint_dir)
  199. else:
  200. yield None
  201. def _save_and_report_checkpoint(self, report_dict: Dict, model: Booster):
  202. with self._get_checkpoint(model=model) as checkpoint:
  203. ray.train.report(report_dict, checkpoint=checkpoint)
  204. def _report_metrics(self, report_dict: Dict):
  205. ray.train.report(report_dict)