import tempfile from pathlib import Path from typing import TYPE_CHECKING, Optional import lightgbm from ray.train._internal.framework_checkpoint import FrameworkCheckpoint from ray.util.annotations import PublicAPI if TYPE_CHECKING: from ray.data.preprocessor import Preprocessor @PublicAPI(stability="beta") class LightGBMCheckpoint(FrameworkCheckpoint): """A :py:class:`~ray.train.Checkpoint` with LightGBM-specific functionality.""" MODEL_FILENAME = "model.txt" @classmethod def from_model( cls, booster: lightgbm.Booster, *, preprocessor: Optional["Preprocessor"] = None, path: Optional[str] = None, ) -> "LightGBMCheckpoint": """Create a :py:class:`~ray.train.Checkpoint` that stores a LightGBM model. Args: booster: The LightGBM model to store in the checkpoint. preprocessor: A fitted preprocessor to be applied before inference. path: The path to the directory where the checkpoint file will be saved. This should start as an empty directory, since the *entire* directory will be treated as the checkpoint when reported. By default, a temporary directory will be created. Returns: An :py:class:`LightGBMCheckpoint` containing the specified ``Estimator``. Examples: .. testcode:: import lightgbm import numpy as np from ray.train.lightgbm import LightGBMCheckpoint train_X = np.array([[1, 2], [3, 4]]) train_y = np.array([0, 1]) model = lightgbm.LGBMClassifier().fit(train_X, train_y) checkpoint = LightGBMCheckpoint.from_model(model.booster_) """ checkpoint_path = Path(path or tempfile.mkdtemp()) if not checkpoint_path.is_dir(): raise ValueError(f"`path` must be a directory, but got: {checkpoint_path}") booster.save_model(checkpoint_path.joinpath(cls.MODEL_FILENAME).as_posix()) checkpoint = cls.from_directory(checkpoint_path.as_posix()) if preprocessor: checkpoint.set_preprocessor(preprocessor) return checkpoint def get_model(self) -> lightgbm.Booster: """Retrieve the LightGBM model stored in this checkpoint.""" with self.as_directory() as checkpoint_path: return lightgbm.Booster( model_file=Path(checkpoint_path, self.MODEL_FILENAME).as_posix() )