pbt_function.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. #!/usr/bin/env python
  2. import argparse
  3. import json
  4. import os
  5. import random
  6. import tempfile
  7. import numpy as np
  8. import ray
  9. from ray import tune
  10. from ray.tune import Checkpoint
  11. from ray.tune.schedulers import PopulationBasedTraining
  12. def pbt_function(config):
  13. """Toy PBT problem for benchmarking adaptive learning rate.
  14. The goal is to optimize this trainable's accuracy. The accuracy increases
  15. fastest at the optimal lr, which is a function of the current accuracy.
  16. The optimal lr schedule for this problem is the triangle wave as follows.
  17. Note that many lr schedules for real models also follow this shape:
  18. best lr
  19. ^
  20. | /\
  21. | / \
  22. | / \
  23. | / \
  24. ------------> accuracy
  25. In this problem, using PBT with a population of 2-4 is sufficient to
  26. roughly approximate this lr schedule. Higher population sizes will yield
  27. faster convergence. Training will not converge without PBT.
  28. """
  29. lr = config["lr"]
  30. checkpoint_interval = config.get("checkpoint_interval", 1)
  31. accuracy = 0.0 # end = 1000
  32. # NOTE: See below why step is initialized to 1
  33. step = 1
  34. checkpoint = tune.get_checkpoint()
  35. if checkpoint:
  36. with checkpoint.as_directory() as checkpoint_dir:
  37. with open(os.path.join(checkpoint_dir, "checkpoint.json"), "r") as f:
  38. checkpoint_dict = json.load(f)
  39. accuracy = checkpoint_dict["acc"]
  40. last_step = checkpoint_dict["step"]
  41. # Current step should be 1 more than the last checkpoint step
  42. step = last_step + 1
  43. # triangle wave:
  44. # - start at 0.001 @ t=0,
  45. # - peak at 0.01 @ t=midpoint,
  46. # - end at 0.001 @ t=midpoint * 2,
  47. midpoint = 100 # lr starts decreasing after acc > midpoint
  48. q_tolerance = 3 # penalize exceeding lr by more than this multiple
  49. noise_level = 2 # add gaussian noise to the acc increase
  50. # Let `stop={"done": True}` in the configs below handle trial stopping
  51. while True:
  52. if accuracy < midpoint:
  53. optimal_lr = 0.01 * accuracy / midpoint
  54. else:
  55. optimal_lr = 0.01 - 0.01 * (accuracy - midpoint) / midpoint
  56. optimal_lr = min(0.01, max(0.001, optimal_lr))
  57. # compute accuracy increase
  58. q_err = max(lr, optimal_lr) / min(lr, optimal_lr)
  59. if q_err < q_tolerance:
  60. accuracy += (1.0 / q_err) * random.random()
  61. elif lr > optimal_lr:
  62. accuracy -= (q_err - q_tolerance) * random.random()
  63. accuracy += noise_level * np.random.normal()
  64. accuracy = max(0, accuracy)
  65. metrics = {
  66. "mean_accuracy": accuracy,
  67. "cur_lr": lr,
  68. "optimal_lr": optimal_lr, # for debugging
  69. "q_err": q_err, # for debugging
  70. "done": accuracy > midpoint * 2, # this stops the training process
  71. }
  72. if step % checkpoint_interval == 0:
  73. # Checkpoint every `checkpoint_interval` steps
  74. # NOTE: if we initialized `step=0` above, our checkpointing and perturbing
  75. # would be out of sync by 1 step.
  76. # Ex: if `checkpoint_interval` = `perturbation_interval` = 3
  77. # step: 0 (checkpoint) 1 2 3 (checkpoint)
  78. # training_iteration: 1 2 3 (perturb) 4
  79. with tempfile.TemporaryDirectory() as tempdir:
  80. with open(os.path.join(tempdir, "checkpoint.json"), "w") as f:
  81. checkpoint_dict = {"acc": accuracy, "step": step}
  82. json.dump(checkpoint_dict, f)
  83. tune.report(metrics, checkpoint=Checkpoint.from_directory(tempdir))
  84. else:
  85. tune.report(metrics)
  86. step += 1
  87. def run_tune_pbt(smoke_test=False):
  88. perturbation_interval = 5
  89. pbt = PopulationBasedTraining(
  90. time_attr="training_iteration",
  91. perturbation_interval=perturbation_interval,
  92. hyperparam_mutations={
  93. # distribution for resampling
  94. "lr": tune.uniform(0.0001, 0.02),
  95. # allow perturbations within this set of categorical values
  96. "some_other_factor": [1, 2],
  97. },
  98. )
  99. tuner = tune.Tuner(
  100. pbt_function,
  101. run_config=tune.RunConfig(
  102. name="pbt_function_api_example",
  103. verbose=False,
  104. stop={
  105. # Stop when done = True or at some # of train steps
  106. # (whichever comes first)
  107. "done": True,
  108. "training_iteration": 10 if smoke_test else 1000,
  109. },
  110. failure_config=tune.FailureConfig(
  111. fail_fast=True,
  112. ),
  113. checkpoint_config=tune.CheckpointConfig(
  114. checkpoint_score_attribute="mean_accuracy",
  115. num_to_keep=2,
  116. ),
  117. ),
  118. tune_config=tune.TuneConfig(
  119. scheduler=pbt,
  120. metric="mean_accuracy",
  121. mode="max",
  122. num_samples=8,
  123. reuse_actors=True,
  124. ),
  125. param_space={
  126. "lr": 0.0001,
  127. # Note: `some_other_factor` is perturbed because it is specified under
  128. # the PBT scheduler's `hyperparam_mutations` argument, but has no effect on
  129. # the model training in this example
  130. "some_other_factor": 1,
  131. # Note: `checkpoint_interval` will not be perturbed (since it's not
  132. # included above), and it will be used to determine how many steps to take
  133. # between each checkpoint.
  134. # We recommend matching `perturbation_interval` and `checkpoint_interval`
  135. # (e.g. checkpoint every 4 steps, and perturb on those same steps)
  136. # or making `perturbation_interval` a multiple of `checkpoint_interval`
  137. # (e.g. checkpoint every 2 steps, and perturb every 4 steps).
  138. # This is to ensure that the lastest checkpoints are being used by PBT
  139. # when trials decide to exploit. If checkpointing and perturbing are not
  140. # aligned, then PBT may use a stale checkpoint to resume from.
  141. "checkpoint_interval": perturbation_interval,
  142. },
  143. )
  144. results = tuner.fit()
  145. print("Best hyperparameters found were: ", results.get_best_result().config)
  146. if __name__ == "__main__":
  147. parser = argparse.ArgumentParser()
  148. parser.add_argument(
  149. "--smoke-test",
  150. action="store_true",
  151. default=False,
  152. help="Finish quickly for testing",
  153. )
  154. args, _ = parser.parse_known_args()
  155. if args.smoke_test:
  156. ray.init(num_cpus=2) # force pausing to happen for test
  157. run_tune_pbt(smoke_test=args.smoke_test)