space_utils.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513
  1. from typing import Any, List, Optional, Union
  2. import gymnasium as gym
  3. import numpy as np
  4. import tree # pip install dm_tree
  5. from ray.rllib.utils.annotations import DeveloperAPI
  6. @DeveloperAPI
  7. class BatchedNdArray(np.ndarray):
  8. """A ndarray-wrapper the usage of which indicates that there a batch dim exists.
  9. This is such that our `batch()` utility can distinguish between having to
  10. stack n individual batch items (each one w/o any batch dim) vs having to
  11. concatenate n already batched items (each one possibly with a different batch
  12. dim, but definitely with some batch dim).
  13. TODO (sven): Maybe replace this by a list-override instead.
  14. """
  15. def __new__(cls, input_array):
  16. # Use __new__ to create a new instance of our subclass.
  17. obj = np.asarray(input_array).view(cls)
  18. return obj
  19. @DeveloperAPI
  20. def get_original_space(space: gym.Space) -> gym.Space:
  21. """Returns the original space of a space, if any.
  22. This function recursively traverses the given space and returns the original space
  23. at the very end of the chain.
  24. Args:
  25. space: The space to get the original space for.
  26. Returns:
  27. The original space or the given space itself if no original space is found.
  28. """
  29. if hasattr(space, "original_space"):
  30. return get_original_space(space.original_space)
  31. else:
  32. return space
  33. @DeveloperAPI
  34. def is_composite_space(space: gym.Space) -> bool:
  35. """Returns true, if the space is composite.
  36. Note, we follow here the glossary of `gymnasium` by which any spoace
  37. that holds other spaces is defined as being 'composite'.
  38. Args:
  39. space: The space to be checked for being composed of other spaces.
  40. Returns:
  41. True, if the space is composed of other spaces, otherwise False.
  42. """
  43. if type(space) in [
  44. gym.spaces.Dict,
  45. gym.spaces.Graph,
  46. gym.spaces.Sequence,
  47. gym.spaces.Tuple,
  48. ]:
  49. return True
  50. else:
  51. return False
  52. @DeveloperAPI
  53. def flatten_space(space: gym.Space) -> List[gym.Space]:
  54. """Flattens a gym.Space into its primitive components.
  55. Primitive components are any non Tuple/Dict spaces.
  56. Args:
  57. space: The gym.Space to flatten. This may be any
  58. supported type (including nested Tuples and Dicts).
  59. Returns:
  60. List[gym.Space]: The flattened list of primitive Spaces. This list
  61. does not contain Tuples or Dicts anymore.
  62. """
  63. def _helper_flatten(space_, return_list):
  64. from ray.rllib.utils.spaces.flexdict import FlexDict
  65. if isinstance(space_, gym.spaces.Tuple):
  66. for s in space_:
  67. _helper_flatten(s, return_list)
  68. elif isinstance(space_, (gym.spaces.Dict, FlexDict)):
  69. for k in sorted(space_.spaces):
  70. _helper_flatten(space_[k], return_list)
  71. else:
  72. return_list.append(space_)
  73. ret = []
  74. _helper_flatten(space, ret)
  75. return ret
  76. @DeveloperAPI
  77. def get_base_struct_from_space(space):
  78. """Returns a Tuple/Dict Space as native (equally structured) py tuple/dict.
  79. Args:
  80. space: The Space to get the python struct for.
  81. Returns:
  82. Union[dict,tuple,gym.Space]: The struct equivalent to the given Space.
  83. Note that the returned struct still contains all original
  84. "primitive" Spaces (e.g. Box, Discrete).
  85. .. testcode::
  86. :skipif: True
  87. get_base_struct_from_space(Dict({
  88. "a": Box(),
  89. "b": Tuple([Discrete(2), Discrete(3)])
  90. }))
  91. .. testoutput::
  92. dict(a=Box(), b=tuple(Discrete(2), Discrete(3)))
  93. """
  94. def _helper_struct(space_):
  95. if isinstance(space_, gym.spaces.Tuple):
  96. return tuple(_helper_struct(s) for s in space_)
  97. elif isinstance(space_, gym.spaces.Dict):
  98. return {k: _helper_struct(space_[k]) for k in space_.spaces}
  99. else:
  100. return space_
  101. return _helper_struct(space)
  102. @DeveloperAPI
  103. def get_dummy_batch_for_space(
  104. space: gym.Space,
  105. batch_size: int = 32,
  106. *,
  107. fill_value: Union[float, int, str] = 0.0,
  108. time_size: Optional[int] = None,
  109. time_major: bool = False,
  110. one_hot_discrete: bool = False,
  111. ) -> np.ndarray:
  112. """Returns batched dummy data (using `batch_size`) for the given `space`.
  113. Note: The returned batch will not pass a `space.contains(batch)` test
  114. as an additional batch dimension has to be added at axis 0, unless `batch_size` is
  115. set to 0.
  116. Args:
  117. space: The space to get a dummy batch for.
  118. batch_size: The required batch size (B). Note that this can also
  119. be 0 (only if `time_size` is None!), which will result in a
  120. non-batched sample for the given space (no batch dim).
  121. fill_value: The value to fill the batch with
  122. or "random" for random values.
  123. time_size: If not None, add an optional time axis
  124. of `time_size` size to the returned batch. This time axis might either
  125. be inserted at axis=1 (default) or axis=0, if `time_major` is True.
  126. time_major: If True AND `time_size` is not None, return batch
  127. as shape [T x B x ...], otherwise as [B x T x ...]. If `time_size`
  128. if None, ignore this setting and return [B x ...].
  129. one_hot_discrete: If True, will return one-hot vectors (instead of
  130. int-values) for those sub-components of a (possibly complex) `space`
  131. that are Discrete or MultiDiscrete. Note that in case `fill_value` is 0.0,
  132. this will result in zero-hot vectors (where all slots have a value of 0.0).
  133. Returns:
  134. The dummy batch of size `bqtch_size` matching the given space.
  135. """
  136. # Complex spaces. Perform recursive calls of this function.
  137. if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple, dict, tuple)):
  138. base_struct = space
  139. if isinstance(space, (gym.spaces.Dict, gym.spaces.Tuple)):
  140. base_struct = get_base_struct_from_space(space)
  141. return tree.map_structure(
  142. lambda s: get_dummy_batch_for_space(
  143. space=s,
  144. batch_size=batch_size,
  145. fill_value=fill_value,
  146. time_size=time_size,
  147. time_major=time_major,
  148. one_hot_discrete=one_hot_discrete,
  149. ),
  150. base_struct,
  151. )
  152. if one_hot_discrete:
  153. if isinstance(space, gym.spaces.Discrete):
  154. space = gym.spaces.Box(0.0, 1.0, (space.n,), np.float32)
  155. elif isinstance(space, gym.spaces.MultiDiscrete):
  156. space = gym.spaces.Box(0.0, 1.0, (np.sum(space.nvec),), np.float32)
  157. # Primitive spaces: Box, Discrete, MultiDiscrete.
  158. # Random values: Use gym's sample() method.
  159. if fill_value == "random":
  160. if time_size is not None:
  161. assert batch_size > 0 and time_size > 0
  162. if time_major:
  163. return np.array(
  164. [
  165. [space.sample() for _ in range(batch_size)]
  166. for t in range(time_size)
  167. ],
  168. dtype=space.dtype,
  169. )
  170. else:
  171. return np.array(
  172. [
  173. [space.sample() for t in range(time_size)]
  174. for _ in range(batch_size)
  175. ],
  176. dtype=space.dtype,
  177. )
  178. else:
  179. return np.array(
  180. [space.sample() for _ in range(batch_size)]
  181. if batch_size > 0
  182. else space.sample(),
  183. dtype=space.dtype,
  184. )
  185. # Fill value given: Use np.full.
  186. else:
  187. if time_size is not None:
  188. assert batch_size > 0 and time_size > 0
  189. if time_major:
  190. shape = [time_size, batch_size]
  191. else:
  192. shape = [batch_size, time_size]
  193. else:
  194. shape = [batch_size] if batch_size > 0 else []
  195. return np.full(
  196. shape + list(space.shape), fill_value=fill_value, dtype=space.dtype
  197. )
  198. @DeveloperAPI
  199. def flatten_to_single_ndarray(input_):
  200. """Returns a single np.ndarray given a list/tuple of np.ndarrays.
  201. Args:
  202. input_ (Union[List[np.ndarray], np.ndarray]): The list of ndarrays or
  203. a single ndarray.
  204. Returns:
  205. np.ndarray: The result after concatenating all single arrays in input_.
  206. .. testcode::
  207. :skipif: True
  208. flatten_to_single_ndarray([
  209. np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]),
  210. np.array([7, 8, 9]),
  211. ])
  212. .. testoutput::
  213. np.array([
  214. 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0
  215. ])
  216. """
  217. # Concatenate complex inputs.
  218. if isinstance(input_, (list, tuple, dict)):
  219. expanded = []
  220. for in_ in tree.flatten(input_):
  221. expanded.append(np.reshape(in_, [-1]))
  222. input_ = np.concatenate(expanded, axis=0).flatten()
  223. return input_
  224. @DeveloperAPI
  225. def batch(
  226. list_of_structs: List[Any],
  227. *,
  228. individual_items_already_have_batch_dim: Union[bool, str] = False,
  229. ):
  230. """Converts input from a list of (nested) structs to a (nested) struct of batches.
  231. Input: List of structs (each of these structs representing a single batch item).
  232. [
  233. {"a": 1, "b": (4, 7.0)}, <- batch item 1
  234. {"a": 2, "b": (5, 8.0)}, <- batch item 2
  235. {"a": 3, "b": (6, 9.0)}, <- batch item 3
  236. ]
  237. Output: Struct of different batches (each batch has size=3 b/c there were 3 items
  238. in the original list):
  239. {
  240. "a": np.array([1, 2, 3]),
  241. "b": (np.array([4, 5, 6]), np.array([7.0, 8.0, 9.0]))
  242. }
  243. Args:
  244. list_of_structs: The list of (possibly nested) structs. Each item
  245. in this list represents a single batch item.
  246. individual_items_already_have_batch_dim: True, if the individual items in
  247. `list_of_structs` already have a batch dim. In this case, we will
  248. concatenate (instead of stack) at the end. In the example above, this would
  249. look like this: Input: [{"a": [1], "b": ([4], [7.0])}, ...] -> Output: same
  250. as in above example.
  251. If the special value "auto" is used,
  252. Returns:
  253. The struct of component batches. Each leaf item in this struct represents the
  254. batch for a single component (in case struct is tuple/dict). If the input is a
  255. simple list of primitive items, e.g. a list of floats, a np.array of floats
  256. will be returned.
  257. """
  258. if not list_of_structs:
  259. raise ValueError("Input `list_of_structs` does not contain any items.")
  260. # TODO (sven): Maybe replace this by a list-override (usage of which indicated
  261. # this method that concatenate should be used (not stack)).
  262. if individual_items_already_have_batch_dim == "auto":
  263. flat = tree.flatten(list_of_structs[0])
  264. individual_items_already_have_batch_dim = isinstance(flat[0], BatchedNdArray)
  265. np_func = np.concatenate if individual_items_already_have_batch_dim else np.stack
  266. ret = tree.map_structure(
  267. lambda *s: np.ascontiguousarray(np_func(s, axis=0)), *list_of_structs
  268. )
  269. return ret
  270. @DeveloperAPI
  271. def unbatch(batches_struct):
  272. """Converts input from (nested) struct of batches to batch of structs.
  273. Input: Struct of different batches (each batch has size=3):
  274. {
  275. "a": np.array([1, 2, 3]),
  276. "b": (np.array([4, 5, 6]), np.array([7.0, 8.0, 9.0]))
  277. }
  278. Output: Batch (list) of structs (each of these structs representing a
  279. single action):
  280. [
  281. {"a": 1, "b": (4, 7.0)}, <- action 1
  282. {"a": 2, "b": (5, 8.0)}, <- action 2
  283. {"a": 3, "b": (6, 9.0)}, <- action 3
  284. ]
  285. Args:
  286. batches_struct: The struct of component batches. Each leaf item
  287. in this struct represents the batch for a single component
  288. (in case struct is tuple/dict).
  289. Alternatively, `batches_struct` may also simply be a batch of
  290. primitives (non tuple/dict).
  291. Returns:
  292. The list of individual structs. Each item in the returned list represents a
  293. single (maybe complex) batch item.
  294. """
  295. flat_batches = tree.flatten(batches_struct)
  296. out = []
  297. for batch_pos in range(len(flat_batches[0])):
  298. out.append(
  299. tree.unflatten_as(
  300. batches_struct,
  301. [flat_batches[i][batch_pos] for i in range(len(flat_batches))],
  302. )
  303. )
  304. return out
  305. @DeveloperAPI
  306. def clip_action(action, action_space):
  307. """Clips all components in `action` according to the given Space.
  308. Only applies to Box components within the action space.
  309. Args:
  310. action: The action to be clipped. This could be any complex
  311. action, e.g. a dict or tuple.
  312. action_space: The action space struct,
  313. e.g. `{"a": Distrete(2)}` for a space: Dict({"a": Discrete(2)}).
  314. Returns:
  315. Any: The input action, but clipped by value according to the space's
  316. bounds.
  317. """
  318. def map_(a, s):
  319. if isinstance(s, gym.spaces.Box):
  320. a = np.clip(a, s.low, s.high)
  321. return a
  322. return tree.map_structure(map_, action, action_space)
  323. @DeveloperAPI
  324. def unsquash_action(action, action_space_struct):
  325. """Unsquashes all components in `action` according to the given Space.
  326. Inverse of `normalize_action()`. Useful for mapping policy action
  327. outputs (normalized between -1.0 and 1.0) to an env's action space.
  328. Unsquashing results in cont. action component values between the
  329. given Space's bounds (`low` and `high`). This only applies to Box
  330. components within the action space, whose dtype is float32 or float64.
  331. Args:
  332. action: The action to be unsquashed. This could be any complex
  333. action, e.g. a dict or tuple.
  334. action_space_struct: The action space struct,
  335. e.g. `{"a": Box()}` for a space: Dict({"a": Box()}).
  336. Returns:
  337. Any: The input action, but unsquashed, according to the space's
  338. bounds. An unsquashed action is ready to be sent to the
  339. environment (`BaseEnv.send_actions([unsquashed actions])`).
  340. """
  341. def map_(a, s):
  342. if (
  343. isinstance(s, gym.spaces.Box)
  344. and np.all(s.bounded_below)
  345. and np.all(s.bounded_above)
  346. ):
  347. if s.dtype == np.float32 or s.dtype == np.float64:
  348. # Assuming values are roughly between -1.0 and 1.0 ->
  349. # unsquash them to the given bounds.
  350. a = s.low + (a + 1.0) * (s.high - s.low) / 2.0
  351. # Clip to given bounds, just in case the squashed values were
  352. # outside [-1.0, 1.0].
  353. a = np.clip(a, s.low, s.high)
  354. elif np.issubdtype(s.dtype, np.integer):
  355. # For Categorical and MultiCategorical actions, shift the selection
  356. # into the proper range.
  357. a = s.low + a
  358. return a
  359. return tree.map_structure(map_, action, action_space_struct)
  360. @DeveloperAPI
  361. def normalize_action(action, action_space_struct):
  362. """Normalizes all (Box) components in `action` to be in [-1.0, 1.0].
  363. Inverse of `unsquash_action()`. Useful for mapping an env's action
  364. (arbitrary bounded values) to a [-1.0, 1.0] interval.
  365. This only applies to Box components within the action space, whose
  366. dtype is float32 or float64.
  367. Args:
  368. action: The action to be normalized. This could be any complex
  369. action, e.g. a dict or tuple.
  370. action_space_struct: The action space struct,
  371. e.g. `{"a": Box()}` for a space: Dict({"a": Box()}).
  372. Returns:
  373. Any: The input action, but normalized, according to the space's
  374. bounds.
  375. """
  376. def map_(a, s):
  377. if isinstance(s, gym.spaces.Box) and (
  378. s.dtype == np.float32 or s.dtype == np.float64
  379. ):
  380. # Normalize values to be exactly between -1.0 and 1.0.
  381. a = ((a - s.low) * 2.0) / (s.high - s.low) - 1.0
  382. return a
  383. return tree.map_structure(map_, action, action_space_struct)
  384. @DeveloperAPI
  385. def convert_element_to_space_type(element: Any, sampled_element: Any) -> Any:
  386. """Convert all the components of the element to match the space dtypes.
  387. Args:
  388. element: The element to be converted.
  389. sampled_element: An element sampled from a space to be matched
  390. to.
  391. Returns:
  392. The input element, but with all its components converted to match
  393. the space dtypes.
  394. """
  395. def map_(elem, s):
  396. if isinstance(s, np.ndarray):
  397. if not isinstance(elem, np.ndarray):
  398. assert isinstance(
  399. elem, (float, int)
  400. ), f"ERROR: `elem` ({elem}) must be np.array, float or int!"
  401. if s.shape == ():
  402. elem = np.array(elem, dtype=s.dtype)
  403. else:
  404. raise ValueError(
  405. "Element should be of type np.ndarray but is instead of \
  406. type {}".format(
  407. type(elem)
  408. )
  409. )
  410. elif s.dtype != elem.dtype:
  411. elem = elem.astype(s.dtype)
  412. # Gymnasium now uses np.int_64 as the dtype of a Discrete action space
  413. elif isinstance(s, int) or isinstance(s, np.int_):
  414. if isinstance(elem, float) and elem.is_integer():
  415. elem = int(elem)
  416. # Note: This does not check if the float element is actually an integer
  417. if isinstance(elem, np.float_):
  418. elem = np.int64(elem)
  419. return elem
  420. return tree.map_structure(map_, element, sampled_element, check_types=False)