env_context.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import copy
  2. from typing import Optional
  3. from ray.rllib.utils.typing import EnvConfigDict
  4. from ray.util.annotations import DeveloperAPI
  5. @DeveloperAPI
  6. class EnvContext(dict):
  7. """Wraps env configurations to include extra rllib metadata.
  8. These attributes can be used to parameterize environments per process.
  9. For example, one might use `worker_index` to control which data file an
  10. environment reads in on initialization.
  11. RLlib auto-sets these attributes when constructing registered envs.
  12. """
  13. def __init__(
  14. self,
  15. env_config: EnvConfigDict,
  16. worker_index: int,
  17. vector_index: int = 0,
  18. remote: bool = False,
  19. num_workers: Optional[int] = None,
  20. recreated_worker: bool = False,
  21. ):
  22. """Initializes an EnvContext instance.
  23. Args:
  24. env_config: The env's configuration defined under the
  25. "env_config" key in the Algorithm's config.
  26. worker_index: When there are multiple workers created, this
  27. uniquely identifies the worker the env is created in.
  28. 0 for local worker, >0 for remote workers.
  29. vector_index: When there are multiple envs per worker, this
  30. uniquely identifies the env index within the worker.
  31. Starts from 0.
  32. remote: Whether individual sub-environments (in a vectorized
  33. env) should be @ray.remote actors or not.
  34. num_workers: The total number of (remote) workers in the set.
  35. 0 if only a local worker exists.
  36. recreated_worker: Whether the worker that holds this env is a recreated one.
  37. This means that it replaced a previous (failed) worker when
  38. `restart_failed_env_runners=True` in the Algorithm's config.
  39. """
  40. # Store the env_config in the (super) dict.
  41. dict.__init__(self, env_config)
  42. # Set some metadata attributes.
  43. self.worker_index = worker_index
  44. self.vector_index = vector_index
  45. self.remote = remote
  46. self.num_workers = num_workers
  47. self.recreated_worker = recreated_worker
  48. def copy_with_overrides(
  49. self,
  50. env_config: Optional[EnvConfigDict] = None,
  51. worker_index: Optional[int] = None,
  52. vector_index: Optional[int] = None,
  53. remote: Optional[bool] = None,
  54. num_workers: Optional[int] = None,
  55. recreated_worker: Optional[bool] = None,
  56. ) -> "EnvContext":
  57. """Returns a copy of this EnvContext with some attributes overridden.
  58. Args:
  59. env_config: Optional env config to use. None for not overriding
  60. the one from the source (self).
  61. worker_index: Optional worker index to use. None for not
  62. overriding the one from the source (self).
  63. vector_index: Optional vector index to use. None for not
  64. overriding the one from the source (self).
  65. remote: Optional remote setting to use. None for not overriding
  66. the one from the source (self).
  67. num_workers: Optional num_workers to use. None for not overriding
  68. the one from the source (self).
  69. recreated_worker: Optional flag, indicating, whether the worker that holds
  70. the env is a recreated one. This means that it replaced a previous
  71. (failed) worker when `restart_failed_env_runners=True` in the
  72. Algorithm's config.
  73. Returns:
  74. A new EnvContext object as a copy of self plus the provided
  75. overrides.
  76. """
  77. return EnvContext(
  78. copy.deepcopy(env_config) if env_config is not None else self,
  79. worker_index if worker_index is not None else self.worker_index,
  80. vector_index if vector_index is not None else self.vector_index,
  81. remote if remote is not None else self.remote,
  82. num_workers if num_workers is not None else self.num_workers,
  83. recreated_worker if recreated_worker is not None else self.recreated_worker,
  84. )
  85. def set_defaults(self, defaults: dict) -> None:
  86. """Sets missing keys of self to the values given in `defaults`.
  87. If `defaults` contains keys that already exist in self, don't override
  88. the values with these defaults.
  89. Args:
  90. defaults: The key/value pairs to add to self, but only for those
  91. keys in `defaults` that don't exist yet in self.
  92. .. testcode::
  93. :skipif: True
  94. from ray.rllib.env.env_context import EnvContext
  95. env_ctx = EnvContext({"a": 1, "b": 2}, worker_index=0)
  96. env_ctx.set_defaults({"a": -42, "c": 3})
  97. print(env_ctx)
  98. .. testoutput::
  99. {"a": 1, "b": 2, "c": 3}
  100. """
  101. for key, value in defaults.items():
  102. if key not in self:
  103. self[key] = value
  104. def __str__(self):
  105. return (
  106. super().__str__()[:-1]
  107. + f", worker={self.worker_index}/{self.num_workers}, "
  108. f"vector_idx={self.vector_index}, remote={self.remote}" + "}"
  109. )