sample_batch.py 67 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843
  1. import collections
  2. import itertools
  3. import sys
  4. from functools import partial
  5. from numbers import Number
  6. from typing import Dict, Iterator, List, Optional, Set, Union
  7. import numpy as np
  8. import tree # pip install dm_tree
  9. from ray._common.deprecation import Deprecated, deprecation_warning
  10. from ray.rllib.core.columns import Columns
  11. from ray.rllib.utils.annotations import DeveloperAPI, ExperimentalAPI, PublicAPI
  12. from ray.rllib.utils.compression import is_compressed, pack, unpack
  13. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  14. from ray.rllib.utils.torch_utils import convert_to_torch_tensor
  15. from ray.rllib.utils.typing import (
  16. ModuleID,
  17. PolicyID,
  18. SampleBatchType,
  19. TensorType,
  20. ViewRequirementsDict,
  21. )
  22. from ray.util import log_once
  23. tf1, tf, tfv = try_import_tf()
  24. torch, _ = try_import_torch()
  25. # Default policy id for single agent environments
  26. DEFAULT_POLICY_ID = "default_policy"
  27. @DeveloperAPI
  28. def attempt_count_timesteps(tensor_dict: dict):
  29. """Attempt to count timesteps based on dimensions of individual elements.
  30. Returns the first successfully counted number of timesteps.
  31. We do not attempt to count on INFOS or any state_in_* and state_out_* keys. The
  32. number of timesteps we count in cases where we are unable to count is zero.
  33. Args:
  34. tensor_dict: A SampleBatch or another dict.
  35. Returns:
  36. count: The inferred number of timesteps >= 0.
  37. """
  38. # Try to infer the "length" of the SampleBatch by finding the first
  39. # value that is actually a ndarray/tensor.
  40. # Skip manual counting routine if we can directly infer count from sequence lengths
  41. seq_lens = tensor_dict.get(SampleBatch.SEQ_LENS)
  42. if (
  43. seq_lens is not None
  44. and not (tf and tf.is_tensor(seq_lens) and not hasattr(seq_lens, "numpy"))
  45. and len(seq_lens) > 0
  46. ):
  47. if torch and torch.is_tensor(seq_lens):
  48. return int(seq_lens.sum().item())
  49. else:
  50. return int(sum(seq_lens))
  51. for k, v in tensor_dict.items():
  52. if k == SampleBatch.SEQ_LENS:
  53. continue
  54. assert isinstance(k, str), tensor_dict
  55. if (
  56. k == SampleBatch.INFOS
  57. or k.startswith("state_in_")
  58. or k.startswith("state_out_")
  59. ):
  60. # Don't attempt to count on infos since we make no assumptions
  61. # about its content
  62. # Don't attempt to count on state since nesting can potentially mess
  63. # things up
  64. continue
  65. # If this is a nested dict (for example a nested observation),
  66. # try to flatten it, assert that all elements have the same length (batch
  67. # dimension)
  68. v_list = tree.flatten(v) if isinstance(v, (dict, tuple)) else [v]
  69. # TODO: Drop support for lists and Numbers as values.
  70. # If v_list contains lists or Numbers, convert them to arrays, too.
  71. v_list = [
  72. np.array(_v) if isinstance(_v, (Number, list)) else _v for _v in v_list
  73. ]
  74. try:
  75. # Add one of the elements' length, since they are all the same
  76. _len = len(v_list[0])
  77. if _len:
  78. return _len
  79. except Exception:
  80. pass
  81. # Return zero if we are unable to count
  82. return 0
  83. @PublicAPI
  84. class SampleBatch(dict):
  85. """Wrapper around a dictionary with string keys and array-like values.
  86. For example, {"obs": [1, 2, 3], "reward": [0, -1, 1]} is a batch of three
  87. samples, each with an "obs" and "reward" attribute.
  88. """
  89. # On rows in SampleBatch:
  90. # Each comment signifies how values relate to each other within a given row.
  91. # A row generally signifies one timestep. Most importantly, at t=0, SampleBatch.OBS
  92. # will usually be the reset-observation, while SampleBatch.ACTIONS will be the
  93. # action based on the reset-observation and so on. This scheme is derived from
  94. # RLlib's sampling logic.
  95. # The following fields have all been moved to `Columns` and are only left here
  96. # for backward compatibility.
  97. OBS = Columns.OBS
  98. ACTIONS = Columns.ACTIONS
  99. REWARDS = Columns.REWARDS
  100. TERMINATEDS = Columns.TERMINATEDS
  101. TRUNCATEDS = Columns.TRUNCATEDS
  102. INFOS = Columns.INFOS
  103. SEQ_LENS = Columns.SEQ_LENS
  104. T = Columns.T
  105. ACTION_DIST_INPUTS = Columns.ACTION_DIST_INPUTS
  106. ACTION_PROB = Columns.ACTION_PROB
  107. ACTION_LOGP = Columns.ACTION_LOGP
  108. VF_PREDS = Columns.VF_PREDS
  109. VALUES_BOOTSTRAPPED = Columns.VALUES_BOOTSTRAPPED
  110. EPS_ID = Columns.EPS_ID
  111. NEXT_OBS = Columns.NEXT_OBS
  112. # Action distribution object.
  113. ACTION_DIST = "action_dist"
  114. # Action chosen before SampleBatch.ACTIONS.
  115. PREV_ACTIONS = "prev_actions"
  116. # Reward received before SampleBatch.REWARDS.
  117. PREV_REWARDS = "prev_rewards"
  118. ENV_ID = "env_id" # An env ID (e.g. the index for a vectorized sub-env).
  119. AGENT_INDEX = "agent_index" # Uniquely identifies an agent within an episode.
  120. # Uniquely identifies a sample batch. This is important to distinguish RNN
  121. # sequences from the same episode when multiple sample batches are
  122. # concatenated (fusing sequences across batches can be unsafe).
  123. UNROLL_ID = "unroll_id"
  124. # RE 3
  125. # This is only computed and used when RE3 exploration strategy is enabled.
  126. OBS_EMBEDS = "obs_embeds"
  127. # Decision Transformer
  128. RETURNS_TO_GO = "returns_to_go"
  129. ATTENTION_MASKS = "attention_masks"
  130. # Do not set this key directly. Instead, the values under this key are
  131. # auto-computed via the values of the TERMINATEDS and TRUNCATEDS keys.
  132. DONES = "dones"
  133. # Use SampleBatch.OBS instead.
  134. CUR_OBS = "obs"
  135. @PublicAPI
  136. def __init__(self, *args, **kwargs):
  137. """Constructs a sample batch (same params as dict constructor).
  138. Note: All args and those kwargs not listed below will be passed
  139. as-is to the parent dict constructor.
  140. Args:
  141. _time_major: Whether data in this sample batch
  142. is time-major. This is False by default and only relevant
  143. if the data contains sequences.
  144. _max_seq_len: The max sequence chunk length
  145. if the data contains sequences.
  146. _zero_padded: Whether the data in this batch
  147. contains sequences AND these sequences are right-zero-padded
  148. according to the `_max_seq_len` setting.
  149. _is_training: Whether this batch is used for
  150. training. If False, batch may be used for e.g. action
  151. computations (inference).
  152. """
  153. if SampleBatch.DONES in kwargs:
  154. raise KeyError(
  155. "SampleBatch cannot be constructed anymore with a `DONES` key! "
  156. "Instead, set the new TERMINATEDS and TRUNCATEDS keys. The values under"
  157. " DONES will then be automatically computed using terminated|truncated."
  158. )
  159. # Possible seq_lens (TxB or BxT) setup.
  160. self.time_major = kwargs.pop("_time_major", None)
  161. # Maximum seq len value.
  162. self.max_seq_len = kwargs.pop("_max_seq_len", None)
  163. # Is alredy right-zero-padded?
  164. self.zero_padded = kwargs.pop("_zero_padded", False)
  165. # Whether this batch is used for training (vs inference).
  166. self._is_training = kwargs.pop("_is_training", None)
  167. # Weighted average number of grad updates that have been performed on the
  168. # policy/ies that were used to collect this batch.
  169. # E.g.: Two rollout workers collect samples of 50ts each
  170. # (rollout_fragment_length=50). One of them has a policy that has undergone
  171. # 2 updates thus far, the other worker uses a policy that has undergone 3
  172. # updates thus far. The train batch size is 100, so we concatenate these 2
  173. # batches to a new one that's 100ts long. This new 100ts batch will have its
  174. # `num_gradient_updates` property set to 2.5 as it's the weighted average
  175. # (both original batches contribute 50%).
  176. self.num_grad_updates: Optional[float] = kwargs.pop("_num_grad_updates", None)
  177. # Call super constructor. This will make the actual data accessible
  178. # by column name (str) via e.g. self["some-col"].
  179. dict.__init__(self, *args, **kwargs)
  180. # Indicates whether, for this batch, sequence lengths should be slices by
  181. # their index in the batch or by their index as a sequence.
  182. # This is useful if a batch contains tensors of shape (B, T, ...), where each
  183. # index of B indicates one sequence. In this case, when slicing the batch,
  184. # we want one sequence to be slices out per index in B (
  185. # `_slice_seq_lens_by_batch_index=True`. However, if the padded batch
  186. # contains tensors of shape (B*T, ...), where each index of B*T indicates
  187. # one timestep, we want one sequence to be sliced per T steps in B*T (
  188. # `self._slice_seq_lens_in_B=False`).
  189. # ._slice_seq_lens_in_B = True is only meant to be used for batches that we
  190. # feed into Learner._update(), all other places in RLlib are not expected to
  191. # need this.
  192. self._slice_seq_lens_in_B = False
  193. self.accessed_keys = set()
  194. self.added_keys = set()
  195. self.deleted_keys = set()
  196. self.intercepted_values = {}
  197. self.get_interceptor = None
  198. # Clear out None seq-lens.
  199. seq_lens_ = self.get(SampleBatch.SEQ_LENS)
  200. if seq_lens_ is None or (isinstance(seq_lens_, list) and len(seq_lens_) == 0):
  201. self.pop(SampleBatch.SEQ_LENS, None)
  202. # Numpyfy seq_lens if list.
  203. elif isinstance(seq_lens_, list):
  204. self[SampleBatch.SEQ_LENS] = seq_lens_ = np.array(seq_lens_, dtype=np.int32)
  205. elif (torch and torch.is_tensor(seq_lens_)) or (tf and tf.is_tensor(seq_lens_)):
  206. self[SampleBatch.SEQ_LENS] = seq_lens_
  207. if (
  208. self.max_seq_len is None
  209. and seq_lens_ is not None
  210. and not (tf and tf.is_tensor(seq_lens_))
  211. and len(seq_lens_) > 0
  212. ):
  213. if torch and torch.is_tensor(seq_lens_):
  214. self.max_seq_len = seq_lens_.max().item()
  215. else:
  216. self.max_seq_len = max(seq_lens_)
  217. if self._is_training is None:
  218. self._is_training = self.pop("is_training", False)
  219. for k, v in self.items():
  220. # TODO: Drop support for lists and Numbers as values.
  221. # Convert lists of int|float into numpy arrays make sure all data
  222. # has same length.
  223. if isinstance(v, (Number, list)) and not k == SampleBatch.INFOS:
  224. self[k] = np.array(v)
  225. self.count = attempt_count_timesteps(self)
  226. # A convenience map for slicing this batch into sub-batches along
  227. # the time axis. This helps reduce repeated iterations through the
  228. # batch's seq_lens array to find good slicing points. Built lazily
  229. # when needed.
  230. self._slice_map = []
  231. @PublicAPI
  232. def __len__(self) -> int:
  233. """Returns the amount of samples in the sample batch."""
  234. return self.count
  235. @PublicAPI
  236. def agent_steps(self) -> int:
  237. """Returns the same as len(self) (number of steps in this batch).
  238. To make this compatible with `MultiAgentBatch.agent_steps()`.
  239. """
  240. return len(self)
  241. @PublicAPI
  242. def env_steps(self) -> int:
  243. """Returns the same as len(self) (number of steps in this batch).
  244. To make this compatible with `MultiAgentBatch.env_steps()`.
  245. """
  246. return len(self)
  247. @DeveloperAPI
  248. def enable_slicing_by_batch_id(self):
  249. self._slice_seq_lens_in_B = True
  250. @DeveloperAPI
  251. def disable_slicing_by_batch_id(self):
  252. self._slice_seq_lens_in_B = False
  253. @ExperimentalAPI
  254. def is_terminated_or_truncated(self) -> bool:
  255. """Returns True if `self` is either terminated or truncated at idx -1."""
  256. return self[SampleBatch.TERMINATEDS][-1] or (
  257. SampleBatch.TRUNCATEDS in self and self[SampleBatch.TRUNCATEDS][-1]
  258. )
  259. @ExperimentalAPI
  260. def is_single_trajectory(self) -> bool:
  261. """Returns True if this SampleBatch only contains one trajectory.
  262. This is determined by checking all timesteps (except for the last) for being
  263. not terminated AND (if applicable) not truncated.
  264. """
  265. return not any(self[SampleBatch.TERMINATEDS][:-1]) and (
  266. SampleBatch.TRUNCATEDS not in self
  267. or not any(self[SampleBatch.TRUNCATEDS][:-1])
  268. )
  269. @staticmethod
  270. @PublicAPI
  271. @Deprecated(new="concat_samples() from rllib.policy.sample_batch", error=True)
  272. def concat_samples(samples):
  273. pass
  274. @PublicAPI
  275. def concat(self, other: "SampleBatch") -> "SampleBatch":
  276. """Concatenates `other` to this one and returns a new SampleBatch.
  277. Args:
  278. other: The other SampleBatch object to concat to this one.
  279. Returns:
  280. The new SampleBatch, resulting from concating `other` to `self`.
  281. .. testcode::
  282. :skipif: True
  283. import numpy as np
  284. from ray.rllib.policy.sample_batch import SampleBatch
  285. b1 = SampleBatch({"a": np.array([1, 2])})
  286. b2 = SampleBatch({"a": np.array([3, 4, 5])})
  287. print(b1.concat(b2))
  288. .. testoutput::
  289. {"a": np.array([1, 2, 3, 4, 5])}
  290. """
  291. return concat_samples([self, other])
  292. @PublicAPI
  293. def copy(self, shallow: bool = False) -> "SampleBatch":
  294. """Creates a deep or shallow copy of this SampleBatch and returns it.
  295. Args:
  296. shallow: Whether the copying should be done shallowly.
  297. Returns:
  298. A deep or shallow copy of this SampleBatch object.
  299. """
  300. copy_ = dict(self)
  301. data = tree.map_structure(
  302. lambda v: (
  303. np.array(v, copy=not shallow) if isinstance(v, np.ndarray) else v
  304. ),
  305. copy_,
  306. )
  307. copy_ = SampleBatch(
  308. data,
  309. _time_major=self.time_major,
  310. _zero_padded=self.zero_padded,
  311. _max_seq_len=self.max_seq_len,
  312. _num_grad_updates=self.num_grad_updates,
  313. )
  314. copy_.set_get_interceptor(self.get_interceptor)
  315. copy_.added_keys = self.added_keys
  316. copy_.deleted_keys = self.deleted_keys
  317. copy_.accessed_keys = self.accessed_keys
  318. return copy_
  319. @PublicAPI
  320. def rows(self) -> Iterator[Dict[str, TensorType]]:
  321. """Returns an iterator over data rows, i.e. dicts with column values.
  322. Note that if `seq_lens` is set in self, we set it to 1 in the rows.
  323. Yields:
  324. The column values of the row in this iteration.
  325. .. testcode::
  326. :skipif: True
  327. from ray.rllib.policy.sample_batch import SampleBatch
  328. batch = SampleBatch({
  329. "a": [1, 2, 3],
  330. "b": [4, 5, 6],
  331. "seq_lens": [1, 2]
  332. })
  333. for row in batch.rows():
  334. print(row)
  335. .. testoutput::
  336. {"a": 1, "b": 4, "seq_lens": 1}
  337. {"a": 2, "b": 5, "seq_lens": 1}
  338. {"a": 3, "b": 6, "seq_lens": 1}
  339. """
  340. seq_lens = None if self.get(SampleBatch.SEQ_LENS, 1) is None else 1
  341. self_as_dict = dict(self)
  342. for i in range(self.count):
  343. yield tree.map_structure_with_path(
  344. lambda p, v, i=i: v[i] if p[0] != self.SEQ_LENS else seq_lens,
  345. self_as_dict,
  346. )
  347. @PublicAPI
  348. def columns(self, keys: List[str]) -> List[any]:
  349. """Returns a list of the batch-data in the specified columns.
  350. Args:
  351. keys: List of column names fo which to return the data.
  352. Returns:
  353. The list of data items ordered by the order of column
  354. names in `keys`.
  355. .. testcode::
  356. :skipif: True
  357. from ray.rllib.policy.sample_batch import SampleBatch
  358. batch = SampleBatch({"a": [1], "b": [2], "c": [3]})
  359. print(batch.columns(["a", "b"]))
  360. .. testoutput::
  361. [[1], [2]]
  362. """
  363. # TODO: (sven) Make this work for nested data as well.
  364. out = []
  365. for k in keys:
  366. out.append(self[k])
  367. return out
  368. @PublicAPI
  369. def shuffle(self) -> "SampleBatch":
  370. """Shuffles the rows of this batch in-place.
  371. Returns:
  372. This very (now shuffled) SampleBatch.
  373. Raises:
  374. ValueError: If self[SampleBatch.SEQ_LENS] is defined.
  375. .. testcode::
  376. :skipif: True
  377. from ray.rllib.policy.sample_batch import SampleBatch
  378. batch = SampleBatch({"a": [1, 2, 3, 4]})
  379. print(batch.shuffle())
  380. .. testoutput::
  381. {"a": [4, 1, 3, 2]}
  382. """
  383. has_time_rank = self.get(SampleBatch.SEQ_LENS) is not None
  384. # Shuffling the data when we have `seq_lens` defined is probably
  385. # a bad idea!
  386. if has_time_rank and not self.zero_padded:
  387. raise ValueError(
  388. "SampleBatch.shuffle not possible when your data has "
  389. "`seq_lens` defined AND is not zero-padded yet!"
  390. )
  391. # Get a permutation over the single items once and use the same
  392. # permutation for all the data (otherwise, data would become
  393. # meaningless).
  394. # - Shuffle by individual item.
  395. if not has_time_rank:
  396. permutation = np.random.permutation(self.count)
  397. # - Shuffle along batch axis (leave axis=1/time-axis as-is).
  398. else:
  399. permutation = np.random.permutation(len(self[SampleBatch.SEQ_LENS]))
  400. self_as_dict = dict(self)
  401. infos = self_as_dict.pop(Columns.INFOS, None)
  402. shuffled = tree.map_structure(lambda v: v[permutation], self_as_dict)
  403. if infos is not None:
  404. self_as_dict[Columns.INFOS] = [infos[i] for i in permutation]
  405. self.update(shuffled)
  406. # Flush cache such that intercepted values are recalculated after the
  407. # shuffling.
  408. self.intercepted_values = {}
  409. return self
  410. @PublicAPI
  411. def split_by_episode(self, key: Optional[str] = None) -> List["SampleBatch"]:
  412. """Splits by `eps_id` column and returns list of new batches.
  413. If `eps_id` is not present, splits by `dones` instead.
  414. Args:
  415. key: If specified, overwrite default and use key to split.
  416. Returns:
  417. List of batches, one per distinct episode.
  418. Raises:
  419. KeyError: If the `eps_id` AND `dones` columns are not present.
  420. .. testcode::
  421. :skipif: True
  422. from ray.rllib.policy.sample_batch import SampleBatch
  423. # "eps_id" is present
  424. batch = SampleBatch(
  425. {"a": [1, 2, 3], "eps_id": [0, 0, 1]})
  426. print(batch.split_by_episode())
  427. # "eps_id" not present, split by "dones" instead
  428. batch = SampleBatch(
  429. {"a": [1, 2, 3, 4, 5], "dones": [0, 0, 1, 0, 1]})
  430. print(batch.split_by_episode())
  431. # The last episode is appended even if it does not end with done
  432. batch = SampleBatch(
  433. {"a": [1, 2, 3, 4, 5], "dones": [0, 0, 1, 0, 0]})
  434. print(batch.split_by_episode())
  435. batch = SampleBatch(
  436. {"a": [1, 2, 3, 4, 5], "dones": [0, 0, 0, 0, 0]})
  437. print(batch.split_by_episode())
  438. .. testoutput::
  439. [{"a": [1, 2], "eps_id": [0, 0]}, {"a": [3], "eps_id": [1]}]
  440. [{"a": [1, 2, 3], "dones": [0, 0, 1]}, {"a": [4, 5], "dones": [0, 1]}]
  441. [{"a": [1, 2, 3], "dones": [0, 0, 1]}, {"a": [4, 5], "dones": [0, 0]}]
  442. [{"a": [1, 2, 3, 4, 5], "dones": [0, 0, 0, 0, 0]}]
  443. """
  444. assert key is None or key in [SampleBatch.EPS_ID, SampleBatch.DONES], (
  445. f"`SampleBatch.split_by_episode(key={key})` invalid! "
  446. f"Must be [None|'dones'|'eps_id']."
  447. )
  448. def slice_by_eps_id():
  449. slices = []
  450. # Produce a new slice whenever we find a new episode ID.
  451. cur_eps_id = self[SampleBatch.EPS_ID][0]
  452. offset = 0
  453. for i in range(self.count):
  454. next_eps_id = self[SampleBatch.EPS_ID][i]
  455. if next_eps_id != cur_eps_id:
  456. slices.append(self[offset:i])
  457. offset = i
  458. cur_eps_id = next_eps_id
  459. # Add final slice.
  460. slices.append(self[offset : self.count])
  461. return slices
  462. def slice_by_terminateds_or_truncateds():
  463. slices = []
  464. offset = 0
  465. for i in range(self.count):
  466. if self[SampleBatch.TERMINATEDS][i] or (
  467. SampleBatch.TRUNCATEDS in self and self[SampleBatch.TRUNCATEDS][i]
  468. ):
  469. # Since self[i] is the last timestep of the episode,
  470. # append it to the batch, then set offset to the start
  471. # of the next batch
  472. slices.append(self[offset : i + 1])
  473. offset = i + 1
  474. # Add final slice.
  475. if offset != self.count:
  476. slices.append(self[offset:])
  477. return slices
  478. key_to_method = {
  479. SampleBatch.EPS_ID: slice_by_eps_id,
  480. SampleBatch.DONES: slice_by_terminateds_or_truncateds,
  481. }
  482. # If key not specified, default to this order.
  483. key_resolve_order = [SampleBatch.EPS_ID, SampleBatch.DONES]
  484. slices = None
  485. if key is not None:
  486. # If key specified, directly use it.
  487. if key == SampleBatch.EPS_ID and key not in self:
  488. raise KeyError(f"{self} does not have key `{key}`!")
  489. slices = key_to_method[key]()
  490. else:
  491. # If key not specified, go in order.
  492. for key in key_resolve_order:
  493. if key == SampleBatch.DONES or key in self:
  494. slices = key_to_method[key]()
  495. break
  496. if slices is None:
  497. raise KeyError(f"{self} does not have keys {key_resolve_order}!")
  498. assert (
  499. sum(s.count for s in slices) == self.count
  500. ), f"Calling split_by_episode on {self} returns {slices}"
  501. f"which should in total have {self.count} timesteps!"
  502. return slices
  503. def slice(
  504. self, start: int, end: int, state_start=None, state_end=None
  505. ) -> "SampleBatch":
  506. """Returns a slice of the row data of this batch (w/o copying).
  507. Args:
  508. start: Starting index. If < 0, will left-zero-pad.
  509. end: Ending index.
  510. Returns:
  511. A new SampleBatch, which has a slice of this batch's data.
  512. """
  513. if (
  514. self.get(SampleBatch.SEQ_LENS) is not None
  515. and len(self[SampleBatch.SEQ_LENS]) > 0
  516. ):
  517. if start < 0:
  518. data = {
  519. k: np.concatenate(
  520. [
  521. np.zeros(shape=(-start,) + v.shape[1:], dtype=v.dtype),
  522. v[0:end],
  523. ]
  524. )
  525. for k, v in self.items()
  526. if k != SampleBatch.SEQ_LENS and not k.startswith("state_in_")
  527. }
  528. else:
  529. data = {
  530. k: tree.map_structure(lambda s: s[start:end], v)
  531. for k, v in self.items()
  532. if k != SampleBatch.SEQ_LENS and not k.startswith("state_in_")
  533. }
  534. if state_start is not None:
  535. assert state_end is not None
  536. state_idx = 0
  537. state_key = "state_in_{}".format(state_idx)
  538. while state_key in self:
  539. data[state_key] = self[state_key][state_start:state_end]
  540. state_idx += 1
  541. state_key = "state_in_{}".format(state_idx)
  542. seq_lens = list(self[SampleBatch.SEQ_LENS][state_start:state_end])
  543. # Adjust seq_lens if necessary.
  544. data_len = len(data[next(iter(data))])
  545. if sum(seq_lens) != data_len:
  546. assert sum(seq_lens) > data_len
  547. seq_lens[-1] = data_len - sum(seq_lens[:-1])
  548. else:
  549. # Fix state_in_x data.
  550. count = 0
  551. state_start = None
  552. seq_lens = None
  553. for i, seq_len in enumerate(self[SampleBatch.SEQ_LENS]):
  554. count += seq_len
  555. if count >= end:
  556. state_idx = 0
  557. state_key = "state_in_{}".format(state_idx)
  558. if state_start is None:
  559. state_start = i
  560. while state_key in self:
  561. data[state_key] = self[state_key][state_start : i + 1]
  562. state_idx += 1
  563. state_key = "state_in_{}".format(state_idx)
  564. seq_lens = list(self[SampleBatch.SEQ_LENS][state_start:i]) + [
  565. seq_len - (count - end)
  566. ]
  567. if start < 0:
  568. seq_lens[0] += -start
  569. diff = sum(seq_lens) - (end - start)
  570. if diff > 0:
  571. seq_lens[0] -= diff
  572. assert sum(seq_lens) == (end - start)
  573. break
  574. elif state_start is None and count > start:
  575. state_start = i
  576. return SampleBatch(
  577. data,
  578. seq_lens=seq_lens,
  579. _is_training=self.is_training,
  580. _time_major=self.time_major,
  581. _num_grad_updates=self.num_grad_updates,
  582. )
  583. else:
  584. return SampleBatch(
  585. tree.map_structure(lambda value: value[start:end], self),
  586. _is_training=self.is_training,
  587. _time_major=self.time_major,
  588. _num_grad_updates=self.num_grad_updates,
  589. )
  590. def _batch_slice(self, slice_: slice) -> "SampleBatch":
  591. """Helper method to handle SampleBatch slicing using a slice object.
  592. The returned SampleBatch uses the same underlying data object as
  593. `self`, so changing the slice will also change `self`.
  594. Note that only zero or positive bounds are allowed for both start
  595. and stop values. The slice step must be 1 (or None, which is the
  596. same).
  597. Args:
  598. slice_: The python slice object to slice by.
  599. Returns:
  600. A new SampleBatch, however "linking" into the same data
  601. (sliced) as self.
  602. """
  603. start = slice_.start or 0
  604. stop = slice_.stop or len(self[SampleBatch.SEQ_LENS])
  605. # If stop goes beyond the length of this batch -> Make it go till the
  606. # end only (including last item).
  607. # Analogous to `l = [0, 1, 2]; l[:100] -> [0, 1, 2];`.
  608. if stop > len(self):
  609. stop = len(self)
  610. assert start >= 0 and stop >= 0 and slice_.step in [1, None]
  611. # Exclude INFOs from regular array slicing as the data under this column might
  612. # be a list (not good for `tree.map_structure` call).
  613. # Furthermore, slicing does not work when the data in the column is
  614. # singular (not a list or array).
  615. infos = self.pop(SampleBatch.INFOS, None)
  616. data = tree.map_structure(lambda value: value[start:stop], self)
  617. if infos is not None:
  618. # Slice infos according to SEQ_LENS.
  619. info_slice_start = int(sum(self[SampleBatch.SEQ_LENS][:start]))
  620. info_slice_stop = int(sum(self[SampleBatch.SEQ_LENS][start:stop]))
  621. data[SampleBatch.INFOS] = infos[info_slice_start:info_slice_stop]
  622. # Put infos back into `self`.
  623. self[Columns.INFOS] = infos
  624. return SampleBatch(
  625. data,
  626. _is_training=self.is_training,
  627. _time_major=self.time_major,
  628. _num_grad_updates=self.num_grad_updates,
  629. )
  630. @PublicAPI
  631. def timeslices(
  632. self,
  633. size: Optional[int] = None,
  634. num_slices: Optional[int] = None,
  635. k: Optional[int] = None,
  636. ) -> List["SampleBatch"]:
  637. """Returns SampleBatches, each one representing a k-slice of this one.
  638. Will start from timestep 0 and produce slices of size=k.
  639. Args:
  640. size: The size (in timesteps) of each returned SampleBatch.
  641. num_slices: The number of slices to produce.
  642. k: Deprecated: Use size or num_slices instead. The size
  643. (in timesteps) of each returned SampleBatch.
  644. Returns:
  645. The list of `num_slices` (new) SampleBatches or n (new)
  646. SampleBatches each one of size `size`.
  647. """
  648. if size is None and num_slices is None:
  649. deprecation_warning("k", "size or num_slices")
  650. assert k is not None
  651. size = k
  652. if size is None:
  653. assert isinstance(num_slices, int)
  654. slices = []
  655. left = len(self)
  656. start = 0
  657. while left:
  658. len_ = left // (num_slices - len(slices))
  659. stop = start + len_
  660. slices.append(self[start:stop])
  661. left -= len_
  662. start = stop
  663. return slices
  664. else:
  665. assert isinstance(size, int)
  666. slices = []
  667. left = len(self)
  668. start = 0
  669. while left:
  670. stop = start + size
  671. slices.append(self[start:stop])
  672. left -= size
  673. start = stop
  674. return slices
  675. @Deprecated(new="SampleBatch.right_zero_pad", error=True)
  676. def zero_pad(self, max_seq_len, exclude_states=True):
  677. pass
  678. def right_zero_pad(self, max_seq_len: int, exclude_states: bool = True):
  679. """Right (adding zeros at end) zero-pads this SampleBatch in-place.
  680. This will set the `self.zero_padded` flag to True and
  681. `self.max_seq_len` to the given `max_seq_len` value.
  682. Args:
  683. max_seq_len: The max (total) length to zero pad to.
  684. exclude_states: If False, also right-zero-pad all
  685. `state_in_x` data. If True, leave `state_in_x` keys
  686. as-is.
  687. Returns:
  688. This very (now right-zero-padded) SampleBatch.
  689. Raises:
  690. ValueError: If self[SampleBatch.SEQ_LENS] is None (not defined).
  691. .. testcode::
  692. :skipif: True
  693. from ray.rllib.policy.sample_batch import SampleBatch
  694. batch = SampleBatch(
  695. {"a": [1, 2, 3], "seq_lens": [1, 2]})
  696. print(batch.right_zero_pad(max_seq_len=4))
  697. batch = SampleBatch({"a": [1, 2, 3],
  698. "state_in_0": [1.0, 3.0],
  699. "seq_lens": [1, 2]})
  700. print(batch.right_zero_pad(max_seq_len=5))
  701. .. testoutput::
  702. {"a": [1, 0, 0, 0, 2, 3, 0, 0], "seq_lens": [1, 2]}
  703. {"a": [1, 0, 0, 0, 0, 2, 3, 0, 0, 0],
  704. "state_in_0": [1.0, 3.0], # <- all state-ins remain as-is
  705. "seq_lens": [1, 2]}
  706. """
  707. seq_lens = self.get(SampleBatch.SEQ_LENS)
  708. if seq_lens is None:
  709. raise ValueError(
  710. "Cannot right-zero-pad SampleBatch if no `seq_lens` field "
  711. f"present! SampleBatch={self}"
  712. )
  713. length = len(seq_lens) * max_seq_len
  714. def _zero_pad_in_place(path, value):
  715. # Skip "state_in_..." columns and "seq_lens".
  716. if (exclude_states is True and path[0].startswith("state_in_")) or path[
  717. 0
  718. ] == SampleBatch.SEQ_LENS:
  719. return
  720. # Generate zero-filled primer of len=max_seq_len.
  721. if value.dtype == object or value.dtype.type is np.str_:
  722. f_pad = [None] * length
  723. else:
  724. # Make sure type doesn't change.
  725. f_pad = np.zeros((length,) + np.shape(value)[1:], dtype=value.dtype)
  726. # Fill primer with data.
  727. f_pad_base = f_base = 0
  728. for len_ in self[SampleBatch.SEQ_LENS]:
  729. f_pad[f_pad_base : f_pad_base + len_] = value[f_base : f_base + len_]
  730. f_pad_base += max_seq_len
  731. f_base += len_
  732. assert f_base == len(value), value
  733. # Update our data in-place.
  734. curr = self
  735. for i, p in enumerate(path):
  736. if i == len(path) - 1:
  737. curr[p] = f_pad
  738. curr = curr[p]
  739. self_as_dict = dict(self)
  740. tree.map_structure_with_path(_zero_pad_in_place, self_as_dict)
  741. # Set flags to indicate, we are now zero-padded (and to what extend).
  742. self.zero_padded = True
  743. self.max_seq_len = max_seq_len
  744. return self
  745. @ExperimentalAPI
  746. def to_device(
  747. self,
  748. device,
  749. framework: str = "torch",
  750. pin_memory: bool = False,
  751. use_stream: bool = False,
  752. stream: Optional[Union["torch.cuda.Stream", "torch.cuda.Stream"]] = None,
  753. ):
  754. """TODO: transfer batch to given device as framework tensor."""
  755. if framework == "torch":
  756. assert torch is not None
  757. for k, v in self.items():
  758. self[k] = convert_to_torch_tensor(
  759. v,
  760. device,
  761. pin_memory=pin_memory,
  762. use_stream=use_stream,
  763. stream=stream,
  764. )
  765. else:
  766. raise NotImplementedError
  767. return self
  768. @PublicAPI
  769. def size_bytes(self) -> int:
  770. """Returns sum over number of bytes of all data buffers.
  771. For numpy arrays, we use ``.nbytes``. For all other value types, we use
  772. sys.getsizeof(...).
  773. Returns:
  774. The overall size in bytes of the data buffer (all columns).
  775. """
  776. return sum(
  777. v.nbytes if isinstance(v, np.ndarray) else sys.getsizeof(v)
  778. for v in tree.flatten(self)
  779. )
  780. def get(self, key, default=None):
  781. """Returns one column (by key) from the data or a default value."""
  782. try:
  783. return self.__getitem__(key)
  784. except KeyError:
  785. return default
  786. @PublicAPI
  787. def as_multi_agent(self, module_id: Optional[ModuleID] = None) -> "MultiAgentBatch":
  788. """Returns the respective MultiAgentBatch
  789. Note, if `module_id` is not provided uses `DEFAULT_POLICY`_ID`.
  790. Args;
  791. module_id: An optional module ID. If `None` the `DEFAULT_POLICY_ID`
  792. is used.
  793. Returns:
  794. The MultiAgentBatch (using DEFAULT_POLICY_ID) corresponding
  795. to this SampleBatch.
  796. """
  797. return MultiAgentBatch({module_id or DEFAULT_POLICY_ID: self}, self.count)
  798. @PublicAPI
  799. def __getitem__(self, key: Union[str, slice]) -> TensorType:
  800. """Returns one column (by key) from the data or a sliced new batch.
  801. Args:
  802. key: The key (column name) to return or
  803. a slice object for slicing this SampleBatch.
  804. Returns:
  805. The data under the given key or a sliced version of this batch.
  806. """
  807. if isinstance(key, slice):
  808. return self._slice(key)
  809. # Special key DONES -> Translate to `TERMINATEDS | TRUNCATEDS` to reflect
  810. # the old meaning of DONES.
  811. if key == SampleBatch.DONES:
  812. return self[SampleBatch.TERMINATEDS]
  813. # Backward compatibility for when "input-dicts" were used.
  814. elif key == "is_training":
  815. if log_once("SampleBatch['is_training']"):
  816. deprecation_warning(
  817. old="SampleBatch['is_training']",
  818. new="SampleBatch.is_training",
  819. error=False,
  820. )
  821. return self.is_training
  822. if not hasattr(self, key) and key in self:
  823. self.accessed_keys.add(key)
  824. value = dict.__getitem__(self, key)
  825. if self.get_interceptor is not None:
  826. if key not in self.intercepted_values:
  827. self.intercepted_values[key] = self.get_interceptor(value)
  828. value = self.intercepted_values[key]
  829. return value
  830. @PublicAPI
  831. def __setitem__(self, key, item) -> None:
  832. """Inserts (overrides) an entire column (by key) in the data buffer.
  833. Args:
  834. key: The column name to set a value for.
  835. item: The data to insert.
  836. """
  837. # Disallow setting DONES key directly.
  838. if key == SampleBatch.DONES:
  839. raise KeyError(
  840. "Cannot set `DONES` anymore in a SampleBatch! "
  841. "Instead, set the new TERMINATEDS and TRUNCATEDS keys. The values under"
  842. " DONES will then be automatically computed using terminated|truncated."
  843. )
  844. # Defend against creating SampleBatch via pickle (no property
  845. # `added_keys` and first item is already set).
  846. elif not hasattr(self, "added_keys"):
  847. dict.__setitem__(self, key, item)
  848. return
  849. # Backward compatibility for when "input-dicts" were used.
  850. if key == "is_training":
  851. if log_once("SampleBatch['is_training']"):
  852. deprecation_warning(
  853. old="SampleBatch['is_training']",
  854. new="SampleBatch.is_training",
  855. error=False,
  856. )
  857. self._is_training = item
  858. return
  859. if key not in self:
  860. self.added_keys.add(key)
  861. dict.__setitem__(self, key, item)
  862. if key in self.intercepted_values:
  863. self.intercepted_values[key] = item
  864. @property
  865. def is_training(self):
  866. if self.get_interceptor is not None and isinstance(self._is_training, bool):
  867. if "_is_training" not in self.intercepted_values:
  868. self.intercepted_values["_is_training"] = self.get_interceptor(
  869. self._is_training
  870. )
  871. return self.intercepted_values["_is_training"]
  872. return self._is_training
  873. def set_training(self, training: Union[bool, "tf1.placeholder"] = True):
  874. """Sets the `is_training` flag for this SampleBatch."""
  875. self._is_training = training
  876. self.intercepted_values.pop("_is_training", None)
  877. @PublicAPI
  878. def __delitem__(self, key):
  879. self.deleted_keys.add(key)
  880. dict.__delitem__(self, key)
  881. @DeveloperAPI
  882. def compress(
  883. self, bulk: bool = False, columns: Set[str] = frozenset(["obs", "new_obs"])
  884. ) -> "SampleBatch":
  885. """Compresses the data buffers (by column) in place.
  886. Args:
  887. bulk: Whether to compress across the batch dimension (0)
  888. as well. If False will compress n separate list items, where n
  889. is the batch size.
  890. columns: The columns to compress. Default: Only
  891. compress the obs and new_obs columns.
  892. Returns:
  893. This very (now compressed) SampleBatch.
  894. """
  895. def _compress_in_place(path, value):
  896. if path[0] not in columns:
  897. return
  898. curr = self
  899. for i, p in enumerate(path):
  900. if i == len(path) - 1:
  901. if bulk:
  902. curr[p] = pack(value)
  903. else:
  904. curr[p] = np.array([pack(o) for o in value])
  905. curr = curr[p]
  906. tree.map_structure_with_path(_compress_in_place, self)
  907. return self
  908. @DeveloperAPI
  909. def decompress_if_needed(
  910. self, columns: Set[str] = frozenset(["obs", "new_obs"])
  911. ) -> "SampleBatch":
  912. """Decompresses data buffers (per column if not compressed) in place.
  913. Args:
  914. columns: The columns to decompress. Default: Only
  915. decompress the obs and new_obs columns.
  916. Returns:
  917. This very (now uncompressed) SampleBatch.
  918. """
  919. def _decompress_in_place(path, value):
  920. if path[0] not in columns:
  921. return
  922. curr = self
  923. for p in path[:-1]:
  924. curr = curr[p]
  925. # Bulk compressed.
  926. if is_compressed(value):
  927. curr[path[-1]] = unpack(value)
  928. # Non bulk compressed.
  929. elif len(value) > 0 and is_compressed(value[0]):
  930. curr[path[-1]] = np.array([unpack(o) for o in value])
  931. tree.map_structure_with_path(_decompress_in_place, self)
  932. return self
  933. @DeveloperAPI
  934. def set_get_interceptor(self, fn):
  935. """Sets a function to be called on every getitem."""
  936. # If get-interceptor changes, must erase old intercepted values.
  937. if fn is not self.get_interceptor:
  938. self.intercepted_values = {}
  939. self.get_interceptor = fn
  940. def __repr__(self):
  941. keys = list(self.keys())
  942. if self.get(SampleBatch.SEQ_LENS) is None:
  943. return f"SampleBatch({self.count}: {keys})"
  944. else:
  945. keys.remove(SampleBatch.SEQ_LENS)
  946. return (
  947. f"SampleBatch({self.count} " f"(seqs={len(self['seq_lens'])}): {keys})"
  948. )
  949. def _slice(self, slice_: slice) -> "SampleBatch":
  950. """Helper method to handle SampleBatch slicing using a slice object.
  951. The returned SampleBatch uses the same underlying data object as
  952. `self`, so changing the slice will also change `self`.
  953. Note that only zero or positive bounds are allowed for both start
  954. and stop values. The slice step must be 1 (or None, which is the
  955. same).
  956. Args:
  957. slice_: The python slice object to slice by.
  958. Returns:
  959. A new SampleBatch, however "linking" into the same data
  960. (sliced) as self.
  961. """
  962. if self._slice_seq_lens_in_B:
  963. return self._batch_slice(slice_)
  964. start = slice_.start or 0
  965. stop = slice_.stop or len(self)
  966. # If stop goes beyond the length of this batch -> Make it go till the
  967. # end only (including last item).
  968. # Analogous to `l = [0, 1, 2]; l[:100] -> [0, 1, 2];`.
  969. if stop > len(self):
  970. stop = len(self)
  971. if (
  972. self.get(SampleBatch.SEQ_LENS) is not None
  973. and len(self[SampleBatch.SEQ_LENS]) > 0
  974. ):
  975. # Build our slice-map, if not done already.
  976. if not self._slice_map:
  977. sum_ = 0
  978. for i, l in enumerate(map(int, self[SampleBatch.SEQ_LENS])):
  979. self._slice_map.extend([(i, sum_)] * l)
  980. sum_ = sum_ + l
  981. # In case `stop` points to the very end (lengths of this
  982. # batch), return the last sequence (the -1 here makes sure we
  983. # never go beyond it; would result in an index error below).
  984. self._slice_map.append((len(self[SampleBatch.SEQ_LENS]), sum_))
  985. start_seq_len, start_unpadded = self._slice_map[start]
  986. stop_seq_len, stop_unpadded = self._slice_map[stop]
  987. start_padded = start_unpadded
  988. stop_padded = stop_unpadded
  989. if self.zero_padded:
  990. start_padded = start_seq_len * self.max_seq_len
  991. stop_padded = stop_seq_len * self.max_seq_len
  992. def map_(path, value):
  993. if path[0] != SampleBatch.SEQ_LENS and not path[0].startswith(
  994. "state_in_"
  995. ):
  996. return value[start_padded:stop_padded]
  997. else:
  998. return value[start_seq_len:stop_seq_len]
  999. infos = self.pop(SampleBatch.INFOS, None)
  1000. data = tree.map_structure_with_path(map_, self)
  1001. if infos is not None and isinstance(infos, (list, np.ndarray)):
  1002. self[SampleBatch.INFOS] = infos
  1003. data[SampleBatch.INFOS] = infos[start_unpadded:stop_unpadded]
  1004. return SampleBatch(
  1005. data,
  1006. _is_training=self.is_training,
  1007. _time_major=self.time_major,
  1008. _zero_padded=self.zero_padded,
  1009. _max_seq_len=self.max_seq_len if self.zero_padded else None,
  1010. _num_grad_updates=self.num_grad_updates,
  1011. )
  1012. else:
  1013. infos = self.pop(SampleBatch.INFOS, None)
  1014. data = tree.map_structure(lambda s: s[start:stop], self)
  1015. if infos is not None and isinstance(infos, (list, np.ndarray)):
  1016. self[SampleBatch.INFOS] = infos
  1017. data[SampleBatch.INFOS] = infos[start:stop]
  1018. return SampleBatch(
  1019. data,
  1020. _is_training=self.is_training,
  1021. _time_major=self.time_major,
  1022. _num_grad_updates=self.num_grad_updates,
  1023. )
  1024. @Deprecated(error=False)
  1025. def _get_slice_indices(self, slice_size):
  1026. data_slices = []
  1027. data_slices_states = []
  1028. if (
  1029. self.get(SampleBatch.SEQ_LENS) is not None
  1030. and len(self[SampleBatch.SEQ_LENS]) > 0
  1031. ):
  1032. assert np.all(self[SampleBatch.SEQ_LENS] < slice_size), (
  1033. "ERROR: `slice_size` must be larger than the max. seq-len "
  1034. "in the batch!"
  1035. )
  1036. start_pos = 0
  1037. current_slize_size = 0
  1038. actual_slice_idx = 0
  1039. start_idx = 0
  1040. idx = 0
  1041. while idx < len(self[SampleBatch.SEQ_LENS]):
  1042. seq_len = self[SampleBatch.SEQ_LENS][idx]
  1043. current_slize_size += seq_len
  1044. actual_slice_idx += (
  1045. seq_len if not self.zero_padded else self.max_seq_len
  1046. )
  1047. # Complete minibatch -> Append to data_slices.
  1048. if current_slize_size >= slice_size:
  1049. end_idx = idx + 1
  1050. # We are not zero-padded yet; all sequences are
  1051. # back-to-back.
  1052. if not self.zero_padded:
  1053. data_slices.append((start_pos, start_pos + slice_size))
  1054. start_pos += slice_size
  1055. if current_slize_size > slice_size:
  1056. overhead = current_slize_size - slice_size
  1057. start_pos -= seq_len - overhead
  1058. idx -= 1
  1059. # We are already zero-padded: Cut in chunks of max_seq_len.
  1060. else:
  1061. data_slices.append((start_pos, actual_slice_idx))
  1062. start_pos = actual_slice_idx
  1063. data_slices_states.append((start_idx, end_idx))
  1064. current_slize_size = 0
  1065. start_idx = idx + 1
  1066. idx += 1
  1067. else:
  1068. i = 0
  1069. while i < self.count:
  1070. data_slices.append((i, i + slice_size))
  1071. i += slice_size
  1072. return data_slices, data_slices_states
  1073. @ExperimentalAPI
  1074. def get_single_step_input_dict(
  1075. self,
  1076. view_requirements: ViewRequirementsDict,
  1077. index: Union[str, int] = "last",
  1078. ) -> "SampleBatch":
  1079. """Creates single ts SampleBatch at given index from `self`.
  1080. For usage as input-dict for model (action or value function) calls.
  1081. Args:
  1082. view_requirements: A view requirements dict from the model for
  1083. which to produce the input_dict.
  1084. index: An integer index value indicating the
  1085. position in the trajectory for which to generate the
  1086. compute_actions input dict. Set to "last" to generate the dict
  1087. at the very end of the trajectory (e.g. for value estimation).
  1088. Note that "last" is different from -1, as "last" will use the
  1089. final NEXT_OBS as observation input.
  1090. Returns:
  1091. The (single-timestep) input dict for ModelV2 calls.
  1092. """
  1093. last_mappings = {
  1094. SampleBatch.OBS: SampleBatch.NEXT_OBS,
  1095. SampleBatch.PREV_ACTIONS: SampleBatch.ACTIONS,
  1096. SampleBatch.PREV_REWARDS: SampleBatch.REWARDS,
  1097. }
  1098. input_dict = {}
  1099. for view_col, view_req in view_requirements.items():
  1100. if view_req.used_for_compute_actions is False:
  1101. continue
  1102. # Create batches of size 1 (single-agent input-dict).
  1103. data_col = view_req.data_col or view_col
  1104. if index == "last":
  1105. data_col = last_mappings.get(data_col, data_col)
  1106. # Range needed.
  1107. if view_req.shift_from is not None:
  1108. # Batch repeat value > 1: We have single frames in the
  1109. # batch at each timestep (for the `data_col`).
  1110. data = self[view_col][-1]
  1111. traj_len = len(self[data_col])
  1112. missing_at_end = traj_len % view_req.batch_repeat_value
  1113. # Index into the observations column must be shifted by
  1114. # -1 b/c index=0 for observations means the current (last
  1115. # seen) observation (after having taken an action).
  1116. obs_shift = (
  1117. -1 if data_col in [SampleBatch.OBS, SampleBatch.NEXT_OBS] else 0
  1118. )
  1119. from_ = view_req.shift_from + obs_shift
  1120. to_ = view_req.shift_to + obs_shift + 1
  1121. if to_ == 0:
  1122. to_ = None
  1123. input_dict[view_col] = np.array(
  1124. [
  1125. np.concatenate([data, self[data_col][-missing_at_end:]])[
  1126. from_:to_
  1127. ]
  1128. ]
  1129. )
  1130. # Single index.
  1131. else:
  1132. input_dict[view_col] = tree.map_structure(
  1133. lambda v: v[-1:], # keep as array (w/ 1 element)
  1134. self[data_col],
  1135. )
  1136. # Single index somewhere inside the trajectory (non-last).
  1137. else:
  1138. input_dict[view_col] = self[data_col][
  1139. index : index + 1 if index != -1 else None
  1140. ]
  1141. return SampleBatch(input_dict, seq_lens=np.array([1], dtype=np.int32))
  1142. @PublicAPI
  1143. class MultiAgentBatch:
  1144. """A batch of experiences from multiple agents in the environment.
  1145. Attributes:
  1146. policy_batches (Dict[PolicyID, SampleBatch]): Dict mapping policy IDs to
  1147. SampleBatches of experiences.
  1148. count: The number of env steps in this batch.
  1149. """
  1150. @PublicAPI
  1151. def __init__(self, policy_batches: Dict[PolicyID, SampleBatch], env_steps: int):
  1152. """Initialize a MultiAgentBatch instance.
  1153. Args:
  1154. policy_batches: Dict mapping policy IDs to SampleBatches of experiences.
  1155. env_steps: The number of environment steps in the environment
  1156. this batch contains. This will be less than the number of
  1157. transitions this batch contains across all policies in total.
  1158. """
  1159. for v in policy_batches.values():
  1160. assert isinstance(v, SampleBatch)
  1161. self.policy_batches = policy_batches
  1162. # Called "count" for uniformity with SampleBatch.
  1163. # Prefer to access this via the `env_steps()` method when possible
  1164. # for clarity.
  1165. self.count = env_steps
  1166. @PublicAPI
  1167. def env_steps(self) -> int:
  1168. """The number of env steps (there are >= 1 agent steps per env step).
  1169. Returns:
  1170. The number of environment steps contained in this batch.
  1171. """
  1172. return self.count
  1173. @PublicAPI
  1174. def __len__(self) -> int:
  1175. """Same as `self.env_steps()`."""
  1176. return self.count
  1177. @PublicAPI
  1178. def agent_steps(self) -> int:
  1179. """The number of agent steps (there are >= 1 agent steps per env step).
  1180. Returns:
  1181. The number of agent steps total in this batch.
  1182. """
  1183. ct = 0
  1184. for batch in self.policy_batches.values():
  1185. ct += batch.count
  1186. return ct
  1187. @PublicAPI
  1188. def timeslices(self, k: int) -> List["MultiAgentBatch"]:
  1189. """Returns k-step batches holding data for each agent at those steps.
  1190. For examples, suppose we have agent1 observations [a1t1, a1t2, a1t3],
  1191. for agent2, [a2t1, a2t3], and for agent3, [a3t3] only.
  1192. Calling timeslices(1) would return three MultiAgentBatches containing
  1193. [a1t1, a2t1], [a1t2], and [a1t3, a2t3, a3t3].
  1194. Calling timeslices(2) would return two MultiAgentBatches containing
  1195. [a1t1, a1t2, a2t1], and [a1t3, a2t3, a3t3].
  1196. This method is used to implement "lockstep" replay mode. Note that this
  1197. method does not guarantee each batch contains only data from a single
  1198. unroll. Batches might contain data from multiple different envs.
  1199. """
  1200. from ray.rllib.evaluation.sample_batch_builder import SampleBatchBuilder
  1201. # Build a sorted set of (eps_id, t, policy_id, data...)
  1202. steps = []
  1203. for policy_id, batch in self.policy_batches.items():
  1204. for row in batch.rows():
  1205. steps.append(
  1206. (
  1207. row[SampleBatch.EPS_ID],
  1208. row[SampleBatch.T],
  1209. row[SampleBatch.AGENT_INDEX],
  1210. policy_id,
  1211. row,
  1212. )
  1213. )
  1214. steps.sort()
  1215. finished_slices = []
  1216. cur_slice = collections.defaultdict(SampleBatchBuilder)
  1217. cur_slice_size = 0
  1218. def finish_slice():
  1219. nonlocal cur_slice_size
  1220. assert cur_slice_size > 0
  1221. batch = MultiAgentBatch(
  1222. {k: v.build_and_reset() for k, v in cur_slice.items()}, cur_slice_size
  1223. )
  1224. cur_slice_size = 0
  1225. cur_slice.clear()
  1226. finished_slices.append(batch)
  1227. # For each unique env timestep.
  1228. for _, group in itertools.groupby(steps, lambda x: x[:2]):
  1229. # Accumulate into the current slice.
  1230. for _, _, _, policy_id, row in group:
  1231. cur_slice[policy_id].add_values(**row)
  1232. cur_slice_size += 1
  1233. # Slice has reached target number of env steps.
  1234. if cur_slice_size >= k:
  1235. finish_slice()
  1236. assert cur_slice_size == 0
  1237. if cur_slice_size > 0:
  1238. finish_slice()
  1239. assert len(finished_slices) > 0, finished_slices
  1240. return finished_slices
  1241. @staticmethod
  1242. @PublicAPI
  1243. def wrap_as_needed(
  1244. policy_batches: Dict[PolicyID, SampleBatch], env_steps: int
  1245. ) -> Union[SampleBatch, "MultiAgentBatch"]:
  1246. """Returns SampleBatch or MultiAgentBatch, depending on given policies.
  1247. If policy_batches is empty (i.e. {}) it returns an empty MultiAgentBatch.
  1248. Args:
  1249. policy_batches: Mapping from policy ids to SampleBatch.
  1250. env_steps: Number of env steps in the batch.
  1251. Returns:
  1252. The single default policy's SampleBatch or a MultiAgentBatch
  1253. (more than one policy).
  1254. """
  1255. if len(policy_batches) == 1 and DEFAULT_POLICY_ID in policy_batches:
  1256. return policy_batches[DEFAULT_POLICY_ID]
  1257. return MultiAgentBatch(policy_batches=policy_batches, env_steps=env_steps)
  1258. @staticmethod
  1259. @PublicAPI
  1260. @Deprecated(new="concat_samples() from rllib.policy.sample_batch", error=True)
  1261. def concat_samples(samples: List["MultiAgentBatch"]) -> "MultiAgentBatch":
  1262. return concat_samples_into_ma_batch(samples)
  1263. @PublicAPI
  1264. def copy(self) -> "MultiAgentBatch":
  1265. """Deep-copies self into a new MultiAgentBatch.
  1266. Returns:
  1267. The copy of self with deep-copied data.
  1268. """
  1269. return MultiAgentBatch(
  1270. {k: v.copy() for (k, v) in self.policy_batches.items()}, self.count
  1271. )
  1272. @ExperimentalAPI
  1273. def to_device(
  1274. self,
  1275. device,
  1276. framework="torch",
  1277. pin_memory: bool = False,
  1278. use_stream: bool = False,
  1279. stream: Optional[Union["torch.cuda.Stream", "torch.cuda.Stream"]] = None,
  1280. ):
  1281. """TODO: transfer batch to given device as framework tensor."""
  1282. if framework == "torch":
  1283. assert torch is not None
  1284. for pid, policy_batch in self.policy_batches.items():
  1285. self.policy_batches[pid] = policy_batch.to_device(
  1286. device,
  1287. framework=framework,
  1288. pin_memory=pin_memory,
  1289. use_stream=use_stream,
  1290. stream=stream,
  1291. )
  1292. else:
  1293. raise NotImplementedError
  1294. return self
  1295. @PublicAPI
  1296. def size_bytes(self) -> int:
  1297. """
  1298. Returns:
  1299. The overall size in bytes of all policy batches (all columns).
  1300. """
  1301. return sum(b.size_bytes() for b in self.policy_batches.values())
  1302. @DeveloperAPI
  1303. def compress(
  1304. self, bulk: bool = False, columns: Set[str] = frozenset(["obs", "new_obs"])
  1305. ) -> None:
  1306. """Compresses each policy batch (per column) in place.
  1307. Args:
  1308. bulk: Whether to compress across the batch dimension (0)
  1309. as well. If False will compress n separate list items, where n
  1310. is the batch size.
  1311. columns: Set of column names to compress.
  1312. """
  1313. for batch in self.policy_batches.values():
  1314. batch.compress(bulk=bulk, columns=columns)
  1315. @DeveloperAPI
  1316. def decompress_if_needed(
  1317. self, columns: Set[str] = frozenset(["obs", "new_obs"])
  1318. ) -> "MultiAgentBatch":
  1319. """Decompresses each policy batch (per column), if already compressed.
  1320. Args:
  1321. columns: Set of column names to decompress.
  1322. Returns:
  1323. Self.
  1324. """
  1325. for batch in self.policy_batches.values():
  1326. batch.decompress_if_needed(columns)
  1327. return self
  1328. @DeveloperAPI
  1329. def as_multi_agent(self) -> "MultiAgentBatch":
  1330. """Simply returns `self` (already a MultiAgentBatch).
  1331. Returns:
  1332. This very instance of MultiAgentBatch.
  1333. """
  1334. return self
  1335. def __getitem__(self, key: str) -> SampleBatch:
  1336. """Returns the SampleBatch for the given policy id."""
  1337. return self.policy_batches[key]
  1338. def __str__(self):
  1339. return "MultiAgentBatch({}, env_steps={})".format(
  1340. str(self.policy_batches), self.count
  1341. )
  1342. def __repr__(self):
  1343. return "MultiAgentBatch({}, env_steps={})".format(
  1344. str(self.policy_batches), self.count
  1345. )
  1346. @PublicAPI
  1347. def concat_samples(samples: List[SampleBatchType]) -> SampleBatchType:
  1348. """Concatenates a list of SampleBatches or MultiAgentBatches.
  1349. If all items in the list are or SampleBatch typ4, the output will be
  1350. a SampleBatch type. Otherwise, the output will be a MultiAgentBatch type.
  1351. If input is a mixture of SampleBatch and MultiAgentBatch types, it will treat
  1352. SampleBatch objects as MultiAgentBatch types with 'default_policy' key and
  1353. concatenate it with th rest of MultiAgentBatch objects.
  1354. Empty samples are simply ignored.
  1355. Args:
  1356. samples: List of SampleBatches or MultiAgentBatches to be
  1357. concatenated.
  1358. Returns:
  1359. A new (concatenated) SampleBatch or MultiAgentBatch.
  1360. .. testcode::
  1361. :skipif: True
  1362. import numpy as np
  1363. from ray.rllib.policy.sample_batch import SampleBatch
  1364. b1 = SampleBatch({"a": np.array([1, 2]),
  1365. "b": np.array([10, 11])})
  1366. b2 = SampleBatch({"a": np.array([3]),
  1367. "b": np.array([12])})
  1368. print(concat_samples([b1, b2]))
  1369. c1 = MultiAgentBatch({'default_policy': {
  1370. "a": np.array([1, 2]),
  1371. "b": np.array([10, 11])
  1372. }}, env_steps=2)
  1373. c2 = SampleBatch({"a": np.array([3]),
  1374. "b": np.array([12])})
  1375. print(concat_samples([b1, b2]))
  1376. .. testoutput::
  1377. {"a": np.array([1, 2, 3]), "b": np.array([10, 11, 12])}
  1378. MultiAgentBatch = {'default_policy': {"a": np.array([1, 2, 3]),
  1379. "b": np.array([10, 11, 12])}}
  1380. """
  1381. if any(isinstance(s, MultiAgentBatch) for s in samples):
  1382. return concat_samples_into_ma_batch(samples)
  1383. # the output is a SampleBatch type
  1384. concatd_seq_lens = []
  1385. concatd_num_grad_updates = [0, 0.0] # [0]=count; [1]=weighted sum values
  1386. concated_samples = []
  1387. # Make sure these settings are consistent amongst all batches.
  1388. zero_padded = max_seq_len = time_major = None
  1389. for s in samples:
  1390. if s.count <= 0:
  1391. continue
  1392. if max_seq_len is None:
  1393. zero_padded = s.zero_padded
  1394. max_seq_len = s.max_seq_len
  1395. time_major = s.time_major
  1396. # Make sure these settings are consistent amongst all batches.
  1397. if s.zero_padded != zero_padded or s.time_major != time_major:
  1398. raise ValueError(
  1399. "All SampleBatches' `zero_padded` and `time_major` settings "
  1400. "must be consistent!"
  1401. )
  1402. if (
  1403. s.max_seq_len is None or max_seq_len is None
  1404. ) and s.max_seq_len != max_seq_len:
  1405. raise ValueError(
  1406. "Samples must consistently either provide or omit `max_seq_len`!"
  1407. )
  1408. elif zero_padded and s.max_seq_len != max_seq_len:
  1409. raise ValueError(
  1410. "For `zero_padded` SampleBatches, the values of `max_seq_len` "
  1411. "must be consistent!"
  1412. )
  1413. if max_seq_len is not None:
  1414. max_seq_len = max(max_seq_len, s.max_seq_len)
  1415. if s.get(SampleBatch.SEQ_LENS) is not None:
  1416. concatd_seq_lens.extend(s[SampleBatch.SEQ_LENS])
  1417. if s.num_grad_updates is not None:
  1418. concatd_num_grad_updates[0] += s.count
  1419. concatd_num_grad_updates[1] += s.num_grad_updates * s.count
  1420. concated_samples.append(s)
  1421. # If we don't have any samples (0 or only empty SampleBatches),
  1422. # return an empty SampleBatch here.
  1423. if len(concated_samples) == 0:
  1424. return SampleBatch()
  1425. # Collect the concat'd data.
  1426. concatd_data = {}
  1427. for k in concated_samples[0].keys():
  1428. if k == SampleBatch.INFOS:
  1429. concatd_data[k] = _concat_values(
  1430. *[s[k] for s in concated_samples],
  1431. time_major=time_major,
  1432. )
  1433. else:
  1434. values_to_concat = [c[k] for c in concated_samples]
  1435. _concat_values_w_time = partial(_concat_values, time_major=time_major)
  1436. concatd_data[k] = tree.map_structure(
  1437. _concat_values_w_time, *values_to_concat
  1438. )
  1439. if concatd_seq_lens != [] and torch and torch.is_tensor(concatd_seq_lens[0]):
  1440. concatd_seq_lens = torch.Tensor(concatd_seq_lens)
  1441. elif concatd_seq_lens != [] and tf and tf.is_tensor(concatd_seq_lens[0]):
  1442. concatd_seq_lens = tf.convert_to_tensor(concatd_seq_lens)
  1443. # Return a new (concat'd) SampleBatch.
  1444. return SampleBatch(
  1445. concatd_data,
  1446. seq_lens=concatd_seq_lens,
  1447. _time_major=time_major,
  1448. _zero_padded=zero_padded,
  1449. _max_seq_len=max_seq_len,
  1450. # Compute weighted average of the num_grad_updates for the batches
  1451. # (assuming they all come from the same policy).
  1452. _num_grad_updates=(
  1453. concatd_num_grad_updates[1] / (concatd_num_grad_updates[0] or 1.0)
  1454. ),
  1455. )
  1456. @PublicAPI
  1457. def concat_samples_into_ma_batch(samples: List[SampleBatchType]) -> "MultiAgentBatch":
  1458. """Concatenates a list of SampleBatchTypes to a single MultiAgentBatch type.
  1459. This function, as opposed to concat_samples() forces the output to always be
  1460. MultiAgentBatch which is more generic than SampleBatch.
  1461. Args:
  1462. samples: List of SampleBatches or MultiAgentBatches to be
  1463. concatenated.
  1464. Returns:
  1465. A new (concatenated) MultiAgentBatch.
  1466. .. testcode::
  1467. :skipif: True
  1468. import numpy as np
  1469. from ray.rllib.policy.sample_batch import SampleBatch
  1470. b1 = MultiAgentBatch({'default_policy': {
  1471. "a": np.array([1, 2]),
  1472. "b": np.array([10, 11])
  1473. }}, env_steps=2)
  1474. b2 = SampleBatch({"a": np.array([3]),
  1475. "b": np.array([12])})
  1476. print(concat_samples([b1, b2]))
  1477. .. testoutput::
  1478. {'default_policy': {"a": np.array([1, 2, 3]),
  1479. "b": np.array([10, 11, 12])}}
  1480. """
  1481. policy_batches = collections.defaultdict(list)
  1482. env_steps = 0
  1483. for s in samples:
  1484. # Some batches in `samples` may be SampleBatch.
  1485. if isinstance(s, SampleBatch):
  1486. # If empty SampleBatch: ok (just ignore).
  1487. if len(s) <= 0:
  1488. continue
  1489. else:
  1490. # if non-empty: just convert to MA-batch and move forward
  1491. s = s.as_multi_agent()
  1492. elif not isinstance(s, MultiAgentBatch):
  1493. # Otherwise: Error.
  1494. raise ValueError(
  1495. "`concat_samples_into_ma_batch` can only concat "
  1496. "SampleBatch|MultiAgentBatch objects, not {}!".format(type(s).__name__)
  1497. )
  1498. for key, batch in s.policy_batches.items():
  1499. policy_batches[key].append(batch)
  1500. env_steps += s.env_steps()
  1501. out = {}
  1502. for key, batches in policy_batches.items():
  1503. out[key] = concat_samples(batches)
  1504. return MultiAgentBatch(out, env_steps)
  1505. def _concat_values(*values, time_major=None) -> TensorType:
  1506. """Concatenates a list of values.
  1507. Args:
  1508. values: The values to concatenate.
  1509. time_major: Whether to concatenate along the first axis
  1510. (time_major=False) or the second axis (time_major=True).
  1511. """
  1512. if torch and torch.is_tensor(values[0]):
  1513. return torch.cat(values, dim=1 if time_major else 0)
  1514. elif isinstance(values[0], np.ndarray):
  1515. return np.concatenate(values, axis=1 if time_major else 0)
  1516. elif tf and tf.is_tensor(values[0]):
  1517. return tf.concat(values, axis=1 if time_major else 0)
  1518. elif isinstance(values[0], list):
  1519. concatenated_list = []
  1520. for sublist in values:
  1521. concatenated_list.extend(sublist)
  1522. return concatenated_list
  1523. else:
  1524. raise ValueError(
  1525. f"Unsupported type for concatenation: {type(values[0])} "
  1526. f"first element: {values[0]}"
  1527. )
  1528. @DeveloperAPI
  1529. def convert_ma_batch_to_sample_batch(batch: SampleBatchType) -> SampleBatch:
  1530. """Converts a MultiAgentBatch to a SampleBatch if necessary.
  1531. Args:
  1532. batch: The SampleBatchType to convert.
  1533. Returns:
  1534. batch: the converted SampleBatch
  1535. Raises:
  1536. ValueError if the MultiAgentBatch has more than one policy_id
  1537. or if the policy_id is not `DEFAULT_POLICY_ID`
  1538. """
  1539. if isinstance(batch, MultiAgentBatch):
  1540. policy_keys = batch.policy_batches.keys()
  1541. if len(policy_keys) == 1 and DEFAULT_POLICY_ID in policy_keys:
  1542. batch = batch.policy_batches[DEFAULT_POLICY_ID]
  1543. else:
  1544. raise ValueError(
  1545. "RLlib tried to convert a multi agent-batch with data from more "
  1546. "than one policy to a single-agent batch. This is not supported and "
  1547. "may be due to a number of issues. Here are two possible ones:"
  1548. "1) Off-Policy Estimation is not implemented for "
  1549. "multi-agent batches. You can set `off_policy_estimation_methods: {}` "
  1550. "to resolve this."
  1551. "2) Loading multi-agent data for offline training is not implemented."
  1552. "Load single-agent data instead to resolve this."
  1553. )
  1554. return batch