xgboost_checkpoint.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import tempfile
  2. from pathlib import Path
  3. from typing import TYPE_CHECKING, Optional
  4. import xgboost
  5. from ray.train._internal.framework_checkpoint import FrameworkCheckpoint
  6. from ray.util.annotations import PublicAPI
  7. if TYPE_CHECKING:
  8. from ray.data.preprocessor import Preprocessor
  9. @PublicAPI(stability="beta")
  10. class XGBoostCheckpoint(FrameworkCheckpoint):
  11. """A :py:class:`~ray.train.Checkpoint` with XGBoost-specific functionality."""
  12. MODEL_FILENAME = "model.json"
  13. @classmethod
  14. def from_model(
  15. cls,
  16. booster: xgboost.Booster,
  17. *,
  18. preprocessor: Optional["Preprocessor"] = None,
  19. path: Optional[str] = None,
  20. ) -> "XGBoostCheckpoint":
  21. """Create a :py:class:`~ray.train.Checkpoint` that stores an XGBoost
  22. model.
  23. Args:
  24. booster: The XGBoost model to store in the checkpoint.
  25. preprocessor: A fitted preprocessor to be applied before inference.
  26. path: The path to the directory where the checkpoint file will be saved.
  27. This should start as an empty directory, since the *entire*
  28. directory will be treated as the checkpoint when reported.
  29. By default, a temporary directory will be created.
  30. Returns:
  31. An :py:class:`XGBoostCheckpoint` containing the specified ``Estimator``.
  32. Examples:
  33. ... testcode::
  34. import numpy as np
  35. import ray
  36. from ray.train.xgboost import XGBoostCheckpoint
  37. import xgboost
  38. train_X = np.array([[1, 2], [3, 4]])
  39. train_y = np.array([0, 1])
  40. model = xgboost.XGBClassifier().fit(train_X, train_y)
  41. checkpoint = XGBoostCheckpoint.from_model(model.get_booster())
  42. """
  43. checkpoint_path = Path(path or tempfile.mkdtemp())
  44. if not checkpoint_path.is_dir():
  45. raise ValueError(f"`path` must be a directory, but got: {checkpoint_path}")
  46. booster.save_model(checkpoint_path.joinpath(cls.MODEL_FILENAME).as_posix())
  47. checkpoint = cls.from_directory(checkpoint_path.as_posix())
  48. if preprocessor:
  49. checkpoint.set_preprocessor(preprocessor)
  50. return checkpoint
  51. def get_model(self) -> xgboost.Booster:
  52. """Retrieve the XGBoost model stored in this checkpoint."""
  53. with self.as_directory() as checkpoint_path:
  54. booster = xgboost.Booster()
  55. booster.load_model(Path(checkpoint_path, self.MODEL_FILENAME).as_posix())
  56. return booster