pb2_example.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. #!/usr/bin/env python
  2. import argparse
  3. import ray
  4. from ray import tune
  5. from ray.tune.examples.pbt_function import pbt_function
  6. from ray.tune.schedulers.pb2 import PB2
  7. if __name__ == "__main__":
  8. parser = argparse.ArgumentParser()
  9. parser.add_argument(
  10. "--smoke-test", action="store_true", help="Finish quickly for testing"
  11. )
  12. args, _ = parser.parse_known_args()
  13. if args.smoke_test:
  14. ray.init(num_cpus=2) # force pausing to happen for test
  15. perturbation_interval = 5
  16. pbt = PB2(
  17. time_attr="training_iteration",
  18. perturbation_interval=perturbation_interval,
  19. hyperparam_bounds={
  20. # hyperparameter bounds.
  21. "lr": [0.0001, 0.02],
  22. },
  23. )
  24. tuner = tune.Tuner(
  25. pbt_function,
  26. run_config=tune.RunConfig(
  27. name="pbt_test",
  28. verbose=False,
  29. stop={
  30. "training_iteration": 30,
  31. },
  32. failure_config=tune.FailureConfig(
  33. fail_fast=True,
  34. ),
  35. ),
  36. tune_config=tune.TuneConfig(
  37. scheduler=pbt,
  38. metric="mean_accuracy",
  39. mode="max",
  40. num_samples=8,
  41. reuse_actors=True,
  42. ),
  43. param_space={
  44. "lr": 0.0001,
  45. # note: this parameter is perturbed but has no effect on
  46. # the model training in this example
  47. "some_other_factor": 1,
  48. # This parameter is not perturbed and is used to determine
  49. # checkpoint frequency. We set checkpoints and perturbations
  50. # to happen at the same frequency.
  51. "checkpoint_interval": perturbation_interval,
  52. },
  53. )
  54. results = tuner.fit()
  55. print("Best hyperparameters found were: ", results.get_best_result().config)