import logging import time from collections import defaultdict from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Set, Tuple, Union import numpy as np import tree # pip install dm_tree from ray.rllib.env.base_env import ASYNC_RESET_RETURN, BaseEnv from ray.rllib.env.external_env import ExternalEnvWrapper from ray.rllib.env.wrappers.atari_wrappers import MonitorEnv, get_wrapper_by_cls from ray.rllib.evaluation.collectors.simple_list_collector import _PolicyCollectorGroup from ray.rllib.evaluation.episode_v2 import EpisodeV2 from ray.rllib.evaluation.metrics import RolloutMetrics from ray.rllib.models.preprocessors import Preprocessor from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch, concat_samples from ray.rllib.utils.annotations import OldAPIStack from ray.rllib.utils.filter import Filter from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.spaces.space_utils import get_original_space, unbatch from ray.rllib.utils.typing import ( ActionConnectorDataType, AgentConnectorDataType, AgentID, EnvActionType, EnvID, EnvInfoDict, EnvObsType, MultiAgentDict, MultiEnvDict, PolicyID, PolicyOutputType, SampleBatchType, StateBatches, TensorStructType, ) from ray.util.debug import log_once if TYPE_CHECKING: from gymnasium.envs.classic_control.rendering import SimpleImageViewer from ray.rllib.callbacks.callbacks import RLlibCallback from ray.rllib.evaluation.rollout_worker import RolloutWorker logger = logging.getLogger(__name__) MIN_LARGE_BATCH_THRESHOLD = 1000 DEFAULT_LARGE_BATCH_THRESHOLD = 5000 MS_TO_SEC = 1000.0 @OldAPIStack class _PerfStats: """Sampler perf stats that will be included in rollout metrics.""" def __init__(self, ema_coef: Optional[float] = None): # If not None, enable Exponential Moving Average mode. # The way we update stats is by: # updated = (1 - ema_coef) * old + ema_coef * new # In general provides more responsive stats about sampler performance. # TODO(jungong) : make ema the default (only) mode if it works well. self.ema_coef = ema_coef self.iters = 0 self.raw_obs_processing_time = 0.0 self.inference_time = 0.0 self.action_processing_time = 0.0 self.env_wait_time = 0.0 self.env_render_time = 0.0 def incr(self, field: str, value: Union[int, float]): if field == "iters": self.iters += value return # All the other fields support either global average or ema mode. if self.ema_coef is None: # Global average. self.__dict__[field] += value else: self.__dict__[field] = (1.0 - self.ema_coef) * self.__dict__[ field ] + self.ema_coef * value def _get_avg(self): # Mean multiplicator (1000 = sec -> ms). factor = MS_TO_SEC / self.iters return { # Raw observation preprocessing. "mean_raw_obs_processing_ms": self.raw_obs_processing_time * factor, # Computing actions through policy. "mean_inference_ms": self.inference_time * factor, # Processing actions (to be sent to env, e.g. clipping). "mean_action_processing_ms": self.action_processing_time * factor, # Waiting for environment (during poll). "mean_env_wait_ms": self.env_wait_time * factor, # Environment rendering (False by default). "mean_env_render_ms": self.env_render_time * factor, } def _get_ema(self): # In EMA mode, stats are already (exponentially) averaged, # hence we only need to do the sec -> ms conversion here. return { # Raw observation preprocessing. "mean_raw_obs_processing_ms": self.raw_obs_processing_time * MS_TO_SEC, # Computing actions through policy. "mean_inference_ms": self.inference_time * MS_TO_SEC, # Processing actions (to be sent to env, e.g. clipping). "mean_action_processing_ms": self.action_processing_time * MS_TO_SEC, # Waiting for environment (during poll). "mean_env_wait_ms": self.env_wait_time * MS_TO_SEC, # Environment rendering (False by default). "mean_env_render_ms": self.env_render_time * MS_TO_SEC, } def get(self): if self.ema_coef is None: return self._get_avg() else: return self._get_ema() @OldAPIStack class _NewDefaultDict(defaultdict): def __missing__(self, env_id): ret = self[env_id] = self.default_factory(env_id) return ret @OldAPIStack def _build_multi_agent_batch( episode_id: int, batch_builder: _PolicyCollectorGroup, large_batch_threshold: int, multiple_episodes_in_batch: bool, ) -> MultiAgentBatch: """Build MultiAgentBatch from a dict of _PolicyCollectors. Args: env_steps: total env steps. policy_collectors: collected training SampleBatchs by policy. Returns: Always returns a sample batch in MultiAgentBatch format. """ ma_batch = {} for pid, collector in batch_builder.policy_collectors.items(): if collector.agent_steps <= 0: continue if batch_builder.agent_steps > large_batch_threshold and log_once( "large_batch_warning" ): logger.warning( "More than {} observations in {} env steps for " "episode {} ".format( batch_builder.agent_steps, batch_builder.env_steps, episode_id ) + "are buffered in the sampler. If this is more than you " "expected, check that that you set a horizon on your " "environment correctly and that it terminates at some " "point. Note: In multi-agent environments, " "`rollout_fragment_length` sets the batch size based on " "(across-agents) environment steps, not the steps of " "individual agents, which can result in unexpectedly " "large batches." + ( "Also, you may be waiting for your Env to " "terminate (batch_mode=`complete_episodes`). Make sure " "it does at some point." if not multiple_episodes_in_batch else "" ) ) batch = collector.build() ma_batch[pid] = batch # Create the multi agent batch. return MultiAgentBatch(policy_batches=ma_batch, env_steps=batch_builder.env_steps) @OldAPIStack def _batch_inference_sample_batches(eval_data: List[SampleBatch]) -> SampleBatch: """Batch a list of input SampleBatches into a single SampleBatch. Args: eval_data: list of SampleBatches. Returns: single batched SampleBatch. """ inference_batch = concat_samples(eval_data) if "state_in_0" in inference_batch: batch_size = len(eval_data) inference_batch[SampleBatch.SEQ_LENS] = np.ones(batch_size, dtype=np.int32) return inference_batch @OldAPIStack class EnvRunnerV2: """Collect experiences from user environment using Connectors.""" def __init__( self, worker: "RolloutWorker", base_env: BaseEnv, multiple_episodes_in_batch: bool, callbacks: "RLlibCallback", perf_stats: _PerfStats, rollout_fragment_length: int = 200, count_steps_by: str = "env_steps", render: bool = None, ): """ Args: worker: Reference to the current rollout worker. base_env: Env implementing BaseEnv. multiple_episodes_in_batch: Whether to pack multiple episodes into each batch. This guarantees batches will be exactly `rollout_fragment_length` in size. callbacks: User callbacks to run on episode events. perf_stats: Record perf stats into this object. rollout_fragment_length: The length of a fragment to collect before building a SampleBatch from the data and resetting the SampleBatchBuilder object. count_steps_by: One of "env_steps" (default) or "agent_steps". Use "agent_steps", if you want rollout lengths to be counted by individual agent steps. In a multi-agent env, a single env_step contains one or more agent_steps, depending on how many agents are present at any given time in the ongoing episode. render: Whether to try to render the environment after each step. """ self._worker = worker if isinstance(base_env, ExternalEnvWrapper): raise ValueError( "Policies using the new Connector API do not support ExternalEnv." ) self._base_env = base_env self._multiple_episodes_in_batch = multiple_episodes_in_batch self._callbacks = callbacks self._perf_stats = perf_stats self._rollout_fragment_length = rollout_fragment_length self._count_steps_by = count_steps_by self._render = render # May be populated for image rendering. self._simple_image_viewer: Optional[ "SimpleImageViewer" ] = self._get_simple_image_viewer() # Keeps track of active episodes. self._active_episodes: Dict[EnvID, EpisodeV2] = {} self._batch_builders: Dict[EnvID, _PolicyCollectorGroup] = _NewDefaultDict( self._new_batch_builder ) self._large_batch_threshold: int = ( max(MIN_LARGE_BATCH_THRESHOLD, self._rollout_fragment_length * 10) if self._rollout_fragment_length != float("inf") else DEFAULT_LARGE_BATCH_THRESHOLD ) def _get_simple_image_viewer(self): """Maybe construct a SimpleImageViewer instance for episode rendering.""" # Try to render the env, if required. if not self._render: return None try: from gymnasium.envs.classic_control.rendering import SimpleImageViewer return SimpleImageViewer() except (ImportError, ModuleNotFoundError): self._render = False # disable rendering logger.warning( "Could not import gymnasium.envs.classic_control." "rendering! Try `pip install gymnasium[all]`." ) return None def _call_on_episode_start(self, episode, env_id): # Call each policy's Exploration.on_episode_start method. # Note: This may break the exploration (e.g. ParameterNoise) of # policies in the `policy_map` that have not been recently used # (and are therefore stashed to disk). However, we certainly do not # want to loop through all (even stashed) policies here as that # would counter the purpose of the LRU policy caching. for p in self._worker.policy_map.cache.values(): if getattr(p, "exploration", None) is not None: p.exploration.on_episode_start( policy=p, environment=self._base_env, episode=episode, tf_sess=p.get_session(), ) # Call `on_episode_start()` callback. self._callbacks.on_episode_start( worker=self._worker, base_env=self._base_env, policies=self._worker.policy_map, env_index=env_id, episode=episode, ) def _new_batch_builder(self, _) -> _PolicyCollectorGroup: """Create a new batch builder. We create a _PolicyCollectorGroup based on the full policy_map as the batch builder. """ return _PolicyCollectorGroup(self._worker.policy_map) def run(self) -> Iterator[SampleBatchType]: """Samples and yields training episodes continuously. Yields: Object containing state, action, reward, terminal condition, and other fields as dictated by `policy`. """ while True: outputs = self.step() for o in outputs: yield o def step(self) -> List[SampleBatchType]: """Samples training episodes by stepping through environments.""" self._perf_stats.incr("iters", 1) t0 = time.time() # Get observations from all ready agents. # types: MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, ... ( unfiltered_obs, rewards, terminateds, truncateds, infos, off_policy_actions, ) = self._base_env.poll() env_poll_time = time.time() - t0 # Process observations and prepare for policy evaluation. t1 = time.time() # types: Set[EnvID], Dict[PolicyID, List[AgentConnectorDataType]], # List[Union[RolloutMetrics, SampleBatchType]] active_envs, to_eval, outputs = self._process_observations( unfiltered_obs=unfiltered_obs, rewards=rewards, terminateds=terminateds, truncateds=truncateds, infos=infos, ) self._perf_stats.incr("raw_obs_processing_time", time.time() - t1) # Do batched policy eval (accross vectorized envs). t2 = time.time() # types: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]] eval_results = self._do_policy_eval(to_eval=to_eval) self._perf_stats.incr("inference_time", time.time() - t2) # Process results and update episode state. t3 = time.time() actions_to_send: Dict[ EnvID, Dict[AgentID, EnvActionType] ] = self._process_policy_eval_results( active_envs=active_envs, to_eval=to_eval, eval_results=eval_results, off_policy_actions=off_policy_actions, ) self._perf_stats.incr("action_processing_time", time.time() - t3) # Return computed actions to ready envs. We also send to envs that have # taken off-policy actions; those envs are free to ignore the action. t4 = time.time() self._base_env.send_actions(actions_to_send) self._perf_stats.incr("env_wait_time", env_poll_time + time.time() - t4) self._maybe_render() return outputs def _get_rollout_metrics( self, episode: EpisodeV2, policy_map: Dict[str, Policy] ) -> List[RolloutMetrics]: """Get rollout metrics from completed episode.""" # TODO(jungong) : why do we need to handle atari metrics differently? # Can we unify atari and normal env metrics? atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics(self._base_env) if atari_metrics is not None: for m in atari_metrics: m._replace(custom_metrics=episode.custom_metrics) return atari_metrics # Create connector metrics connector_metrics = {} active_agents = episode.get_agents() for agent in active_agents: policy_id = episode.policy_for(agent) policy = episode.policy_map[policy_id] connector_metrics[policy_id] = policy.get_connector_metrics() # Otherwise, return RolloutMetrics for the episode. return [ RolloutMetrics( episode_length=episode.length, episode_reward=episode.total_reward, agent_rewards=dict(episode.agent_rewards), custom_metrics=episode.custom_metrics, perf_stats={}, hist_data=episode.hist_data, media=episode.media, connector_metrics=connector_metrics, ) ] def _process_observations( self, unfiltered_obs: MultiEnvDict, rewards: MultiEnvDict, terminateds: MultiEnvDict, truncateds: MultiEnvDict, infos: MultiEnvDict, ) -> Tuple[ Set[EnvID], Dict[PolicyID, List[AgentConnectorDataType]], List[Union[RolloutMetrics, SampleBatchType]], ]: """Process raw obs from env. Group data for active agents by policy. Reset environments that are done. Args: unfiltered_obs: The unfiltered, raw observations from the BaseEnv (vectorized, possibly multi-agent). Dict of dict: By env index, then agent ID, then mapped to actual obs. rewards: The rewards MultiEnvDict of the BaseEnv. terminateds: The `terminated` flags MultiEnvDict of the BaseEnv. truncateds: The `truncated` flags MultiEnvDict of the BaseEnv. infos: The MultiEnvDict of infos dicts of the BaseEnv. Returns: A tuple of: A list of envs that were active during this step. AgentConnectorDataType for active agents for policy evaluation. SampleBatches and RolloutMetrics for completed agents for output. """ # Output objects. # Note that we need to track envs that are active during this round explicitly, # just to be confident which envs require us to send at least an empty action # dict to. # We can not get this from the _active_episode or to_eval lists because # 1. All envs are not required to step during every single step. And # 2. to_eval only contains data for the agents that are still active. An env may # be active but all agents are done during the step. active_envs: Set[EnvID] = set() to_eval: Dict[PolicyID, List[AgentConnectorDataType]] = defaultdict(list) outputs: List[Union[RolloutMetrics, SampleBatchType]] = [] # For each (vectorized) sub-environment. # types: EnvID, Dict[AgentID, EnvObsType] for env_id, env_obs in unfiltered_obs.items(): # Check for env_id having returned an error instead of a multi-agent # obs dict. This is how our BaseEnv can tell the caller to `poll()` that # one of its sub-environments is faulty and should be restarted (and the # ongoing episode should not be used for training). if isinstance(env_obs, Exception): assert terminateds[env_id]["__all__"] is True, ( f"ERROR: When a sub-environment (env-id {env_id}) returns an error " "as observation, the terminateds[__all__] flag must also be set to " "True!" ) # all_agents_obs is an Exception here. # Drop this episode and skip to next. self._handle_done_episode( env_id=env_id, env_obs_or_exception=env_obs, is_done=True, active_envs=active_envs, to_eval=to_eval, outputs=outputs, ) continue if env_id not in self._active_episodes: episode: EpisodeV2 = self.create_episode(env_id) self._active_episodes[env_id] = episode else: episode: EpisodeV2 = self._active_episodes[env_id] # If this episode is brand-new, call the episode start callback(s). # Note: EpisodeV2s are initialized with length=-1 (before the reset). if not episode.has_init_obs(): self._call_on_episode_start(episode, env_id) # Check episode termination conditions. if terminateds[env_id]["__all__"] or truncateds[env_id]["__all__"]: all_agents_done = True else: all_agents_done = False active_envs.add(env_id) # Special handling of common info dict. episode.set_last_info("__common__", infos[env_id].get("__common__", {})) # Agent sample batches grouped by policy. Each set of sample batches will # go through agent connectors together. sample_batches_by_policy = defaultdict(list) # Whether an agent is terminated or truncated. agent_terminateds = {} agent_truncateds = {} for agent_id, obs in env_obs.items(): assert agent_id != "__all__" policy_id: PolicyID = episode.policy_for(agent_id) agent_terminated = bool( terminateds[env_id]["__all__"] or terminateds[env_id].get(agent_id) ) agent_terminateds[agent_id] = agent_terminated agent_truncated = bool( truncateds[env_id]["__all__"] or truncateds[env_id].get(agent_id, False) ) agent_truncateds[agent_id] = agent_truncated # A completely new agent is already done -> Skip entirely. if not episode.has_init_obs(agent_id) and ( agent_terminated or agent_truncated ): continue values_dict = { SampleBatch.T: episode.length, # Episodes start at -1 before we # add the initial obs. After that, we infer from initial obs at # t=0 since that will be our new episode.length. SampleBatch.ENV_ID: env_id, SampleBatch.AGENT_INDEX: episode.agent_index(agent_id), # Last action (SampleBatch.ACTIONS) column will be populated by # StateBufferConnector. # Reward received after taking action at timestep t. SampleBatch.REWARDS: rewards[env_id].get(agent_id, 0.0), # After taking action=a, did we reach terminal? SampleBatch.TERMINATEDS: agent_terminated, # Was the episode truncated artificially # (e.g. b/c of some time limit)? SampleBatch.TRUNCATEDS: agent_truncated, SampleBatch.INFOS: infos[env_id].get(agent_id, {}), SampleBatch.NEXT_OBS: obs, } # Queue this obs sample for connector preprocessing. sample_batches_by_policy[policy_id].append((agent_id, values_dict)) # The entire episode is done. if all_agents_done: # Let's check to see if there are any agents that haven't got the # last obs yet. If there are, we have to create fake-last # observations for them. (the environment is not required to do so if # terminateds[__all__]==True or truncateds[__all__]==True). for agent_id in episode.get_agents(): # If the latest obs we got for this agent is done, or if its # episode state is already done, nothing to do. if ( agent_terminateds.get(agent_id, False) or agent_truncateds.get(agent_id, False) or episode.is_done(agent_id) ): continue policy_id: PolicyID = episode.policy_for(agent_id) policy = self._worker.policy_map[policy_id] # Create a fake observation by sampling the original env # observation space. obs_space = get_original_space(policy.observation_space) # Although there is no obs for this agent, there may be # good rewards and info dicts for it. # This is the case for e.g. OpenSpiel games, where a reward # is only earned with the last step, but the obs for that # step is {}. reward = rewards[env_id].get(agent_id, 0.0) info = infos[env_id].get(agent_id, {}) values_dict = { SampleBatch.T: episode.length, SampleBatch.ENV_ID: env_id, SampleBatch.AGENT_INDEX: episode.agent_index(agent_id), # TODO(sven): These should be the summed-up(!) rewards since the # last observation received for this agent. SampleBatch.REWARDS: reward, SampleBatch.TERMINATEDS: True, SampleBatch.TRUNCATEDS: truncateds[env_id].get(agent_id, False), SampleBatch.INFOS: info, SampleBatch.NEXT_OBS: obs_space.sample(), } # Queue these fake obs for connector preprocessing too. sample_batches_by_policy[policy_id].append((agent_id, values_dict)) # Run agent connectors. for policy_id, batches in sample_batches_by_policy.items(): policy: Policy = self._worker.policy_map[policy_id] # Collected full MultiAgentDicts for this environment. # Run agent connectors. assert ( policy.agent_connectors ), "EnvRunnerV2 requires agent connectors to work." acd_list: List[AgentConnectorDataType] = [ AgentConnectorDataType(env_id, agent_id, data) for agent_id, data in batches ] # For all agents mapped to policy_id, run their data # through agent_connectors. processed = policy.agent_connectors(acd_list) for d in processed: # Record transition info if applicable. if not episode.has_init_obs(d.agent_id): episode.add_init_obs( agent_id=d.agent_id, init_obs=d.data.raw_dict[SampleBatch.NEXT_OBS], init_infos=d.data.raw_dict[SampleBatch.INFOS], t=d.data.raw_dict[SampleBatch.T], ) else: episode.add_action_reward_done_next_obs( d.agent_id, d.data.raw_dict ) # Need to evaluate next actions. if not ( all_agents_done or agent_terminateds.get(d.agent_id, False) or agent_truncateds.get(d.agent_id, False) or episode.is_done(d.agent_id) ): # Add to eval set if env is not done and this particular agent # is also not done. item = AgentConnectorDataType(d.env_id, d.agent_id, d.data) to_eval[policy_id].append(item) # Finished advancing episode by 1 step, mark it so. episode.step() # Exception: The very first env.poll() call causes the env to get reset # (no step taken yet, just a single starting observation logged). # We need to skip this callback in this case. if episode.length > 0: # Invoke the `on_episode_step` callback after the step is logged # to the episode. self._callbacks.on_episode_step( worker=self._worker, base_env=self._base_env, policies=self._worker.policy_map, episode=episode, env_index=env_id, ) # Episode is terminated/truncated for all agents # (terminateds[__all__] == True or truncateds[__all__] == True). if all_agents_done: # _handle_done_episode will build a MultiAgentBatch for all # the agents that are done during this step of rollout in # the case of _multiple_episodes_in_batch=False. self._handle_done_episode( env_id, env_obs, terminateds[env_id]["__all__"] or truncateds[env_id]["__all__"], active_envs, to_eval, outputs, ) # Try to build something. if self._multiple_episodes_in_batch: sample_batch = self._try_build_truncated_episode_multi_agent_batch( self._batch_builders[env_id], episode ) if sample_batch: outputs.append(sample_batch) # SampleBatch built from data collected by batch_builder. # Clean up and delete the batch_builder. del self._batch_builders[env_id] return active_envs, to_eval, outputs def _build_done_episode( self, env_id: EnvID, is_done: bool, outputs: List[SampleBatchType], ): """Builds a MultiAgentSampleBatch from the episode and adds it to outputs. Args: env_id: The env id. is_done: Whether the env is done. outputs: The list of outputs to add the """ episode: EpisodeV2 = self._active_episodes[env_id] batch_builder = self._batch_builders[env_id] episode.postprocess_episode( batch_builder=batch_builder, is_done=is_done, check_dones=is_done, ) # If, we are not allowed to pack the next episode into the same # SampleBatch (batch_mode=complete_episodes) -> Build the # MultiAgentBatch from a single episode and add it to "outputs". # Otherwise, just postprocess and continue collecting across # episodes. if not self._multiple_episodes_in_batch: ma_sample_batch = _build_multi_agent_batch( episode.episode_id, batch_builder, self._large_batch_threshold, self._multiple_episodes_in_batch, ) if ma_sample_batch: outputs.append(ma_sample_batch) # SampleBatch built from data collected by batch_builder. # Clean up and delete the batch_builder. del self._batch_builders[env_id] def __process_resetted_obs_for_eval( self, env_id: EnvID, obs: Dict[EnvID, Dict[AgentID, EnvObsType]], infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]], episode: EpisodeV2, to_eval: Dict[PolicyID, List[AgentConnectorDataType]], ): """Process resetted obs through agent connectors for policy eval. Args: env_id: The env id. obs: The Resetted obs. episode: New episode. to_eval: List of agent connector data for policy eval. """ per_policy_resetted_obs: Dict[PolicyID, List] = defaultdict(list) # types: AgentID, EnvObsType for agent_id, raw_obs in obs[env_id].items(): policy_id: PolicyID = episode.policy_for(agent_id) per_policy_resetted_obs[policy_id].append((agent_id, raw_obs)) for policy_id, agents_obs in per_policy_resetted_obs.items(): policy = self._worker.policy_map[policy_id] acd_list: List[AgentConnectorDataType] = [ AgentConnectorDataType( env_id, agent_id, { SampleBatch.NEXT_OBS: obs, SampleBatch.INFOS: infos, SampleBatch.T: episode.length, SampleBatch.AGENT_INDEX: episode.agent_index(agent_id), }, ) for agent_id, obs in agents_obs ] # Call agent connectors on these initial obs. processed = policy.agent_connectors(acd_list) for d in processed: episode.add_init_obs( agent_id=d.agent_id, init_obs=d.data.raw_dict[SampleBatch.NEXT_OBS], init_infos=d.data.raw_dict[SampleBatch.INFOS], t=d.data.raw_dict[SampleBatch.T], ) to_eval[policy_id].append(d) def _handle_done_episode( self, env_id: EnvID, env_obs_or_exception: MultiAgentDict, is_done: bool, active_envs: Set[EnvID], to_eval: Dict[PolicyID, List[AgentConnectorDataType]], outputs: List[SampleBatchType], ) -> None: """Handle an all-finished episode. Add collected SampleBatch to batch builder. Reset corresponding env, etc. Args: env_id: Environment ID. env_obs_or_exception: Last per-environment observation or Exception. env_infos: Last per-environment infos. is_done: If all agents are done. active_envs: Set of active env ids. to_eval: Output container for policy eval data. outputs: Output container for collected sample batches. """ if isinstance(env_obs_or_exception, Exception): episode_or_exception: Exception = env_obs_or_exception # Tell the sampler we have got a faulty episode. outputs.append(RolloutMetrics(episode_faulty=True)) else: episode_or_exception: EpisodeV2 = self._active_episodes[env_id] # Add rollout metrics. outputs.extend( self._get_rollout_metrics( episode_or_exception, policy_map=self._worker.policy_map ) ) # Output the collected episode after adding rollout metrics so that we # always fetch metrics with RolloutWorker before we fetch samples. # This is because we need to behave like env_runner() for now. self._build_done_episode(env_id, is_done, outputs) # Clean up and deleted the post-processed episode now that we have collected # its data. self.end_episode(env_id, episode_or_exception) # Create a new episode instance (before we reset the sub-environment). new_episode: EpisodeV2 = self.create_episode(env_id) # The sub environment at index `env_id` might throw an exception # during the following `try_reset()` attempt. If configured with # `restart_failed_sub_environments=True`, the BaseEnv will restart # the affected sub environment (create a new one using its c'tor) and # must reset the recreated sub env right after that. # Should the sub environment fail indefinitely during these # repeated reset attempts, the entire worker will be blocked. # This would be ok, b/c the alternative would be the worker crashing # entirely. while True: resetted_obs, resetted_infos = self._base_env.try_reset(env_id) if ( resetted_obs is None or resetted_obs == ASYNC_RESET_RETURN or not isinstance(resetted_obs[env_id], Exception) ): break else: # Report a faulty episode. outputs.append(RolloutMetrics(episode_faulty=True)) # Reset connector state if this is a hard reset. for p in self._worker.policy_map.cache.values(): p.agent_connectors.reset(env_id) # Creates a new episode if this is not async return. # If reset is async, we will get its result in some future poll. if resetted_obs is not None and resetted_obs != ASYNC_RESET_RETURN: self._active_episodes[env_id] = new_episode self._call_on_episode_start(new_episode, env_id) self.__process_resetted_obs_for_eval( env_id, resetted_obs, resetted_infos, new_episode, to_eval, ) # Step after adding initial obs. This will give us 0 env and agent step. new_episode.step() active_envs.add(env_id) def create_episode(self, env_id: EnvID) -> EpisodeV2: """Creates a new EpisodeV2 instance and returns it. Calls `on_episode_created` callbacks, but does NOT reset the respective sub-environment yet. Args: env_id: Env ID. Returns: The newly created EpisodeV2 instance. """ # Make sure we currently don't have an active episode under this env ID. assert env_id not in self._active_episodes # Create a new episode under the same `env_id` and call the # `on_episode_created` callbacks. new_episode = EpisodeV2( env_id, self._worker.policy_map, self._worker.policy_mapping_fn, worker=self._worker, callbacks=self._callbacks, ) # Call `on_episode_created()` callback. self._callbacks.on_episode_created( worker=self._worker, base_env=self._base_env, policies=self._worker.policy_map, env_index=env_id, episode=new_episode, ) return new_episode def end_episode( self, env_id: EnvID, episode_or_exception: Union[EpisodeV2, Exception] ): """Cleans up an episode that has finished. Args: env_id: Env ID. episode_or_exception: Instance of an episode if it finished successfully. Otherwise, the exception that was thrown, """ # Signal the end of an episode, either successfully with an Episode or # unsuccessfully with an Exception. self._callbacks.on_episode_end( worker=self._worker, base_env=self._base_env, policies=self._worker.policy_map, episode=episode_or_exception, env_index=env_id, ) # Call each (in-memory) policy's Exploration.on_episode_end # method. # Note: This may break the exploration (e.g. ParameterNoise) of # policies in the `policy_map` that have not been recently used # (and are therefore stashed to disk). However, we certainly do not # want to loop through all (even stashed) policies here as that # would counter the purpose of the LRU policy caching. for p in self._worker.policy_map.cache.values(): if getattr(p, "exploration", None) is not None: p.exploration.on_episode_end( policy=p, environment=self._base_env, episode=episode_or_exception, tf_sess=p.get_session(), ) if isinstance(episode_or_exception, EpisodeV2): episode = episode_or_exception if episode.total_agent_steps == 0: # if the key does not exist it means that throughout the episode all # observations were empty (i.e. there was no agent in the env) msg = ( f"Data from episode {episode.episode_id} does not show any agent " f"interactions. Hint: Make sure for at least one timestep in the " f"episode, env.step() returns non-empty values." ) raise ValueError(msg) # Clean up the episode and batch_builder for this env id. if env_id in self._active_episodes: del self._active_episodes[env_id] def _try_build_truncated_episode_multi_agent_batch( self, batch_builder: _PolicyCollectorGroup, episode: EpisodeV2 ) -> Union[None, SampleBatch, MultiAgentBatch]: # Measure batch size in env-steps. if self._count_steps_by == "env_steps": built_steps = batch_builder.env_steps ongoing_steps = episode.active_env_steps # Measure batch-size in agent-steps. else: built_steps = batch_builder.agent_steps ongoing_steps = episode.active_agent_steps # Reached the fragment-len -> We should build an MA-Batch. if built_steps + ongoing_steps >= self._rollout_fragment_length: if self._count_steps_by != "agent_steps": assert built_steps + ongoing_steps == self._rollout_fragment_length, ( f"built_steps ({built_steps}) + ongoing_steps ({ongoing_steps}) != " f"rollout_fragment_length ({self._rollout_fragment_length})." ) # If we reached the fragment-len only because of `episode_id` # (still ongoing) -> postprocess `episode_id` first. if built_steps < self._rollout_fragment_length: episode.postprocess_episode(batch_builder=batch_builder, is_done=False) # If builder has collected some data, # build the MA-batch and add to return values. if batch_builder.agent_steps > 0: return _build_multi_agent_batch( episode.episode_id, batch_builder, self._large_batch_threshold, self._multiple_episodes_in_batch, ) # No batch-builder: # We have reached the rollout-fragment length w/o any agent # steps! Warn that the environment may never request any # actions from any agents. elif log_once("no_agent_steps"): logger.warning( "Your environment seems to be stepping w/o ever " "emitting agent observations (agents are never " "requested to act)!" ) return None def _do_policy_eval( self, to_eval: Dict[PolicyID, List[AgentConnectorDataType]], ) -> Dict[PolicyID, PolicyOutputType]: """Call compute_actions on collected episode data to get next action. Args: to_eval: Mapping of policy IDs to lists of AgentConnectorDataType objects (items in these lists will be the batch's items for the model forward pass). Returns: Dict mapping PolicyIDs to compute_actions_from_input_dict() outputs. """ policies = self._worker.policy_map # In case policy map has changed, try to find the new policy that # should handle all these per-agent eval data. # Throws exception if these agents are mapped to multiple different # policies now. def _try_find_policy_again(eval_data: AgentConnectorDataType): policy_id = None for d in eval_data: episode = self._active_episodes[d.env_id] # Force refresh policy mapping on the episode. pid = episode.policy_for(d.agent_id, refresh=True) if policy_id is not None and pid != policy_id: raise ValueError( "Policy map changed. The list of eval data that was handled " f"by a same policy is now handled by policy {pid} " "and {policy_id}. " "Please don't do this in the middle of an episode." ) policy_id = pid return _get_or_raise(self._worker.policy_map, policy_id) eval_results: Dict[PolicyID, TensorStructType] = {} for policy_id, eval_data in to_eval.items(): # In case the policyID has been removed from this worker, we need to # re-assign policy_id and re-lookup the Policy object to use. try: policy: Policy = _get_or_raise(policies, policy_id) except ValueError: # policy_mapping_fn from the worker may have already been # changed (mapping fn not staying constant within one episode). policy: Policy = _try_find_policy_again(eval_data) input_dict = _batch_inference_sample_batches( [d.data.sample_batch for d in eval_data] ) eval_results[policy_id] = policy.compute_actions_from_input_dict( input_dict, timestep=policy.global_timestep, episodes=[self._active_episodes[t.env_id] for t in eval_data], ) return eval_results def _process_policy_eval_results( self, active_envs: Set[EnvID], to_eval: Dict[PolicyID, List[AgentConnectorDataType]], eval_results: Dict[PolicyID, PolicyOutputType], off_policy_actions: MultiEnvDict, ): """Process the output of policy neural network evaluation. Records policy evaluation results into agent connectors and returns replies to send back to agents in the env. Args: active_envs: Set of env IDs that are still active. to_eval: Mapping of policy IDs to lists of AgentConnectorDataType objects. eval_results: Mapping of policy IDs to list of actions, rnn-out states, extra-action-fetches dicts. off_policy_actions: Doubly keyed dict of env-ids -> agent ids -> off-policy-action, returned by a `BaseEnv.poll()` call. Returns: Nested dict of env id -> agent id -> actions to be sent to Env (np.ndarrays). """ actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = defaultdict(dict) for env_id in active_envs: actions_to_send[env_id] = {} # at minimum send empty dict # types: PolicyID, List[AgentConnectorDataType] for policy_id, eval_data in to_eval.items(): actions: TensorStructType = eval_results[policy_id][0] actions = convert_to_numpy(actions) rnn_out: StateBatches = eval_results[policy_id][1] extra_action_out: dict = eval_results[policy_id][2] # In case actions is a list (representing the 0th dim of a batch of # primitive actions), try converting it first. if isinstance(actions, list): actions = np.array(actions) # Split action-component batches into single action rows. actions: List[EnvActionType] = unbatch(actions) policy: Policy = _get_or_raise(self._worker.policy_map, policy_id) assert ( policy.agent_connectors and policy.action_connectors ), "EnvRunnerV2 requires action connectors to work." # types: int, EnvActionType for i, action in enumerate(actions): env_id: int = eval_data[i].env_id agent_id: AgentID = eval_data[i].agent_id input_dict: TensorStructType = eval_data[i].data.raw_dict rnn_states: List[StateBatches] = tree.map_structure( lambda x, i=i: x[i], rnn_out ) # extra_action_out could be a nested dict fetches: Dict = tree.map_structure( lambda x, i=i: x[i], extra_action_out ) # Post-process policy output by running them through action connectors. ac_data = ActionConnectorDataType( env_id, agent_id, input_dict, (action, rnn_states, fetches) ) action_to_send, rnn_states, fetches = policy.action_connectors( ac_data ).output # The action we want to buffer is the direct output of # compute_actions_from_input_dict() here. This is because we want to # send the unsqushed actions to the environment while learning and # possibly basing subsequent actions on the squashed actions. action_to_buffer = ( action if env_id not in off_policy_actions or agent_id not in off_policy_actions[env_id] else off_policy_actions[env_id][agent_id] ) # Notify agent connectors with this new policy output. # Necessary for state buffering agent connectors, for example. ac_data: ActionConnectorDataType = ActionConnectorDataType( env_id, agent_id, input_dict, (action_to_buffer, rnn_states, fetches), ) policy.agent_connectors.on_policy_output(ac_data) assert agent_id not in actions_to_send[env_id] actions_to_send[env_id][agent_id] = action_to_send return actions_to_send def _maybe_render(self): """Visualize environment.""" # Check if we should render. if not self._render or not self._simple_image_viewer: return t5 = time.time() # Render can either return an RGB image (uint8 [w x h x 3] numpy # array) or take care of rendering itself (returning True). rendered = self._base_env.try_render() # Rendering returned an image -> Display it in a SimpleImageViewer. if isinstance(rendered, np.ndarray) and len(rendered.shape) == 3: self._simple_image_viewer.imshow(rendered) elif rendered not in [True, False, None]: raise ValueError( f"The env's ({self._base_env}) `try_render()` method returned an" " unsupported value! Make sure you either return a " "uint8/w x h x 3 (RGB) image or handle rendering in a " "window and then return `True`." ) self._perf_stats.incr("env_render_time", time.time() - t5) def _fetch_atari_metrics(base_env: BaseEnv) -> List[RolloutMetrics]: """Atari games have multiple logical episodes, one per life. However, for metrics reporting we count full episodes, all lives included. """ sub_environments = base_env.get_sub_environments() if not sub_environments: return None atari_out = [] for sub_env in sub_environments: monitor = get_wrapper_by_cls(sub_env, MonitorEnv) if not monitor: return None for eps_rew, eps_len in monitor.next_episode_results(): atari_out.append(RolloutMetrics(eps_len, eps_rew)) return atari_out def _get_or_raise( mapping: Dict[PolicyID, Union[Policy, Preprocessor, Filter]], policy_id: PolicyID ) -> Union[Policy, Preprocessor, Filter]: """Returns an object under key `policy_id` in `mapping`. Args: mapping (Dict[PolicyID, Union[Policy, Preprocessor, Filter]]): The mapping dict from policy id (str) to actual object (Policy, Preprocessor, etc.). policy_id: The policy ID to lookup. Returns: Union[Policy, Preprocessor, Filter]: The found object. Raises: ValueError: If `policy_id` cannot be found in `mapping`. """ if policy_id not in mapping: raise ValueError( "Could not find policy for agent: PolicyID `{}` not found " "in policy map, whose keys are `{}`.".format(policy_id, mapping.keys()) ) return mapping[policy_id]