| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230 |
- from typing import Optional
- import gymnasium as gym
- from ray.rllib.env.multi_agent_env import MultiAgentEnv
- from ray.rllib.utils.annotations import PublicAPI
- @PublicAPI
- class PettingZooEnv(MultiAgentEnv):
- """An interface to the PettingZoo MARL environment library.
- See: https://github.com/Farama-Foundation/PettingZoo
- Inherits from MultiAgentEnv and exposes a given AEC
- (actor-environment-cycle) game from the PettingZoo project via the
- MultiAgentEnv public API.
- Note that the wrapper has the following important limitation:
- Environments are positive sum games (-> Agents are expected to cooperate
- to maximize reward). This isn't a hard restriction, it just that
- standard algorithms aren't expected to work well in highly competitive
- games.
- Also note that the earlier existing restriction of all agents having the same
- observation- and action spaces has been lifted. Different agents can now have
- different spaces and the entire environment's e.g. `self.action_space` is a Dict
- mapping agent IDs to individual agents' spaces. Same for `self.observation_space`.
- .. testcode::
- :skipif: True
- from pettingzoo.butterfly import prison_v3
- from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
- env = PettingZooEnv(prison_v3.env())
- obs, infos = env.reset()
- # only returns the observation for the agent which should be stepping
- print(obs)
- .. testoutput::
- {
- 'prisoner_0': array([[[0, 0, 0],
- [0, 0, 0],
- [0, 0, 0],
- ...,
- [0, 0, 0],
- [0, 0, 0],
- [0, 0, 0]]], dtype=uint8)
- }
- .. testcode::
- :skipif: True
- obs, rewards, terminateds, truncateds, infos = env.step({
- "prisoner_0": 1
- })
- # only returns the observation, reward, info, etc, for
- # the agent who's turn is next.
- print(obs)
- .. testoutput::
- {
- 'prisoner_1': array([[[0, 0, 0],
- [0, 0, 0],
- [0, 0, 0],
- ...,
- [0, 0, 0],
- [0, 0, 0],
- [0, 0, 0]]], dtype=uint8)
- }
- .. testcode::
- :skipif: True
- print(rewards)
- .. testoutput::
- {
- 'prisoner_1': 0
- }
- .. testcode::
- :skipif: True
- print(terminateds)
- .. testoutput::
- {
- 'prisoner_1': False, '__all__': False
- }
- .. testcode::
- :skipif: True
- print(truncateds)
- .. testoutput::
- {
- 'prisoner_1': False, '__all__': False
- }
- .. testcode::
- :skipif: True
- print(infos)
- .. testoutput::
- {
- 'prisoner_1': {'map_tuple': (1, 0)}
- }
- """
- def __init__(self, env):
- super().__init__()
- self.env = env
- env.reset()
- self._agent_ids = set(self.env.agents)
- # If these important attributes are not set, try to infer them.
- if not self.agents:
- self.agents = list(self._agent_ids)
- if not self.possible_agents:
- self.possible_agents = self.agents.copy()
- # Set these attributes for sampling in `VectorMultiAgentEnv`s.
- self.observation_spaces = {
- aid: self.env.observation_space(aid) for aid in self._agent_ids
- }
- self.action_spaces = {
- aid: self.env.action_space(aid) for aid in self._agent_ids
- }
- self.observation_space = gym.spaces.Dict(self.observation_spaces)
- self.action_space = gym.spaces.Dict(self.action_spaces)
- def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
- info = self.env.reset(seed=seed, options=options)
- return (
- {self.env.agent_selection: self.env.observe(self.env.agent_selection)},
- info or {},
- )
- def step(self, action):
- self.env.step(action[self.env.agent_selection])
- obs_d = {}
- rew_d = {}
- terminated_d = {}
- truncated_d = {}
- info_d = {}
- while self.env.agents:
- obs, rew, terminated, truncated, info = self.env.last()
- agent_id = self.env.agent_selection
- obs_d[agent_id] = obs
- rew_d[agent_id] = rew
- terminated_d[agent_id] = terminated
- truncated_d[agent_id] = truncated
- info_d[agent_id] = info
- if (
- self.env.terminations[self.env.agent_selection]
- or self.env.truncations[self.env.agent_selection]
- ):
- self.env.step(None)
- else:
- break
- all_gone = not self.env.agents
- terminated_d["__all__"] = all_gone and all(terminated_d.values())
- truncated_d["__all__"] = all_gone and all(truncated_d.values())
- return obs_d, rew_d, terminated_d, truncated_d, info_d
- def close(self):
- self.env.close()
- def render(self):
- return self.env.render(self.render_mode)
- @property
- def get_sub_environments(self):
- return self.env.unwrapped
- @PublicAPI
- class ParallelPettingZooEnv(MultiAgentEnv):
- def __init__(self, env):
- super().__init__()
- self.par_env = env
- self.par_env.reset()
- self._agent_ids = set(self.par_env.agents)
- # If these important attributes are not set, try to infer them.
- if not self.agents:
- self.agents = list(self._agent_ids)
- if not self.possible_agents:
- self.possible_agents = self.agents.copy()
- self.observation_space = gym.spaces.Dict(
- {aid: self.par_env.observation_space(aid) for aid in self._agent_ids}
- )
- self.action_space = gym.spaces.Dict(
- {aid: self.par_env.action_space(aid) for aid in self._agent_ids}
- )
- def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
- obs, info = self.par_env.reset(seed=seed, options=options)
- return obs, info or {}
- def step(self, action_dict):
- obss, rews, terminateds, truncateds, infos = self.par_env.step(action_dict)
- terminateds["__all__"] = all(terminateds.values())
- truncateds["__all__"] = all(truncateds.values())
- return obss, rews, terminateds, truncateds, infos
- def close(self):
- self.par_env.close()
- def render(self):
- return self.par_env.render(self.render_mode)
- @property
- def get_sub_environments(self):
- return self.par_env.unwrapped
|