lightgbm_checkpoint.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import tempfile
  2. from pathlib import Path
  3. from typing import TYPE_CHECKING, Optional
  4. import lightgbm
  5. from ray.train._internal.framework_checkpoint import FrameworkCheckpoint
  6. from ray.util.annotations import PublicAPI
  7. if TYPE_CHECKING:
  8. from ray.data.preprocessor import Preprocessor
  9. @PublicAPI(stability="beta")
  10. class LightGBMCheckpoint(FrameworkCheckpoint):
  11. """A :py:class:`~ray.train.Checkpoint` with LightGBM-specific functionality."""
  12. MODEL_FILENAME = "model.txt"
  13. @classmethod
  14. def from_model(
  15. cls,
  16. booster: lightgbm.Booster,
  17. *,
  18. preprocessor: Optional["Preprocessor"] = None,
  19. path: Optional[str] = None,
  20. ) -> "LightGBMCheckpoint":
  21. """Create a :py:class:`~ray.train.Checkpoint` that stores a LightGBM model.
  22. Args:
  23. booster: The LightGBM model to store in the checkpoint.
  24. preprocessor: A fitted preprocessor to be applied before inference.
  25. path: The path to the directory where the checkpoint file will be saved.
  26. This should start as an empty directory, since the *entire*
  27. directory will be treated as the checkpoint when reported.
  28. By default, a temporary directory will be created.
  29. Returns:
  30. An :py:class:`LightGBMCheckpoint` containing the specified ``Estimator``.
  31. Examples:
  32. .. testcode::
  33. import lightgbm
  34. import numpy as np
  35. from ray.train.lightgbm import LightGBMCheckpoint
  36. train_X = np.array([[1, 2], [3, 4]])
  37. train_y = np.array([0, 1])
  38. model = lightgbm.LGBMClassifier().fit(train_X, train_y)
  39. checkpoint = LightGBMCheckpoint.from_model(model.booster_)
  40. """
  41. checkpoint_path = Path(path or tempfile.mkdtemp())
  42. if not checkpoint_path.is_dir():
  43. raise ValueError(f"`path` must be a directory, but got: {checkpoint_path}")
  44. booster.save_model(checkpoint_path.joinpath(cls.MODEL_FILENAME).as_posix())
  45. checkpoint = cls.from_directory(checkpoint_path.as_posix())
  46. if preprocessor:
  47. checkpoint.set_preprocessor(preprocessor)
  48. return checkpoint
  49. def get_model(self) -> lightgbm.Booster:
  50. """Retrieve the LightGBM model stored in this checkpoint."""
  51. with self.as_directory() as checkpoint_path:
  52. return lightgbm.Booster(
  53. model_file=Path(checkpoint_path, self.MODEL_FILENAME).as_posix()
  54. )