xgboost_example.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. from typing import Dict, List
  2. import numpy as np
  3. import sklearn.datasets
  4. import sklearn.metrics
  5. import xgboost as xgb
  6. from sklearn.model_selection import train_test_split
  7. import ray
  8. from ray import tune
  9. from ray.tune.integration.xgboost import TuneReportCheckpointCallback
  10. from ray.tune.schedulers import ASHAScheduler
  11. CHECKPOINT_FILENAME = "booster-checkpoint.json"
  12. def train_breast_cancer(config: dict):
  13. # This is a simple training function to be passed into Tune
  14. # Load dataset
  15. data, labels = sklearn.datasets.load_breast_cancer(return_X_y=True)
  16. # Split into train and test set
  17. train_x, test_x, train_y, test_y = train_test_split(data, labels, test_size=0.25)
  18. # Build input matrices for XGBoost
  19. train_set = xgb.DMatrix(train_x, label=train_y)
  20. test_set = xgb.DMatrix(test_x, label=test_y)
  21. # Train the classifier, using the Tune callback
  22. xgb.train(
  23. config,
  24. train_set,
  25. evals=[(test_set, "test")],
  26. verbose_eval=False,
  27. callbacks=[
  28. TuneReportCheckpointCallback(frequency=1, filename=CHECKPOINT_FILENAME)
  29. ],
  30. )
  31. def train_breast_cancer_cv(config: dict):
  32. # This is a simple training function to be passed into Tune
  33. # using xgboost's cross validation functionality
  34. # Load dataset
  35. data, labels = sklearn.datasets.load_breast_cancer(return_X_y=True)
  36. # For CV, we need to average over a list of results form folds
  37. def average_cv_folds(results_dict: Dict[str, List[float]]) -> Dict[str, float]:
  38. return {k: np.mean(v) for k, v in results_dict.items()}
  39. train_set = xgb.DMatrix(data, label=labels)
  40. # Run CV, using the Tune callback
  41. xgb.cv(
  42. config,
  43. train_set,
  44. verbose_eval=False,
  45. stratified=True,
  46. # Checkpointing is not supported for CV
  47. callbacks=[
  48. TuneReportCheckpointCallback(
  49. results_postprocessing_fn=average_cv_folds, frequency=0
  50. )
  51. ],
  52. )
  53. def get_best_model_checkpoint(best_result: "ray.tune.Result"):
  54. best_bst = TuneReportCheckpointCallback.get_model(
  55. best_result.checkpoint, filename=CHECKPOINT_FILENAME
  56. )
  57. accuracy = 1.0 - best_result.metrics["test-error"]
  58. print(f"Best model parameters: {best_result.config}")
  59. print(f"Best model total accuracy: {accuracy:.4f}")
  60. return best_bst
  61. def tune_xgboost(use_cv: bool = False):
  62. search_space = {
  63. # You can mix constants with search space objects.
  64. "objective": "binary:logistic",
  65. "eval_metric": ["logloss", "error"],
  66. "max_depth": tune.randint(1, 9),
  67. "min_child_weight": tune.choice([1, 2, 3]),
  68. "subsample": tune.uniform(0.5, 1.0),
  69. "eta": tune.loguniform(1e-4, 1e-1),
  70. }
  71. # This will enable aggressive early stopping of bad trials.
  72. scheduler = ASHAScheduler(
  73. max_t=10, grace_period=1, reduction_factor=2 # 10 training iterations
  74. )
  75. tuner = tune.Tuner(
  76. tune.with_resources(
  77. train_breast_cancer if not use_cv else train_breast_cancer_cv,
  78. # You can add "gpu": 0.1 to allocate GPUs
  79. resources={"cpu": 1},
  80. ),
  81. tune_config=tune.TuneConfig(
  82. metric="test-logloss",
  83. mode="min",
  84. num_samples=10,
  85. scheduler=scheduler,
  86. ),
  87. param_space=search_space,
  88. )
  89. results = tuner.fit()
  90. return results.get_best_result()
  91. if __name__ == "__main__":
  92. import argparse
  93. parser = argparse.ArgumentParser()
  94. parser.add_argument(
  95. "--use-cv", action="store_true", help="Use `xgb.cv` instead of `xgb.train`."
  96. )
  97. args, _ = parser.parse_known_args()
  98. best_result = tune_xgboost(args.use_cv)
  99. # Load the best model checkpoint.
  100. # Checkpointing is not supported when using `xgb.cv`
  101. if not args.use_cv:
  102. best_bst = get_best_model_checkpoint(best_result)
  103. # You could now do further predictions with
  104. # best_bst.predict(...)