sampler.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. import logging
  2. import queue
  3. from abc import ABCMeta, abstractmethod
  4. from collections import defaultdict, namedtuple
  5. from typing import (
  6. TYPE_CHECKING,
  7. Any,
  8. List,
  9. Optional,
  10. Type,
  11. Union,
  12. )
  13. from ray._common.deprecation import DEPRECATED_VALUE, deprecation_warning
  14. from ray.rllib.env.base_env import BaseEnv, convert_to_base_env
  15. from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
  16. from ray.rllib.evaluation.collectors.simple_list_collector import SimpleListCollector
  17. from ray.rllib.evaluation.env_runner_v2 import EnvRunnerV2, _PerfStats
  18. from ray.rllib.evaluation.metrics import RolloutMetrics
  19. from ray.rllib.offline import InputReader
  20. from ray.rllib.policy.sample_batch import concat_samples
  21. from ray.rllib.utils.annotations import OldAPIStack, override
  22. from ray.rllib.utils.framework import try_import_tf
  23. from ray.rllib.utils.typing import SampleBatchType
  24. from ray.util.debug import log_once
  25. if TYPE_CHECKING:
  26. from ray.rllib.callbacks.callbacks import RLlibCallback
  27. from ray.rllib.evaluation.observation_function import ObservationFunction
  28. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  29. tf1, tf, _ = try_import_tf()
  30. logger = logging.getLogger(__name__)
  31. _PolicyEvalData = namedtuple(
  32. "_PolicyEvalData",
  33. ["env_id", "agent_id", "obs", "info", "rnn_state", "prev_action", "prev_reward"],
  34. )
  35. # A batch of RNN states with dimensions [state_index, batch, state_object].
  36. StateBatch = List[List[Any]]
  37. class _NewEpisodeDefaultDict(defaultdict):
  38. def __missing__(self, env_id):
  39. if self.default_factory is None:
  40. raise KeyError(env_id)
  41. else:
  42. ret = self[env_id] = self.default_factory(env_id)
  43. return ret
  44. @OldAPIStack
  45. class SamplerInput(InputReader, metaclass=ABCMeta):
  46. """Reads input experiences from an existing sampler."""
  47. @override(InputReader)
  48. def next(self) -> SampleBatchType:
  49. batches = [self.get_data()]
  50. batches.extend(self.get_extra_batches())
  51. if len(batches) == 0:
  52. raise RuntimeError("No data available from sampler.")
  53. return concat_samples(batches)
  54. @abstractmethod
  55. def get_data(self) -> SampleBatchType:
  56. """Called by `self.next()` to return the next batch of data.
  57. Override this in child classes.
  58. Returns:
  59. The next batch of data.
  60. """
  61. raise NotImplementedError
  62. @abstractmethod
  63. def get_metrics(self) -> List[RolloutMetrics]:
  64. """Returns list of episode metrics since the last call to this method.
  65. The list will contain one RolloutMetrics object per completed episode.
  66. Returns:
  67. List of RolloutMetrics objects, one per completed episode since
  68. the last call to this method.
  69. """
  70. raise NotImplementedError
  71. @abstractmethod
  72. def get_extra_batches(self) -> List[SampleBatchType]:
  73. """Returns list of extra batches since the last call to this method.
  74. The list will contain all SampleBatches or
  75. MultiAgentBatches that the user has provided thus-far. Users can
  76. add these "extra batches" to an episode by calling the episode's
  77. `add_extra_batch([SampleBatchType])` method. This can be done from
  78. inside an overridden `Policy.compute_actions_from_input_dict(...,
  79. episodes)` or from a custom callback's `on_episode_[start|step|end]()`
  80. methods.
  81. Returns:
  82. List of SamplesBatches or MultiAgentBatches provided thus-far by
  83. the user since the last call to this method.
  84. """
  85. raise NotImplementedError
  86. @OldAPIStack
  87. class SyncSampler(SamplerInput):
  88. """Sync SamplerInput that collects experiences when `get_data()` is called."""
  89. def __init__(
  90. self,
  91. *,
  92. worker: "RolloutWorker",
  93. env: BaseEnv,
  94. clip_rewards: Union[bool, float],
  95. rollout_fragment_length: int,
  96. count_steps_by: str = "env_steps",
  97. callbacks: "RLlibCallback",
  98. multiple_episodes_in_batch: bool = False,
  99. normalize_actions: bool = True,
  100. clip_actions: bool = False,
  101. observation_fn: Optional["ObservationFunction"] = None,
  102. sample_collector_class: Optional[Type[SampleCollector]] = None,
  103. render: bool = False,
  104. # Obsolete.
  105. policies=None,
  106. policy_mapping_fn=None,
  107. preprocessors=None,
  108. obs_filters=None,
  109. tf_sess=None,
  110. horizon=DEPRECATED_VALUE,
  111. soft_horizon=DEPRECATED_VALUE,
  112. no_done_at_end=DEPRECATED_VALUE,
  113. ):
  114. """Initializes a SyncSampler instance.
  115. Args:
  116. worker: The RolloutWorker that will use this Sampler for sampling.
  117. env: Any Env object. Will be converted into an RLlib BaseEnv.
  118. clip_rewards: True for +/-1.0 clipping,
  119. actual float value for +/- value clipping. False for no
  120. clipping.
  121. rollout_fragment_length: The length of a fragment to collect
  122. before building a SampleBatch from the data and resetting
  123. the SampleBatchBuilder object.
  124. count_steps_by: One of "env_steps" (default) or "agent_steps".
  125. Use "agent_steps", if you want rollout lengths to be counted
  126. by individual agent steps. In a multi-agent env,
  127. a single env_step contains one or more agent_steps, depending
  128. on how many agents are present at any given time in the
  129. ongoing episode.
  130. callbacks: The RLlibCallback object to use when episode
  131. events happen during rollout.
  132. multiple_episodes_in_batch: Whether to pack multiple
  133. episodes into each batch. This guarantees batches will be
  134. exactly `rollout_fragment_length` in size.
  135. normalize_actions: Whether to normalize actions to the
  136. action space's bounds.
  137. clip_actions: Whether to clip actions according to the
  138. given action_space's bounds.
  139. observation_fn: Optional multi-agent observation func to use for
  140. preprocessing observations.
  141. sample_collector_class: An optional SampleCollector sub-class to
  142. use to collect, store, and retrieve environment-, model-,
  143. and sampler data.
  144. render: Whether to try to render the environment after each step.
  145. """
  146. # All of the following arguments are deprecated. They will instead be
  147. # provided via the passed in `worker` arg, e.g. `worker.policy_map`.
  148. if log_once("deprecated_sync_sampler_args"):
  149. if policies is not None:
  150. deprecation_warning(old="policies")
  151. if policy_mapping_fn is not None:
  152. deprecation_warning(old="policy_mapping_fn")
  153. if preprocessors is not None:
  154. deprecation_warning(old="preprocessors")
  155. if obs_filters is not None:
  156. deprecation_warning(old="obs_filters")
  157. if tf_sess is not None:
  158. deprecation_warning(old="tf_sess")
  159. if horizon != DEPRECATED_VALUE:
  160. deprecation_warning(old="horizon", error=True)
  161. if soft_horizon != DEPRECATED_VALUE:
  162. deprecation_warning(old="soft_horizon", error=True)
  163. if no_done_at_end != DEPRECATED_VALUE:
  164. deprecation_warning(old="no_done_at_end", error=True)
  165. self.base_env = convert_to_base_env(env)
  166. self.rollout_fragment_length = rollout_fragment_length
  167. self.extra_batches = queue.Queue()
  168. self.perf_stats = _PerfStats(
  169. ema_coef=worker.config.sampler_perf_stats_ema_coef,
  170. )
  171. if not sample_collector_class:
  172. sample_collector_class = SimpleListCollector
  173. self.sample_collector = sample_collector_class(
  174. worker.policy_map,
  175. clip_rewards,
  176. callbacks,
  177. multiple_episodes_in_batch,
  178. rollout_fragment_length,
  179. count_steps_by=count_steps_by,
  180. )
  181. self.render = render
  182. # Keep a reference to the underlying EnvRunnerV2 instance for
  183. # unit testing purpose.
  184. self._env_runner_obj = EnvRunnerV2(
  185. worker=worker,
  186. base_env=self.base_env,
  187. multiple_episodes_in_batch=multiple_episodes_in_batch,
  188. callbacks=callbacks,
  189. perf_stats=self.perf_stats,
  190. rollout_fragment_length=rollout_fragment_length,
  191. count_steps_by=count_steps_by,
  192. render=self.render,
  193. )
  194. self._env_runner = self._env_runner_obj.run()
  195. self.metrics_queue = queue.Queue()
  196. @override(SamplerInput)
  197. def get_data(self) -> SampleBatchType:
  198. while True:
  199. item = next(self._env_runner)
  200. if isinstance(item, RolloutMetrics):
  201. self.metrics_queue.put(item)
  202. else:
  203. return item
  204. @override(SamplerInput)
  205. def get_metrics(self) -> List[RolloutMetrics]:
  206. completed = []
  207. while True:
  208. try:
  209. completed.append(
  210. self.metrics_queue.get_nowait()._replace(
  211. perf_stats=self.perf_stats.get()
  212. )
  213. )
  214. except queue.Empty:
  215. break
  216. return completed
  217. @override(SamplerInput)
  218. def get_extra_batches(self) -> List[SampleBatchType]:
  219. extra = []
  220. while True:
  221. try:
  222. extra.append(self.extra_batches.get_nowait())
  223. except queue.Empty:
  224. break
  225. return extra