pettingzoo_env.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. from typing import Optional
  2. import gymnasium as gym
  3. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  4. from ray.rllib.utils.annotations import PublicAPI
  5. @PublicAPI
  6. class PettingZooEnv(MultiAgentEnv):
  7. """An interface to the PettingZoo MARL environment library.
  8. See: https://github.com/Farama-Foundation/PettingZoo
  9. Inherits from MultiAgentEnv and exposes a given AEC
  10. (actor-environment-cycle) game from the PettingZoo project via the
  11. MultiAgentEnv public API.
  12. Note that the wrapper has the following important limitation:
  13. Environments are positive sum games (-> Agents are expected to cooperate
  14. to maximize reward). This isn't a hard restriction, it just that
  15. standard algorithms aren't expected to work well in highly competitive
  16. games.
  17. Also note that the earlier existing restriction of all agents having the same
  18. observation- and action spaces has been lifted. Different agents can now have
  19. different spaces and the entire environment's e.g. `self.action_space` is a Dict
  20. mapping agent IDs to individual agents' spaces. Same for `self.observation_space`.
  21. .. testcode::
  22. :skipif: True
  23. from pettingzoo.butterfly import prison_v3
  24. from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
  25. env = PettingZooEnv(prison_v3.env())
  26. obs, infos = env.reset()
  27. # only returns the observation for the agent which should be stepping
  28. print(obs)
  29. .. testoutput::
  30. {
  31. 'prisoner_0': array([[[0, 0, 0],
  32. [0, 0, 0],
  33. [0, 0, 0],
  34. ...,
  35. [0, 0, 0],
  36. [0, 0, 0],
  37. [0, 0, 0]]], dtype=uint8)
  38. }
  39. .. testcode::
  40. :skipif: True
  41. obs, rewards, terminateds, truncateds, infos = env.step({
  42. "prisoner_0": 1
  43. })
  44. # only returns the observation, reward, info, etc, for
  45. # the agent who's turn is next.
  46. print(obs)
  47. .. testoutput::
  48. {
  49. 'prisoner_1': array([[[0, 0, 0],
  50. [0, 0, 0],
  51. [0, 0, 0],
  52. ...,
  53. [0, 0, 0],
  54. [0, 0, 0],
  55. [0, 0, 0]]], dtype=uint8)
  56. }
  57. .. testcode::
  58. :skipif: True
  59. print(rewards)
  60. .. testoutput::
  61. {
  62. 'prisoner_1': 0
  63. }
  64. .. testcode::
  65. :skipif: True
  66. print(terminateds)
  67. .. testoutput::
  68. {
  69. 'prisoner_1': False, '__all__': False
  70. }
  71. .. testcode::
  72. :skipif: True
  73. print(truncateds)
  74. .. testoutput::
  75. {
  76. 'prisoner_1': False, '__all__': False
  77. }
  78. .. testcode::
  79. :skipif: True
  80. print(infos)
  81. .. testoutput::
  82. {
  83. 'prisoner_1': {'map_tuple': (1, 0)}
  84. }
  85. """
  86. def __init__(self, env):
  87. super().__init__()
  88. self.env = env
  89. env.reset()
  90. self._agent_ids = set(self.env.agents)
  91. # If these important attributes are not set, try to infer them.
  92. if not self.agents:
  93. self.agents = list(self._agent_ids)
  94. if not self.possible_agents:
  95. self.possible_agents = self.agents.copy()
  96. # Set these attributes for sampling in `VectorMultiAgentEnv`s.
  97. self.observation_spaces = {
  98. aid: self.env.observation_space(aid) for aid in self._agent_ids
  99. }
  100. self.action_spaces = {
  101. aid: self.env.action_space(aid) for aid in self._agent_ids
  102. }
  103. self.observation_space = gym.spaces.Dict(self.observation_spaces)
  104. self.action_space = gym.spaces.Dict(self.action_spaces)
  105. def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
  106. info = self.env.reset(seed=seed, options=options)
  107. return (
  108. {self.env.agent_selection: self.env.observe(self.env.agent_selection)},
  109. info or {},
  110. )
  111. def step(self, action):
  112. self.env.step(action[self.env.agent_selection])
  113. obs_d = {}
  114. rew_d = {}
  115. terminated_d = {}
  116. truncated_d = {}
  117. info_d = {}
  118. while self.env.agents:
  119. obs, rew, terminated, truncated, info = self.env.last()
  120. agent_id = self.env.agent_selection
  121. obs_d[agent_id] = obs
  122. rew_d[agent_id] = rew
  123. terminated_d[agent_id] = terminated
  124. truncated_d[agent_id] = truncated
  125. info_d[agent_id] = info
  126. if (
  127. self.env.terminations[self.env.agent_selection]
  128. or self.env.truncations[self.env.agent_selection]
  129. ):
  130. self.env.step(None)
  131. else:
  132. break
  133. all_gone = not self.env.agents
  134. terminated_d["__all__"] = all_gone and all(terminated_d.values())
  135. truncated_d["__all__"] = all_gone and all(truncated_d.values())
  136. return obs_d, rew_d, terminated_d, truncated_d, info_d
  137. def close(self):
  138. self.env.close()
  139. def render(self):
  140. return self.env.render(self.render_mode)
  141. @property
  142. def get_sub_environments(self):
  143. return self.env.unwrapped
  144. @PublicAPI
  145. class ParallelPettingZooEnv(MultiAgentEnv):
  146. def __init__(self, env):
  147. super().__init__()
  148. self.par_env = env
  149. self.par_env.reset()
  150. self._agent_ids = set(self.par_env.agents)
  151. # If these important attributes are not set, try to infer them.
  152. if not self.agents:
  153. self.agents = list(self._agent_ids)
  154. if not self.possible_agents:
  155. self.possible_agents = self.agents.copy()
  156. self.observation_space = gym.spaces.Dict(
  157. {aid: self.par_env.observation_space(aid) for aid in self._agent_ids}
  158. )
  159. self.action_space = gym.spaces.Dict(
  160. {aid: self.par_env.action_space(aid) for aid in self._agent_ids}
  161. )
  162. def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
  163. obs, info = self.par_env.reset(seed=seed, options=options)
  164. return obs, info or {}
  165. def step(self, action_dict):
  166. obss, rews, terminateds, truncateds, infos = self.par_env.step(action_dict)
  167. terminateds["__all__"] = all(terminateds.values())
  168. truncateds["__all__"] = all(truncateds.values())
  169. return obss, rews, terminateds, truncateds, infos
  170. def close(self):
  171. self.par_env.close()
  172. def render(self):
  173. return self.par_env.render(self.render_mode)
  174. @property
  175. def get_sub_environments(self):
  176. return self.par_env.unwrapped