dreamerv3.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732
  1. """
  2. [1] Mastering Diverse Domains through World Models - 2023
  3. D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
  4. https://arxiv.org/pdf/2301.04104v1.pdf
  5. [2] Mastering Atari with Discrete World Models - 2021
  6. D. Hafner, T. Lillicrap, M. Norouzi, J. Ba
  7. https://arxiv.org/pdf/2010.02193.pdf
  8. """
  9. import logging
  10. from typing import Any, Dict, Optional, Union
  11. import gymnasium as gym
  12. from typing_extensions import Self
  13. from ray.rllib.algorithms.algorithm import Algorithm
  14. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
  15. from ray.rllib.algorithms.dreamerv3.dreamerv3_catalog import DreamerV3Catalog
  16. from ray.rllib.algorithms.dreamerv3.utils import do_symlog_obs
  17. from ray.rllib.algorithms.dreamerv3.utils.add_is_firsts_to_batch import (
  18. AddIsFirstsToBatch,
  19. )
  20. from ray.rllib.algorithms.dreamerv3.utils.summaries import (
  21. report_dreamed_eval_trajectory_vs_samples,
  22. report_predicted_vs_sampled_obs,
  23. report_sampling_and_replay_buffer,
  24. )
  25. from ray.rllib.connectors.common import AddStatesFromEpisodesToBatch
  26. from ray.rllib.core import DEFAULT_MODULE_ID
  27. from ray.rllib.core.columns import Columns
  28. from ray.rllib.core.rl_module.rl_module import RLModuleSpec
  29. from ray.rllib.env import INPUT_ENV_SINGLE_SPACES
  30. from ray.rllib.execution.rollout_ops import synchronous_parallel_sample
  31. from ray.rllib.policy.sample_batch import SampleBatch
  32. from ray.rllib.utils import deep_update
  33. from ray.rllib.utils.annotations import PublicAPI, override
  34. from ray.rllib.utils.metrics import (
  35. ENV_RUNNER_RESULTS,
  36. LEARN_ON_BATCH_TIMER,
  37. LEARNER_RESULTS,
  38. NUM_ENV_STEPS_SAMPLED_LIFETIME,
  39. NUM_ENV_STEPS_TRAINED_LIFETIME,
  40. NUM_GRAD_UPDATES_LIFETIME,
  41. NUM_SYNCH_WORKER_WEIGHTS,
  42. REPLAY_BUFFER_RESULTS,
  43. SAMPLE_TIMER,
  44. SYNCH_WORKER_WEIGHTS_TIMER,
  45. TIMERS,
  46. )
  47. from ray.rllib.utils.numpy import one_hot
  48. from ray.rllib.utils.replay_buffers.episode_replay_buffer import EpisodeReplayBuffer
  49. from ray.rllib.utils.typing import LearningRateOrSchedule
  50. logger = logging.getLogger(__name__)
  51. class DreamerV3Config(AlgorithmConfig):
  52. """Defines a configuration class from which a DreamerV3 can be built.
  53. .. testcode::
  54. from ray.rllib.algorithms.dreamerv3 import DreamerV3Config
  55. config = (
  56. DreamerV3Config()
  57. .environment("CartPole-v1")
  58. .training(
  59. model_size="XS",
  60. training_ratio=1,
  61. # TODO
  62. model={
  63. "batch_size_B": 1,
  64. "batch_length_T": 1,
  65. "horizon_H": 1,
  66. "gamma": 0.997,
  67. "model_size": "XS",
  68. },
  69. )
  70. )
  71. config = config.learners(num_learners=0)
  72. # Build a Algorithm object from the config and run 1 training iteration.
  73. algo = config.build()
  74. # algo.train()
  75. del algo
  76. """
  77. def __init__(self, algo_class=None):
  78. """Initializes a DreamerV3Config instance."""
  79. super().__init__(algo_class=algo_class or DreamerV3)
  80. # fmt: off
  81. # __sphinx_doc_begin__
  82. # DreamerV3 specific settings:
  83. self.model_size = "XS"
  84. self.training_ratio = 1024
  85. self.replay_buffer_config = {
  86. "type": "EpisodeReplayBuffer",
  87. "capacity": int(1e6),
  88. }
  89. self.world_model_lr = 1e-4
  90. self.actor_lr = 3e-5
  91. self.critic_lr = 3e-5
  92. self.batch_size_B = 16
  93. self.batch_length_T = 64
  94. self.horizon_H = 15
  95. self.gae_lambda = 0.95 # [1] eq. 7.
  96. self.entropy_scale = 3e-4 # [1] eq. 11.
  97. self.return_normalization_decay = 0.99 # [1] eq. 11 and 12.
  98. self.train_critic = True
  99. self.train_actor = True
  100. self.intrinsic_rewards_scale = 0.1
  101. self.world_model_grad_clip_by_global_norm = 1000.0
  102. self.critic_grad_clip_by_global_norm = 100.0
  103. self.actor_grad_clip_by_global_norm = 100.0
  104. self.symlog_obs = "auto"
  105. self.use_float16 = False
  106. self.use_curiosity = False
  107. # Reporting.
  108. # DreamerV3 is super sample efficient and only needs very few episodes
  109. # (normally) to learn. Leaving this at its default value would gravely
  110. # underestimate the learning performance over the course of an experiment.
  111. self.metrics_num_episodes_for_smoothing = 1
  112. self.report_individual_batch_item_stats = False
  113. self.report_dream_data = False
  114. self.report_images_and_videos = False
  115. # Override some of AlgorithmConfig's default values with DreamerV3-specific
  116. # values.
  117. self.lr = None
  118. self.gamma = 0.997 # [1] eq. 7.
  119. # Do not use! Set `batch_size_B` and `batch_length_T` instead.
  120. self.train_batch_size = None
  121. self.num_env_runners = 0
  122. self.rollout_fragment_length = 1
  123. # Dreamer only runs on the new API stack.
  124. self.enable_rl_module_and_learner = True
  125. self.enable_env_runner_and_connector_v2 = True
  126. # TODO (sven): DreamerV3 still uses its own EnvRunner class. This env-runner
  127. # does not use connectors. We therefore should not attempt to merge/broadcast
  128. # the connector states between EnvRunners (if >0). Note that this is only
  129. # relevant if num_env_runners > 0, which is normally not the case when using
  130. # this algo.
  131. self.use_worker_filter_stats = False
  132. # __sphinx_doc_end__
  133. # fmt: on
  134. @override(AlgorithmConfig)
  135. def build_env_to_module_connector(self, env, spaces, device):
  136. connector = super().build_env_to_module_connector(env, spaces, device)
  137. # Prepend the "is_first" connector such that the RSSM knows, when to insert
  138. # its (learned) internal state into the batch.
  139. # We have to do this before the `AddStatesFromEpisodesToBatch` piece
  140. # such that the column is properly batched/time-ranked.
  141. if self.add_default_connectors_to_learner_pipeline:
  142. connector.insert_before(
  143. AddStatesFromEpisodesToBatch,
  144. AddIsFirstsToBatch(),
  145. )
  146. return connector
  147. @property
  148. def batch_size_B_per_learner(self):
  149. """Returns the batch_size_B per Learner worker.
  150. Needed by some of the DreamerV3 loss math."""
  151. return self.batch_size_B // (self.num_learners or 1)
  152. @override(AlgorithmConfig)
  153. def training(
  154. self,
  155. *,
  156. model_size: Optional[str] = NotProvided,
  157. training_ratio: Optional[float] = NotProvided,
  158. batch_size_B: Optional[int] = NotProvided,
  159. batch_length_T: Optional[int] = NotProvided,
  160. horizon_H: Optional[int] = NotProvided,
  161. gae_lambda: Optional[float] = NotProvided,
  162. entropy_scale: Optional[float] = NotProvided,
  163. return_normalization_decay: Optional[float] = NotProvided,
  164. train_critic: Optional[bool] = NotProvided,
  165. train_actor: Optional[bool] = NotProvided,
  166. intrinsic_rewards_scale: Optional[float] = NotProvided,
  167. world_model_lr: Optional[LearningRateOrSchedule] = NotProvided,
  168. actor_lr: Optional[LearningRateOrSchedule] = NotProvided,
  169. critic_lr: Optional[LearningRateOrSchedule] = NotProvided,
  170. world_model_grad_clip_by_global_norm: Optional[float] = NotProvided,
  171. critic_grad_clip_by_global_norm: Optional[float] = NotProvided,
  172. actor_grad_clip_by_global_norm: Optional[float] = NotProvided,
  173. symlog_obs: Optional[Union[bool, str]] = NotProvided,
  174. use_float16: Optional[bool] = NotProvided,
  175. replay_buffer_config: Optional[dict] = NotProvided,
  176. use_curiosity: Optional[bool] = NotProvided,
  177. **kwargs,
  178. ) -> Self:
  179. """Sets the training related configuration.
  180. Args:
  181. model_size: The main switch for adjusting the overall model size. See [1]
  182. (table B) for more information on the effects of this setting on the
  183. model architecture.
  184. Supported values are "XS", "S", "M", "L", "XL" (as per the paper), as
  185. well as, "nano", "micro", "mini", and "XXS" (for RLlib's
  186. implementation). See ray.rllib.algorithms.dreamerv3.utils.
  187. __init__.py for the details on what exactly each size does to the layer
  188. sizes, number of layers, etc..
  189. training_ratio: The ratio of total steps trained (sum of the sizes of all
  190. batches ever sampled from the replay buffer) over the total env steps
  191. taken (in the actual environment, not the dreamed one). For example,
  192. if the training_ratio is 1024 and the batch size is 1024, we would take
  193. 1 env step for every training update: 1024 / 1. If the training ratio
  194. is 512 and the batch size is 1024, we would take 2 env steps and then
  195. perform a single training update (on a 1024 batch): 1024 / 2.
  196. batch_size_B: The batch size (B) interpreted as number of rows (each of
  197. length `batch_length_T`) to sample from the replay buffer in each
  198. iteration.
  199. batch_length_T: The batch length (T) interpreted as the length of each row
  200. sampled from the replay buffer in each iteration. Note that
  201. `batch_size_B` rows will be sampled in each iteration. Rows normally
  202. contain consecutive data (consecutive timesteps from the same episode),
  203. but there might be episode boundaries in a row as well.
  204. horizon_H: The horizon (in timesteps) used to create dreamed data from the
  205. world model, which in turn is used to train/update both actor- and
  206. critic networks.
  207. gae_lambda: The lambda parameter used for computing the GAE-style
  208. value targets for the actor- and critic losses.
  209. entropy_scale: The factor with which to multiply the entropy loss term
  210. inside the actor loss.
  211. return_normalization_decay: The decay value to use when computing the
  212. running EMA values for return normalization (used in the actor loss).
  213. train_critic: Whether to train the critic network. If False, `train_actor`
  214. must also be False (cannot train actor w/o training the critic).
  215. train_actor: Whether to train the actor network. If True, `train_critic`
  216. must also be True (cannot train actor w/o training the critic).
  217. intrinsic_rewards_scale: The factor to multiply intrinsic rewards with
  218. before adding them to the extrinsic (environment) rewards.
  219. world_model_lr: The learning rate or schedule for the world model optimizer.
  220. actor_lr: The learning rate or schedule for the actor optimizer.
  221. critic_lr: The learning rate or schedule for the critic optimizer.
  222. world_model_grad_clip_by_global_norm: World model grad clipping value
  223. (by global norm).
  224. critic_grad_clip_by_global_norm: Critic grad clipping value
  225. (by global norm).
  226. actor_grad_clip_by_global_norm: Actor grad clipping value (by global norm).
  227. symlog_obs: Whether to symlog observations or not. If set to "auto"
  228. (default), will check for the environment's observation space and then
  229. only symlog if not an image space.
  230. use_float16: Whether to train with mixed float16 precision. In this mode,
  231. model parameters are stored as float32, but all computations are
  232. performed in float16 space (except for losses and distribution params
  233. and outputs).
  234. replay_buffer_config: Replay buffer config.
  235. Only serves in DreamerV3 to set the capacity of the replay buffer.
  236. Note though that in the paper ([1]) a size of 1M is used for all
  237. benchmarks and there doesn't seem to be a good reason to change this
  238. parameter.
  239. Examples:
  240. {
  241. "type": "EpisodeReplayBuffer",
  242. "capacity": 100000,
  243. }
  244. Returns:
  245. This updated AlgorithmConfig object.
  246. """
  247. # Not fully supported/tested yet.
  248. if use_curiosity is not NotProvided:
  249. raise ValueError(
  250. "`DreamerV3Config.curiosity` is not fully supported and tested yet! "
  251. "It thus remains disabled for now."
  252. )
  253. # Pass kwargs onto super's `training()` method.
  254. super().training(**kwargs)
  255. if model_size is not NotProvided:
  256. self.model_size = model_size
  257. if training_ratio is not NotProvided:
  258. self.training_ratio = training_ratio
  259. if batch_size_B is not NotProvided:
  260. self.batch_size_B = batch_size_B
  261. if batch_length_T is not NotProvided:
  262. self.batch_length_T = batch_length_T
  263. if horizon_H is not NotProvided:
  264. self.horizon_H = horizon_H
  265. if gae_lambda is not NotProvided:
  266. self.gae_lambda = gae_lambda
  267. if entropy_scale is not NotProvided:
  268. self.entropy_scale = entropy_scale
  269. if return_normalization_decay is not NotProvided:
  270. self.return_normalization_decay = return_normalization_decay
  271. if train_critic is not NotProvided:
  272. self.train_critic = train_critic
  273. if train_actor is not NotProvided:
  274. self.train_actor = train_actor
  275. if intrinsic_rewards_scale is not NotProvided:
  276. self.intrinsic_rewards_scale = intrinsic_rewards_scale
  277. if world_model_lr is not NotProvided:
  278. self.world_model_lr = world_model_lr
  279. if actor_lr is not NotProvided:
  280. self.actor_lr = actor_lr
  281. if critic_lr is not NotProvided:
  282. self.critic_lr = critic_lr
  283. if world_model_grad_clip_by_global_norm is not NotProvided:
  284. self.world_model_grad_clip_by_global_norm = (
  285. world_model_grad_clip_by_global_norm
  286. )
  287. if critic_grad_clip_by_global_norm is not NotProvided:
  288. self.critic_grad_clip_by_global_norm = critic_grad_clip_by_global_norm
  289. if actor_grad_clip_by_global_norm is not NotProvided:
  290. self.actor_grad_clip_by_global_norm = actor_grad_clip_by_global_norm
  291. if symlog_obs is not NotProvided:
  292. self.symlog_obs = symlog_obs
  293. if use_float16 is not NotProvided:
  294. self.use_float16 = use_float16
  295. if replay_buffer_config is not NotProvided:
  296. # Override entire `replay_buffer_config` if `type` key changes.
  297. # Update, if `type` key remains the same or is not specified.
  298. new_replay_buffer_config = deep_update(
  299. {"replay_buffer_config": self.replay_buffer_config},
  300. {"replay_buffer_config": replay_buffer_config},
  301. False,
  302. ["replay_buffer_config"],
  303. ["replay_buffer_config"],
  304. )
  305. self.replay_buffer_config = new_replay_buffer_config["replay_buffer_config"]
  306. return self
  307. @override(AlgorithmConfig)
  308. def reporting(
  309. self,
  310. *,
  311. report_individual_batch_item_stats: Optional[bool] = NotProvided,
  312. report_dream_data: Optional[bool] = NotProvided,
  313. report_images_and_videos: Optional[bool] = NotProvided,
  314. **kwargs,
  315. ):
  316. """Sets the reporting related configuration.
  317. Args:
  318. report_individual_batch_item_stats: Whether to include loss and other stats
  319. per individual timestep inside the training batch in the result dict
  320. returned by `training_step()`. If True, besides the `CRITIC_L_total`,
  321. the individual critic loss values per batch row and time axis step
  322. in the train batch (CRITIC_L_total_B_T) will also be part of the
  323. results.
  324. report_dream_data: Whether to include the dreamed trajectory data in the
  325. result dict returned by `training_step()`. If True, however, will
  326. slice each reported item in the dream data down to the shape.
  327. (H, B, t=0, ...), where H is the horizon and B is the batch size. The
  328. original time axis will only be represented by the first timestep
  329. to not make this data too large to handle.
  330. report_images_and_videos: Whether to include any image/video data in the
  331. result dict returned by `training_step()`.
  332. **kwargs:
  333. Returns:
  334. This updated AlgorithmConfig object.
  335. """
  336. super().reporting(**kwargs)
  337. if report_individual_batch_item_stats is not NotProvided:
  338. self.report_individual_batch_item_stats = report_individual_batch_item_stats
  339. if report_dream_data is not NotProvided:
  340. self.report_dream_data = report_dream_data
  341. if report_images_and_videos is not NotProvided:
  342. self.report_images_and_videos = report_images_and_videos
  343. return self
  344. @override(AlgorithmConfig)
  345. def validate(self) -> None:
  346. # Call the super class' validation method first.
  347. super().validate()
  348. # Make sure, users are not using DreamerV3 yet for multi-agent:
  349. if self.is_multi_agent:
  350. self._value_error("DreamerV3 does NOT support multi-agent setups yet!")
  351. # Make sure, we are configure for the new API stack.
  352. if not self.enable_rl_module_and_learner:
  353. self._value_error(
  354. "DreamerV3 must be run with `config.api_stack("
  355. "enable_rl_module_and_learner=True)`!"
  356. )
  357. # If run on several Learners, the provided batch_size_B must be a multiple
  358. # of `num_learners`.
  359. if self.num_learners > 1 and (self.batch_size_B % self.num_learners != 0):
  360. self._value_error(
  361. f"Your `batch_size_B` ({self.batch_size_B}) must be a multiple of "
  362. f"`num_learners` ({self.num_learners}) in order for "
  363. "DreamerV3 to be able to split batches evenly across your Learner "
  364. "processes."
  365. )
  366. # Cannot train actor w/o critic.
  367. if self.train_actor and not self.train_critic:
  368. self._value_error(
  369. "Cannot train actor network (`train_actor=True`) w/o training critic! "
  370. "Make sure you either set `train_critic=True` or `train_actor=False`."
  371. )
  372. # Use DreamerV3 specific batch size settings.
  373. if self.train_batch_size is not None:
  374. self._value_error(
  375. "`train_batch_size` should NOT be set! Use `batch_size_B` and "
  376. "`batch_length_T` instead."
  377. )
  378. # Must be run with `EpisodeReplayBuffer` type.
  379. if self.replay_buffer_config.get("type") != "EpisodeReplayBuffer":
  380. self._value_error(
  381. "DreamerV3 must be run with the `EpisodeReplayBuffer` type! None "
  382. "other supported."
  383. )
  384. @override(AlgorithmConfig)
  385. def get_default_learner_class(self):
  386. if self.framework_str == "torch":
  387. from ray.rllib.algorithms.dreamerv3.torch.dreamerv3_torch_learner import (
  388. DreamerV3TorchLearner,
  389. )
  390. return DreamerV3TorchLearner
  391. else:
  392. raise ValueError(f"The framework {self.framework_str} is not supported.")
  393. @override(AlgorithmConfig)
  394. def get_default_rl_module_spec(self) -> RLModuleSpec:
  395. if self.framework_str == "torch":
  396. from ray.rllib.algorithms.dreamerv3.torch.dreamerv3_torch_rl_module import (
  397. DreamerV3TorchRLModule as module,
  398. )
  399. else:
  400. raise ValueError(f"The framework {self.framework_str} is not supported.")
  401. return RLModuleSpec(module_class=module, catalog_class=DreamerV3Catalog)
  402. @property
  403. @override(AlgorithmConfig)
  404. def _model_config_auto_includes(self) -> Dict[str, Any]:
  405. return super()._model_config_auto_includes | {
  406. "gamma": self.gamma,
  407. "horizon_H": self.horizon_H,
  408. "model_size": self.model_size,
  409. "symlog_obs": self.symlog_obs,
  410. "use_float16": self.use_float16,
  411. "batch_length_T": self.batch_length_T,
  412. }
  413. class DreamerV3(Algorithm):
  414. """Implementation of the model-based DreamerV3 RL algorithm described in [1]."""
  415. # TODO (sven): Deprecate/do-over the Algorithm.compute_single_action() API.
  416. @override(Algorithm)
  417. def compute_single_action(self, *args, **kwargs):
  418. raise NotImplementedError(
  419. "DreamerV3 does not support the `compute_single_action()` API. Refer to the"
  420. " README here (https://github.com/ray-project/ray/tree/master/rllib/"
  421. "algorithms/dreamerv3) to find more information on how to run action "
  422. "inference with this algorithm."
  423. )
  424. @classmethod
  425. @override(Algorithm)
  426. def get_default_config(cls) -> DreamerV3Config:
  427. return DreamerV3Config()
  428. @override(Algorithm)
  429. def setup(self, config: AlgorithmConfig):
  430. super().setup(config)
  431. # Share RLModule between EnvRunner and single (local) Learner instance.
  432. # To avoid possibly expensive weight synching step.
  433. # if self.config.share_module_between_env_runner_and_learner:
  434. # assert self.env_runner.module is None
  435. # self.env_runner.module = self.learner_group._learner.module[
  436. # DEFAULT_MODULE_ID
  437. # ]
  438. # Create a replay buffer for storing actual env samples.
  439. self.replay_buffer = EpisodeReplayBuffer(
  440. capacity=self.config.replay_buffer_config["capacity"],
  441. batch_size_B=self.config.batch_size_B,
  442. batch_length_T=self.config.batch_length_T,
  443. )
  444. @override(Algorithm)
  445. def training_step(self) -> None:
  446. # Push enough samples into buffer initially before we start training.
  447. if self.training_iteration == 0:
  448. logger.info(
  449. "Filling replay buffer so it contains at least "
  450. f"{self.config.batch_size_B * self.config.batch_length_T} timesteps "
  451. "(required for a single train batch)."
  452. )
  453. # Have we sampled yet in this `training_step()` call?
  454. have_sampled = False
  455. with self.metrics.log_time((TIMERS, SAMPLE_TIMER)):
  456. # Continue sampling from the actual environment (and add collected samples
  457. # to our replay buffer) as long as we:
  458. while (
  459. # a) Don't have at least batch_size_B x batch_length_T timesteps stored
  460. # in the buffer. This is the minimum needed to train.
  461. self.replay_buffer.get_num_timesteps()
  462. < (self.config.batch_size_B * self.config.batch_length_T)
  463. # b) The computed `training_ratio` is >= the configured (desired)
  464. # training ratio (meaning we should continue sampling).
  465. or self.training_ratio >= self.config.training_ratio
  466. # c) we have not sampled at all yet in this `training_step()` call.
  467. or not have_sampled
  468. ):
  469. # Sample using the env runner's module.
  470. episodes, env_runner_results = synchronous_parallel_sample(
  471. worker_set=self.env_runner_group,
  472. max_agent_steps=(
  473. self.config.rollout_fragment_length
  474. * self.config.num_envs_per_env_runner
  475. ),
  476. sample_timeout_s=self.config.sample_timeout_s,
  477. _uses_new_env_runners=True,
  478. _return_metrics=True,
  479. )
  480. self.metrics.aggregate(env_runner_results, key=ENV_RUNNER_RESULTS)
  481. # Add ongoing and finished episodes into buffer. The buffer will
  482. # automatically take care of properly concatenating (by episode IDs)
  483. # the different chunks of the same episodes, even if they come in via
  484. # separate `add()` calls.
  485. self.replay_buffer.add(episodes=episodes)
  486. have_sampled = True
  487. # We took B x T env steps.
  488. env_steps_last_regular_sample = sum(len(eps) for eps in episodes)
  489. total_sampled = env_steps_last_regular_sample
  490. # If we have never sampled before (just started the algo and not
  491. # recovered from a checkpoint), sample B random actions first.
  492. if (
  493. self.metrics.peek(
  494. (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME),
  495. default=0,
  496. )
  497. == 0
  498. ):
  499. _episodes, _env_runner_results = synchronous_parallel_sample(
  500. worker_set=self.env_runner_group,
  501. max_agent_steps=(
  502. self.config.batch_size_B * self.config.batch_length_T
  503. - env_steps_last_regular_sample
  504. ),
  505. sample_timeout_s=self.config.sample_timeout_s,
  506. random_actions=True,
  507. _uses_new_env_runners=True,
  508. _return_metrics=True,
  509. )
  510. self.metrics.aggregate(_env_runner_results, key=ENV_RUNNER_RESULTS)
  511. self.replay_buffer.add(episodes=_episodes)
  512. total_sampled += sum(len(eps) for eps in _episodes)
  513. # Summarize environment interaction and buffer data.
  514. report_sampling_and_replay_buffer(
  515. metrics=self.metrics, replay_buffer=self.replay_buffer
  516. )
  517. # Get the replay buffer metrics.
  518. replay_buffer_results = self.local_replay_buffer.get_metrics()
  519. self.metrics.aggregate([replay_buffer_results], key=REPLAY_BUFFER_RESULTS)
  520. # Use self.spaces for the environment spaces of the env-runners
  521. single_observation_space, single_action_space = self.spaces[
  522. INPUT_ENV_SINGLE_SPACES
  523. ]
  524. # Continue sampling batch_size_B x batch_length_T sized batches from the buffer
  525. # and using these to update our models (`LearnerGroup.update()`)
  526. # until the computed `training_ratio` is larger than the configured one, meaning
  527. # we should go back and collect more samples again from the actual environment.
  528. # However, when calculating the `training_ratio` here, we use only the
  529. # trained steps in this very `training_step()` call over the most recent sample
  530. # amount (`env_steps_last_regular_sample`), not the global values. This is to
  531. # avoid a heavy overtraining at the very beginning when we have just pre-filled
  532. # the buffer with the minimum amount of samples.
  533. replayed_steps_this_iter = sub_iter = 0
  534. while (
  535. replayed_steps_this_iter / env_steps_last_regular_sample
  536. ) < self.config.training_ratio:
  537. # Time individual batch updates.
  538. with self.metrics.log_time((TIMERS, LEARN_ON_BATCH_TIMER)):
  539. logger.info(f"\tSub-iteration {self.training_iteration}/{sub_iter})")
  540. # Draw a new sample from the replay buffer.
  541. sample = self.replay_buffer.sample(
  542. batch_size_B=self.config.batch_size_B,
  543. batch_length_T=self.config.batch_length_T,
  544. )
  545. replayed_steps = self.config.batch_size_B * self.config.batch_length_T
  546. replayed_steps_this_iter += replayed_steps
  547. if isinstance(single_action_space, gym.spaces.Discrete):
  548. sample["actions_ints"] = sample[Columns.ACTIONS]
  549. sample[Columns.ACTIONS] = one_hot(
  550. sample["actions_ints"],
  551. depth=single_action_space.n,
  552. )
  553. # Perform the actual update via our learner group.
  554. learner_results = self.learner_group.update(
  555. batch=SampleBatch(sample).as_multi_agent(),
  556. # TODO(sven): Maybe we should do this broadcase of global timesteps
  557. # at the end, like for EnvRunner global env step counts. Maybe when
  558. # we request the state from the Learners, we can - at the same
  559. # time - send the current globally summed/reduced-timesteps.
  560. timesteps={
  561. NUM_ENV_STEPS_SAMPLED_LIFETIME: self.metrics.peek(
  562. (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME),
  563. default=0,
  564. )
  565. },
  566. )
  567. self.metrics.aggregate(learner_results, key=LEARNER_RESULTS)
  568. sub_iter += 1
  569. self.metrics.log_value(
  570. NUM_GRAD_UPDATES_LIFETIME, 1, reduce="lifetime_sum"
  571. )
  572. # Log videos showing how the decoder produces observation predictions
  573. # from the posterior states.
  574. # Only every n iterations and only for the first sampled batch row
  575. # (videos are `config.batch_length_T` frames long).
  576. report_predicted_vs_sampled_obs(
  577. # TODO (sven): DreamerV3 is single-agent only.
  578. metrics=self.metrics,
  579. sample=sample,
  580. batch_size_B=self.config.batch_size_B,
  581. batch_length_T=self.config.batch_length_T,
  582. symlog_obs=do_symlog_obs(
  583. single_observation_space,
  584. self.config.symlog_obs,
  585. ),
  586. do_report=(
  587. self.config.report_images_and_videos
  588. and self.training_iteration % 100 == 0
  589. ),
  590. )
  591. # Log videos showing some of the dreamed trajectories and compare them with the
  592. # actual trajectories from the train batch.
  593. # Only every n iterations and only for the first sampled batch row AND first ts.
  594. # (videos are `config.horizon_H` frames long originating from the observation
  595. # at B=0 and T=0 in the train batch).
  596. report_dreamed_eval_trajectory_vs_samples(
  597. metrics=self.metrics,
  598. sample=sample,
  599. burn_in_T=0,
  600. dreamed_T=self.config.horizon_H + 1,
  601. dreamer_model=self.env_runner.module.dreamer_model,
  602. symlog_obs=do_symlog_obs(
  603. single_observation_space,
  604. self.config.symlog_obs,
  605. ),
  606. do_report=(
  607. self.config.report_dream_data and self.training_iteration % 100 == 0
  608. ),
  609. framework=self.config.framework_str,
  610. )
  611. # Update weights - after learning on the LearnerGroup - on all EnvRunner
  612. # workers.
  613. with self.metrics.log_time((TIMERS, SYNCH_WORKER_WEIGHTS_TIMER)):
  614. # Only necessary if RLModule is not shared between (local) EnvRunner and
  615. # (local) Learner.
  616. # if not self.config.share_module_between_env_runner_and_learner:
  617. self.metrics.log_value(NUM_SYNCH_WORKER_WEIGHTS, 1, reduce="sum")
  618. self.env_runner_group.sync_weights(
  619. from_worker_or_learner_group=self.learner_group,
  620. inference_only=True,
  621. )
  622. # Add train results and the actual training ratio to stats. The latter should
  623. # be close to the configured `training_ratio`.
  624. self.metrics.log_value("actual_training_ratio", self.training_ratio, window=1)
  625. @property
  626. def training_ratio(self) -> float:
  627. """Returns the actual training ratio of this Algorithm (not the configured one).
  628. The training ratio is copmuted by dividing the total number of steps
  629. trained thus far (replayed from the buffer) over the total number of actual
  630. env steps taken thus far.
  631. """
  632. eps = 0.0001
  633. return self.metrics.peek(NUM_ENV_STEPS_TRAINED_LIFETIME, default=0) / (
  634. (
  635. self.metrics.peek(
  636. (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME),
  637. default=eps,
  638. )
  639. or eps
  640. )
  641. )
  642. # TODO (sven): Remove this once DreamerV3 is on the new SingleAgentEnvRunner.
  643. @PublicAPI
  644. def __setstate__(self, state) -> None:
  645. """Sts the algorithm to the provided state
  646. Args:
  647. state: The state dictionary to restore this `DreamerV3` instance to.
  648. `state` may have been returned by a call to an `Algorithm`'s
  649. `__getstate__()` method.
  650. """
  651. # Call the `Algorithm`'s `__setstate__()` method.
  652. super().__setstate__(state=state)
  653. # Assign the module to the local `EnvRunner` if sharing is enabled.
  654. # Note, in `Learner.restore_from_path()` the module is first deleted
  655. # and then a new one is built - therefore the worker has no
  656. # longer a copy of the learner.
  657. if self.config.share_module_between_env_runner_and_learner:
  658. assert id(self.env_runner.module) != id(
  659. self.learner_group._learner.module[DEFAULT_MODULE_ID]
  660. )
  661. self.env_runner.module = self.learner_group._learner.module[
  662. DEFAULT_MODULE_ID
  663. ]