lightgbm_example.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import lightgbm as lgb
  2. import sklearn.datasets
  3. import sklearn.metrics
  4. from sklearn.model_selection import train_test_split
  5. from ray import tune
  6. from ray.tune.integration.lightgbm import TuneReportCheckpointCallback
  7. from ray.tune.schedulers import ASHAScheduler
  8. def train_breast_cancer(config: dict):
  9. # This is a simple training function to be passed into Tune
  10. # Load dataset
  11. data, target = sklearn.datasets.load_breast_cancer(return_X_y=True)
  12. # Split into train and test set
  13. train_x, test_x, train_y, test_y = train_test_split(data, target, test_size=0.25)
  14. # Build input Datasets for LightGBM
  15. train_set = lgb.Dataset(train_x, label=train_y)
  16. test_set = lgb.Dataset(test_x, label=test_y)
  17. # Train the classifier, using the Tune callback
  18. lgb.train(
  19. config,
  20. train_set,
  21. valid_sets=[test_set],
  22. valid_names=["eval"],
  23. callbacks=[
  24. TuneReportCheckpointCallback(
  25. {
  26. "binary_error": "eval-binary_error",
  27. "binary_logloss": "eval-binary_logloss",
  28. }
  29. )
  30. ],
  31. )
  32. def train_breast_cancer_cv(config: dict):
  33. # This is a simple training function to be passed into Tune, using
  34. # lightgbm's cross validation functionality
  35. # Load dataset
  36. data, target = sklearn.datasets.load_breast_cancer(return_X_y=True)
  37. train_set = lgb.Dataset(data, label=target)
  38. # Run CV, using the Tune callback
  39. lgb.cv(
  40. config,
  41. train_set,
  42. stratified=True,
  43. # Checkpointing is not supported for CV
  44. # LightGBM aggregates metrics over folds automatically
  45. # with the cv_agg key. Both mean and standard deviation
  46. # are provided.
  47. callbacks=[
  48. TuneReportCheckpointCallback(
  49. {
  50. "binary_error": "valid-binary_error-mean",
  51. "binary_logloss": "valid-binary_logloss-mean",
  52. "binary_error_stdv": "valid-binary_error-stdv",
  53. "binary_logloss_stdv": "valid-binary_logloss-stdv",
  54. },
  55. frequency=0,
  56. )
  57. ],
  58. )
  59. if __name__ == "__main__":
  60. import argparse
  61. parser = argparse.ArgumentParser()
  62. parser.add_argument(
  63. "--use-cv", action="store_true", help="Use `lgb.cv` instead of `lgb.train`."
  64. )
  65. args, _ = parser.parse_known_args()
  66. config = {
  67. "objective": "binary",
  68. "metric": ["binary_error", "binary_logloss"],
  69. "verbose": -1,
  70. "boosting_type": tune.grid_search(["gbdt", "dart"]),
  71. "num_leaves": tune.randint(10, 1000),
  72. "learning_rate": tune.loguniform(1e-8, 1e-1),
  73. }
  74. tuner = tune.Tuner(
  75. train_breast_cancer if not args.use_cv else train_breast_cancer_cv,
  76. tune_config=tune.TuneConfig(
  77. metric="binary_error",
  78. mode="min",
  79. num_samples=2,
  80. scheduler=ASHAScheduler(),
  81. ),
  82. param_space=config,
  83. )
  84. results = tuner.fit()
  85. print("Best hyperparameters found were: ", results.get_best_result().config)