minibatch_utils.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. import math
  2. from typing import List, Optional
  3. from ray.data import DataIterator
  4. from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch, concat_samples
  5. from ray.rllib.utils import unflatten_dict
  6. from ray.rllib.utils.annotations import DeveloperAPI
  7. from ray.rllib.utils.typing import DeviceType, EpisodeType
  8. @DeveloperAPI
  9. class MiniBatchIteratorBase:
  10. """The base class for all minibatch iterators."""
  11. def __init__(
  12. self,
  13. batch: MultiAgentBatch,
  14. *,
  15. num_epochs: int = 1,
  16. shuffle_batch_per_epoch: bool = True,
  17. minibatch_size: int,
  18. num_total_minibatches: int = 0,
  19. ) -> None:
  20. """Initializes a MiniBatchIteratorBase instance.
  21. Args:
  22. batch: The input multi-agent batch.
  23. num_epochs: The number of complete passes over the entire train batch. Each
  24. pass might be further split into n minibatches (if `minibatch_size`
  25. provided). The train batch is generated from the given `episodes`
  26. through the Learner connector pipeline.
  27. minibatch_size: The size of minibatches to use to further split the train
  28. batch into per epoch. The train batch is generated from the given
  29. `episodes` through the Learner connector pipeline.
  30. num_total_minibatches: The total number of minibatches to loop through
  31. (over all `num_epochs` epochs). It's only required to set this to != 0
  32. in multi-agent + multi-GPU situations, in which the MultiAgentEpisodes
  33. themselves are roughly sharded equally, however, they might contain
  34. SingleAgentEpisodes with very lopsided length distributions. Thus,
  35. without this fixed, pre-computed value, one Learner might go through a
  36. different number of minibatche passes than others causing a deadlock.
  37. """
  38. pass
  39. @DeveloperAPI
  40. class MiniBatchCyclicIterator(MiniBatchIteratorBase):
  41. """This implements a simple multi-agent minibatch iterator.
  42. This iterator will split the input multi-agent batch into minibatches where the
  43. size of batch for each module_id (aka policy_id) is equal to minibatch_size. If the
  44. input batch is smaller than minibatch_size, then the iterator will cycle through
  45. the batch until it has covered `num_epochs` epochs.
  46. """
  47. def __init__(
  48. self,
  49. batch: MultiAgentBatch,
  50. *,
  51. num_epochs: int = 1,
  52. minibatch_size: int,
  53. shuffle_batch_per_epoch: bool = True,
  54. num_total_minibatches: int = 0,
  55. ) -> None:
  56. """Initializes a MiniBatchCyclicIterator instance."""
  57. super().__init__(
  58. batch,
  59. num_epochs=num_epochs,
  60. minibatch_size=minibatch_size,
  61. shuffle_batch_per_epoch=shuffle_batch_per_epoch,
  62. )
  63. self._batch = batch
  64. self._minibatch_size = minibatch_size
  65. self._num_epochs = num_epochs
  66. self._shuffle_batch_per_epoch = shuffle_batch_per_epoch
  67. # mapping from module_id to the start index of the batch
  68. self._start = {mid: 0 for mid in batch.policy_batches.keys()}
  69. # mapping from module_id to the number of epochs covered for each module_id
  70. self._num_covered_epochs = {mid: 0 for mid in batch.policy_batches.keys()}
  71. self._minibatch_count = 0
  72. self._num_total_minibatches = num_total_minibatches
  73. def __iter__(self):
  74. while (
  75. # Make sure each item in the total batch gets at least iterated over
  76. # `self._num_epochs` times.
  77. (
  78. self._num_total_minibatches == 0
  79. and min(self._num_covered_epochs.values()) < self._num_epochs
  80. )
  81. # Make sure we reach at least the given minimum number of mini-batches.
  82. or (
  83. self._num_total_minibatches > 0
  84. and self._minibatch_count < self._num_total_minibatches
  85. )
  86. ):
  87. minibatch = {}
  88. for module_id, module_batch in self._batch.policy_batches.items():
  89. if len(module_batch) == 0:
  90. raise ValueError(
  91. f"The batch for module_id {module_id} is empty! "
  92. "This will create an infinite loop because we need to cover "
  93. "the same number of samples for each module_id."
  94. )
  95. s = self._start[module_id] # start
  96. # TODO (sven): Fix this bug for LSTMs:
  97. # In an RNN-setting, the Learner connector already has zero-padded
  98. # and added a timerank to the batch. Thus, n_step would still be based
  99. # on the BxT dimension, rather than the new B dimension (excluding T),
  100. # which then leads to minibatches way too large.
  101. # However, changing this already would break APPO/IMPALA w/o LSTMs as
  102. # these setups require sequencing, BUT their batches are not yet time-
  103. # ranked (this is done only in their loss functions via the
  104. # `make_time_major` utility).
  105. n_steps = self._minibatch_size
  106. samples_to_concat = []
  107. # get_len is a function that returns the length of a batch
  108. # if we are not slicing the batch in the batch dimension B, then
  109. # the length of the batch is simply the length of the batch
  110. # o.w the length of the batch is the length list of seq_lens.
  111. if module_batch._slice_seq_lens_in_B:
  112. assert module_batch.get(SampleBatch.SEQ_LENS) is not None, (
  113. "MiniBatchCyclicIterator requires SampleBatch.SEQ_LENS"
  114. "to be present in the batch for slicing a batch in the batch "
  115. "dimension B."
  116. )
  117. def get_len(b):
  118. return len(b[SampleBatch.SEQ_LENS])
  119. n_steps = int(
  120. get_len(module_batch)
  121. * (self._minibatch_size / len(module_batch))
  122. )
  123. else:
  124. def get_len(b):
  125. return len(b)
  126. # Cycle through the batch until we have enough samples.
  127. while s + n_steps >= get_len(module_batch):
  128. sample = module_batch[s:]
  129. samples_to_concat.append(sample)
  130. len_sample = get_len(sample)
  131. assert len_sample > 0, "Length of a sample must be > 0!"
  132. n_steps -= len_sample
  133. s = 0
  134. self._num_covered_epochs[module_id] += 1
  135. # Shuffle the individual single-agent batch, if required.
  136. # This should happen once per minibatch iteration in order to make
  137. # each iteration go through a different set of minibatches.
  138. if self._shuffle_batch_per_epoch:
  139. module_batch.shuffle()
  140. e = s + n_steps # end
  141. if e > s:
  142. samples_to_concat.append(module_batch[s:e])
  143. # concatenate all the samples, we should have minibatch_size of sample
  144. # after this step
  145. minibatch[module_id] = concat_samples(samples_to_concat)
  146. # roll minibatch to zero when we reach the end of the batch
  147. self._start[module_id] = e
  148. # Note (Kourosh): env_steps is the total number of env_steps that this
  149. # multi-agent batch is covering. It should be simply inherited from the
  150. # original multi-agent batch.
  151. minibatch = MultiAgentBatch(minibatch, len(self._batch))
  152. yield minibatch
  153. self._minibatch_count += 1
  154. class MiniBatchDummyIterator(MiniBatchIteratorBase):
  155. def __init__(self, batch: MultiAgentBatch, **kwargs):
  156. super().__init__(batch, **kwargs)
  157. self._batch = batch
  158. def __iter__(self):
  159. yield self._batch
  160. @DeveloperAPI
  161. class MiniBatchRayDataIterator:
  162. def __init__(
  163. self,
  164. *,
  165. iterator: DataIterator,
  166. device: DeviceType,
  167. minibatch_size: int,
  168. num_iters: Optional[int],
  169. **kwargs,
  170. ):
  171. # A `ray.data.DataIterator` that can iterate in different ways over the data.
  172. self._iterator = iterator
  173. # Note, in multi-learner settings the `return_state` is in `kwargs`.
  174. self._kwargs = {k: v for k, v in kwargs.items() if k != "return_state"}
  175. # Holds a batched_iterable over the dataset.
  176. self._batched_iterable = self._iterator.iter_torch_batches(
  177. batch_size=minibatch_size,
  178. device=device,
  179. **self._kwargs,
  180. )
  181. # Create an iterator that can be stopped and resumed during an epoch.
  182. self._epoch_iterator = iter(self._batched_iterable)
  183. self._num_iters = num_iters
  184. def __iter__(self) -> MultiAgentBatch:
  185. iteration = 0
  186. while self._num_iters is None or iteration < self._num_iters:
  187. for batch in self._epoch_iterator:
  188. # Update the iteration counter.
  189. iteration += 1
  190. batch = unflatten_dict(batch)
  191. batch = MultiAgentBatch(
  192. {
  193. module_id: SampleBatch(module_data)
  194. for module_id, module_data in batch.items()
  195. },
  196. env_steps=sum(
  197. len(next(iter(module_data.values())))
  198. for module_data in batch.values()
  199. ),
  200. )
  201. yield (batch)
  202. # If `num_iters` is reached break and return.
  203. if self._num_iters and iteration == self._num_iters:
  204. break
  205. else:
  206. # Reinstantiate a new epoch iterator.
  207. self._epoch_iterator = iter(self._batched_iterable)
  208. # If a full epoch on the data should be run, stop.
  209. if not self._num_iters:
  210. # Exit the loop.
  211. break
  212. @DeveloperAPI
  213. class ShardBatchIterator:
  214. """Iterator for sharding batch into num_shards batches.
  215. Args:
  216. batch: The input multi-agent batch.
  217. num_shards: The number of shards to split the batch into.
  218. Yields:
  219. A MultiAgentBatch of size len(batch) / num_shards.
  220. """
  221. def __init__(self, batch: MultiAgentBatch, num_shards: int):
  222. self._batch = batch
  223. self._num_shards = num_shards
  224. def __iter__(self):
  225. for i in range(self._num_shards):
  226. # TODO (sven): The following way of sharding a multi-agent batch destroys
  227. # the relationship of the different agents' timesteps to each other.
  228. # Thus, in case the algorithm requires agent-synchronized data (aka.
  229. # "lockstep"), the `ShardBatchIterator` cannot be used.
  230. batch_to_send = {}
  231. for pid, sub_batch in self._batch.policy_batches.items():
  232. batch_size = math.ceil(len(sub_batch) / self._num_shards)
  233. start = batch_size * i
  234. end = min(start + batch_size, len(sub_batch))
  235. batch_to_send[pid] = sub_batch[int(start) : int(end)]
  236. # TODO (Avnish): int(batch_size) ? How should we shard MA batches really?
  237. new_batch = MultiAgentBatch(batch_to_send, int(batch_size))
  238. yield new_batch
  239. @DeveloperAPI
  240. class ShardEpisodesIterator:
  241. """Iterator for sharding a list of Episodes into `num_shards` lists of Episodes."""
  242. def __init__(
  243. self,
  244. episodes: List[EpisodeType],
  245. num_shards: int,
  246. len_lookback_buffer: Optional[int] = None,
  247. ):
  248. """Initializes a ShardEpisodesIterator instance.
  249. Args:
  250. episodes: The input list of Episodes.
  251. num_shards: The number of shards to split the episodes into.
  252. len_lookback_buffer: An optional length of a lookback buffer to enforce
  253. on the returned shards. When spitting an episode, the second piece
  254. might need a lookback buffer (into the first piece) depending on the
  255. user's settings.
  256. """
  257. self._episodes = sorted(episodes, key=len, reverse=True)
  258. self._num_shards = num_shards
  259. self._len_lookback_buffer = len_lookback_buffer
  260. self._total_length = sum(len(e) for e in episodes)
  261. self._target_lengths = [0 for _ in range(self._num_shards)]
  262. remaining_length = self._total_length
  263. for s in range(self._num_shards):
  264. len_ = remaining_length // (num_shards - s)
  265. self._target_lengths[s] = len_
  266. remaining_length -= len_
  267. def __iter__(self) -> List[EpisodeType]:
  268. """Runs one iteration through this sharder.
  269. Yields:
  270. A sub-list of Episodes of size roughly `len(episodes) / num_shards`. The
  271. yielded sublists might have slightly different total sums of episode
  272. lengths, in order to not have to drop even a single timestep.
  273. """
  274. sublists = [[] for _ in range(self._num_shards)]
  275. lengths = [0 for _ in range(self._num_shards)]
  276. episode_index = 0
  277. while episode_index < len(self._episodes):
  278. episode = self._episodes[episode_index]
  279. min_index = lengths.index(min(lengths))
  280. # Add the whole episode if it fits within the target length
  281. if lengths[min_index] + len(episode) <= self._target_lengths[min_index]:
  282. sublists[min_index].append(episode)
  283. lengths[min_index] += len(episode)
  284. episode_index += 1
  285. # Otherwise, slice the episode
  286. else:
  287. remaining_length = self._target_lengths[min_index] - lengths[min_index]
  288. if remaining_length > 0:
  289. slice_part, remaining_part = (
  290. # Note that the first slice will automatically "inherit" the
  291. # lookback buffer size of the episode.
  292. episode[:remaining_length],
  293. # However, the second slice might need a user defined lookback
  294. # buffer (into the first slice).
  295. episode.slice(
  296. slice(remaining_length, None),
  297. len_lookback_buffer=self._len_lookback_buffer,
  298. ),
  299. )
  300. sublists[min_index].append(slice_part)
  301. lengths[min_index] += len(slice_part)
  302. self._episodes[episode_index] = remaining_part
  303. else:
  304. assert remaining_length == 0
  305. sublists[min_index].append(episode)
  306. episode_index += 1
  307. for sublist in sublists:
  308. yield sublist
  309. @DeveloperAPI
  310. class ShardObjectRefIterator:
  311. """Iterator for sharding a list of ray ObjectRefs into num_shards sub-lists.
  312. Args:
  313. object_refs: The input list of ray ObjectRefs.
  314. num_shards: The number of shards to split the references into.
  315. Yields:
  316. A sub-list of ray ObjectRefs with lengths as equal as possible.
  317. """
  318. def __init__(self, object_refs, num_shards: int):
  319. self._object_refs = object_refs
  320. self._num_shards = num_shards
  321. def __iter__(self):
  322. # Calculate the size of each sublist
  323. n = len(self._object_refs)
  324. sublist_size = n // self._num_shards
  325. remaining_elements = n % self._num_shards
  326. start = 0
  327. for i in range(self._num_shards):
  328. # Determine the end index for the current sublist
  329. end = start + sublist_size + (1 if i < remaining_elements else 0)
  330. # Append the sublist to the result
  331. yield self._object_refs[start:end]
  332. # Update the start index for the next sublist
  333. start = end