hyperband_example.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. #!/usr/bin/env python
  2. import argparse
  3. import ray
  4. from ray import tune
  5. from ray.tune.schedulers import HyperBandScheduler
  6. from ray.tune.utils.mock_trainable import MyTrainableClass
  7. if __name__ == "__main__":
  8. parser = argparse.ArgumentParser()
  9. parser.add_argument(
  10. "--smoke-test", action="store_true", help="Finish quickly for testing"
  11. )
  12. args, _ = parser.parse_known_args()
  13. ray.init(num_cpus=4 if args.smoke_test else None)
  14. # Hyperband early stopping, configured with `episode_reward_mean` as the
  15. # objective and `training_iteration` as the time unit,
  16. # which is automatically filled by Tune.
  17. hyperband = HyperBandScheduler(time_attr="training_iteration", max_t=200)
  18. tuner = tune.Tuner(
  19. MyTrainableClass,
  20. run_config=tune.RunConfig(
  21. name="hyperband_test",
  22. stop={"training_iteration": 1 if args.smoke_test else 200},
  23. verbose=1,
  24. failure_config=tune.FailureConfig(
  25. fail_fast=True,
  26. ),
  27. ),
  28. tune_config=tune.TuneConfig(
  29. num_samples=20 if args.smoke_test else 200,
  30. metric="episode_reward_mean",
  31. mode="max",
  32. scheduler=hyperband,
  33. ),
  34. param_space={"width": tune.randint(10, 90), "height": tune.randint(0, 100)},
  35. )
  36. results = tuner.fit()
  37. print("Best hyperparameters found were: ", results.get_best_result().config)