rollout_ops.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. import logging
  2. from typing import List, Optional, Union
  3. import tree
  4. from ray.rllib.env.env_runner_group import EnvRunnerGroup
  5. from ray.rllib.policy.sample_batch import (
  6. DEFAULT_POLICY_ID,
  7. SampleBatch,
  8. concat_samples,
  9. )
  10. from ray.rllib.utils.annotations import ExperimentalAPI, OldAPIStack
  11. from ray.rllib.utils.metrics import NUM_AGENT_STEPS_SAMPLED, NUM_ENV_STEPS_SAMPLED
  12. from ray.rllib.utils.sgd import standardized
  13. from ray.rllib.utils.typing import EpisodeType, SampleBatchType
  14. logger = logging.getLogger(__name__)
  15. @ExperimentalAPI
  16. def synchronous_parallel_sample(
  17. *,
  18. worker_set: EnvRunnerGroup,
  19. max_agent_steps: Optional[int] = None,
  20. max_env_steps: Optional[int] = None,
  21. concat: bool = True,
  22. sample_timeout_s: Optional[float] = None,
  23. random_actions: bool = False,
  24. _uses_new_env_runners: bool = False,
  25. _return_metrics: bool = False,
  26. ) -> Union[List[SampleBatchType], SampleBatchType, List[EpisodeType], EpisodeType]:
  27. """Runs parallel and synchronous rollouts on all remote workers.
  28. Waits for all workers to return from the remote calls.
  29. If no remote workers exist (num_workers == 0), use the local worker
  30. for sampling.
  31. Alternatively to calling `worker.sample.remote()`, the user can provide a
  32. `remote_fn()`, which will be applied to the worker(s) instead.
  33. Args:
  34. worker_set: The EnvRunnerGroup to use for sampling.
  35. remote_fn: If provided, use `worker.apply.remote(remote_fn)` instead
  36. of `worker.sample.remote()` to generate the requests.
  37. max_agent_steps: Optional number of agent steps to be included in the
  38. final batch or list of episodes.
  39. max_env_steps: Optional number of environment steps to be included in the
  40. final batch or list of episodes.
  41. concat: Whether to aggregate all resulting batches or episodes. in case of
  42. batches the list of batches is concatinated at the end. in case of
  43. episodes all episode lists from workers are flattened into a single list.
  44. sample_timeout_s: The timeout in sec to use on the `foreach_env_runner` call.
  45. After this time, the call will return with a result (or not if all
  46. EnvRunners are stalling). If None, will block indefinitely and not timeout.
  47. _uses_new_env_runners: Whether the new `EnvRunner API` is used. In this case
  48. episodes instead of `SampleBatch` objects are returned.
  49. Returns:
  50. The list of collected sample batch types or episode types (one for each parallel
  51. rollout worker in the given `worker_set`).
  52. .. testcode::
  53. # Define an RLlib Algorithm.
  54. from ray.rllib.algorithms.ppo import PPO, PPOConfig
  55. config = (
  56. PPOConfig()
  57. .environment("CartPole-v1")
  58. )
  59. algorithm = config.build()
  60. # 2 remote EnvRunners (num_env_runners=2):
  61. episodes = synchronous_parallel_sample(
  62. worker_set=algorithm.env_runner_group,
  63. _uses_new_env_runners=True,
  64. concat=False,
  65. )
  66. print(len(episodes))
  67. .. testoutput::
  68. 2
  69. """
  70. # Only allow one of `max_agent_steps` or `max_env_steps` to be defined.
  71. assert not (max_agent_steps is not None and max_env_steps is not None)
  72. agent_or_env_steps = 0
  73. max_agent_or_env_steps = max_agent_steps or max_env_steps or None
  74. sample_batches_or_episodes = []
  75. all_stats_dicts = []
  76. random_action_kwargs = {} if not random_actions else {"random_actions": True}
  77. # Stop collecting batches as soon as one criterium is met.
  78. while (max_agent_or_env_steps is None and agent_or_env_steps == 0) or (
  79. max_agent_or_env_steps is not None
  80. and agent_or_env_steps < max_agent_or_env_steps
  81. ):
  82. # No remote workers in the set -> Use local worker for collecting
  83. # samples.
  84. if worker_set.num_remote_workers() <= 0:
  85. sampled_data = [worker_set.local_env_runner.sample(**random_action_kwargs)]
  86. if _return_metrics:
  87. stats_dicts = [worker_set.local_env_runner.get_metrics()]
  88. # Loop over remote workers' `sample()` method in parallel.
  89. else:
  90. sampled_data = worker_set.foreach_env_runner(
  91. (
  92. (lambda w: w.sample(**random_action_kwargs))
  93. if not _return_metrics
  94. else (lambda w: (w.sample(**random_action_kwargs), w.get_metrics()))
  95. ),
  96. local_env_runner=False,
  97. timeout_seconds=sample_timeout_s,
  98. )
  99. # Nothing was returned (maybe all workers are stalling) or no healthy
  100. # remote workers left: Break.
  101. # There is no point staying in this loop, since we will not be able to
  102. # get any new samples if we don't have any healthy remote workers left.
  103. if not sampled_data or worker_set.num_healthy_remote_workers() <= 0:
  104. if not sampled_data:
  105. logger.warning(
  106. "No samples returned from remote workers. If you have a "
  107. "slow environment or model, consider increasing the "
  108. "`sample_timeout_s` or decreasing the "
  109. "`rollout_fragment_length` in `AlgorithmConfig.env_runners()."
  110. )
  111. elif worker_set.num_healthy_remote_workers() <= 0:
  112. logger.warning(
  113. "No healthy remote workers left. Trying to restore workers ..."
  114. )
  115. break
  116. if _return_metrics:
  117. stats_dicts = [s[1] for s in sampled_data]
  118. sampled_data = [s[0] for s in sampled_data]
  119. # Update our counters for the stopping criterion of the while loop.
  120. if _return_metrics:
  121. if max_agent_steps:
  122. agent_or_env_steps += sum(
  123. int(agent_stat)
  124. for stat_dict in stats_dicts
  125. for agent_stat in stat_dict[NUM_AGENT_STEPS_SAMPLED].values()
  126. )
  127. else:
  128. agent_or_env_steps += sum(
  129. int(stat_dict[NUM_ENV_STEPS_SAMPLED]) for stat_dict in stats_dicts
  130. )
  131. sample_batches_or_episodes.extend(sampled_data)
  132. all_stats_dicts.extend(stats_dicts)
  133. else:
  134. for batch_or_episode in sampled_data:
  135. if max_agent_steps:
  136. agent_or_env_steps += (
  137. sum(e.agent_steps() for e in batch_or_episode)
  138. if _uses_new_env_runners
  139. else batch_or_episode.agent_steps()
  140. )
  141. else:
  142. agent_or_env_steps += (
  143. sum(e.env_steps() for e in batch_or_episode)
  144. if _uses_new_env_runners
  145. else batch_or_episode.env_steps()
  146. )
  147. sample_batches_or_episodes.append(batch_or_episode)
  148. # Break out (and ignore the remaining samples) if max timesteps (batch
  149. # size) reached. We want to avoid collecting batches that are too large
  150. # only because of a failed/restarted worker causing a second iteration
  151. # of the main loop.
  152. if (
  153. max_agent_or_env_steps is not None
  154. and agent_or_env_steps >= max_agent_or_env_steps
  155. ):
  156. break
  157. if concat is True:
  158. # If we have episodes flatten the episode list.
  159. if _uses_new_env_runners:
  160. sample_batches_or_episodes = tree.flatten(sample_batches_or_episodes)
  161. # Otherwise we concatenate the `SampleBatch` objects
  162. else:
  163. sample_batches_or_episodes = concat_samples(sample_batches_or_episodes)
  164. if _return_metrics:
  165. return sample_batches_or_episodes, all_stats_dicts
  166. return sample_batches_or_episodes
  167. @OldAPIStack
  168. def standardize_fields(samples: SampleBatchType, fields: List[str]) -> SampleBatchType:
  169. """Standardize fields of the given SampleBatch"""
  170. wrapped = False
  171. if isinstance(samples, SampleBatch):
  172. samples = samples.as_multi_agent()
  173. wrapped = True
  174. for policy_id in samples.policy_batches:
  175. batch = samples.policy_batches[policy_id]
  176. for field in fields:
  177. if field in batch:
  178. batch[field] = standardized(batch[field])
  179. if wrapped:
  180. samples = samples.policy_batches[DEFAULT_POLICY_ID]
  181. return samples