multi_agent_env.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807
  1. import copy
  2. import logging
  3. from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
  4. import gymnasium as gym
  5. import numpy as np
  6. from ray._common.deprecation import Deprecated
  7. from ray.rllib.env.base_env import BaseEnv
  8. from ray.rllib.env.env_context import EnvContext
  9. from ray.rllib.utils.annotations import OldAPIStack, override
  10. from ray.rllib.utils.typing import (
  11. AgentID,
  12. EnvCreator,
  13. EnvID,
  14. EnvType,
  15. MultiAgentDict,
  16. MultiEnvDict,
  17. )
  18. from ray.util import log_once
  19. from ray.util.annotations import DeveloperAPI, PublicAPI
  20. # If the obs space is Dict type, look for the global state under this key.
  21. ENV_STATE = "state"
  22. logger = logging.getLogger(__name__)
  23. @PublicAPI(stability="beta")
  24. class MultiAgentEnv(gym.Env):
  25. """An environment that hosts multiple independent agents.
  26. Agents are identified by AgentIDs (string).
  27. """
  28. # Optional mappings from AgentID to individual agents' spaces.
  29. # Set this to an "exhaustive" dictionary, mapping all possible AgentIDs to
  30. # individual agents' spaces. Alternatively, override
  31. # `get_observation_space(agent_id=...)` and `get_action_space(agent_id=...)`, which
  32. # is the API that RLlib uses to get individual spaces and whose default
  33. # implementation is to simply look up `agent_id` in these dicts.
  34. observation_spaces: Optional[Dict[AgentID, gym.Space]] = None
  35. action_spaces: Optional[Dict[AgentID, gym.Space]] = None
  36. # All agents currently active in the environment. This attribute may change during
  37. # the lifetime of the env or even during an individual episode.
  38. agents: List[AgentID] = []
  39. # All agents that may appear in the environment, ever.
  40. # This attribute should not be changed during the lifetime of this env.
  41. possible_agents: List[AgentID] = []
  42. # @OldAPIStack, use `observation_spaces` and `action_spaces`, instead.
  43. observation_space: Optional[gym.Space] = None
  44. action_space: Optional[gym.Space] = None
  45. def __init__(self):
  46. super().__init__()
  47. # @OldAPIStack
  48. if not hasattr(self, "_agent_ids"):
  49. self._agent_ids = set()
  50. # If these important attributes are not set, try to infer them.
  51. if not self.agents:
  52. self.agents = list(self._agent_ids)
  53. if not self.possible_agents:
  54. self.possible_agents = self.agents.copy()
  55. def reset(
  56. self,
  57. *,
  58. seed: Optional[int] = None,
  59. options: Optional[dict] = None,
  60. ) -> Tuple[MultiAgentDict, MultiAgentDict]: # type: ignore
  61. """Resets the env and returns observations from ready agents.
  62. Args:
  63. seed: An optional seed to use for the new episode.
  64. Returns:
  65. New observations for each ready agent.
  66. .. testcode::
  67. :skipif: True
  68. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  69. class MyMultiAgentEnv(MultiAgentEnv):
  70. # Define your env here.
  71. env = MyMultiAgentEnv()
  72. obs, infos = env.reset(seed=42, options={})
  73. print(obs)
  74. .. testoutput::
  75. {
  76. "car_0": [2.4, 1.6],
  77. "car_1": [3.4, -3.2],
  78. "traffic_light_1": [0, 3, 5, 1],
  79. }
  80. """
  81. # Call super's `reset()` method to (maybe) set the given `seed`.
  82. super().reset(seed=seed, options=options)
  83. def step(
  84. self, action_dict: MultiAgentDict
  85. ) -> Tuple[
  86. MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict, MultiAgentDict
  87. ]:
  88. """Returns observations from ready agents.
  89. The returns are dicts mapping from agent_id strings to values. The
  90. number of agents in the env can vary over time.
  91. Returns:
  92. Tuple containing 1) new observations for
  93. each ready agent, 2) reward values for each ready agent. If
  94. the episode is just started, the value will be None.
  95. 3) Terminated values for each ready agent. The special key
  96. "__all__" (required) is used to indicate env termination.
  97. 4) Truncated values for each ready agent.
  98. 5) Info values for each agent id (may be empty dicts).
  99. .. testcode::
  100. :skipif: True
  101. env = ...
  102. obs, rewards, terminateds, truncateds, infos = env.step(action_dict={
  103. "car_0": 1, "car_1": 0, "traffic_light_1": 2,
  104. })
  105. print(rewards)
  106. print(terminateds)
  107. print(infos)
  108. .. testoutput::
  109. {
  110. "car_0": 3,
  111. "car_1": -1,
  112. "traffic_light_1": 0,
  113. }
  114. {
  115. "car_0": False, # car_0 is still running
  116. "car_1": True, # car_1 is terminated
  117. "__all__": False, # the env is not terminated
  118. }
  119. {
  120. "car_0": {}, # info for car_0
  121. "car_1": {}, # info for car_1
  122. }
  123. """
  124. raise NotImplementedError
  125. def render(self) -> None:
  126. """Tries to render the environment."""
  127. # By default, do nothing.
  128. pass
  129. def get_observation_space(self, agent_id: AgentID) -> gym.Space:
  130. if self.observation_spaces is not None:
  131. return self.observation_spaces[agent_id]
  132. # @OldAPIStack behavior.
  133. # `self.observation_space` is a `gym.spaces.Dict` AND contains `agent_id`.
  134. if (
  135. isinstance(self.observation_space, gym.spaces.Dict)
  136. and agent_id in self.observation_space.spaces
  137. ):
  138. return self.observation_space[agent_id]
  139. # `self.observation_space` is not a `gym.spaces.Dict` OR doesn't contain
  140. # `agent_id` -> The defined space is most likely meant to be the space
  141. # for all agents.
  142. else:
  143. return self.observation_space
  144. def get_action_space(self, agent_id: AgentID) -> gym.Space:
  145. if self.action_spaces is not None:
  146. return self.action_spaces[agent_id]
  147. # @OldAPIStack behavior.
  148. # `self.action_space` is a `gym.spaces.Dict` AND contains `agent_id`.
  149. if (
  150. isinstance(self.action_space, gym.spaces.Dict)
  151. and agent_id in self.action_space.spaces
  152. ):
  153. return self.action_space[agent_id]
  154. # `self.action_space` is not a `gym.spaces.Dict` OR doesn't contain
  155. # `agent_id` -> The defined space is most likely meant to be the space
  156. # for all agents.
  157. else:
  158. return self.action_space
  159. @property
  160. def num_agents(self) -> int:
  161. return len(self.agents)
  162. @property
  163. def max_num_agents(self) -> int:
  164. return len(self.possible_agents)
  165. # fmt: off
  166. # __grouping_doc_begin__
  167. def with_agent_groups(
  168. self,
  169. groups: Dict[str, List[AgentID]],
  170. obs_space: gym.Space = None,
  171. act_space: gym.Space = None,
  172. ) -> "MultiAgentEnv":
  173. """Convenience method for grouping together agents in this env.
  174. An agent group is a list of agent IDs that are mapped to a single
  175. logical agent. All agents of the group must act at the same time in the
  176. environment. The grouped agent exposes Tuple action and observation
  177. spaces that are the concatenated action and obs spaces of the
  178. individual agents.
  179. The rewards of all the agents in a group are summed. The individual
  180. agent rewards are available under the "individual_rewards" key of the
  181. group info return.
  182. Agent grouping is required to leverage algorithms such as Q-Mix.
  183. Args:
  184. groups: Mapping from group id to a list of the agent ids
  185. of group members. If an agent id is not present in any group
  186. value, it will be left ungrouped. The group id becomes a new agent ID
  187. in the final environment.
  188. obs_space: Optional observation space for the grouped
  189. env. Must be a tuple space. If not provided, will infer this to be a
  190. Tuple of n individual agents spaces (n=num agents in a group).
  191. act_space: Optional action space for the grouped env.
  192. Must be a tuple space. If not provided, will infer this to be a Tuple
  193. of n individual agents spaces (n=num agents in a group).
  194. .. testcode::
  195. :skipif: True
  196. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  197. class MyMultiAgentEnv(MultiAgentEnv):
  198. # define your env here
  199. ...
  200. env = MyMultiAgentEnv(...)
  201. grouped_env = env.with_agent_groups(env, {
  202. "group1": ["agent1", "agent2", "agent3"],
  203. "group2": ["agent4", "agent5"],
  204. })
  205. """
  206. from ray.rllib.env.wrappers.group_agents_wrapper import GroupAgentsWrapper
  207. return GroupAgentsWrapper(self, groups, obs_space, act_space)
  208. # __grouping_doc_end__
  209. # fmt: on
  210. @OldAPIStack
  211. @Deprecated(new="MultiAgentEnv.possible_agents", error=False)
  212. def get_agent_ids(self) -> Set[AgentID]:
  213. if not hasattr(self, "_agent_ids"):
  214. self._agent_ids = set()
  215. if not isinstance(self._agent_ids, set):
  216. self._agent_ids = set(self._agent_ids)
  217. # Make this backward compatible as much as possible.
  218. return self._agent_ids if self._agent_ids else set(self.agents)
  219. @OldAPIStack
  220. def to_base_env(
  221. self,
  222. make_env: Optional[Callable[[int], EnvType]] = None,
  223. num_envs: int = 1,
  224. remote_envs: bool = False,
  225. remote_env_batch_wait_ms: int = 0,
  226. restart_failed_sub_environments: bool = False,
  227. ) -> "BaseEnv":
  228. """Converts an RLlib MultiAgentEnv into a BaseEnv object.
  229. The resulting BaseEnv is always vectorized (contains n
  230. sub-environments) to support batched forward passes, where n may
  231. also be 1. BaseEnv also supports async execution via the `poll` and
  232. `send_actions` methods and thus supports external simulators.
  233. Args:
  234. make_env: A callable taking an int as input (which indicates
  235. the number of individual sub-environments within the final
  236. vectorized BaseEnv) and returning one individual
  237. sub-environment.
  238. num_envs: The number of sub-environments to create in the
  239. resulting (vectorized) BaseEnv. The already existing `env`
  240. will be one of the `num_envs`.
  241. remote_envs: Whether each sub-env should be a @ray.remote
  242. actor. You can set this behavior in your config via the
  243. `remote_worker_envs=True` option.
  244. remote_env_batch_wait_ms: The wait time (in ms) to poll remote
  245. sub-environments for, if applicable. Only used if
  246. `remote_envs` is True.
  247. restart_failed_sub_environments: If True and any sub-environment (within
  248. a vectorized env) throws any error during env stepping, we will try to
  249. restart the faulty sub-environment. This is done
  250. without disturbing the other (still intact) sub-environments.
  251. Returns:
  252. The resulting BaseEnv object.
  253. """
  254. from ray.rllib.env.remote_base_env import RemoteBaseEnv
  255. if remote_envs:
  256. env = RemoteBaseEnv(
  257. make_env,
  258. num_envs,
  259. multiagent=True,
  260. remote_env_batch_wait_ms=remote_env_batch_wait_ms,
  261. restart_failed_sub_environments=restart_failed_sub_environments,
  262. )
  263. # Sub-environments are not ray.remote actors.
  264. else:
  265. env = MultiAgentEnvWrapper(
  266. make_env=make_env,
  267. existing_envs=[self],
  268. num_envs=num_envs,
  269. restart_failed_sub_environments=restart_failed_sub_environments,
  270. )
  271. return env
  272. @DeveloperAPI
  273. def make_multi_agent(
  274. env_name_or_creator: Union[str, EnvCreator],
  275. ) -> Type["MultiAgentEnv"]:
  276. """Convenience wrapper for any single-agent env to be converted into MA.
  277. Allows you to convert a simple (single-agent) `gym.Env` class
  278. into a `MultiAgentEnv` class. This function simply stacks n instances
  279. of the given ```gym.Env``` class into one unified ``MultiAgentEnv`` class
  280. and returns this class, thus pretending the agents act together in the
  281. same environment, whereas - under the hood - they live separately from
  282. each other in n parallel single-agent envs.
  283. Agent IDs in the resulting and are int numbers starting from 0
  284. (first agent).
  285. Args:
  286. env_name_or_creator: String specifier or env_maker function taking
  287. an EnvContext object as only arg and returning a gym.Env.
  288. Returns:
  289. New MultiAgentEnv class to be used as env.
  290. The constructor takes a config dict with `num_agents` key
  291. (default=1). The rest of the config dict will be passed on to the
  292. underlying single-agent env's constructor.
  293. .. testcode::
  294. :skipif: True
  295. from ray.rllib.env.multi_agent_env import make_multi_agent
  296. # By gym string:
  297. ma_cartpole_cls = make_multi_agent("CartPole-v1")
  298. # Create a 2 agent multi-agent cartpole.
  299. ma_cartpole = ma_cartpole_cls({"num_agents": 2})
  300. obs = ma_cartpole.reset()
  301. print(obs)
  302. # By env-maker callable:
  303. from ray.rllib.examples.envs.classes.stateless_cartpole import StatelessCartPole
  304. ma_stateless_cartpole_cls = make_multi_agent(
  305. lambda config: StatelessCartPole(config))
  306. # Create a 3 agent multi-agent stateless cartpole.
  307. ma_stateless_cartpole = ma_stateless_cartpole_cls(
  308. {"num_agents": 3})
  309. print(obs)
  310. .. testoutput::
  311. {0: [...], 1: [...]}
  312. {0: [...], 1: [...], 2: [...]}
  313. """
  314. class MultiEnv(MultiAgentEnv):
  315. def __init__(self, config: EnvContext = None):
  316. super().__init__()
  317. # Note: Explicitly check for None here, because config
  318. # can have an empty dict but meaningful data fields (worker_index,
  319. # vector_index) etc.
  320. # TODO (sven): Clean this up, so we are not mixing up dict fields
  321. # with data fields.
  322. if config is None:
  323. config = {}
  324. else:
  325. # Note the deepcopy is needed b/c (a) we need to remove the
  326. # `num_agents` keyword and (b) with `num_envs > 0` in the
  327. # `VectorMultiAgentEnv` all following environment creations
  328. # need the same config again.
  329. config = copy.deepcopy(config)
  330. num = config.pop("num_agents", 1)
  331. if isinstance(env_name_or_creator, str):
  332. self.envs = [gym.make(env_name_or_creator) for _ in range(num)]
  333. else:
  334. self.envs = [env_name_or_creator(config) for _ in range(num)]
  335. self.terminateds = set()
  336. self.truncateds = set()
  337. self.observation_spaces = {
  338. i: self.envs[i].observation_space for i in range(num)
  339. }
  340. self.action_spaces = {i: self.envs[i].action_space for i in range(num)}
  341. self.agents = list(range(num))
  342. self.possible_agents = self.agents.copy()
  343. @override(MultiAgentEnv)
  344. def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
  345. self.terminateds = set()
  346. self.truncateds = set()
  347. obs, infos = {}, {}
  348. for i, env in enumerate(self.envs):
  349. obs[i], infos[i] = env.reset(seed=seed, options=options)
  350. if not self.observation_spaces[i].contains(obs[i]):
  351. logger.warning("MultiEnv does not contain obs.")
  352. return obs, infos
  353. @override(MultiAgentEnv)
  354. def step(self, action_dict):
  355. obs, rew, terminated, truncated, info = {}, {}, {}, {}, {}
  356. # The environment is expecting an action for at least one agent.
  357. if len(action_dict) == 0:
  358. raise ValueError(
  359. "The environment is expecting an action for at least one agent."
  360. )
  361. for i, action in action_dict.items():
  362. obs[i], rew[i], terminated[i], truncated[i], info[i] = self.envs[
  363. i
  364. ].step(action)
  365. if terminated[i]:
  366. self.terminateds.add(i)
  367. if truncated[i]:
  368. self.truncateds.add(i)
  369. # TODO: Flaw in our MultiAgentEnv API wrt. new gymnasium: Need to return
  370. # an additional episode_done bool that covers cases where all agents are
  371. # either terminated or truncated, but not all are truncated and not all are
  372. # terminated. We can then get rid of the aweful `__all__` special keys!
  373. terminated["__all__"] = len(self.terminateds | self.truncateds) == len(
  374. self.envs
  375. )
  376. truncated["__all__"] = len(self.truncateds) == len(self.envs)
  377. return obs, rew, terminated, truncated, info
  378. @override(MultiAgentEnv)
  379. def render(self):
  380. # This render method simply renders all n underlying individual single-agent
  381. # envs and concatenates their images (on top of each other if the returned
  382. # images have dims where [width] > [height], otherwise next to each other).
  383. render_images = [e.render() for e in self.envs]
  384. if render_images[0].shape[1] > render_images[0].shape[0]:
  385. concat_dim = 0
  386. else:
  387. concat_dim = 1
  388. return np.concatenate(render_images, axis=concat_dim)
  389. return MultiEnv
  390. @OldAPIStack
  391. class MultiAgentEnvWrapper(BaseEnv):
  392. """Internal adapter of MultiAgentEnv to BaseEnv.
  393. This also supports vectorization if num_envs > 1.
  394. """
  395. def __init__(
  396. self,
  397. make_env: Callable[[int], EnvType],
  398. existing_envs: List["MultiAgentEnv"],
  399. num_envs: int,
  400. restart_failed_sub_environments: bool = False,
  401. ):
  402. """Wraps MultiAgentEnv(s) into the BaseEnv API.
  403. Args:
  404. make_env: Factory that produces a new MultiAgentEnv instance taking the
  405. vector index as only call argument.
  406. Must be defined, if the number of existing envs is less than num_envs.
  407. existing_envs: List of already existing multi-agent envs.
  408. num_envs: Desired num multiagent envs to have at the end in
  409. total. This will include the given (already created)
  410. `existing_envs`.
  411. restart_failed_sub_environments: If True and any sub-environment (within
  412. this vectorized env) throws any error during env stepping, we will try
  413. to restart the faulty sub-environment. This is done
  414. without disturbing the other (still intact) sub-environments.
  415. """
  416. self.make_env = make_env
  417. self.envs = existing_envs
  418. self.num_envs = num_envs
  419. self.restart_failed_sub_environments = restart_failed_sub_environments
  420. self.terminateds = set()
  421. self.truncateds = set()
  422. while len(self.envs) < self.num_envs:
  423. self.envs.append(self.make_env(len(self.envs)))
  424. for env in self.envs:
  425. assert isinstance(env, MultiAgentEnv)
  426. self._init_env_state(idx=None)
  427. self._unwrapped_env = self.envs[0].unwrapped
  428. @override(BaseEnv)
  429. def poll(
  430. self,
  431. ) -> Tuple[
  432. MultiEnvDict,
  433. MultiEnvDict,
  434. MultiEnvDict,
  435. MultiEnvDict,
  436. MultiEnvDict,
  437. MultiEnvDict,
  438. ]:
  439. obs, rewards, terminateds, truncateds, infos = {}, {}, {}, {}, {}
  440. for i, env_state in enumerate(self.env_states):
  441. (
  442. obs[i],
  443. rewards[i],
  444. terminateds[i],
  445. truncateds[i],
  446. infos[i],
  447. ) = env_state.poll()
  448. return obs, rewards, terminateds, truncateds, infos, {}
  449. @override(BaseEnv)
  450. def send_actions(self, action_dict: MultiEnvDict) -> None:
  451. for env_id, agent_dict in action_dict.items():
  452. if env_id in self.terminateds or env_id in self.truncateds:
  453. raise ValueError(
  454. f"Env {env_id} is already done and cannot accept new actions"
  455. )
  456. env = self.envs[env_id]
  457. try:
  458. obs, rewards, terminateds, truncateds, infos = env.step(agent_dict)
  459. except Exception as e:
  460. if self.restart_failed_sub_environments:
  461. logger.exception(e.args[0])
  462. self.try_restart(env_id=env_id)
  463. obs = e
  464. rewards = {}
  465. terminateds = {"__all__": True}
  466. truncateds = {"__all__": False}
  467. infos = {}
  468. else:
  469. raise e
  470. assert isinstance(
  471. obs, (dict, Exception)
  472. ), "Not a multi-agent obs dict or an Exception!"
  473. assert isinstance(rewards, dict), "Not a multi-agent reward dict!"
  474. assert isinstance(terminateds, dict), "Not a multi-agent terminateds dict!"
  475. assert isinstance(truncateds, dict), "Not a multi-agent truncateds dict!"
  476. assert isinstance(infos, dict), "Not a multi-agent info dict!"
  477. if isinstance(obs, dict):
  478. info_diff = set(infos).difference(set(obs))
  479. if info_diff and info_diff != {"__common__"}:
  480. raise ValueError(
  481. "Key set for infos must be a subset of obs (plus optionally "
  482. "the '__common__' key for infos concerning all/no agents): "
  483. "{} vs {}".format(infos.keys(), obs.keys())
  484. )
  485. if "__all__" not in terminateds:
  486. raise ValueError(
  487. "In multi-agent environments, '__all__': True|False must "
  488. "be included in the 'terminateds' dict: got {}.".format(terminateds)
  489. )
  490. elif "__all__" not in truncateds:
  491. raise ValueError(
  492. "In multi-agent environments, '__all__': True|False must "
  493. "be included in the 'truncateds' dict: got {}.".format(truncateds)
  494. )
  495. if terminateds["__all__"]:
  496. self.terminateds.add(env_id)
  497. if truncateds["__all__"]:
  498. self.truncateds.add(env_id)
  499. self.env_states[env_id].observe(
  500. obs, rewards, terminateds, truncateds, infos
  501. )
  502. @override(BaseEnv)
  503. def try_reset(
  504. self,
  505. env_id: Optional[EnvID] = None,
  506. *,
  507. seed: Optional[int] = None,
  508. options: Optional[dict] = None,
  509. ) -> Optional[Tuple[MultiEnvDict, MultiEnvDict]]:
  510. ret_obs = {}
  511. ret_infos = {}
  512. if isinstance(env_id, int):
  513. env_id = [env_id]
  514. if env_id is None:
  515. env_id = list(range(len(self.envs)))
  516. for idx in env_id:
  517. obs, infos = self.env_states[idx].reset(seed=seed, options=options)
  518. if isinstance(obs, Exception):
  519. if self.restart_failed_sub_environments:
  520. self.env_states[idx].env = self.envs[idx] = self.make_env(idx)
  521. else:
  522. raise obs
  523. else:
  524. assert isinstance(obs, dict), "Not a multi-agent obs dict!"
  525. if obs is not None:
  526. if idx in self.terminateds:
  527. self.terminateds.remove(idx)
  528. if idx in self.truncateds:
  529. self.truncateds.remove(idx)
  530. ret_obs[idx] = obs
  531. ret_infos[idx] = infos
  532. return ret_obs, ret_infos
  533. @override(BaseEnv)
  534. def try_restart(self, env_id: Optional[EnvID] = None) -> None:
  535. if isinstance(env_id, int):
  536. env_id = [env_id]
  537. if env_id is None:
  538. env_id = list(range(len(self.envs)))
  539. for idx in env_id:
  540. # Try closing down the old (possibly faulty) sub-env, but ignore errors.
  541. try:
  542. self.envs[idx].close()
  543. except Exception as e:
  544. if log_once("close_sub_env"):
  545. logger.warning(
  546. "Trying to close old and replaced sub-environment (at vector "
  547. f"index={idx}), but closing resulted in error:\n{e}"
  548. )
  549. # Try recreating the sub-env.
  550. logger.warning(f"Trying to restart sub-environment at index {idx}.")
  551. self.env_states[idx].env = self.envs[idx] = self.make_env(idx)
  552. logger.warning(f"Sub-environment at index {idx} restarted successfully.")
  553. @override(BaseEnv)
  554. def get_sub_environments(
  555. self, as_dict: bool = False
  556. ) -> Union[Dict[str, EnvType], List[EnvType]]:
  557. if as_dict:
  558. return {_id: env_state.env for _id, env_state in enumerate(self.env_states)}
  559. return [state.env for state in self.env_states]
  560. @override(BaseEnv)
  561. def try_render(self, env_id: Optional[EnvID] = None) -> None:
  562. if env_id is None:
  563. env_id = 0
  564. assert isinstance(env_id, int)
  565. return self.envs[env_id].render()
  566. @property
  567. @override(BaseEnv)
  568. def observation_space(self) -> gym.spaces.Dict:
  569. return self.envs[0].observation_space
  570. @property
  571. @override(BaseEnv)
  572. def action_space(self) -> gym.Space:
  573. return self.envs[0].action_space
  574. @override(BaseEnv)
  575. def get_agent_ids(self) -> Set[AgentID]:
  576. return self.envs[0].get_agent_ids()
  577. def _init_env_state(self, idx: Optional[int] = None) -> None:
  578. """Resets all or one particular sub-environment's state (by index).
  579. Args:
  580. idx: The index to reset at. If None, reset all the sub-environments' states.
  581. """
  582. # If index is None, reset all sub-envs' states:
  583. if idx is None:
  584. self.env_states = [
  585. _MultiAgentEnvState(env, self.restart_failed_sub_environments)
  586. for env in self.envs
  587. ]
  588. # Index provided, reset only the sub-env's state at the given index.
  589. else:
  590. assert isinstance(idx, int)
  591. self.env_states[idx] = _MultiAgentEnvState(
  592. self.envs[idx], self.restart_failed_sub_environments
  593. )
  594. @OldAPIStack
  595. class _MultiAgentEnvState:
  596. def __init__(self, env: MultiAgentEnv, return_error_as_obs: bool = False):
  597. assert isinstance(env, MultiAgentEnv)
  598. self.env = env
  599. self.return_error_as_obs = return_error_as_obs
  600. self.initialized = False
  601. self.last_obs = {}
  602. self.last_rewards = {}
  603. self.last_terminateds = {"__all__": False}
  604. self.last_truncateds = {"__all__": False}
  605. self.last_infos = {}
  606. def poll(
  607. self,
  608. ) -> Tuple[
  609. MultiAgentDict,
  610. MultiAgentDict,
  611. MultiAgentDict,
  612. MultiAgentDict,
  613. MultiAgentDict,
  614. ]:
  615. if not self.initialized:
  616. # TODO(sven): Should we make it possible to pass in a seed here?
  617. self.reset()
  618. self.initialized = True
  619. observations = self.last_obs
  620. rewards = {}
  621. terminateds = {"__all__": self.last_terminateds["__all__"]}
  622. truncateds = {"__all__": self.last_truncateds["__all__"]}
  623. infos = self.last_infos
  624. # If episode is done or we have an error, release everything we have.
  625. if (
  626. terminateds["__all__"]
  627. or truncateds["__all__"]
  628. or isinstance(observations, Exception)
  629. ):
  630. rewards = self.last_rewards
  631. self.last_rewards = {}
  632. terminateds = self.last_terminateds
  633. if isinstance(observations, Exception):
  634. terminateds["__all__"] = True
  635. truncateds["__all__"] = False
  636. self.last_terminateds = {}
  637. truncateds = self.last_truncateds
  638. self.last_truncateds = {}
  639. self.last_obs = {}
  640. infos = self.last_infos
  641. self.last_infos = {}
  642. # Only release those agents' rewards/terminateds/truncateds/infos, whose
  643. # observations we have.
  644. else:
  645. for ag in observations.keys():
  646. if ag in self.last_rewards:
  647. rewards[ag] = self.last_rewards[ag]
  648. del self.last_rewards[ag]
  649. if ag in self.last_terminateds:
  650. terminateds[ag] = self.last_terminateds[ag]
  651. del self.last_terminateds[ag]
  652. if ag in self.last_truncateds:
  653. truncateds[ag] = self.last_truncateds[ag]
  654. del self.last_truncateds[ag]
  655. self.last_terminateds["__all__"] = False
  656. self.last_truncateds["__all__"] = False
  657. return observations, rewards, terminateds, truncateds, infos
  658. def observe(
  659. self,
  660. obs: MultiAgentDict,
  661. rewards: MultiAgentDict,
  662. terminateds: MultiAgentDict,
  663. truncateds: MultiAgentDict,
  664. infos: MultiAgentDict,
  665. ):
  666. self.last_obs = obs
  667. for ag, r in rewards.items():
  668. if ag in self.last_rewards:
  669. self.last_rewards[ag] += r
  670. else:
  671. self.last_rewards[ag] = r
  672. for ag, d in terminateds.items():
  673. if ag in self.last_terminateds:
  674. self.last_terminateds[ag] = self.last_terminateds[ag] or d
  675. else:
  676. self.last_terminateds[ag] = d
  677. for ag, t in truncateds.items():
  678. if ag in self.last_truncateds:
  679. self.last_truncateds[ag] = self.last_truncateds[ag] or t
  680. else:
  681. self.last_truncateds[ag] = t
  682. self.last_infos = infos
  683. def reset(
  684. self,
  685. *,
  686. seed: Optional[int] = None,
  687. options: Optional[dict] = None,
  688. ) -> Tuple[MultiAgentDict, MultiAgentDict]:
  689. try:
  690. obs_and_infos = self.env.reset(seed=seed, options=options)
  691. except Exception as e:
  692. if self.return_error_as_obs:
  693. logger.exception(e.args[0])
  694. obs_and_infos = e, e
  695. else:
  696. raise e
  697. self.last_obs, self.last_infos = obs_and_infos
  698. self.last_rewards = {}
  699. self.last_terminateds = {"__all__": False}
  700. self.last_truncateds = {"__all__": False}
  701. return self.last_obs, self.last_infos