pbt_example.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. #!/usr/bin/env python
  2. import argparse
  3. import random
  4. import numpy as np
  5. import ray
  6. from ray import tune
  7. from ray.tune.schedulers import PopulationBasedTraining
  8. class PBTBenchmarkExample(tune.Trainable):
  9. """Toy PBT problem for benchmarking adaptive learning rate.
  10. The goal is to optimize this trainable's accuracy. The accuracy increases
  11. fastest at the optimal lr, which is a function of the current accuracy.
  12. The optimal lr schedule for this problem is the triangle wave as follows.
  13. Note that many lr schedules for real models also follow this shape:
  14. best lr
  15. ^
  16. | /\
  17. | / \
  18. | / \
  19. | / \
  20. ------------> accuracy
  21. In this problem, using PBT with a population of 2-4 is sufficient to
  22. roughly approximate this lr schedule. Higher population sizes will yield
  23. faster convergence. Training will not converge without PBT.
  24. """
  25. def setup(self, config):
  26. self.lr = config["lr"]
  27. self.accuracy = 0.0 # end = 1000
  28. def step(self):
  29. midpoint = 100 # lr starts decreasing after acc > midpoint
  30. q_tolerance = 3 # penalize exceeding lr by more than this multiple
  31. noise_level = 2 # add gaussian noise to the acc increase
  32. # triangle wave:
  33. # - start at 0.001 @ t=0,
  34. # - peak at 0.01 @ t=midpoint,
  35. # - end at 0.001 @ t=midpoint * 2,
  36. if self.accuracy < midpoint:
  37. optimal_lr = 0.01 * self.accuracy / midpoint
  38. else:
  39. optimal_lr = 0.01 - 0.01 * (self.accuracy - midpoint) / midpoint
  40. optimal_lr = min(0.01, max(0.001, optimal_lr))
  41. # compute accuracy increase
  42. q_err = max(self.lr, optimal_lr) / min(self.lr, optimal_lr)
  43. if q_err < q_tolerance:
  44. self.accuracy += (1.0 / q_err) * random.random()
  45. elif self.lr > optimal_lr:
  46. self.accuracy -= (q_err - q_tolerance) * random.random()
  47. self.accuracy += noise_level * np.random.normal()
  48. self.accuracy = max(0, self.accuracy)
  49. return {
  50. "mean_accuracy": self.accuracy,
  51. "cur_lr": self.lr,
  52. "optimal_lr": optimal_lr, # for debugging
  53. "q_err": q_err, # for debugging
  54. "done": self.accuracy > midpoint * 2,
  55. }
  56. def save_checkpoint(self, checkpoint_dir):
  57. return {
  58. "accuracy": self.accuracy,
  59. "lr": self.lr,
  60. }
  61. def load_checkpoint(self, checkpoint):
  62. self.accuracy = checkpoint["accuracy"]
  63. def reset_config(self, new_config):
  64. self.lr = new_config["lr"]
  65. self.config = new_config
  66. return True
  67. if __name__ == "__main__":
  68. parser = argparse.ArgumentParser()
  69. parser.add_argument(
  70. "--smoke-test", action="store_true", help="Finish quickly for testing"
  71. )
  72. args, _ = parser.parse_known_args()
  73. if args.smoke_test:
  74. ray.init(num_cpus=2) # force pausing to happen for test
  75. perturbation_interval = 5
  76. pbt = PopulationBasedTraining(
  77. time_attr="training_iteration",
  78. perturbation_interval=perturbation_interval,
  79. hyperparam_mutations={
  80. # distribution for resampling
  81. "lr": lambda: random.uniform(0.0001, 0.02),
  82. # allow perturbations within this set of categorical values
  83. "some_other_factor": [1, 2],
  84. },
  85. )
  86. tuner = tune.Tuner(
  87. PBTBenchmarkExample,
  88. run_config=tune.RunConfig(
  89. name="pbt_class_api_example",
  90. # Stop when done = True or at some # of train steps (whichever comes first)
  91. stop={
  92. "done": True,
  93. "training_iteration": 10 if args.smoke_test else 1000,
  94. },
  95. verbose=0,
  96. # We recommend matching `perturbation_interval` and `checkpoint_interval`
  97. # (e.g. checkpoint every 4 steps, and perturb on those same steps)
  98. # or making `perturbation_interval` a multiple of `checkpoint_interval`
  99. # (e.g. checkpoint every 2 steps, and perturb every 4 steps).
  100. # This is to ensure that the lastest checkpoints are being used by PBT
  101. # when trials decide to exploit. If checkpointing and perturbing are not
  102. # aligned, then PBT may use a stale checkpoint to resume from.
  103. checkpoint_config=tune.CheckpointConfig(
  104. checkpoint_frequency=perturbation_interval,
  105. checkpoint_score_attribute="mean_accuracy",
  106. num_to_keep=4,
  107. ),
  108. ),
  109. tune_config=tune.TuneConfig(
  110. scheduler=pbt,
  111. metric="mean_accuracy",
  112. mode="max",
  113. reuse_actors=True,
  114. num_samples=8,
  115. ),
  116. param_space={
  117. "lr": 0.0001,
  118. # note: this parameter is perturbed but has no effect on
  119. # the model training in this example
  120. "some_other_factor": 1,
  121. },
  122. )
  123. results = tuner.fit()
  124. print("Best hyperparameters found were: ", results.get_best_result().config)