| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- import tempfile
- from pathlib import Path
- from typing import TYPE_CHECKING, Optional
- import xgboost
- from ray.train._internal.framework_checkpoint import FrameworkCheckpoint
- from ray.util.annotations import PublicAPI
- if TYPE_CHECKING:
- from ray.data.preprocessor import Preprocessor
- @PublicAPI(stability="beta")
- class XGBoostCheckpoint(FrameworkCheckpoint):
- """A :py:class:`~ray.train.Checkpoint` with XGBoost-specific functionality."""
- MODEL_FILENAME = "model.json"
- @classmethod
- def from_model(
- cls,
- booster: xgboost.Booster,
- *,
- preprocessor: Optional["Preprocessor"] = None,
- path: Optional[str] = None,
- ) -> "XGBoostCheckpoint":
- """Create a :py:class:`~ray.train.Checkpoint` that stores an XGBoost
- model.
- Args:
- booster: The XGBoost model to store in the checkpoint.
- preprocessor: A fitted preprocessor to be applied before inference.
- path: The path to the directory where the checkpoint file will be saved.
- This should start as an empty directory, since the *entire*
- directory will be treated as the checkpoint when reported.
- By default, a temporary directory will be created.
- Returns:
- An :py:class:`XGBoostCheckpoint` containing the specified ``Estimator``.
- Examples:
- ... testcode::
- import numpy as np
- import ray
- from ray.train.xgboost import XGBoostCheckpoint
- import xgboost
- train_X = np.array([[1, 2], [3, 4]])
- train_y = np.array([0, 1])
- model = xgboost.XGBClassifier().fit(train_X, train_y)
- checkpoint = XGBoostCheckpoint.from_model(model.get_booster())
- """
- checkpoint_path = Path(path or tempfile.mkdtemp())
- if not checkpoint_path.is_dir():
- raise ValueError(f"`path` must be a directory, but got: {checkpoint_path}")
- booster.save_model(checkpoint_path.joinpath(cls.MODEL_FILENAME).as_posix())
- checkpoint = cls.from_directory(checkpoint_path.as_posix())
- if preprocessor:
- checkpoint.set_preprocessor(preprocessor)
- return checkpoint
- def get_model(self) -> xgboost.Booster:
- """Retrieve the XGBoost model stored in this checkpoint."""
- with self.as_directory() as checkpoint_path:
- booster = xgboost.Booster()
- booster.load_model(Path(checkpoint_path, self.MODEL_FILENAME).as_posix())
- return booster
|