zero_padding.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. from collections import deque
  2. from typing import List, Tuple, Union
  3. import numpy as np
  4. import tree # pip install dm_tree
  5. from ray.rllib.utils.spaces.space_utils import BatchedNdArray, batch
  6. from ray.util.annotations import DeveloperAPI
  7. @DeveloperAPI
  8. def create_mask_and_seq_lens(episode_len: int, T: int) -> Tuple[List, List]:
  9. """Creates loss mask and a seq_lens array, given an episode length and T.
  10. Args:
  11. episode_lens: A list of episode lengths to infer the loss mask and seq_lens
  12. array from.
  13. T: The maximum number of timesteps in each "row", also known as the maximum
  14. sequence length (max_seq_len). Episodes are split into chunks that are at
  15. most `T` long and remaining timesteps will be zero-padded (and masked out).
  16. Returns:
  17. Tuple consisting of a) list of the loss masks to use (masking out areas that
  18. are past the end of an episode (or rollout), but had to be zero-added due to
  19. the added extra time rank (of length T) and b) the list of sequence lengths
  20. resulting from splitting the given episodes into chunks of at most `T`
  21. timesteps.
  22. """
  23. mask = []
  24. seq_lens = []
  25. len_ = min(episode_len, T)
  26. seq_lens.append(len_)
  27. row = np.array([1] * len_ + [0] * (T - len_), np.bool_)
  28. mask.append(row)
  29. # Handle sequence lengths greater than T.
  30. overflow = episode_len - T
  31. while overflow > 0:
  32. len_ = min(overflow, T)
  33. seq_lens.append(len_)
  34. extra_row = np.array([1] * len_ + [0] * (T - len_), np.bool_)
  35. mask.append(extra_row)
  36. overflow -= T
  37. return mask, seq_lens
  38. @DeveloperAPI
  39. def split_and_zero_pad(
  40. item_list: List[Union[BatchedNdArray, np._typing.NDArray, float]],
  41. max_seq_len: int,
  42. ) -> List[np._typing.NDArray]:
  43. """Splits the contents of `item_list` into a new list of ndarrays and returns it.
  44. In the returned list, each item is one ndarray of len (axis=0) `max_seq_len`.
  45. The last item in the returned list may be (right) zero-padded, if necessary, to
  46. reach `max_seq_len`.
  47. If `item_list` contains one or more `BatchedNdArray` (instead of individual
  48. items), these will be split accordingly along their axis=0 to yield the returned
  49. structure described above.
  50. .. testcode::
  51. from ray.rllib.utils.postprocessing.zero_padding import (
  52. BatchedNdArray,
  53. split_and_zero_pad,
  54. )
  55. from ray.rllib.utils.test_utils import check
  56. # Simple case: `item_list` contains individual floats.
  57. check(
  58. split_and_zero_pad([0, 1, 2, 3, 4, 5, 6, 7], 5),
  59. [[0, 1, 2, 3, 4], [5, 6, 7, 0, 0]],
  60. )
  61. # `item_list` contains BatchedNdArray (ndarrays that explicitly declare they
  62. # have a batch axis=0).
  63. check(
  64. split_and_zero_pad([
  65. BatchedNdArray([0, 1]),
  66. BatchedNdArray([2, 3, 4, 5]),
  67. BatchedNdArray([6, 7, 8]),
  68. ], 5),
  69. [[0, 1, 2, 3, 4], [5, 6, 7, 8, 0]],
  70. )
  71. Args:
  72. item_list: A list of individual items or BatchedNdArrays to be split into
  73. `max_seq_len` long pieces (the last of which may be zero-padded).
  74. max_seq_len: The maximum length of each item in the returned list.
  75. Returns:
  76. A list of np.ndarrays (all of length `max_seq_len`), which contains the same
  77. data as `item_list`, but split into sub-chunks of size `max_seq_len`.
  78. The last item in the returned list may be zero-padded, if necessary.
  79. """
  80. zero_element = tree.map_structure(
  81. lambda s: np.zeros_like([s[0]] if isinstance(s, BatchedNdArray) else s),
  82. item_list[0],
  83. )
  84. # The replacement list (to be returned) for `items_list`.
  85. # Items list contains n individual items.
  86. # -> ret will contain m batched rows, where m == n // T and the last row
  87. # may be zero padded (until T).
  88. ret = []
  89. # List of the T-axis item, collected to form the next row.
  90. current_time_row = []
  91. current_t = 0
  92. item_list = deque(item_list)
  93. while len(item_list) > 0:
  94. item = item_list.popleft()
  95. t = max_seq_len - current_t
  96. # In case `item` is a complex struct.
  97. item_flat = tree.flatten(item)
  98. item_list_append = []
  99. current_time_row_flat_items = []
  100. add_to_current_t = 0
  101. for itm in item_flat:
  102. # `itm` is already a batched np.array: Split if necessary.
  103. if isinstance(itm, BatchedNdArray):
  104. current_time_row_flat_items.append(itm[:t])
  105. if len(itm) <= t:
  106. add_to_current_t = len(itm)
  107. else:
  108. add_to_current_t = t
  109. item_list_append.append(itm[t:])
  110. # `itm` is a single item (no batch axis): Append and continue with next
  111. # item.
  112. else:
  113. current_time_row_flat_items.append(itm)
  114. add_to_current_t = 1
  115. current_t += add_to_current_t
  116. current_time_row.append(tree.unflatten_as(item, current_time_row_flat_items))
  117. if item_list_append:
  118. item_list.appendleft(tree.unflatten_as(item, item_list_append))
  119. # `current_time_row` is "full" (max_seq_len): Append as ndarray (with batch
  120. # axis) to `ret`.
  121. if current_t == max_seq_len:
  122. ret.append(
  123. batch(
  124. current_time_row,
  125. individual_items_already_have_batch_dim="auto",
  126. )
  127. )
  128. current_time_row = []
  129. current_t = 0
  130. # `current_time_row` is unfinished: Pad, if necessary and append to `ret`.
  131. if current_t > 0 and current_t < max_seq_len:
  132. current_time_row.extend([zero_element] * (max_seq_len - current_t))
  133. ret.append(
  134. batch(current_time_row, individual_items_already_have_batch_dim="auto")
  135. )
  136. return ret
  137. @DeveloperAPI
  138. def split_and_zero_pad_n_episodes(
  139. nd_array: np._typing.NDArray,
  140. episode_lens: List[int],
  141. max_seq_len: int,
  142. ) -> List[np._typing.NDArray]:
  143. """Splits and zero-pads a single np.ndarray based on episode lens and a maxlen.
  144. Args:
  145. nd_array: The single np.ndarray to be split into n chunks, based on the given
  146. `episode_lens` and the `max_seq_len` argument. For example, if `nd_array`
  147. has a batch dimension (axis 0) of 21, `episode_lens` is [15, 3, 3], and
  148. `max_seq_len` is 6, then the returned list would have np.ndarrays in it of
  149. batch dimensions (axis 0): [6, 6, 6 (zero-padded), 6 (zero-padded),
  150. 6 (zero-padded)].
  151. Note that this function doesn't work on nested data, such as dicts of
  152. ndarrays.
  153. episode_lens: A list of episode lengths along which to split and zero-pad the
  154. given `nd_array`.
  155. max_seq_len: The maximum sequence length to split at (and zero-pad).
  156. Returns: A list of n np.ndarrays, resulting from splitting and zero-padding the
  157. given `nd_array`.
  158. """
  159. ret = []
  160. cursor = 0
  161. for episode_len in episode_lens:
  162. items = BatchedNdArray(nd_array[cursor : cursor + episode_len])
  163. ret.extend(split_and_zero_pad([items], max_seq_len))
  164. cursor += episode_len
  165. return ret
  166. @DeveloperAPI
  167. def unpad_data_if_necessary(
  168. episode_lens: List[int],
  169. data: np._typing.NDArray,
  170. ) -> np._typing.NDArray:
  171. """Removes right-side zero-padding from data based on `episode_lens`.
  172. ..testcode::
  173. from ray.rllib.utils.postprocessing.zero_padding import unpad_data_if_necessary
  174. import numpy as np
  175. unpadded = unpad_data_if_necessary(
  176. episode_lens=[4, 2],
  177. data=np.array([
  178. [2, 4, 5, 3, 0, 0, 0, 0],
  179. [-1, 3, 0, 0, 0, 0, 0, 0],
  180. ]),
  181. )
  182. assert (unpadded == [2, 4, 5, 3, -1, 3]).all()
  183. unpadded = unpad_data_if_necessary(
  184. episode_lens=[1, 5],
  185. data=np.array([
  186. [2, 0, 0, 0, 0],
  187. [-1, -2, -3, -4, -5],
  188. ]),
  189. )
  190. assert (unpadded == [2, -1, -2, -3, -4, -5]).all()
  191. Args:
  192. episode_lens: A list of actual episode lengths.
  193. data: A 2D np.ndarray with right-side zero-padded rows.
  194. Returns:
  195. A 1D np.ndarray resulting from concatenation of the un-padded
  196. input data along the 0-axis.
  197. """
  198. # If data des NOT have time dimension, return right away.
  199. if len(data.shape) == 1:
  200. return data
  201. # Assert we only have B and T dimensions (meaning this function only operates
  202. # on single-float data, such as value function predictions, advantages, or rewards).
  203. assert len(data.shape) == 2
  204. new_data = []
  205. row_idx = 0
  206. T = data.shape[1]
  207. for len_ in episode_lens:
  208. # Calculate how many full rows this array occupies and how many elements are
  209. # in the last, potentially partial row.
  210. num_rows, col_idx = divmod(len_, T)
  211. # If the array spans multiple full rows, fully include these rows.
  212. for i in range(num_rows):
  213. new_data.append(data[row_idx])
  214. row_idx += 1
  215. # If there are elements in the last, potentially partial row, add this
  216. # partial row as well.
  217. if col_idx > 0:
  218. new_data.append(data[row_idx, :col_idx])
  219. # Move to the next row for the next array (skip the zero-padding zone).
  220. row_idx += 1
  221. return np.concatenate(new_data)