remote_base_env.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463
  1. import logging
  2. from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple
  3. import gymnasium as gym
  4. import ray
  5. from ray.rllib.env.base_env import _DUMMY_AGENT_ID, ASYNC_RESET_RETURN, BaseEnv
  6. from ray.rllib.utils.annotations import OldAPIStack, override
  7. from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiEnvDict
  8. from ray.util import log_once
  9. if TYPE_CHECKING:
  10. from ray.rllib.evaluation.rollout_worker import RolloutWorker
  11. logger = logging.getLogger(__name__)
  12. @OldAPIStack
  13. class RemoteBaseEnv(BaseEnv):
  14. """BaseEnv that executes its sub environments as @ray.remote actors.
  15. This provides dynamic batching of inference as observations are returned
  16. from the remote simulator actors. Both single and multi-agent child envs
  17. are supported, and envs can be stepped synchronously or asynchronously.
  18. NOTE: This class implicitly assumes that the remote envs are gym.Env's
  19. You shouldn't need to instantiate this class directly. It's automatically
  20. inserted when you use the `remote_worker_envs=True` option in your
  21. Algorithm's config.
  22. """
  23. def __init__(
  24. self,
  25. make_env: Callable[[int], EnvType],
  26. num_envs: int,
  27. multiagent: bool,
  28. remote_env_batch_wait_ms: int,
  29. existing_envs: Optional[List[ray.actor.ActorHandle]] = None,
  30. worker: Optional["RolloutWorker"] = None,
  31. restart_failed_sub_environments: bool = False,
  32. ):
  33. """Initializes a RemoteVectorEnv instance.
  34. Args:
  35. make_env: Callable that produces a single (non-vectorized) env,
  36. given the vector env index as only arg.
  37. num_envs: The number of sub-environments to create for the
  38. vectorization.
  39. multiagent: Whether this is a multiagent env or not.
  40. remote_env_batch_wait_ms: Time to wait for (ray.remote)
  41. sub-environments to have new observations available when
  42. polled. Only when none of the sub-environments is ready,
  43. repeat the `ray.wait()` call until at least one sub-env
  44. is ready. Then return only the observations of the ready
  45. sub-environment(s).
  46. existing_envs: Optional list of already created sub-environments.
  47. These will be used as-is and only as many new sub-envs as
  48. necessary (`num_envs - len(existing_envs)`) will be created.
  49. worker: An optional RolloutWorker that owns the env. This is only
  50. used if `remote_worker_envs` is True in your config and the
  51. `on_sub_environment_created` custom callback needs to be
  52. called on each created actor.
  53. restart_failed_sub_environments: If True and any sub-environment (within
  54. a vectorized env) throws any error during env stepping, the
  55. Sampler will try to restart the faulty sub-environment. This is done
  56. without disturbing the other (still intact) sub-environment and without
  57. the RolloutWorker crashing.
  58. """
  59. # Could be creating local or remote envs.
  60. self.make_env = make_env
  61. self.num_envs = num_envs
  62. self.multiagent = multiagent
  63. self.poll_timeout = remote_env_batch_wait_ms / 1000
  64. self.worker = worker
  65. self.restart_failed_sub_environments = restart_failed_sub_environments
  66. # Already existing env objects (generated by the RolloutWorker).
  67. existing_envs = existing_envs or []
  68. # Whether the given `make_env` callable already returns ActorHandles
  69. # (@ray.remote class instances) or not.
  70. self.make_env_creates_actors = False
  71. self._observation_space = None
  72. self._action_space = None
  73. # List of ray actor handles (each handle points to one @ray.remote
  74. # sub-environment).
  75. self.actors: Optional[List[ray.actor.ActorHandle]] = None
  76. # `self.make_env` already produces Actors: Use it directly.
  77. if len(existing_envs) > 0 and isinstance(
  78. existing_envs[0], ray.actor.ActorHandle
  79. ):
  80. self.make_env_creates_actors = True
  81. self.actors = existing_envs
  82. while len(self.actors) < self.num_envs:
  83. self.actors.append(self._make_sub_env(len(self.actors)))
  84. # `self.make_env` produces gym.Envs (or children thereof, such
  85. # as MultiAgentEnv): Need to auto-wrap it here. The problem with
  86. # this is that custom methods wil get lost. If you would like to
  87. # keep your custom methods in your envs, you should provide the
  88. # env class directly in your config (w/o tune.register_env()),
  89. # such that your class can directly be made a @ray.remote
  90. # (w/o the wrapping via `_Remote[Multi|Single]AgentEnv`).
  91. # Also, if `len(existing_envs) > 0`, we have to throw those away
  92. # as we need to create ray actors here.
  93. else:
  94. self.actors = [self._make_sub_env(i) for i in range(self.num_envs)]
  95. # Utilize existing envs for inferring observation/action spaces.
  96. if len(existing_envs) > 0:
  97. self._observation_space = existing_envs[0].observation_space
  98. self._action_space = existing_envs[0].action_space
  99. # Have to call actors' remote methods to get observation/action spaces.
  100. else:
  101. self._observation_space, self._action_space = ray.get(
  102. [
  103. self.actors[0].observation_space.remote(),
  104. self.actors[0].action_space.remote(),
  105. ]
  106. )
  107. # Dict mapping object refs (return values of @ray.remote calls),
  108. # whose actual values we are waiting for (via ray.wait in
  109. # `self.poll()`) to their corresponding actor handles (the actors
  110. # that created these return values).
  111. # Call `reset()` on all @ray.remote sub-environment actors.
  112. self.pending: Dict[ray.actor.ActorHandle] = {
  113. a.reset.remote(): a for a in self.actors
  114. }
  115. @override(BaseEnv)
  116. def poll(
  117. self,
  118. ) -> Tuple[
  119. MultiEnvDict,
  120. MultiEnvDict,
  121. MultiEnvDict,
  122. MultiEnvDict,
  123. MultiEnvDict,
  124. MultiEnvDict,
  125. ]:
  126. # each keyed by env_id in [0, num_remote_envs)
  127. obs, rewards, terminateds, truncateds, infos = {}, {}, {}, {}, {}
  128. ready = []
  129. # Wait for at least 1 env to be ready here.
  130. while not ready:
  131. ready, _ = ray.wait(
  132. list(self.pending),
  133. num_returns=len(self.pending),
  134. timeout=self.poll_timeout,
  135. )
  136. # Get and return observations for each of the ready envs
  137. env_ids = set()
  138. for obj_ref in ready:
  139. # Get the corresponding actor handle from our dict and remove the
  140. # object ref (we will call `ray.get()` on it and it will no longer
  141. # be "pending").
  142. actor = self.pending.pop(obj_ref)
  143. env_id = self.actors.index(actor)
  144. env_ids.add(env_id)
  145. # Get the ready object ref (this may be return value(s) of
  146. # `reset()` or `step()`).
  147. try:
  148. ret = ray.get(obj_ref)
  149. except Exception as e:
  150. # Something happened on the actor during stepping/resetting.
  151. # Restart sub-environment (create new actor; close old one).
  152. if self.restart_failed_sub_environments:
  153. logger.exception(e.args[0])
  154. self.try_restart(env_id)
  155. # Always return multi-agent data.
  156. # Set the observation to the exception, no rewards,
  157. # terminated[__all__]=True (episode will be discarded anyways),
  158. # no infos.
  159. ret = (
  160. e,
  161. {},
  162. {"__all__": True},
  163. {"__all__": False},
  164. {},
  165. )
  166. # Do not try to restart. Just raise the error.
  167. else:
  168. raise e
  169. # Our sub-envs are simple Actor-turned gym.Envs or MultiAgentEnvs.
  170. if self.make_env_creates_actors:
  171. rew, terminated, truncated, info = None, None, None, None
  172. if self.multiagent:
  173. if isinstance(ret, tuple):
  174. # Gym >= 0.26: `step()` result: Obs, reward, terminated,
  175. # truncated, info.
  176. if len(ret) == 5:
  177. ob, rew, terminated, truncated, info = ret
  178. # Gym >= 0.26: `reset()` result: Obs and infos.
  179. elif len(ret) == 2:
  180. ob = ret[0]
  181. info = ret[1]
  182. # Gym < 0.26? Something went wrong.
  183. else:
  184. raise AssertionError(
  185. "Your gymnasium.Env seems to NOT return the correct "
  186. "number of return values for `step()` (needs to return"
  187. " 5 values: obs, reward, terminated, truncated and "
  188. "info) or `reset()` (needs to return 2 values: obs and "
  189. "info)!"
  190. )
  191. # Gym < 0.26: `reset()` result: Only obs.
  192. else:
  193. raise AssertionError(
  194. "Your gymnasium.Env seems to only return a single value "
  195. "upon `reset()`! Must return 2 (obs AND infos)."
  196. )
  197. else:
  198. if isinstance(ret, tuple):
  199. # `step()` result: Obs, reward, terminated, truncated, info.
  200. if len(ret) == 5:
  201. ob = {_DUMMY_AGENT_ID: ret[0]}
  202. rew = {_DUMMY_AGENT_ID: ret[1]}
  203. terminated = {_DUMMY_AGENT_ID: ret[2], "__all__": ret[2]}
  204. truncated = {_DUMMY_AGENT_ID: ret[3], "__all__": ret[3]}
  205. info = {_DUMMY_AGENT_ID: ret[4]}
  206. # `reset()` result: Obs and infos.
  207. elif len(ret) == 2:
  208. ob = {_DUMMY_AGENT_ID: ret[0]}
  209. info = {_DUMMY_AGENT_ID: ret[1]}
  210. # Gym < 0.26? Something went wrong.
  211. else:
  212. raise AssertionError(
  213. "Your gymnasium.Env seems to NOT return the correct "
  214. "number of return values for `step()` (needs to return"
  215. " 5 values: obs, reward, terminated, truncated and "
  216. "info) or `reset()` (needs to return 2 values: obs and "
  217. "info)!"
  218. )
  219. # Gym < 0.26?
  220. else:
  221. raise AssertionError(
  222. "Your gymnasium.Env seems to only return a single value "
  223. "upon `reset()`! Must return 2 (obs and infos)."
  224. )
  225. # If this is a `reset()` return value, we only have the initial
  226. # observations and infos: Set rewards, terminateds, and truncateds to
  227. # dummy values.
  228. if rew is None:
  229. rew = {agent_id: 0 for agent_id in ob.keys()}
  230. terminated = {"__all__": False}
  231. truncated = {"__all__": False}
  232. # Our sub-envs are auto-wrapped (by `_RemoteSingleAgentEnv` or
  233. # `_RemoteMultiAgentEnv`) and already behave like multi-agent
  234. # envs.
  235. else:
  236. ob, rew, terminated, truncated, info = ret
  237. obs[env_id] = ob
  238. rewards[env_id] = rew
  239. terminateds[env_id] = terminated
  240. truncateds[env_id] = truncated
  241. infos[env_id] = info
  242. logger.debug(f"Got obs batch for actors {env_ids}")
  243. return obs, rewards, terminateds, truncateds, infos, {}
  244. @override(BaseEnv)
  245. def send_actions(self, action_dict: MultiEnvDict) -> None:
  246. for env_id, actions in action_dict.items():
  247. actor = self.actors[env_id]
  248. # `actor` is a simple single-agent (remote) env, e.g. a gym.Env
  249. # that was made a @ray.remote.
  250. if not self.multiagent and self.make_env_creates_actors:
  251. obj_ref = actor.step.remote(actions[_DUMMY_AGENT_ID])
  252. # `actor` is already a _RemoteSingleAgentEnv or
  253. # _RemoteMultiAgentEnv wrapper
  254. # (handles the multi-agent action_dict automatically).
  255. else:
  256. obj_ref = actor.step.remote(actions)
  257. self.pending[obj_ref] = actor
  258. @override(BaseEnv)
  259. def try_reset(
  260. self,
  261. env_id: Optional[EnvID] = None,
  262. *,
  263. seed: Optional[int] = None,
  264. options: Optional[dict] = None,
  265. ) -> Tuple[MultiEnvDict, MultiEnvDict]:
  266. actor = self.actors[env_id]
  267. obj_ref = actor.reset.remote(seed=seed, options=options)
  268. self.pending[obj_ref] = actor
  269. # Because this env type does not support synchronous reset requests (with
  270. # immediate return value), we return ASYNC_RESET_RETURN here to indicate
  271. # that the reset results will be available via the next `poll()` call.
  272. return ASYNC_RESET_RETURN, ASYNC_RESET_RETURN
  273. @override(BaseEnv)
  274. def try_restart(self, env_id: Optional[EnvID] = None) -> None:
  275. # Try closing down the old (possibly faulty) sub-env, but ignore errors.
  276. try:
  277. # Close the env on the remote side.
  278. self.actors[env_id].close.remote()
  279. except Exception as e:
  280. if log_once("close_sub_env"):
  281. logger.warning(
  282. "Trying to close old and replaced sub-environment (at vector "
  283. f"index={env_id}), but closing resulted in error:\n{e}"
  284. )
  285. # Terminate the actor itself to free up its resources.
  286. self.actors[env_id].__ray_terminate__.remote()
  287. # Re-create a new sub-environment.
  288. self.actors[env_id] = self._make_sub_env(env_id)
  289. @override(BaseEnv)
  290. def stop(self) -> None:
  291. if self.actors is not None:
  292. for actor in self.actors:
  293. actor.__ray_terminate__.remote()
  294. @override(BaseEnv)
  295. def get_sub_environments(self, as_dict: bool = False) -> List[EnvType]:
  296. if as_dict:
  297. return dict(enumerate(self.actors))
  298. return self.actors
  299. @property
  300. @override(BaseEnv)
  301. def observation_space(self) -> gym.spaces.Dict:
  302. return self._observation_space
  303. @property
  304. @override(BaseEnv)
  305. def action_space(self) -> gym.Space:
  306. return self._action_space
  307. def _make_sub_env(self, idx: Optional[int] = None):
  308. """Re-creates a sub-environment at the new index."""
  309. # Our `make_env` creates ray actors directly.
  310. if self.make_env_creates_actors:
  311. sub_env = self.make_env(idx)
  312. if self.worker is not None:
  313. self.worker.callbacks.on_sub_environment_created(
  314. worker=self.worker,
  315. sub_environment=self.actors[idx],
  316. env_context=self.worker.env_context.copy_with_overrides(
  317. vector_index=idx
  318. ),
  319. )
  320. # Our `make_env` returns actual envs -> Have to convert them into actors
  321. # using our utility wrapper classes.
  322. else:
  323. def make_remote_env(i):
  324. logger.info("Launching env {} in remote actor".format(i))
  325. if self.multiagent:
  326. sub_env = _RemoteMultiAgentEnv.remote(self.make_env, i)
  327. else:
  328. sub_env = _RemoteSingleAgentEnv.remote(self.make_env, i)
  329. if self.worker is not None:
  330. self.worker.callbacks.on_sub_environment_created(
  331. worker=self.worker,
  332. sub_environment=sub_env,
  333. env_context=self.worker.env_context.copy_with_overrides(
  334. vector_index=i
  335. ),
  336. )
  337. return sub_env
  338. sub_env = make_remote_env(idx)
  339. return sub_env
  340. @override(BaseEnv)
  341. def get_agent_ids(self) -> Set[AgentID]:
  342. if self.multiagent:
  343. return ray.get(self.actors[0].get_agent_ids.remote())
  344. else:
  345. return {_DUMMY_AGENT_ID}
  346. @ray.remote(num_cpus=0)
  347. class _RemoteMultiAgentEnv:
  348. """Wrapper class for making a multi-agent env a remote actor."""
  349. def __init__(self, make_env, i):
  350. self.env = make_env(i)
  351. self.agent_ids = set()
  352. def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
  353. obs, info = self.env.reset(seed=seed, options=options)
  354. # each keyed by agent_id in the env
  355. rew = {}
  356. for agent_id in obs.keys():
  357. self.agent_ids.add(agent_id)
  358. rew[agent_id] = 0.0
  359. terminated = {"__all__": False}
  360. truncated = {"__all__": False}
  361. return obs, rew, terminated, truncated, info
  362. def step(self, action_dict):
  363. return self.env.step(action_dict)
  364. # Defining these 2 functions that way this information can be queried
  365. # with a call to ray.get().
  366. def observation_space(self):
  367. return self.env.observation_space
  368. def action_space(self):
  369. return self.env.action_space
  370. def get_agent_ids(self) -> Set[AgentID]:
  371. return self.agent_ids
  372. @ray.remote(num_cpus=0)
  373. class _RemoteSingleAgentEnv:
  374. """Wrapper class for making a gym env a remote actor."""
  375. def __init__(self, make_env, i):
  376. self.env = make_env(i)
  377. def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
  378. obs_and_info = self.env.reset(seed=seed, options=options)
  379. obs = {_DUMMY_AGENT_ID: obs_and_info[0]}
  380. info = {_DUMMY_AGENT_ID: obs_and_info[1]}
  381. rew = {_DUMMY_AGENT_ID: 0.0}
  382. terminated = {"__all__": False}
  383. truncated = {"__all__": False}
  384. return obs, rew, terminated, truncated, info
  385. def step(self, action):
  386. results = self.env.step(action[_DUMMY_AGENT_ID])
  387. obs, rew, terminated, truncated, info = [{_DUMMY_AGENT_ID: x} for x in results]
  388. terminated["__all__"] = terminated[_DUMMY_AGENT_ID]
  389. truncated["__all__"] = truncated[_DUMMY_AGENT_ID]
  390. return obs, rew, terminated, truncated, info
  391. # Defining these 2 functions that way this information can be queried
  392. # with a call to ray.get().
  393. def observation_space(self):
  394. return self.env.observation_space
  395. def action_space(self):
  396. return self.env.action_space