flexdict.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import gymnasium as gym
  2. from ray.rllib.utils.annotations import PublicAPI
  3. @PublicAPI
  4. class FlexDict(gym.spaces.Dict):
  5. """Gym Dictionary with arbitrary keys updatable after instantiation
  6. Example:
  7. space = FlexDict({})
  8. space['key'] = spaces.Box(4,)
  9. See also: documentation for gym.spaces.Dict
  10. """
  11. def __init__(self, spaces=None, **spaces_kwargs):
  12. err = "Use either Dict(spaces=dict(...)) or Dict(foo=x, bar=z)"
  13. assert (spaces is None) or (not spaces_kwargs), err
  14. if spaces is None:
  15. spaces = spaces_kwargs
  16. for space in spaces.values():
  17. self.assertSpace(space)
  18. super().__init__(spaces=spaces)
  19. def assertSpace(self, space):
  20. err = "Values of the dict should be instances of gym.Space"
  21. assert issubclass(type(space), gym.spaces.Space), err
  22. def sample(self):
  23. return {k: space.sample() for k, space in self.spaces.items()}
  24. def __getitem__(self, key):
  25. return self.spaces[key]
  26. def __setitem__(self, key, space):
  27. self.assertSpace(space)
  28. self.spaces[key] = space
  29. def __repr__(self):
  30. return (
  31. "FlexDict("
  32. + ", ".join([str(k) + ":" + str(s) for k, s in self.spaces.items()])
  33. + ")"
  34. )