import math from typing import List, Optional from ray.data import DataIterator from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch, concat_samples from ray.rllib.utils import unflatten_dict from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.typing import DeviceType, EpisodeType @DeveloperAPI class MiniBatchIteratorBase: """The base class for all minibatch iterators.""" def __init__( self, batch: MultiAgentBatch, *, num_epochs: int = 1, shuffle_batch_per_epoch: bool = True, minibatch_size: int, num_total_minibatches: int = 0, ) -> None: """Initializes a MiniBatchIteratorBase instance. Args: batch: The input multi-agent batch. num_epochs: The number of complete passes over the entire train batch. Each pass might be further split into n minibatches (if `minibatch_size` provided). The train batch is generated from the given `episodes` through the Learner connector pipeline. minibatch_size: The size of minibatches to use to further split the train batch into per epoch. The train batch is generated from the given `episodes` through the Learner connector pipeline. num_total_minibatches: The total number of minibatches to loop through (over all `num_epochs` epochs). It's only required to set this to != 0 in multi-agent + multi-GPU situations, in which the MultiAgentEpisodes themselves are roughly sharded equally, however, they might contain SingleAgentEpisodes with very lopsided length distributions. Thus, without this fixed, pre-computed value, one Learner might go through a different number of minibatche passes than others causing a deadlock. """ pass @DeveloperAPI class MiniBatchCyclicIterator(MiniBatchIteratorBase): """This implements a simple multi-agent minibatch iterator. This iterator will split the input multi-agent batch into minibatches where the size of batch for each module_id (aka policy_id) is equal to minibatch_size. If the input batch is smaller than minibatch_size, then the iterator will cycle through the batch until it has covered `num_epochs` epochs. """ def __init__( self, batch: MultiAgentBatch, *, num_epochs: int = 1, minibatch_size: int, shuffle_batch_per_epoch: bool = True, num_total_minibatches: int = 0, ) -> None: """Initializes a MiniBatchCyclicIterator instance.""" super().__init__( batch, num_epochs=num_epochs, minibatch_size=minibatch_size, shuffle_batch_per_epoch=shuffle_batch_per_epoch, ) self._batch = batch self._minibatch_size = minibatch_size self._num_epochs = num_epochs self._shuffle_batch_per_epoch = shuffle_batch_per_epoch # mapping from module_id to the start index of the batch self._start = {mid: 0 for mid in batch.policy_batches.keys()} # mapping from module_id to the number of epochs covered for each module_id self._num_covered_epochs = {mid: 0 for mid in batch.policy_batches.keys()} self._minibatch_count = 0 self._num_total_minibatches = num_total_minibatches def __iter__(self): while ( # Make sure each item in the total batch gets at least iterated over # `self._num_epochs` times. ( self._num_total_minibatches == 0 and min(self._num_covered_epochs.values()) < self._num_epochs ) # Make sure we reach at least the given minimum number of mini-batches. or ( self._num_total_minibatches > 0 and self._minibatch_count < self._num_total_minibatches ) ): minibatch = {} for module_id, module_batch in self._batch.policy_batches.items(): if len(module_batch) == 0: raise ValueError( f"The batch for module_id {module_id} is empty! " "This will create an infinite loop because we need to cover " "the same number of samples for each module_id." ) s = self._start[module_id] # start # TODO (sven): Fix this bug for LSTMs: # In an RNN-setting, the Learner connector already has zero-padded # and added a timerank to the batch. Thus, n_step would still be based # on the BxT dimension, rather than the new B dimension (excluding T), # which then leads to minibatches way too large. # However, changing this already would break APPO/IMPALA w/o LSTMs as # these setups require sequencing, BUT their batches are not yet time- # ranked (this is done only in their loss functions via the # `make_time_major` utility). n_steps = self._minibatch_size samples_to_concat = [] # get_len is a function that returns the length of a batch # if we are not slicing the batch in the batch dimension B, then # the length of the batch is simply the length of the batch # o.w the length of the batch is the length list of seq_lens. if module_batch._slice_seq_lens_in_B: assert module_batch.get(SampleBatch.SEQ_LENS) is not None, ( "MiniBatchCyclicIterator requires SampleBatch.SEQ_LENS" "to be present in the batch for slicing a batch in the batch " "dimension B." ) def get_len(b): return len(b[SampleBatch.SEQ_LENS]) n_steps = int( get_len(module_batch) * (self._minibatch_size / len(module_batch)) ) else: def get_len(b): return len(b) # Cycle through the batch until we have enough samples. while s + n_steps >= get_len(module_batch): sample = module_batch[s:] samples_to_concat.append(sample) len_sample = get_len(sample) assert len_sample > 0, "Length of a sample must be > 0!" n_steps -= len_sample s = 0 self._num_covered_epochs[module_id] += 1 # Shuffle the individual single-agent batch, if required. # This should happen once per minibatch iteration in order to make # each iteration go through a different set of minibatches. if self._shuffle_batch_per_epoch: module_batch.shuffle() e = s + n_steps # end if e > s: samples_to_concat.append(module_batch[s:e]) # concatenate all the samples, we should have minibatch_size of sample # after this step minibatch[module_id] = concat_samples(samples_to_concat) # roll minibatch to zero when we reach the end of the batch self._start[module_id] = e # Note (Kourosh): env_steps is the total number of env_steps that this # multi-agent batch is covering. It should be simply inherited from the # original multi-agent batch. minibatch = MultiAgentBatch(minibatch, len(self._batch)) yield minibatch self._minibatch_count += 1 class MiniBatchDummyIterator(MiniBatchIteratorBase): def __init__(self, batch: MultiAgentBatch, **kwargs): super().__init__(batch, **kwargs) self._batch = batch def __iter__(self): yield self._batch @DeveloperAPI class MiniBatchRayDataIterator: def __init__( self, *, iterator: DataIterator, device: DeviceType, minibatch_size: int, num_iters: Optional[int], **kwargs, ): # A `ray.data.DataIterator` that can iterate in different ways over the data. self._iterator = iterator # Note, in multi-learner settings the `return_state` is in `kwargs`. self._kwargs = {k: v for k, v in kwargs.items() if k != "return_state"} # Holds a batched_iterable over the dataset. self._batched_iterable = self._iterator.iter_torch_batches( batch_size=minibatch_size, device=device, **self._kwargs, ) # Create an iterator that can be stopped and resumed during an epoch. self._epoch_iterator = iter(self._batched_iterable) self._num_iters = num_iters def __iter__(self) -> MultiAgentBatch: iteration = 0 while self._num_iters is None or iteration < self._num_iters: for batch in self._epoch_iterator: # Update the iteration counter. iteration += 1 batch = unflatten_dict(batch) batch = MultiAgentBatch( { module_id: SampleBatch(module_data) for module_id, module_data in batch.items() }, env_steps=sum( len(next(iter(module_data.values()))) for module_data in batch.values() ), ) yield (batch) # If `num_iters` is reached break and return. if self._num_iters and iteration == self._num_iters: break else: # Reinstantiate a new epoch iterator. self._epoch_iterator = iter(self._batched_iterable) # If a full epoch on the data should be run, stop. if not self._num_iters: # Exit the loop. break @DeveloperAPI class ShardBatchIterator: """Iterator for sharding batch into num_shards batches. Args: batch: The input multi-agent batch. num_shards: The number of shards to split the batch into. Yields: A MultiAgentBatch of size len(batch) / num_shards. """ def __init__(self, batch: MultiAgentBatch, num_shards: int): self._batch = batch self._num_shards = num_shards def __iter__(self): for i in range(self._num_shards): # TODO (sven): The following way of sharding a multi-agent batch destroys # the relationship of the different agents' timesteps to each other. # Thus, in case the algorithm requires agent-synchronized data (aka. # "lockstep"), the `ShardBatchIterator` cannot be used. batch_to_send = {} for pid, sub_batch in self._batch.policy_batches.items(): batch_size = math.ceil(len(sub_batch) / self._num_shards) start = batch_size * i end = min(start + batch_size, len(sub_batch)) batch_to_send[pid] = sub_batch[int(start) : int(end)] # TODO (Avnish): int(batch_size) ? How should we shard MA batches really? new_batch = MultiAgentBatch(batch_to_send, int(batch_size)) yield new_batch @DeveloperAPI class ShardEpisodesIterator: """Iterator for sharding a list of Episodes into `num_shards` lists of Episodes.""" def __init__( self, episodes: List[EpisodeType], num_shards: int, len_lookback_buffer: Optional[int] = None, ): """Initializes a ShardEpisodesIterator instance. Args: episodes: The input list of Episodes. num_shards: The number of shards to split the episodes into. len_lookback_buffer: An optional length of a lookback buffer to enforce on the returned shards. When spitting an episode, the second piece might need a lookback buffer (into the first piece) depending on the user's settings. """ self._episodes = sorted(episodes, key=len, reverse=True) self._num_shards = num_shards self._len_lookback_buffer = len_lookback_buffer self._total_length = sum(len(e) for e in episodes) self._target_lengths = [0 for _ in range(self._num_shards)] remaining_length = self._total_length for s in range(self._num_shards): len_ = remaining_length // (num_shards - s) self._target_lengths[s] = len_ remaining_length -= len_ def __iter__(self) -> List[EpisodeType]: """Runs one iteration through this sharder. Yields: A sub-list of Episodes of size roughly `len(episodes) / num_shards`. The yielded sublists might have slightly different total sums of episode lengths, in order to not have to drop even a single timestep. """ sublists = [[] for _ in range(self._num_shards)] lengths = [0 for _ in range(self._num_shards)] episode_index = 0 while episode_index < len(self._episodes): episode = self._episodes[episode_index] min_index = lengths.index(min(lengths)) # Add the whole episode if it fits within the target length if lengths[min_index] + len(episode) <= self._target_lengths[min_index]: sublists[min_index].append(episode) lengths[min_index] += len(episode) episode_index += 1 # Otherwise, slice the episode else: remaining_length = self._target_lengths[min_index] - lengths[min_index] if remaining_length > 0: slice_part, remaining_part = ( # Note that the first slice will automatically "inherit" the # lookback buffer size of the episode. episode[:remaining_length], # However, the second slice might need a user defined lookback # buffer (into the first slice). episode.slice( slice(remaining_length, None), len_lookback_buffer=self._len_lookback_buffer, ), ) sublists[min_index].append(slice_part) lengths[min_index] += len(slice_part) self._episodes[episode_index] = remaining_part else: assert remaining_length == 0 sublists[min_index].append(episode) episode_index += 1 for sublist in sublists: yield sublist @DeveloperAPI class ShardObjectRefIterator: """Iterator for sharding a list of ray ObjectRefs into num_shards sub-lists. Args: object_refs: The input list of ray ObjectRefs. num_shards: The number of shards to split the references into. Yields: A sub-list of ray ObjectRefs with lengths as equal as possible. """ def __init__(self, object_refs, num_shards: int): self._object_refs = object_refs self._num_shards = num_shards def __iter__(self): # Calculate the size of each sublist n = len(self._object_refs) sublist_size = n // self._num_shards remaining_elements = n % self._num_shards start = 0 for i in range(self._num_shards): # Determine the end index for the current sublist end = start + sublist_size + (1 if i < remaining_elements else 0) # Append the sublist to the result yield self._object_refs[start:end] # Update the start index for the next sublist start = end