| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513 |
- from typing import Any, List, Optional, Union
- import gymnasium as gym
- import numpy as np
- import tree # pip install dm_tree
- from ray.rllib.utils.annotations import DeveloperAPI
- @DeveloperAPI
- class BatchedNdArray(np.ndarray):
- """A ndarray-wrapper the usage of which indicates that there a batch dim exists.
- This is such that our `batch()` utility can distinguish between having to
- stack n individual batch items (each one w/o any batch dim) vs having to
- concatenate n already batched items (each one possibly with a different batch
- dim, but definitely with some batch dim).
- TODO (sven): Maybe replace this by a list-override instead.
- """
- def __new__(cls, input_array):
- # Use __new__ to create a new instance of our subclass.
- obj = np.asarray(input_array).view(cls)
- return obj
- @DeveloperAPI
- def get_original_space(space: gym.Space) -> gym.Space:
- """Returns the original space of a space, if any.
- This function recursively traverses the given space and returns the original space
- at the very end of the chain.
- Args:
- space: The space to get the original space for.
- Returns:
- The original space or the given space itself if no original space is found.
- """
- if hasattr(space, "original_space"):
- return get_original_space(space.original_space)
- else:
- return space
- @DeveloperAPI
- def is_composite_space(space: gym.Space) -> bool:
- """Returns true, if the space is composite.
- Note, we follow here the glossary of `gymnasium` by which any spoace
- that holds other spaces is defined as being 'composite'.
- Args:
- space: The space to be checked for being composed of other spaces.
- Returns:
- True, if the space is composed of other spaces, otherwise False.
- """
- if type(space) in [
- gym.spaces.Dict,
- gym.spaces.Graph,
- gym.spaces.Sequence,
- gym.spaces.Tuple,
- ]:
- return True
- else:
- return False
- @DeveloperAPI
- def flatten_space(space: gym.Space) -> List[gym.Space]:
- """Flattens a gym.Space into its primitive components.
- Primitive components are any non Tuple/Dict spaces.
- Args:
- space: The gym.Space to flatten. This may be any
- supported type (including nested Tuples and Dicts).
- Returns:
- List[gym.Space]: The flattened list of primitive Spaces. This list
- does not contain Tuples or Dicts anymore.
- """
- def _helper_flatten(space_, return_list):
- from ray.rllib.utils.spaces.flexdict import FlexDict
- if isinstance(space_, gym.spaces.Tuple):
- for s in space_:
- _helper_flatten(s, return_list)
- elif isinstance(space_, (gym.spaces.Dict, FlexDict)):
- for k in sorted(space_.spaces):
- _helper_flatten(space_[k], return_list)
- else:
- return_list.append(space_)
- ret = []
- _helper_flatten(space, ret)
- return ret
- @DeveloperAPI
- def get_base_struct_from_space(space):
- """Returns a Tuple/Dict Space as native (equally structured) py tuple/dict.
- Args:
- space: The Space to get the python struct for.
- Returns:
- Union[dict,tuple,gym.Space]: The struct equivalent to the given Space.
- Note that the returned struct still contains all original
- "primitive" Spaces (e.g. Box, Discrete).
- .. testcode::
- :skipif: True
- get_base_struct_from_space(Dict({
- "a": Box(),
- "b": Tuple([Discrete(2), Discrete(3)])
- }))
- .. testoutput::
- dict(a=Box(), b=tuple(Discrete(2), Discrete(3)))
- """
- def _helper_struct(space_):
- if isinstance(space_, gym.spaces.Tuple):
- return tuple(_helper_struct(s) for s in space_)
- elif isinstance(space_, gym.spaces.Dict):
- return {k: _helper_struct(space_[k]) for k in space_.spaces}
- else:
- return space_
- return _helper_struct(space)
- @DeveloperAPI
- def get_dummy_batch_for_space(
- space: gym.Space,
- batch_size: int = 32,
- *,
- fill_value: Union[float, int, str] = 0.0,
- time_size: Optional[int] = None,
- time_major: bool = False,
- one_hot_discrete: bool = False,
- ) -> np.ndarray:
- """Returns batched dummy data (using `batch_size`) for the given `space`.
- Note: The returned batch will not pass a `space.contains(batch)` test
- as an additional batch dimension has to be added at axis 0, unless `batch_size` is
- set to 0.
- Args:
- space: The space to get a dummy batch for.
- batch_size: The required batch size (B). Note that this can also
- be 0 (only if `time_size` is None!), which will result in a
- non-batched sample for the given space (no batch dim).
- fill_value: The value to fill the batch with
- or "random" for random values.
- time_size: If not None, add an optional time axis
- of `time_size` size to the returned batch. This time axis might either
- be inserted at axis=1 (default) or axis=0, if `time_major` is True.
- time_major: If True AND `time_size` is not None, return batch
- as shape [T x B x ...], otherwise as [B x T x ...]. If `time_size`
- if None, ignore this setting and return [B x ...].
- one_hot_discrete: If True, will return one-hot vectors (instead of
- int-values) for those sub-components of a (possibly complex) `space`
- that are Discrete or MultiDiscrete. Note that in case `fill_value` is 0.0,
- this will result in zero-hot vectors (where all slots have a value of 0.0).
- Returns:
- The dummy batch of size `bqtch_size` matching the given space.
- """
- # Complex spaces. Perform recursive calls of this function.
- if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple, dict, tuple)):
- base_struct = space
- if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)):
- base_struct = get_base_struct_from_space(space)
- return tree.map_structure(
- lambda s: get_dummy_batch_for_space(
- space=s,
- batch_size=batch_size,
- fill_value=fill_value,
- time_size=time_size,
- time_major=time_major,
- one_hot_discrete=one_hot_discrete,
- ),
- base_struct,
- )
- if one_hot_discrete:
- if isinstance(space, gym.spaces.Discrete):
- space = gym.spaces.Box(0.0, 1.0, (space.n,), np.float32)
- elif isinstance(space, gym.spaces.MultiDiscrete):
- space = gym.spaces.Box(0.0, 1.0, (np.sum(space.nvec),), np.float32)
- # Primitive spaces: Box, Discrete, MultiDiscrete.
- # Random values: Use gym's sample() method.
- if fill_value == "random":
- if time_size is not None:
- assert batch_size > 0 and time_size > 0
- if time_major:
- return np.array(
- [
- [space.sample() for _ in range(batch_size)]
- for t in range(time_size)
- ],
- dtype=space.dtype,
- )
- else:
- return np.array(
- [
- [space.sample() for t in range(time_size)]
- for _ in range(batch_size)
- ],
- dtype=space.dtype,
- )
- else:
- return np.array(
- [space.sample() for _ in range(batch_size)]
- if batch_size > 0
- else space.sample(),
- dtype=space.dtype,
- )
- # Fill value given: Use np.full.
- else:
- if time_size is not None:
- assert batch_size > 0 and time_size > 0
- if time_major:
- shape = [time_size, batch_size]
- else:
- shape = [batch_size, time_size]
- else:
- shape = [batch_size] if batch_size > 0 else []
- return np.full(
- shape + list(space.shape), fill_value=fill_value, dtype=space.dtype
- )
- @DeveloperAPI
- def flatten_to_single_ndarray(input_):
- """Returns a single np.ndarray given a list/tuple of np.ndarrays.
- Args:
- input_ (Union[List[np.ndarray], np.ndarray]): The list of ndarrays or
- a single ndarray.
- Returns:
- np.ndarray: The result after concatenating all single arrays in input_.
- .. testcode::
- :skipif: True
- flatten_to_single_ndarray([
- np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
- np.array([7, 8, 9]),
- ])
- .. testoutput::
- np.array([
- 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0
- ])
- """
- # Concatenate complex inputs.
- if isinstance(input_, (list, tuple, dict)):
- expanded = []
- for in_ in tree.flatten(input_):
- expanded.append(np.reshape(in_, [-1]))
- input_ = np.concatenate(expanded, axis=0).flatten()
- return input_
- @DeveloperAPI
- def batch(
- list_of_structs: List[Any],
- *,
- individual_items_already_have_batch_dim: Union[bool, str] = False,
- ):
- """Converts input from a list of (nested) structs to a (nested) struct of batches.
- Input: List of structs (each of these structs representing a single batch item).
- [
- {"a": 1, "b": (4, 7.0)}, <- batch item 1
- {"a": 2, "b": (5, 8.0)}, <- batch item 2
- {"a": 3, "b": (6, 9.0)}, <- batch item 3
- ]
- Output: Struct of different batches (each batch has size=3 b/c there were 3 items
- in the original list):
- {
- "a": np.array([1, 2, 3]),
- "b": (np.array([4, 5, 6]), np.array([7.0, 8.0, 9.0]))
- }
- Args:
- list_of_structs: The list of (possibly nested) structs. Each item
- in this list represents a single batch item.
- individual_items_already_have_batch_dim: True, if the individual items in
- `list_of_structs` already have a batch dim. In this case, we will
- concatenate (instead of stack) at the end. In the example above, this would
- look like this: Input: [{"a": [1], "b": ([4], [7.0])}, ...] -> Output: same
- as in above example.
- If the special value "auto" is used,
- Returns:
- The struct of component batches. Each leaf item in this struct represents the
- batch for a single component (in case struct is tuple/dict). If the input is a
- simple list of primitive items, e.g. a list of floats, a np.array of floats
- will be returned.
- """
- if not list_of_structs:
- raise ValueError("Input `list_of_structs` does not contain any items.")
- # TODO (sven): Maybe replace this by a list-override (usage of which indicated
- # this method that concatenate should be used (not stack)).
- if individual_items_already_have_batch_dim == "auto":
- flat = tree.flatten(list_of_structs[0])
- individual_items_already_have_batch_dim = isinstance(flat[0], BatchedNdArray)
- np_func = np.concatenate if individual_items_already_have_batch_dim else np.stack
- ret = tree.map_structure(
- lambda *s: np.ascontiguousarray(np_func(s, axis=0)), *list_of_structs
- )
- return ret
- @DeveloperAPI
- def unbatch(batches_struct):
- """Converts input from (nested) struct of batches to batch of structs.
- Input: Struct of different batches (each batch has size=3):
- {
- "a": np.array([1, 2, 3]),
- "b": (np.array([4, 5, 6]), np.array([7.0, 8.0, 9.0]))
- }
- Output: Batch (list) of structs (each of these structs representing a
- single action):
- [
- {"a": 1, "b": (4, 7.0)}, <- action 1
- {"a": 2, "b": (5, 8.0)}, <- action 2
- {"a": 3, "b": (6, 9.0)}, <- action 3
- ]
- Args:
- batches_struct: The struct of component batches. Each leaf item
- in this struct represents the batch for a single component
- (in case struct is tuple/dict).
- Alternatively, `batches_struct` may also simply be a batch of
- primitives (non tuple/dict).
- Returns:
- The list of individual structs. Each item in the returned list represents a
- single (maybe complex) batch item.
- """
- flat_batches = tree.flatten(batches_struct)
- out = []
- for batch_pos in range(len(flat_batches[0])):
- out.append(
- tree.unflatten_as(
- batches_struct,
- [flat_batches[i][batch_pos] for i in range(len(flat_batches))],
- )
- )
- return out
- @DeveloperAPI
- def clip_action(action, action_space):
- """Clips all components in `action` according to the given Space.
- Only applies to Box components within the action space.
- Args:
- action: The action to be clipped. This could be any complex
- action, e.g. a dict or tuple.
- action_space: The action space struct,
- e.g. `{"a": Distrete(2)}` for a space: Dict({"a": Discrete(2)}).
- Returns:
- Any: The input action, but clipped by value according to the space's
- bounds.
- """
- def map_(a, s):
- if isinstance(s, gym.spaces.Box):
- a = np.clip(a, s.low, s.high)
- return a
- return tree.map_structure(map_, action, action_space)
- @DeveloperAPI
- def unsquash_action(action, action_space_struct):
- """Unsquashes all components in `action` according to the given Space.
- Inverse of `normalize_action()`. Useful for mapping policy action
- outputs (normalized between -1.0 and 1.0) to an env's action space.
- Unsquashing results in cont. action component values between the
- given Space's bounds (`low` and `high`). This only applies to Box
- components within the action space, whose dtype is float32 or float64.
- Args:
- action: The action to be unsquashed. This could be any complex
- action, e.g. a dict or tuple.
- action_space_struct: The action space struct,
- e.g. `{"a": Box()}` for a space: Dict({"a": Box()}).
- Returns:
- Any: The input action, but unsquashed, according to the space's
- bounds. An unsquashed action is ready to be sent to the
- environment (`BaseEnv.send_actions([unsquashed actions])`).
- """
- def map_(a, s):
- if (
- isinstance(s, gym.spaces.Box)
- and np.all(s.bounded_below)
- and np.all(s.bounded_above)
- ):
- if s.dtype == np.float32 or s.dtype == np.float64:
- # Assuming values are roughly between -1.0 and 1.0 ->
- # unsquash them to the given bounds.
- a = s.low + (a + 1.0) * (s.high - s.low) / 2.0
- # Clip to given bounds, just in case the squashed values were
- # outside [-1.0, 1.0].
- a = np.clip(a, s.low, s.high)
- elif np.issubdtype(s.dtype, np.integer):
- # For Categorical and MultiCategorical actions, shift the selection
- # into the proper range.
- a = s.low + a
- return a
- return tree.map_structure(map_, action, action_space_struct)
- @DeveloperAPI
- def normalize_action(action, action_space_struct):
- """Normalizes all (Box) components in `action` to be in [-1.0, 1.0].
- Inverse of `unsquash_action()`. Useful for mapping an env's action
- (arbitrary bounded values) to a [-1.0, 1.0] interval.
- This only applies to Box components within the action space, whose
- dtype is float32 or float64.
- Args:
- action: The action to be normalized. This could be any complex
- action, e.g. a dict or tuple.
- action_space_struct: The action space struct,
- e.g. `{"a": Box()}` for a space: Dict({"a": Box()}).
- Returns:
- Any: The input action, but normalized, according to the space's
- bounds.
- """
- def map_(a, s):
- if isinstance(s, gym.spaces.Box) and (
- s.dtype == np.float32 or s.dtype == np.float64
- ):
- # Normalize values to be exactly between -1.0 and 1.0.
- a = ((a - s.low) * 2.0) / (s.high - s.low) - 1.0
- return a
- return tree.map_structure(map_, action, action_space_struct)
- @DeveloperAPI
- def convert_element_to_space_type(element: Any, sampled_element: Any) -> Any:
- """Convert all the components of the element to match the space dtypes.
- Args:
- element: The element to be converted.
- sampled_element: An element sampled from a space to be matched
- to.
- Returns:
- The input element, but with all its components converted to match
- the space dtypes.
- """
- def map_(elem, s):
- if isinstance(s, np.ndarray):
- if not isinstance(elem, np.ndarray):
- assert isinstance(
- elem, (float, int)
- ), f"ERROR: `elem` ({elem}) must be np.array, float or int!"
- if s.shape == ():
- elem = np.array(elem, dtype=s.dtype)
- else:
- raise ValueError(
- "Element should be of type np.ndarray but is instead of \
- type {}".format(
- type(elem)
- )
- )
- elif s.dtype != elem.dtype:
- elem = elem.astype(s.dtype)
- # Gymnasium now uses np.int_64 as the dtype of a Discrete action space
- elif isinstance(s, int) or isinstance(s, np.int_):
- if isinstance(elem, float) and elem.is_integer():
- elem = int(elem)
- # Note: This does not check if the float element is actually an integer
- if isinstance(elem, np.float_):
- elem = np.int64(elem)
- return elem
- return tree.map_structure(map_, element, sampled_element, check_types=False)
|