import lightgbm as lgb import sklearn.datasets import sklearn.metrics from sklearn.model_selection import train_test_split from ray import tune from ray.tune.integration.lightgbm import TuneReportCheckpointCallback from ray.tune.schedulers import ASHAScheduler def train_breast_cancer(config: dict): # This is a simple training function to be passed into Tune # Load dataset data, target = sklearn.datasets.load_breast_cancer(return_X_y=True) # Split into train and test set train_x, test_x, train_y, test_y = train_test_split(data, target, test_size=0.25) # Build input Datasets for LightGBM train_set = lgb.Dataset(train_x, label=train_y) test_set = lgb.Dataset(test_x, label=test_y) # Train the classifier, using the Tune callback lgb.train( config, train_set, valid_sets=[test_set], valid_names=["eval"], callbacks=[ TuneReportCheckpointCallback( { "binary_error": "eval-binary_error", "binary_logloss": "eval-binary_logloss", } ) ], ) def train_breast_cancer_cv(config: dict): # This is a simple training function to be passed into Tune, using # lightgbm's cross validation functionality # Load dataset data, target = sklearn.datasets.load_breast_cancer(return_X_y=True) train_set = lgb.Dataset(data, label=target) # Run CV, using the Tune callback lgb.cv( config, train_set, stratified=True, # Checkpointing is not supported for CV # LightGBM aggregates metrics over folds automatically # with the cv_agg key. Both mean and standard deviation # are provided. callbacks=[ TuneReportCheckpointCallback( { "binary_error": "valid-binary_error-mean", "binary_logloss": "valid-binary_logloss-mean", "binary_error_stdv": "valid-binary_error-stdv", "binary_logloss_stdv": "valid-binary_logloss-stdv", }, frequency=0, ) ], ) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument( "--use-cv", action="store_true", help="Use `lgb.cv` instead of `lgb.train`." ) args, _ = parser.parse_known_args() config = { "objective": "binary", "metric": ["binary_error", "binary_logloss"], "verbose": -1, "boosting_type": tune.grid_search(["gbdt", "dart"]), "num_leaves": tune.randint(10, 1000), "learning_rate": tune.loguniform(1e-8, 1e-1), } tuner = tune.Tuner( train_breast_cancer if not args.use_cv else train_breast_cancer_cv, tune_config=tune.TuneConfig( metric="binary_error", mode="min", num_samples=2, scheduler=ASHAScheduler(), ), param_space=config, ) results = tuner.fit() print("Best hyperparameters found were: ", results.get_best_result().config)