from typing import Optional import gymnasium as gym import numpy as np from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.utils import try_import_pyspiel pyspiel = try_import_pyspiel(error=True) class OpenSpielEnv(MultiAgentEnv): def __init__(self, env): super().__init__() self.env = env self.agents = self.possible_agents = list(range(self.env.num_players())) # Store the open-spiel game type. self.type = self.env.get_type() # Stores the current open-spiel game state. self.state = None self.observation_space = gym.spaces.Dict( { aid: gym.spaces.Box( float("-inf"), float("inf"), (self.env.observation_tensor_size(),), dtype=np.float32, ) for aid in self.possible_agents } ) self.action_space = gym.spaces.Dict( { aid: gym.spaces.Discrete(self.env.num_distinct_actions()) for aid in self.possible_agents } ) def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): self.state = self.env.new_initial_state() return self._get_obs(), {} def step(self, action): # Before applying action(s), there could be chance nodes. # E.g. if env has to figure out, which agent's action should get # resolved first in a simultaneous node. self._solve_chance_nodes() penalties = {} # Sequential game: if str(self.type.dynamics) == "Dynamics.SEQUENTIAL": curr_player = self.state.current_player() assert curr_player in action try: self.state.apply_action(action[curr_player]) # TODO: (sven) resolve this hack by publishing legal actions # with each step. except pyspiel.SpielError: self.state.apply_action(np.random.choice(self.state.legal_actions())) penalties[curr_player] = -0.1 # Compile rewards dict. rewards = dict(enumerate(self.state.returns())) # Simultaneous game. else: assert self.state.current_player() == -2 # Apparently, this works, even if one or more actions are invalid. self.state.apply_actions([action[ag] for ag in range(self.num_agents)]) # Now that we have applied all actions, get the next obs. obs = self._get_obs() # Compile rewards dict and add the accumulated penalties # (for taking invalid actions). rewards = dict(enumerate(self.state.returns())) for ag, penalty in penalties.items(): rewards[ag] += penalty # Are we done? is_terminated = self.state.is_terminal() terminateds = dict( {ag: is_terminated for ag in range(self.num_agents)}, **{"__all__": is_terminated} ) truncateds = dict( {ag: False for ag in range(self.num_agents)}, **{"__all__": False} ) return obs, rewards, terminateds, truncateds, {} def render(self, mode=None) -> None: if mode == "human": print(self.state) def _get_obs(self): # Before calculating an observation, there could be chance nodes # (that may have an effect on the actual observations). # E.g. After reset, figure out initial (random) positions of the # agents. self._solve_chance_nodes() if self.state.is_terminal(): return {} # Sequential game: if str(self.type.dynamics) == "Dynamics.SEQUENTIAL": curr_player = self.state.current_player() return { curr_player: np.reshape(self.state.observation_tensor(), [-1]).astype( np.float32 ) } # Simultaneous game. else: assert self.state.current_player() == -2 return { ag: np.reshape(self.state.observation_tensor(ag), [-1]).astype( np.float32 ) for ag in range(self.num_agents) } def _solve_chance_nodes(self): # Chance node(s): Sample a (non-player) action and apply. while self.state.is_chance_node(): assert self.state.current_player() == -1 actions, probs = zip(*self.state.chance_outcomes()) action = np.random.choice(actions, p=probs) self.state.apply_action(action)