open_spiel.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. from typing import Optional
  2. import gymnasium as gym
  3. import numpy as np
  4. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  5. from ray.rllib.env.utils import try_import_pyspiel
  6. pyspiel = try_import_pyspiel(error=True)
  7. class OpenSpielEnv(MultiAgentEnv):
  8. def __init__(self, env):
  9. super().__init__()
  10. self.env = env
  11. self.agents = self.possible_agents = list(range(self.env.num_players()))
  12. # Store the open-spiel game type.
  13. self.type = self.env.get_type()
  14. # Stores the current open-spiel game state.
  15. self.state = None
  16. self.observation_space = gym.spaces.Dict(
  17. {
  18. aid: gym.spaces.Box(
  19. float("-inf"),
  20. float("inf"),
  21. (self.env.observation_tensor_size(),),
  22. dtype=np.float32,
  23. )
  24. for aid in self.possible_agents
  25. }
  26. )
  27. self.action_space = gym.spaces.Dict(
  28. {
  29. aid: gym.spaces.Discrete(self.env.num_distinct_actions())
  30. for aid in self.possible_agents
  31. }
  32. )
  33. def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
  34. self.state = self.env.new_initial_state()
  35. return self._get_obs(), {}
  36. def step(self, action):
  37. # Before applying action(s), there could be chance nodes.
  38. # E.g. if env has to figure out, which agent's action should get
  39. # resolved first in a simultaneous node.
  40. self._solve_chance_nodes()
  41. penalties = {}
  42. # Sequential game:
  43. if str(self.type.dynamics) == "Dynamics.SEQUENTIAL":
  44. curr_player = self.state.current_player()
  45. assert curr_player in action
  46. try:
  47. self.state.apply_action(action[curr_player])
  48. # TODO: (sven) resolve this hack by publishing legal actions
  49. # with each step.
  50. except pyspiel.SpielError:
  51. self.state.apply_action(np.random.choice(self.state.legal_actions()))
  52. penalties[curr_player] = -0.1
  53. # Compile rewards dict.
  54. rewards = dict(enumerate(self.state.returns()))
  55. # Simultaneous game.
  56. else:
  57. assert self.state.current_player() == -2
  58. # Apparently, this works, even if one or more actions are invalid.
  59. self.state.apply_actions([action[ag] for ag in range(self.num_agents)])
  60. # Now that we have applied all actions, get the next obs.
  61. obs = self._get_obs()
  62. # Compile rewards dict and add the accumulated penalties
  63. # (for taking invalid actions).
  64. rewards = dict(enumerate(self.state.returns()))
  65. for ag, penalty in penalties.items():
  66. rewards[ag] += penalty
  67. # Are we done?
  68. is_terminated = self.state.is_terminal()
  69. terminateds = dict(
  70. {ag: is_terminated for ag in range(self.num_agents)},
  71. **{"__all__": is_terminated}
  72. )
  73. truncateds = dict(
  74. {ag: False for ag in range(self.num_agents)}, **{"__all__": False}
  75. )
  76. return obs, rewards, terminateds, truncateds, {}
  77. def render(self, mode=None) -> None:
  78. if mode == "human":
  79. print(self.state)
  80. def _get_obs(self):
  81. # Before calculating an observation, there could be chance nodes
  82. # (that may have an effect on the actual observations).
  83. # E.g. After reset, figure out initial (random) positions of the
  84. # agents.
  85. self._solve_chance_nodes()
  86. if self.state.is_terminal():
  87. return {}
  88. # Sequential game:
  89. if str(self.type.dynamics) == "Dynamics.SEQUENTIAL":
  90. curr_player = self.state.current_player()
  91. return {
  92. curr_player: np.reshape(self.state.observation_tensor(), [-1]).astype(
  93. np.float32
  94. )
  95. }
  96. # Simultaneous game.
  97. else:
  98. assert self.state.current_player() == -2
  99. return {
  100. ag: np.reshape(self.state.observation_tensor(ag), [-1]).astype(
  101. np.float32
  102. )
  103. for ag in range(self.num_agents)
  104. }
  105. def _solve_chance_nodes(self):
  106. # Chance node(s): Sample a (non-player) action and apply.
  107. while self.state.is_chance_node():
  108. assert self.state.current_player() == -1
  109. actions, probs = zip(*self.state.chance_outcomes())
  110. action = np.random.choice(actions, p=probs)
  111. self.state.apply_action(action)