env.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. """Common pre-checks for all RLlib experiments."""
  2. import logging
  3. from typing import TYPE_CHECKING, Set
  4. import gymnasium as gym
  5. import numpy as np
  6. import tree # pip install dm_tree
  7. from ray.rllib.utils.annotations import DeveloperAPI
  8. from ray.rllib.utils.error import ERR_MSG_OLD_GYM_API, UnsupportedSpaceException
  9. from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
  10. from ray.util import log_once
  11. if TYPE_CHECKING:
  12. from ray.rllib.env import MultiAgentEnv
  13. logger = logging.getLogger(__name__)
  14. @DeveloperAPI
  15. def check_multiagent_environments(env: "MultiAgentEnv") -> None:
  16. """Checking for common errors in RLlib MultiAgentEnvs.
  17. Args:
  18. env: The env to be checked.
  19. """
  20. from ray.rllib.env import MultiAgentEnv
  21. if not isinstance(env, MultiAgentEnv):
  22. raise ValueError("The passed env is not a MultiAgentEnv.")
  23. elif not (
  24. hasattr(env, "observation_space")
  25. and hasattr(env, "action_space")
  26. and hasattr(env, "_agent_ids")
  27. ):
  28. if log_once("ma_env_super_ctor_called"):
  29. logger.warning(
  30. f"Your MultiAgentEnv {env} does not have some or all of the needed "
  31. "base-class attributes! Make sure you call `super().__init__()` from "
  32. "within your MutiAgentEnv's constructor. "
  33. "This will raise an error in the future."
  34. )
  35. return
  36. try:
  37. obs_and_infos = env.reset(seed=42, options={})
  38. except Exception as e:
  39. raise ValueError(
  40. ERR_MSG_OLD_GYM_API.format(
  41. env, "In particular, the `reset()` method seems to be faulty."
  42. )
  43. ) from e
  44. reset_obs, reset_infos = obs_and_infos
  45. _check_if_element_multi_agent_dict(env, reset_obs, "reset()")
  46. sampled_action = {
  47. aid: env.get_action_space(aid).sample() for aid in reset_obs.keys()
  48. }
  49. _check_if_element_multi_agent_dict(
  50. env, sampled_action, "get_action_space(agent_id=..).sample()"
  51. )
  52. try:
  53. results = env.step(sampled_action)
  54. except Exception as e:
  55. raise ValueError(
  56. ERR_MSG_OLD_GYM_API.format(
  57. env, "In particular, the `step()` method seems to be faulty."
  58. )
  59. ) from e
  60. next_obs, reward, done, truncated, info = results
  61. _check_if_element_multi_agent_dict(env, next_obs, "step, next_obs")
  62. _check_if_element_multi_agent_dict(env, reward, "step, reward")
  63. _check_if_element_multi_agent_dict(env, done, "step, done")
  64. _check_if_element_multi_agent_dict(env, truncated, "step, truncated")
  65. _check_if_element_multi_agent_dict(env, info, "step, info", allow_common=True)
  66. _check_reward({"dummy_env_id": reward}, base_env=True, agent_ids=env.agents)
  67. _check_done_and_truncated(
  68. {"dummy_env_id": done},
  69. {"dummy_env_id": truncated},
  70. base_env=True,
  71. agent_ids=env.agents,
  72. )
  73. _check_info({"dummy_env_id": info}, base_env=True, agent_ids=env.agents)
  74. def _check_reward(reward, base_env=False, agent_ids=None):
  75. if base_env:
  76. for _, multi_agent_dict in reward.items():
  77. for agent_id, rew in multi_agent_dict.items():
  78. if not (
  79. np.isreal(rew)
  80. and not isinstance(rew, bool)
  81. and (
  82. np.isscalar(rew)
  83. or (isinstance(rew, np.ndarray) and rew.shape == ())
  84. )
  85. ):
  86. error = (
  87. "Your step function must return rewards that are"
  88. f" integer or float. reward: {rew}. Instead it was a "
  89. f"{type(rew)}"
  90. )
  91. raise ValueError(error)
  92. if not (agent_id in agent_ids or agent_id == "__all__"):
  93. error = (
  94. f"Your reward dictionary must have agent ids that belong to "
  95. f"the environment. AgentIDs received from "
  96. f"env.agents are: {agent_ids}"
  97. )
  98. raise ValueError(error)
  99. elif not (
  100. np.isreal(reward)
  101. and not isinstance(reward, bool)
  102. and (
  103. np.isscalar(reward)
  104. or (isinstance(reward, np.ndarray) and reward.shape == ())
  105. )
  106. ):
  107. error = (
  108. "Your step function must return a reward that is integer or float. "
  109. "Instead it was a {}".format(type(reward))
  110. )
  111. raise ValueError(error)
  112. def _check_done_and_truncated(done, truncated, base_env=False, agent_ids=None):
  113. for what in ["done", "truncated"]:
  114. data = done if what == "done" else truncated
  115. if base_env:
  116. for _, multi_agent_dict in data.items():
  117. for agent_id, done_ in multi_agent_dict.items():
  118. if not isinstance(done_, (bool, np.bool_)):
  119. raise ValueError(
  120. f"Your step function must return `{what}s` that are "
  121. f"boolean. But instead was a {type(data)}"
  122. )
  123. if not (agent_id in agent_ids or agent_id == "__all__"):
  124. error = (
  125. f"Your `{what}s` dictionary must have agent ids that "
  126. f"belong to the environment. AgentIDs received from "
  127. f"env.agents are: {agent_ids}"
  128. )
  129. raise ValueError(error)
  130. elif not isinstance(data, (bool, np.bool_)):
  131. error = (
  132. f"Your step function must return a `{what}` that is a boolean. But "
  133. f"instead was a {type(data)}"
  134. )
  135. raise ValueError(error)
  136. def _check_info(info, base_env=False, agent_ids=None):
  137. if base_env:
  138. for _, multi_agent_dict in info.items():
  139. for agent_id, inf in multi_agent_dict.items():
  140. if not isinstance(inf, dict):
  141. raise ValueError(
  142. "Your step function must return infos that are a dict. "
  143. f"instead was a {type(inf)}: element: {inf}"
  144. )
  145. if not (
  146. agent_id in agent_ids
  147. or agent_id == "__all__"
  148. or agent_id == "__common__"
  149. ):
  150. error = (
  151. f"Your dones dictionary must have agent ids that belong to "
  152. f"the environment. AgentIDs received from "
  153. f"env.agents are: {agent_ids}"
  154. )
  155. raise ValueError(error)
  156. elif not isinstance(info, dict):
  157. error = (
  158. "Your step function must return a info that "
  159. f"is a dict. element type: {type(info)}. element: {info}"
  160. )
  161. raise ValueError(error)
  162. def _not_contained_error(func_name, _type):
  163. _error = (
  164. f"The {_type} collected from {func_name} was not contained within"
  165. f" your env's {_type} space. Its possible that there was a type"
  166. f"mismatch (for example {_type}s of np.float32 and a space of"
  167. f"np.float64 {_type}s), or that one of the sub-{_type}s was"
  168. f"out of bounds"
  169. )
  170. return _error
  171. def _check_if_element_multi_agent_dict(
  172. env,
  173. element,
  174. function_string,
  175. base_env=False,
  176. allow_common=False,
  177. ):
  178. if not isinstance(element, dict):
  179. if base_env:
  180. error = (
  181. f"The element returned by {function_string} contains values "
  182. f"that are not MultiAgentDicts. Instead, they are of "
  183. f"type: {type(element)}"
  184. )
  185. else:
  186. error = (
  187. f"The element returned by {function_string} is not a "
  188. f"MultiAgentDict. Instead, it is of type: "
  189. f" {type(element)}"
  190. )
  191. raise ValueError(error)
  192. agent_ids: Set = set(env.agents)
  193. agent_ids.add("__all__")
  194. if allow_common:
  195. agent_ids.add("__common__")
  196. if not all(k in agent_ids for k in element):
  197. if base_env:
  198. error = (
  199. f"The element returned by {function_string} has agent_ids"
  200. f" that are not the names of the agents in the env."
  201. f"agent_ids in this\nMultiEnvDict:"
  202. f" {list(element.keys())}\nAgentIDs in this env: "
  203. f"{env.agents}"
  204. )
  205. else:
  206. error = (
  207. f"The element returned by {function_string} has agent_ids"
  208. f" that are not the names of the agents in the env. "
  209. f"\nAgentIDs in this MultiAgentDict: "
  210. f"{list(element.keys())}\nAgentIDs in this env: "
  211. f"{env.agents}. You likely need to add the attribute `agents` to your "
  212. f"env, which is a list containing the IDs of agents currently in your "
  213. f"env/episode, as well as, `possible_agents`, which is a list of all "
  214. f"possible agents that could ever show up in your env."
  215. )
  216. raise ValueError(error)
  217. def _find_offending_sub_space(space, value):
  218. """Returns error, value, and space when offending `space.contains(value)` fails.
  219. Returns only the offending sub-value/sub-space in case `space` is a complex Tuple
  220. or Dict space.
  221. Args:
  222. space: The gym.Space to check.
  223. value: The actual (numpy) value to check for matching `space`.
  224. Returns:
  225. Tuple consisting of 1) key-sequence of the offending sub-space or the empty
  226. string if `space` is not complex (Tuple or Dict), 2) the offending sub-space,
  227. 3) the offending sub-space's dtype, 4) the offending sub-value, 5) the offending
  228. sub-value's dtype.
  229. .. testcode::
  230. :skipif: True
  231. path, space, space_dtype, value, value_dtype = _find_offending_sub_space(
  232. gym.spaces.Dict({
  233. -2.0, 1.5, (2, ), np.int8), np.array([-1.5, 3.0])
  234. )
  235. """
  236. if not isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)):
  237. return None, space, space.dtype, value, _get_type(value)
  238. structured_space = get_base_struct_from_space(space)
  239. def map_fn(p, s, v):
  240. if not s.contains(v):
  241. raise UnsupportedSpaceException((p, s, v))
  242. try:
  243. tree.map_structure_with_path(map_fn, structured_space, value)
  244. except UnsupportedSpaceException as e:
  245. space, value = e.args[0][1], e.args[0][2]
  246. return "->".join(e.args[0][0]), space, space.dtype, value, _get_type(value)
  247. # This is actually an error.
  248. return None, None, None, None, None
  249. def _get_type(var):
  250. return var.dtype if hasattr(var, "dtype") else type(var)