| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061 |
- import gymnasium as gym
- import numpy as np
- from ray.rllib.utils.annotations import PublicAPI
- @PublicAPI
- class Simplex(gym.Space):
- """Represents a d - 1 dimensional Simplex in R^d.
- That is, all coordinates are in [0, 1] and sum to 1.
- The dimension d of the simplex is assumed to be shape[-1].
- Additionally one can specify the underlying distribution of
- the simplex as a Dirichlet distribution by providing concentration
- parameters. By default, sampling is uniform, i.e. concentration is
- all 1s.
- Example usage:
- self.action_space = spaces.Simplex(shape=(3, 4))
- --> 3 independent 4d Dirichlet with uniform concentration
- """
- def __init__(self, shape, concentration=None, dtype=np.float32):
- assert type(shape) in [tuple, list]
- super().__init__(shape, dtype)
- self.dim = self.shape[-1]
- if concentration is not None:
- assert (
- concentration.shape[0] == shape[-1]
- ), f"{concentration.shape[0]} vs {shape[-1]}"
- self.concentration = concentration
- else:
- self.concentration = np.array([1] * self.dim)
- def sample(self):
- return np.random.dirichlet(self.concentration, size=self.shape[:-1]).astype(
- self.dtype
- )
- def contains(self, x):
- return x.shape == self.shape and np.allclose(
- np.sum(x, axis=-1), np.ones_like(x[..., 0])
- )
- def to_jsonable(self, sample_n):
- return np.array(sample_n).tolist()
- def from_jsonable(self, sample_n):
- return [np.asarray(sample) for sample in sample_n]
- def __repr__(self):
- return "Simplex({}; {})".format(self.shape, self.concentration)
- def __eq__(self, other):
- return (
- np.allclose(self.concentration, other.concentration)
- and self.shape == other.shape
- )
|