| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391 |
- import math
- from typing import List, Optional
- from ray.data import DataIterator
- from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch, concat_samples
- from ray.rllib.utils import unflatten_dict
- from ray.rllib.utils.annotations import DeveloperAPI
- from ray.rllib.utils.typing import DeviceType, EpisodeType
- @DeveloperAPI
- class MiniBatchIteratorBase:
- """The base class for all minibatch iterators."""
- def __init__(
- self,
- batch: MultiAgentBatch,
- *,
- num_epochs: int = 1,
- shuffle_batch_per_epoch: bool = True,
- minibatch_size: int,
- num_total_minibatches: int = 0,
- ) -> None:
- """Initializes a MiniBatchIteratorBase instance.
- Args:
- batch: The input multi-agent batch.
- num_epochs: The number of complete passes over the entire train batch. Each
- pass might be further split into n minibatches (if `minibatch_size`
- provided). The train batch is generated from the given `episodes`
- through the Learner connector pipeline.
- minibatch_size: The size of minibatches to use to further split the train
- batch into per epoch. The train batch is generated from the given
- `episodes` through the Learner connector pipeline.
- num_total_minibatches: The total number of minibatches to loop through
- (over all `num_epochs` epochs). It's only required to set this to != 0
- in multi-agent + multi-GPU situations, in which the MultiAgentEpisodes
- themselves are roughly sharded equally, however, they might contain
- SingleAgentEpisodes with very lopsided length distributions. Thus,
- without this fixed, pre-computed value, one Learner might go through a
- different number of minibatche passes than others causing a deadlock.
- """
- pass
- @DeveloperAPI
- class MiniBatchCyclicIterator(MiniBatchIteratorBase):
- """This implements a simple multi-agent minibatch iterator.
- This iterator will split the input multi-agent batch into minibatches where the
- size of batch for each module_id (aka policy_id) is equal to minibatch_size. If the
- input batch is smaller than minibatch_size, then the iterator will cycle through
- the batch until it has covered `num_epochs` epochs.
- """
- def __init__(
- self,
- batch: MultiAgentBatch,
- *,
- num_epochs: int = 1,
- minibatch_size: int,
- shuffle_batch_per_epoch: bool = True,
- num_total_minibatches: int = 0,
- ) -> None:
- """Initializes a MiniBatchCyclicIterator instance."""
- super().__init__(
- batch,
- num_epochs=num_epochs,
- minibatch_size=minibatch_size,
- shuffle_batch_per_epoch=shuffle_batch_per_epoch,
- )
- self._batch = batch
- self._minibatch_size = minibatch_size
- self._num_epochs = num_epochs
- self._shuffle_batch_per_epoch = shuffle_batch_per_epoch
- # mapping from module_id to the start index of the batch
- self._start = {mid: 0 for mid in batch.policy_batches.keys()}
- # mapping from module_id to the number of epochs covered for each module_id
- self._num_covered_epochs = {mid: 0 for mid in batch.policy_batches.keys()}
- self._minibatch_count = 0
- self._num_total_minibatches = num_total_minibatches
- def __iter__(self):
- while (
- # Make sure each item in the total batch gets at least iterated over
- # `self._num_epochs` times.
- (
- self._num_total_minibatches == 0
- and min(self._num_covered_epochs.values()) < self._num_epochs
- )
- # Make sure we reach at least the given minimum number of mini-batches.
- or (
- self._num_total_minibatches > 0
- and self._minibatch_count < self._num_total_minibatches
- )
- ):
- minibatch = {}
- for module_id, module_batch in self._batch.policy_batches.items():
- if len(module_batch) == 0:
- raise ValueError(
- f"The batch for module_id {module_id} is empty! "
- "This will create an infinite loop because we need to cover "
- "the same number of samples for each module_id."
- )
- s = self._start[module_id] # start
- # TODO (sven): Fix this bug for LSTMs:
- # In an RNN-setting, the Learner connector already has zero-padded
- # and added a timerank to the batch. Thus, n_step would still be based
- # on the BxT dimension, rather than the new B dimension (excluding T),
- # which then leads to minibatches way too large.
- # However, changing this already would break APPO/IMPALA w/o LSTMs as
- # these setups require sequencing, BUT their batches are not yet time-
- # ranked (this is done only in their loss functions via the
- # `make_time_major` utility).
- n_steps = self._minibatch_size
- samples_to_concat = []
- # get_len is a function that returns the length of a batch
- # if we are not slicing the batch in the batch dimension B, then
- # the length of the batch is simply the length of the batch
- # o.w the length of the batch is the length list of seq_lens.
- if module_batch._slice_seq_lens_in_B:
- assert module_batch.get(SampleBatch.SEQ_LENS) is not None, (
- "MiniBatchCyclicIterator requires SampleBatch.SEQ_LENS"
- "to be present in the batch for slicing a batch in the batch "
- "dimension B."
- )
- def get_len(b):
- return len(b[SampleBatch.SEQ_LENS])
- n_steps = int(
- get_len(module_batch)
- * (self._minibatch_size / len(module_batch))
- )
- else:
- def get_len(b):
- return len(b)
- # Cycle through the batch until we have enough samples.
- while s + n_steps >= get_len(module_batch):
- sample = module_batch[s:]
- samples_to_concat.append(sample)
- len_sample = get_len(sample)
- assert len_sample > 0, "Length of a sample must be > 0!"
- n_steps -= len_sample
- s = 0
- self._num_covered_epochs[module_id] += 1
- # Shuffle the individual single-agent batch, if required.
- # This should happen once per minibatch iteration in order to make
- # each iteration go through a different set of minibatches.
- if self._shuffle_batch_per_epoch:
- module_batch.shuffle()
- e = s + n_steps # end
- if e > s:
- samples_to_concat.append(module_batch[s:e])
- # concatenate all the samples, we should have minibatch_size of sample
- # after this step
- minibatch[module_id] = concat_samples(samples_to_concat)
- # roll minibatch to zero when we reach the end of the batch
- self._start[module_id] = e
- # Note (Kourosh): env_steps is the total number of env_steps that this
- # multi-agent batch is covering. It should be simply inherited from the
- # original multi-agent batch.
- minibatch = MultiAgentBatch(minibatch, len(self._batch))
- yield minibatch
- self._minibatch_count += 1
- class MiniBatchDummyIterator(MiniBatchIteratorBase):
- def __init__(self, batch: MultiAgentBatch, **kwargs):
- super().__init__(batch, **kwargs)
- self._batch = batch
- def __iter__(self):
- yield self._batch
- @DeveloperAPI
- class MiniBatchRayDataIterator:
- def __init__(
- self,
- *,
- iterator: DataIterator,
- device: DeviceType,
- minibatch_size: int,
- num_iters: Optional[int],
- **kwargs,
- ):
- # A `ray.data.DataIterator` that can iterate in different ways over the data.
- self._iterator = iterator
- # Note, in multi-learner settings the `return_state` is in `kwargs`.
- self._kwargs = {k: v for k, v in kwargs.items() if k != "return_state"}
- # Holds a batched_iterable over the dataset.
- self._batched_iterable = self._iterator.iter_torch_batches(
- batch_size=minibatch_size,
- device=device,
- **self._kwargs,
- )
- # Create an iterator that can be stopped and resumed during an epoch.
- self._epoch_iterator = iter(self._batched_iterable)
- self._num_iters = num_iters
- def __iter__(self) -> MultiAgentBatch:
- iteration = 0
- while self._num_iters is None or iteration < self._num_iters:
- for batch in self._epoch_iterator:
- # Update the iteration counter.
- iteration += 1
- batch = unflatten_dict(batch)
- batch = MultiAgentBatch(
- {
- module_id: SampleBatch(module_data)
- for module_id, module_data in batch.items()
- },
- env_steps=sum(
- len(next(iter(module_data.values())))
- for module_data in batch.values()
- ),
- )
- yield (batch)
- # If `num_iters` is reached break and return.
- if self._num_iters and iteration == self._num_iters:
- break
- else:
- # Reinstantiate a new epoch iterator.
- self._epoch_iterator = iter(self._batched_iterable)
- # If a full epoch on the data should be run, stop.
- if not self._num_iters:
- # Exit the loop.
- break
- @DeveloperAPI
- class ShardBatchIterator:
- """Iterator for sharding batch into num_shards batches.
- Args:
- batch: The input multi-agent batch.
- num_shards: The number of shards to split the batch into.
- Yields:
- A MultiAgentBatch of size len(batch) / num_shards.
- """
- def __init__(self, batch: MultiAgentBatch, num_shards: int):
- self._batch = batch
- self._num_shards = num_shards
- def __iter__(self):
- for i in range(self._num_shards):
- # TODO (sven): The following way of sharding a multi-agent batch destroys
- # the relationship of the different agents' timesteps to each other.
- # Thus, in case the algorithm requires agent-synchronized data (aka.
- # "lockstep"), the `ShardBatchIterator` cannot be used.
- batch_to_send = {}
- for pid, sub_batch in self._batch.policy_batches.items():
- batch_size = math.ceil(len(sub_batch) / self._num_shards)
- start = batch_size * i
- end = min(start + batch_size, len(sub_batch))
- batch_to_send[pid] = sub_batch[int(start) : int(end)]
- # TODO (Avnish): int(batch_size) ? How should we shard MA batches really?
- new_batch = MultiAgentBatch(batch_to_send, int(batch_size))
- yield new_batch
- @DeveloperAPI
- class ShardEpisodesIterator:
- """Iterator for sharding a list of Episodes into `num_shards` lists of Episodes."""
- def __init__(
- self,
- episodes: List[EpisodeType],
- num_shards: int,
- len_lookback_buffer: Optional[int] = None,
- ):
- """Initializes a ShardEpisodesIterator instance.
- Args:
- episodes: The input list of Episodes.
- num_shards: The number of shards to split the episodes into.
- len_lookback_buffer: An optional length of a lookback buffer to enforce
- on the returned shards. When spitting an episode, the second piece
- might need a lookback buffer (into the first piece) depending on the
- user's settings.
- """
- self._episodes = sorted(episodes, key=len, reverse=True)
- self._num_shards = num_shards
- self._len_lookback_buffer = len_lookback_buffer
- self._total_length = sum(len(e) for e in episodes)
- self._target_lengths = [0 for _ in range(self._num_shards)]
- remaining_length = self._total_length
- for s in range(self._num_shards):
- len_ = remaining_length // (num_shards - s)
- self._target_lengths[s] = len_
- remaining_length -= len_
- def __iter__(self) -> List[EpisodeType]:
- """Runs one iteration through this sharder.
- Yields:
- A sub-list of Episodes of size roughly `len(episodes) / num_shards`. The
- yielded sublists might have slightly different total sums of episode
- lengths, in order to not have to drop even a single timestep.
- """
- sublists = [[] for _ in range(self._num_shards)]
- lengths = [0 for _ in range(self._num_shards)]
- episode_index = 0
- while episode_index < len(self._episodes):
- episode = self._episodes[episode_index]
- min_index = lengths.index(min(lengths))
- # Add the whole episode if it fits within the target length
- if lengths[min_index] + len(episode) <= self._target_lengths[min_index]:
- sublists[min_index].append(episode)
- lengths[min_index] += len(episode)
- episode_index += 1
- # Otherwise, slice the episode
- else:
- remaining_length = self._target_lengths[min_index] - lengths[min_index]
- if remaining_length > 0:
- slice_part, remaining_part = (
- # Note that the first slice will automatically "inherit" the
- # lookback buffer size of the episode.
- episode[:remaining_length],
- # However, the second slice might need a user defined lookback
- # buffer (into the first slice).
- episode.slice(
- slice(remaining_length, None),
- len_lookback_buffer=self._len_lookback_buffer,
- ),
- )
- sublists[min_index].append(slice_part)
- lengths[min_index] += len(slice_part)
- self._episodes[episode_index] = remaining_part
- else:
- assert remaining_length == 0
- sublists[min_index].append(episode)
- episode_index += 1
- for sublist in sublists:
- yield sublist
- @DeveloperAPI
- class ShardObjectRefIterator:
- """Iterator for sharding a list of ray ObjectRefs into num_shards sub-lists.
- Args:
- object_refs: The input list of ray ObjectRefs.
- num_shards: The number of shards to split the references into.
- Yields:
- A sub-list of ray ObjectRefs with lengths as equal as possible.
- """
- def __init__(self, object_refs, num_shards: int):
- self._object_refs = object_refs
- self._num_shards = num_shards
- def __iter__(self):
- # Calculate the size of each sublist
- n = len(self._object_refs)
- sublist_size = n // self._num_shards
- remaining_elements = n % self._num_shards
- start = 0
- for i in range(self._num_shards):
- # Determine the end index for the current sublist
- end = start + sublist_size + (1 if i < remaining_elements else 0)
- # Append the sublist to the result
- yield self._object_refs[start:end]
- # Update the start index for the next sublist
- start = end
|