hyperband_function_example.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. #!/usr/bin/env python
  2. import argparse
  3. import json
  4. import os
  5. import tempfile
  6. import numpy as np
  7. import ray
  8. from ray import tune
  9. from ray.tune import Checkpoint
  10. from ray.tune.schedulers import HyperBandScheduler
  11. def train_func(config):
  12. step = 0
  13. checkpoint = tune.get_checkpoint()
  14. if checkpoint:
  15. with checkpoint.as_directory() as checkpoint_dir:
  16. with open(os.path.join(checkpoint_dir, "checkpoint.json")) as f:
  17. step = json.load(f)["timestep"] + 1
  18. for timestep in range(step, 100):
  19. v = np.tanh(float(timestep) / config.get("width", 1))
  20. v *= config.get("height", 1)
  21. # Checkpoint the state of the training every 3 steps
  22. # Note that this is only required for certain schedulers
  23. with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
  24. checkpoint = None
  25. if timestep % 3 == 0:
  26. with open(
  27. os.path.join(temp_checkpoint_dir, "checkpoint.json"), "w"
  28. ) as f:
  29. json.dump({"timestep": timestep}, f)
  30. checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
  31. # Here we use `episode_reward_mean`, but you can also report other
  32. # objectives such as loss or accuracy.
  33. tune.report({"episode_reward_mean": v}, checkpoint=checkpoint)
  34. if __name__ == "__main__":
  35. parser = argparse.ArgumentParser()
  36. parser.add_argument(
  37. "--smoke-test", action="store_true", help="Finish quickly for testing"
  38. )
  39. args, _ = parser.parse_known_args()
  40. ray.init(num_cpus=4 if args.smoke_test else None)
  41. # Hyperband early stopping, configured with `episode_reward_mean` as the
  42. # objective and `training_iteration` as the time unit,
  43. # which is automatically filled by Tune.
  44. hyperband = HyperBandScheduler(max_t=200)
  45. tuner = tune.Tuner(
  46. train_func,
  47. run_config=tune.RunConfig(
  48. name="hyperband_test",
  49. stop={"training_iteration": 10 if args.smoke_test else 99999},
  50. failure_config=tune.FailureConfig(
  51. fail_fast=True,
  52. ),
  53. ),
  54. tune_config=tune.TuneConfig(
  55. num_samples=20,
  56. metric="episode_reward_mean",
  57. mode="max",
  58. scheduler=hyperband,
  59. ),
  60. param_space={"height": tune.uniform(0, 100)},
  61. )
  62. results = tuner.fit()
  63. print("Best hyperparameters found were: ", results.get_best_result().config)