single_agent_episode.py 85 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836
  1. import copy
  2. import functools
  3. import time
  4. import uuid
  5. from collections import defaultdict
  6. from typing import Any, Dict, List, Optional, SupportsFloat, Union
  7. import gymnasium as gym
  8. import numpy as np
  9. import tree
  10. from gymnasium.core import ActType, ObsType
  11. from ray._common.deprecation import Deprecated
  12. from ray.rllib.core.columns import Columns
  13. from ray.rllib.env.utils.infinite_lookback_buffer import InfiniteLookbackBuffer
  14. from ray.rllib.policy.sample_batch import SampleBatch
  15. from ray.rllib.utils.serialization import gym_space_from_dict, gym_space_to_dict
  16. from ray.rllib.utils.typing import AgentID, ModuleID
  17. from ray.util.annotations import PublicAPI
  18. @PublicAPI(stability="alpha")
  19. class SingleAgentEpisode:
  20. """A class representing RL environment episodes for individual agents.
  21. SingleAgentEpisode stores observations, info dicts, actions, rewards, and all
  22. module outputs (e.g. state outs, action logp, etc..) for an individual agent within
  23. some single-agent or multi-agent environment.
  24. The two main APIs to add data to an ongoing episode are the `add_env_reset()`
  25. and `add_env_step()` methods, which should be called passing the outputs of the
  26. respective gym.Env API calls: `env.reset()` and `env.step()`.
  27. A SingleAgentEpisode might also only represent a chunk of an episode, which is
  28. useful for cases, in which partial (non-complete episode) sampling is performed
  29. and collected episode data has to be returned before the actual gym.Env episode has
  30. finished (see `SingleAgentEpisode.cut()`). In order to still maintain visibility
  31. onto past experiences within such a "cut" episode, SingleAgentEpisode instances
  32. can have a "lookback buffer" of n timesteps at their beginning (left side), which
  33. solely exists for the purpose of compiling extra data (e.g. "prev. reward"), but
  34. is not considered part of the finished/packaged episode (b/c the data in the
  35. lookback buffer is already part of a previous episode chunk).
  36. Powerful getter methods, such as `get_observations()` help collect different types
  37. of data from the episode at individual time indices or time ranges, including the
  38. "lookback buffer" range described above. For example, to extract the last 4 rewards
  39. of an ongoing episode, one can call `self.get_rewards(slice(-4, None))` or
  40. `self.rewards[-4:]`. This would work, even if the ongoing SingleAgentEpisode is
  41. a continuation chunk from a much earlier started episode, as long as it has a
  42. lookback buffer size of sufficient size.
  43. Examples:
  44. .. testcode::
  45. import gymnasium as gym
  46. import numpy as np
  47. from ray.rllib.env.single_agent_episode import SingleAgentEpisode
  48. # Construct a new episode (without any data in it yet).
  49. episode = SingleAgentEpisode()
  50. assert len(episode) == 0
  51. # Fill the episode with some data (10 timesteps).
  52. env = gym.make("CartPole-v1")
  53. obs, infos = env.reset()
  54. episode.add_env_reset(obs, infos)
  55. # Even with the initial obs/infos, the episode is still considered len=0.
  56. assert len(episode) == 0
  57. for _ in range(5):
  58. action = env.action_space.sample()
  59. obs, reward, term, trunc, infos = env.step(action)
  60. episode.add_env_step(
  61. observation=obs,
  62. action=action,
  63. reward=reward,
  64. terminated=term,
  65. truncated=trunc,
  66. infos=infos,
  67. )
  68. assert len(episode) == 5
  69. # We can now access information from the episode via the getter APIs.
  70. # Get the last 3 rewards (in a batch of size 3).
  71. episode.get_rewards(slice(-3, None)) # same as `episode.rewards[-3:]`
  72. # Get the most recent action (single item, not batched).
  73. # This works regardless of the action space or whether the episode has
  74. # been numpy'ized or not (see below).
  75. episode.get_actions(-1) # same as episode.actions[-1]
  76. # Looking back from ts=1, get the previous 4 rewards AND fill with 0.0
  77. # in case we go over the beginning (ts=0). So we would expect
  78. # [0.0, 0.0, 0.0, r0] to be returned here, where r0 is the very first received
  79. # reward in the episode:
  80. episode.get_rewards(slice(-4, 0), neg_index_as_lookback=True, fill=0.0)
  81. # Note the use of fill=0.0 here (fill everything that's out of range with this
  82. # value) AND the argument `neg_index_as_lookback=True`, which interprets
  83. # negative indices as being left of ts=0 (e.g. -1 being the timestep before
  84. # ts=0).
  85. # Assuming we had a complex action space (nested gym.spaces.Dict) with one or
  86. # more elements being Discrete or MultiDiscrete spaces:
  87. # 1) The `fill=...` argument would still work, filling all spaces (Boxes,
  88. # Discrete) with that provided value.
  89. # 2) Setting the flag `one_hot_discrete=True` would convert those discrete
  90. # sub-components automatically into one-hot (or multi-one-hot) tensors.
  91. # This simplifies the task of having to provide the previous 4 (nested and
  92. # partially discrete/multi-discrete) actions for each timestep within a training
  93. # batch, thereby filling timesteps before the episode started with 0.0s and
  94. # one-hot'ing the discrete/multi-discrete components in these actions:
  95. episode = SingleAgentEpisode(action_space=gym.spaces.Dict({
  96. "a": gym.spaces.Discrete(3),
  97. "b": gym.spaces.MultiDiscrete([2, 3]),
  98. "c": gym.spaces.Box(-1.0, 1.0, (2,)),
  99. }))
  100. # ... fill episode with data ...
  101. episode.add_env_reset(observation=0)
  102. # ... from a few steps.
  103. episode.add_env_step(
  104. observation=1,
  105. action={"a":0, "b":np.array([1, 2]), "c":np.array([.5, -.5], np.float32)},
  106. reward=1.0,
  107. )
  108. # In your connector
  109. prev_4_a = []
  110. # Note here that len(episode) does NOT include the lookback buffer.
  111. for ts in range(len(episode)):
  112. prev_4_a.append(
  113. episode.get_actions(
  114. indices=slice(ts - 4, ts),
  115. # Make sure negative indices are interpreted as
  116. # "into lookback buffer"
  117. neg_index_as_lookback=True,
  118. # Zero-out everything even further before the lookback buffer.
  119. fill=0.0,
  120. # Take care of discrete components (get ready as NN input).
  121. one_hot_discrete=True,
  122. )
  123. )
  124. # Finally, convert from list of batch items to a struct (same as action space)
  125. # of batched (numpy) arrays, in which all leafs have B==len(prev_4_a).
  126. from ray.rllib.utils.spaces.space_utils import batch
  127. prev_4_actions_col = batch(prev_4_a)
  128. """
  129. __slots__ = (
  130. "actions",
  131. "agent_id",
  132. "extra_model_outputs",
  133. "id_",
  134. "infos",
  135. "is_terminated",
  136. "is_truncated",
  137. "module_id",
  138. "multi_agent_episode_id",
  139. "observations",
  140. "rewards",
  141. "t",
  142. "t_started",
  143. "_action_space",
  144. "_last_added_observation",
  145. "_last_added_infos",
  146. "_last_step_time",
  147. "_observation_space",
  148. "_start_time",
  149. "_custom_data",
  150. )
  151. def __init__(
  152. self,
  153. id_: Optional[str] = None,
  154. *,
  155. observations: Optional[Union[List[ObsType], InfiniteLookbackBuffer]] = None,
  156. observation_space: Optional[gym.Space] = None,
  157. infos: Optional[Union[List[Dict], InfiniteLookbackBuffer]] = None,
  158. actions: Optional[Union[List[ActType], InfiniteLookbackBuffer]] = None,
  159. action_space: Optional[gym.Space] = None,
  160. rewards: Optional[Union[List[SupportsFloat], InfiniteLookbackBuffer]] = None,
  161. terminated: bool = False,
  162. truncated: bool = False,
  163. extra_model_outputs: Optional[Dict[str, Any]] = None,
  164. t_started: Optional[int] = None,
  165. len_lookback_buffer: Union[int, str] = "auto",
  166. agent_id: Optional[AgentID] = None,
  167. module_id: Optional[ModuleID] = None,
  168. multi_agent_episode_id: Optional[int] = None,
  169. ):
  170. """Initializes a SingleAgentEpisode instance.
  171. This constructor can be called with or without already sampled data, part of
  172. which might then go into the lookback buffer.
  173. Args:
  174. id_: Unique identifier for this episode. If no ID is provided the
  175. constructor generates a unique hexadecimal code for the id.
  176. observations: Either a list of individual observations from a sampling or
  177. an already instantiated `InfiniteLookbackBuffer` object (possibly
  178. with observation data in it). If a list, will construct the buffer
  179. automatically (given the data and the `len_lookback_buffer` argument).
  180. observation_space: An optional gym.Space, which all individual observations
  181. should abide to. If not None and this SingleAgentEpisode is numpy'ized
  182. (via the `self.to_numpy()` method), and data is appended or set, the new
  183. data will be checked for correctness.
  184. infos: Either a list of individual info dicts from a sampling or
  185. an already instantiated `InfiniteLookbackBuffer` object (possibly
  186. with info dicts in it). If a list, will construct the buffer
  187. automatically (given the data and the `len_lookback_buffer` argument).
  188. actions: Either a list of individual info dicts from a sampling or
  189. an already instantiated `InfiniteLookbackBuffer` object (possibly
  190. with info dict] data in it). If a list, will construct the buffer
  191. automatically (given the data and the `len_lookback_buffer` argument).
  192. action_space: An optional gym.Space, which all individual actions
  193. should abide to. If not None and this SingleAgentEpisode is numpy'ized
  194. (via the `self.to_numpy()` method), and data is appended or set, the new
  195. data will be checked for correctness.
  196. rewards: Either a list of individual rewards from a sampling or
  197. an already instantiated `InfiniteLookbackBuffer` object (possibly
  198. with reward data in it). If a list, will construct the buffer
  199. automatically (given the data and the `len_lookback_buffer` argument).
  200. extra_model_outputs: A dict mapping string keys to either lists of
  201. individual extra model output tensors (e.g. `action_logp` or
  202. `state_outs`) from a sampling or to already instantiated
  203. `InfiniteLookbackBuffer` object (possibly with extra model output data
  204. in it). If mapping is to lists, will construct the buffers automatically
  205. (given the data and the `len_lookback_buffer` argument).
  206. terminated: A boolean indicating, if the episode is already terminated.
  207. truncated: A boolean indicating, if the episode has been truncated.
  208. t_started: Optional. The starting timestep of the episode. The default
  209. is zero. If data is provided, the starting point is from the last
  210. observation onwards (i.e. `t_started = len(observations) - 1`). If
  211. this parameter is provided the episode starts at the provided value.
  212. len_lookback_buffer: The size of the (optional) lookback buffers to keep in
  213. front of this Episode for each type of data (observations, actions,
  214. etc..). If larger than 0, the first `len_lookback_buffer`
  215. items of each type of data are interpreted as NOT part of this actual
  216. episode chunk, but instead serve as "historical" record that may be
  217. viewed and used to derive new data from. For example, it might be
  218. necessary to have a lookback buffer of four if you would like to do
  219. observation frame stacking and your episode has been cut and you're now
  220. operating on a new chunk (continuing from the cut one). Then, for the
  221. first 3 items, you would have to be able to look back into the old
  222. chunk's data.
  223. If `len_lookback_buffer` is "auto" (default), will interpret all
  224. provided data in the constructor as part of the lookback buffers.
  225. agent_id: An optional AgentID indicating which agent this episode belongs
  226. to. This information is stored under `self.agent_id` and only serves
  227. reference purposes.
  228. module_id: An optional ModuleID indicating which RLModule this episode
  229. belongs to. Normally, this information is obtained by querying an
  230. `agent_to_module_mapping_fn` with a given agent ID. This information
  231. is stored under `self.module_id` and only serves reference purposes.
  232. multi_agent_episode_id: An optional EpisodeID of the encapsulating
  233. `MultiAgentEpisode` that this `SingleAgentEpisode` belongs to.
  234. """
  235. self.id_ = id_ or uuid.uuid4().hex
  236. self.agent_id = agent_id
  237. self.module_id = module_id
  238. self.multi_agent_episode_id = multi_agent_episode_id
  239. # Lookback buffer length is not provided. Interpret already given data as
  240. # lookback buffer lengths for all data types.
  241. len_rewards = len(rewards) if rewards is not None else 0
  242. if len_lookback_buffer == "auto" or len_lookback_buffer > len_rewards:
  243. len_lookback_buffer = len_rewards
  244. infos = infos or [{} for _ in range(len(observations or []))]
  245. # Observations: t0 (initial obs) to T.
  246. self._observation_space = None
  247. if isinstance(observations, InfiniteLookbackBuffer):
  248. self.observations = observations
  249. else:
  250. self.observations = InfiniteLookbackBuffer(
  251. data=observations,
  252. lookback=len_lookback_buffer,
  253. )
  254. self.observation_space = observation_space
  255. # Infos: t0 (initial info) to T.
  256. if isinstance(infos, InfiniteLookbackBuffer):
  257. self.infos = infos
  258. else:
  259. self.infos = InfiniteLookbackBuffer(
  260. data=infos,
  261. lookback=len_lookback_buffer,
  262. )
  263. # Actions: t1 to T.
  264. self._action_space = None
  265. if isinstance(actions, InfiniteLookbackBuffer):
  266. self.actions = actions
  267. else:
  268. self.actions = InfiniteLookbackBuffer(
  269. data=actions,
  270. lookback=len_lookback_buffer,
  271. )
  272. self.action_space = action_space
  273. # Rewards: t1 to T.
  274. if isinstance(rewards, InfiniteLookbackBuffer):
  275. self.rewards = rewards
  276. else:
  277. self.rewards = InfiniteLookbackBuffer(
  278. data=rewards,
  279. lookback=len_lookback_buffer,
  280. space=gym.spaces.Box(float("-inf"), float("inf"), (), np.float32),
  281. )
  282. # obs[-1] is the final observation in the episode.
  283. self.is_terminated = terminated
  284. # obs[-1] is the last obs in a truncated-by-the-env episode (there will no more
  285. # observations in following chunks for this episode).
  286. self.is_truncated = truncated
  287. # Extra model outputs, e.g. `action_dist_input` needed in the batch.
  288. self.extra_model_outputs = {}
  289. for k, v in (extra_model_outputs or {}).items():
  290. if isinstance(v, InfiniteLookbackBuffer):
  291. self.extra_model_outputs[k] = v
  292. else:
  293. # We cannot use the defaultdict's own constructor here as this would
  294. # auto-set the lookback buffer to 0 (there is no data passed to that
  295. # constructor). Then, when we manually have to set the data property,
  296. # the lookback buffer would still be (incorrectly) 0.
  297. self.extra_model_outputs[k] = InfiniteLookbackBuffer(
  298. data=v, lookback=len_lookback_buffer
  299. )
  300. # The (global) timestep when this episode (possibly an episode chunk) started,
  301. # excluding a possible lookback buffer.
  302. self.t_started = t_started or 0
  303. # The current (global) timestep in the episode (possibly an episode chunk).
  304. self.t = len(self.rewards) + self.t_started
  305. # Cache for custom data. May be used to store custom metrics from within a
  306. # callback for the ongoing episode (e.g. render images).
  307. self._custom_data = {}
  308. # Keep timer stats on deltas between steps.
  309. self._start_time = None
  310. self._last_step_time = None
  311. self._last_added_observation = None
  312. self._last_added_infos = None
  313. # Validate the episode data thus far.
  314. self.validate()
  315. def add_env_reset(
  316. self,
  317. observation: ObsType,
  318. infos: Optional[Dict] = None,
  319. ) -> None:
  320. """Adds the initial data (after an `env.reset()`) to the episode.
  321. This data consists of initial observations and initial infos.
  322. Args:
  323. observation: The initial observation returned by `env.reset()`.
  324. infos: An (optional) info dict returned by `env.reset()`.
  325. """
  326. assert not self.is_reset
  327. assert not self.is_done
  328. assert len(self.observations) == 0
  329. # Assume that this episode is completely empty and has not stepped yet.
  330. # Leave self.t (and self.t_started) at 0.
  331. assert self.t == self.t_started == 0
  332. infos = infos or {}
  333. self.observations.append(observation)
  334. self.infos.append(infos)
  335. self._last_added_observation = observation
  336. self._last_added_infos = infos
  337. # Validate our data.
  338. self.validate()
  339. # Start the timer for this episode.
  340. self._start_time = time.perf_counter()
  341. def add_env_step(
  342. self,
  343. observation: ObsType,
  344. action: ActType,
  345. reward: SupportsFloat,
  346. infos: Optional[Dict[str, Any]] = None,
  347. *,
  348. terminated: bool = False,
  349. truncated: bool = False,
  350. extra_model_outputs: Optional[Dict[str, Any]] = None,
  351. ) -> None:
  352. """Adds results of an `env.step()` call (including the action) to this episode.
  353. This data consists of an observation and info dict, an action, a reward,
  354. terminated/truncated flags, and extra model outputs (e.g. action probabilities
  355. or RNN internal state outputs).
  356. Args:
  357. observation: The next observation received from the environment after(!)
  358. taking `action`.
  359. action: The last action used by the agent during the call to `env.step()`.
  360. reward: The last reward received by the agent after taking `action`.
  361. infos: The last info received from the environment after taking `action`.
  362. terminated: A boolean indicating, if the environment has been
  363. terminated (after taking `action`).
  364. truncated: A boolean indicating, if the environment has been
  365. truncated (after taking `action`).
  366. extra_model_outputs: The last timestep's specific model outputs.
  367. These are normally outputs of an RLModule that were computed along with
  368. `action`, e.g. `action_logp` or `action_dist_inputs`.
  369. """
  370. # Cannot add data to an already done episode.
  371. assert (
  372. not self.is_done
  373. ), "The agent is already done: no data can be added to its episode."
  374. self.observations.append(observation)
  375. self.actions.append(action)
  376. self.rewards.append(reward)
  377. infos = infos or {}
  378. self.infos.append(infos)
  379. self.t += 1
  380. if extra_model_outputs is not None:
  381. for k, v in extra_model_outputs.items():
  382. if k not in self.extra_model_outputs:
  383. self.extra_model_outputs[k] = InfiniteLookbackBuffer([v])
  384. else:
  385. self.extra_model_outputs[k].append(v)
  386. self.is_terminated = terminated
  387. self.is_truncated = truncated
  388. self._last_added_observation = observation
  389. self._last_added_infos = infos
  390. # Only check spaces if numpy'ized AND every n timesteps.
  391. if self.is_numpy and self.t % 100:
  392. if self.observation_space is not None:
  393. assert self.observation_space.contains(observation), (
  394. f"`observation` {observation} does NOT fit SingleAgentEpisode's "
  395. f"observation_space: {self.observation_space}!"
  396. )
  397. if self.action_space is not None:
  398. assert self.action_space.contains(action), (
  399. f"`action` {action} does NOT fit SingleAgentEpisode's "
  400. f"action_space: {self.action_space}!"
  401. )
  402. # Validate our data.
  403. self.validate()
  404. # Step time stats.
  405. self._last_step_time = time.perf_counter()
  406. if self._start_time is None:
  407. self._start_time = self._last_step_time
  408. def validate(self) -> None:
  409. """Validates the episode's data.
  410. This function ensures that the data stored to a `SingleAgentEpisode` is
  411. in order (e.g. that the correct number of observations, actions, rewards
  412. are there).
  413. """
  414. assert len(self.observations) == len(self.infos)
  415. if len(self.observations) == 0:
  416. assert len(self.infos) == len(self.rewards) == len(self.actions) == 0
  417. for k, v in self.extra_model_outputs.items():
  418. assert len(v) == 0, (k, v, v.data, len(v))
  419. # Make sure we always have one more obs stored than rewards (and actions)
  420. # due to the reset/last-obs logic of an MDP.
  421. else:
  422. assert (
  423. len(self.observations)
  424. == len(self.infos)
  425. == len(self.rewards) + 1
  426. == len(self.actions) + 1
  427. ), (
  428. len(self.observations),
  429. len(self.infos),
  430. len(self.rewards),
  431. len(self.actions),
  432. )
  433. for k, v in self.extra_model_outputs.items():
  434. assert len(v) == len(self.observations) - 1
  435. @property
  436. def custom_data(self):
  437. return self._custom_data
  438. @property
  439. def is_reset(self) -> bool:
  440. """Returns True if `self.add_env_reset()` has already been called."""
  441. return len(self.observations) > 0
  442. @property
  443. def is_numpy(self) -> bool:
  444. """True, if the data in this episode is already stored as numpy arrays."""
  445. # If rewards are still a list, return False.
  446. # Otherwise, rewards should already be a (1D) numpy array.
  447. return self.rewards.finalized
  448. @property
  449. def is_done(self) -> bool:
  450. """Whether the episode is actually done (terminated or truncated).
  451. A done episode cannot be continued via `self.add_timestep()` or being
  452. concatenated on its right-side with another episode chunk or being
  453. succeeded via `self.create_successor()`.
  454. """
  455. return self.is_terminated or self.is_truncated
  456. def to_numpy(self) -> "SingleAgentEpisode":
  457. """Converts this Episode's list attributes to numpy arrays.
  458. This means in particular that this episodes' lists of (possibly complex)
  459. data (e.g. if we have a dict obs space) will be converted to (possibly complex)
  460. structs, whose leafs are now numpy arrays. Each of these leaf numpy arrays will
  461. have the same length (batch dimension) as the length of the original lists.
  462. Note that the data under the Columns.INFOS are NEVER numpy'ized and will remain
  463. a list (normally, a list of the original, env-returned dicts). This is due to
  464. the herterogenous nature of INFOS returned by envs, which would make it unwieldy
  465. to convert this information to numpy arrays.
  466. After calling this method, no further data may be added to this episode via
  467. the `self.add_env_step()` method.
  468. Examples:
  469. .. testcode::
  470. import numpy as np
  471. from ray.rllib.env.single_agent_episode import SingleAgentEpisode
  472. episode = SingleAgentEpisode(
  473. observations=[0, 1, 2, 3],
  474. actions=[1, 2, 3],
  475. rewards=[1, 2, 3],
  476. # Note: terminated/truncated have nothing to do with an episode
  477. # being numpy'ized or not (via the `self.to_numpy()` method)!
  478. terminated=False,
  479. len_lookback_buffer=0, # no lookback; all data is actually "in" episode
  480. )
  481. # Episode has not been numpy'ized yet.
  482. assert not episode.is_numpy
  483. # We are still operating on lists.
  484. assert episode.get_observations([1]) == [1]
  485. assert episode.get_observations(slice(None, 2)) == [0, 1]
  486. # We can still add data (and even add the terminated=True flag).
  487. episode.add_env_step(
  488. observation=4,
  489. action=4,
  490. reward=4,
  491. terminated=True,
  492. )
  493. # Still NOT numpy'ized.
  494. assert not episode.is_numpy
  495. # Numpy'ized the episode.
  496. episode.to_numpy()
  497. assert episode.is_numpy
  498. # We cannot add data anymore. The following would crash.
  499. # episode.add_env_step(observation=5, action=5, reward=5)
  500. # Everything is now numpy arrays (with 0-axis of size
  501. # B=[len of requested slice]).
  502. assert isinstance(episode.get_observations([1]), np.ndarray) # B=1
  503. assert isinstance(episode.actions[0:2], np.ndarray) # B=2
  504. assert isinstance(episode.rewards[1:4], np.ndarray) # B=3
  505. Returns:
  506. This `SingleAgentEpisode` object with the converted numpy data.
  507. """
  508. self.observations.finalize()
  509. if len(self) > 0:
  510. self.actions.finalize()
  511. self.rewards.finalize()
  512. for k, v in self.extra_model_outputs.items():
  513. self.extra_model_outputs[k].finalize()
  514. return self
  515. def concat_episode(self, other: "SingleAgentEpisode") -> None:
  516. """Adds the given `other` SingleAgentEpisode to the right side of `self`.
  517. In order for this to work, both chunks (`self` and `other`) must fit
  518. together. This is checked by the IDs (must be identical), the time step counters
  519. (`self.env_t` must be the same as `other.env_t_started`), as well as the
  520. observations/infos at the concatenation boundaries. Also, `self.is_done` must
  521. not be True, meaning `self.is_terminated` and `self.is_truncated` are both
  522. False.
  523. Args:
  524. other: The other `SingleAgentEpisode` to be concatenated to this one.
  525. Returns:
  526. A `SingleAgentEpisode` instance containing the concatenated data
  527. from both episodes (`self` and `other`).
  528. """
  529. assert other.id_ == self.id_
  530. # NOTE (sven): This is what we agreed on. As the replay buffers must be
  531. # able to concatenate.
  532. assert not self.is_done
  533. # Make sure the timesteps match.
  534. assert self.t == other.t_started, f"{self.t=}, {other.t_started=}"
  535. # Validate `other`.
  536. other.validate()
  537. # Make sure, end matches other episode chunk's beginning.
  538. tree.assert_same_structure(other.observations[0], self.observations[-1])
  539. # Use tree.map_structure with np.array_equal to check every leaf node are equivalent
  540. # then np.all on flatten to validate all are tree
  541. assert np.all(
  542. tree.flatten(
  543. tree.map_structure(
  544. np.array_equal, other.observations[0], self.observations[-1]
  545. )
  546. )
  547. )
  548. # Pop out our last observations and infos (as these are identical
  549. # to the first obs and infos in the next episode).
  550. self.observations.pop()
  551. self.infos.pop()
  552. # Extend ourselves. In case, episode_chunk is already terminated and numpy'ized
  553. # we need to convert to lists (as we are ourselves still filling up lists).
  554. self.observations.extend(other.get_observations())
  555. self.actions.extend(other.get_actions())
  556. self.rewards.extend(other.get_rewards())
  557. self.infos.extend(other.get_infos())
  558. self.t = other.t
  559. if other.is_terminated:
  560. self.is_terminated = True
  561. elif other.is_truncated:
  562. self.is_truncated = True
  563. for key in other.extra_model_outputs.keys():
  564. assert key in self.extra_model_outputs
  565. self.extra_model_outputs[key].extend(other.get_extra_model_outputs(key))
  566. # Merge with `other`'s custom_data, but give `other` priority b/c we assume
  567. # that as a follow-up chunk of `self` other has a more complete version of
  568. # `custom_data`.
  569. self.custom_data.update(other.custom_data)
  570. # Validate.
  571. self.validate()
  572. def cut(self, len_lookback_buffer: int = 0) -> "SingleAgentEpisode":
  573. """Returns a successor episode chunk (of len=0) continuing from this Episode.
  574. The successor will have the same ID as `self`.
  575. If no lookback buffer is requested (len_lookback_buffer=0), the successor's
  576. observations will be the last observation(s) of `self` and its length will
  577. therefore be 0 (no further steps taken yet). If `len_lookback_buffer` > 0,
  578. the returned successor will have `len_lookback_buffer` observations (and
  579. actions, rewards, etc..) taken from the right side (end) of `self`. For example
  580. if `len_lookback_buffer=2`, the returned successor's lookback buffer actions
  581. will be identical to `self.actions[-2:]`.
  582. This method is useful if you would like to discontinue building an episode
  583. chunk (b/c you have to return it from somewhere), but would like to have a new
  584. episode instance to continue building the actual gym.Env episode at a later
  585. time. Vie the `len_lookback_buffer` argument, the continuing chunk (successor)
  586. will still be able to "look back" into this predecessor episode's data (at
  587. least to some extend, depending on the value of `len_lookback_buffer`).
  588. Args:
  589. len_lookback_buffer: The number of timesteps to take along into the new
  590. chunk as "lookback buffer". A lookback buffer is additional data on
  591. the left side of the actual episode data for visibility purposes
  592. (but without actually being part of the new chunk). For example, if
  593. `self` ends in actions 5, 6, 7, and 8, and we call
  594. `self.cut(len_lookback_buffer=2)`, the returned chunk will have
  595. actions 7 and 8 already in it, but still `t_started`==t==8 (not 7!) and
  596. a length of 0. If there is not enough data in `self` yet to fulfil
  597. the `len_lookback_buffer` request, the value of `len_lookback_buffer`
  598. is automatically adjusted (lowered).
  599. Returns:
  600. The successor Episode chunk of this one with the same ID and state and the
  601. only observation being the last observation in self.
  602. """
  603. assert not self.is_done and len_lookback_buffer >= 0
  604. # Initialize this chunk with the most recent obs and infos (even if lookback is
  605. # 0). Similar to an initial `env.reset()`.
  606. indices_obs_and_infos = slice(-len_lookback_buffer - 1, None)
  607. indices_rest = (
  608. slice(-len_lookback_buffer, None)
  609. if len_lookback_buffer > 0
  610. else slice(None, 0)
  611. )
  612. sa_episode = SingleAgentEpisode(
  613. # Same ID.
  614. id_=self.id_,
  615. observations=self.get_observations(indices=indices_obs_and_infos),
  616. observation_space=self.observation_space,
  617. infos=self.get_infos(indices=indices_obs_and_infos),
  618. actions=self.get_actions(indices=indices_rest),
  619. action_space=self.action_space,
  620. rewards=self.get_rewards(indices=indices_rest),
  621. extra_model_outputs={
  622. k: self.get_extra_model_outputs(k, indices_rest)
  623. for k in self.extra_model_outputs.keys()
  624. },
  625. # Continue with self's current timestep.
  626. t_started=self.t,
  627. # Use the length of the provided data as lookback buffer.
  628. len_lookback_buffer="auto",
  629. )
  630. # Deepcopy all custom data in `self` to be continued in the cut episode.
  631. sa_episode._custom_data = copy.deepcopy(self.custom_data)
  632. return sa_episode
  633. # TODO (sven): Distinguish between:
  634. # - global index: This is the absolute, global timestep whose values always
  635. # start from 0 (at the env reset). So doing get_observations(0, global_ts=True)
  636. # should always return the exact 1st observation (reset obs), no matter what. In
  637. # case we are in an episode chunk and `fill` or a sufficient lookback buffer is
  638. # provided, this should yield a result. Otherwise, error.
  639. # - global index=False -> indices are relative to the chunk start. If a chunk has
  640. # t_started=6 and we ask for index=0, then return observation at timestep 6
  641. # (t_started).
  642. def get_observations(
  643. self,
  644. indices: Optional[Union[int, List[int], slice]] = None,
  645. *,
  646. neg_index_as_lookback: bool = False,
  647. fill: Optional[Any] = None,
  648. one_hot_discrete: bool = False,
  649. ) -> Any:
  650. """Returns individual observations or batched ranges thereof from this episode.
  651. Args:
  652. indices: A single int is interpreted as an index, from which to return the
  653. individual observation stored at this index.
  654. A list of ints is interpreted as a list of indices from which to gather
  655. individual observations in a batch of size len(indices).
  656. A slice object is interpreted as a range of observations to be returned.
  657. Thereby, negative indices by default are interpreted as "before the end"
  658. unless the `neg_index_as_lookback=True` option is used, in which case
  659. negative indices are interpreted as "before ts=0", meaning going back
  660. into the lookback buffer.
  661. If None, will return all observations (from ts=0 to the end).
  662. neg_index_as_lookback: If True, negative values in `indices` are
  663. interpreted as "before ts=0", meaning going back into the lookback
  664. buffer. For example, an episode with observations [4, 5, 6, 7, 8, 9],
  665. where [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will
  666. respond to `get_observations(-1, neg_index_as_lookback=True)`
  667. with `6` and to
  668. `get_observations(slice(-2, 1), neg_index_as_lookback=True)` with
  669. `[5, 6, 7]`.
  670. fill: An optional value to use for filling up the returned results at
  671. the boundaries. This filling only happens if the requested index range's
  672. start/stop boundaries exceed the episode's boundaries (including the
  673. lookback buffer on the left side). This comes in very handy, if users
  674. don't want to worry about reaching such boundaries and want to zero-pad.
  675. For example, an episode with observations [10, 11, 12, 13, 14] and
  676. lookback buffer size of 2 (meaning observations `10` and `11` are part
  677. of the lookback buffer) will respond to
  678. `get_observations(slice(-7, -2), fill=0.0)` with
  679. `[0.0, 0.0, 10, 11, 12]`.
  680. one_hot_discrete: If True, will return one-hot vectors (instead of
  681. int-values) for those sub-components of a (possibly complex) observation
  682. space that are Discrete or MultiDiscrete. Note that if `fill=0` and the
  683. requested `indices` are out of the range of our data, the returned
  684. one-hot vectors will actually be zero-hot (all slots zero).
  685. Examples:
  686. .. testcode::
  687. import gymnasium as gym
  688. from ray.rllib.env.single_agent_episode import SingleAgentEpisode
  689. from ray.rllib.utils.test_utils import check
  690. episode = SingleAgentEpisode(
  691. # Discrete(4) observations (ints between 0 and 4 (excl.))
  692. observation_space=gym.spaces.Discrete(4),
  693. observations=[0, 1, 2, 3],
  694. actions=[1, 2, 3], rewards=[1, 2, 3], # <- not relevant for this demo
  695. len_lookback_buffer=0, # no lookback; all data is actually "in" episode
  696. )
  697. # Plain usage (`indices` arg only).
  698. check(episode.get_observations(-1), 3)
  699. check(episode.get_observations(0), 0)
  700. check(episode.get_observations([0, 2]), [0, 2])
  701. check(episode.get_observations([-1, 0]), [3, 0])
  702. check(episode.get_observations(slice(None, 2)), [0, 1])
  703. check(episode.get_observations(slice(-2, None)), [2, 3])
  704. # Using `fill=...` (requesting slices beyond the boundaries).
  705. check(episode.get_observations(slice(-6, -2), fill=-9), [-9, -9, 0, 1])
  706. check(episode.get_observations(slice(2, 5), fill=-7), [2, 3, -7])
  707. # Using `one_hot_discrete=True`.
  708. check(episode.get_observations(2, one_hot_discrete=True), [0, 0, 1, 0])
  709. check(episode.get_observations(3, one_hot_discrete=True), [0, 0, 0, 1])
  710. check(episode.get_observations(
  711. slice(0, 3),
  712. one_hot_discrete=True,
  713. ), [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]])
  714. # Special case: Using `fill=0.0` AND `one_hot_discrete=True`.
  715. check(episode.get_observations(
  716. -1,
  717. neg_index_as_lookback=True, # -1 means one left of ts=0
  718. fill=0.0,
  719. one_hot_discrete=True,
  720. ), [0, 0, 0, 0]) # <- all 0s one-hot tensor (note difference to [1 0 0 0]!)
  721. Returns:
  722. The collected observations.
  723. As a 0-axis batch, if there are several `indices` or a list of exactly one
  724. index provided OR `indices` is a slice object.
  725. As single item (B=0 -> no additional 0-axis) if `indices` is a single int.
  726. """
  727. return self.observations.get(
  728. indices=indices,
  729. neg_index_as_lookback=neg_index_as_lookback,
  730. fill=fill,
  731. one_hot_discrete=one_hot_discrete,
  732. )
  733. def get_infos(
  734. self,
  735. indices: Optional[Union[int, List[int], slice]] = None,
  736. *,
  737. neg_index_as_lookback: bool = False,
  738. fill: Optional[Any] = None,
  739. ) -> Any:
  740. """Returns individual info dicts or list (ranges) thereof from this episode.
  741. Args:
  742. indices: A single int is interpreted as an index, from which to return the
  743. individual info dict stored at this index.
  744. A list of ints is interpreted as a list of indices from which to gather
  745. individual info dicts in a list of size len(indices).
  746. A slice object is interpreted as a range of info dicts to be returned.
  747. Thereby, negative indices by default are interpreted as "before the end"
  748. unless the `neg_index_as_lookback=True` option is used, in which case
  749. negative indices are interpreted as "before ts=0", meaning going back
  750. into the lookback buffer.
  751. If None, will return all infos (from ts=0 to the end).
  752. neg_index_as_lookback: If True, negative values in `indices` are
  753. interpreted as "before ts=0", meaning going back into the lookback
  754. buffer. For example, an episode with infos
  755. [{"l":4}, {"l":5}, {"l":6}, {"a":7}, {"b":8}, {"c":9}], where the
  756. first 3 items are the lookback buffer (ts=0 item is {"a": 7}), will
  757. respond to `get_infos(-1, neg_index_as_lookback=True)` with
  758. `{"l":6}` and to
  759. `get_infos(slice(-2, 1), neg_index_as_lookback=True)` with
  760. `[{"l":5}, {"l":6}, {"a":7}]`.
  761. fill: An optional value to use for filling up the returned results at
  762. the boundaries. This filling only happens if the requested index range's
  763. start/stop boundaries exceed the episode's boundaries (including the
  764. lookback buffer on the left side). This comes in very handy, if users
  765. don't want to worry about reaching such boundaries and want to
  766. auto-fill. For example, an episode with infos
  767. [{"l":10}, {"l":11}, {"a":12}, {"b":13}, {"c":14}] and lookback buffer
  768. size of 2 (meaning infos {"l":10}, {"l":11} are part of the lookback
  769. buffer) will respond to `get_infos(slice(-7, -2), fill={"o": 0.0})`
  770. with `[{"o":0.0}, {"o":0.0}, {"l":10}, {"l":11}, {"a":12}]`.
  771. Examples:
  772. .. testcode::
  773. from ray.rllib.env.single_agent_episode import SingleAgentEpisode
  774. episode = SingleAgentEpisode(
  775. infos=[{"a":0}, {"b":1}, {"c":2}, {"d":3}],
  776. # The following is needed, but not relevant for this demo.
  777. observations=[0, 1, 2, 3], actions=[1, 2, 3], rewards=[1, 2, 3],
  778. len_lookback_buffer=0, # no lookback; all data is actually "in" episode
  779. )
  780. # Plain usage (`indices` arg only).
  781. episode.get_infos(-1) # {"d":3}
  782. episode.get_infos(0) # {"a":0}
  783. episode.get_infos([0, 2]) # [{"a":0},{"c":2}]
  784. episode.get_infos([-1, 0]) # [{"d":3},{"a":0}]
  785. episode.get_infos(slice(None, 2)) # [{"a":0},{"b":1}]
  786. episode.get_infos(slice(-2, None)) # [{"c":2},{"d":3}]
  787. # Using `fill=...` (requesting slices beyond the boundaries).
  788. # TODO (sven): This would require a space being provided. Maybe we can
  789. # skip this check for infos, which don't have a space anyways.
  790. # episode.get_infos(slice(-5, -3), fill={"o":-1}) # [{"o":-1},{"a":0}]
  791. # episode.get_infos(slice(3, 5), fill={"o":-2}) # [{"d":3},{"o":-2}]
  792. Returns:
  793. The collected info dicts.
  794. As a 0-axis batch, if there are several `indices` or a list of exactly one
  795. index provided OR `indices` is a slice object.
  796. As single item (B=0 -> no additional 0-axis) if `indices` is a single int.
  797. """
  798. return self.infos.get(
  799. indices=indices,
  800. neg_index_as_lookback=neg_index_as_lookback,
  801. fill=fill,
  802. )
  803. def get_actions(
  804. self,
  805. indices: Optional[Union[int, List[int], slice]] = None,
  806. *,
  807. neg_index_as_lookback: bool = False,
  808. fill: Optional[Any] = None,
  809. one_hot_discrete: bool = False,
  810. ) -> Any:
  811. """Returns individual actions or batched ranges thereof from this episode.
  812. Args:
  813. indices: A single int is interpreted as an index, from which to return the
  814. individual action stored at this index.
  815. A list of ints is interpreted as a list of indices from which to gather
  816. individual actions in a batch of size len(indices).
  817. A slice object is interpreted as a range of actions to be returned.
  818. Thereby, negative indices by default are interpreted as "before the end"
  819. unless the `neg_index_as_lookback=True` option is used, in which case
  820. negative indices are interpreted as "before ts=0", meaning going back
  821. into the lookback buffer.
  822. If None, will return all actions (from ts=0 to the end).
  823. neg_index_as_lookback: If True, negative values in `indices` are
  824. interpreted as "before ts=0", meaning going back into the lookback
  825. buffer. For example, an episode with actions [4, 5, 6, 7, 8, 9], where
  826. [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will respond
  827. to `get_actions(-1, neg_index_as_lookback=True)` with `6` and
  828. to `get_actions(slice(-2, 1), neg_index_as_lookback=True)` with
  829. `[5, 6, 7]`.
  830. fill: An optional value to use for filling up the returned results at
  831. the boundaries. This filling only happens if the requested index range's
  832. start/stop boundaries exceed the episode's boundaries (including the
  833. lookback buffer on the left side). This comes in very handy, if users
  834. don't want to worry about reaching such boundaries and want to zero-pad.
  835. For example, an episode with actions [10, 11, 12, 13, 14] and
  836. lookback buffer size of 2 (meaning actions `10` and `11` are part
  837. of the lookback buffer) will respond to
  838. `get_actions(slice(-7, -2), fill=0.0)` with `[0.0, 0.0, 10, 11, 12]`.
  839. one_hot_discrete: If True, will return one-hot vectors (instead of
  840. int-values) for those sub-components of a (possibly complex) action
  841. space that are Discrete or MultiDiscrete. Note that if `fill=0` and the
  842. requested `indices` are out of the range of our data, the returned
  843. one-hot vectors will actually be zero-hot (all slots zero).
  844. Examples:
  845. .. testcode::
  846. import gymnasium as gym
  847. from ray.rllib.env.single_agent_episode import SingleAgentEpisode
  848. episode = SingleAgentEpisode(
  849. # Discrete(4) actions (ints between 0 and 4 (excl.))
  850. action_space=gym.spaces.Discrete(4),
  851. actions=[1, 2, 3],
  852. observations=[0, 1, 2, 3], rewards=[1, 2, 3], # <- not relevant here
  853. len_lookback_buffer=0, # no lookback; all data is actually "in" episode
  854. )
  855. # Plain usage (`indices` arg only).
  856. episode.get_actions(-1) # 3
  857. episode.get_actions(0) # 1
  858. episode.get_actions([0, 2]) # [1, 3]
  859. episode.get_actions([-1, 0]) # [3, 1]
  860. episode.get_actions(slice(None, 2)) # [1, 2]
  861. episode.get_actions(slice(-2, None)) # [2, 3]
  862. # Using `fill=...` (requesting slices beyond the boundaries).
  863. episode.get_actions(slice(-5, -2), fill=-9) # [-9, -9, 1, 2]
  864. episode.get_actions(slice(1, 5), fill=-7) # [2, 3, -7, -7]
  865. # Using `one_hot_discrete=True`.
  866. episode.get_actions(1, one_hot_discrete=True) # [0 0 1 0] (action=2)
  867. episode.get_actions(2, one_hot_discrete=True) # [0 0 0 1] (action=3)
  868. episode.get_actions(
  869. slice(0, 2),
  870. one_hot_discrete=True,
  871. ) # [[0 1 0 0], [0 0 0 1]] (actions=1 and 3)
  872. # Special case: Using `fill=0.0` AND `one_hot_discrete=True`.
  873. episode.get_actions(
  874. -1,
  875. neg_index_as_lookback=True, # -1 means one left of ts=0
  876. fill=0.0,
  877. one_hot_discrete=True,
  878. ) # [0 0 0 0] <- all 0s one-hot tensor (note difference to [1 0 0 0]!)
  879. Returns:
  880. The collected actions.
  881. As a 0-axis batch, if there are several `indices` or a list of exactly one
  882. index provided OR `indices` is a slice object.
  883. As single item (B=0 -> no additional 0-axis) if `indices` is a single int.
  884. """
  885. return self.actions.get(
  886. indices=indices,
  887. neg_index_as_lookback=neg_index_as_lookback,
  888. fill=fill,
  889. one_hot_discrete=one_hot_discrete,
  890. )
  891. def get_rewards(
  892. self,
  893. indices: Optional[Union[int, List[int], slice]] = None,
  894. *,
  895. neg_index_as_lookback: bool = False,
  896. fill: Optional[float] = None,
  897. ) -> Any:
  898. """Returns individual rewards or batched ranges thereof from this episode.
  899. Args:
  900. indices: A single int is interpreted as an index, from which to return the
  901. individual reward stored at this index.
  902. A list of ints is interpreted as a list of indices from which to gather
  903. individual rewards in a batch of size len(indices).
  904. A slice object is interpreted as a range of rewards to be returned.
  905. Thereby, negative indices by default are interpreted as "before the end"
  906. unless the `neg_index_as_lookback=True` option is used, in which case
  907. negative indices are interpreted as "before ts=0", meaning going back
  908. into the lookback buffer.
  909. If None, will return all rewards (from ts=0 to the end).
  910. neg_index_as_lookback: Negative values in `indices` are interpreted as
  911. as "before ts=0", meaning going back into the lookback buffer.
  912. For example, an episode with rewards [4, 5, 6, 7, 8, 9], where
  913. [4, 5, 6] is the lookback buffer range (ts=0 item is 7), will respond
  914. to `get_rewards(-1, neg_index_as_lookback=True)` with `6` and
  915. to `get_rewards(slice(-2, 1), neg_index_as_lookback=True)` with
  916. `[5, 6, 7]`.
  917. fill: An optional float value to use for filling up the returned results at
  918. the boundaries. This filling only happens if the requested index range's
  919. start/stop boundaries exceed the episode's boundaries (including the
  920. lookback buffer on the left side). This comes in very handy, if users
  921. don't want to worry about reaching such boundaries and want to zero-pad.
  922. For example, an episode with rewards [10, 11, 12, 13, 14] and
  923. lookback buffer size of 2 (meaning rewards `10` and `11` are part
  924. of the lookback buffer) will respond to
  925. `get_rewards(slice(-7, -2), fill=0.0)` with `[0.0, 0.0, 10, 11, 12]`.
  926. Examples:
  927. .. testcode::
  928. from ray.rllib.env.single_agent_episode import SingleAgentEpisode
  929. episode = SingleAgentEpisode(
  930. rewards=[1.0, 2.0, 3.0],
  931. observations=[0, 1, 2, 3], actions=[1, 2, 3], # <- not relevant here
  932. len_lookback_buffer=0, # no lookback; all data is actually "in" episode
  933. )
  934. # Plain usage (`indices` arg only).
  935. episode.get_rewards(-1) # 3.0
  936. episode.get_rewards(0) # 1.0
  937. episode.get_rewards([0, 2]) # [1.0, 3.0]
  938. episode.get_rewards([-1, 0]) # [3.0, 1.0]
  939. episode.get_rewards(slice(None, 2)) # [1.0, 2.0]
  940. episode.get_rewards(slice(-2, None)) # [2.0, 3.0]
  941. # Using `fill=...` (requesting slices beyond the boundaries).
  942. episode.get_rewards(slice(-5, -2), fill=0.0) # [0.0, 0.0, 1.0, 2.0]
  943. episode.get_rewards(slice(1, 5), fill=0.0) # [2.0, 3.0, 0.0, 0.0]
  944. Returns:
  945. The collected rewards.
  946. As a 0-axis batch, if there are several `indices` or a list of exactly one
  947. index provided OR `indices` is a slice object.
  948. As single item (B=0 -> no additional 0-axis) if `indices` is a single int.
  949. """
  950. return self.rewards.get(
  951. indices=indices,
  952. neg_index_as_lookback=neg_index_as_lookback,
  953. fill=fill,
  954. )
  955. def get_extra_model_outputs(
  956. self,
  957. key: str,
  958. indices: Optional[Union[int, List[int], slice]] = None,
  959. *,
  960. neg_index_as_lookback: bool = False,
  961. fill: Optional[Any] = None,
  962. ) -> Any:
  963. """Returns extra model outputs (under given key) from this episode.
  964. Args:
  965. key: The `key` within `self.extra_model_outputs` to extract data for.
  966. indices: A single int is interpreted as an index, from which to return an
  967. individual extra model output stored under `key` at index.
  968. A list of ints is interpreted as a list of indices from which to gather
  969. individual actions in a batch of size len(indices).
  970. A slice object is interpreted as a range of extra model outputs to be
  971. returned. Thereby, negative indices by default are interpreted as
  972. "before the end" unless the `neg_index_as_lookback=True` option is
  973. used, in which case negative indices are interpreted as "before ts=0",
  974. meaning going back into the lookback buffer.
  975. If None, will return all extra model outputs (from ts=0 to the end).
  976. neg_index_as_lookback: If True, negative values in `indices` are
  977. interpreted as "before ts=0", meaning going back into the lookback
  978. buffer. For example, an episode with
  979. extra_model_outputs['a'] = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the
  980. lookback buffer range (ts=0 item is 7), will respond to
  981. `get_extra_model_outputs("a", -1, neg_index_as_lookback=True)` with
  982. `6` and to `get_extra_model_outputs("a", slice(-2, 1),
  983. neg_index_as_lookback=True)` with `[5, 6, 7]`.
  984. fill: An optional value to use for filling up the returned results at
  985. the boundaries. This filling only happens if the requested index range's
  986. start/stop boundaries exceed the episode's boundaries (including the
  987. lookback buffer on the left side). This comes in very handy, if users
  988. don't want to worry about reaching such boundaries and want to zero-pad.
  989. For example, an episode with
  990. extra_model_outputs["b"] = [10, 11, 12, 13, 14] and lookback buffer
  991. size of 2 (meaning `10` and `11` are part of the lookback buffer) will
  992. respond to
  993. `get_extra_model_outputs("b", slice(-7, -2), fill=0.0)` with
  994. `[0.0, 0.0, 10, 11, 12]`.
  995. TODO (sven): This would require a space being provided. Maybe we can
  996. automatically infer the space from existing data?
  997. Examples:
  998. .. testcode::
  999. from ray.rllib.env.single_agent_episode import SingleAgentEpisode
  1000. episode = SingleAgentEpisode(
  1001. extra_model_outputs={"mo": [1, 2, 3]},
  1002. len_lookback_buffer=0, # no lookback; all data is actually "in" episode
  1003. # The following is needed, but not relevant for this demo.
  1004. observations=[0, 1, 2, 3], actions=[1, 2, 3], rewards=[1, 2, 3],
  1005. )
  1006. # Plain usage (`indices` arg only).
  1007. episode.get_extra_model_outputs("mo", -1) # 3
  1008. episode.get_extra_model_outputs("mo", 1) # 0
  1009. episode.get_extra_model_outputs("mo", [0, 2]) # [1, 3]
  1010. episode.get_extra_model_outputs("mo", [-1, 0]) # [3, 1]
  1011. episode.get_extra_model_outputs("mo", slice(None, 2)) # [1, 2]
  1012. episode.get_extra_model_outputs("mo", slice(-2, None)) # [2, 3]
  1013. # Using `fill=...` (requesting slices beyond the boundaries).
  1014. # TODO (sven): This would require a space being provided. Maybe we can
  1015. # automatically infer the space from existing data?
  1016. # episode.get_extra_model_outputs("mo", slice(-5, -2), fill=0) # [0, 0, 1]
  1017. # episode.get_extra_model_outputs("mo", slice(2, 5), fill=-1) # [3, -1, -1]
  1018. Returns:
  1019. The collected extra_model_outputs[`key`].
  1020. As a 0-axis batch, if there are several `indices` or a list of exactly one
  1021. index provided OR `indices` is a slice object.
  1022. As single item (B=0 -> no additional 0-axis) if `indices` is a single int.
  1023. """
  1024. value = self.extra_model_outputs[key]
  1025. # The expected case is: `value` is a `InfiniteLookbackBuffer`.
  1026. if isinstance(value, InfiniteLookbackBuffer):
  1027. return value.get(
  1028. indices=indices,
  1029. neg_index_as_lookback=neg_index_as_lookback,
  1030. fill=fill,
  1031. )
  1032. # TODO (sven): This does not seem to be solid yet. Users should NOT be able
  1033. # to just write directly into our buffers. Instead, use:
  1034. # `self.set_extra_model_outputs(key, new_data, at_indices=...)` and if key
  1035. # is not known, add a new buffer to the `extra_model_outputs` dict.
  1036. assert False
  1037. # It might be that the user has added new key/value pairs in their custom
  1038. # postprocessing/connector logic. The values are then most likely numpy
  1039. # arrays. We convert them automatically to buffers and get the requested
  1040. # indices (with the given options) from there.
  1041. return InfiniteLookbackBuffer(value).get(
  1042. indices, fill=fill, neg_index_as_lookback=neg_index_as_lookback
  1043. )
  1044. def set_observations(
  1045. self,
  1046. *,
  1047. new_data,
  1048. at_indices: Optional[Union[int, List[int], slice]] = None,
  1049. neg_index_as_lookback: bool = False,
  1050. ) -> None:
  1051. """Overwrites all or some of this Episode's observations with the provided data.
  1052. Note that an episode's observation data cannot be written to directly as it is
  1053. managed by a `InfiniteLookbackBuffer` object. Normally, individual, current
  1054. observations are added to the episode either by calling `self.add_env_step` or
  1055. more directly (and manually) via `self.observations.append|extend()`.
  1056. However, for certain postprocessing steps, the entirety (or a slice) of an
  1057. episode's observations might have to be rewritten, which is when
  1058. `self.set_observations()` should be used.
  1059. Args:
  1060. new_data: The new observation data to overwrite existing data with.
  1061. This may be a list of individual observation(s) in case this episode
  1062. is still not numpy'ized yet. In case this episode has already been
  1063. numpy'ized, this should be (possibly complex) struct matching the
  1064. observation space and with a batch size of its leafs exactly the size
  1065. of the to-be-overwritten slice or segment (provided by `at_indices`).
  1066. at_indices: A single int is interpreted as one index, which to overwrite
  1067. with `new_data` (which is expected to be a single observation).
  1068. A list of ints is interpreted as a list of indices, all of which to
  1069. overwrite with `new_data` (which is expected to be of the same size
  1070. as `len(at_indices)`).
  1071. A slice object is interpreted as a range of indices to be overwritten
  1072. with `new_data` (which is expected to be of the same size as the
  1073. provided slice).
  1074. Thereby, negative indices by default are interpreted as "before the end"
  1075. unless the `neg_index_as_lookback=True` option is used, in which case
  1076. negative indices are interpreted as "before ts=0", meaning going back
  1077. into the lookback buffer.
  1078. neg_index_as_lookback: If True, negative values in `at_indices` are
  1079. interpreted as "before ts=0", meaning going back into the lookback
  1080. buffer. For example, an episode with
  1081. observations = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the
  1082. lookback buffer range (ts=0 item is 7), will handle a call to
  1083. `set_observations(individual_observation, -1,
  1084. neg_index_as_lookback=True)` by overwriting the value of 6 in our
  1085. observations buffer with the provided "individual_observation".
  1086. Raises:
  1087. IndexError: If the provided `at_indices` do not match the size of
  1088. `new_data`.
  1089. """
  1090. self.observations.set(
  1091. new_data=new_data,
  1092. at_indices=at_indices,
  1093. neg_index_as_lookback=neg_index_as_lookback,
  1094. )
  1095. def set_actions(
  1096. self,
  1097. *,
  1098. new_data,
  1099. at_indices: Optional[Union[int, List[int], slice]] = None,
  1100. neg_index_as_lookback: bool = False,
  1101. ) -> None:
  1102. """Overwrites all or some of this Episode's actions with the provided data.
  1103. Note that an episode's action data cannot be written to directly as it is
  1104. managed by a `InfiniteLookbackBuffer` object. Normally, individual, current
  1105. actions are added to the episode either by calling `self.add_env_step` or
  1106. more directly (and manually) via `self.actions.append|extend()`.
  1107. However, for certain postprocessing steps, the entirety (or a slice) of an
  1108. episode's actions might have to be rewritten, which is when
  1109. `self.set_actions()` should be used.
  1110. Args:
  1111. new_data: The new action data to overwrite existing data with.
  1112. This may be a list of individual action(s) in case this episode
  1113. is still not numpy'ized yet. In case this episode has already been
  1114. numpy'ized, this should be (possibly complex) struct matching the
  1115. action space and with a batch size of its leafs exactly the size
  1116. of the to-be-overwritten slice or segment (provided by `at_indices`).
  1117. at_indices: A single int is interpreted as one index, which to overwrite
  1118. with `new_data` (which is expected to be a single action).
  1119. A list of ints is interpreted as a list of indices, all of which to
  1120. overwrite with `new_data` (which is expected to be of the same size
  1121. as `len(at_indices)`).
  1122. A slice object is interpreted as a range of indices to be overwritten
  1123. with `new_data` (which is expected to be of the same size as the
  1124. provided slice).
  1125. Thereby, negative indices by default are interpreted as "before the end"
  1126. unless the `neg_index_as_lookback=True` option is used, in which case
  1127. negative indices are interpreted as "before ts=0", meaning going back
  1128. into the lookback buffer.
  1129. neg_index_as_lookback: If True, negative values in `at_indices` are
  1130. interpreted as "before ts=0", meaning going back into the lookback
  1131. buffer. For example, an episode with
  1132. actions = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the
  1133. lookback buffer range (ts=0 item is 7), will handle a call to
  1134. `set_actions(individual_action, -1,
  1135. neg_index_as_lookback=True)` by overwriting the value of 6 in our
  1136. actions buffer with the provided "individual_action".
  1137. Raises:
  1138. IndexError: If the provided `at_indices` do not match the size of
  1139. `new_data`.
  1140. """
  1141. self.actions.set(
  1142. new_data=new_data,
  1143. at_indices=at_indices,
  1144. neg_index_as_lookback=neg_index_as_lookback,
  1145. )
  1146. def set_rewards(
  1147. self,
  1148. *,
  1149. new_data,
  1150. at_indices: Optional[Union[int, List[int], slice]] = None,
  1151. neg_index_as_lookback: bool = False,
  1152. ) -> None:
  1153. """Overwrites all or some of this Episode's rewards with the provided data.
  1154. Note that an episode's reward data cannot be written to directly as it is
  1155. managed by a `InfiniteLookbackBuffer` object. Normally, individual, current
  1156. rewards are added to the episode either by calling `self.add_env_step` or
  1157. more directly (and manually) via `self.rewards.append|extend()`.
  1158. However, for certain postprocessing steps, the entirety (or a slice) of an
  1159. episode's rewards might have to be rewritten, which is when
  1160. `self.set_rewards()` should be used.
  1161. Args:
  1162. new_data: The new reward data to overwrite existing data with.
  1163. This may be a list of individual reward(s) in case this episode
  1164. is still not numpy'ized yet. In case this episode has already been
  1165. numpy'ized, this should be a np.ndarray with a length exactly
  1166. the size of the to-be-overwritten slice or segment (provided by
  1167. `at_indices`).
  1168. at_indices: A single int is interpreted as one index, which to overwrite
  1169. with `new_data` (which is expected to be a single reward).
  1170. A list of ints is interpreted as a list of indices, all of which to
  1171. overwrite with `new_data` (which is expected to be of the same size
  1172. as `len(at_indices)`).
  1173. A slice object is interpreted as a range of indices to be overwritten
  1174. with `new_data` (which is expected to be of the same size as the
  1175. provided slice).
  1176. Thereby, negative indices by default are interpreted as "before the end"
  1177. unless the `neg_index_as_lookback=True` option is used, in which case
  1178. negative indices are interpreted as "before ts=0", meaning going back
  1179. into the lookback buffer.
  1180. neg_index_as_lookback: If True, negative values in `at_indices` are
  1181. interpreted as "before ts=0", meaning going back into the lookback
  1182. buffer. For example, an episode with
  1183. rewards = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the
  1184. lookback buffer range (ts=0 item is 7), will handle a call to
  1185. `set_rewards(individual_reward, -1,
  1186. neg_index_as_lookback=True)` by overwriting the value of 6 in our
  1187. rewards buffer with the provided "individual_reward".
  1188. Raises:
  1189. IndexError: If the provided `at_indices` do not match the size of
  1190. `new_data`.
  1191. """
  1192. self.rewards.set(
  1193. new_data=new_data,
  1194. at_indices=at_indices,
  1195. neg_index_as_lookback=neg_index_as_lookback,
  1196. )
  1197. def set_extra_model_outputs(
  1198. self,
  1199. *,
  1200. key,
  1201. new_data,
  1202. at_indices: Optional[Union[int, List[int], slice]] = None,
  1203. neg_index_as_lookback: bool = False,
  1204. ) -> None:
  1205. """Overwrites all or some of this Episode's extra model outputs with `new_data`.
  1206. Note that an episode's `extra_model_outputs` data cannot be written to directly
  1207. as it is managed by a `InfiniteLookbackBuffer` object. Normally, individual,
  1208. current `extra_model_output` values are added to the episode either by calling
  1209. `self.add_env_step` or more directly (and manually) via
  1210. `self.extra_model_outputs[key].append|extend()`. However, for certain
  1211. postprocessing steps, the entirety (or a slice) of an episode's
  1212. `extra_model_outputs` might have to be rewritten or a new key (a new type of
  1213. `extra_model_outputs`) must be inserted, which is when
  1214. `self.set_extra_model_outputs()` should be used.
  1215. Args:
  1216. key: The `key` within `self.extra_model_outputs` to override data on or
  1217. to insert as a new key into `self.extra_model_outputs`.
  1218. new_data: The new data to overwrite existing data with.
  1219. This may be a list of individual reward(s) in case this episode
  1220. is still not numpy'ized yet. In case this episode has already been
  1221. numpy'ized, this should be a np.ndarray with a length exactly
  1222. the size of the to-be-overwritten slice or segment (provided by
  1223. `at_indices`).
  1224. at_indices: A single int is interpreted as one index, which to overwrite
  1225. with `new_data` (which is expected to be a single reward).
  1226. A list of ints is interpreted as a list of indices, all of which to
  1227. overwrite with `new_data` (which is expected to be of the same size
  1228. as `len(at_indices)`).
  1229. A slice object is interpreted as a range of indices to be overwritten
  1230. with `new_data` (which is expected to be of the same size as the
  1231. provided slice).
  1232. Thereby, negative indices by default are interpreted as "before the end"
  1233. unless the `neg_index_as_lookback=True` option is used, in which case
  1234. negative indices are interpreted as "before ts=0", meaning going back
  1235. into the lookback buffer.
  1236. neg_index_as_lookback: If True, negative values in `at_indices` are
  1237. interpreted as "before ts=0", meaning going back into the lookback
  1238. buffer. For example, an episode with
  1239. rewards = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the
  1240. lookback buffer range (ts=0 item is 7), will handle a call to
  1241. `set_rewards(individual_reward, -1,
  1242. neg_index_as_lookback=True)` by overwriting the value of 6 in our
  1243. rewards buffer with the provided "individual_reward".
  1244. Raises:
  1245. IndexError: If the provided `at_indices` do not match the size of
  1246. `new_data`.
  1247. """
  1248. # Record already exists -> Set existing record's data to new values.
  1249. assert key in self.extra_model_outputs
  1250. self.extra_model_outputs[key].set(
  1251. new_data=new_data,
  1252. at_indices=at_indices,
  1253. neg_index_as_lookback=neg_index_as_lookback,
  1254. )
  1255. def slice(
  1256. self,
  1257. slice_: slice,
  1258. *,
  1259. len_lookback_buffer: Optional[int] = None,
  1260. ) -> "SingleAgentEpisode":
  1261. """Returns a slice of this episode with the given slice object.
  1262. For example, if `self` contains o0 (the reset observation), o1, o2, o3, and o4
  1263. and the actions a1, a2, a3, and a4 (len of `self` is 4), then a call to
  1264. `self.slice(slice(1, 3))` would return a new SingleAgentEpisode with
  1265. observations o1, o2, and o3, and actions a2 and a3. Note here that there is
  1266. always one observation more in an episode than there are actions (and rewards
  1267. and extra model outputs) due to the initial observation received after an env
  1268. reset.
  1269. .. testcode::
  1270. from ray.rllib.env.single_agent_episode import SingleAgentEpisode
  1271. from ray.rllib.utils.test_utils import check
  1272. # Generate a simple multi-agent episode.
  1273. observations = [0, 1, 2, 3, 4, 5]
  1274. actions = [1, 2, 3, 4, 5]
  1275. rewards = [0.1, 0.2, 0.3, 0.4, 0.5]
  1276. episode = SingleAgentEpisode(
  1277. observations=observations,
  1278. actions=actions,
  1279. rewards=rewards,
  1280. len_lookback_buffer=0, # all given data is part of the episode
  1281. )
  1282. slice_1 = episode[:1]
  1283. check(slice_1.observations, [0, 1])
  1284. check(slice_1.actions, [1])
  1285. check(slice_1.rewards, [0.1])
  1286. slice_2 = episode[-2:]
  1287. check(slice_2.observations, [3, 4, 5])
  1288. check(slice_2.actions, [4, 5])
  1289. check(slice_2.rewards, [0.4, 0.5])
  1290. Args:
  1291. slice_: The slice object to use for slicing. This should exclude the
  1292. lookback buffer, which will be prepended automatically to the returned
  1293. slice.
  1294. len_lookback_buffer: If not None, forces the returned slice to try to have
  1295. this number of timesteps in its lookback buffer (if available). If None
  1296. (default), tries to make the returned slice's lookback as large as the
  1297. current lookback buffer of this episode (`self`).
  1298. Returns:
  1299. The new SingleAgentEpisode representing the requested slice.
  1300. """
  1301. # Translate `slice_` into one that only contains 0-or-positive ints and will
  1302. # NOT contain any None.
  1303. start = slice_.start
  1304. stop = slice_.stop
  1305. # Start is None -> 0.
  1306. if start is None:
  1307. start = 0
  1308. # Start is negative -> Interpret index as counting "from end".
  1309. elif start < 0:
  1310. start = len(self) + start
  1311. # Stop is None -> Set stop to our len (one ts past last valid index).
  1312. if stop is None:
  1313. stop = len(self)
  1314. # Stop is negative -> Interpret index as counting "from end".
  1315. elif stop < 0:
  1316. stop = len(self) + stop
  1317. step = slice_.step if slice_.step is not None else 1
  1318. # Figure out, whether slicing stops at the very end of this episode to know
  1319. # whether `self.is_terminated/is_truncated` should be kept as-is.
  1320. keep_done = stop == len(self)
  1321. # Provide correct timestep- and pre-buffer information.
  1322. t_started = self.t_started + start
  1323. _lb = (
  1324. len_lookback_buffer
  1325. if len_lookback_buffer is not None
  1326. else self.observations.lookback
  1327. )
  1328. if (
  1329. start >= 0
  1330. and start - _lb < 0
  1331. and self.observations.lookback < (_lb - start)
  1332. ):
  1333. _lb = self.observations.lookback + start
  1334. observations = InfiniteLookbackBuffer(
  1335. data=self.get_observations(
  1336. slice(start - _lb, stop + 1, step),
  1337. neg_index_as_lookback=True,
  1338. ),
  1339. lookback=_lb,
  1340. space=self.observation_space,
  1341. )
  1342. _lb = (
  1343. len_lookback_buffer
  1344. if len_lookback_buffer is not None
  1345. else self.infos.lookback
  1346. )
  1347. if start >= 0 and start - _lb < 0 and self.infos.lookback < (_lb - start):
  1348. _lb = self.infos.lookback + start
  1349. infos = InfiniteLookbackBuffer(
  1350. data=self.get_infos(
  1351. slice(start - _lb, stop + 1, step),
  1352. neg_index_as_lookback=True,
  1353. ),
  1354. lookback=_lb,
  1355. )
  1356. _lb = (
  1357. len_lookback_buffer
  1358. if len_lookback_buffer is not None
  1359. else self.actions.lookback
  1360. )
  1361. if start >= 0 and start - _lb < 0 and self.actions.lookback < (_lb - start):
  1362. _lb = self.actions.lookback + start
  1363. actions = InfiniteLookbackBuffer(
  1364. data=self.get_actions(
  1365. slice(start - _lb, stop, step),
  1366. neg_index_as_lookback=True,
  1367. ),
  1368. lookback=_lb,
  1369. space=self.action_space,
  1370. )
  1371. _lb = (
  1372. len_lookback_buffer
  1373. if len_lookback_buffer is not None
  1374. else self.rewards.lookback
  1375. )
  1376. if start >= 0 and start - _lb < 0 and self.rewards.lookback < (_lb - start):
  1377. _lb = self.rewards.lookback + start
  1378. rewards = InfiniteLookbackBuffer(
  1379. data=self.get_rewards(
  1380. slice(start - _lb, stop, step),
  1381. neg_index_as_lookback=True,
  1382. ),
  1383. lookback=_lb,
  1384. )
  1385. extra_model_outputs = {}
  1386. for k, v in self.extra_model_outputs.items():
  1387. _lb = len_lookback_buffer if len_lookback_buffer is not None else v.lookback
  1388. if start >= 0 and start - _lb < 0 and v.lookback < (_lb - start):
  1389. _lb = v.lookback + start
  1390. extra_model_outputs[k] = InfiniteLookbackBuffer(
  1391. data=self.get_extra_model_outputs(
  1392. key=k,
  1393. indices=slice(start - _lb, stop, step),
  1394. neg_index_as_lookback=True,
  1395. ),
  1396. lookback=_lb,
  1397. )
  1398. return SingleAgentEpisode(
  1399. id_=self.id_,
  1400. # In the following, offset `start`s automatically by lookbacks.
  1401. observations=observations,
  1402. observation_space=self.observation_space,
  1403. infos=infos,
  1404. actions=actions,
  1405. action_space=self.action_space,
  1406. rewards=rewards,
  1407. extra_model_outputs=extra_model_outputs,
  1408. terminated=(self.is_terminated if keep_done else False),
  1409. truncated=(self.is_truncated if keep_done else False),
  1410. t_started=t_started,
  1411. )
  1412. def get_data_dict(self):
  1413. """Converts a SingleAgentEpisode into a data dict mapping str keys to data.
  1414. The keys used are:
  1415. Columns.EPS_ID, T, OBS, INFOS, ACTIONS, REWARDS, TERMINATEDS, TRUNCATEDS,
  1416. and those in `self.extra_model_outputs`.
  1417. Returns:
  1418. A data dict mapping str keys to data records.
  1419. """
  1420. t = list(range(self.t_started, self.t))
  1421. terminateds = [False] * (len(self) - 1) + [self.is_terminated]
  1422. truncateds = [False] * (len(self) - 1) + [self.is_truncated]
  1423. eps_id = [self.id_] * len(self)
  1424. if self.is_numpy:
  1425. t = np.array(t)
  1426. terminateds = np.array(terminateds)
  1427. truncateds = np.array(truncateds)
  1428. eps_id = np.array(eps_id)
  1429. return dict(
  1430. {
  1431. # Trivial 1D data (compiled above).
  1432. Columns.TERMINATEDS: terminateds,
  1433. Columns.TRUNCATEDS: truncateds,
  1434. Columns.T: t,
  1435. Columns.EPS_ID: eps_id,
  1436. # Retrieve obs, infos, actions, rewards using our get_... APIs,
  1437. # which return all relevant timesteps (excluding the lookback
  1438. # buffer!). Slice off last obs and infos to have the same number
  1439. # of them as we have actions and rewards.
  1440. Columns.OBS: self.get_observations(slice(None, -1)),
  1441. Columns.INFOS: self.get_infos(slice(None, -1)),
  1442. Columns.ACTIONS: self.get_actions(),
  1443. Columns.REWARDS: self.get_rewards(),
  1444. },
  1445. # All `extra_model_outs`: Same as obs: Use get_... API.
  1446. **{
  1447. k: self.get_extra_model_outputs(k)
  1448. for k in self.extra_model_outputs.keys()
  1449. },
  1450. )
  1451. def get_sample_batch(self) -> SampleBatch:
  1452. """Converts this `SingleAgentEpisode` into a `SampleBatch`.
  1453. Returns:
  1454. A SampleBatch containing all of this episode's data.
  1455. """
  1456. return SampleBatch(self.get_data_dict())
  1457. def get_return(self) -> float:
  1458. """Calculates an episode's return, excluding the lookback buffer's rewards.
  1459. The return is computed by a simple sum, neglecting the discount factor.
  1460. Note that if `self` is a continuation chunk (resulting from a call to
  1461. `self.cut()`), the previous chunk's rewards are NOT counted and thus NOT
  1462. part of the returned reward sum.
  1463. Returns:
  1464. The sum of rewards collected during this episode, excluding possible data
  1465. inside the lookback buffer and excluding possible data in a predecessor
  1466. chunk.
  1467. """
  1468. return sum(self.get_rewards())
  1469. def get_duration_s(self) -> float:
  1470. """Returns the duration of this Episode (chunk) in seconds."""
  1471. if self._last_step_time is None:
  1472. return 0.0
  1473. return self._last_step_time - self._start_time
  1474. def env_steps(self) -> int:
  1475. """Returns the number of environment steps.
  1476. Note, this episode instance could be a chunk of an actual episode.
  1477. Returns:
  1478. An integer that counts the number of environment steps this episode instance
  1479. has seen.
  1480. """
  1481. return len(self)
  1482. def agent_steps(self) -> int:
  1483. """Returns the number of agent steps.
  1484. Note, these are identical to the environment steps for a single-agent episode.
  1485. Returns:
  1486. An integer counting the number of agent steps executed during the time this
  1487. episode instance records.
  1488. """
  1489. return self.env_steps()
  1490. def get_state(self) -> Dict[str, Any]:
  1491. """Returns the pickable state of an episode.
  1492. The data in the episode is stored into a dictionary. Note that episodes
  1493. can also be generated from states (see `SingleAgentEpisode.from_state()`).
  1494. Returns:
  1495. A dict containing all the data from the episode.
  1496. """
  1497. infos = self.infos.get_state()
  1498. infos["data"] = np.array([info if info else None for info in infos["data"]])
  1499. return {
  1500. "id_": self.id_,
  1501. "agent_id": self.agent_id,
  1502. "module_id": self.module_id,
  1503. "multi_agent_episode_id": self.multi_agent_episode_id,
  1504. # Note, all data is stored in `InfiniteLookbackBuffer`s.
  1505. "observations": self.observations.get_state(),
  1506. "actions": self.actions.get_state(),
  1507. "rewards": self.rewards.get_state(),
  1508. "infos": self.infos.get_state(),
  1509. "extra_model_outputs": {
  1510. k: v.get_state() if v else v
  1511. for k, v in self.extra_model_outputs.items()
  1512. }
  1513. if len(self.extra_model_outputs) > 0
  1514. else None,
  1515. "is_terminated": self.is_terminated,
  1516. "is_truncated": self.is_truncated,
  1517. "t_started": self.t_started,
  1518. "t": self.t,
  1519. "_observation_space": gym_space_to_dict(self._observation_space)
  1520. if self._observation_space
  1521. else None,
  1522. "_action_space": gym_space_to_dict(self._action_space)
  1523. if self._action_space
  1524. else None,
  1525. "_start_time": self._start_time,
  1526. "_last_step_time": self._last_step_time,
  1527. "custom_data": self.custom_data,
  1528. }
  1529. @staticmethod
  1530. def from_state(state: Dict[str, Any]) -> "SingleAgentEpisode":
  1531. """Creates a new `SingleAgentEpisode` instance from a state dict.
  1532. Args:
  1533. state: The state dict, as returned by `self.get_state()`.
  1534. Returns:
  1535. A new `SingleAgentEpisode` instance with the data from the state dict.
  1536. """
  1537. # Create an empy episode instance.
  1538. episode = SingleAgentEpisode(id_=state["id_"])
  1539. # Load all the data from the state dict into the episode.
  1540. episode.agent_id = state["agent_id"]
  1541. episode.module_id = state["module_id"]
  1542. episode.multi_agent_episode_id = state["multi_agent_episode_id"]
  1543. # Convert data back to `InfiniteLookbackBuffer`s.
  1544. episode.observations = InfiniteLookbackBuffer.from_state(state["observations"])
  1545. episode.actions = InfiniteLookbackBuffer.from_state(state["actions"])
  1546. episode.rewards = InfiniteLookbackBuffer.from_state(state["rewards"])
  1547. episode.infos = InfiniteLookbackBuffer.from_state(state["infos"])
  1548. episode.extra_model_outputs = (
  1549. defaultdict(
  1550. functools.partial(
  1551. InfiniteLookbackBuffer, lookback=episode.observations.lookback
  1552. ),
  1553. {
  1554. k: InfiniteLookbackBuffer.from_state(v)
  1555. for k, v in state["extra_model_outputs"].items()
  1556. },
  1557. )
  1558. if state["extra_model_outputs"]
  1559. else defaultdict(
  1560. functools.partial(
  1561. InfiniteLookbackBuffer, lookback=episode.observations.lookback
  1562. ),
  1563. )
  1564. )
  1565. episode.is_terminated = state["is_terminated"]
  1566. episode.is_truncated = state["is_truncated"]
  1567. episode.t_started = state["t_started"]
  1568. episode.t = state["t"]
  1569. # We need to convert the spaces to dictionaries for serialization.
  1570. episode._observation_space = (
  1571. gym_space_from_dict(state["_observation_space"])
  1572. if state["_observation_space"]
  1573. else None
  1574. )
  1575. episode._action_space = (
  1576. gym_space_from_dict(state["_action_space"])
  1577. if state["_action_space"]
  1578. else None
  1579. )
  1580. episode._start_time = state["_start_time"]
  1581. episode._last_step_time = state["_last_step_time"]
  1582. episode._custom_data = state.get("custom_data", {})
  1583. # Validate the episode.
  1584. episode.validate()
  1585. return episode
  1586. @property
  1587. def observation_space(self):
  1588. return self._observation_space
  1589. @observation_space.setter
  1590. def observation_space(self, value):
  1591. self._observation_space = self.observations.space = value
  1592. @property
  1593. def action_space(self):
  1594. return self._action_space
  1595. @action_space.setter
  1596. def action_space(self, value):
  1597. self._action_space = self.actions.space = value
  1598. def __len__(self) -> int:
  1599. """Returning the length of an episode.
  1600. The length of an episode is defined by the length of its data, excluding
  1601. the lookback buffer data. The length is the number of timesteps an agent has
  1602. stepped through an environment thus far.
  1603. The length is 0 in case of an episode whose env has NOT been reset yet, but
  1604. also 0 right after the `env.reset()` data has been added via
  1605. `self.add_env_reset()`. Only after the first call to `env.step()` (and
  1606. `self.add_env_step()`, the length will be 1.
  1607. Returns:
  1608. An integer, defining the length of an episode.
  1609. """
  1610. return self.t - self.t_started
  1611. def __repr__(self):
  1612. return (
  1613. f"SAEps(len={len(self)} done={self.is_done} "
  1614. f"R={self.get_return()} id_={self.id_})"
  1615. )
  1616. def __getitem__(self, item: slice) -> "SingleAgentEpisode":
  1617. """Enable squared bracket indexing- and slicing syntax, e.g. episode[-4:]."""
  1618. if isinstance(item, slice):
  1619. return self.slice(slice_=item)
  1620. else:
  1621. raise NotImplementedError(
  1622. f"SingleAgentEpisode does not support getting item '{item}'! "
  1623. "Only slice objects allowed with the syntax: `episode[a:b]`."
  1624. )
  1625. @Deprecated(new="SingleAgentEpisode.custom_data[some-key] = ...", error=True)
  1626. def add_temporary_timestep_data(self):
  1627. pass
  1628. @Deprecated(new="SingleAgentEpisode.custom_data[some-key]", error=True)
  1629. def get_temporary_timestep_data(self):
  1630. pass