env_runner_v2.py 51 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233
  1. import logging
  2. import time
  3. from collections import defaultdict
  4. from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Set, Tuple, Union
  5. import numpy as np
  6. import tree # pip install dm_tree
  7. from ray.rllib.env.base_env import ASYNC_RESET_RETURN, BaseEnv
  8. from ray.rllib.env.external_env import ExternalEnvWrapper
  9. from ray.rllib.env.wrappers.atari_wrappers import MonitorEnv, get_wrapper_by_cls
  10. from ray.rllib.evaluation.collectors.simple_list_collector import _PolicyCollectorGroup
  11. from ray.rllib.evaluation.episode_v2 import EpisodeV2
  12. from ray.rllib.evaluation.metrics import RolloutMetrics
  13. from ray.rllib.models.preprocessors import Preprocessor
  14. from ray.rllib.policy.policy import Policy
  15. from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch, concat_samples
  16. from ray.rllib.utils.annotations import OldAPIStack
  17. from ray.rllib.utils.filter import Filter
  18. from ray.rllib.utils.numpy import convert_to_numpy
  19. from ray.rllib.utils.spaces.space_utils import get_original_space, unbatch
  20. from ray.rllib.utils.typing import (
  21. ActionConnectorDataType,
  22. AgentConnectorDataType,
  23. AgentID,
  24. EnvActionType,
  25. EnvID,
  26. EnvInfoDict,
  27. EnvObsType,
  28. MultiAgentDict,
  29. MultiEnvDict,
  30. PolicyID,
  31. PolicyOutputType,
  32. SampleBatchType,
  33. StateBatches,
  34. TensorStructType,
  35. )
  36. from ray.util.debug import log_once
  37. if TYPE_CHECKING:
  38. from gymnasium.envs.classic_control.rendering import SimpleImageViewer
  39. from ray.rllib.callbacks.callbacks import RLlibCallback
  40. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  41. logger = logging.getLogger(__name__)
  42. MIN_LARGE_BATCH_THRESHOLD = 1000
  43. DEFAULT_LARGE_BATCH_THRESHOLD = 5000
  44. MS_TO_SEC = 1000.0
  45. @OldAPIStack
  46. class _PerfStats:
  47. """Sampler perf stats that will be included in rollout metrics."""
  48. def __init__(self, ema_coef: Optional[float] = None):
  49. # If not None, enable Exponential Moving Average mode.
  50. # The way we update stats is by:
  51. # updated = (1 - ema_coef) * old + ema_coef * new
  52. # In general provides more responsive stats about sampler performance.
  53. # TODO(jungong) : make ema the default (only) mode if it works well.
  54. self.ema_coef = ema_coef
  55. self.iters = 0
  56. self.raw_obs_processing_time = 0.0
  57. self.inference_time = 0.0
  58. self.action_processing_time = 0.0
  59. self.env_wait_time = 0.0
  60. self.env_render_time = 0.0
  61. def incr(self, field: str, value: Union[int, float]):
  62. if field == "iters":
  63. self.iters += value
  64. return
  65. # All the other fields support either global average or ema mode.
  66. if self.ema_coef is None:
  67. # Global average.
  68. self.__dict__[field] += value
  69. else:
  70. self.__dict__[field] = (1.0 - self.ema_coef) * self.__dict__[
  71. field
  72. ] + self.ema_coef * value
  73. def _get_avg(self):
  74. # Mean multiplicator (1000 = sec -> ms).
  75. factor = MS_TO_SEC / self.iters
  76. return {
  77. # Raw observation preprocessing.
  78. "mean_raw_obs_processing_ms": self.raw_obs_processing_time * factor,
  79. # Computing actions through policy.
  80. "mean_inference_ms": self.inference_time * factor,
  81. # Processing actions (to be sent to env, e.g. clipping).
  82. "mean_action_processing_ms": self.action_processing_time * factor,
  83. # Waiting for environment (during poll).
  84. "mean_env_wait_ms": self.env_wait_time * factor,
  85. # Environment rendering (False by default).
  86. "mean_env_render_ms": self.env_render_time * factor,
  87. }
  88. def _get_ema(self):
  89. # In EMA mode, stats are already (exponentially) averaged,
  90. # hence we only need to do the sec -> ms conversion here.
  91. return {
  92. # Raw observation preprocessing.
  93. "mean_raw_obs_processing_ms": self.raw_obs_processing_time * MS_TO_SEC,
  94. # Computing actions through policy.
  95. "mean_inference_ms": self.inference_time * MS_TO_SEC,
  96. # Processing actions (to be sent to env, e.g. clipping).
  97. "mean_action_processing_ms": self.action_processing_time * MS_TO_SEC,
  98. # Waiting for environment (during poll).
  99. "mean_env_wait_ms": self.env_wait_time * MS_TO_SEC,
  100. # Environment rendering (False by default).
  101. "mean_env_render_ms": self.env_render_time * MS_TO_SEC,
  102. }
  103. def get(self):
  104. if self.ema_coef is None:
  105. return self._get_avg()
  106. else:
  107. return self._get_ema()
  108. @OldAPIStack
  109. class _NewDefaultDict(defaultdict):
  110. def __missing__(self, env_id):
  111. ret = self[env_id] = self.default_factory(env_id)
  112. return ret
  113. @OldAPIStack
  114. def _build_multi_agent_batch(
  115. episode_id: int,
  116. batch_builder: _PolicyCollectorGroup,
  117. large_batch_threshold: int,
  118. multiple_episodes_in_batch: bool,
  119. ) -> MultiAgentBatch:
  120. """Build MultiAgentBatch from a dict of _PolicyCollectors.
  121. Args:
  122. env_steps: total env steps.
  123. policy_collectors: collected training SampleBatchs by policy.
  124. Returns:
  125. Always returns a sample batch in MultiAgentBatch format.
  126. """
  127. ma_batch = {}
  128. for pid, collector in batch_builder.policy_collectors.items():
  129. if collector.agent_steps <= 0:
  130. continue
  131. if batch_builder.agent_steps > large_batch_threshold and log_once(
  132. "large_batch_warning"
  133. ):
  134. logger.warning(
  135. "More than {} observations in {} env steps for "
  136. "episode {} ".format(
  137. batch_builder.agent_steps, batch_builder.env_steps, episode_id
  138. )
  139. + "are buffered in the sampler. If this is more than you "
  140. "expected, check that that you set a horizon on your "
  141. "environment correctly and that it terminates at some "
  142. "point. Note: In multi-agent environments, "
  143. "`rollout_fragment_length` sets the batch size based on "
  144. "(across-agents) environment steps, not the steps of "
  145. "individual agents, which can result in unexpectedly "
  146. "large batches."
  147. + (
  148. "Also, you may be waiting for your Env to "
  149. "terminate (batch_mode=`complete_episodes`). Make sure "
  150. "it does at some point."
  151. if not multiple_episodes_in_batch
  152. else ""
  153. )
  154. )
  155. batch = collector.build()
  156. ma_batch[pid] = batch
  157. # Create the multi agent batch.
  158. return MultiAgentBatch(policy_batches=ma_batch, env_steps=batch_builder.env_steps)
  159. @OldAPIStack
  160. def _batch_inference_sample_batches(eval_data: List[SampleBatch]) -> SampleBatch:
  161. """Batch a list of input SampleBatches into a single SampleBatch.
  162. Args:
  163. eval_data: list of SampleBatches.
  164. Returns:
  165. single batched SampleBatch.
  166. """
  167. inference_batch = concat_samples(eval_data)
  168. if "state_in_0" in inference_batch:
  169. batch_size = len(eval_data)
  170. inference_batch[SampleBatch.SEQ_LENS] = np.ones(batch_size, dtype=np.int32)
  171. return inference_batch
  172. @OldAPIStack
  173. class EnvRunnerV2:
  174. """Collect experiences from user environment using Connectors."""
  175. def __init__(
  176. self,
  177. worker: "RolloutWorker",
  178. base_env: BaseEnv,
  179. multiple_episodes_in_batch: bool,
  180. callbacks: "RLlibCallback",
  181. perf_stats: _PerfStats,
  182. rollout_fragment_length: int = 200,
  183. count_steps_by: str = "env_steps",
  184. render: bool = None,
  185. ):
  186. """
  187. Args:
  188. worker: Reference to the current rollout worker.
  189. base_env: Env implementing BaseEnv.
  190. multiple_episodes_in_batch: Whether to pack multiple
  191. episodes into each batch. This guarantees batches will be exactly
  192. `rollout_fragment_length` in size.
  193. callbacks: User callbacks to run on episode events.
  194. perf_stats: Record perf stats into this object.
  195. rollout_fragment_length: The length of a fragment to collect
  196. before building a SampleBatch from the data and resetting
  197. the SampleBatchBuilder object.
  198. count_steps_by: One of "env_steps" (default) or "agent_steps".
  199. Use "agent_steps", if you want rollout lengths to be counted
  200. by individual agent steps. In a multi-agent env,
  201. a single env_step contains one or more agent_steps, depending
  202. on how many agents are present at any given time in the
  203. ongoing episode.
  204. render: Whether to try to render the environment after each
  205. step.
  206. """
  207. self._worker = worker
  208. if isinstance(base_env, ExternalEnvWrapper):
  209. raise ValueError(
  210. "Policies using the new Connector API do not support ExternalEnv."
  211. )
  212. self._base_env = base_env
  213. self._multiple_episodes_in_batch = multiple_episodes_in_batch
  214. self._callbacks = callbacks
  215. self._perf_stats = perf_stats
  216. self._rollout_fragment_length = rollout_fragment_length
  217. self._count_steps_by = count_steps_by
  218. self._render = render
  219. # May be populated for image rendering.
  220. self._simple_image_viewer: Optional[
  221. "SimpleImageViewer"
  222. ] = self._get_simple_image_viewer()
  223. # Keeps track of active episodes.
  224. self._active_episodes: Dict[EnvID, EpisodeV2] = {}
  225. self._batch_builders: Dict[EnvID, _PolicyCollectorGroup] = _NewDefaultDict(
  226. self._new_batch_builder
  227. )
  228. self._large_batch_threshold: int = (
  229. max(MIN_LARGE_BATCH_THRESHOLD, self._rollout_fragment_length * 10)
  230. if self._rollout_fragment_length != float("inf")
  231. else DEFAULT_LARGE_BATCH_THRESHOLD
  232. )
  233. def _get_simple_image_viewer(self):
  234. """Maybe construct a SimpleImageViewer instance for episode rendering."""
  235. # Try to render the env, if required.
  236. if not self._render:
  237. return None
  238. try:
  239. from gymnasium.envs.classic_control.rendering import SimpleImageViewer
  240. return SimpleImageViewer()
  241. except (ImportError, ModuleNotFoundError):
  242. self._render = False # disable rendering
  243. logger.warning(
  244. "Could not import gymnasium.envs.classic_control."
  245. "rendering! Try `pip install gymnasium[all]`."
  246. )
  247. return None
  248. def _call_on_episode_start(self, episode, env_id):
  249. # Call each policy's Exploration.on_episode_start method.
  250. # Note: This may break the exploration (e.g. ParameterNoise) of
  251. # policies in the `policy_map` that have not been recently used
  252. # (and are therefore stashed to disk). However, we certainly do not
  253. # want to loop through all (even stashed) policies here as that
  254. # would counter the purpose of the LRU policy caching.
  255. for p in self._worker.policy_map.cache.values():
  256. if getattr(p, "exploration", None) is not None:
  257. p.exploration.on_episode_start(
  258. policy=p,
  259. environment=self._base_env,
  260. episode=episode,
  261. tf_sess=p.get_session(),
  262. )
  263. # Call `on_episode_start()` callback.
  264. self._callbacks.on_episode_start(
  265. worker=self._worker,
  266. base_env=self._base_env,
  267. policies=self._worker.policy_map,
  268. env_index=env_id,
  269. episode=episode,
  270. )
  271. def _new_batch_builder(self, _) -> _PolicyCollectorGroup:
  272. """Create a new batch builder.
  273. We create a _PolicyCollectorGroup based on the full policy_map
  274. as the batch builder.
  275. """
  276. return _PolicyCollectorGroup(self._worker.policy_map)
  277. def run(self) -> Iterator[SampleBatchType]:
  278. """Samples and yields training episodes continuously.
  279. Yields:
  280. Object containing state, action, reward, terminal condition,
  281. and other fields as dictated by `policy`.
  282. """
  283. while True:
  284. outputs = self.step()
  285. for o in outputs:
  286. yield o
  287. def step(self) -> List[SampleBatchType]:
  288. """Samples training episodes by stepping through environments."""
  289. self._perf_stats.incr("iters", 1)
  290. t0 = time.time()
  291. # Get observations from all ready agents.
  292. # types: MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, ...
  293. (
  294. unfiltered_obs,
  295. rewards,
  296. terminateds,
  297. truncateds,
  298. infos,
  299. off_policy_actions,
  300. ) = self._base_env.poll()
  301. env_poll_time = time.time() - t0
  302. # Process observations and prepare for policy evaluation.
  303. t1 = time.time()
  304. # types: Set[EnvID], Dict[PolicyID, List[AgentConnectorDataType]],
  305. # List[Union[RolloutMetrics, SampleBatchType]]
  306. active_envs, to_eval, outputs = self._process_observations(
  307. unfiltered_obs=unfiltered_obs,
  308. rewards=rewards,
  309. terminateds=terminateds,
  310. truncateds=truncateds,
  311. infos=infos,
  312. )
  313. self._perf_stats.incr("raw_obs_processing_time", time.time() - t1)
  314. # Do batched policy eval (accross vectorized envs).
  315. t2 = time.time()
  316. # types: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]
  317. eval_results = self._do_policy_eval(to_eval=to_eval)
  318. self._perf_stats.incr("inference_time", time.time() - t2)
  319. # Process results and update episode state.
  320. t3 = time.time()
  321. actions_to_send: Dict[
  322. EnvID, Dict[AgentID, EnvActionType]
  323. ] = self._process_policy_eval_results(
  324. active_envs=active_envs,
  325. to_eval=to_eval,
  326. eval_results=eval_results,
  327. off_policy_actions=off_policy_actions,
  328. )
  329. self._perf_stats.incr("action_processing_time", time.time() - t3)
  330. # Return computed actions to ready envs. We also send to envs that have
  331. # taken off-policy actions; those envs are free to ignore the action.
  332. t4 = time.time()
  333. self._base_env.send_actions(actions_to_send)
  334. self._perf_stats.incr("env_wait_time", env_poll_time + time.time() - t4)
  335. self._maybe_render()
  336. return outputs
  337. def _get_rollout_metrics(
  338. self, episode: EpisodeV2, policy_map: Dict[str, Policy]
  339. ) -> List[RolloutMetrics]:
  340. """Get rollout metrics from completed episode."""
  341. # TODO(jungong) : why do we need to handle atari metrics differently?
  342. # Can we unify atari and normal env metrics?
  343. atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics(self._base_env)
  344. if atari_metrics is not None:
  345. for m in atari_metrics:
  346. m._replace(custom_metrics=episode.custom_metrics)
  347. return atari_metrics
  348. # Create connector metrics
  349. connector_metrics = {}
  350. active_agents = episode.get_agents()
  351. for agent in active_agents:
  352. policy_id = episode.policy_for(agent)
  353. policy = episode.policy_map[policy_id]
  354. connector_metrics[policy_id] = policy.get_connector_metrics()
  355. # Otherwise, return RolloutMetrics for the episode.
  356. return [
  357. RolloutMetrics(
  358. episode_length=episode.length,
  359. episode_reward=episode.total_reward,
  360. agent_rewards=dict(episode.agent_rewards),
  361. custom_metrics=episode.custom_metrics,
  362. perf_stats={},
  363. hist_data=episode.hist_data,
  364. media=episode.media,
  365. connector_metrics=connector_metrics,
  366. )
  367. ]
  368. def _process_observations(
  369. self,
  370. unfiltered_obs: MultiEnvDict,
  371. rewards: MultiEnvDict,
  372. terminateds: MultiEnvDict,
  373. truncateds: MultiEnvDict,
  374. infos: MultiEnvDict,
  375. ) -> Tuple[
  376. Set[EnvID],
  377. Dict[PolicyID, List[AgentConnectorDataType]],
  378. List[Union[RolloutMetrics, SampleBatchType]],
  379. ]:
  380. """Process raw obs from env.
  381. Group data for active agents by policy. Reset environments that are done.
  382. Args:
  383. unfiltered_obs: The unfiltered, raw observations from the BaseEnv
  384. (vectorized, possibly multi-agent). Dict of dict: By env index,
  385. then agent ID, then mapped to actual obs.
  386. rewards: The rewards MultiEnvDict of the BaseEnv.
  387. terminateds: The `terminated` flags MultiEnvDict of the BaseEnv.
  388. truncateds: The `truncated` flags MultiEnvDict of the BaseEnv.
  389. infos: The MultiEnvDict of infos dicts of the BaseEnv.
  390. Returns:
  391. A tuple of:
  392. A list of envs that were active during this step.
  393. AgentConnectorDataType for active agents for policy evaluation.
  394. SampleBatches and RolloutMetrics for completed agents for output.
  395. """
  396. # Output objects.
  397. # Note that we need to track envs that are active during this round explicitly,
  398. # just to be confident which envs require us to send at least an empty action
  399. # dict to.
  400. # We can not get this from the _active_episode or to_eval lists because
  401. # 1. All envs are not required to step during every single step. And
  402. # 2. to_eval only contains data for the agents that are still active. An env may
  403. # be active but all agents are done during the step.
  404. active_envs: Set[EnvID] = set()
  405. to_eval: Dict[PolicyID, List[AgentConnectorDataType]] = defaultdict(list)
  406. outputs: List[Union[RolloutMetrics, SampleBatchType]] = []
  407. # For each (vectorized) sub-environment.
  408. # types: EnvID, Dict[AgentID, EnvObsType]
  409. for env_id, env_obs in unfiltered_obs.items():
  410. # Check for env_id having returned an error instead of a multi-agent
  411. # obs dict. This is how our BaseEnv can tell the caller to `poll()` that
  412. # one of its sub-environments is faulty and should be restarted (and the
  413. # ongoing episode should not be used for training).
  414. if isinstance(env_obs, Exception):
  415. assert terminateds[env_id]["__all__"] is True, (
  416. f"ERROR: When a sub-environment (env-id {env_id}) returns an error "
  417. "as observation, the terminateds[__all__] flag must also be set to "
  418. "True!"
  419. )
  420. # all_agents_obs is an Exception here.
  421. # Drop this episode and skip to next.
  422. self._handle_done_episode(
  423. env_id=env_id,
  424. env_obs_or_exception=env_obs,
  425. is_done=True,
  426. active_envs=active_envs,
  427. to_eval=to_eval,
  428. outputs=outputs,
  429. )
  430. continue
  431. if env_id not in self._active_episodes:
  432. episode: EpisodeV2 = self.create_episode(env_id)
  433. self._active_episodes[env_id] = episode
  434. else:
  435. episode: EpisodeV2 = self._active_episodes[env_id]
  436. # If this episode is brand-new, call the episode start callback(s).
  437. # Note: EpisodeV2s are initialized with length=-1 (before the reset).
  438. if not episode.has_init_obs():
  439. self._call_on_episode_start(episode, env_id)
  440. # Check episode termination conditions.
  441. if terminateds[env_id]["__all__"] or truncateds[env_id]["__all__"]:
  442. all_agents_done = True
  443. else:
  444. all_agents_done = False
  445. active_envs.add(env_id)
  446. # Special handling of common info dict.
  447. episode.set_last_info("__common__", infos[env_id].get("__common__", {}))
  448. # Agent sample batches grouped by policy. Each set of sample batches will
  449. # go through agent connectors together.
  450. sample_batches_by_policy = defaultdict(list)
  451. # Whether an agent is terminated or truncated.
  452. agent_terminateds = {}
  453. agent_truncateds = {}
  454. for agent_id, obs in env_obs.items():
  455. assert agent_id != "__all__"
  456. policy_id: PolicyID = episode.policy_for(agent_id)
  457. agent_terminated = bool(
  458. terminateds[env_id]["__all__"] or terminateds[env_id].get(agent_id)
  459. )
  460. agent_terminateds[agent_id] = agent_terminated
  461. agent_truncated = bool(
  462. truncateds[env_id]["__all__"]
  463. or truncateds[env_id].get(agent_id, False)
  464. )
  465. agent_truncateds[agent_id] = agent_truncated
  466. # A completely new agent is already done -> Skip entirely.
  467. if not episode.has_init_obs(agent_id) and (
  468. agent_terminated or agent_truncated
  469. ):
  470. continue
  471. values_dict = {
  472. SampleBatch.T: episode.length, # Episodes start at -1 before we
  473. # add the initial obs. After that, we infer from initial obs at
  474. # t=0 since that will be our new episode.length.
  475. SampleBatch.ENV_ID: env_id,
  476. SampleBatch.AGENT_INDEX: episode.agent_index(agent_id),
  477. # Last action (SampleBatch.ACTIONS) column will be populated by
  478. # StateBufferConnector.
  479. # Reward received after taking action at timestep t.
  480. SampleBatch.REWARDS: rewards[env_id].get(agent_id, 0.0),
  481. # After taking action=a, did we reach terminal?
  482. SampleBatch.TERMINATEDS: agent_terminated,
  483. # Was the episode truncated artificially
  484. # (e.g. b/c of some time limit)?
  485. SampleBatch.TRUNCATEDS: agent_truncated,
  486. SampleBatch.INFOS: infos[env_id].get(agent_id, {}),
  487. SampleBatch.NEXT_OBS: obs,
  488. }
  489. # Queue this obs sample for connector preprocessing.
  490. sample_batches_by_policy[policy_id].append((agent_id, values_dict))
  491. # The entire episode is done.
  492. if all_agents_done:
  493. # Let's check to see if there are any agents that haven't got the
  494. # last obs yet. If there are, we have to create fake-last
  495. # observations for them. (the environment is not required to do so if
  496. # terminateds[__all__]==True or truncateds[__all__]==True).
  497. for agent_id in episode.get_agents():
  498. # If the latest obs we got for this agent is done, or if its
  499. # episode state is already done, nothing to do.
  500. if (
  501. agent_terminateds.get(agent_id, False)
  502. or agent_truncateds.get(agent_id, False)
  503. or episode.is_done(agent_id)
  504. ):
  505. continue
  506. policy_id: PolicyID = episode.policy_for(agent_id)
  507. policy = self._worker.policy_map[policy_id]
  508. # Create a fake observation by sampling the original env
  509. # observation space.
  510. obs_space = get_original_space(policy.observation_space)
  511. # Although there is no obs for this agent, there may be
  512. # good rewards and info dicts for it.
  513. # This is the case for e.g. OpenSpiel games, where a reward
  514. # is only earned with the last step, but the obs for that
  515. # step is {}.
  516. reward = rewards[env_id].get(agent_id, 0.0)
  517. info = infos[env_id].get(agent_id, {})
  518. values_dict = {
  519. SampleBatch.T: episode.length,
  520. SampleBatch.ENV_ID: env_id,
  521. SampleBatch.AGENT_INDEX: episode.agent_index(agent_id),
  522. # TODO(sven): These should be the summed-up(!) rewards since the
  523. # last observation received for this agent.
  524. SampleBatch.REWARDS: reward,
  525. SampleBatch.TERMINATEDS: True,
  526. SampleBatch.TRUNCATEDS: truncateds[env_id].get(agent_id, False),
  527. SampleBatch.INFOS: info,
  528. SampleBatch.NEXT_OBS: obs_space.sample(),
  529. }
  530. # Queue these fake obs for connector preprocessing too.
  531. sample_batches_by_policy[policy_id].append((agent_id, values_dict))
  532. # Run agent connectors.
  533. for policy_id, batches in sample_batches_by_policy.items():
  534. policy: Policy = self._worker.policy_map[policy_id]
  535. # Collected full MultiAgentDicts for this environment.
  536. # Run agent connectors.
  537. assert (
  538. policy.agent_connectors
  539. ), "EnvRunnerV2 requires agent connectors to work."
  540. acd_list: List[AgentConnectorDataType] = [
  541. AgentConnectorDataType(env_id, agent_id, data)
  542. for agent_id, data in batches
  543. ]
  544. # For all agents mapped to policy_id, run their data
  545. # through agent_connectors.
  546. processed = policy.agent_connectors(acd_list)
  547. for d in processed:
  548. # Record transition info if applicable.
  549. if not episode.has_init_obs(d.agent_id):
  550. episode.add_init_obs(
  551. agent_id=d.agent_id,
  552. init_obs=d.data.raw_dict[SampleBatch.NEXT_OBS],
  553. init_infos=d.data.raw_dict[SampleBatch.INFOS],
  554. t=d.data.raw_dict[SampleBatch.T],
  555. )
  556. else:
  557. episode.add_action_reward_done_next_obs(
  558. d.agent_id, d.data.raw_dict
  559. )
  560. # Need to evaluate next actions.
  561. if not (
  562. all_agents_done
  563. or agent_terminateds.get(d.agent_id, False)
  564. or agent_truncateds.get(d.agent_id, False)
  565. or episode.is_done(d.agent_id)
  566. ):
  567. # Add to eval set if env is not done and this particular agent
  568. # is also not done.
  569. item = AgentConnectorDataType(d.env_id, d.agent_id, d.data)
  570. to_eval[policy_id].append(item)
  571. # Finished advancing episode by 1 step, mark it so.
  572. episode.step()
  573. # Exception: The very first env.poll() call causes the env to get reset
  574. # (no step taken yet, just a single starting observation logged).
  575. # We need to skip this callback in this case.
  576. if episode.length > 0:
  577. # Invoke the `on_episode_step` callback after the step is logged
  578. # to the episode.
  579. self._callbacks.on_episode_step(
  580. worker=self._worker,
  581. base_env=self._base_env,
  582. policies=self._worker.policy_map,
  583. episode=episode,
  584. env_index=env_id,
  585. )
  586. # Episode is terminated/truncated for all agents
  587. # (terminateds[__all__] == True or truncateds[__all__] == True).
  588. if all_agents_done:
  589. # _handle_done_episode will build a MultiAgentBatch for all
  590. # the agents that are done during this step of rollout in
  591. # the case of _multiple_episodes_in_batch=False.
  592. self._handle_done_episode(
  593. env_id,
  594. env_obs,
  595. terminateds[env_id]["__all__"] or truncateds[env_id]["__all__"],
  596. active_envs,
  597. to_eval,
  598. outputs,
  599. )
  600. # Try to build something.
  601. if self._multiple_episodes_in_batch:
  602. sample_batch = self._try_build_truncated_episode_multi_agent_batch(
  603. self._batch_builders[env_id], episode
  604. )
  605. if sample_batch:
  606. outputs.append(sample_batch)
  607. # SampleBatch built from data collected by batch_builder.
  608. # Clean up and delete the batch_builder.
  609. del self._batch_builders[env_id]
  610. return active_envs, to_eval, outputs
  611. def _build_done_episode(
  612. self,
  613. env_id: EnvID,
  614. is_done: bool,
  615. outputs: List[SampleBatchType],
  616. ):
  617. """Builds a MultiAgentSampleBatch from the episode and adds it to outputs.
  618. Args:
  619. env_id: The env id.
  620. is_done: Whether the env is done.
  621. outputs: The list of outputs to add the
  622. """
  623. episode: EpisodeV2 = self._active_episodes[env_id]
  624. batch_builder = self._batch_builders[env_id]
  625. episode.postprocess_episode(
  626. batch_builder=batch_builder,
  627. is_done=is_done,
  628. check_dones=is_done,
  629. )
  630. # If, we are not allowed to pack the next episode into the same
  631. # SampleBatch (batch_mode=complete_episodes) -> Build the
  632. # MultiAgentBatch from a single episode and add it to "outputs".
  633. # Otherwise, just postprocess and continue collecting across
  634. # episodes.
  635. if not self._multiple_episodes_in_batch:
  636. ma_sample_batch = _build_multi_agent_batch(
  637. episode.episode_id,
  638. batch_builder,
  639. self._large_batch_threshold,
  640. self._multiple_episodes_in_batch,
  641. )
  642. if ma_sample_batch:
  643. outputs.append(ma_sample_batch)
  644. # SampleBatch built from data collected by batch_builder.
  645. # Clean up and delete the batch_builder.
  646. del self._batch_builders[env_id]
  647. def __process_resetted_obs_for_eval(
  648. self,
  649. env_id: EnvID,
  650. obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
  651. infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]],
  652. episode: EpisodeV2,
  653. to_eval: Dict[PolicyID, List[AgentConnectorDataType]],
  654. ):
  655. """Process resetted obs through agent connectors for policy eval.
  656. Args:
  657. env_id: The env id.
  658. obs: The Resetted obs.
  659. episode: New episode.
  660. to_eval: List of agent connector data for policy eval.
  661. """
  662. per_policy_resetted_obs: Dict[PolicyID, List] = defaultdict(list)
  663. # types: AgentID, EnvObsType
  664. for agent_id, raw_obs in obs[env_id].items():
  665. policy_id: PolicyID = episode.policy_for(agent_id)
  666. per_policy_resetted_obs[policy_id].append((agent_id, raw_obs))
  667. for policy_id, agents_obs in per_policy_resetted_obs.items():
  668. policy = self._worker.policy_map[policy_id]
  669. acd_list: List[AgentConnectorDataType] = [
  670. AgentConnectorDataType(
  671. env_id,
  672. agent_id,
  673. {
  674. SampleBatch.NEXT_OBS: obs,
  675. SampleBatch.INFOS: infos,
  676. SampleBatch.T: episode.length,
  677. SampleBatch.AGENT_INDEX: episode.agent_index(agent_id),
  678. },
  679. )
  680. for agent_id, obs in agents_obs
  681. ]
  682. # Call agent connectors on these initial obs.
  683. processed = policy.agent_connectors(acd_list)
  684. for d in processed:
  685. episode.add_init_obs(
  686. agent_id=d.agent_id,
  687. init_obs=d.data.raw_dict[SampleBatch.NEXT_OBS],
  688. init_infos=d.data.raw_dict[SampleBatch.INFOS],
  689. t=d.data.raw_dict[SampleBatch.T],
  690. )
  691. to_eval[policy_id].append(d)
  692. def _handle_done_episode(
  693. self,
  694. env_id: EnvID,
  695. env_obs_or_exception: MultiAgentDict,
  696. is_done: bool,
  697. active_envs: Set[EnvID],
  698. to_eval: Dict[PolicyID, List[AgentConnectorDataType]],
  699. outputs: List[SampleBatchType],
  700. ) -> None:
  701. """Handle an all-finished episode.
  702. Add collected SampleBatch to batch builder. Reset corresponding env, etc.
  703. Args:
  704. env_id: Environment ID.
  705. env_obs_or_exception: Last per-environment observation or Exception.
  706. env_infos: Last per-environment infos.
  707. is_done: If all agents are done.
  708. active_envs: Set of active env ids.
  709. to_eval: Output container for policy eval data.
  710. outputs: Output container for collected sample batches.
  711. """
  712. if isinstance(env_obs_or_exception, Exception):
  713. episode_or_exception: Exception = env_obs_or_exception
  714. # Tell the sampler we have got a faulty episode.
  715. outputs.append(RolloutMetrics(episode_faulty=True))
  716. else:
  717. episode_or_exception: EpisodeV2 = self._active_episodes[env_id]
  718. # Add rollout metrics.
  719. outputs.extend(
  720. self._get_rollout_metrics(
  721. episode_or_exception, policy_map=self._worker.policy_map
  722. )
  723. )
  724. # Output the collected episode after adding rollout metrics so that we
  725. # always fetch metrics with RolloutWorker before we fetch samples.
  726. # This is because we need to behave like env_runner() for now.
  727. self._build_done_episode(env_id, is_done, outputs)
  728. # Clean up and deleted the post-processed episode now that we have collected
  729. # its data.
  730. self.end_episode(env_id, episode_or_exception)
  731. # Create a new episode instance (before we reset the sub-environment).
  732. new_episode: EpisodeV2 = self.create_episode(env_id)
  733. # The sub environment at index `env_id` might throw an exception
  734. # during the following `try_reset()` attempt. If configured with
  735. # `restart_failed_sub_environments=True`, the BaseEnv will restart
  736. # the affected sub environment (create a new one using its c'tor) and
  737. # must reset the recreated sub env right after that.
  738. # Should the sub environment fail indefinitely during these
  739. # repeated reset attempts, the entire worker will be blocked.
  740. # This would be ok, b/c the alternative would be the worker crashing
  741. # entirely.
  742. while True:
  743. resetted_obs, resetted_infos = self._base_env.try_reset(env_id)
  744. if (
  745. resetted_obs is None
  746. or resetted_obs == ASYNC_RESET_RETURN
  747. or not isinstance(resetted_obs[env_id], Exception)
  748. ):
  749. break
  750. else:
  751. # Report a faulty episode.
  752. outputs.append(RolloutMetrics(episode_faulty=True))
  753. # Reset connector state if this is a hard reset.
  754. for p in self._worker.policy_map.cache.values():
  755. p.agent_connectors.reset(env_id)
  756. # Creates a new episode if this is not async return.
  757. # If reset is async, we will get its result in some future poll.
  758. if resetted_obs is not None and resetted_obs != ASYNC_RESET_RETURN:
  759. self._active_episodes[env_id] = new_episode
  760. self._call_on_episode_start(new_episode, env_id)
  761. self.__process_resetted_obs_for_eval(
  762. env_id,
  763. resetted_obs,
  764. resetted_infos,
  765. new_episode,
  766. to_eval,
  767. )
  768. # Step after adding initial obs. This will give us 0 env and agent step.
  769. new_episode.step()
  770. active_envs.add(env_id)
  771. def create_episode(self, env_id: EnvID) -> EpisodeV2:
  772. """Creates a new EpisodeV2 instance and returns it.
  773. Calls `on_episode_created` callbacks, but does NOT reset the respective
  774. sub-environment yet.
  775. Args:
  776. env_id: Env ID.
  777. Returns:
  778. The newly created EpisodeV2 instance.
  779. """
  780. # Make sure we currently don't have an active episode under this env ID.
  781. assert env_id not in self._active_episodes
  782. # Create a new episode under the same `env_id` and call the
  783. # `on_episode_created` callbacks.
  784. new_episode = EpisodeV2(
  785. env_id,
  786. self._worker.policy_map,
  787. self._worker.policy_mapping_fn,
  788. worker=self._worker,
  789. callbacks=self._callbacks,
  790. )
  791. # Call `on_episode_created()` callback.
  792. self._callbacks.on_episode_created(
  793. worker=self._worker,
  794. base_env=self._base_env,
  795. policies=self._worker.policy_map,
  796. env_index=env_id,
  797. episode=new_episode,
  798. )
  799. return new_episode
  800. def end_episode(
  801. self, env_id: EnvID, episode_or_exception: Union[EpisodeV2, Exception]
  802. ):
  803. """Cleans up an episode that has finished.
  804. Args:
  805. env_id: Env ID.
  806. episode_or_exception: Instance of an episode if it finished successfully.
  807. Otherwise, the exception that was thrown,
  808. """
  809. # Signal the end of an episode, either successfully with an Episode or
  810. # unsuccessfully with an Exception.
  811. self._callbacks.on_episode_end(
  812. worker=self._worker,
  813. base_env=self._base_env,
  814. policies=self._worker.policy_map,
  815. episode=episode_or_exception,
  816. env_index=env_id,
  817. )
  818. # Call each (in-memory) policy's Exploration.on_episode_end
  819. # method.
  820. # Note: This may break the exploration (e.g. ParameterNoise) of
  821. # policies in the `policy_map` that have not been recently used
  822. # (and are therefore stashed to disk). However, we certainly do not
  823. # want to loop through all (even stashed) policies here as that
  824. # would counter the purpose of the LRU policy caching.
  825. for p in self._worker.policy_map.cache.values():
  826. if getattr(p, "exploration", None) is not None:
  827. p.exploration.on_episode_end(
  828. policy=p,
  829. environment=self._base_env,
  830. episode=episode_or_exception,
  831. tf_sess=p.get_session(),
  832. )
  833. if isinstance(episode_or_exception, EpisodeV2):
  834. episode = episode_or_exception
  835. if episode.total_agent_steps == 0:
  836. # if the key does not exist it means that throughout the episode all
  837. # observations were empty (i.e. there was no agent in the env)
  838. msg = (
  839. f"Data from episode {episode.episode_id} does not show any agent "
  840. f"interactions. Hint: Make sure for at least one timestep in the "
  841. f"episode, env.step() returns non-empty values."
  842. )
  843. raise ValueError(msg)
  844. # Clean up the episode and batch_builder for this env id.
  845. if env_id in self._active_episodes:
  846. del self._active_episodes[env_id]
  847. def _try_build_truncated_episode_multi_agent_batch(
  848. self, batch_builder: _PolicyCollectorGroup, episode: EpisodeV2
  849. ) -> Union[None, SampleBatch, MultiAgentBatch]:
  850. # Measure batch size in env-steps.
  851. if self._count_steps_by == "env_steps":
  852. built_steps = batch_builder.env_steps
  853. ongoing_steps = episode.active_env_steps
  854. # Measure batch-size in agent-steps.
  855. else:
  856. built_steps = batch_builder.agent_steps
  857. ongoing_steps = episode.active_agent_steps
  858. # Reached the fragment-len -> We should build an MA-Batch.
  859. if built_steps + ongoing_steps >= self._rollout_fragment_length:
  860. if self._count_steps_by != "agent_steps":
  861. assert built_steps + ongoing_steps == self._rollout_fragment_length, (
  862. f"built_steps ({built_steps}) + ongoing_steps ({ongoing_steps}) != "
  863. f"rollout_fragment_length ({self._rollout_fragment_length})."
  864. )
  865. # If we reached the fragment-len only because of `episode_id`
  866. # (still ongoing) -> postprocess `episode_id` first.
  867. if built_steps < self._rollout_fragment_length:
  868. episode.postprocess_episode(batch_builder=batch_builder, is_done=False)
  869. # If builder has collected some data,
  870. # build the MA-batch and add to return values.
  871. if batch_builder.agent_steps > 0:
  872. return _build_multi_agent_batch(
  873. episode.episode_id,
  874. batch_builder,
  875. self._large_batch_threshold,
  876. self._multiple_episodes_in_batch,
  877. )
  878. # No batch-builder:
  879. # We have reached the rollout-fragment length w/o any agent
  880. # steps! Warn that the environment may never request any
  881. # actions from any agents.
  882. elif log_once("no_agent_steps"):
  883. logger.warning(
  884. "Your environment seems to be stepping w/o ever "
  885. "emitting agent observations (agents are never "
  886. "requested to act)!"
  887. )
  888. return None
  889. def _do_policy_eval(
  890. self,
  891. to_eval: Dict[PolicyID, List[AgentConnectorDataType]],
  892. ) -> Dict[PolicyID, PolicyOutputType]:
  893. """Call compute_actions on collected episode data to get next action.
  894. Args:
  895. to_eval: Mapping of policy IDs to lists of AgentConnectorDataType objects
  896. (items in these lists will be the batch's items for the model
  897. forward pass).
  898. Returns:
  899. Dict mapping PolicyIDs to compute_actions_from_input_dict() outputs.
  900. """
  901. policies = self._worker.policy_map
  902. # In case policy map has changed, try to find the new policy that
  903. # should handle all these per-agent eval data.
  904. # Throws exception if these agents are mapped to multiple different
  905. # policies now.
  906. def _try_find_policy_again(eval_data: AgentConnectorDataType):
  907. policy_id = None
  908. for d in eval_data:
  909. episode = self._active_episodes[d.env_id]
  910. # Force refresh policy mapping on the episode.
  911. pid = episode.policy_for(d.agent_id, refresh=True)
  912. if policy_id is not None and pid != policy_id:
  913. raise ValueError(
  914. "Policy map changed. The list of eval data that was handled "
  915. f"by a same policy is now handled by policy {pid} "
  916. "and {policy_id}. "
  917. "Please don't do this in the middle of an episode."
  918. )
  919. policy_id = pid
  920. return _get_or_raise(self._worker.policy_map, policy_id)
  921. eval_results: Dict[PolicyID, TensorStructType] = {}
  922. for policy_id, eval_data in to_eval.items():
  923. # In case the policyID has been removed from this worker, we need to
  924. # re-assign policy_id and re-lookup the Policy object to use.
  925. try:
  926. policy: Policy = _get_or_raise(policies, policy_id)
  927. except ValueError:
  928. # policy_mapping_fn from the worker may have already been
  929. # changed (mapping fn not staying constant within one episode).
  930. policy: Policy = _try_find_policy_again(eval_data)
  931. input_dict = _batch_inference_sample_batches(
  932. [d.data.sample_batch for d in eval_data]
  933. )
  934. eval_results[policy_id] = policy.compute_actions_from_input_dict(
  935. input_dict,
  936. timestep=policy.global_timestep,
  937. episodes=[self._active_episodes[t.env_id] for t in eval_data],
  938. )
  939. return eval_results
  940. def _process_policy_eval_results(
  941. self,
  942. active_envs: Set[EnvID],
  943. to_eval: Dict[PolicyID, List[AgentConnectorDataType]],
  944. eval_results: Dict[PolicyID, PolicyOutputType],
  945. off_policy_actions: MultiEnvDict,
  946. ):
  947. """Process the output of policy neural network evaluation.
  948. Records policy evaluation results into agent connectors and
  949. returns replies to send back to agents in the env.
  950. Args:
  951. active_envs: Set of env IDs that are still active.
  952. to_eval: Mapping of policy IDs to lists of AgentConnectorDataType objects.
  953. eval_results: Mapping of policy IDs to list of
  954. actions, rnn-out states, extra-action-fetches dicts.
  955. off_policy_actions: Doubly keyed dict of env-ids -> agent ids ->
  956. off-policy-action, returned by a `BaseEnv.poll()` call.
  957. Returns:
  958. Nested dict of env id -> agent id -> actions to be sent to
  959. Env (np.ndarrays).
  960. """
  961. actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = defaultdict(dict)
  962. for env_id in active_envs:
  963. actions_to_send[env_id] = {} # at minimum send empty dict
  964. # types: PolicyID, List[AgentConnectorDataType]
  965. for policy_id, eval_data in to_eval.items():
  966. actions: TensorStructType = eval_results[policy_id][0]
  967. actions = convert_to_numpy(actions)
  968. rnn_out: StateBatches = eval_results[policy_id][1]
  969. extra_action_out: dict = eval_results[policy_id][2]
  970. # In case actions is a list (representing the 0th dim of a batch of
  971. # primitive actions), try converting it first.
  972. if isinstance(actions, list):
  973. actions = np.array(actions)
  974. # Split action-component batches into single action rows.
  975. actions: List[EnvActionType] = unbatch(actions)
  976. policy: Policy = _get_or_raise(self._worker.policy_map, policy_id)
  977. assert (
  978. policy.agent_connectors and policy.action_connectors
  979. ), "EnvRunnerV2 requires action connectors to work."
  980. # types: int, EnvActionType
  981. for i, action in enumerate(actions):
  982. env_id: int = eval_data[i].env_id
  983. agent_id: AgentID = eval_data[i].agent_id
  984. input_dict: TensorStructType = eval_data[i].data.raw_dict
  985. rnn_states: List[StateBatches] = tree.map_structure(
  986. lambda x, i=i: x[i], rnn_out
  987. )
  988. # extra_action_out could be a nested dict
  989. fetches: Dict = tree.map_structure(
  990. lambda x, i=i: x[i], extra_action_out
  991. )
  992. # Post-process policy output by running them through action connectors.
  993. ac_data = ActionConnectorDataType(
  994. env_id, agent_id, input_dict, (action, rnn_states, fetches)
  995. )
  996. action_to_send, rnn_states, fetches = policy.action_connectors(
  997. ac_data
  998. ).output
  999. # The action we want to buffer is the direct output of
  1000. # compute_actions_from_input_dict() here. This is because we want to
  1001. # send the unsqushed actions to the environment while learning and
  1002. # possibly basing subsequent actions on the squashed actions.
  1003. action_to_buffer = (
  1004. action
  1005. if env_id not in off_policy_actions
  1006. or agent_id not in off_policy_actions[env_id]
  1007. else off_policy_actions[env_id][agent_id]
  1008. )
  1009. # Notify agent connectors with this new policy output.
  1010. # Necessary for state buffering agent connectors, for example.
  1011. ac_data: ActionConnectorDataType = ActionConnectorDataType(
  1012. env_id,
  1013. agent_id,
  1014. input_dict,
  1015. (action_to_buffer, rnn_states, fetches),
  1016. )
  1017. policy.agent_connectors.on_policy_output(ac_data)
  1018. assert agent_id not in actions_to_send[env_id]
  1019. actions_to_send[env_id][agent_id] = action_to_send
  1020. return actions_to_send
  1021. def _maybe_render(self):
  1022. """Visualize environment."""
  1023. # Check if we should render.
  1024. if not self._render or not self._simple_image_viewer:
  1025. return
  1026. t5 = time.time()
  1027. # Render can either return an RGB image (uint8 [w x h x 3] numpy
  1028. # array) or take care of rendering itself (returning True).
  1029. rendered = self._base_env.try_render()
  1030. # Rendering returned an image -> Display it in a SimpleImageViewer.
  1031. if isinstance(rendered, np.ndarray) and len(rendered.shape) == 3:
  1032. self._simple_image_viewer.imshow(rendered)
  1033. elif rendered not in [True, False, None]:
  1034. raise ValueError(
  1035. f"The env's ({self._base_env}) `try_render()` method returned an"
  1036. " unsupported value! Make sure you either return a "
  1037. "uint8/w x h x 3 (RGB) image or handle rendering in a "
  1038. "window and then return `True`."
  1039. )
  1040. self._perf_stats.incr("env_render_time", time.time() - t5)
  1041. def _fetch_atari_metrics(base_env: BaseEnv) -> List[RolloutMetrics]:
  1042. """Atari games have multiple logical episodes, one per life.
  1043. However, for metrics reporting we count full episodes, all lives included.
  1044. """
  1045. sub_environments = base_env.get_sub_environments()
  1046. if not sub_environments:
  1047. return None
  1048. atari_out = []
  1049. for sub_env in sub_environments:
  1050. monitor = get_wrapper_by_cls(sub_env, MonitorEnv)
  1051. if not monitor:
  1052. return None
  1053. for eps_rew, eps_len in monitor.next_episode_results():
  1054. atari_out.append(RolloutMetrics(eps_len, eps_rew))
  1055. return atari_out
  1056. def _get_or_raise(
  1057. mapping: Dict[PolicyID, Union[Policy, Preprocessor, Filter]], policy_id: PolicyID
  1058. ) -> Union[Policy, Preprocessor, Filter]:
  1059. """Returns an object under key `policy_id` in `mapping`.
  1060. Args:
  1061. mapping (Dict[PolicyID, Union[Policy, Preprocessor, Filter]]): The
  1062. mapping dict from policy id (str) to actual object (Policy,
  1063. Preprocessor, etc.).
  1064. policy_id: The policy ID to lookup.
  1065. Returns:
  1066. Union[Policy, Preprocessor, Filter]: The found object.
  1067. Raises:
  1068. ValueError: If `policy_id` cannot be found in `mapping`.
  1069. """
  1070. if policy_id not in mapping:
  1071. raise ValueError(
  1072. "Could not find policy for agent: PolicyID `{}` not found "
  1073. "in policy map, whose keys are `{}`.".format(policy_id, mapping.keys())
  1074. )
  1075. return mapping[policy_id]