| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- from typing import Optional
- import gymnasium as gym
- import numpy as np
- from ray.rllib.env.multi_agent_env import MultiAgentEnv
- from ray.rllib.env.utils import try_import_pyspiel
- pyspiel = try_import_pyspiel(error=True)
- class OpenSpielEnv(MultiAgentEnv):
- def __init__(self, env):
- super().__init__()
- self.env = env
- self.agents = self.possible_agents = list(range(self.env.num_players()))
- # Store the open-spiel game type.
- self.type = self.env.get_type()
- # Stores the current open-spiel game state.
- self.state = None
- self.observation_space = gym.spaces.Dict(
- {
- aid: gym.spaces.Box(
- float("-inf"),
- float("inf"),
- (self.env.observation_tensor_size(),),
- dtype=np.float32,
- )
- for aid in self.possible_agents
- }
- )
- self.action_space = gym.spaces.Dict(
- {
- aid: gym.spaces.Discrete(self.env.num_distinct_actions())
- for aid in self.possible_agents
- }
- )
- def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
- self.state = self.env.new_initial_state()
- return self._get_obs(), {}
- def step(self, action):
- # Before applying action(s), there could be chance nodes.
- # E.g. if env has to figure out, which agent's action should get
- # resolved first in a simultaneous node.
- self._solve_chance_nodes()
- penalties = {}
- # Sequential game:
- if str(self.type.dynamics) == "Dynamics.SEQUENTIAL":
- curr_player = self.state.current_player()
- assert curr_player in action
- try:
- self.state.apply_action(action[curr_player])
- # TODO: (sven) resolve this hack by publishing legal actions
- # with each step.
- except pyspiel.SpielError:
- self.state.apply_action(np.random.choice(self.state.legal_actions()))
- penalties[curr_player] = -0.1
- # Compile rewards dict.
- rewards = dict(enumerate(self.state.returns()))
- # Simultaneous game.
- else:
- assert self.state.current_player() == -2
- # Apparently, this works, even if one or more actions are invalid.
- self.state.apply_actions([action[ag] for ag in range(self.num_agents)])
- # Now that we have applied all actions, get the next obs.
- obs = self._get_obs()
- # Compile rewards dict and add the accumulated penalties
- # (for taking invalid actions).
- rewards = dict(enumerate(self.state.returns()))
- for ag, penalty in penalties.items():
- rewards[ag] += penalty
- # Are we done?
- is_terminated = self.state.is_terminal()
- terminateds = dict(
- {ag: is_terminated for ag in range(self.num_agents)},
- **{"__all__": is_terminated}
- )
- truncateds = dict(
- {ag: False for ag in range(self.num_agents)}, **{"__all__": False}
- )
- return obs, rewards, terminateds, truncateds, {}
- def render(self, mode=None) -> None:
- if mode == "human":
- print(self.state)
- def _get_obs(self):
- # Before calculating an observation, there could be chance nodes
- # (that may have an effect on the actual observations).
- # E.g. After reset, figure out initial (random) positions of the
- # agents.
- self._solve_chance_nodes()
- if self.state.is_terminal():
- return {}
- # Sequential game:
- if str(self.type.dynamics) == "Dynamics.SEQUENTIAL":
- curr_player = self.state.current_player()
- return {
- curr_player: np.reshape(self.state.observation_tensor(), [-1]).astype(
- np.float32
- )
- }
- # Simultaneous game.
- else:
- assert self.state.current_player() == -2
- return {
- ag: np.reshape(self.state.observation_tensor(ag), [-1]).astype(
- np.float32
- )
- for ag in range(self.num_agents)
- }
- def _solve_chance_nodes(self):
- # Chance node(s): Sample a (non-player) action and apply.
- while self.state.is_chance_node():
- assert self.state.current_player() == -1
- actions, probs = zip(*self.state.chance_outcomes())
- action = np.random.choice(actions, p=probs)
- self.state.apply_action(action)
|