simplex.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import gymnasium as gym
  2. import numpy as np
  3. from ray.rllib.utils.annotations import PublicAPI
  4. @PublicAPI
  5. class Simplex(gym.Space):
  6. """Represents a d - 1 dimensional Simplex in R^d.
  7. That is, all coordinates are in [0, 1] and sum to 1.
  8. The dimension d of the simplex is assumed to be shape[-1].
  9. Additionally one can specify the underlying distribution of
  10. the simplex as a Dirichlet distribution by providing concentration
  11. parameters. By default, sampling is uniform, i.e. concentration is
  12. all 1s.
  13. Example usage:
  14. self.action_space = spaces.Simplex(shape=(3, 4))
  15. --> 3 independent 4d Dirichlet with uniform concentration
  16. """
  17. def __init__(self, shape, concentration=None, dtype=np.float32):
  18. assert type(shape) in [tuple, list]
  19. super().__init__(shape, dtype)
  20. self.dim = self.shape[-1]
  21. if concentration is not None:
  22. assert (
  23. concentration.shape[0] == shape[-1]
  24. ), f"{concentration.shape[0]} vs {shape[-1]}"
  25. self.concentration = concentration
  26. else:
  27. self.concentration = np.array([1] * self.dim)
  28. def sample(self):
  29. return np.random.dirichlet(self.concentration, size=self.shape[:-1]).astype(
  30. self.dtype
  31. )
  32. def contains(self, x):
  33. return x.shape == self.shape and np.allclose(
  34. np.sum(x, axis=-1), np.ones_like(x[..., 0])
  35. )
  36. def to_jsonable(self, sample_n):
  37. return np.array(sample_n).tolist()
  38. def from_jsonable(self, sample_n):
  39. return [np.asarray(sample) for sample in sample_n]
  40. def __repr__(self):
  41. return "Simplex({}; {})".format(self.shape, self.concentration)
  42. def __eq__(self, other):
  43. return (
  44. np.allclose(self.concentration, other.concentration)
  45. and self.shape == other.shape
  46. )