pb2_ppo_example.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import argparse
  2. import os
  3. import random
  4. from datetime import datetime
  5. import pandas as pd
  6. from ray.tune import run, sample_from
  7. from ray.tune.schedulers import PopulationBasedTraining
  8. from ray.tune.schedulers.pb2 import PB2
  9. # Postprocess the perturbed config to ensure it's still valid used if PBT.
  10. def explore(config):
  11. # Ensure we collect enough timesteps to do sgd.
  12. if config["train_batch_size"] < config["sgd_minibatch_size"] * 2:
  13. config["train_batch_size"] = config["sgd_minibatch_size"] * 2
  14. # Ensure we run at least one sgd iter.
  15. if config["lambda"] > 1:
  16. config["lambda"] = 1
  17. config["train_batch_size"] = int(config["train_batch_size"])
  18. return config
  19. if __name__ == "__main__":
  20. parser = argparse.ArgumentParser()
  21. parser.add_argument("--max", type=int, default=1000000)
  22. parser.add_argument("--algo", type=str, default="PPO")
  23. parser.add_argument("--num_workers", type=int, default=4)
  24. parser.add_argument("--num_samples", type=int, default=4)
  25. parser.add_argument("--t_ready", type=int, default=50000)
  26. parser.add_argument("--seed", type=int, default=0)
  27. parser.add_argument(
  28. "--horizon", type=int, default=1600
  29. ) # make this 1000 for other envs
  30. parser.add_argument("--perturb", type=float, default=0.25) # if using PBT
  31. parser.add_argument("--env_name", type=str, default="BipedalWalker-v2")
  32. parser.add_argument(
  33. "--criteria", type=str, default="timesteps_total"
  34. ) # "training_iteration", "time_total_s"
  35. parser.add_argument(
  36. "--net", type=str, default="32_32"
  37. ) # May be important to use a larger network for bigger tasks.
  38. parser.add_argument("--filename", type=str, default="")
  39. parser.add_argument("--method", type=str, default="pb2") # ['pbt', 'pb2']
  40. parser.add_argument("--save_csv", type=bool, default=False)
  41. args = parser.parse_args()
  42. # bipedalwalker needs 1600
  43. if args.env_name in ["BipedalWalker-v2", "BipedalWalker-v3"]:
  44. horizon = 1600
  45. else:
  46. horizon = 1000
  47. pbt = PopulationBasedTraining(
  48. time_attr=args.criteria,
  49. metric="episode_reward_mean",
  50. mode="max",
  51. perturbation_interval=args.t_ready,
  52. resample_probability=args.perturb,
  53. quantile_fraction=args.perturb, # copy bottom % with top %
  54. # Specifies the search space for these hyperparams
  55. hyperparam_mutations={
  56. "lambda": lambda: random.uniform(0.9, 1.0),
  57. "clip_param": lambda: random.uniform(0.1, 0.5),
  58. "lr": lambda: random.uniform(1e-3, 1e-5),
  59. "train_batch_size": lambda: random.randint(1000, 60000),
  60. },
  61. custom_explore_fn=explore,
  62. )
  63. pb2 = PB2(
  64. time_attr=args.criteria,
  65. metric="episode_reward_mean",
  66. mode="max",
  67. perturbation_interval=args.t_ready,
  68. quantile_fraction=args.perturb, # copy bottom % with top %
  69. # Specifies the hyperparam search space
  70. hyperparam_bounds={
  71. "lambda": [0.9, 1.0],
  72. "clip_param": [0.1, 0.5],
  73. "lr": [1e-5, 1e-3],
  74. "train_batch_size": [1000, 60000],
  75. },
  76. )
  77. methods = {"pbt": pbt, "pb2": pb2}
  78. timelog = (
  79. str(datetime.date(datetime.now())) + "_" + str(datetime.time(datetime.now()))
  80. )
  81. args.dir = "{}_{}_{}_Size{}_{}_{}".format(
  82. args.algo,
  83. args.filename,
  84. args.method,
  85. str(args.num_samples),
  86. args.env_name,
  87. args.criteria,
  88. )
  89. analysis = run(
  90. args.algo,
  91. name="{}_{}_{}_seed{}_{}".format(
  92. timelog, args.method, args.env_name, str(args.seed), args.filename
  93. ),
  94. scheduler=methods[args.method],
  95. verbose=1,
  96. num_samples=args.num_samples,
  97. reuse_actors=True,
  98. stop={args.criteria: args.max},
  99. config={
  100. "env": args.env_name,
  101. "log_level": "INFO",
  102. "seed": args.seed,
  103. "kl_coeff": 1.0,
  104. "num_gpus": 0,
  105. "horizon": horizon,
  106. "observation_filter": "MeanStdFilter",
  107. "model": {
  108. "fcnet_hiddens": [
  109. int(args.net.split("_")[0]),
  110. int(args.net.split("_")[1]),
  111. ],
  112. "free_log_std": True,
  113. },
  114. "num_sgd_iter": 10,
  115. "sgd_minibatch_size": 128,
  116. "lambda": sample_from(lambda spec: random.uniform(0.9, 1.0)),
  117. "clip_param": sample_from(lambda spec: random.uniform(0.1, 0.5)),
  118. "lr": sample_from(lambda spec: random.uniform(1e-3, 1e-5)),
  119. "train_batch_size": sample_from(lambda spec: random.randint(1000, 60000)),
  120. },
  121. )
  122. all_dfs = list(analysis.trial_dataframes.values())
  123. results = pd.DataFrame()
  124. for i in range(args.num_samples):
  125. df = all_dfs[i]
  126. df = df[
  127. [
  128. "timesteps_total",
  129. "episodes_total",
  130. "episode_reward_mean",
  131. "info/learner/default_policy/cur_kl_coeff",
  132. ]
  133. ]
  134. df["Agent"] = i
  135. results = pd.concat([results, df]).reset_index(drop=True)
  136. if args.save_csv:
  137. if not (os.path.exists("data/" + args.dir)):
  138. os.makedirs("data/" + args.dir)
  139. results.to_csv("data/{}/seed{}.csv".format(args.dir, str(args.seed)))