registration.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import copy
  2. import logging
  3. from typing import Any, Dict, Optional
  4. import gymnasium as gym
  5. from gymnasium.envs.registration import VectorizeMode
  6. from ray.rllib.env.multi_agent_env import MultiAgentEnv
  7. from ray.rllib.env.vector.sync_vector_multi_agent_env import SyncVectorMultiAgentEnv
  8. from ray.rllib.env.vector.vector_multi_agent_env import VectorMultiAgentEnv
  9. logger = logging.getLogger(__file__)
  10. def make_vec(
  11. id: str,
  12. num_envs: int = 1,
  13. vectorization_mode: Optional[VectorizeMode] = None,
  14. vector_kwargs: Optional[Dict[str, Any]] = None,
  15. # TODO (simon): Add wrappers?
  16. **kwargs,
  17. ) -> VectorMultiAgentEnv:
  18. if vector_kwargs is None:
  19. vector_kwargs = {}
  20. if vectorization_mode is None:
  21. vectorization_mode = "sync"
  22. # Create an `gymnasium.envs.registration.EnvSpec` to properly
  23. # initialize the sub-environments.
  24. if isinstance(id, gym.envs.registration.EnvSpec):
  25. env_spec = id
  26. elif isinstance(id, str):
  27. env_spec = gym.envs.registration._find_spec(id)
  28. else:
  29. raise ValueError(f"Invalid id type: {type(id)}. Expected `str` or `EnvSpec`.")
  30. env_spec = copy.deepcopy(env_spec)
  31. env_spec_kwargs = env_spec.kwargs
  32. env_spec.kwargs = dict()
  33. num_envs = env_spec.kwargs.get("num_envs", num_envs)
  34. vectorization_mode = env_spec_kwargs.pop("vectorization_mode", vectorization_mode)
  35. vector_kwargs = env_spec_kwargs.pop("vector_kwargs", vector_kwargs)
  36. env_spec_kwargs.update(kwargs)
  37. # Specify the vectorization mode.
  38. if vectorization_mode is None:
  39. vectorization_mode = VectorizeMode.SYNC
  40. else:
  41. try:
  42. vectorization_mode = VectorizeMode(vectorization_mode)
  43. except ValueError:
  44. raise ValueError(
  45. f"Invalid vectorization mode: {vectorization_mode!r}, "
  46. f"valid modes: {[mode.value for mode in VectorizeMode]}."
  47. )
  48. assert isinstance(vectorization_mode, VectorizeMode)
  49. def create_single_env() -> MultiAgentEnv:
  50. single_env = gym.make(env_spec, **env_spec_kwargs.copy())
  51. return single_env
  52. # Check, the vectorization mode.
  53. if vectorization_mode == VectorizeMode.SYNC:
  54. # Create the synchronized vector environemnt.
  55. env = SyncVectorMultiAgentEnv(
  56. env_fns=(create_single_env for _ in range(num_envs)),
  57. **vector_kwargs,
  58. )
  59. # Other modes are not implemented, yet.
  60. else:
  61. raise ValueError(
  62. "For `MultiAgentEnv` only synchronous environment vectorization "
  63. "is implemented. Use `gym_env_vectorize_mode='sync'`."
  64. )
  65. # Add all creation specifications to the environment.
  66. copied_id_spec = copy.deepcopy(env_spec)
  67. copied_id_spec.kwargs = env_spec_kwargs.copy()
  68. if num_envs != 1:
  69. copied_id_spec.kwargs["num_envs"] = num_envs
  70. copied_id_spec.kwargs["vectorization_mode"] = vectorization_mode.value
  71. if len(vector_kwargs) > 0:
  72. copied_id_spec.kwargs["vector_kwargs"] = vector_kwargs
  73. env.unwrapped.spec = copied_id_spec
  74. # Return the `VectorMultiAgentEnv`.
  75. return env