dqn.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859
  1. """
  2. Deep Q-Networks (DQN, Rainbow, Parametric DQN)
  3. ==============================================
  4. This file defines the distributed Algorithm class for the Deep Q-Networks
  5. algorithm. See `dqn_[tf|torch]_policy.py` for the definition of the policies.
  6. Detailed documentation:
  7. https://docs.ray.io/en/master/rllib-algorithms.html#deep-q-networks-dqn-rainbow-parametric-dqn
  8. """ # noqa: E501
  9. import logging
  10. from collections import defaultdict
  11. from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
  12. import numpy as np
  13. from typing_extensions import Self
  14. from ray._common.deprecation import DEPRECATED_VALUE
  15. from ray.rllib.algorithms.algorithm import Algorithm
  16. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
  17. from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy
  18. from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy
  19. from ray.rllib.core.learner import Learner
  20. from ray.rllib.core.rl_module.rl_module import RLModuleSpec
  21. from ray.rllib.execution.rollout_ops import (
  22. synchronous_parallel_sample,
  23. )
  24. from ray.rllib.execution.train_ops import (
  25. multi_gpu_train_one_step,
  26. train_one_step,
  27. )
  28. from ray.rllib.policy.policy import Policy
  29. from ray.rllib.policy.sample_batch import MultiAgentBatch
  30. from ray.rllib.utils import deep_update
  31. from ray.rllib.utils.annotations import override
  32. from ray.rllib.utils.metrics import (
  33. ALL_MODULES,
  34. ENV_RUNNER_RESULTS,
  35. ENV_RUNNER_SAMPLING_TIMER,
  36. LAST_TARGET_UPDATE_TS,
  37. LEARNER_RESULTS,
  38. LEARNER_UPDATE_TIMER,
  39. NUM_AGENT_STEPS_SAMPLED,
  40. NUM_AGENT_STEPS_SAMPLED_LIFETIME,
  41. NUM_ENV_STEPS_SAMPLED,
  42. NUM_ENV_STEPS_SAMPLED_LIFETIME,
  43. NUM_TARGET_UPDATES,
  44. REPLAY_BUFFER_ADD_DATA_TIMER,
  45. REPLAY_BUFFER_RESULTS,
  46. REPLAY_BUFFER_SAMPLE_TIMER,
  47. REPLAY_BUFFER_UPDATE_PRIOS_TIMER,
  48. SAMPLE_TIMER,
  49. SYNCH_WORKER_WEIGHTS_TIMER,
  50. TD_ERROR_KEY,
  51. TIMERS,
  52. )
  53. from ray.rllib.utils.numpy import convert_to_numpy
  54. from ray.rllib.utils.replay_buffers.utils import (
  55. sample_min_n_steps_from_buffer,
  56. update_priorities_in_episode_replay_buffer,
  57. update_priorities_in_replay_buffer,
  58. validate_buffer_config,
  59. )
  60. from ray.rllib.utils.typing import (
  61. LearningRateOrSchedule,
  62. ResultDict,
  63. RLModuleSpecType,
  64. SampleBatchType,
  65. )
  66. logger = logging.getLogger(__name__)
  67. class DQNConfig(AlgorithmConfig):
  68. r"""Defines a configuration class from which a DQN Algorithm can be built.
  69. .. testcode::
  70. from ray.rllib.algorithms.dqn.dqn import DQNConfig
  71. config = (
  72. DQNConfig()
  73. .environment("CartPole-v1")
  74. .training(replay_buffer_config={
  75. "type": "PrioritizedEpisodeReplayBuffer",
  76. "capacity": 60000,
  77. "alpha": 0.5,
  78. "beta": 0.5,
  79. })
  80. .env_runners(num_env_runners=1)
  81. )
  82. algo = config.build()
  83. algo.train()
  84. algo.stop()
  85. .. testcode::
  86. from ray.rllib.algorithms.dqn.dqn import DQNConfig
  87. from ray import tune
  88. config = (
  89. DQNConfig()
  90. .environment("CartPole-v1")
  91. .training(
  92. num_atoms=tune.grid_search([1,])
  93. )
  94. )
  95. tune.Tuner(
  96. "DQN",
  97. run_config=tune.RunConfig(stop={"training_iteration":1}),
  98. param_space=config,
  99. ).fit()
  100. .. testoutput::
  101. :hide:
  102. ...
  103. """
  104. def __init__(self, algo_class=None):
  105. """Initializes a DQNConfig instance."""
  106. self.exploration_config = {
  107. "type": "EpsilonGreedy",
  108. "initial_epsilon": 1.0,
  109. "final_epsilon": 0.02,
  110. "epsilon_timesteps": 10000,
  111. }
  112. super().__init__(algo_class=algo_class or DQN)
  113. # Overrides of AlgorithmConfig defaults
  114. # `env_runners()`
  115. # Set to `self.n_step`, if 'auto'.
  116. self.rollout_fragment_length: Union[int, str] = "auto"
  117. # New stack uses `epsilon` as either a constant value or a scheduler
  118. # defined like this.
  119. # TODO (simon): Ensure that users can understand how to provide epsilon.
  120. # (sven): Should we add this to `self.env_runners(epsilon=..)`?
  121. self.epsilon = [(0, 1.0), (10000, 0.05)]
  122. # `training()`
  123. self.grad_clip = 40.0
  124. # Note: Only when using enable_rl_module_and_learner=True can the clipping mode
  125. # be configured by the user. On the old API stack, RLlib will always clip by
  126. # global_norm, no matter the value of `grad_clip_by`.
  127. self.grad_clip_by = "global_norm"
  128. self.lr = 5e-4
  129. self.train_batch_size = 32
  130. # `evaluation()`
  131. self.evaluation(evaluation_config=AlgorithmConfig.overrides(explore=False))
  132. # `reporting()`
  133. self.min_time_s_per_iteration = None
  134. self.min_sample_timesteps_per_iteration = 1000
  135. # DQN specific config settings.
  136. # fmt: off
  137. # __sphinx_doc_begin__
  138. self.target_network_update_freq = 500
  139. self.num_steps_sampled_before_learning_starts = 1000
  140. self.store_buffer_in_checkpoints = False
  141. self.adam_epsilon = 1e-8
  142. self.tau = 1.0
  143. self.num_atoms = 1
  144. self.v_min = -10.0
  145. self.v_max = 10.0
  146. self.noisy = False
  147. self.sigma0 = 0.5
  148. self.dueling = True
  149. self.hiddens = [256]
  150. self.double_q = True
  151. self.n_step = 1
  152. self.before_learn_on_batch = None
  153. self.training_intensity = None
  154. self.td_error_loss_fn = "huber"
  155. self.categorical_distribution_temperature = 1.0
  156. # The burn-in for stateful `RLModule`s.
  157. self.burn_in_len = 0
  158. # Replay buffer configuration.
  159. self.replay_buffer_config = {
  160. "type": "PrioritizedEpisodeReplayBuffer",
  161. # Size of the replay buffer. Note that if async_updates is set,
  162. # then each worker will have a replay buffer of this size.
  163. "capacity": 50000,
  164. "alpha": 0.6,
  165. # Beta parameter for sampling from prioritized replay buffer.
  166. "beta": 0.4,
  167. }
  168. # fmt: on
  169. # __sphinx_doc_end__
  170. self.lr_schedule = None # @OldAPIStack
  171. # Deprecated
  172. self.buffer_size = DEPRECATED_VALUE
  173. self.prioritized_replay = DEPRECATED_VALUE
  174. self.learning_starts = DEPRECATED_VALUE
  175. self.replay_batch_size = DEPRECATED_VALUE
  176. # Can not use DEPRECATED_VALUE here because -1 is a common config value
  177. self.replay_sequence_length = None
  178. self.prioritized_replay_alpha = DEPRECATED_VALUE
  179. self.prioritized_replay_beta = DEPRECATED_VALUE
  180. self.prioritized_replay_eps = DEPRECATED_VALUE
  181. @override(AlgorithmConfig)
  182. def training(
  183. self,
  184. *,
  185. target_network_update_freq: Optional[int] = NotProvided,
  186. replay_buffer_config: Optional[dict] = NotProvided,
  187. store_buffer_in_checkpoints: Optional[bool] = NotProvided,
  188. lr_schedule: Optional[List[List[Union[int, float]]]] = NotProvided,
  189. epsilon: Optional[LearningRateOrSchedule] = NotProvided,
  190. adam_epsilon: Optional[float] = NotProvided,
  191. grad_clip: Optional[int] = NotProvided,
  192. num_steps_sampled_before_learning_starts: Optional[int] = NotProvided,
  193. tau: Optional[float] = NotProvided,
  194. num_atoms: Optional[int] = NotProvided,
  195. v_min: Optional[float] = NotProvided,
  196. v_max: Optional[float] = NotProvided,
  197. noisy: Optional[bool] = NotProvided,
  198. sigma0: Optional[float] = NotProvided,
  199. dueling: Optional[bool] = NotProvided,
  200. hiddens: Optional[int] = NotProvided,
  201. double_q: Optional[bool] = NotProvided,
  202. n_step: Optional[Union[int, Tuple[int, int]]] = NotProvided,
  203. before_learn_on_batch: Callable[
  204. [Type[MultiAgentBatch], List[Type[Policy]], Type[int]],
  205. Type[MultiAgentBatch],
  206. ] = NotProvided,
  207. training_intensity: Optional[float] = NotProvided,
  208. td_error_loss_fn: Optional[str] = NotProvided,
  209. categorical_distribution_temperature: Optional[float] = NotProvided,
  210. burn_in_len: Optional[int] = NotProvided,
  211. **kwargs,
  212. ) -> Self:
  213. """Sets the training related configuration.
  214. Args:
  215. target_network_update_freq: Update the target network every
  216. `target_network_update_freq` sample steps.
  217. replay_buffer_config: Replay buffer config.
  218. Examples:
  219. {
  220. "_enable_replay_buffer_api": True,
  221. "type": "MultiAgentReplayBuffer",
  222. "capacity": 50000,
  223. "replay_sequence_length": 1,
  224. }
  225. - OR -
  226. {
  227. "_enable_replay_buffer_api": True,
  228. "type": "MultiAgentPrioritizedReplayBuffer",
  229. "capacity": 50000,
  230. "prioritized_replay_alpha": 0.6,
  231. "prioritized_replay_beta": 0.4,
  232. "prioritized_replay_eps": 1e-6,
  233. "replay_sequence_length": 1,
  234. }
  235. - Where -
  236. prioritized_replay_alpha: Alpha parameter controls the degree of
  237. prioritization in the buffer. In other words, when a buffer sample has
  238. a higher temporal-difference error, with how much more probability
  239. should it drawn to use to update the parametrized Q-network. 0.0
  240. corresponds to uniform probability. Setting much above 1.0 may quickly
  241. result as the sampling distribution could become heavily “pointy” with
  242. low entropy.
  243. prioritized_replay_beta: Beta parameter controls the degree of
  244. importance sampling which suppresses the influence of gradient updates
  245. from samples that have higher probability of being sampled via alpha
  246. parameter and the temporal-difference error.
  247. prioritized_replay_eps: Epsilon parameter sets the baseline probability
  248. for sampling so that when the temporal-difference error of a sample is
  249. zero, there is still a chance of drawing the sample.
  250. store_buffer_in_checkpoints: Set this to True, if you want the contents of
  251. your buffer(s) to be stored in any saved checkpoints as well.
  252. Warnings will be created if:
  253. - This is True AND restoring from a checkpoint that contains no buffer
  254. data.
  255. - This is False AND restoring from a checkpoint that does contain
  256. buffer data.
  257. epsilon: Epsilon exploration schedule. In the format of [[timestep, value],
  258. [timestep, value], ...]. A schedule must start from
  259. timestep 0.
  260. adam_epsilon: Adam optimizer's epsilon hyper parameter.
  261. grad_clip: If not None, clip gradients during optimization at this value.
  262. num_steps_sampled_before_learning_starts: Number of timesteps to collect
  263. from rollout workers before we start sampling from replay buffers for
  264. learning. Whether we count this in agent steps or environment steps
  265. depends on config.multi_agent(count_steps_by=..).
  266. tau: Update the target by \tau * policy + (1-\tau) * target_policy.
  267. num_atoms: Number of atoms for representing the distribution of return.
  268. When this is greater than 1, distributional Q-learning is used.
  269. v_min: Minimum value estimation
  270. v_max: Maximum value estimation
  271. noisy: Whether to use noisy network to aid exploration. This adds parametric
  272. noise to the model weights.
  273. sigma0: Control the initial parameter noise for noisy nets.
  274. dueling: Whether to use dueling DQN.
  275. hiddens: Dense-layer setup for each the advantage branch and the value
  276. branch
  277. double_q: Whether to use double DQN.
  278. n_step: N-step target updates. If >1, sars' tuples in trajectories will be
  279. postprocessed to become sa[discounted sum of R][s t+n] tuples. An
  280. integer will be interpreted as a fixed n-step value. If a tuple of 2
  281. ints is provided here, the n-step value will be drawn for each sample(!)
  282. in the train batch from a uniform distribution over the closed interval
  283. defined by `[n_step[0], n_step[1]]`.
  284. before_learn_on_batch: Callback to run before learning on a multi-agent
  285. batch of experiences.
  286. training_intensity: The intensity with which to update the model (vs
  287. collecting samples from the env).
  288. If None, uses "natural" values of:
  289. `train_batch_size` / (`rollout_fragment_length` x `num_env_runners` x
  290. `num_envs_per_env_runner`).
  291. If not None, will make sure that the ratio between timesteps inserted
  292. into and sampled from the buffer matches the given values.
  293. Example:
  294. training_intensity=1000.0
  295. train_batch_size=250
  296. rollout_fragment_length=1
  297. num_env_runners=1 (or 0)
  298. num_envs_per_env_runner=1
  299. -> natural value = 250 / 1 = 250.0
  300. -> will make sure that replay+train op will be executed 4x asoften as
  301. rollout+insert op (4 * 250 = 1000).
  302. See: rllib/algorithms/dqn/dqn.py::calculate_rr_weights for further
  303. details.
  304. td_error_loss_fn: "huber" or "mse". loss function for calculating TD error
  305. when num_atoms is 1. Note that if num_atoms is > 1, this parameter
  306. is simply ignored, and softmax cross entropy loss will be used.
  307. categorical_distribution_temperature: Set the temperature parameter used
  308. by Categorical action distribution. A valid temperature is in the range
  309. of [0, 1]. Note that this mostly affects evaluation since TD error uses
  310. argmax for return calculation.
  311. burn_in_len: The burn-in period for a stateful RLModule. It allows the
  312. Learner to utilize the initial `burn_in_len` steps in a replay sequence
  313. solely for unrolling the network and establishing a typical starting
  314. state. The network is then updated on the remaining steps of the
  315. sequence. This process helps mitigate issues stemming from a poor
  316. initial state - zero or an outdated recorded state. Consider setting
  317. this parameter to a positive integer if your stateful RLModule faces
  318. convergence challenges or exhibits signs of catastrophic forgetting.
  319. Returns:
  320. This updated AlgorithmConfig object.
  321. """
  322. # Pass kwargs onto super's `training()` method.
  323. super().training(**kwargs)
  324. if target_network_update_freq is not NotProvided:
  325. self.target_network_update_freq = target_network_update_freq
  326. if replay_buffer_config is not NotProvided:
  327. # Override entire `replay_buffer_config` if `type` key changes.
  328. # Update, if `type` key remains the same or is not specified.
  329. new_replay_buffer_config = deep_update(
  330. {"replay_buffer_config": self.replay_buffer_config},
  331. {"replay_buffer_config": replay_buffer_config},
  332. False,
  333. ["replay_buffer_config"],
  334. ["replay_buffer_config"],
  335. )
  336. self.replay_buffer_config = new_replay_buffer_config["replay_buffer_config"]
  337. if store_buffer_in_checkpoints is not NotProvided:
  338. self.store_buffer_in_checkpoints = store_buffer_in_checkpoints
  339. if lr_schedule is not NotProvided:
  340. self.lr_schedule = lr_schedule
  341. if epsilon is not NotProvided:
  342. self.epsilon = epsilon
  343. if adam_epsilon is not NotProvided:
  344. self.adam_epsilon = adam_epsilon
  345. if grad_clip is not NotProvided:
  346. self.grad_clip = grad_clip
  347. if num_steps_sampled_before_learning_starts is not NotProvided:
  348. self.num_steps_sampled_before_learning_starts = (
  349. num_steps_sampled_before_learning_starts
  350. )
  351. if tau is not NotProvided:
  352. self.tau = tau
  353. if num_atoms is not NotProvided:
  354. self.num_atoms = num_atoms
  355. if v_min is not NotProvided:
  356. self.v_min = v_min
  357. if v_max is not NotProvided:
  358. self.v_max = v_max
  359. if noisy is not NotProvided:
  360. self.noisy = noisy
  361. if sigma0 is not NotProvided:
  362. self.sigma0 = sigma0
  363. if dueling is not NotProvided:
  364. self.dueling = dueling
  365. if hiddens is not NotProvided:
  366. self.hiddens = hiddens
  367. if double_q is not NotProvided:
  368. self.double_q = double_q
  369. if n_step is not NotProvided:
  370. self.n_step = n_step
  371. if before_learn_on_batch is not NotProvided:
  372. self.before_learn_on_batch = before_learn_on_batch
  373. if training_intensity is not NotProvided:
  374. self.training_intensity = training_intensity
  375. if td_error_loss_fn is not NotProvided:
  376. self.td_error_loss_fn = td_error_loss_fn
  377. if categorical_distribution_temperature is not NotProvided:
  378. self.categorical_distribution_temperature = (
  379. categorical_distribution_temperature
  380. )
  381. if burn_in_len is not NotProvided:
  382. self.burn_in_len = burn_in_len
  383. return self
  384. @override(AlgorithmConfig)
  385. def validate(self) -> None:
  386. # Call super's validation method.
  387. super().validate()
  388. if self.enable_rl_module_and_learner:
  389. # `lr_schedule` checking.
  390. if self.lr_schedule is not None:
  391. self._value_error(
  392. "`lr_schedule` is deprecated and must be None! Use the "
  393. "`lr` setting to setup a schedule."
  394. )
  395. else:
  396. if not self.in_evaluation:
  397. validate_buffer_config(self)
  398. # TODO (simon): Find a clean solution to deal with configuration configs
  399. # when using the new API stack.
  400. if self.exploration_config["type"] == "ParameterNoise":
  401. if self.batch_mode != "complete_episodes":
  402. self._value_error(
  403. "ParameterNoise Exploration requires `batch_mode` to be "
  404. "'complete_episodes'. Try setting `config.env_runners("
  405. "batch_mode='complete_episodes')`."
  406. )
  407. if self.noisy:
  408. self._value_error(
  409. "ParameterNoise Exploration and `noisy` network cannot be"
  410. " used at the same time!"
  411. )
  412. if self.td_error_loss_fn not in ["huber", "mse"]:
  413. self._value_error("`td_error_loss_fn` must be 'huber' or 'mse'!")
  414. # Check rollout_fragment_length to be compatible with n_step.
  415. if (
  416. not self.in_evaluation
  417. and self.rollout_fragment_length != "auto"
  418. and self.rollout_fragment_length < self.n_step
  419. ):
  420. self._value_error(
  421. f"Your `rollout_fragment_length` ({self.rollout_fragment_length}) is "
  422. f"smaller than `n_step` ({self.n_step})! "
  423. "Try setting config.env_runners(rollout_fragment_length="
  424. f"{self.n_step})."
  425. )
  426. # Check, if the `max_seq_len` is longer then the burn-in.
  427. if (
  428. "max_seq_len" in self.model_config
  429. and 0 < self.model_config["max_seq_len"] <= self.burn_in_len
  430. ):
  431. raise ValueError(
  432. f"Your defined `burn_in_len`={self.burn_in_len} is larger or equal "
  433. f"`max_seq_len`={self.model_config['max_seq_len']}! Either decrease "
  434. "the `burn_in_len` or increase your `max_seq_len`."
  435. )
  436. # Validate that we use the corresponding `EpisodeReplayBuffer` when using
  437. # episodes.
  438. # TODO (sven, simon): Implement the multi-agent case for replay buffers.
  439. from ray.rllib.utils.replay_buffers.episode_replay_buffer import (
  440. EpisodeReplayBuffer,
  441. )
  442. if (
  443. self.enable_env_runner_and_connector_v2
  444. and not isinstance(self.replay_buffer_config["type"], str)
  445. and not issubclass(self.replay_buffer_config["type"], EpisodeReplayBuffer)
  446. ):
  447. self._value_error(
  448. "When using the new `EnvRunner API` the replay buffer must be of type "
  449. "`EpisodeReplayBuffer`."
  450. )
  451. elif not self.enable_env_runner_and_connector_v2 and (
  452. (
  453. isinstance(self.replay_buffer_config["type"], str)
  454. and "Episode" in self.replay_buffer_config["type"]
  455. )
  456. or issubclass(self.replay_buffer_config["type"], EpisodeReplayBuffer)
  457. ):
  458. self._value_error(
  459. "When using the old API stack the replay buffer must not be of type "
  460. "`EpisodeReplayBuffer`! We suggest you use the following config to run "
  461. "DQN on the old API stack: `config.training(replay_buffer_config={"
  462. "'type': 'MultiAgentPrioritizedReplayBuffer', "
  463. "'prioritized_replay_alpha': [alpha], "
  464. "'prioritized_replay_beta': [beta], "
  465. "'prioritized_replay_eps': [eps], "
  466. "})`."
  467. )
  468. @override(AlgorithmConfig)
  469. def get_rollout_fragment_length(self, worker_index: int = 0) -> int:
  470. if self.rollout_fragment_length == "auto":
  471. return (
  472. self.n_step[1]
  473. if isinstance(self.n_step, (tuple, list))
  474. else self.n_step
  475. )
  476. else:
  477. return self.rollout_fragment_length
  478. @override(AlgorithmConfig)
  479. def get_default_rl_module_spec(self) -> RLModuleSpecType:
  480. if self.framework_str == "torch":
  481. from ray.rllib.algorithms.dqn.torch.default_dqn_torch_rl_module import (
  482. DefaultDQNTorchRLModule,
  483. )
  484. return RLModuleSpec(
  485. module_class=DefaultDQNTorchRLModule,
  486. model_config=self.model_config,
  487. )
  488. else:
  489. raise ValueError(
  490. f"The framework {self.framework_str} is not supported! "
  491. "Use `config.framework('torch')` instead."
  492. )
  493. @property
  494. @override(AlgorithmConfig)
  495. def _model_config_auto_includes(self) -> Dict[str, Any]:
  496. return super()._model_config_auto_includes | {
  497. "double_q": self.double_q,
  498. "dueling": self.dueling,
  499. "epsilon": self.epsilon,
  500. "num_atoms": self.num_atoms,
  501. "std_init": self.sigma0,
  502. "v_max": self.v_max,
  503. "v_min": self.v_min,
  504. }
  505. @override(AlgorithmConfig)
  506. def get_default_learner_class(self) -> Union[Type["Learner"], str]:
  507. if self.framework_str == "torch":
  508. from ray.rllib.algorithms.dqn.torch.dqn_torch_learner import (
  509. DQNTorchLearner,
  510. )
  511. return DQNTorchLearner
  512. else:
  513. raise ValueError(
  514. f"The framework {self.framework_str} is not supported! "
  515. "Use `config.framework('torch')` instead."
  516. )
  517. def calculate_rr_weights(config: AlgorithmConfig) -> List[float]:
  518. """Calculate the round robin weights for the rollout and train steps"""
  519. if not config.training_intensity:
  520. return [1, 1]
  521. # Calculate the "native ratio" as:
  522. # [train-batch-size] / [size of env-rolled-out sampled data]
  523. # This is to set freshly rollout-collected data in relation to
  524. # the data we pull from the replay buffer (which also contains old
  525. # samples).
  526. native_ratio = config.total_train_batch_size / (
  527. config.get_rollout_fragment_length()
  528. * config.num_envs_per_env_runner
  529. # Add one to workers because the local
  530. # worker usually collects experiences as well, and we avoid division by zero.
  531. * max(config.num_env_runners + 1, 1)
  532. )
  533. # Training intensity is specified in terms of
  534. # (steps_replayed / steps_sampled), so adjust for the native ratio.
  535. sample_and_train_weight = config.training_intensity / native_ratio
  536. if sample_and_train_weight < 1:
  537. return [int(np.round(1 / sample_and_train_weight)), 1]
  538. else:
  539. return [1, int(np.round(sample_and_train_weight))]
  540. class DQN(Algorithm):
  541. @classmethod
  542. @override(Algorithm)
  543. def get_default_config(cls) -> DQNConfig:
  544. return DQNConfig()
  545. @classmethod
  546. @override(Algorithm)
  547. def get_default_policy_class(
  548. cls, config: AlgorithmConfig
  549. ) -> Optional[Type[Policy]]:
  550. if config["framework"] == "torch":
  551. return DQNTorchPolicy
  552. else:
  553. return DQNTFPolicy
  554. @override(Algorithm)
  555. def setup(self, config: AlgorithmConfig) -> None:
  556. super().setup(config)
  557. if self.config.enable_env_runner_and_connector_v2 and self.env_runner_group:
  558. if self.env_runner is None:
  559. self._module_is_stateful = self.env_runner_group.foreach_env_runner(
  560. lambda er: er.module.is_stateful(),
  561. remote_worker_ids=[1],
  562. local_env_runner=False,
  563. )[0]
  564. else:
  565. self._module_is_stateful = self.env_runner.module.is_stateful()
  566. @override(Algorithm)
  567. def training_step(self) -> None:
  568. """DQN training iteration function.
  569. Each training iteration, we:
  570. - Sample (MultiAgentBatch) from workers.
  571. - Store new samples in replay buffer.
  572. - Sample training batch (MultiAgentBatch) from replay buffer.
  573. - Learn on training batch.
  574. - Update remote workers' new policy weights.
  575. - Update target network every `target_network_update_freq` sample steps.
  576. - Return all collected metrics for the iteration.
  577. Returns:
  578. The results dict from executing the training iteration.
  579. """
  580. # Old API stack (Policy, RolloutWorker, Connector).
  581. if not self.config.enable_env_runner_and_connector_v2:
  582. return self._training_step_old_api_stack()
  583. # New API stack (RLModule, Learner, EnvRunner, ConnectorV2).
  584. return self._training_step_new_api_stack()
  585. def _training_step_new_api_stack(self):
  586. # Alternate between storing and sampling and training.
  587. store_weight, sample_and_train_weight = calculate_rr_weights(self.config)
  588. # Run multiple sampling + storing to buffer iterations.
  589. for _ in range(store_weight):
  590. with self.metrics.log_time((TIMERS, ENV_RUNNER_SAMPLING_TIMER)):
  591. # Sample in parallel from workers.
  592. episodes, env_runner_results = synchronous_parallel_sample(
  593. worker_set=self.env_runner_group,
  594. concat=True,
  595. sample_timeout_s=self.config.sample_timeout_s,
  596. _uses_new_env_runners=True,
  597. _return_metrics=True,
  598. )
  599. # Reduce EnvRunner metrics over the n EnvRunners.
  600. self.metrics.aggregate(env_runner_results, key=ENV_RUNNER_RESULTS)
  601. # Add the sampled experiences to the replay buffer.
  602. with self.metrics.log_time((TIMERS, REPLAY_BUFFER_ADD_DATA_TIMER)):
  603. self.local_replay_buffer.add(episodes)
  604. if self.config.count_steps_by == "agent_steps":
  605. current_ts = sum(
  606. self.metrics.peek(
  607. (ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED_LIFETIME), default={}
  608. ).values()
  609. )
  610. else:
  611. current_ts = self.metrics.peek(
  612. (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME), default=0
  613. )
  614. # If enough experiences have been sampled start training.
  615. if current_ts >= self.config.num_steps_sampled_before_learning_starts:
  616. # Run multiple sample-from-buffer and update iterations.
  617. for _ in range(sample_and_train_weight):
  618. # Sample a list of episodes used for learning from the replay buffer.
  619. with self.metrics.log_time((TIMERS, REPLAY_BUFFER_SAMPLE_TIMER)):
  620. episodes = self.local_replay_buffer.sample(
  621. num_items=self.config.total_train_batch_size,
  622. n_step=self.config.n_step,
  623. # In case an `EpisodeReplayBuffer` is used we need to provide
  624. # the sequence length.
  625. batch_length_T=(
  626. self._module_is_stateful
  627. * self.config.model_config.get("max_seq_len", 0)
  628. ),
  629. lookback=int(self._module_is_stateful),
  630. # TODO (simon): Implement `burn_in_len` in SAC and remove this
  631. # if-else clause.
  632. min_batch_length_T=self.config.burn_in_len
  633. if hasattr(self.config, "burn_in_len")
  634. else 0,
  635. gamma=self.config.gamma,
  636. beta=self.config.replay_buffer_config.get("beta"),
  637. sample_episodes=True,
  638. )
  639. # Get the replay buffer metrics.
  640. replay_buffer_results = self.local_replay_buffer.get_metrics()
  641. self.metrics.aggregate(
  642. [replay_buffer_results], key=REPLAY_BUFFER_RESULTS
  643. )
  644. # Perform an update on the buffer-sampled train batch.
  645. with self.metrics.log_time((TIMERS, LEARNER_UPDATE_TIMER)):
  646. learner_results = self.learner_group.update(
  647. episodes=episodes,
  648. timesteps={
  649. NUM_ENV_STEPS_SAMPLED_LIFETIME: (
  650. self.metrics.peek(
  651. (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME)
  652. )
  653. ),
  654. NUM_AGENT_STEPS_SAMPLED_LIFETIME: (
  655. self.metrics.peek(
  656. (
  657. ENV_RUNNER_RESULTS,
  658. NUM_AGENT_STEPS_SAMPLED_LIFETIME,
  659. )
  660. )
  661. ),
  662. },
  663. )
  664. # Isolate TD-errors from result dicts (we should not log these to
  665. # disk or WandB, they might be very large).
  666. td_errors = defaultdict(list)
  667. for res in learner_results:
  668. for module_id, module_results in res.items():
  669. if TD_ERROR_KEY in module_results:
  670. td_errors[module_id].extend(
  671. convert_to_numpy(
  672. module_results.pop(TD_ERROR_KEY).peek()
  673. )
  674. )
  675. td_errors = {
  676. module_id: {TD_ERROR_KEY: np.concatenate(s, axis=0)}
  677. for module_id, s in td_errors.items()
  678. }
  679. self.metrics.aggregate(learner_results, key=LEARNER_RESULTS)
  680. # Update replay buffer priorities.
  681. with self.metrics.log_time((TIMERS, REPLAY_BUFFER_UPDATE_PRIOS_TIMER)):
  682. update_priorities_in_episode_replay_buffer(
  683. replay_buffer=self.local_replay_buffer,
  684. td_errors=td_errors,
  685. )
  686. # Update weights and global_vars - after learning on the local worker -
  687. # on all remote workers.
  688. with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)):
  689. modules_to_update = set(learner_results[0].keys()) - {ALL_MODULES}
  690. # NOTE: the new API stack does not use global vars.
  691. self.env_runner_group.sync_weights(
  692. from_worker_or_learner_group=self.learner_group,
  693. policies=modules_to_update,
  694. global_vars=None,
  695. inference_only=True,
  696. )
  697. def _training_step_old_api_stack(self) -> ResultDict:
  698. """Training step for the old API stack.
  699. More specifically this training step relies on `RolloutWorker`.
  700. """
  701. train_results = {}
  702. # We alternate between storing new samples and sampling and training
  703. store_weight, sample_and_train_weight = calculate_rr_weights(self.config)
  704. for _ in range(store_weight):
  705. # Sample (MultiAgentBatch) from workers.
  706. with self._timers[SAMPLE_TIMER]:
  707. new_sample_batch: SampleBatchType = synchronous_parallel_sample(
  708. worker_set=self.env_runner_group,
  709. concat=True,
  710. sample_timeout_s=self.config.sample_timeout_s,
  711. )
  712. # Return early if all our workers failed.
  713. if not new_sample_batch:
  714. return {}
  715. # Update counters
  716. self._counters[NUM_AGENT_STEPS_SAMPLED] += new_sample_batch.agent_steps()
  717. self._counters[NUM_ENV_STEPS_SAMPLED] += new_sample_batch.env_steps()
  718. # Store new samples in replay buffer.
  719. self.local_replay_buffer.add(new_sample_batch)
  720. global_vars = {
  721. "timestep": self._counters[NUM_ENV_STEPS_SAMPLED],
  722. }
  723. # Update target network every `target_network_update_freq` sample steps.
  724. cur_ts = self._counters[
  725. (
  726. NUM_AGENT_STEPS_SAMPLED
  727. if self.config.count_steps_by == "agent_steps"
  728. else NUM_ENV_STEPS_SAMPLED
  729. )
  730. ]
  731. if cur_ts > self.config.num_steps_sampled_before_learning_starts:
  732. for _ in range(sample_and_train_weight):
  733. # Sample training batch (MultiAgentBatch) from replay buffer.
  734. train_batch = sample_min_n_steps_from_buffer(
  735. self.local_replay_buffer,
  736. self.config.total_train_batch_size,
  737. count_by_agent_steps=self.config.count_steps_by == "agent_steps",
  738. )
  739. # Postprocess batch before we learn on it
  740. post_fn = self.config.get("before_learn_on_batch") or (lambda b, *a: b)
  741. train_batch = post_fn(train_batch, self.env_runner_group, self.config)
  742. # Learn on training batch.
  743. # Use simple optimizer (only for multi-agent or tf-eager; all other
  744. # cases should use the multi-GPU optimizer, even if only using 1 GPU)
  745. if self.config.get("simple_optimizer") is True:
  746. train_results = train_one_step(self, train_batch)
  747. else:
  748. train_results = multi_gpu_train_one_step(self, train_batch)
  749. # Update replay buffer priorities.
  750. update_priorities_in_replay_buffer(
  751. self.local_replay_buffer,
  752. self.config,
  753. train_batch,
  754. train_results,
  755. )
  756. last_update = self._counters[LAST_TARGET_UPDATE_TS]
  757. if cur_ts - last_update >= self.config.target_network_update_freq:
  758. to_update = self.env_runner.get_policies_to_train()
  759. self.env_runner.foreach_policy_to_train(
  760. lambda p, pid, to_update=to_update: (
  761. pid in to_update and p.update_target()
  762. )
  763. )
  764. self._counters[NUM_TARGET_UPDATES] += 1
  765. self._counters[LAST_TARGET_UPDATE_TS] = cur_ts
  766. # Update weights and global_vars - after learning on the local worker -
  767. # on all remote workers.
  768. with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]:
  769. self.env_runner_group.sync_weights(global_vars=global_vars)
  770. # Return all collected metrics for the iteration.
  771. return train_results