appo.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452
  1. """Asynchronous Proximal Policy Optimization (APPO)
  2. The algorithm is described in [1] (under the name of "IMPACT"):
  3. Detailed documentation:
  4. https://docs.ray.io/en/master/rllib-algorithms.html#appo
  5. [1] IMPACT: Importance Weighted Asynchronous Architectures with Clipped Target Networks.
  6. Luo et al. 2020
  7. https://arxiv.org/pdf/1912.00167
  8. """
  9. import logging
  10. from typing import Optional, Type
  11. from typing_extensions import Self
  12. from ray._common.deprecation import DEPRECATED_VALUE, deprecation_warning
  13. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
  14. from ray.rllib.algorithms.impala.impala import IMPALA, IMPALAConfig
  15. from ray.rllib.core.rl_module.rl_module import RLModuleSpec
  16. from ray.rllib.policy.policy import Policy
  17. from ray.rllib.utils.annotations import override
  18. from ray.rllib.utils.metrics import (
  19. LAST_TARGET_UPDATE_TS,
  20. LEARNER_STATS_KEY,
  21. NUM_AGENT_STEPS_SAMPLED,
  22. NUM_ENV_STEPS_SAMPLED,
  23. NUM_TARGET_UPDATES,
  24. )
  25. logger = logging.getLogger(__name__)
  26. LEARNER_RESULTS_KL_KEY = "mean_kl_loss"
  27. LEARNER_RESULTS_CURR_KL_COEFF_KEY = "curr_kl_coeff"
  28. OLD_ACTION_DIST_KEY = "old_action_dist"
  29. class APPOConfig(IMPALAConfig):
  30. """Defines a configuration class from which an APPO Algorithm can be built.
  31. .. testcode::
  32. from ray.rllib.algorithms.appo import APPOConfig
  33. config = (
  34. APPOConfig()
  35. .training(lr=0.01, grad_clip=30.0, train_batch_size_per_learner=50)
  36. )
  37. config = config.learners(num_learners=1)
  38. config = config.env_runners(num_env_runners=1)
  39. config = config.environment("CartPole-v1")
  40. # Build an Algorithm object from the config and run 1 training iteration.
  41. algo = config.build()
  42. algo.train()
  43. del algo
  44. .. testcode::
  45. from ray.rllib.algorithms.appo import APPOConfig
  46. from ray import tune
  47. config = APPOConfig()
  48. # Update the config object.
  49. config = config.training(lr=tune.grid_search([0.001,]))
  50. # Set the config object's env.
  51. config = config.environment(env="CartPole-v1")
  52. # Use to_dict() to get the old-style python config dict when running with tune.
  53. tune.Tuner(
  54. "APPO",
  55. run_config=tune.RunConfig(
  56. stop={"training_iteration": 1},
  57. verbose=0,
  58. ),
  59. param_space=config.to_dict(),
  60. ).fit()
  61. .. testoutput::
  62. :hide:
  63. ...
  64. """
  65. def __init__(self, algo_class=None):
  66. """Initializes a APPOConfig instance."""
  67. self.exploration_config = {
  68. # The Exploration class to use. In the simplest case, this is the name
  69. # (str) of any class present in the `rllib.utils.exploration` package.
  70. # You can also provide the python class directly or the full location
  71. # of your class (e.g. "ray.rllib.utils.exploration.epsilon_greedy.
  72. # EpsilonGreedy").
  73. "type": "StochasticSampling",
  74. # Add constructor kwargs here (if any).
  75. }
  76. super().__init__(algo_class=algo_class or APPO)
  77. # fmt: off
  78. # __sphinx_doc_begin__
  79. # APPO specific settings:
  80. self.vtrace = True
  81. self.use_gae = True
  82. self.lambda_ = 1.0
  83. self.clip_param = 0.4
  84. self.use_kl_loss = False
  85. self.kl_coeff = 1.0
  86. self.kl_target = 0.01
  87. self.target_worker_clipping = 2.0
  88. # If a circular buffer should be used to store training batches. The
  89. # alternative is a simple `Queue`.
  90. self.use_circular_buffer = True
  91. # Circular replay buffer settings.
  92. # Used in [1] for discrete action tasks:
  93. # `circular_buffer_num_batches=4` and `circular_buffer_iterations_per_batch=2`
  94. # For cont. action tasks:
  95. # `circular_buffer_num_batches=16` and `circular_buffer_iterations_per_batch=20`
  96. self.circular_buffer_num_batches = 8
  97. self.circular_buffer_iterations_per_batch = 2
  98. # Size of the simple queue (if `use_circular_buffer` is False).
  99. self.simple_queue_size = 32
  100. # Override some of IMPALAConfig's default values with APPO-specific values.
  101. self.num_env_runners = 2
  102. self.target_network_update_freq = 2
  103. self.broadcast_interval = 1
  104. self.grad_clip = 40.0
  105. # Note: Only when using enable_rl_module_and_learner=True can the clipping mode
  106. # be configured by the user. On the old API stack, RLlib will always clip by
  107. # global_norm, no matter the value of `grad_clip_by`.
  108. self.grad_clip_by = "global_norm"
  109. self.opt_type = "adam"
  110. self.lr = 0.0005
  111. self.decay = 0.99
  112. self.momentum = 0.0
  113. self.epsilon = 0.1
  114. self.vf_loss_coeff = 0.5
  115. self.entropy_coeff = 0.01
  116. self.tau = 1.0
  117. # __sphinx_doc_end__
  118. # fmt: on
  119. self.lr_schedule = None # @OldAPIStack
  120. self.entropy_coeff_schedule = None # @OldAPIStack
  121. self.num_gpus = 0 # @OldAPIStack
  122. self.num_multi_gpu_tower_stacks = 1 # @OldAPIStack
  123. self.minibatch_buffer_size = 1 # @OldAPIStack
  124. self.replay_proportion = 0.0 # @OldAPIStack
  125. self.replay_buffer_num_slots = 100 # @OldAPIStack
  126. self.learner_queue_size = 16 # @OldAPIStack
  127. self.learner_queue_timeout = 300 # @OldAPIStack
  128. # Deprecated keys.
  129. self.target_update_frequency = DEPRECATED_VALUE
  130. self.use_critic = DEPRECATED_VALUE
  131. @override(IMPALAConfig)
  132. def training(
  133. self,
  134. *,
  135. vtrace: Optional[bool] = NotProvided,
  136. use_gae: Optional[bool] = NotProvided,
  137. lambda_: Optional[float] = NotProvided,
  138. clip_param: Optional[float] = NotProvided,
  139. use_kl_loss: Optional[bool] = NotProvided,
  140. kl_coeff: Optional[float] = NotProvided,
  141. kl_target: Optional[float] = NotProvided,
  142. target_network_update_freq: Optional[int] = NotProvided,
  143. tau: Optional[float] = NotProvided,
  144. target_worker_clipping: Optional[float] = NotProvided,
  145. use_circular_buffer: Optional[bool] = NotProvided,
  146. circular_buffer_num_batches: Optional[int] = NotProvided,
  147. circular_buffer_iterations_per_batch: Optional[int] = NotProvided,
  148. simple_queue_size: Optional[int] = NotProvided,
  149. # Deprecated keys.
  150. target_update_frequency=DEPRECATED_VALUE,
  151. use_critic=DEPRECATED_VALUE,
  152. **kwargs,
  153. ) -> Self:
  154. """Sets the training related configuration.
  155. Args:
  156. vtrace: Whether to use V-trace weighted advantages. If false, PPO GAE
  157. advantages will be used instead.
  158. use_gae: If true, use the Generalized Advantage Estimator (GAE)
  159. with a value function, see https://arxiv.org/pdf/1506.02438.pdf.
  160. Only applies if vtrace=False.
  161. lambda_: GAE (lambda) parameter.
  162. clip_param: PPO surrogate slipping parameter.
  163. use_kl_loss: Whether to use the KL-term in the loss function.
  164. kl_coeff: Coefficient for weighting the KL-loss term.
  165. kl_target: Target term for the KL-term to reach (via adjusting the
  166. `kl_coeff` automatically).
  167. target_network_update_freq: NOTE: This parameter is only applicable on
  168. the new API stack. The frequency with which to update the target
  169. policy network from the main trained policy network. The metric
  170. used is `NUM_ENV_STEPS_TRAINED_LIFETIME` and the unit is `n` (see [1]
  171. 4.1.1), where: `n = [circular_buffer_num_batches (N)] *
  172. [circular_buffer_iterations_per_batch (K)] * [train batch size]`
  173. For example, if you set `target_network_update_freq=2`, and N=4, K=2,
  174. and `train_batch_size_per_learner=500`, then the target net is updated
  175. every 2*4*2*500=8000 trained env steps (every 16 batch updates on each
  176. learner).
  177. The authors in [1] suggests that this setting is robust to a range of
  178. choices (try values between 0.125 and 4).
  179. target_network_update_freq: The frequency to update the target policy and
  180. tune the kl loss coefficients that are used during training. After
  181. setting this parameter, the algorithm waits for at least
  182. `target_network_update_freq` number of environment samples to be trained
  183. on before updating the target networks and tune the kl loss
  184. coefficients. NOTE: This parameter is only applicable when using the
  185. Learner API (enable_rl_module_and_learner=True).
  186. tau: The factor by which to update the target policy network towards
  187. the current policy network. Can range between 0 and 1.
  188. e.g. updated_param = tau * current_param + (1 - tau) * target_param
  189. target_worker_clipping: The maximum value for the target-worker-clipping
  190. used for computing the IS ratio, described in [1]
  191. IS = min(π(i) / π(target), ρ) * (π / π(i))
  192. use_circular_buffer: Whether to use a circular buffer for storing
  193. training batches. If false, a simple Queue will be used. Defaults to
  194. True.
  195. circular_buffer_num_batches: The number of train batches that fit
  196. into the circular buffer. Each such train batch can be sampled for
  197. training max. `circular_buffer_iterations_per_batch` times.
  198. circular_buffer_iterations_per_batch: The number of times any train
  199. batch in the circular buffer can be sampled for training. A batch gets
  200. evicted from the buffer either if it's the oldest batch in the buffer
  201. and a new batch is added OR if the batch reaches this max. number of
  202. being sampled.
  203. simple_queue_size: The size of the simple queue (if `use_circular_buffer`
  204. is False) for storing training batches.
  205. Returns:
  206. This updated AlgorithmConfig object.
  207. """
  208. if target_update_frequency != DEPRECATED_VALUE:
  209. deprecation_warning(
  210. old="target_update_frequency",
  211. new="target_network_update_freq",
  212. error=True,
  213. )
  214. if use_critic != DEPRECATED_VALUE:
  215. deprecation_warning(
  216. old="use_critic",
  217. help="`use_critic` no longer supported! APPO always uses a value "
  218. "function (critic).",
  219. error=True,
  220. )
  221. # Pass kwargs onto super's `training()` method.
  222. super().training(**kwargs)
  223. if vtrace is not NotProvided:
  224. self.vtrace = vtrace
  225. if use_gae is not NotProvided:
  226. self.use_gae = use_gae
  227. if lambda_ is not NotProvided:
  228. self.lambda_ = lambda_
  229. if clip_param is not NotProvided:
  230. self.clip_param = clip_param
  231. if use_kl_loss is not NotProvided:
  232. self.use_kl_loss = use_kl_loss
  233. if kl_coeff is not NotProvided:
  234. self.kl_coeff = kl_coeff
  235. if kl_target is not NotProvided:
  236. self.kl_target = kl_target
  237. if target_network_update_freq is not NotProvided:
  238. self.target_network_update_freq = target_network_update_freq
  239. if tau is not NotProvided:
  240. self.tau = tau
  241. if target_worker_clipping is not NotProvided:
  242. self.target_worker_clipping = target_worker_clipping
  243. if use_circular_buffer is not NotProvided:
  244. self.use_circular_buffer = use_circular_buffer
  245. if circular_buffer_num_batches is not NotProvided:
  246. self.circular_buffer_num_batches = circular_buffer_num_batches
  247. if circular_buffer_iterations_per_batch is not NotProvided:
  248. self.circular_buffer_iterations_per_batch = (
  249. circular_buffer_iterations_per_batch
  250. )
  251. if simple_queue_size is not NotProvided:
  252. self.simple_queue_size = simple_queue_size
  253. return self
  254. @override(IMPALAConfig)
  255. def validate(self) -> None:
  256. super().validate()
  257. # On new API stack, circular buffer should be used, not `minibatch_buffer_size`.
  258. if self.enable_rl_module_and_learner:
  259. if self.minibatch_buffer_size != 1 or self.replay_proportion != 0.0:
  260. self._value_error(
  261. "`minibatch_buffer_size/replay_proportion` not valid on new API "
  262. "stack with APPO! "
  263. "Use `circular_buffer_num_batches` for the number of train batches "
  264. "in the circular buffer. To change the maximum number of times "
  265. "any batch may be sampled, set "
  266. "`circular_buffer_iterations_per_batch`."
  267. )
  268. if self.num_multi_gpu_tower_stacks != 1:
  269. self._value_error(
  270. "`num_multi_gpu_tower_stacks` not supported on new API stack with "
  271. "APPO! In order to train on multi-GPU, use "
  272. "`config.learners(num_learners=[number of GPUs], "
  273. "num_gpus_per_learner=1)`. To scale the throughput of batch-to-GPU-"
  274. "pre-loading on each of your `Learners`, set "
  275. "`num_gpu_loader_threads` to a higher number (recommended values: "
  276. "1-8)."
  277. )
  278. if self.learner_queue_size != 16:
  279. self._value_error(
  280. "`learner_queue_size` not supported on new API stack with "
  281. "APPO! In order set the size of the circular buffer (which acts as "
  282. "a 'learner queue'), use "
  283. "`config.training(circular_buffer_num_batches=..)`. To change the "
  284. "maximum number of times any batch may be sampled, set "
  285. "`config.training(circular_buffer_iterations_per_batch=..)`."
  286. )
  287. @override(IMPALAConfig)
  288. def get_default_learner_class(self):
  289. if self.framework_str == "torch":
  290. from ray.rllib.algorithms.appo.torch.appo_torch_learner import (
  291. APPOTorchLearner,
  292. )
  293. return APPOTorchLearner
  294. elif self.framework_str in ["tf2", "tf"]:
  295. raise ValueError(
  296. "TensorFlow is no longer supported on the new API stack! "
  297. "Use `framework='torch'`."
  298. )
  299. else:
  300. raise ValueError(
  301. f"The framework {self.framework_str} is not supported. "
  302. "Use `framework='torch'`."
  303. )
  304. @override(IMPALAConfig)
  305. def get_default_rl_module_spec(self) -> RLModuleSpec:
  306. if self.framework_str == "torch":
  307. from ray.rllib.algorithms.appo.torch.appo_torch_rl_module import (
  308. APPOTorchRLModule as RLModule,
  309. )
  310. else:
  311. raise ValueError(
  312. f"The framework {self.framework_str} is not supported. "
  313. "Use either 'torch' or 'tf2'."
  314. )
  315. return RLModuleSpec(module_class=RLModule)
  316. @property
  317. @override(AlgorithmConfig)
  318. def _model_config_auto_includes(self):
  319. return super()._model_config_auto_includes | {"vf_share_layers": False}
  320. class APPO(IMPALA):
  321. def __init__(self, config, *args, **kwargs):
  322. """Initializes an APPO instance."""
  323. super().__init__(config, *args, **kwargs)
  324. # After init: Initialize target net.
  325. # TODO(avnishn): Does this need to happen in __init__? I think we can move it
  326. # to setup()
  327. if not self.config.enable_rl_module_and_learner:
  328. self.env_runner.foreach_policy_to_train(lambda p, _: p.update_target())
  329. @override(IMPALA)
  330. def training_step(self) -> None:
  331. if self.config.enable_rl_module_and_learner:
  332. return super().training_step()
  333. train_results = super().training_step()
  334. # Update the target network and the KL coefficient for the APPO-loss.
  335. # The target network update frequency is calculated automatically by the product
  336. # of `num_epochs` setting (usually 1 for APPO) and `minibatch_buffer_size`.
  337. last_update = self._counters[LAST_TARGET_UPDATE_TS]
  338. cur_ts = self._counters[
  339. (
  340. NUM_AGENT_STEPS_SAMPLED
  341. if self.config.count_steps_by == "agent_steps"
  342. else NUM_ENV_STEPS_SAMPLED
  343. )
  344. ]
  345. target_update_freq = self.config.num_epochs * self.config.minibatch_buffer_size
  346. if cur_ts - last_update > target_update_freq:
  347. self._counters[NUM_TARGET_UPDATES] += 1
  348. self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
  349. # Update our target network.
  350. self.env_runner.foreach_policy_to_train(lambda p, _: p.update_target())
  351. # Also update the KL-coefficient for the APPO loss, if necessary.
  352. if self.config.use_kl_loss:
  353. def update(pi, pi_id):
  354. assert LEARNER_STATS_KEY not in train_results, (
  355. "{} should be nested under policy id key".format(
  356. LEARNER_STATS_KEY
  357. ),
  358. train_results,
  359. )
  360. if pi_id in train_results:
  361. kl = train_results[pi_id][LEARNER_STATS_KEY].get("kl")
  362. assert kl is not None, (train_results, pi_id)
  363. # Make the actual `Policy.update_kl()` call.
  364. pi.update_kl(kl)
  365. else:
  366. logger.warning("No data for {}, not updating kl".format(pi_id))
  367. # Update KL on all trainable policies within the local (trainer)
  368. # Worker.
  369. self.env_runner.foreach_policy_to_train(update)
  370. return train_results
  371. @classmethod
  372. @override(IMPALA)
  373. def get_default_config(cls) -> APPOConfig:
  374. return APPOConfig()
  375. @classmethod
  376. @override(IMPALA)
  377. def get_default_policy_class(
  378. cls, config: AlgorithmConfig
  379. ) -> Optional[Type[Policy]]:
  380. if config["framework"] == "torch":
  381. from ray.rllib.algorithms.appo.appo_torch_policy import APPOTorchPolicy
  382. return APPOTorchPolicy
  383. elif config["framework"] == "tf":
  384. if config.enable_rl_module_and_learner:
  385. raise ValueError(
  386. "RLlib's RLModule and Learner API is not supported for"
  387. " tf1. Use "
  388. "framework='tf2' instead."
  389. )
  390. from ray.rllib.algorithms.appo.appo_tf_policy import APPOTF1Policy
  391. return APPOTF1Policy
  392. else:
  393. from ray.rllib.algorithms.appo.appo_tf_policy import APPOTF2Policy
  394. return APPOTF2Policy