pbt_ppo_example.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. #!/usr/bin/env python
  2. """Example of using PBT with RLlib.
  3. Note that this requires a cluster with at least 8 GPUs in order for all trials
  4. to run concurrently, otherwise PBT will round-robin train the trials which
  5. is less efficient (or you can set {"gpu": 0} to use CPUs for SGD instead).
  6. Note that Tune in general does not need 8 GPUs, and this is just a more
  7. computationally demanding example.
  8. """
  9. import random
  10. from ray import tune
  11. from ray.rllib.algorithms.ppo import PPO
  12. from ray.tune.schedulers import PopulationBasedTraining
  13. if __name__ == "__main__":
  14. # Postprocess the perturbed config to ensure it's still valid
  15. def explore(config):
  16. # ensure we collect enough timesteps to do sgd
  17. if config["train_batch_size"] < config["sgd_minibatch_size"] * 2:
  18. config["train_batch_size"] = config["sgd_minibatch_size"] * 2
  19. # ensure we run at least one sgd iter
  20. if config["num_sgd_iter"] < 1:
  21. config["num_sgd_iter"] = 1
  22. return config
  23. pbt = PopulationBasedTraining(
  24. time_attr="time_total_s",
  25. perturbation_interval=120,
  26. resample_probability=0.25,
  27. # Specifies the mutations of these hyperparams
  28. hyperparam_mutations={
  29. "lambda": lambda: random.uniform(0.9, 1.0),
  30. "clip_param": lambda: random.uniform(0.01, 0.5),
  31. "lr": [1e-3, 5e-4, 1e-4, 5e-5, 1e-5],
  32. "num_sgd_iter": lambda: random.randint(1, 30),
  33. "sgd_minibatch_size": lambda: random.randint(128, 16384),
  34. "train_batch_size": lambda: random.randint(2000, 160000),
  35. },
  36. custom_explore_fn=explore,
  37. )
  38. tuner = tune.Tuner(
  39. PPO,
  40. run_config=tune.RunConfig(
  41. name="pbt_humanoid_test",
  42. ),
  43. tune_config=tune.TuneConfig(
  44. scheduler=pbt,
  45. num_samples=8,
  46. metric="episode_reward_mean",
  47. mode="max",
  48. reuse_actors=True,
  49. ),
  50. param_space={
  51. "env": "Humanoid-v1",
  52. "kl_coeff": 1.0,
  53. "num_workers": 8,
  54. "num_gpus": 1,
  55. "model": {"free_log_std": True},
  56. # These params are tuned from a fixed starting value.
  57. "lambda": 0.95,
  58. "clip_param": 0.2,
  59. "lr": 1e-4,
  60. # These params start off randomly drawn from a set.
  61. "num_sgd_iter": tune.choice([10, 20, 30]),
  62. "sgd_minibatch_size": tune.choice([128, 512, 2048]),
  63. "train_batch_size": tune.choice([10000, 20000, 40000]),
  64. },
  65. )
  66. results = tuner.fit()
  67. print("best hyperparameters: ", results.get_best_result().config)