| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208 |
- import logging
- from typing import List, Optional, Union
- import tree
- from ray.rllib.env.env_runner_group import EnvRunnerGroup
- from ray.rllib.policy.sample_batch import (
- DEFAULT_POLICY_ID,
- SampleBatch,
- concat_samples,
- )
- from ray.rllib.utils.annotations import ExperimentalAPI, OldAPIStack
- from ray.rllib.utils.metrics import NUM_AGENT_STEPS_SAMPLED, NUM_ENV_STEPS_SAMPLED
- from ray.rllib.utils.sgd import standardized
- from ray.rllib.utils.typing import EpisodeType, SampleBatchType
- logger = logging.getLogger(__name__)
- @ExperimentalAPI
- def synchronous_parallel_sample(
- *,
- worker_set: EnvRunnerGroup,
- max_agent_steps: Optional[int] = None,
- max_env_steps: Optional[int] = None,
- concat: bool = True,
- sample_timeout_s: Optional[float] = None,
- random_actions: bool = False,
- _uses_new_env_runners: bool = False,
- _return_metrics: bool = False,
- ) -> Union[List[SampleBatchType], SampleBatchType, List[EpisodeType], EpisodeType]:
- """Runs parallel and synchronous rollouts on all remote workers.
- Waits for all workers to return from the remote calls.
- If no remote workers exist (num_workers == 0), use the local worker
- for sampling.
- Alternatively to calling `worker.sample.remote()`, the user can provide a
- `remote_fn()`, which will be applied to the worker(s) instead.
- Args:
- worker_set: The EnvRunnerGroup to use for sampling.
- remote_fn: If provided, use `worker.apply.remote(remote_fn)` instead
- of `worker.sample.remote()` to generate the requests.
- max_agent_steps: Optional number of agent steps to be included in the
- final batch or list of episodes.
- max_env_steps: Optional number of environment steps to be included in the
- final batch or list of episodes.
- concat: Whether to aggregate all resulting batches or episodes. in case of
- batches the list of batches is concatinated at the end. in case of
- episodes all episode lists from workers are flattened into a single list.
- sample_timeout_s: The timeout in sec to use on the `foreach_env_runner` call.
- After this time, the call will return with a result (or not if all
- EnvRunners are stalling). If None, will block indefinitely and not timeout.
- _uses_new_env_runners: Whether the new `EnvRunner API` is used. In this case
- episodes instead of `SampleBatch` objects are returned.
- Returns:
- The list of collected sample batch types or episode types (one for each parallel
- rollout worker in the given `worker_set`).
- .. testcode::
- # Define an RLlib Algorithm.
- from ray.rllib.algorithms.ppo import PPO, PPOConfig
- config = (
- PPOConfig()
- .environment("CartPole-v1")
- )
- algorithm = config.build()
- # 2 remote EnvRunners (num_env_runners=2):
- episodes = synchronous_parallel_sample(
- worker_set=algorithm.env_runner_group,
- _uses_new_env_runners=True,
- concat=False,
- )
- print(len(episodes))
- .. testoutput::
- 2
- """
- # Only allow one of `max_agent_steps` or `max_env_steps` to be defined.
- assert not (max_agent_steps is not None and max_env_steps is not None)
- agent_or_env_steps = 0
- max_agent_or_env_steps = max_agent_steps or max_env_steps or None
- sample_batches_or_episodes = []
- all_stats_dicts = []
- random_action_kwargs = {} if not random_actions else {"random_actions": True}
- # Stop collecting batches as soon as one criterium is met.
- while (max_agent_or_env_steps is None and agent_or_env_steps == 0) or (
- max_agent_or_env_steps is not None
- and agent_or_env_steps < max_agent_or_env_steps
- ):
- # No remote workers in the set -> Use local worker for collecting
- # samples.
- if worker_set.num_remote_workers() <= 0:
- sampled_data = [worker_set.local_env_runner.sample(**random_action_kwargs)]
- if _return_metrics:
- stats_dicts = [worker_set.local_env_runner.get_metrics()]
- # Loop over remote workers' `sample()` method in parallel.
- else:
- sampled_data = worker_set.foreach_env_runner(
- (
- (lambda w: w.sample(**random_action_kwargs))
- if not _return_metrics
- else (lambda w: (w.sample(**random_action_kwargs), w.get_metrics()))
- ),
- local_env_runner=False,
- timeout_seconds=sample_timeout_s,
- )
- # Nothing was returned (maybe all workers are stalling) or no healthy
- # remote workers left: Break.
- # There is no point staying in this loop, since we will not be able to
- # get any new samples if we don't have any healthy remote workers left.
- if not sampled_data or worker_set.num_healthy_remote_workers() <= 0:
- if not sampled_data:
- logger.warning(
- "No samples returned from remote workers. If you have a "
- "slow environment or model, consider increasing the "
- "`sample_timeout_s` or decreasing the "
- "`rollout_fragment_length` in `AlgorithmConfig.env_runners()."
- )
- elif worker_set.num_healthy_remote_workers() <= 0:
- logger.warning(
- "No healthy remote workers left. Trying to restore workers ..."
- )
- break
- if _return_metrics:
- stats_dicts = [s[1] for s in sampled_data]
- sampled_data = [s[0] for s in sampled_data]
- # Update our counters for the stopping criterion of the while loop.
- if _return_metrics:
- if max_agent_steps:
- agent_or_env_steps += sum(
- int(agent_stat)
- for stat_dict in stats_dicts
- for agent_stat in stat_dict[NUM_AGENT_STEPS_SAMPLED].values()
- )
- else:
- agent_or_env_steps += sum(
- int(stat_dict[NUM_ENV_STEPS_SAMPLED]) for stat_dict in stats_dicts
- )
- sample_batches_or_episodes.extend(sampled_data)
- all_stats_dicts.extend(stats_dicts)
- else:
- for batch_or_episode in sampled_data:
- if max_agent_steps:
- agent_or_env_steps += (
- sum(e.agent_steps() for e in batch_or_episode)
- if _uses_new_env_runners
- else batch_or_episode.agent_steps()
- )
- else:
- agent_or_env_steps += (
- sum(e.env_steps() for e in batch_or_episode)
- if _uses_new_env_runners
- else batch_or_episode.env_steps()
- )
- sample_batches_or_episodes.append(batch_or_episode)
- # Break out (and ignore the remaining samples) if max timesteps (batch
- # size) reached. We want to avoid collecting batches that are too large
- # only because of a failed/restarted worker causing a second iteration
- # of the main loop.
- if (
- max_agent_or_env_steps is not None
- and agent_or_env_steps >= max_agent_or_env_steps
- ):
- break
- if concat is True:
- # If we have episodes flatten the episode list.
- if _uses_new_env_runners:
- sample_batches_or_episodes = tree.flatten(sample_batches_or_episodes)
- # Otherwise we concatenate the `SampleBatch` objects
- else:
- sample_batches_or_episodes = concat_samples(sample_batches_or_episodes)
- if _return_metrics:
- return sample_batches_or_episodes, all_stats_dicts
- return sample_batches_or_episodes
- @OldAPIStack
- def standardize_fields(samples: SampleBatchType, fields: List[str]) -> SampleBatchType:
- """Standardize fields of the given SampleBatch"""
- wrapped = False
- if isinstance(samples, SampleBatch):
- samples = samples.as_multi_agent()
- wrapped = True
- for policy_id in samples.policy_batches:
- batch = samples.policy_batches[policy_id]
- for field in fields:
- if field in batch:
- batch[field] = standardized(batch[field])
- if wrapped:
- samples = samples.policy_batches[DEFAULT_POLICY_ID]
- return samples
|