env_runner.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. import abc
  2. import logging
  3. from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
  4. import gymnasium as gym
  5. import tree # pip install dm_tree
  6. import ray
  7. from ray.rllib.core import COMPONENT_RL_MODULE
  8. from ray.rllib.env.env_errors import StepFailedRecreateEnvError
  9. from ray.rllib.utils.actor_manager import FaultAwareApply
  10. from ray.rllib.utils.debug import update_global_seed_if_necessary
  11. from ray.rllib.utils.framework import try_import_tf
  12. from ray.rllib.utils.metrics import ENV_RESET_TIMER, ENV_STEP_TIMER
  13. from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
  14. from ray.rllib.utils.torch_utils import convert_to_torch_tensor
  15. from ray.rllib.utils.typing import StateDict, TensorType
  16. from ray.util.annotations import DeveloperAPI, PublicAPI
  17. from ray.util.metrics import Counter
  18. if TYPE_CHECKING:
  19. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  20. logger = logging.getLogger("ray.rllib")
  21. tf1, tf, _ = try_import_tf()
  22. ENV_RESET_FAILURE = "env_reset_failure"
  23. ENV_STEP_FAILURE = "env_step_failure"
  24. NUM_ENV_STEP_FAILURES_LIFETIME = "num_env_step_failures"
  25. # TODO (sven): As soon as RolloutWorker is no longer supported, make this base class
  26. # a Checkpointable. Currently, only some of its subclasses are Checkpointables.
  27. @PublicAPI(stability="alpha")
  28. class EnvRunner(FaultAwareApply, metaclass=abc.ABCMeta):
  29. """Base class for distributed RL-style data collection from an environment.
  30. The EnvRunner API's core functionalities can be summarized as:
  31. - Gets configured via passing a AlgorithmConfig object to the constructor.
  32. Normally, subclasses of EnvRunner then construct their own environment (possibly
  33. vectorized) copies and RLModules/Policies and use the latter to step through the
  34. environment in order to collect training data.
  35. - Clients of EnvRunner can use the `sample()` method to collect data for training
  36. from the environment(s).
  37. - EnvRunner offers parallelism via creating n remote Ray Actors based on this class.
  38. Use `ray.remote([resources])(EnvRunner)` method to create the corresponding Ray
  39. remote class. Then instantiate n Actors using the Ray `[ctor].remote(...)` syntax.
  40. - EnvRunner clients can get information about the server/node on which the
  41. individual Actors are running.
  42. """
  43. def __init__(self, *, config: "AlgorithmConfig", **kwargs):
  44. """Initializes an EnvRunner instance.
  45. Args:
  46. config: The AlgorithmConfig to use to setup this EnvRunner.
  47. **kwargs: Forward compatibility kwargs.
  48. """
  49. self.config: AlgorithmConfig = config.copy(copy_frozen=False)
  50. self.num_env_steps_sampled_lifetime = 0
  51. # Get the worker index on which this instance is running.
  52. # TODO (sven): We should make these c'tor named args.
  53. self.worker_index: int = kwargs.get("worker_index")
  54. self.num_workers: int = kwargs.get("num_workers", self.config.num_env_runners)
  55. self.env = None
  56. # Create a MetricsLogger object for logging custom stats.
  57. self.metrics: MetricsLogger = MetricsLogger(
  58. stats_cls_lookup=config.stats_cls_lookup,
  59. root=False,
  60. )
  61. super().__init__()
  62. # This eager check is necessary for certain all-framework tests
  63. # that use tf's eager_mode() context generator.
  64. if (
  65. tf1
  66. and (self.config.framework_str == "tf2" or config.enable_tf1_exec_eagerly)
  67. and not tf1.executing_eagerly()
  68. ):
  69. tf1.enable_eager_execution()
  70. # Determine actual seed for this particular worker based on worker index AND
  71. # whether it's an eval worker.
  72. self._seed: Optional[int] = None
  73. if self.config.seed is not None:
  74. self._seed = int(
  75. self.config.seed
  76. + (self.worker_index or 0)
  77. # Eval workers get a +1M seed.
  78. + (1e6 * self.config.in_evaluation)
  79. )
  80. # Seed everything (random, numpy, torch, tf), if `seed` is provided.
  81. update_global_seed_if_necessary(
  82. framework=self.config.framework_str,
  83. seed=self._seed,
  84. )
  85. # Ray metrics
  86. self._metrics_num_try_env_step = Counter(
  87. name="rllib_env_runner_num_try_env_step_counter",
  88. description="Number of env.step() calls attempted in this Env Runner.",
  89. tag_keys=("rllib",),
  90. )
  91. self._metrics_num_try_env_step.set_default_tags(
  92. {"rllib": self.__class__.__name__}
  93. )
  94. self._metrics_num_env_steps_sampled = Counter(
  95. name="rllib_env_runner_num_env_steps_sampled_counter",
  96. description="Number of env steps sampled in this Env Runner.",
  97. tag_keys=("rllib",),
  98. )
  99. self._metrics_num_env_steps_sampled.set_default_tags(
  100. {"rllib": self.__class__.__name__}
  101. )
  102. self._shared_data = None
  103. @abc.abstractmethod
  104. def assert_healthy(self):
  105. """Checks that self.__init__() has been completed properly.
  106. Useful in case an `EnvRunner` is run as @ray.remote (Actor) and the owner
  107. would like to make sure the Ray Actor has been properly initialized.
  108. Raises:
  109. AssertionError: If the EnvRunner Actor has NOT been properly initialized.
  110. """
  111. # TODO: Make this an abstract method that must be implemented.
  112. def make_env(self):
  113. """Creates the RL environment for this EnvRunner and assigns it to `self.env`.
  114. Note that users should be able to change the EnvRunner's config (e.g. change
  115. `self.config.env_config`) and then call this method to create new environments
  116. with the updated configuration.
  117. It should also be called after a failure of an earlier env in order to clean up
  118. the existing env (for example `close()` it), re-create a new one, and then
  119. continue sampling with that new env.
  120. """
  121. pass
  122. # TODO: Make this an abstract method that must be implemented.
  123. def make_module(self):
  124. """Creates the RLModule for this EnvRunner and assigns it to `self.module`.
  125. Note that users should be able to change the EnvRunner's config (e.g. change
  126. `self.config.rl_module_spec`) and then call this method to create a new RLModule
  127. with the updated configuration.
  128. """
  129. pass
  130. @abc.abstractmethod
  131. def sample(self, **kwargs) -> Any:
  132. """Returns experiences (of any form) sampled from this EnvRunner.
  133. The exact nature and size of collected data are defined via the EnvRunner's
  134. config and may be overridden by the given arguments.
  135. Args:
  136. **kwargs: Forward compatibility kwargs.
  137. Returns:
  138. The collected experience in any form.
  139. """
  140. # TODO (sven): Make this an abstract method that must be overridden.
  141. def get_metrics(self) -> Any:
  142. """Returns metrics (in any form) of the thus far collected, completed episodes.
  143. Returns:
  144. Metrics of any form.
  145. """
  146. pass
  147. @DeveloperAPI
  148. def sample_get_state_and_metrics(
  149. self,
  150. ) -> Tuple[ray.ObjectRef, StateDict, StateDict]:
  151. """Convenience method for fast, async algorithms.
  152. Use this in Algorithms that need to sample Episode lists as ray.ObjectRef, but
  153. also require (in the same remote call) the metrics and the EnvRunner states,
  154. except for the module weights.
  155. """
  156. _episodes = self.sample()
  157. # Get the EnvRunner's connector states.
  158. _connector_states = self.get_state(not_components=COMPONENT_RL_MODULE)
  159. _metrics = self.get_metrics()
  160. # Return episode lists by reference so we don't have to send them to the
  161. # main algo process, but to the Aggregator- or Learner actors directly.
  162. return ray.put(_episodes), _connector_states, _metrics
  163. @abc.abstractmethod
  164. def get_spaces(self) -> Dict[str, Tuple[gym.Space, gym.Space]]:
  165. """Returns a dict mapping ModuleIDs to 2-tuples of obs- and action space."""
  166. def stop(self) -> None:
  167. """Releases all resources used by this EnvRunner.
  168. For example, when using a gym.Env in this EnvRunner, you should make sure
  169. that its `close()` method is called.
  170. """
  171. pass
  172. def __del__(self) -> None:
  173. """If this Actor is deleted, clears all resources used by it."""
  174. pass
  175. def _try_env_reset(
  176. self,
  177. *,
  178. seed: Optional[int] = None,
  179. options: Optional[dict] = None,
  180. ) -> Tuple[Any, Any]:
  181. """Tries resetting the env and - if an error occurs - handles it gracefully.
  182. Args:
  183. seed: An optional seed (int) to be passed to the Env.reset() call.
  184. options: An optional options-dict to be passed to the Env.reset() call.
  185. Returns:
  186. The results of calling `Env.reset()`, which is a tuple of observations and
  187. info dicts.
  188. Raises:
  189. Exception: In case `config.restart_failed_sub_environments` is False and
  190. `Env.reset()` resulted in an error.
  191. """
  192. # Try to reset.
  193. try:
  194. with self.metrics.log_time(ENV_RESET_TIMER):
  195. obs, infos = self.env.reset(seed=seed, options=options)
  196. # Everything ok -> return.
  197. return obs, infos
  198. # Error.
  199. except Exception as e:
  200. # If user wants to simply restart the env -> recreate env and try again
  201. # (calling this method recursively until success).
  202. if self.config.restart_failed_sub_environments:
  203. logger.exception(
  204. "Resetting the env resulted in an error! The original error "
  205. f"is: {e.args[0]}"
  206. )
  207. # Recreate the env and simply try again.
  208. self.make_env()
  209. return self._try_env_reset(seed=seed, options=options)
  210. else:
  211. raise e
  212. def _try_env_step(self, actions):
  213. """Tries stepping the env and - if an error occurs - handles it gracefully."""
  214. try:
  215. with self.metrics.log_time(ENV_STEP_TIMER):
  216. results = self.env.step(actions)
  217. self._log_env_steps(metric=self._metrics_num_try_env_step, num_steps=1)
  218. return results
  219. except Exception as e:
  220. self.metrics.log_value(
  221. NUM_ENV_STEP_FAILURES_LIFETIME, 1, reduce="lifetime_sum"
  222. )
  223. if self.config.restart_failed_sub_environments:
  224. if not isinstance(e, StepFailedRecreateEnvError):
  225. logger.exception(
  226. f"RLlib {self.__class__.__name__}: Environment step failed. Will force reset env(s) in this EnvRunner. The original error is: {e}"
  227. )
  228. # Recreate the env.
  229. self.make_env()
  230. # And return that the stepping failed. The caller will then handle
  231. # specific cleanup operations (for example discarding thus-far collected
  232. # data and repeating the step attempt).
  233. return ENV_STEP_FAILURE
  234. else:
  235. logger.exception(
  236. f"RLlib {self.__class__.__name__}: Environment step failed and "
  237. "'config.restart_failed_sub_environments' is False. "
  238. "This env will not be recreated. "
  239. "Consider setting 'fault_tolerance(restart_failed_sub_environments=True)' in your AlgorithmConfig "
  240. "in order to automatically re-create and force-reset an env."
  241. f"The original error type: {type(e)}. "
  242. f"{e}"
  243. )
  244. raise RuntimeError from e
  245. def _convert_to_tensor(self, struct) -> TensorType:
  246. """Converts structs to a framework-specific tensor."""
  247. if self.config.framework_str == "torch":
  248. return convert_to_torch_tensor(struct)
  249. else:
  250. return tree.map_structure(tf.convert_to_tensor, struct)
  251. def _log_env_steps(self, metric: Counter, num_steps: int) -> None:
  252. if num_steps > 0:
  253. metric.inc(value=num_steps)
  254. else:
  255. logger.warning(
  256. f"RLlib {self.__class__.__name__}: Skipping Prometheus logging for metric '{metric.info['name']}'. "
  257. f"Received num_steps={num_steps}, but the number of steps must be greater than 0."
  258. )
  259. def _reset_envs_and_episodes(self, explore: bool):
  260. """Helper method to reset the envs, ongoing episodes and shared data.
  261. This resets the global env_ts and agent_ts variables and deletes ongoing episodes.
  262. The done episodes are preserved.
  263. Args:
  264. explore: Whether we sample in exploration or inference mode.
  265. """
  266. self._ongoing_episodes = [None for _ in range(self.num_envs)]
  267. self._shared_data = {}
  268. self._reset_envs(self._ongoing_episodes, self._shared_data, explore)
  269. # We just reset the env. Don't have to force this again in the next
  270. # call to `self._sample_timesteps()`.
  271. self._needs_initial_reset = False