| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270 |
- from collections import deque
- from typing import List, Tuple, Union
- import numpy as np
- import tree # pip install dm_tree
- from ray.rllib.utils.spaces.space_utils import BatchedNdArray, batch
- from ray.util.annotations import DeveloperAPI
- @DeveloperAPI
- def create_mask_and_seq_lens(episode_len: int, T: int) -> Tuple[List, List]:
- """Creates loss mask and a seq_lens array, given an episode length and T.
- Args:
- episode_lens: A list of episode lengths to infer the loss mask and seq_lens
- array from.
- T: The maximum number of timesteps in each "row", also known as the maximum
- sequence length (max_seq_len). Episodes are split into chunks that are at
- most `T` long and remaining timesteps will be zero-padded (and masked out).
- Returns:
- Tuple consisting of a) list of the loss masks to use (masking out areas that
- are past the end of an episode (or rollout), but had to be zero-added due to
- the added extra time rank (of length T) and b) the list of sequence lengths
- resulting from splitting the given episodes into chunks of at most `T`
- timesteps.
- """
- mask = []
- seq_lens = []
- len_ = min(episode_len, T)
- seq_lens.append(len_)
- row = np.array([1] * len_ + [0] * (T - len_), np.bool_)
- mask.append(row)
- # Handle sequence lengths greater than T.
- overflow = episode_len - T
- while overflow > 0:
- len_ = min(overflow, T)
- seq_lens.append(len_)
- extra_row = np.array([1] * len_ + [0] * (T - len_), np.bool_)
- mask.append(extra_row)
- overflow -= T
- return mask, seq_lens
- @DeveloperAPI
- def split_and_zero_pad(
- item_list: List[Union[BatchedNdArray, np._typing.NDArray, float]],
- max_seq_len: int,
- ) -> List[np._typing.NDArray]:
- """Splits the contents of `item_list` into a new list of ndarrays and returns it.
- In the returned list, each item is one ndarray of len (axis=0) `max_seq_len`.
- The last item in the returned list may be (right) zero-padded, if necessary, to
- reach `max_seq_len`.
- If `item_list` contains one or more `BatchedNdArray` (instead of individual
- items), these will be split accordingly along their axis=0 to yield the returned
- structure described above.
- .. testcode::
- from ray.rllib.utils.postprocessing.zero_padding import (
- BatchedNdArray,
- split_and_zero_pad,
- )
- from ray.rllib.utils.test_utils import check
- # Simple case: `item_list` contains individual floats.
- check(
- split_and_zero_pad([0, 1, 2, 3, 4, 5, 6, 7], 5),
- [[0, 1, 2, 3, 4], [5, 6, 7, 0, 0]],
- )
- # `item_list` contains BatchedNdArray (ndarrays that explicitly declare they
- # have a batch axis=0).
- check(
- split_and_zero_pad([
- BatchedNdArray([0, 1]),
- BatchedNdArray([2, 3, 4, 5]),
- BatchedNdArray([6, 7, 8]),
- ], 5),
- [[0, 1, 2, 3, 4], [5, 6, 7, 8, 0]],
- )
- Args:
- item_list: A list of individual items or BatchedNdArrays to be split into
- `max_seq_len` long pieces (the last of which may be zero-padded).
- max_seq_len: The maximum length of each item in the returned list.
- Returns:
- A list of np.ndarrays (all of length `max_seq_len`), which contains the same
- data as `item_list`, but split into sub-chunks of size `max_seq_len`.
- The last item in the returned list may be zero-padded, if necessary.
- """
- zero_element = tree.map_structure(
- lambda s: np.zeros_like([s[0]] if isinstance(s, BatchedNdArray) else s),
- item_list[0],
- )
- # The replacement list (to be returned) for `items_list`.
- # Items list contains n individual items.
- # -> ret will contain m batched rows, where m == n // T and the last row
- # may be zero padded (until T).
- ret = []
- # List of the T-axis item, collected to form the next row.
- current_time_row = []
- current_t = 0
- item_list = deque(item_list)
- while len(item_list) > 0:
- item = item_list.popleft()
- t = max_seq_len - current_t
- # In case `item` is a complex struct.
- item_flat = tree.flatten(item)
- item_list_append = []
- current_time_row_flat_items = []
- add_to_current_t = 0
- for itm in item_flat:
- # `itm` is already a batched np.array: Split if necessary.
- if isinstance(itm, BatchedNdArray):
- current_time_row_flat_items.append(itm[:t])
- if len(itm) <= t:
- add_to_current_t = len(itm)
- else:
- add_to_current_t = t
- item_list_append.append(itm[t:])
- # `itm` is a single item (no batch axis): Append and continue with next
- # item.
- else:
- current_time_row_flat_items.append(itm)
- add_to_current_t = 1
- current_t += add_to_current_t
- current_time_row.append(tree.unflatten_as(item, current_time_row_flat_items))
- if item_list_append:
- item_list.appendleft(tree.unflatten_as(item, item_list_append))
- # `current_time_row` is "full" (max_seq_len): Append as ndarray (with batch
- # axis) to `ret`.
- if current_t == max_seq_len:
- ret.append(
- batch(
- current_time_row,
- individual_items_already_have_batch_dim="auto",
- )
- )
- current_time_row = []
- current_t = 0
- # `current_time_row` is unfinished: Pad, if necessary and append to `ret`.
- if current_t > 0 and current_t < max_seq_len:
- current_time_row.extend([zero_element] * (max_seq_len - current_t))
- ret.append(
- batch(current_time_row, individual_items_already_have_batch_dim="auto")
- )
- return ret
- @DeveloperAPI
- def split_and_zero_pad_n_episodes(
- nd_array: np._typing.NDArray,
- episode_lens: List[int],
- max_seq_len: int,
- ) -> List[np._typing.NDArray]:
- """Splits and zero-pads a single np.ndarray based on episode lens and a maxlen.
- Args:
- nd_array: The single np.ndarray to be split into n chunks, based on the given
- `episode_lens` and the `max_seq_len` argument. For example, if `nd_array`
- has a batch dimension (axis 0) of 21, `episode_lens` is [15, 3, 3], and
- `max_seq_len` is 6, then the returned list would have np.ndarrays in it of
- batch dimensions (axis 0): [6, 6, 6 (zero-padded), 6 (zero-padded),
- 6 (zero-padded)].
- Note that this function doesn't work on nested data, such as dicts of
- ndarrays.
- episode_lens: A list of episode lengths along which to split and zero-pad the
- given `nd_array`.
- max_seq_len: The maximum sequence length to split at (and zero-pad).
- Returns: A list of n np.ndarrays, resulting from splitting and zero-padding the
- given `nd_array`.
- """
- ret = []
- cursor = 0
- for episode_len in episode_lens:
- items = BatchedNdArray(nd_array[cursor : cursor + episode_len])
- ret.extend(split_and_zero_pad([items], max_seq_len))
- cursor += episode_len
- return ret
- @DeveloperAPI
- def unpad_data_if_necessary(
- episode_lens: List[int],
- data: np._typing.NDArray,
- ) -> np._typing.NDArray:
- """Removes right-side zero-padding from data based on `episode_lens`.
- ..testcode::
- from ray.rllib.utils.postprocessing.zero_padding import unpad_data_if_necessary
- import numpy as np
- unpadded = unpad_data_if_necessary(
- episode_lens=[4, 2],
- data=np.array([
- [2, 4, 5, 3, 0, 0, 0, 0],
- [-1, 3, 0, 0, 0, 0, 0, 0],
- ]),
- )
- assert (unpadded == [2, 4, 5, 3, -1, 3]).all()
- unpadded = unpad_data_if_necessary(
- episode_lens=[1, 5],
- data=np.array([
- [2, 0, 0, 0, 0],
- [-1, -2, -3, -4, -5],
- ]),
- )
- assert (unpadded == [2, -1, -2, -3, -4, -5]).all()
- Args:
- episode_lens: A list of actual episode lengths.
- data: A 2D np.ndarray with right-side zero-padded rows.
- Returns:
- A 1D np.ndarray resulting from concatenation of the un-padded
- input data along the 0-axis.
- """
- # If data des NOT have time dimension, return right away.
- if len(data.shape) == 1:
- return data
- # Assert we only have B and T dimensions (meaning this function only operates
- # on single-float data, such as value function predictions, advantages, or rewards).
- assert len(data.shape) == 2
- new_data = []
- row_idx = 0
- T = data.shape[1]
- for len_ in episode_lens:
- # Calculate how many full rows this array occupies and how many elements are
- # in the last, potentially partial row.
- num_rows, col_idx = divmod(len_, T)
- # If the array spans multiple full rows, fully include these rows.
- for i in range(num_rows):
- new_data.append(data[row_idx])
- row_idx += 1
- # If there are elements in the last, potentially partial row, add this
- # partial row as well.
- if col_idx > 0:
- new_data.append(data[row_idx, :col_idx])
- # Move to the next row for the next array (skip the zero-padding zone).
- row_idx += 1
- return np.concatenate(new_data)
|