| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683 |
- """RNN utils for RLlib.
- The main trick here is that we add the time dimension at the last moment.
- The non-LSTM layers of the model see their inputs as one flat batch. Before
- the LSTM cell, we reshape the input to add the expected time dimension. During
- postprocessing, we dynamically pad the experience batches so that this
- reshaping is possible.
- Note that this padding strategy only works out if we assume zero inputs don't
- meaningfully affect the loss function. This happens to be true for all the
- current algorithms: https://github.com/ray-project/ray/issues/2992
- """
- import functools
- import logging
- from typing import List, Optional
- import numpy as np
- import tree # pip install dm_tree
- from ray.rllib.policy.sample_batch import SampleBatch
- from ray.rllib.utils.annotations import OldAPIStack
- from ray.rllib.utils.debug import summarize
- from ray.rllib.utils.framework import try_import_tf, try_import_torch
- from ray.rllib.utils.typing import SampleBatchType, TensorType, ViewRequirementsDict
- from ray.util import log_once
- tf1, tf, tfv = try_import_tf()
- torch, _ = try_import_torch()
- logger = logging.getLogger(__name__)
- @OldAPIStack
- def pad_batch_to_sequences_of_same_size(
- batch: SampleBatch,
- max_seq_len: int,
- shuffle: bool = False,
- batch_divisibility_req: int = 1,
- feature_keys: Optional[List[str]] = None,
- view_requirements: Optional[ViewRequirementsDict] = None,
- _enable_new_api_stack: bool = False,
- padding: str = "zero",
- ):
- """Applies padding to `batch` so it's choppable into same-size sequences.
- Shuffles `batch` (if desired), makes sure divisibility requirement is met,
- then pads the batch ([B, ...]) into same-size chunks ([B, ...]) w/o
- adding a time dimension (yet).
- Padding depends on episodes found in batch and `max_seq_len`.
- Args:
- batch: The SampleBatch object. All values in here have
- the shape [B, ...].
- max_seq_len: The max. sequence length to use for chopping.
- shuffle: Whether to shuffle batch sequences. Shuffle may
- be done in-place. This only makes sense if you're further
- applying minibatch SGD after getting the outputs.
- batch_divisibility_req: The int by which the batch dimension
- must be dividable.
- feature_keys: An optional list of keys to apply sequence-chopping
- to. If None, use all keys in batch that are not
- "state_in/out_"-type keys.
- view_requirements: An optional Policy ViewRequirements dict to
- be able to infer whether e.g. dynamic max'ing should be
- applied over the seq_lens.
- _enable_new_api_stack: This is a temporary flag to enable the new RLModule API.
- After a complete rollout of the new API, this flag will be removed.
- padding: Padding type to use. Either "zero" or "last". Zero padding
- will pad with zeros, last padding will pad with the last value.
- """
- # If already zero-padded, skip.
- if batch.zero_padded:
- return
- batch.zero_padded = True
- if batch_divisibility_req > 1:
- meets_divisibility_reqs = (
- len(batch[SampleBatch.CUR_OBS]) % batch_divisibility_req == 0
- # not multiagent
- and max(batch[SampleBatch.AGENT_INDEX]) == 0
- )
- else:
- meets_divisibility_reqs = True
- states_already_reduced_to_init = False
- # RNN/attention net case. Figure out whether we should apply dynamic
- # max'ing over the list of sequence lengths.
- if _enable_new_api_stack and ("state_in" in batch or "state_out" in batch):
- # TODO (Kourosh): This is a temporary fix to enable the new RLModule API.
- # We should think of a more elegant solution once we have confirmed that other
- # parts of the API are stable and user-friendly.
- seq_lens = batch.get(SampleBatch.SEQ_LENS)
- # state_in is a nested dict of tensors of states. We need to retreive the
- # length of the inner most tensor (which should be already the same as the
- # length of other tensors) and compare it to len(seq_lens).
- state_ins = tree.flatten(batch["state_in"])
- if state_ins:
- assert all(
- len(state_in) == len(state_ins[0]) for state_in in state_ins
- ), "All state_in tensors should have the same batch_dim size."
- # if the batch dim of states is the same as the number of sequences
- if len(state_ins[0]) == len(seq_lens):
- states_already_reduced_to_init = True
- # TODO (Kourosh): What is the use-case of DynamicMax functionality?
- dynamic_max = True
- else:
- dynamic_max = False
- elif not _enable_new_api_stack and (
- "state_in_0" in batch or "state_out_0" in batch
- ):
- # Check, whether the state inputs have already been reduced to their
- # init values at the beginning of each max_seq_len chunk.
- if batch.get(SampleBatch.SEQ_LENS) is not None and len(
- batch["state_in_0"]
- ) == len(batch[SampleBatch.SEQ_LENS]):
- states_already_reduced_to_init = True
- # RNN (or single timestep state-in): Set the max dynamically.
- if view_requirements and view_requirements["state_in_0"].shift_from is None:
- dynamic_max = True
- # Attention Nets (state inputs are over some range): No dynamic maxing
- # possible.
- else:
- dynamic_max = False
- # Multi-agent case.
- elif not meets_divisibility_reqs:
- max_seq_len = batch_divisibility_req
- dynamic_max = False
- batch.max_seq_len = max_seq_len
- # Simple case: No RNN/attention net, nor do we need to pad.
- else:
- if shuffle:
- batch.shuffle()
- return
- # RNN, attention net, or multi-agent case.
- state_keys = []
- feature_keys_ = feature_keys or []
- for k, v in batch.items():
- if k.startswith("state_in"):
- state_keys.append(k)
- elif (
- not feature_keys
- and (not k.startswith("state_out") if not _enable_new_api_stack else True)
- and k not in [SampleBatch.SEQ_LENS]
- ):
- feature_keys_.append(k)
- feature_sequences, initial_states, seq_lens = chop_into_sequences(
- feature_columns=[batch[k] for k in feature_keys_],
- state_columns=[batch[k] for k in state_keys],
- episode_ids=batch.get(SampleBatch.EPS_ID),
- unroll_ids=batch.get(SampleBatch.UNROLL_ID),
- agent_indices=batch.get(SampleBatch.AGENT_INDEX),
- seq_lens=batch.get(SampleBatch.SEQ_LENS),
- max_seq_len=max_seq_len,
- dynamic_max=dynamic_max,
- states_already_reduced_to_init=states_already_reduced_to_init,
- shuffle=shuffle,
- handle_nested_data=True,
- padding=padding,
- pad_infos_with_empty_dicts=_enable_new_api_stack,
- )
- for i, k in enumerate(feature_keys_):
- batch[k] = tree.unflatten_as(batch[k], feature_sequences[i])
- for i, k in enumerate(state_keys):
- batch[k] = initial_states[i]
- batch[SampleBatch.SEQ_LENS] = np.array(seq_lens)
- if dynamic_max:
- batch.max_seq_len = max(seq_lens)
- if log_once("rnn_ma_feed_dict"):
- logger.info(
- "Padded input for RNN/Attn.Nets/MA:\n\n{}\n".format(
- summarize(
- {
- "features": feature_sequences,
- "initial_states": initial_states,
- "seq_lens": seq_lens,
- "max_seq_len": max_seq_len,
- }
- )
- )
- )
- @OldAPIStack
- def add_time_dimension(
- padded_inputs: TensorType,
- *,
- seq_lens: TensorType,
- framework: str = "tf",
- time_major: bool = False,
- ):
- """Adds a time dimension to padded inputs.
- Args:
- padded_inputs: a padded batch of sequences. That is,
- for seq_lens=[1, 2, 2], then inputs=[A, *, B, B, C, C], where
- A, B, C are sequence elements and * denotes padding.
- seq_lens: A 1D tensor of sequence lengths, denoting the non-padded length
- in timesteps of each rollout in the batch.
- framework: The framework string ("tf2", "tf", "torch").
- time_major: Whether data should be returned in time-major (TxB)
- format or not (BxT).
- Returns:
- TensorType: Reshaped tensor of shape [B, T, ...] or [T, B, ...].
- """
- # Sequence lengths have to be specified for LSTM batch inputs. The
- # input batch must be padded to the max seq length given here. That is,
- # batch_size == len(seq_lens) * max(seq_lens)
- if framework in ["tf2", "tf"]:
- assert time_major is False, "time-major not supported yet for tf!"
- padded_inputs = tf.convert_to_tensor(padded_inputs)
- padded_batch_size = tf.shape(padded_inputs)[0]
- # Dynamically reshape the padded batch to introduce a time dimension.
- new_batch_size = tf.shape(seq_lens)[0]
- time_size = padded_batch_size // new_batch_size
- new_shape = tf.concat(
- [
- tf.expand_dims(new_batch_size, axis=0),
- tf.expand_dims(time_size, axis=0),
- tf.shape(padded_inputs)[1:],
- ],
- axis=0,
- )
- return tf.reshape(padded_inputs, new_shape)
- elif framework == "torch":
- padded_inputs = torch.as_tensor(padded_inputs)
- padded_batch_size = padded_inputs.shape[0]
- # Dynamically reshape the padded batch to introduce a time dimension.
- new_batch_size = seq_lens.shape[0]
- time_size = padded_batch_size // new_batch_size
- batch_major_shape = (new_batch_size, time_size) + padded_inputs.shape[1:]
- padded_outputs = padded_inputs.view(batch_major_shape)
- if time_major:
- # Swap the batch and time dimensions
- padded_outputs = padded_outputs.transpose(0, 1)
- return padded_outputs
- else:
- assert framework == "np", "Unknown framework: {}".format(framework)
- padded_inputs = np.asarray(padded_inputs)
- padded_batch_size = padded_inputs.shape[0]
- # Dynamically reshape the padded batch to introduce a time dimension.
- new_batch_size = seq_lens.shape[0]
- time_size = padded_batch_size // new_batch_size
- batch_major_shape = (new_batch_size, time_size) + padded_inputs.shape[1:]
- padded_outputs = padded_inputs.reshape(batch_major_shape)
- if time_major:
- # Swap the batch and time dimensions
- padded_outputs = padded_outputs.transpose(0, 1)
- return padded_outputs
- @OldAPIStack
- def chop_into_sequences(
- *,
- feature_columns,
- state_columns,
- max_seq_len,
- episode_ids=None,
- unroll_ids=None,
- agent_indices=None,
- dynamic_max=True,
- shuffle=False,
- seq_lens=None,
- states_already_reduced_to_init=False,
- handle_nested_data=False,
- _extra_padding=0,
- padding: str = "zero",
- pad_infos_with_empty_dicts: bool = False,
- ):
- """Truncate and pad experiences into fixed-length sequences.
- Args:
- feature_columns: List of arrays containing features.
- state_columns: List of arrays containing LSTM state values.
- max_seq_len: Max length of sequences. Sequences longer than max_seq_len
- will be split into subsequences that span the batch dimension
- and sum to max_seq_len.
- episode_ids (List[EpisodeID]): List of episode ids for each step.
- unroll_ids (List[UnrollID]): List of identifiers for the sample batch.
- This is used to make sure sequences are cut between sample batches.
- agent_indices (List[AgentID]): List of agent ids for each step. Note
- that this has to be combined with episode_ids for uniqueness.
- dynamic_max: Whether to dynamically shrink the max seq len.
- For example, if max len is 20 and the actual max seq len in the
- data is 7, it will be shrunk to 7.
- shuffle: Whether to shuffle the sequence outputs.
- handle_nested_data: If True, assume that the data in
- `feature_columns` could be nested structures (of data).
- If False, assumes that all items in `feature_columns` are
- only np.ndarrays (no nested structured of np.ndarrays).
- _extra_padding: Add extra padding to the end of sequences.
- padding: Padding type to use. Either "zero" or "last". Zero padding
- will pad with zeros, last padding will pad with the last value.
- pad_infos_with_empty_dicts: If True, will zero-pad INFOs with empty
- dicts (instead of None). Used by the new API stack in the meantime,
- however, as soon as the new ConnectorV2 API will be activated (as
- part of the new API stack), we will no longer use this utility function
- anyway.
- Returns:
- f_pad: Padded feature columns. These will be of shape
- [NUM_SEQUENCES * MAX_SEQ_LEN, ...].
- s_init: Initial states for each sequence, of shape
- [NUM_SEQUENCES, ...].
- seq_lens: List of sequence lengths, of shape [NUM_SEQUENCES].
- .. testcode::
- :skipif: True
- from ray.rllib.policy.rnn_sequencing import chop_into_sequences
- f_pad, s_init, seq_lens = chop_into_sequences(
- episode_ids=[1, 1, 5, 5, 5, 5],
- unroll_ids=[4, 4, 4, 4, 4, 4],
- agent_indices=[0, 0, 0, 0, 0, 0],
- feature_columns=[[4, 4, 8, 8, 8, 8],
- [1, 1, 0, 1, 1, 0]],
- state_columns=[[4, 5, 4, 5, 5, 5]],
- max_seq_len=3)
- print(f_pad)
- print(s_init)
- print(seq_lens)
- .. testoutput::
- [[4, 4, 0, 8, 8, 8, 8, 0, 0],
- [1, 1, 0, 0, 1, 1, 0, 0, 0]]
- [[4, 4, 5]]
- [2, 3, 1]
- """
- if seq_lens is None or len(seq_lens) == 0:
- prev_id = None
- seq_lens = []
- seq_len = 0
- unique_ids = np.add(
- np.add(episode_ids, agent_indices),
- np.array(unroll_ids, dtype=np.int64) << 32,
- )
- for uid in unique_ids:
- if (prev_id is not None and uid != prev_id) or seq_len >= max_seq_len:
- seq_lens.append(seq_len)
- seq_len = 0
- seq_len += 1
- prev_id = uid
- if seq_len:
- seq_lens.append(seq_len)
- seq_lens = np.array(seq_lens, dtype=np.int32)
- # Dynamically shrink max len as needed to optimize memory usage
- if dynamic_max:
- max_seq_len = max(seq_lens) + _extra_padding
- length = len(seq_lens) * max_seq_len
- feature_sequences = []
- for col in feature_columns:
- if isinstance(col, list):
- col = np.array(col)
- feature_sequences.append([])
- for f in tree.flatten(col):
- # Save unnecessary copy.
- if not isinstance(f, np.ndarray):
- f = np.array(f)
- # New stack behavior (temporarily until we move to ConnectorV2 API, where
- # this (admitedly convoluted) function will no longer be used at all).
- if (
- f.dtype == object
- and pad_infos_with_empty_dicts
- and isinstance(f[0], dict)
- ):
- f_pad = [{} for _ in range(length)]
- # Old stack behavior: Pad INFOs with None.
- elif f.dtype == object or f.dtype.type is np.str_:
- f_pad = [None] * length
- # Pad everything else with zeros.
- else:
- # Make sure type doesn't change.
- f_pad = np.zeros((length,) + np.shape(f)[1:], dtype=f.dtype)
- seq_base = 0
- i = 0
- for len_ in seq_lens:
- for seq_offset in range(len_):
- f_pad[seq_base + seq_offset] = f[i]
- i += 1
- if padding == "last":
- for seq_offset in range(len_, max_seq_len):
- f_pad[seq_base + seq_offset] = f[i - 1]
- seq_base += max_seq_len
- assert i == len(f), f
- feature_sequences[-1].append(f_pad)
- if states_already_reduced_to_init:
- initial_states = state_columns
- else:
- initial_states = []
- for state_column in state_columns:
- if isinstance(state_column, list):
- state_column = np.array(state_column)
- initial_state_flat = []
- # state_column may have a nested structure (e.g. LSTM state).
- for s in tree.flatten(state_column):
- # Skip unnecessary copy.
- if not isinstance(s, np.ndarray):
- s = np.array(s)
- s_init = []
- i = 0
- for len_ in seq_lens:
- s_init.append(s[i])
- i += len_
- initial_state_flat.append(np.array(s_init))
- initial_states.append(tree.unflatten_as(state_column, initial_state_flat))
- if shuffle:
- permutation = np.random.permutation(len(seq_lens))
- for i, f in enumerate(tree.flatten(feature_sequences)):
- orig_shape = f.shape
- f = np.reshape(f, (len(seq_lens), -1) + f.shape[1:])
- f = f[permutation]
- f = np.reshape(f, orig_shape)
- feature_sequences[i] = f
- for i, s in enumerate(initial_states):
- s = s[permutation]
- initial_states[i] = s
- seq_lens = seq_lens[permutation]
- # Classic behavior: Don't assume data in feature_columns are nested
- # structs. Don't return them as flattened lists, but as is (index 0).
- if not handle_nested_data:
- feature_sequences = [f[0] for f in feature_sequences]
- return feature_sequences, initial_states, seq_lens
- @OldAPIStack
- def timeslice_along_seq_lens_with_overlap(
- sample_batch: SampleBatchType,
- seq_lens: Optional[List[int]] = None,
- zero_pad_max_seq_len: int = 0,
- pre_overlap: int = 0,
- zero_init_states: bool = True,
- ) -> List["SampleBatch"]:
- """Slices batch along `seq_lens` (each seq-len item produces one batch).
- Args:
- sample_batch: The SampleBatch to timeslice.
- seq_lens (Optional[List[int]]): An optional list of seq_lens to slice
- at. If None, use `sample_batch[SampleBatch.SEQ_LENS]`.
- zero_pad_max_seq_len: If >0, already zero-pad the resulting
- slices up to this length. NOTE: This max-len will include the
- additional timesteps gained via setting pre_overlap (see Example).
- pre_overlap: If >0, will overlap each two consecutive slices by
- this many timesteps (toward the left side). This will cause
- zero-padding at the very beginning of the batch.
- zero_init_states: Whether initial states should always be
- zero'd. If False, will use the state_outs of the batch to
- populate state_in values.
- Returns:
- List[SampleBatch]: The list of (new) SampleBatches.
- Examples:
- assert seq_lens == [5, 5, 2]
- assert sample_batch.count == 12
- # self = 0 1 2 3 4 | 5 6 7 8 9 | 10 11 <- timesteps
- slices = timeslice_along_seq_lens_with_overlap(
- sample_batch=sample_batch.
- zero_pad_max_seq_len=10,
- pre_overlap=3)
- # Z = zero padding (at beginning or end).
- # |pre (3)| seq | max-seq-len (up to 10)
- # slices[0] = | Z Z Z | 0 1 2 3 4 | Z Z
- # slices[1] = | 2 3 4 | 5 6 7 8 9 | Z Z
- # slices[2] = | 7 8 9 | 10 11 Z Z Z | Z Z
- # Note that `zero_pad_max_seq_len=10` includes the 3 pre-overlaps
- # count (makes sure each slice has exactly length 10).
- """
- if seq_lens is None:
- seq_lens = sample_batch.get(SampleBatch.SEQ_LENS)
- else:
- if sample_batch.get(SampleBatch.SEQ_LENS) is not None and log_once(
- "overriding_sequencing_information"
- ):
- logger.warning(
- "Found sequencing information in a batch that will be "
- "ignored when slicing. Ignore this warning if you know "
- "what you are doing."
- )
- if seq_lens is None:
- max_seq_len = zero_pad_max_seq_len - pre_overlap
- if log_once("no_sequence_lengths_available_for_time_slicing"):
- logger.warning(
- "Trying to slice a batch along sequences without "
- "sequence lengths being provided in the batch. Batch will "
- "be sliced into slices of size "
- "{} = {} - {} = zero_pad_max_seq_len - pre_overlap.".format(
- max_seq_len, zero_pad_max_seq_len, pre_overlap
- )
- )
- num_seq_lens, last_seq_len = divmod(len(sample_batch), max_seq_len)
- seq_lens = [zero_pad_max_seq_len] * num_seq_lens + (
- [last_seq_len] if last_seq_len else []
- )
- assert (
- seq_lens is not None and len(seq_lens) > 0
- ), "Cannot timeslice along `seq_lens` when `seq_lens` is empty or None!"
- # Generate n slices based on seq_lens.
- start = 0
- slices = []
- for seq_len in seq_lens:
- pre_begin = start - pre_overlap
- slice_begin = start
- end = start + seq_len
- slices.append((pre_begin, slice_begin, end))
- start += seq_len
- timeslices = []
- for begin, slice_begin, end in slices:
- zero_length = None
- data_begin = 0
- zero_init_states_ = zero_init_states
- if begin < 0:
- zero_length = pre_overlap
- data_begin = slice_begin
- zero_init_states_ = True
- else:
- eps_ids = sample_batch[SampleBatch.EPS_ID][begin if begin >= 0 else 0 : end]
- is_last_episode_ids = eps_ids == eps_ids[-1]
- if not is_last_episode_ids[0]:
- zero_length = int(sum(1.0 - is_last_episode_ids))
- data_begin = begin + zero_length
- zero_init_states_ = True
- if zero_length is not None:
- data = {
- k: np.concatenate(
- [
- np.zeros(shape=(zero_length,) + v.shape[1:], dtype=v.dtype),
- v[data_begin:end],
- ]
- )
- for k, v in sample_batch.items()
- if k != SampleBatch.SEQ_LENS
- }
- else:
- data = {
- k: v[begin:end]
- for k, v in sample_batch.items()
- if k != SampleBatch.SEQ_LENS
- }
- if zero_init_states_:
- i = 0
- key = "state_in_{}".format(i)
- while key in data:
- data[key] = np.zeros_like(sample_batch[key][0:1])
- # Del state_out_n from data if exists.
- data.pop("state_out_{}".format(i), None)
- i += 1
- key = "state_in_{}".format(i)
- # TODO: This will not work with attention nets as their state_outs are
- # not compatible with state_ins.
- else:
- i = 0
- key = "state_in_{}".format(i)
- while key in data:
- data[key] = sample_batch["state_out_{}".format(i)][begin - 1 : begin]
- del data["state_out_{}".format(i)]
- i += 1
- key = "state_in_{}".format(i)
- timeslices.append(SampleBatch(data, seq_lens=[end - begin]))
- # Zero-pad each slice if necessary.
- if zero_pad_max_seq_len > 0:
- for ts in timeslices:
- ts.right_zero_pad(max_seq_len=zero_pad_max_seq_len, exclude_states=True)
- return timeslices
- @OldAPIStack
- def get_fold_unfold_fns(b_dim: int, t_dim: int, framework: str):
- """Produces two functions to fold/unfold any Tensors in a struct.
- Args:
- b_dim: The batch dimension to use for folding.
- t_dim: The time dimension to use for folding.
- framework: The framework to use for folding. One of "tf2" or "torch".
- Returns:
- fold: A function that takes a struct of torch.Tensors and reshapes
- them to have a first dimension of `b_dim * t_dim`.
- unfold: A function that takes a struct of torch.Tensors and reshapes
- them to have a first dimension of `b_dim` and a second dimension
- of `t_dim`.
- """
- if framework in "tf2":
- # TensorFlow traced eager complains if we don't convert these to tensors here
- b_dim = tf.convert_to_tensor(b_dim)
- t_dim = tf.convert_to_tensor(t_dim)
- def fold_mapping(item):
- if item is None:
- # Torch has no representation for `None`, so we return None
- return item
- item = tf.convert_to_tensor(item)
- shape = tf.shape(item)
- other_dims = shape[2:]
- return tf.reshape(item, tf.concat([[b_dim * t_dim], other_dims], axis=0))
- def unfold_mapping(item):
- if item is None:
- return item
- item = tf.convert_to_tensor(item)
- shape = item.shape
- other_dims = shape[1:]
- return tf.reshape(item, tf.concat([[b_dim], [t_dim], other_dims], axis=0))
- elif framework == "torch":
- def fold_mapping(item):
- if item is None:
- # Torch has no representation for `None`, so we return None
- return item
- item = torch.as_tensor(item)
- size = list(item.size())
- current_b_dim, current_t_dim = list(size[:2])
- assert (b_dim, t_dim) == (current_b_dim, current_t_dim), (
- "All tensors in the struct must have the same batch and time "
- "dimensions. Got {} and {}.".format(
- (b_dim, t_dim), (current_b_dim, current_t_dim)
- )
- )
- other_dims = size[2:]
- return item.reshape([b_dim * t_dim] + other_dims)
- def unfold_mapping(item):
- if item is None:
- return item
- item = torch.as_tensor(item)
- size = list(item.size())
- current_b_dim = size[0]
- other_dims = size[1:]
- assert current_b_dim == b_dim * t_dim, (
- "The first dimension of the tensor must be equal to the product of "
- "the desired batch and time dimensions. Got {} and {}.".format(
- current_b_dim, b_dim * t_dim
- )
- )
- return item.reshape([b_dim, t_dim] + other_dims)
- else:
- raise ValueError(f"framework {framework} not implemented!")
- return functools.partial(tree.map_structure, fold_mapping), functools.partial(
- tree.map_structure, unfold_mapping
- )
|