base_env.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. import logging
  2. from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union
  3. import gymnasium as gym
  4. import ray
  5. from ray.rllib.utils.annotations import OldAPIStack
  6. from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiEnvDict
  7. if TYPE_CHECKING:
  8. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  9. ASYNC_RESET_RETURN = "async_reset_return"
  10. logger = logging.getLogger(__name__)
  11. @OldAPIStack
  12. class BaseEnv:
  13. """The lowest-level env interface used by RLlib for sampling.
  14. BaseEnv models multiple agents executing asynchronously in multiple
  15. vectorized sub-environments. A call to `poll()` returns observations from
  16. ready agents keyed by their sub-environment ID and agent IDs, and
  17. actions for those agents can be sent back via `send_actions()`.
  18. All other RLlib supported env types can be converted to BaseEnv.
  19. RLlib handles these conversions internally in RolloutWorker, for example:
  20. gym.Env => rllib.VectorEnv => rllib.BaseEnv
  21. rllib.MultiAgentEnv (is-a gym.Env) => rllib.VectorEnv => rllib.BaseEnv
  22. rllib.ExternalEnv => rllib.BaseEnv
  23. .. testcode::
  24. :skipif: True
  25. MyBaseEnv = ...
  26. env = MyBaseEnv()
  27. obs, rewards, terminateds, truncateds, infos, off_policy_actions = (
  28. env.poll()
  29. )
  30. print(obs)
  31. env.send_actions({
  32. "env_0": {
  33. "car_0": 0,
  34. "car_1": 1,
  35. }, ...
  36. })
  37. obs, rewards, terminateds, truncateds, infos, off_policy_actions = (
  38. env.poll()
  39. )
  40. print(obs)
  41. print(terminateds)
  42. .. testoutput::
  43. {
  44. "env_0": {
  45. "car_0": [2.4, 1.6],
  46. "car_1": [3.4, -3.2],
  47. },
  48. "env_1": {
  49. "car_0": [8.0, 4.1],
  50. },
  51. "env_2": {
  52. "car_0": [2.3, 3.3],
  53. "car_1": [1.4, -0.2],
  54. "car_3": [1.2, 0.1],
  55. },
  56. }
  57. {
  58. "env_0": {
  59. "car_0": [4.1, 1.7],
  60. "car_1": [3.2, -4.2],
  61. }, ...
  62. }
  63. {
  64. "env_0": {
  65. "__all__": False,
  66. "car_0": False,
  67. "car_1": True,
  68. }, ...
  69. }
  70. """
  71. def to_base_env(
  72. self,
  73. make_env: Optional[Callable[[int], EnvType]] = None,
  74. num_envs: int = 1,
  75. remote_envs: bool = False,
  76. remote_env_batch_wait_ms: int = 0,
  77. restart_failed_sub_environments: bool = False,
  78. ) -> "BaseEnv":
  79. """Converts an RLlib-supported env into a BaseEnv object.
  80. Supported types for the `env` arg are gym.Env, BaseEnv,
  81. VectorEnv, MultiAgentEnv, ExternalEnv, or ExternalMultiAgentEnv.
  82. The resulting BaseEnv is always vectorized (contains n
  83. sub-environments) to support batched forward passes, where n may also
  84. be 1. BaseEnv also supports async execution via the `poll` and
  85. `send_actions` methods and thus supports external simulators.
  86. TODO: Support gym3 environments, which are already vectorized.
  87. Args:
  88. env: An already existing environment of any supported env type
  89. to convert/wrap into a BaseEnv. Supported types are gym.Env,
  90. BaseEnv, VectorEnv, MultiAgentEnv, ExternalEnv, and
  91. ExternalMultiAgentEnv.
  92. make_env: A callable taking an int as input (which indicates the
  93. number of individual sub-environments within the final
  94. vectorized BaseEnv) and returning one individual
  95. sub-environment.
  96. num_envs: The number of sub-environments to create in the
  97. resulting (vectorized) BaseEnv. The already existing `env`
  98. will be one of the `num_envs`.
  99. remote_envs: Whether each sub-env should be a @ray.remote actor.
  100. You can set this behavior in your config via the
  101. `remote_worker_envs=True` option.
  102. remote_env_batch_wait_ms: The wait time (in ms) to poll remote
  103. sub-environments for, if applicable. Only used if
  104. `remote_envs` is True.
  105. policy_config: Optional policy config dict.
  106. Returns:
  107. The resulting BaseEnv object.
  108. """
  109. return self
  110. def poll(
  111. self,
  112. ) -> Tuple[
  113. MultiEnvDict,
  114. MultiEnvDict,
  115. MultiEnvDict,
  116. MultiEnvDict,
  117. MultiEnvDict,
  118. MultiEnvDict,
  119. ]:
  120. """Returns observations from ready agents.
  121. All return values are two-level dicts mapping from EnvID to dicts
  122. mapping from AgentIDs to (observation/reward/etc..) values.
  123. The number of agents and sub-environments may vary over time.
  124. Returns:
  125. Tuple consisting of:
  126. New observations for each ready agent.
  127. Reward values for each ready agent. If the episode is just started,
  128. the value will be None.
  129. Terminated values for each ready agent. The special key "__all__" is used to
  130. indicate episode termination.
  131. Truncated values for each ready agent. The special key "__all__"
  132. is used to indicate episode truncation.
  133. Info values for each ready agent.
  134. Agents may take off-policy actions, in which case, there will be an entry
  135. in this dict that contains the taken action. There is no need to
  136. `send_actions()` for agents that have already chosen off-policy actions.
  137. """
  138. raise NotImplementedError
  139. def send_actions(self, action_dict: MultiEnvDict) -> None:
  140. """Called to send actions back to running agents in this env.
  141. Actions should be sent for each ready agent that returned observations
  142. in the previous poll() call.
  143. Args:
  144. action_dict: Actions values keyed by env_id and agent_id.
  145. """
  146. raise NotImplementedError
  147. def try_reset(
  148. self,
  149. env_id: Optional[EnvID] = None,
  150. *,
  151. seed: Optional[int] = None,
  152. options: Optional[dict] = None,
  153. ) -> Tuple[Optional[MultiEnvDict], Optional[MultiEnvDict]]:
  154. """Attempt to reset the sub-env with the given id or all sub-envs.
  155. If the environment does not support synchronous reset, a tuple of
  156. (ASYNC_RESET_REQUEST, ASYNC_RESET_REQUEST) can be returned here.
  157. Note: A MultiAgentDict is returned when using the deprecated wrapper
  158. classes such as `ray.rllib.env.base_env._MultiAgentEnvToBaseEnv`,
  159. however for consistency with the poll() method, a `MultiEnvDict` is
  160. returned from the new wrapper classes, such as
  161. `ray.rllib.env.multi_agent_env.MultiAgentEnvWrapper`.
  162. Args:
  163. env_id: The sub-environment's ID if applicable. If None, reset
  164. the entire Env (i.e. all sub-environments).
  165. seed: The seed to be passed to the sub-environment(s) when
  166. resetting it. If None, will not reset any existing PRNG. If you pass an
  167. integer, the PRNG will be reset even if it already exists.
  168. options: An options dict to be passed to the sub-environment(s) when
  169. resetting it.
  170. Returns:
  171. A tuple consisting of a) the reset (multi-env/multi-agent) observation
  172. dict and b) the reset (multi-env/multi-agent) infos dict. Returns the
  173. (ASYNC_RESET_REQUEST, ASYNC_RESET_REQUEST) tuple, if not supported.
  174. """
  175. return None, None
  176. def try_restart(self, env_id: Optional[EnvID] = None) -> None:
  177. """Attempt to restart the sub-env with the given id or all sub-envs.
  178. This could result in the sub-env being completely removed (gc'd) and recreated.
  179. Args:
  180. env_id: The sub-environment's ID, if applicable. If None, restart
  181. the entire Env (i.e. all sub-environments).
  182. """
  183. return None
  184. def get_sub_environments(self, as_dict: bool = False) -> Union[List[EnvType], dict]:
  185. """Return a reference to the underlying sub environments, if any.
  186. Args:
  187. as_dict: If True, return a dict mapping from env_id to env.
  188. Returns:
  189. List or dictionary of the underlying sub environments or [] / {}.
  190. """
  191. if as_dict:
  192. return {}
  193. return []
  194. def get_agent_ids(self) -> Set[AgentID]:
  195. """Return the agent ids for the sub_environment.
  196. Returns:
  197. All agent ids for each the environment.
  198. """
  199. return {}
  200. def try_render(self, env_id: Optional[EnvID] = None) -> None:
  201. """Tries to render the sub-environment with the given id or all.
  202. Args:
  203. env_id: The sub-environment's ID, if applicable.
  204. If None, renders the entire Env (i.e. all sub-environments).
  205. """
  206. # By default, do nothing.
  207. pass
  208. def stop(self) -> None:
  209. """Releases all resources used."""
  210. # Try calling `close` on all sub-environments.
  211. for env in self.get_sub_environments():
  212. if hasattr(env, "close"):
  213. env.close()
  214. @property
  215. def observation_space(self) -> gym.Space:
  216. """Returns the observation space for each agent.
  217. Note: samples from the observation space need to be preprocessed into a
  218. `MultiEnvDict` before being used by a policy.
  219. Returns:
  220. The observation space for each environment.
  221. """
  222. raise NotImplementedError
  223. @property
  224. def action_space(self) -> gym.Space:
  225. """Returns the action space for each agent.
  226. Note: samples from the action space need to be preprocessed into a
  227. `MultiEnvDict` before being passed to `send_actions`.
  228. Returns:
  229. The observation space for each environment.
  230. """
  231. raise NotImplementedError
  232. def last(
  233. self,
  234. ) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict]:
  235. """Returns the last observations, rewards, done- truncated flags and infos ...
  236. that were returned by the environment.
  237. Returns:
  238. The last observations, rewards, done- and truncated flags, and infos
  239. for each sub-environment.
  240. """
  241. logger.warning("last has not been implemented for this environment.")
  242. return {}, {}, {}, {}, {}
  243. # Fixed agent identifier when there is only the single agent in the env
  244. _DUMMY_AGENT_ID = "agent0"
  245. @OldAPIStack
  246. def with_dummy_agent_id(
  247. env_id_to_values: Dict[EnvID, Any], dummy_id: "AgentID" = _DUMMY_AGENT_ID
  248. ) -> MultiEnvDict:
  249. ret = {}
  250. for (env_id, value) in env_id_to_values.items():
  251. # If the value (e.g. the observation) is an Exception, publish this error
  252. # under the env ID so the caller of `poll()` knows that the entire episode
  253. # (sub-environment) has crashed.
  254. ret[env_id] = value if isinstance(value, Exception) else {dummy_id: value}
  255. return ret
  256. @OldAPIStack
  257. def convert_to_base_env(
  258. env: EnvType,
  259. make_env: Callable[[int], EnvType] = None,
  260. num_envs: int = 1,
  261. remote_envs: bool = False,
  262. remote_env_batch_wait_ms: int = 0,
  263. worker: Optional["RolloutWorker"] = None,
  264. restart_failed_sub_environments: bool = False,
  265. ) -> "BaseEnv":
  266. """Converts an RLlib-supported env into a BaseEnv object.
  267. Supported types for the `env` arg are gym.Env, BaseEnv,
  268. VectorEnv, MultiAgentEnv, ExternalEnv, or ExternalMultiAgentEnv.
  269. The resulting BaseEnv is always vectorized (contains n
  270. sub-environments) to support batched forward passes, where n may also
  271. be 1. BaseEnv also supports async execution via the `poll` and
  272. `send_actions` methods and thus supports external simulators.
  273. TODO: Support gym3 environments, which are already vectorized.
  274. Args:
  275. env: An already existing environment of any supported env type
  276. to convert/wrap into a BaseEnv. Supported types are gym.Env,
  277. BaseEnv, VectorEnv, MultiAgentEnv, ExternalEnv, and
  278. ExternalMultiAgentEnv.
  279. make_env: A callable taking an int as input (which indicates the
  280. number of individual sub-environments within the final
  281. vectorized BaseEnv) and returning one individual
  282. sub-environment.
  283. num_envs: The number of sub-environments to create in the
  284. resulting (vectorized) BaseEnv. The already existing `env`
  285. will be one of the `num_envs`.
  286. remote_envs: Whether each sub-env should be a @ray.remote actor.
  287. You can set this behavior in your config via the
  288. `remote_worker_envs=True` option.
  289. remote_env_batch_wait_ms: The wait time (in ms) to poll remote
  290. sub-environments for, if applicable. Only used if
  291. `remote_envs` is True.
  292. worker: An optional RolloutWorker that owns the env. This is only
  293. used if `remote_worker_envs` is True in your config and the
  294. `on_sub_environment_created` custom callback needs to be called
  295. on each created actor.
  296. restart_failed_sub_environments: If True and any sub-environment (within
  297. a vectorized env) throws any error during env stepping, the
  298. Sampler will try to restart the faulty sub-environment. This is done
  299. without disturbing the other (still intact) sub-environment and without
  300. the RolloutWorker crashing.
  301. Returns:
  302. The resulting BaseEnv object.
  303. """
  304. from ray.rllib.env.external_env import ExternalEnv
  305. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  306. from ray.rllib.env.remote_base_env import RemoteBaseEnv
  307. from ray.rllib.env.vector_env import VectorEnv, VectorEnvWrapper
  308. if remote_envs and num_envs == 1:
  309. raise ValueError(
  310. "Remote envs only make sense to use if num_envs > 1 "
  311. "(i.e. environment vectorization is enabled)."
  312. )
  313. # Given `env` has a `to_base_env` method -> Call that to convert to a BaseEnv type.
  314. if isinstance(env, (BaseEnv, MultiAgentEnv, VectorEnv, ExternalEnv)):
  315. return env.to_base_env(
  316. make_env=make_env,
  317. num_envs=num_envs,
  318. remote_envs=remote_envs,
  319. remote_env_batch_wait_ms=remote_env_batch_wait_ms,
  320. restart_failed_sub_environments=restart_failed_sub_environments,
  321. )
  322. # `env` is not a BaseEnv yet -> Need to convert/vectorize.
  323. else:
  324. # Sub-environments are ray.remote actors:
  325. if remote_envs:
  326. # Determine, whether the already existing sub-env (could
  327. # be a ray.actor) is multi-agent or not.
  328. multiagent = (
  329. ray.get(env._is_multi_agent.remote())
  330. if hasattr(env, "_is_multi_agent")
  331. else False
  332. )
  333. env = RemoteBaseEnv(
  334. make_env,
  335. num_envs,
  336. multiagent=multiagent,
  337. remote_env_batch_wait_ms=remote_env_batch_wait_ms,
  338. existing_envs=[env],
  339. worker=worker,
  340. restart_failed_sub_environments=restart_failed_sub_environments,
  341. )
  342. # Sub-environments are not ray.remote actors.
  343. else:
  344. # Convert gym.Env to VectorEnv ...
  345. env = VectorEnv.vectorize_gym_envs(
  346. make_env=make_env,
  347. existing_envs=[env],
  348. num_envs=num_envs,
  349. action_space=env.action_space,
  350. observation_space=env.observation_space,
  351. restart_failed_sub_environments=restart_failed_sub_environments,
  352. )
  353. # ... then the resulting VectorEnv to a BaseEnv.
  354. env = VectorEnvWrapper(env)
  355. # Make sure conversion went well.
  356. assert isinstance(env, BaseEnv), env
  357. return env