| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233 |
- import logging
- import time
- from collections import defaultdict
- from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Set, Tuple, Union
- import numpy as np
- import tree # pip install dm_tree
- from ray.rllib.env.base_env import ASYNC_RESET_RETURN, BaseEnv
- from ray.rllib.env.external_env import ExternalEnvWrapper
- from ray.rllib.env.wrappers.atari_wrappers import MonitorEnv, get_wrapper_by_cls
- from ray.rllib.evaluation.collectors.simple_list_collector import _PolicyCollectorGroup
- from ray.rllib.evaluation.episode_v2 import EpisodeV2
- from ray.rllib.evaluation.metrics import RolloutMetrics
- from ray.rllib.models.preprocessors import Preprocessor
- from ray.rllib.policy.policy import Policy
- from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch, concat_samples
- from ray.rllib.utils.annotations import OldAPIStack
- from ray.rllib.utils.filter import Filter
- from ray.rllib.utils.numpy import convert_to_numpy
- from ray.rllib.utils.spaces.space_utils import get_original_space, unbatch
- from ray.rllib.utils.typing import (
- ActionConnectorDataType,
- AgentConnectorDataType,
- AgentID,
- EnvActionType,
- EnvID,
- EnvInfoDict,
- EnvObsType,
- MultiAgentDict,
- MultiEnvDict,
- PolicyID,
- PolicyOutputType,
- SampleBatchType,
- StateBatches,
- TensorStructType,
- )
- from ray.util.debug import log_once
- if TYPE_CHECKING:
- from gymnasium.envs.classic_control.rendering import SimpleImageViewer
- from ray.rllib.callbacks.callbacks import RLlibCallback
- from ray.rllib.evaluation.rollout_worker import RolloutWorker
- logger = logging.getLogger(__name__)
- MIN_LARGE_BATCH_THRESHOLD = 1000
- DEFAULT_LARGE_BATCH_THRESHOLD = 5000
- MS_TO_SEC = 1000.0
- @OldAPIStack
- class _PerfStats:
- """Sampler perf stats that will be included in rollout metrics."""
- def __init__(self, ema_coef: Optional[float] = None):
- # If not None, enable Exponential Moving Average mode.
- # The way we update stats is by:
- # updated = (1 - ema_coef) * old + ema_coef * new
- # In general provides more responsive stats about sampler performance.
- # TODO(jungong) : make ema the default (only) mode if it works well.
- self.ema_coef = ema_coef
- self.iters = 0
- self.raw_obs_processing_time = 0.0
- self.inference_time = 0.0
- self.action_processing_time = 0.0
- self.env_wait_time = 0.0
- self.env_render_time = 0.0
- def incr(self, field: str, value: Union[int, float]):
- if field == "iters":
- self.iters += value
- return
- # All the other fields support either global average or ema mode.
- if self.ema_coef is None:
- # Global average.
- self.__dict__[field] += value
- else:
- self.__dict__[field] = (1.0 - self.ema_coef) * self.__dict__[
- field
- ] + self.ema_coef * value
- def _get_avg(self):
- # Mean multiplicator (1000 = sec -> ms).
- factor = MS_TO_SEC / self.iters
- return {
- # Raw observation preprocessing.
- "mean_raw_obs_processing_ms": self.raw_obs_processing_time * factor,
- # Computing actions through policy.
- "mean_inference_ms": self.inference_time * factor,
- # Processing actions (to be sent to env, e.g. clipping).
- "mean_action_processing_ms": self.action_processing_time * factor,
- # Waiting for environment (during poll).
- "mean_env_wait_ms": self.env_wait_time * factor,
- # Environment rendering (False by default).
- "mean_env_render_ms": self.env_render_time * factor,
- }
- def _get_ema(self):
- # In EMA mode, stats are already (exponentially) averaged,
- # hence we only need to do the sec -> ms conversion here.
- return {
- # Raw observation preprocessing.
- "mean_raw_obs_processing_ms": self.raw_obs_processing_time * MS_TO_SEC,
- # Computing actions through policy.
- "mean_inference_ms": self.inference_time * MS_TO_SEC,
- # Processing actions (to be sent to env, e.g. clipping).
- "mean_action_processing_ms": self.action_processing_time * MS_TO_SEC,
- # Waiting for environment (during poll).
- "mean_env_wait_ms": self.env_wait_time * MS_TO_SEC,
- # Environment rendering (False by default).
- "mean_env_render_ms": self.env_render_time * MS_TO_SEC,
- }
- def get(self):
- if self.ema_coef is None:
- return self._get_avg()
- else:
- return self._get_ema()
- @OldAPIStack
- class _NewDefaultDict(defaultdict):
- def __missing__(self, env_id):
- ret = self[env_id] = self.default_factory(env_id)
- return ret
- @OldAPIStack
- def _build_multi_agent_batch(
- episode_id: int,
- batch_builder: _PolicyCollectorGroup,
- large_batch_threshold: int,
- multiple_episodes_in_batch: bool,
- ) -> MultiAgentBatch:
- """Build MultiAgentBatch from a dict of _PolicyCollectors.
- Args:
- env_steps: total env steps.
- policy_collectors: collected training SampleBatchs by policy.
- Returns:
- Always returns a sample batch in MultiAgentBatch format.
- """
- ma_batch = {}
- for pid, collector in batch_builder.policy_collectors.items():
- if collector.agent_steps <= 0:
- continue
- if batch_builder.agent_steps > large_batch_threshold and log_once(
- "large_batch_warning"
- ):
- logger.warning(
- "More than {} observations in {} env steps for "
- "episode {} ".format(
- batch_builder.agent_steps, batch_builder.env_steps, episode_id
- )
- + "are buffered in the sampler. If this is more than you "
- "expected, check that that you set a horizon on your "
- "environment correctly and that it terminates at some "
- "point. Note: In multi-agent environments, "
- "`rollout_fragment_length` sets the batch size based on "
- "(across-agents) environment steps, not the steps of "
- "individual agents, which can result in unexpectedly "
- "large batches."
- + (
- "Also, you may be waiting for your Env to "
- "terminate (batch_mode=`complete_episodes`). Make sure "
- "it does at some point."
- if not multiple_episodes_in_batch
- else ""
- )
- )
- batch = collector.build()
- ma_batch[pid] = batch
- # Create the multi agent batch.
- return MultiAgentBatch(policy_batches=ma_batch, env_steps=batch_builder.env_steps)
- @OldAPIStack
- def _batch_inference_sample_batches(eval_data: List[SampleBatch]) -> SampleBatch:
- """Batch a list of input SampleBatches into a single SampleBatch.
- Args:
- eval_data: list of SampleBatches.
- Returns:
- single batched SampleBatch.
- """
- inference_batch = concat_samples(eval_data)
- if "state_in_0" in inference_batch:
- batch_size = len(eval_data)
- inference_batch[SampleBatch.SEQ_LENS] = np.ones(batch_size, dtype=np.int32)
- return inference_batch
- @OldAPIStack
- class EnvRunnerV2:
- """Collect experiences from user environment using Connectors."""
- def __init__(
- self,
- worker: "RolloutWorker",
- base_env: BaseEnv,
- multiple_episodes_in_batch: bool,
- callbacks: "RLlibCallback",
- perf_stats: _PerfStats,
- rollout_fragment_length: int = 200,
- count_steps_by: str = "env_steps",
- render: bool = None,
- ):
- """
- Args:
- worker: Reference to the current rollout worker.
- base_env: Env implementing BaseEnv.
- multiple_episodes_in_batch: Whether to pack multiple
- episodes into each batch. This guarantees batches will be exactly
- `rollout_fragment_length` in size.
- callbacks: User callbacks to run on episode events.
- perf_stats: Record perf stats into this object.
- rollout_fragment_length: The length of a fragment to collect
- before building a SampleBatch from the data and resetting
- the SampleBatchBuilder object.
- count_steps_by: One of "env_steps" (default) or "agent_steps".
- Use "agent_steps", if you want rollout lengths to be counted
- by individual agent steps. In a multi-agent env,
- a single env_step contains one or more agent_steps, depending
- on how many agents are present at any given time in the
- ongoing episode.
- render: Whether to try to render the environment after each
- step.
- """
- self._worker = worker
- if isinstance(base_env, ExternalEnvWrapper):
- raise ValueError(
- "Policies using the new Connector API do not support ExternalEnv."
- )
- self._base_env = base_env
- self._multiple_episodes_in_batch = multiple_episodes_in_batch
- self._callbacks = callbacks
- self._perf_stats = perf_stats
- self._rollout_fragment_length = rollout_fragment_length
- self._count_steps_by = count_steps_by
- self._render = render
- # May be populated for image rendering.
- self._simple_image_viewer: Optional[
- "SimpleImageViewer"
- ] = self._get_simple_image_viewer()
- # Keeps track of active episodes.
- self._active_episodes: Dict[EnvID, EpisodeV2] = {}
- self._batch_builders: Dict[EnvID, _PolicyCollectorGroup] = _NewDefaultDict(
- self._new_batch_builder
- )
- self._large_batch_threshold: int = (
- max(MIN_LARGE_BATCH_THRESHOLD, self._rollout_fragment_length * 10)
- if self._rollout_fragment_length != float("inf")
- else DEFAULT_LARGE_BATCH_THRESHOLD
- )
- def _get_simple_image_viewer(self):
- """Maybe construct a SimpleImageViewer instance for episode rendering."""
- # Try to render the env, if required.
- if not self._render:
- return None
- try:
- from gymnasium.envs.classic_control.rendering import SimpleImageViewer
- return SimpleImageViewer()
- except (ImportError, ModuleNotFoundError):
- self._render = False # disable rendering
- logger.warning(
- "Could not import gymnasium.envs.classic_control."
- "rendering! Try `pip install gymnasium[all]`."
- )
- return None
- def _call_on_episode_start(self, episode, env_id):
- # Call each policy's Exploration.on_episode_start method.
- # Note: This may break the exploration (e.g. ParameterNoise) of
- # policies in the `policy_map` that have not been recently used
- # (and are therefore stashed to disk). However, we certainly do not
- # want to loop through all (even stashed) policies here as that
- # would counter the purpose of the LRU policy caching.
- for p in self._worker.policy_map.cache.values():
- if getattr(p, "exploration", None) is not None:
- p.exploration.on_episode_start(
- policy=p,
- environment=self._base_env,
- episode=episode,
- tf_sess=p.get_session(),
- )
- # Call `on_episode_start()` callback.
- self._callbacks.on_episode_start(
- worker=self._worker,
- base_env=self._base_env,
- policies=self._worker.policy_map,
- env_index=env_id,
- episode=episode,
- )
- def _new_batch_builder(self, _) -> _PolicyCollectorGroup:
- """Create a new batch builder.
- We create a _PolicyCollectorGroup based on the full policy_map
- as the batch builder.
- """
- return _PolicyCollectorGroup(self._worker.policy_map)
- def run(self) -> Iterator[SampleBatchType]:
- """Samples and yields training episodes continuously.
- Yields:
- Object containing state, action, reward, terminal condition,
- and other fields as dictated by `policy`.
- """
- while True:
- outputs = self.step()
- for o in outputs:
- yield o
- def step(self) -> List[SampleBatchType]:
- """Samples training episodes by stepping through environments."""
- self._perf_stats.incr("iters", 1)
- t0 = time.time()
- # Get observations from all ready agents.
- # types: MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, ...
- (
- unfiltered_obs,
- rewards,
- terminateds,
- truncateds,
- infos,
- off_policy_actions,
- ) = self._base_env.poll()
- env_poll_time = time.time() - t0
- # Process observations and prepare for policy evaluation.
- t1 = time.time()
- # types: Set[EnvID], Dict[PolicyID, List[AgentConnectorDataType]],
- # List[Union[RolloutMetrics, SampleBatchType]]
- active_envs, to_eval, outputs = self._process_observations(
- unfiltered_obs=unfiltered_obs,
- rewards=rewards,
- terminateds=terminateds,
- truncateds=truncateds,
- infos=infos,
- )
- self._perf_stats.incr("raw_obs_processing_time", time.time() - t1)
- # Do batched policy eval (accross vectorized envs).
- t2 = time.time()
- # types: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]
- eval_results = self._do_policy_eval(to_eval=to_eval)
- self._perf_stats.incr("inference_time", time.time() - t2)
- # Process results and update episode state.
- t3 = time.time()
- actions_to_send: Dict[
- EnvID, Dict[AgentID, EnvActionType]
- ] = self._process_policy_eval_results(
- active_envs=active_envs,
- to_eval=to_eval,
- eval_results=eval_results,
- off_policy_actions=off_policy_actions,
- )
- self._perf_stats.incr("action_processing_time", time.time() - t3)
- # Return computed actions to ready envs. We also send to envs that have
- # taken off-policy actions; those envs are free to ignore the action.
- t4 = time.time()
- self._base_env.send_actions(actions_to_send)
- self._perf_stats.incr("env_wait_time", env_poll_time + time.time() - t4)
- self._maybe_render()
- return outputs
- def _get_rollout_metrics(
- self, episode: EpisodeV2, policy_map: Dict[str, Policy]
- ) -> List[RolloutMetrics]:
- """Get rollout metrics from completed episode."""
- # TODO(jungong) : why do we need to handle atari metrics differently?
- # Can we unify atari and normal env metrics?
- atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics(self._base_env)
- if atari_metrics is not None:
- for m in atari_metrics:
- m._replace(custom_metrics=episode.custom_metrics)
- return atari_metrics
- # Create connector metrics
- connector_metrics = {}
- active_agents = episode.get_agents()
- for agent in active_agents:
- policy_id = episode.policy_for(agent)
- policy = episode.policy_map[policy_id]
- connector_metrics[policy_id] = policy.get_connector_metrics()
- # Otherwise, return RolloutMetrics for the episode.
- return [
- RolloutMetrics(
- episode_length=episode.length,
- episode_reward=episode.total_reward,
- agent_rewards=dict(episode.agent_rewards),
- custom_metrics=episode.custom_metrics,
- perf_stats={},
- hist_data=episode.hist_data,
- media=episode.media,
- connector_metrics=connector_metrics,
- )
- ]
- def _process_observations(
- self,
- unfiltered_obs: MultiEnvDict,
- rewards: MultiEnvDict,
- terminateds: MultiEnvDict,
- truncateds: MultiEnvDict,
- infos: MultiEnvDict,
- ) -> Tuple[
- Set[EnvID],
- Dict[PolicyID, List[AgentConnectorDataType]],
- List[Union[RolloutMetrics, SampleBatchType]],
- ]:
- """Process raw obs from env.
- Group data for active agents by policy. Reset environments that are done.
- Args:
- unfiltered_obs: The unfiltered, raw observations from the BaseEnv
- (vectorized, possibly multi-agent). Dict of dict: By env index,
- then agent ID, then mapped to actual obs.
- rewards: The rewards MultiEnvDict of the BaseEnv.
- terminateds: The `terminated` flags MultiEnvDict of the BaseEnv.
- truncateds: The `truncated` flags MultiEnvDict of the BaseEnv.
- infos: The MultiEnvDict of infos dicts of the BaseEnv.
- Returns:
- A tuple of:
- A list of envs that were active during this step.
- AgentConnectorDataType for active agents for policy evaluation.
- SampleBatches and RolloutMetrics for completed agents for output.
- """
- # Output objects.
- # Note that we need to track envs that are active during this round explicitly,
- # just to be confident which envs require us to send at least an empty action
- # dict to.
- # We can not get this from the _active_episode or to_eval lists because
- # 1. All envs are not required to step during every single step. And
- # 2. to_eval only contains data for the agents that are still active. An env may
- # be active but all agents are done during the step.
- active_envs: Set[EnvID] = set()
- to_eval: Dict[PolicyID, List[AgentConnectorDataType]] = defaultdict(list)
- outputs: List[Union[RolloutMetrics, SampleBatchType]] = []
- # For each (vectorized) sub-environment.
- # types: EnvID, Dict[AgentID, EnvObsType]
- for env_id, env_obs in unfiltered_obs.items():
- # Check for env_id having returned an error instead of a multi-agent
- # obs dict. This is how our BaseEnv can tell the caller to `poll()` that
- # one of its sub-environments is faulty and should be restarted (and the
- # ongoing episode should not be used for training).
- if isinstance(env_obs, Exception):
- assert terminateds[env_id]["__all__"] is True, (
- f"ERROR: When a sub-environment (env-id {env_id}) returns an error "
- "as observation, the terminateds[__all__] flag must also be set to "
- "True!"
- )
- # all_agents_obs is an Exception here.
- # Drop this episode and skip to next.
- self._handle_done_episode(
- env_id=env_id,
- env_obs_or_exception=env_obs,
- is_done=True,
- active_envs=active_envs,
- to_eval=to_eval,
- outputs=outputs,
- )
- continue
- if env_id not in self._active_episodes:
- episode: EpisodeV2 = self.create_episode(env_id)
- self._active_episodes[env_id] = episode
- else:
- episode: EpisodeV2 = self._active_episodes[env_id]
- # If this episode is brand-new, call the episode start callback(s).
- # Note: EpisodeV2s are initialized with length=-1 (before the reset).
- if not episode.has_init_obs():
- self._call_on_episode_start(episode, env_id)
- # Check episode termination conditions.
- if terminateds[env_id]["__all__"] or truncateds[env_id]["__all__"]:
- all_agents_done = True
- else:
- all_agents_done = False
- active_envs.add(env_id)
- # Special handling of common info dict.
- episode.set_last_info("__common__", infos[env_id].get("__common__", {}))
- # Agent sample batches grouped by policy. Each set of sample batches will
- # go through agent connectors together.
- sample_batches_by_policy = defaultdict(list)
- # Whether an agent is terminated or truncated.
- agent_terminateds = {}
- agent_truncateds = {}
- for agent_id, obs in env_obs.items():
- assert agent_id != "__all__"
- policy_id: PolicyID = episode.policy_for(agent_id)
- agent_terminated = bool(
- terminateds[env_id]["__all__"] or terminateds[env_id].get(agent_id)
- )
- agent_terminateds[agent_id] = agent_terminated
- agent_truncated = bool(
- truncateds[env_id]["__all__"]
- or truncateds[env_id].get(agent_id, False)
- )
- agent_truncateds[agent_id] = agent_truncated
- # A completely new agent is already done -> Skip entirely.
- if not episode.has_init_obs(agent_id) and (
- agent_terminated or agent_truncated
- ):
- continue
- values_dict = {
- SampleBatch.T: episode.length, # Episodes start at -1 before we
- # add the initial obs. After that, we infer from initial obs at
- # t=0 since that will be our new episode.length.
- SampleBatch.ENV_ID: env_id,
- SampleBatch.AGENT_INDEX: episode.agent_index(agent_id),
- # Last action (SampleBatch.ACTIONS) column will be populated by
- # StateBufferConnector.
- # Reward received after taking action at timestep t.
- SampleBatch.REWARDS: rewards[env_id].get(agent_id, 0.0),
- # After taking action=a, did we reach terminal?
- SampleBatch.TERMINATEDS: agent_terminated,
- # Was the episode truncated artificially
- # (e.g. b/c of some time limit)?
- SampleBatch.TRUNCATEDS: agent_truncated,
- SampleBatch.INFOS: infos[env_id].get(agent_id, {}),
- SampleBatch.NEXT_OBS: obs,
- }
- # Queue this obs sample for connector preprocessing.
- sample_batches_by_policy[policy_id].append((agent_id, values_dict))
- # The entire episode is done.
- if all_agents_done:
- # Let's check to see if there are any agents that haven't got the
- # last obs yet. If there are, we have to create fake-last
- # observations for them. (the environment is not required to do so if
- # terminateds[__all__]==True or truncateds[__all__]==True).
- for agent_id in episode.get_agents():
- # If the latest obs we got for this agent is done, or if its
- # episode state is already done, nothing to do.
- if (
- agent_terminateds.get(agent_id, False)
- or agent_truncateds.get(agent_id, False)
- or episode.is_done(agent_id)
- ):
- continue
- policy_id: PolicyID = episode.policy_for(agent_id)
- policy = self._worker.policy_map[policy_id]
- # Create a fake observation by sampling the original env
- # observation space.
- obs_space = get_original_space(policy.observation_space)
- # Although there is no obs for this agent, there may be
- # good rewards and info dicts for it.
- # This is the case for e.g. OpenSpiel games, where a reward
- # is only earned with the last step, but the obs for that
- # step is {}.
- reward = rewards[env_id].get(agent_id, 0.0)
- info = infos[env_id].get(agent_id, {})
- values_dict = {
- SampleBatch.T: episode.length,
- SampleBatch.ENV_ID: env_id,
- SampleBatch.AGENT_INDEX: episode.agent_index(agent_id),
- # TODO(sven): These should be the summed-up(!) rewards since the
- # last observation received for this agent.
- SampleBatch.REWARDS: reward,
- SampleBatch.TERMINATEDS: True,
- SampleBatch.TRUNCATEDS: truncateds[env_id].get(agent_id, False),
- SampleBatch.INFOS: info,
- SampleBatch.NEXT_OBS: obs_space.sample(),
- }
- # Queue these fake obs for connector preprocessing too.
- sample_batches_by_policy[policy_id].append((agent_id, values_dict))
- # Run agent connectors.
- for policy_id, batches in sample_batches_by_policy.items():
- policy: Policy = self._worker.policy_map[policy_id]
- # Collected full MultiAgentDicts for this environment.
- # Run agent connectors.
- assert (
- policy.agent_connectors
- ), "EnvRunnerV2 requires agent connectors to work."
- acd_list: List[AgentConnectorDataType] = [
- AgentConnectorDataType(env_id, agent_id, data)
- for agent_id, data in batches
- ]
- # For all agents mapped to policy_id, run their data
- # through agent_connectors.
- processed = policy.agent_connectors(acd_list)
- for d in processed:
- # Record transition info if applicable.
- if not episode.has_init_obs(d.agent_id):
- episode.add_init_obs(
- agent_id=d.agent_id,
- init_obs=d.data.raw_dict[SampleBatch.NEXT_OBS],
- init_infos=d.data.raw_dict[SampleBatch.INFOS],
- t=d.data.raw_dict[SampleBatch.T],
- )
- else:
- episode.add_action_reward_done_next_obs(
- d.agent_id, d.data.raw_dict
- )
- # Need to evaluate next actions.
- if not (
- all_agents_done
- or agent_terminateds.get(d.agent_id, False)
- or agent_truncateds.get(d.agent_id, False)
- or episode.is_done(d.agent_id)
- ):
- # Add to eval set if env is not done and this particular agent
- # is also not done.
- item = AgentConnectorDataType(d.env_id, d.agent_id, d.data)
- to_eval[policy_id].append(item)
- # Finished advancing episode by 1 step, mark it so.
- episode.step()
- # Exception: The very first env.poll() call causes the env to get reset
- # (no step taken yet, just a single starting observation logged).
- # We need to skip this callback in this case.
- if episode.length > 0:
- # Invoke the `on_episode_step` callback after the step is logged
- # to the episode.
- self._callbacks.on_episode_step(
- worker=self._worker,
- base_env=self._base_env,
- policies=self._worker.policy_map,
- episode=episode,
- env_index=env_id,
- )
- # Episode is terminated/truncated for all agents
- # (terminateds[__all__] == True or truncateds[__all__] == True).
- if all_agents_done:
- # _handle_done_episode will build a MultiAgentBatch for all
- # the agents that are done during this step of rollout in
- # the case of _multiple_episodes_in_batch=False.
- self._handle_done_episode(
- env_id,
- env_obs,
- terminateds[env_id]["__all__"] or truncateds[env_id]["__all__"],
- active_envs,
- to_eval,
- outputs,
- )
- # Try to build something.
- if self._multiple_episodes_in_batch:
- sample_batch = self._try_build_truncated_episode_multi_agent_batch(
- self._batch_builders[env_id], episode
- )
- if sample_batch:
- outputs.append(sample_batch)
- # SampleBatch built from data collected by batch_builder.
- # Clean up and delete the batch_builder.
- del self._batch_builders[env_id]
- return active_envs, to_eval, outputs
- def _build_done_episode(
- self,
- env_id: EnvID,
- is_done: bool,
- outputs: List[SampleBatchType],
- ):
- """Builds a MultiAgentSampleBatch from the episode and adds it to outputs.
- Args:
- env_id: The env id.
- is_done: Whether the env is done.
- outputs: The list of outputs to add the
- """
- episode: EpisodeV2 = self._active_episodes[env_id]
- batch_builder = self._batch_builders[env_id]
- episode.postprocess_episode(
- batch_builder=batch_builder,
- is_done=is_done,
- check_dones=is_done,
- )
- # If, we are not allowed to pack the next episode into the same
- # SampleBatch (batch_mode=complete_episodes) -> Build the
- # MultiAgentBatch from a single episode and add it to "outputs".
- # Otherwise, just postprocess and continue collecting across
- # episodes.
- if not self._multiple_episodes_in_batch:
- ma_sample_batch = _build_multi_agent_batch(
- episode.episode_id,
- batch_builder,
- self._large_batch_threshold,
- self._multiple_episodes_in_batch,
- )
- if ma_sample_batch:
- outputs.append(ma_sample_batch)
- # SampleBatch built from data collected by batch_builder.
- # Clean up and delete the batch_builder.
- del self._batch_builders[env_id]
- def __process_resetted_obs_for_eval(
- self,
- env_id: EnvID,
- obs: Dict[EnvID, Dict[AgentID, EnvObsType]],
- infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]],
- episode: EpisodeV2,
- to_eval: Dict[PolicyID, List[AgentConnectorDataType]],
- ):
- """Process resetted obs through agent connectors for policy eval.
- Args:
- env_id: The env id.
- obs: The Resetted obs.
- episode: New episode.
- to_eval: List of agent connector data for policy eval.
- """
- per_policy_resetted_obs: Dict[PolicyID, List] = defaultdict(list)
- # types: AgentID, EnvObsType
- for agent_id, raw_obs in obs[env_id].items():
- policy_id: PolicyID = episode.policy_for(agent_id)
- per_policy_resetted_obs[policy_id].append((agent_id, raw_obs))
- for policy_id, agents_obs in per_policy_resetted_obs.items():
- policy = self._worker.policy_map[policy_id]
- acd_list: List[AgentConnectorDataType] = [
- AgentConnectorDataType(
- env_id,
- agent_id,
- {
- SampleBatch.NEXT_OBS: obs,
- SampleBatch.INFOS: infos,
- SampleBatch.T: episode.length,
- SampleBatch.AGENT_INDEX: episode.agent_index(agent_id),
- },
- )
- for agent_id, obs in agents_obs
- ]
- # Call agent connectors on these initial obs.
- processed = policy.agent_connectors(acd_list)
- for d in processed:
- episode.add_init_obs(
- agent_id=d.agent_id,
- init_obs=d.data.raw_dict[SampleBatch.NEXT_OBS],
- init_infos=d.data.raw_dict[SampleBatch.INFOS],
- t=d.data.raw_dict[SampleBatch.T],
- )
- to_eval[policy_id].append(d)
- def _handle_done_episode(
- self,
- env_id: EnvID,
- env_obs_or_exception: MultiAgentDict,
- is_done: bool,
- active_envs: Set[EnvID],
- to_eval: Dict[PolicyID, List[AgentConnectorDataType]],
- outputs: List[SampleBatchType],
- ) -> None:
- """Handle an all-finished episode.
- Add collected SampleBatch to batch builder. Reset corresponding env, etc.
- Args:
- env_id: Environment ID.
- env_obs_or_exception: Last per-environment observation or Exception.
- env_infos: Last per-environment infos.
- is_done: If all agents are done.
- active_envs: Set of active env ids.
- to_eval: Output container for policy eval data.
- outputs: Output container for collected sample batches.
- """
- if isinstance(env_obs_or_exception, Exception):
- episode_or_exception: Exception = env_obs_or_exception
- # Tell the sampler we have got a faulty episode.
- outputs.append(RolloutMetrics(episode_faulty=True))
- else:
- episode_or_exception: EpisodeV2 = self._active_episodes[env_id]
- # Add rollout metrics.
- outputs.extend(
- self._get_rollout_metrics(
- episode_or_exception, policy_map=self._worker.policy_map
- )
- )
- # Output the collected episode after adding rollout metrics so that we
- # always fetch metrics with RolloutWorker before we fetch samples.
- # This is because we need to behave like env_runner() for now.
- self._build_done_episode(env_id, is_done, outputs)
- # Clean up and deleted the post-processed episode now that we have collected
- # its data.
- self.end_episode(env_id, episode_or_exception)
- # Create a new episode instance (before we reset the sub-environment).
- new_episode: EpisodeV2 = self.create_episode(env_id)
- # The sub environment at index `env_id` might throw an exception
- # during the following `try_reset()` attempt. If configured with
- # `restart_failed_sub_environments=True`, the BaseEnv will restart
- # the affected sub environment (create a new one using its c'tor) and
- # must reset the recreated sub env right after that.
- # Should the sub environment fail indefinitely during these
- # repeated reset attempts, the entire worker will be blocked.
- # This would be ok, b/c the alternative would be the worker crashing
- # entirely.
- while True:
- resetted_obs, resetted_infos = self._base_env.try_reset(env_id)
- if (
- resetted_obs is None
- or resetted_obs == ASYNC_RESET_RETURN
- or not isinstance(resetted_obs[env_id], Exception)
- ):
- break
- else:
- # Report a faulty episode.
- outputs.append(RolloutMetrics(episode_faulty=True))
- # Reset connector state if this is a hard reset.
- for p in self._worker.policy_map.cache.values():
- p.agent_connectors.reset(env_id)
- # Creates a new episode if this is not async return.
- # If reset is async, we will get its result in some future poll.
- if resetted_obs is not None and resetted_obs != ASYNC_RESET_RETURN:
- self._active_episodes[env_id] = new_episode
- self._call_on_episode_start(new_episode, env_id)
- self.__process_resetted_obs_for_eval(
- env_id,
- resetted_obs,
- resetted_infos,
- new_episode,
- to_eval,
- )
- # Step after adding initial obs. This will give us 0 env and agent step.
- new_episode.step()
- active_envs.add(env_id)
- def create_episode(self, env_id: EnvID) -> EpisodeV2:
- """Creates a new EpisodeV2 instance and returns it.
- Calls `on_episode_created` callbacks, but does NOT reset the respective
- sub-environment yet.
- Args:
- env_id: Env ID.
- Returns:
- The newly created EpisodeV2 instance.
- """
- # Make sure we currently don't have an active episode under this env ID.
- assert env_id not in self._active_episodes
- # Create a new episode under the same `env_id` and call the
- # `on_episode_created` callbacks.
- new_episode = EpisodeV2(
- env_id,
- self._worker.policy_map,
- self._worker.policy_mapping_fn,
- worker=self._worker,
- callbacks=self._callbacks,
- )
- # Call `on_episode_created()` callback.
- self._callbacks.on_episode_created(
- worker=self._worker,
- base_env=self._base_env,
- policies=self._worker.policy_map,
- env_index=env_id,
- episode=new_episode,
- )
- return new_episode
- def end_episode(
- self, env_id: EnvID, episode_or_exception: Union[EpisodeV2, Exception]
- ):
- """Cleans up an episode that has finished.
- Args:
- env_id: Env ID.
- episode_or_exception: Instance of an episode if it finished successfully.
- Otherwise, the exception that was thrown,
- """
- # Signal the end of an episode, either successfully with an Episode or
- # unsuccessfully with an Exception.
- self._callbacks.on_episode_end(
- worker=self._worker,
- base_env=self._base_env,
- policies=self._worker.policy_map,
- episode=episode_or_exception,
- env_index=env_id,
- )
- # Call each (in-memory) policy's Exploration.on_episode_end
- # method.
- # Note: This may break the exploration (e.g. ParameterNoise) of
- # policies in the `policy_map` that have not been recently used
- # (and are therefore stashed to disk). However, we certainly do not
- # want to loop through all (even stashed) policies here as that
- # would counter the purpose of the LRU policy caching.
- for p in self._worker.policy_map.cache.values():
- if getattr(p, "exploration", None) is not None:
- p.exploration.on_episode_end(
- policy=p,
- environment=self._base_env,
- episode=episode_or_exception,
- tf_sess=p.get_session(),
- )
- if isinstance(episode_or_exception, EpisodeV2):
- episode = episode_or_exception
- if episode.total_agent_steps == 0:
- # if the key does not exist it means that throughout the episode all
- # observations were empty (i.e. there was no agent in the env)
- msg = (
- f"Data from episode {episode.episode_id} does not show any agent "
- f"interactions. Hint: Make sure for at least one timestep in the "
- f"episode, env.step() returns non-empty values."
- )
- raise ValueError(msg)
- # Clean up the episode and batch_builder for this env id.
- if env_id in self._active_episodes:
- del self._active_episodes[env_id]
- def _try_build_truncated_episode_multi_agent_batch(
- self, batch_builder: _PolicyCollectorGroup, episode: EpisodeV2
- ) -> Union[None, SampleBatch, MultiAgentBatch]:
- # Measure batch size in env-steps.
- if self._count_steps_by == "env_steps":
- built_steps = batch_builder.env_steps
- ongoing_steps = episode.active_env_steps
- # Measure batch-size in agent-steps.
- else:
- built_steps = batch_builder.agent_steps
- ongoing_steps = episode.active_agent_steps
- # Reached the fragment-len -> We should build an MA-Batch.
- if built_steps + ongoing_steps >= self._rollout_fragment_length:
- if self._count_steps_by != "agent_steps":
- assert built_steps + ongoing_steps == self._rollout_fragment_length, (
- f"built_steps ({built_steps}) + ongoing_steps ({ongoing_steps}) != "
- f"rollout_fragment_length ({self._rollout_fragment_length})."
- )
- # If we reached the fragment-len only because of `episode_id`
- # (still ongoing) -> postprocess `episode_id` first.
- if built_steps < self._rollout_fragment_length:
- episode.postprocess_episode(batch_builder=batch_builder, is_done=False)
- # If builder has collected some data,
- # build the MA-batch and add to return values.
- if batch_builder.agent_steps > 0:
- return _build_multi_agent_batch(
- episode.episode_id,
- batch_builder,
- self._large_batch_threshold,
- self._multiple_episodes_in_batch,
- )
- # No batch-builder:
- # We have reached the rollout-fragment length w/o any agent
- # steps! Warn that the environment may never request any
- # actions from any agents.
- elif log_once("no_agent_steps"):
- logger.warning(
- "Your environment seems to be stepping w/o ever "
- "emitting agent observations (agents are never "
- "requested to act)!"
- )
- return None
- def _do_policy_eval(
- self,
- to_eval: Dict[PolicyID, List[AgentConnectorDataType]],
- ) -> Dict[PolicyID, PolicyOutputType]:
- """Call compute_actions on collected episode data to get next action.
- Args:
- to_eval: Mapping of policy IDs to lists of AgentConnectorDataType objects
- (items in these lists will be the batch's items for the model
- forward pass).
- Returns:
- Dict mapping PolicyIDs to compute_actions_from_input_dict() outputs.
- """
- policies = self._worker.policy_map
- # In case policy map has changed, try to find the new policy that
- # should handle all these per-agent eval data.
- # Throws exception if these agents are mapped to multiple different
- # policies now.
- def _try_find_policy_again(eval_data: AgentConnectorDataType):
- policy_id = None
- for d in eval_data:
- episode = self._active_episodes[d.env_id]
- # Force refresh policy mapping on the episode.
- pid = episode.policy_for(d.agent_id, refresh=True)
- if policy_id is not None and pid != policy_id:
- raise ValueError(
- "Policy map changed. The list of eval data that was handled "
- f"by a same policy is now handled by policy {pid} "
- "and {policy_id}. "
- "Please don't do this in the middle of an episode."
- )
- policy_id = pid
- return _get_or_raise(self._worker.policy_map, policy_id)
- eval_results: Dict[PolicyID, TensorStructType] = {}
- for policy_id, eval_data in to_eval.items():
- # In case the policyID has been removed from this worker, we need to
- # re-assign policy_id and re-lookup the Policy object to use.
- try:
- policy: Policy = _get_or_raise(policies, policy_id)
- except ValueError:
- # policy_mapping_fn from the worker may have already been
- # changed (mapping fn not staying constant within one episode).
- policy: Policy = _try_find_policy_again(eval_data)
- input_dict = _batch_inference_sample_batches(
- [d.data.sample_batch for d in eval_data]
- )
- eval_results[policy_id] = policy.compute_actions_from_input_dict(
- input_dict,
- timestep=policy.global_timestep,
- episodes=[self._active_episodes[t.env_id] for t in eval_data],
- )
- return eval_results
- def _process_policy_eval_results(
- self,
- active_envs: Set[EnvID],
- to_eval: Dict[PolicyID, List[AgentConnectorDataType]],
- eval_results: Dict[PolicyID, PolicyOutputType],
- off_policy_actions: MultiEnvDict,
- ):
- """Process the output of policy neural network evaluation.
- Records policy evaluation results into agent connectors and
- returns replies to send back to agents in the env.
- Args:
- active_envs: Set of env IDs that are still active.
- to_eval: Mapping of policy IDs to lists of AgentConnectorDataType objects.
- eval_results: Mapping of policy IDs to list of
- actions, rnn-out states, extra-action-fetches dicts.
- off_policy_actions: Doubly keyed dict of env-ids -> agent ids ->
- off-policy-action, returned by a `BaseEnv.poll()` call.
- Returns:
- Nested dict of env id -> agent id -> actions to be sent to
- Env (np.ndarrays).
- """
- actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = defaultdict(dict)
- for env_id in active_envs:
- actions_to_send[env_id] = {} # at minimum send empty dict
- # types: PolicyID, List[AgentConnectorDataType]
- for policy_id, eval_data in to_eval.items():
- actions: TensorStructType = eval_results[policy_id][0]
- actions = convert_to_numpy(actions)
- rnn_out: StateBatches = eval_results[policy_id][1]
- extra_action_out: dict = eval_results[policy_id][2]
- # In case actions is a list (representing the 0th dim of a batch of
- # primitive actions), try converting it first.
- if isinstance(actions, list):
- actions = np.array(actions)
- # Split action-component batches into single action rows.
- actions: List[EnvActionType] = unbatch(actions)
- policy: Policy = _get_or_raise(self._worker.policy_map, policy_id)
- assert (
- policy.agent_connectors and policy.action_connectors
- ), "EnvRunnerV2 requires action connectors to work."
- # types: int, EnvActionType
- for i, action in enumerate(actions):
- env_id: int = eval_data[i].env_id
- agent_id: AgentID = eval_data[i].agent_id
- input_dict: TensorStructType = eval_data[i].data.raw_dict
- rnn_states: List[StateBatches] = tree.map_structure(
- lambda x, i=i: x[i], rnn_out
- )
- # extra_action_out could be a nested dict
- fetches: Dict = tree.map_structure(
- lambda x, i=i: x[i], extra_action_out
- )
- # Post-process policy output by running them through action connectors.
- ac_data = ActionConnectorDataType(
- env_id, agent_id, input_dict, (action, rnn_states, fetches)
- )
- action_to_send, rnn_states, fetches = policy.action_connectors(
- ac_data
- ).output
- # The action we want to buffer is the direct output of
- # compute_actions_from_input_dict() here. This is because we want to
- # send the unsqushed actions to the environment while learning and
- # possibly basing subsequent actions on the squashed actions.
- action_to_buffer = (
- action
- if env_id not in off_policy_actions
- or agent_id not in off_policy_actions[env_id]
- else off_policy_actions[env_id][agent_id]
- )
- # Notify agent connectors with this new policy output.
- # Necessary for state buffering agent connectors, for example.
- ac_data: ActionConnectorDataType = ActionConnectorDataType(
- env_id,
- agent_id,
- input_dict,
- (action_to_buffer, rnn_states, fetches),
- )
- policy.agent_connectors.on_policy_output(ac_data)
- assert agent_id not in actions_to_send[env_id]
- actions_to_send[env_id][agent_id] = action_to_send
- return actions_to_send
- def _maybe_render(self):
- """Visualize environment."""
- # Check if we should render.
- if not self._render or not self._simple_image_viewer:
- return
- t5 = time.time()
- # Render can either return an RGB image (uint8 [w x h x 3] numpy
- # array) or take care of rendering itself (returning True).
- rendered = self._base_env.try_render()
- # Rendering returned an image -> Display it in a SimpleImageViewer.
- if isinstance(rendered, np.ndarray) and len(rendered.shape) == 3:
- self._simple_image_viewer.imshow(rendered)
- elif rendered not in [True, False, None]:
- raise ValueError(
- f"The env's ({self._base_env}) `try_render()` method returned an"
- " unsupported value! Make sure you either return a "
- "uint8/w x h x 3 (RGB) image or handle rendering in a "
- "window and then return `True`."
- )
- self._perf_stats.incr("env_render_time", time.time() - t5)
- def _fetch_atari_metrics(base_env: BaseEnv) -> List[RolloutMetrics]:
- """Atari games have multiple logical episodes, one per life.
- However, for metrics reporting we count full episodes, all lives included.
- """
- sub_environments = base_env.get_sub_environments()
- if not sub_environments:
- return None
- atari_out = []
- for sub_env in sub_environments:
- monitor = get_wrapper_by_cls(sub_env, MonitorEnv)
- if not monitor:
- return None
- for eps_rew, eps_len in monitor.next_episode_results():
- atari_out.append(RolloutMetrics(eps_len, eps_rew))
- return atari_out
- def _get_or_raise(
- mapping: Dict[PolicyID, Union[Policy, Preprocessor, Filter]], policy_id: PolicyID
- ) -> Union[Policy, Preprocessor, Filter]:
- """Returns an object under key `policy_id` in `mapping`.
- Args:
- mapping (Dict[PolicyID, Union[Policy, Preprocessor, Filter]]): The
- mapping dict from policy id (str) to actual object (Policy,
- Preprocessor, etc.).
- policy_id: The policy ID to lookup.
- Returns:
- Union[Policy, Preprocessor, Filter]: The found object.
- Raises:
- ValueError: If `policy_id` cannot be found in `mapping`.
- """
- if policy_id not in mapping:
- raise ValueError(
- "Could not find policy for agent: PolicyID `{}` not found "
- "in policy map, whose keys are `{}`.".format(policy_id, mapping.keys())
- )
- return mapping[policy_id]
|