offline_prelearner.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660
  1. import copy
  2. import logging
  3. import uuid
  4. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
  5. import gymnasium as gym
  6. import numpy as np
  7. import tree
  8. from ray.rllib.core.columns import Columns
  9. from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule, MultiRLModuleSpec
  10. from ray.rllib.env.single_agent_episode import SingleAgentEpisode
  11. from ray.rllib.utils import flatten_dict
  12. from ray.rllib.utils.annotations import (
  13. OverrideToImplementCustomLogic,
  14. OverrideToImplementCustomLogic_CallToSuperRecommended,
  15. )
  16. from ray.rllib.utils.compression import unpack_if_needed
  17. from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer
  18. from ray.rllib.utils.typing import EpisodeType, ModuleID
  19. from ray.util.annotations import PublicAPI
  20. if TYPE_CHECKING:
  21. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  22. #: This is the default schema used if no `input_read_schema` is set in
  23. #: the config. If a user passes in a schema into `input_read_schema`
  24. #: this user-defined schema has to comply with the keys of `SCHEMA`,
  25. #: while values correspond to the columns in the user's dataset. Note
  26. #: that only the user-defined values will be overridden while all
  27. #: other values from SCHEMA remain as defined here.
  28. SCHEMA = {
  29. Columns.EPS_ID: Columns.EPS_ID,
  30. Columns.AGENT_ID: Columns.AGENT_ID,
  31. Columns.MODULE_ID: Columns.MODULE_ID,
  32. Columns.OBS: Columns.OBS,
  33. Columns.ACTIONS: Columns.ACTIONS,
  34. Columns.REWARDS: Columns.REWARDS,
  35. Columns.INFOS: Columns.INFOS,
  36. Columns.NEXT_OBS: Columns.NEXT_OBS,
  37. Columns.TERMINATEDS: Columns.TERMINATEDS,
  38. Columns.TRUNCATEDS: Columns.TRUNCATEDS,
  39. Columns.T: Columns.T,
  40. # TODO (simon): Add remove as soon as we are new stack only.
  41. "agent_index": "agent_index",
  42. "dones": "dones",
  43. "unroll_id": "unroll_id",
  44. }
  45. logger = logging.getLogger(__name__)
  46. @PublicAPI(stability="alpha")
  47. class OfflinePreLearner:
  48. """Class that coordinates data transformation from dataset to learner.
  49. This class is an essential part of the new `Offline RL API` of `RLlib`.
  50. It is a callable class that is run in `ray.data.Dataset.map_batches`
  51. when iterating over batches for training. It's basic function is to
  52. convert data in batch from rows to episodes (`SingleAGentEpisode`s
  53. for now) and to then run the learner connector pipeline to convert
  54. further to trainable batches. These batches are used directly in the
  55. `Learner`'s `update` method.
  56. The main reason to run these transformations inside of `map_batches`
  57. is for better performance. Batches can be pre-fetched in `ray.data`
  58. and therefore batch trransformation can be run highly parallelized to
  59. the `Learner''s `update`.
  60. This class can be overridden to implement custom logic for transforming
  61. batches and make them 'Learner'-ready. When deriving from this class
  62. the `__call__` method and `_map_to_episodes` can be overridden to induce
  63. custom logic for the complete transformation pipeline (`__call__`) or
  64. for converting to episodes only ('_map_to_episodes`).
  65. Custom `OfflinePreLearner` classes can be passed into
  66. `AlgorithmConfig.offline`'s `prelearner_class`. The `OfflineData` class
  67. will then use the custom class in its data pipeline.
  68. """
  69. @OverrideToImplementCustomLogic_CallToSuperRecommended
  70. def __init__(
  71. self,
  72. *,
  73. config: "AlgorithmConfig",
  74. spaces: Optional[Tuple[gym.Space, gym.Space]] = None,
  75. module_spec: Optional[MultiRLModuleSpec] = None,
  76. module_state: Optional[Dict[ModuleID, Any]] = None,
  77. **kwargs: Dict[str, Any],
  78. ):
  79. self.config: AlgorithmConfig = config
  80. self.input_read_episodes: bool = self.config.input_read_episodes
  81. self.input_read_sample_batches: bool = self.config.input_read_sample_batches
  82. # Build the module from spec.
  83. self._module: MultiRLModule = module_spec.build()
  84. self._module.set_state(module_state)
  85. # Map the module to the device, if necessary.
  86. # TODO (simon): Check here if we already have a list.
  87. # self._set_device(device_strings)
  88. # Store the observation and action space if defined, otherwise we
  89. # set them to `None`. Note, if `None` the `convert_from_jsonable`
  90. # will not convert the input space samples.
  91. self.observation_space, self.action_space = spaces or (None, None)
  92. # Build the learner connector pipeline.
  93. self._learner_connector = self.config.build_learner_connector(
  94. input_observation_space=self.observation_space,
  95. input_action_space=self.action_space,
  96. )
  97. # Cache the policies to be trained to update weights only for these.
  98. self._policies_to_train = self.config.policies_to_train
  99. self._is_multi_agent: bool = config.is_multi_agent
  100. # Set the counter to zero.
  101. self.iter_since_last_module_update: int = 0
  102. # self._future = None
  103. # Set up an episode buffer, if the module is stateful or we sample from
  104. # `SampleBatch` types.
  105. if (
  106. self.input_read_sample_batches
  107. or self._module.is_stateful()
  108. or self.input_read_episodes
  109. ):
  110. # Either the user defined a buffer class or we fall back to the default.
  111. prelearner_buffer_class = (
  112. self.config.prelearner_buffer_class
  113. or self.default_prelearner_buffer_class
  114. )
  115. prelearner_buffer_kwargs = (
  116. self.default_prelearner_buffer_kwargs
  117. | self.config.prelearner_buffer_kwargs
  118. )
  119. # Initialize the buffer.
  120. self.episode_buffer = prelearner_buffer_class(
  121. **prelearner_buffer_kwargs,
  122. )
  123. @OverrideToImplementCustomLogic
  124. def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
  125. """Prepares plain data batches for training with `Learner`'s.
  126. Args:
  127. batch: A dictionary of numpy arrays containing either column data
  128. with `self.config.input_read_schema`, `EpisodeType` data, or
  129. `BatchType` data.
  130. Returns:
  131. A `MultiAgentBatch` that can be passed to `Learner.update` methods.
  132. """
  133. # If we directly read in episodes we just convert to list.
  134. if self.input_read_episodes:
  135. # Import `msgpack` for decoding.
  136. import msgpack
  137. import msgpack_numpy as mnp
  138. # Read the episodes and decode them.
  139. episodes: List[SingleAgentEpisode] = [
  140. SingleAgentEpisode.from_state(
  141. msgpack.unpackb(state, object_hook=mnp.decode)
  142. )
  143. for state in batch["item"]
  144. ]
  145. # Postprocess and sample from the buffer.
  146. episodes = self._postprocess_and_sample(episodes)
  147. # Else, if we have old stack `SampleBatch`es.
  148. elif self.input_read_sample_batches:
  149. episodes: List[
  150. SingleAgentEpisode
  151. ] = OfflinePreLearner._map_sample_batch_to_episode(
  152. self._is_multi_agent,
  153. batch,
  154. to_numpy=True,
  155. schema=SCHEMA | self.config.input_read_schema,
  156. input_compress_columns=self.config.input_compress_columns,
  157. )[
  158. "episodes"
  159. ]
  160. # Postprocess and sample from the buffer.
  161. episodes = self._postprocess_and_sample(episodes)
  162. # Otherwise we map the batch to episodes.
  163. else:
  164. episodes: List[SingleAgentEpisode] = self._map_to_episodes(
  165. self._is_multi_agent,
  166. batch,
  167. schema=SCHEMA | self.config.input_read_schema,
  168. to_numpy=False,
  169. input_compress_columns=self.config.input_compress_columns,
  170. observation_space=self.observation_space,
  171. action_space=self.action_space,
  172. )["episodes"]
  173. # TODO (simon): Make synching work. Right now this becomes blocking or never
  174. # receives weights. Learners appear to be non accessi ble via other actors.
  175. # Increase the counter for updating the module.
  176. # self.iter_since_last_module_update += 1
  177. # if self._future:
  178. # refs, _ = ray.wait([self._future], timeout=0)
  179. # print(f"refs: {refs}")
  180. # if refs:
  181. # module_state = ray.get(self._future)
  182. #
  183. # self._module.set_state(module_state)
  184. # self._future = None
  185. # # Synch the learner module, if necessary. Note, in case of a local learner
  186. # # we have a reference to the module and therefore an up-to-date module.
  187. # if self.learner_is_remote and self.iter_since_last_module_update
  188. # > self.config.prelearner_module_synch_period:
  189. # # Reset the iteration counter.
  190. # self.iter_since_last_module_update = 0
  191. # # Request the module weights from the remote learner.
  192. # self._future =
  193. # self._learner.get_module_state.remote(inference_only=False)
  194. # # module_state =
  195. # ray.get(self._learner.get_module_state.remote(inference_only=False))
  196. # # self._module.set_state(module_state)
  197. # Run the `Learner`'s connector pipeline.
  198. batch = self._learner_connector(
  199. rl_module=self._module,
  200. batch={},
  201. episodes=episodes,
  202. shared_data={},
  203. # TODO (sven): Add MetricsLogger to non-Learner components that have a
  204. # LearnerConnector pipeline.
  205. metrics=None,
  206. )
  207. # Remove all data from modules that should not be trained. We do
  208. # not want to pass around more data than necessary.
  209. for module_id in batch:
  210. if not self._should_module_be_updated(module_id, batch):
  211. del batch[module_id]
  212. # Flatten the dictionary to increase serialization performance.
  213. return flatten_dict(batch)
  214. @property
  215. def default_prelearner_buffer_class(self) -> ReplayBuffer:
  216. """Sets the default replay buffer."""
  217. from ray.rllib.utils.replay_buffers.episode_replay_buffer import (
  218. EpisodeReplayBuffer,
  219. )
  220. # Return the buffer.
  221. return EpisodeReplayBuffer
  222. @property
  223. def default_prelearner_buffer_kwargs(self) -> Dict[str, Any]:
  224. """Sets the default arguments for the replay buffer.
  225. Note, the `capacity` might vary with the size of the episodes or
  226. sample batches in the offline dataset.
  227. """
  228. return {
  229. "capacity": self.config.train_batch_size_per_learner * 10,
  230. "batch_size_B": self.config.train_batch_size_per_learner,
  231. }
  232. def _validate_episodes(
  233. self, episodes: List[SingleAgentEpisode]
  234. ) -> Set[SingleAgentEpisode]:
  235. """Validate episodes sampled from the dataset.
  236. Note, our episode buffers cannot handle either duplicates nor
  237. non-ordered fragmentations, i.e. fragments from episodes that do
  238. not arrive in timestep order.
  239. Args:
  240. episodes: A list of `SingleAgentEpisode` instances sampled
  241. from a dataset.
  242. Returns:
  243. A set of `SingleAgentEpisode` instances.
  244. Raises:
  245. ValueError: If not all episodes are `done`.
  246. """
  247. # Ensure that episodes are all done.
  248. if not all(eps.is_done for eps in episodes):
  249. raise ValueError(
  250. "When sampling from episodes (`input_read_episodes=True`) all "
  251. "recorded episodes must be done (i.e. either `terminated=True`) "
  252. "or `truncated=True`)."
  253. )
  254. # Ensure that episodes do not contain duplicates. Note, this can happen
  255. # if the dataset is small and pulled batches contain multiple episodes.
  256. unique_episode_ids = set()
  257. cleaned_episodes = set()
  258. for eps in episodes:
  259. if (
  260. eps.id_ not in unique_episode_ids
  261. and eps.id_ not in self.episode_buffer.episode_id_to_index
  262. ):
  263. unique_episode_ids.add(eps.id_)
  264. cleaned_episodes.add(eps)
  265. return cleaned_episodes
  266. def _remove_states_from_episodes(
  267. self,
  268. episodes: List[SingleAgentEpisode],
  269. ) -> List[SingleAgentEpisode]:
  270. """Removes states from episodes.
  271. This is necessary, if the module is stateful and we want to
  272. enable the offline RLModule to learn its own state representations.
  273. Args:
  274. episodes: A list of `SingleAgentEpisode` instances.
  275. Returns:
  276. A list of `SingleAgentEpisode` instances without states.
  277. """
  278. for eps in episodes:
  279. if Columns.STATE_OUT in eps.extra_model_outputs:
  280. del eps.extra_model_outputs[Columns.STATE_OUT]
  281. if Columns.STATE_IN in eps.extra_model_outputs:
  282. del eps.extra_model_outputs[Columns.STATE_IN]
  283. return episodes
  284. def _postprocess_and_sample(
  285. self, episodes: List[SingleAgentEpisode]
  286. ) -> List[SingleAgentEpisode]:
  287. """Postprocesses episodes and samples from the buffer.
  288. Args:
  289. episodes: A list of `SingleAgentEpisode` instances.
  290. Returns:
  291. A list of `SingleAgentEpisode` instances sampled from the buffer.
  292. """
  293. # Ensure that all episodes are done and no duplicates are in the batch.
  294. episodes = self._validate_episodes(episodes)
  295. if (
  296. self._module.is_stateful()
  297. and not self.config.prelearner_use_recorded_module_states
  298. ):
  299. episodes = self._remove_states_from_episodes(episodes)
  300. # Add the episodes to the buffer.
  301. self.episode_buffer.add(episodes)
  302. # Sample from the buffer.
  303. batch_length_T = (
  304. self.config.model_config.get("max_seq_len", 0)
  305. if self._module.is_stateful()
  306. else None
  307. )
  308. return self.episode_buffer.sample(
  309. num_items=self.config.train_batch_size_per_learner,
  310. batch_length_T=batch_length_T,
  311. n_step=self.config.get("n_step", 1),
  312. # TODO (simon): This can be removed as soon as DreamerV3 has been
  313. # cleaned up, i.e. can use episode samples for training.
  314. sample_episodes=True,
  315. to_numpy=True,
  316. lookback=self.config.episode_lookback_horizon,
  317. min_batch_length_T=getattr(self.config, "burnin_len", 0),
  318. )
  319. def _should_module_be_updated(self, module_id, multi_agent_batch=None) -> bool:
  320. """Checks which modules in a MultiRLModule should be updated."""
  321. if not self._policies_to_train:
  322. # In case of no update information, the module is updated.
  323. return True
  324. elif not callable(self._policies_to_train):
  325. return module_id in set(self._policies_to_train)
  326. else:
  327. return self._policies_to_train(module_id, multi_agent_batch)
  328. @OverrideToImplementCustomLogic
  329. @staticmethod
  330. def _map_to_episodes(
  331. is_multi_agent: bool,
  332. batch: Dict[str, Union[list, np.ndarray]],
  333. schema: Dict[str, str] = SCHEMA,
  334. to_numpy: bool = False,
  335. input_compress_columns: Optional[List[str]] = None,
  336. ignore_final_observation: Optional[bool] = False,
  337. observation_space: gym.Space = None,
  338. action_space: gym.Space = None,
  339. **kwargs: Dict[str, Any],
  340. ) -> Dict[str, List[EpisodeType]]:
  341. """Maps a batch of data to episodes."""
  342. # Set to empty list, if `None`.
  343. input_compress_columns = input_compress_columns or []
  344. episodes = []
  345. for i, obs in enumerate(batch[schema[Columns.OBS]]):
  346. # If multi-agent we need to extract the agent ID.
  347. # TODO (simon): Check, what happens with the module ID.
  348. if is_multi_agent:
  349. agent_id = (
  350. batch[schema[Columns.AGENT_ID]][i]
  351. if Columns.AGENT_ID in batch
  352. # The old stack uses "agent_index" instead of "agent_id".
  353. # TODO (simon): Remove this as soon as we are new stack only.
  354. else (
  355. batch[schema["agent_index"]][i]
  356. if schema["agent_index"] in batch
  357. else None
  358. )
  359. )
  360. else:
  361. agent_id = None
  362. if is_multi_agent:
  363. # TODO (simon): Add support for multi-agent episodes.
  364. pass
  365. else:
  366. # Unpack observations, if needed.
  367. unpacked_obs = (
  368. unpack_if_needed(obs)
  369. if Columns.OBS in input_compress_columns
  370. else obs
  371. )
  372. # Set the next observation.
  373. if ignore_final_observation:
  374. unpacked_next_obs = tree.map_structure(
  375. lambda x: 0 * x, copy.deepcopy(unpacked_obs)
  376. )
  377. else:
  378. unpacked_next_obs = (
  379. unpack_if_needed(batch[schema[Columns.NEXT_OBS]][i])
  380. if Columns.OBS in input_compress_columns
  381. else batch[schema[Columns.NEXT_OBS]][i]
  382. )
  383. # Build a single-agent episode with a single row of the batch.
  384. episode = SingleAgentEpisode(
  385. id_=str(batch[schema[Columns.EPS_ID]][i])
  386. if schema[Columns.EPS_ID] in batch
  387. else uuid.uuid4().hex,
  388. agent_id=agent_id,
  389. # Observations might be (a) serialized and/or (b) converted
  390. # to a JSONable (when a composite space was used). We unserialize
  391. # and then reconvert from JSONable to space sample.
  392. observations=[unpacked_obs, unpacked_next_obs],
  393. infos=[
  394. {},
  395. batch[schema[Columns.INFOS]][i]
  396. if schema[Columns.INFOS] in batch
  397. else {},
  398. ],
  399. # Actions might be (a) serialized and/or (b) converted to a JSONable
  400. # (when a composite space was used). We unserializer and then
  401. # reconvert from JSONable to space sample.
  402. actions=[
  403. unpack_if_needed(batch[schema[Columns.ACTIONS]][i])
  404. if Columns.ACTIONS in input_compress_columns
  405. else batch[schema[Columns.ACTIONS]][i]
  406. ],
  407. rewards=[batch[schema[Columns.REWARDS]][i]],
  408. terminated=batch[
  409. schema[Columns.TERMINATEDS]
  410. if schema[Columns.TERMINATEDS] in batch
  411. else "dones"
  412. ][i],
  413. truncated=batch[schema[Columns.TRUNCATEDS]][i]
  414. if schema[Columns.TRUNCATEDS] in batch
  415. else False,
  416. # TODO (simon): Results in zero-length episodes in connector.
  417. # t_started=batch[Columns.T if Columns.T in batch else
  418. # "unroll_id"][i][0],
  419. # TODO (simon): Single-dimensional columns are not supported.
  420. # Extra model outputs might be serialized. We unserialize them here
  421. # if needed.
  422. # TODO (simon): Check, if we need here also reconversion from
  423. # JSONable in case of composite spaces.
  424. extra_model_outputs={
  425. k: [
  426. unpack_if_needed(v[i])
  427. if k in input_compress_columns
  428. else v[i]
  429. ]
  430. for k, v in batch.items()
  431. if (
  432. k not in schema
  433. and k not in schema.values()
  434. and k not in ["dones", "agent_index", "type"]
  435. )
  436. },
  437. len_lookback_buffer=0,
  438. )
  439. if to_numpy:
  440. episode.to_numpy()
  441. episodes.append(episode)
  442. # Note, `map_batches` expects a `Dict` as return value.
  443. return {"episodes": episodes}
  444. @OverrideToImplementCustomLogic
  445. @staticmethod
  446. def _map_sample_batch_to_episode(
  447. is_multi_agent: bool,
  448. batch: Dict[str, Union[list, np.ndarray]],
  449. schema: Dict[str, str] = SCHEMA,
  450. to_numpy: bool = False,
  451. input_compress_columns: Optional[List[str]] = None,
  452. ) -> Dict[str, List[EpisodeType]]:
  453. """Maps an old stack `SampleBatch` to new stack episodes."""
  454. # Set `input_compress_columns` to an empty `list` if `None`.
  455. input_compress_columns = input_compress_columns or []
  456. # TODO (simon): CHeck, if needed. It could possibly happen that a batch contains
  457. # data from different episodes. Merging and resplitting the batch would then
  458. # be the solution.
  459. # Check, if batch comes actually from multiple episodes.
  460. # episode_begin_indices = np.where(np.diff(np.hstack(batch["eps_id"])) != 0) + 1
  461. # Define a container to collect episodes.
  462. episodes = []
  463. # Loop over `SampleBatch`es in the `ray.data` batch (a dict).
  464. for i, obs in enumerate(batch[schema[Columns.OBS]]):
  465. # If multi-agent we need to extract the agent ID.
  466. # TODO (simon): Check, what happens with the module ID.
  467. if is_multi_agent:
  468. agent_id = (
  469. # The old stack uses "agent_index" instead of "agent_id".
  470. batch[schema["agent_index"]][i][0]
  471. if schema["agent_index"] in batch
  472. else None
  473. )
  474. else:
  475. agent_id = None
  476. if is_multi_agent:
  477. # TODO (simon): Add support for multi-agent episodes.
  478. pass
  479. else:
  480. # Unpack observations, if needed. Note, observations could
  481. # be either compressed by their entirety (the complete batch
  482. # column) or individually (each column entry).
  483. if isinstance(obs, str):
  484. # Decompress the observations if we have a string, i.e.
  485. # observations are compressed in their entirety.
  486. obs = unpack_if_needed(obs)
  487. # Convert to a list of arrays. This is needed as input by
  488. # the `SingleAgentEpisode`.
  489. obs = [obs[i, ...] for i in range(obs.shape[0])]
  490. # Otherwise observations are only compressed inside of the
  491. # batch column (if at all).
  492. elif isinstance(obs, np.ndarray):
  493. # Unpack observations, if they are compressed otherwise we
  494. # simply convert to a list, which is needed by the
  495. # `SingleAgentEpisode`.
  496. obs = (
  497. unpack_if_needed(obs.tolist())
  498. if schema[Columns.OBS] in input_compress_columns
  499. else obs.tolist()
  500. )
  501. else:
  502. raise TypeError(
  503. f"Unknown observation type: {type(obs)}. When mapping "
  504. "from old recorded `SampleBatches` batched "
  505. "observations should be either of type `np.array` "
  506. "or - if the column is compressed - of `str` type."
  507. )
  508. if schema[Columns.NEXT_OBS] in batch:
  509. # Append the last `new_obs` to get the correct length of
  510. # observations.
  511. obs.append(
  512. unpack_if_needed(batch[schema[Columns.NEXT_OBS]][i][-1])
  513. if schema[Columns.OBS] in input_compress_columns
  514. else batch[schema[Columns.NEXT_OBS]][i][-1]
  515. )
  516. else:
  517. # Otherwise we duplicate the last observation.
  518. obs.append(obs[-1])
  519. # Check, if we have `done`, `truncated`, or `terminated`s in
  520. # the batch.
  521. if (
  522. schema[Columns.TRUNCATEDS] in batch
  523. and schema[Columns.TERMINATEDS] in batch
  524. ):
  525. truncated = batch[schema[Columns.TRUNCATEDS]][i][-1]
  526. terminated = batch[schema[Columns.TERMINATEDS]][i][-1]
  527. elif (
  528. schema[Columns.TRUNCATEDS] in batch
  529. and schema[Columns.TERMINATEDS] not in batch
  530. ):
  531. truncated = batch[schema[Columns.TRUNCATEDS]][i][-1]
  532. terminated = False
  533. elif (
  534. schema[Columns.TRUNCATEDS] not in batch
  535. and schema[Columns.TERMINATEDS] in batch
  536. ):
  537. terminated = batch[schema[Columns.TERMINATEDS]][i][-1]
  538. truncated = False
  539. elif "done" in batch:
  540. terminated = batch["done"][i][-1]
  541. truncated = False
  542. # Otherwise, if no `terminated`, nor `truncated` nor `done`
  543. # is given, we consider the episode as terminated.
  544. else:
  545. terminated = True
  546. truncated = False
  547. # Create a `SingleAgentEpisode`.
  548. episode = SingleAgentEpisode(
  549. # If the recorded episode has an ID we use this ID,
  550. # otherwise we generate a new one.
  551. id_=str(batch[schema[Columns.EPS_ID]][i][0])
  552. if schema[Columns.EPS_ID] in batch
  553. else uuid.uuid4().hex,
  554. agent_id=agent_id,
  555. observations=obs,
  556. infos=(
  557. batch[schema[Columns.INFOS]][i]
  558. if schema[Columns.INFOS] in batch
  559. else [{}] * len(obs)
  560. ),
  561. # Actions might be (a) serialized. We unserialize them here.
  562. actions=(
  563. unpack_if_needed(batch[schema[Columns.ACTIONS]][i])
  564. if Columns.ACTIONS in input_compress_columns
  565. else batch[schema[Columns.ACTIONS]][i]
  566. ),
  567. rewards=batch[schema[Columns.REWARDS]][i],
  568. terminated=terminated,
  569. truncated=truncated,
  570. # TODO (simon): Results in zero-length episodes in connector.
  571. # t_started=batch[Columns.T if Columns.T in batch else
  572. # "unroll_id"][i][0],
  573. # TODO (simon): Single-dimensional columns are not supported.
  574. # Extra model outputs might be serialized. We unserialize them here
  575. # if needed.
  576. # TODO (simon): Check, if we need here also reconversion from
  577. # JSONable in case of composite spaces.
  578. extra_model_outputs={
  579. k: unpack_if_needed(v[i])
  580. if k in input_compress_columns
  581. else v[i]
  582. for k, v in batch.items()
  583. if (
  584. k not in schema
  585. and k not in schema.values()
  586. and k not in ["dones", "agent_index", "type"]
  587. )
  588. },
  589. len_lookback_buffer=0,
  590. )
  591. # Numpy'ized, if necessary.
  592. # TODO (simon, sven): Check, if we should convert all data to lists
  593. # before. Right now only obs are lists.
  594. if to_numpy:
  595. episode.to_numpy()
  596. episodes.append(episode)
  597. # Note, `map_batches` expects a `Dict` as return value.
  598. return {"episodes": episodes}