multi_agent_episode.py 131 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780
  1. import copy
  2. import time
  3. import uuid
  4. from collections import defaultdict
  5. from typing import (
  6. Any,
  7. Callable,
  8. Collection,
  9. DefaultDict,
  10. Dict,
  11. List,
  12. Optional,
  13. Set,
  14. Union,
  15. )
  16. import gymnasium as gym
  17. from ray._common.deprecation import Deprecated
  18. from ray.rllib.env.single_agent_episode import SingleAgentEpisode
  19. from ray.rllib.env.utils.infinite_lookback_buffer import InfiniteLookbackBuffer
  20. from ray.rllib.policy.sample_batch import MultiAgentBatch
  21. from ray.rllib.utils import force_list
  22. from ray.rllib.utils.error import MultiAgentEnvError
  23. from ray.rllib.utils.spaces.space_utils import batch
  24. from ray.rllib.utils.typing import AgentID, ModuleID, MultiAgentDict
  25. from ray.util.annotations import PublicAPI
  26. # TODO (simon): Include cases in which the number of agents in an
  27. # episode are shrinking or growing during the episode itself.
  28. @PublicAPI(stability="alpha")
  29. class MultiAgentEpisode:
  30. """Stores multi-agent episode data.
  31. The central attribute of the class is the timestep mapping
  32. `self.env_t_to_agent_t` that maps AgentIDs to their specific environment steps to
  33. the agent's own scale/timesteps.
  34. Each AgentID in the `MultiAgentEpisode` has its own `SingleAgentEpisode` object
  35. in which this agent's data is stored. Together with the env_t_to_agent_t mapping,
  36. we can extract information either on any individual agent's time scale or from
  37. the (global) multi-agent environment time scale.
  38. Extraction of data from a MultiAgentEpisode happens via the getter APIs, e.g.
  39. `get_observations()`, which work analogous to the ones implemented in the
  40. `SingleAgentEpisode` class.
  41. Note that recorded `terminateds`/`truncateds` come as simple
  42. `MultiAgentDict`s mapping AgentID to bools and thus have no assignment to a
  43. certain timestep (analogous to a SingleAgentEpisode's single `terminated/truncated`
  44. boolean flag). Instead we assign it to the last observation recorded.
  45. Theoretically, there could occur edge cases in some environments
  46. where an agent receives partial rewards and then terminates without
  47. a last observation. In these cases, we duplicate the last observation.
  48. Also, if no initial observation has been received yet for an agent, but
  49. some rewards for this same agent already occurred, we delete the agent's data
  50. up to here, b/c there is nothing to learn from these "premature" rewards.
  51. """
  52. __slots__ = (
  53. "id_",
  54. "agent_to_module_mapping_fn",
  55. "_agent_to_module_mapping",
  56. "observation_space",
  57. "action_space",
  58. "env_t_started",
  59. "env_t",
  60. "agent_t_started",
  61. "env_t_to_agent_t",
  62. "_hanging_actions_end",
  63. "_hanging_extra_model_outputs_end",
  64. "_hanging_rewards_end",
  65. "_hanging_rewards_begin",
  66. "is_terminated",
  67. "is_truncated",
  68. "agent_episodes",
  69. "_last_step_time",
  70. "_len_lookback_buffers",
  71. "_start_time",
  72. "_custom_data",
  73. )
  74. SKIP_ENV_TS_TAG = "S"
  75. def __init__(
  76. self,
  77. id_: Optional[str] = None,
  78. *,
  79. observations: Optional[List[MultiAgentDict]] = None,
  80. observation_space: Optional[gym.Space] = None,
  81. infos: Optional[List[MultiAgentDict]] = None,
  82. actions: Optional[List[MultiAgentDict]] = None,
  83. action_space: Optional[gym.Space] = None,
  84. rewards: Optional[List[MultiAgentDict]] = None,
  85. terminateds: Union[MultiAgentDict, bool] = False,
  86. truncateds: Union[MultiAgentDict, bool] = False,
  87. extra_model_outputs: Optional[List[MultiAgentDict]] = None,
  88. env_t_started: Optional[int] = None,
  89. agent_t_started: Optional[Dict[AgentID, int]] = None,
  90. len_lookback_buffer: Union[int, str] = "auto",
  91. agent_episode_ids: Optional[Dict[AgentID, str]] = None,
  92. agent_module_ids: Optional[Dict[AgentID, ModuleID]] = None,
  93. agent_to_module_mapping_fn: Optional[
  94. Callable[[AgentID, "MultiAgentEpisode"], ModuleID]
  95. ] = None,
  96. ):
  97. """Initializes a `MultiAgentEpisode`.
  98. Args:
  99. id_: Optional. Either a string to identify an episode or None.
  100. If None, a hexadecimal id is created. In case of providing
  101. a string, make sure that it is unique, as episodes get
  102. concatenated via this string.
  103. observations: A list of dictionaries mapping agent IDs to observations.
  104. Can be None. If provided, should match all other episode data
  105. (actions, rewards, etc.) in terms of list lengths and agent IDs.
  106. observation_space: An optional gym.spaces.Dict mapping agent IDs to
  107. individual agents' spaces, which all (individual agents') observations
  108. should abide to. If not None and this MultiAgentEpisode is numpy'ized
  109. (via the `self.to_numpy()` method), and data is appended or set, the new
  110. data will be checked for correctness.
  111. infos: A list of dictionaries mapping agent IDs to info dicts.
  112. Can be None. If provided, should match all other episode data
  113. (observations, rewards, etc.) in terms of list lengths and agent IDs.
  114. actions: A list of dictionaries mapping agent IDs to actions.
  115. Can be None. If provided, should match all other episode data
  116. (observations, rewards, etc.) in terms of list lengths and agent IDs.
  117. action_space: An optional gym.spaces.Dict mapping agent IDs to
  118. individual agents' spaces, which all (individual agents') actions
  119. should abide to. If not None and this MultiAgentEpisode is numpy'ized
  120. (via the `self.to_numpy()` method), and data is appended or set, the new
  121. data will be checked for correctness.
  122. rewards: A list of dictionaries mapping agent IDs to rewards.
  123. Can be None. If provided, should match all other episode data
  124. (actions, rewards, etc.) in terms of list lengths and agent IDs.
  125. terminateds: A boolean defining if an environment has
  126. terminated OR a MultiAgentDict mapping individual agent ids
  127. to boolean flags indicating whether individual agents have terminated.
  128. A special __all__ key in these dicts indicates, whether the episode
  129. is terminated for all agents.
  130. The default is `False`, i.e. the episode has not been terminated.
  131. truncateds: A boolean defining if the environment has been
  132. truncated OR a MultiAgentDict mapping individual agent ids
  133. to boolean flags indicating whether individual agents have been
  134. truncated. A special __all__ key in these dicts indicates, whether the
  135. episode is truncated for all agents.
  136. The default is `False`, i.e. the episode has not been truncated.
  137. extra_model_outputs: A list of dictionaries mapping agent IDs to their
  138. corresponding extra model outputs. Each of these "outputs" is a dict
  139. mapping keys (str) to model output values, for example for
  140. `key=STATE_OUT`, the values would be the internal state outputs for
  141. that agent.
  142. env_t_started: The env timestep (int) that defines the starting point
  143. of the episode. This is only larger zero, if an already ongoing episode
  144. chunk is being created, for example by slicing an ongoing episode or
  145. by calling the `cut()` method on an ongoing episode.
  146. agent_t_started: A dict mapping AgentIDs to the respective agent's (local)
  147. timestep at which its SingleAgentEpisode chunk started.
  148. len_lookback_buffer: The size of the lookback buffers to keep in
  149. front of this Episode for each type of data (observations, actions,
  150. etc..). If larger 0, will interpret the first `len_lookback_buffer`
  151. items in each type of data as NOT part of this actual
  152. episode chunk, but instead serve as "historical" record that may be
  153. viewed and used to derive new data from. For example, it might be
  154. necessary to have a lookback buffer of four if you would like to do
  155. observation frame stacking and your episode has been cut and you are now
  156. operating on a new chunk (continuing from the cut one). Then, for the
  157. first 3 items, you would have to be able to look back into the old
  158. chunk's data.
  159. If `len_lookback_buffer` is "auto" (default), will interpret all
  160. provided data in the constructor as part of the lookback buffers.
  161. agent_episode_ids: An optional dict mapping AgentIDs
  162. to their corresponding `SingleAgentEpisode`. If None, each
  163. `SingleAgentEpisode` in `MultiAgentEpisode.agent_episodes`
  164. will generate a hexadecimal code. If a dictionary is provided,
  165. make sure that IDs are unique, because the agents' `SingleAgentEpisode`
  166. instances are concatenated or recreated by it.
  167. agent_module_ids: An optional dict mapping AgentIDs to their respective
  168. ModuleIDs (these mapping are always valid for an entire episode and
  169. thus won't change during the course of this episode). If a mapping from
  170. agent to module has already been provided via this dict, the (optional)
  171. `agent_to_module_mapping_fn` will NOT be used again to map the same
  172. agent (agents do not change their assigned module in the course of
  173. one episode).
  174. agent_to_module_mapping_fn: A callable taking an AgentID and a
  175. MultiAgentEpisode as args and returning a ModuleID. Used to map agents
  176. that have not been mapped yet (because they just entered this episode)
  177. to a ModuleID. The resulting ModuleID is only stored inside the agent's
  178. SingleAgentEpisode object.
  179. """
  180. self.id_: str = id_ or uuid.uuid4().hex
  181. if agent_to_module_mapping_fn is None:
  182. from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
  183. agent_to_module_mapping_fn = (
  184. AlgorithmConfig.DEFAULT_AGENT_TO_MODULE_MAPPING_FN
  185. )
  186. self.agent_to_module_mapping_fn = agent_to_module_mapping_fn
  187. # In case a user - e.g. via callbacks - already forces a mapping to happen
  188. # via the `module_for()` API even before the agent has entered the episode
  189. # (and has its SingleAgentEpisode created), we store all aldeary done mappings
  190. # in this dict here.
  191. self._agent_to_module_mapping: Dict[AgentID, ModuleID] = agent_module_ids or {}
  192. # Lookback buffer length is not provided. Interpret all provided data as
  193. # lookback buffer.
  194. if len_lookback_buffer == "auto":
  195. len_lookback_buffer = len(rewards or [])
  196. self._len_lookback_buffers = len_lookback_buffer
  197. self.observation_space = observation_space or {}
  198. self.action_space = action_space or {}
  199. terminateds = terminateds or {}
  200. truncateds = truncateds or {}
  201. # The global last timestep of the episode and the timesteps when this chunk
  202. # started (excluding a possible lookback buffer).
  203. self.env_t_started = env_t_started or 0
  204. self.env_t = (
  205. (len(rewards) if rewards is not None else 0)
  206. - self._len_lookback_buffers
  207. + self.env_t_started
  208. )
  209. self.agent_t_started = defaultdict(int, agent_t_started or {})
  210. # Keeps track of the correspondence between agent steps and environment steps.
  211. # Under each AgentID as key is a InfiniteLookbackBuffer with the following
  212. # data in it:
  213. # The indices of the items in the data represent environment timesteps,
  214. # starting from index=0 for the `env.reset()` and with each `env.step()` call
  215. # increase by 1.
  216. # The values behind these (env timestep) indices represent the agent timesteps
  217. # happening at these env timesteps and the special value of
  218. # `self.SKIP_ENV_TS_TAG` means that the agent did NOT step at the given env
  219. # timestep.
  220. # Thus, agents that are part of the reset obs, will start their mapping data
  221. # with a [0 ...], all other agents will start their mapping data with:
  222. # [self.SKIP_ENV_TS_TAG, ...].
  223. self.env_t_to_agent_t: DefaultDict[
  224. AgentID, InfiniteLookbackBuffer
  225. ] = defaultdict(InfiniteLookbackBuffer)
  226. # Create caches for hanging actions/rewards/extra_model_outputs.
  227. # When an agent gets an observation (and then sends an action), but does not
  228. # receive immediately a next observation, we store the "hanging" action (and
  229. # related rewards and extra model outputs) in the caches postfixed w/ `_end`
  230. # until the next observation is received.
  231. self._hanging_actions_end = {}
  232. self._hanging_extra_model_outputs_end = defaultdict(dict)
  233. self._hanging_rewards_end = defaultdict(float)
  234. # In case of a `cut()` or `slice()`, we also need to store the hanging actions,
  235. # rewards, and extra model outputs that were already "hanging" in preceeding
  236. # episode slice.
  237. self._hanging_rewards_begin = defaultdict(float)
  238. # If this is an ongoing episode than the last `__all__` should be `False`
  239. self.is_terminated: bool = (
  240. terminateds
  241. if isinstance(terminateds, bool)
  242. else terminateds.get("__all__", False)
  243. )
  244. # If this is an ongoing episode than the last `__all__` should be `False`
  245. self.is_truncated: bool = (
  246. truncateds
  247. if isinstance(truncateds, bool)
  248. else truncateds.get("__all__", False)
  249. )
  250. # The individual agent SingleAgentEpisode objects.
  251. self.agent_episodes: Dict[AgentID, SingleAgentEpisode] = {}
  252. self._init_single_agent_episodes(
  253. agent_module_ids=agent_module_ids,
  254. agent_episode_ids=agent_episode_ids,
  255. observations=observations,
  256. infos=infos,
  257. actions=actions,
  258. rewards=rewards,
  259. terminateds=terminateds,
  260. truncateds=truncateds,
  261. extra_model_outputs=extra_model_outputs,
  262. )
  263. # Cache for custom data. May be used to store custom metrics from within a
  264. # callback for the ongoing episode (e.g. render images).
  265. self._custom_data = {}
  266. # Keep timer stats on deltas between steps.
  267. self._start_time = None
  268. self._last_step_time = None
  269. # Validate ourselves.
  270. self.validate()
  271. def add_env_reset(
  272. self,
  273. *,
  274. observations: MultiAgentDict,
  275. infos: Optional[MultiAgentDict] = None,
  276. ) -> None:
  277. """Stores initial observation.
  278. Args:
  279. observations: A dictionary mapping agent IDs to initial observations.
  280. Note that some agents may not have an initial observation.
  281. infos: A dictionary mapping agent IDs to initial info dicts.
  282. Note that some agents may not have an initial info dict. If not None,
  283. the agent IDs in `infos` must be a subset of those in `observations`
  284. meaning it would not be allowed to have an agent with an info dict,
  285. but not with an observation.
  286. """
  287. assert not self.is_done
  288. # Assume that this episode is completely empty and has not stepped yet.
  289. # Leave self.env_t (and self.env_t_started) at 0.
  290. assert self.env_t == self.env_t_started == 0
  291. infos = infos or {}
  292. # Note, all agents will have an initial observation, some may have an initial
  293. # info dict as well.
  294. for agent_id, agent_obs in observations.items():
  295. # Update env_t_to_agent_t mapping (all agents that are part of the reset
  296. # obs have their first mapping 0 (env_t) -> 0 (agent_t)).
  297. self.env_t_to_agent_t[agent_id].append(0)
  298. # Create SingleAgentEpisode, if necessary.
  299. if agent_id not in self.agent_episodes:
  300. self.agent_episodes[agent_id] = SingleAgentEpisode(
  301. agent_id=agent_id,
  302. module_id=self.module_for(agent_id),
  303. multi_agent_episode_id=self.id_,
  304. observation_space=self.observation_space.get(agent_id),
  305. action_space=self.action_space.get(agent_id),
  306. )
  307. # Add initial observations (and infos) to the agent's episode.
  308. self.agent_episodes[agent_id].add_env_reset(
  309. observation=agent_obs,
  310. infos=infos.get(agent_id),
  311. )
  312. # Validate our data.
  313. self.validate()
  314. # Start the timer for this episode.
  315. self._start_time = time.perf_counter()
  316. def add_env_step(
  317. self,
  318. observations: MultiAgentDict,
  319. actions: MultiAgentDict,
  320. rewards: MultiAgentDict,
  321. infos: Optional[MultiAgentDict] = None,
  322. *,
  323. terminateds: Optional[MultiAgentDict] = None,
  324. truncateds: Optional[MultiAgentDict] = None,
  325. extra_model_outputs: Optional[MultiAgentDict] = None,
  326. ) -> None:
  327. """Adds a timestep to the episode.
  328. Args:
  329. observations: A dictionary mapping agent IDs to their corresponding
  330. next observations. Note that some agents may not have stepped at this
  331. timestep.
  332. actions: Mandatory. A dictionary mapping agent IDs to their
  333. corresponding actions. Note that some agents may not have stepped at
  334. this timestep.
  335. rewards: Mandatory. A dictionary mapping agent IDs to their
  336. corresponding observations. Note that some agents may not have stepped
  337. at this timestep.
  338. infos: A dictionary mapping agent IDs to their
  339. corresponding info. Note that some agents may not have stepped at this
  340. timestep.
  341. terminateds: A dictionary mapping agent IDs to their `terminated` flags,
  342. indicating, whether the environment has been terminated for them.
  343. A special `__all__` key indicates that the episode is terminated for
  344. all agent IDs.
  345. terminateds: A dictionary mapping agent IDs to their `truncated` flags,
  346. indicating, whether the environment has been truncated for them.
  347. A special `__all__` key indicates that the episode is `truncated` for
  348. all agent IDs.
  349. extra_model_outputs: A dictionary mapping agent IDs to their
  350. corresponding specific model outputs (also in a dictionary; e.g.
  351. `vf_preds` for PPO).
  352. """
  353. # Cannot add data to an already done episode.
  354. if self.is_done:
  355. raise MultiAgentEnvError(
  356. "Cannot call `add_env_step` on a MultiAgentEpisode that is already "
  357. "done!"
  358. )
  359. infos = infos or {}
  360. terminateds = terminateds or {}
  361. truncateds = truncateds or {}
  362. extra_model_outputs = extra_model_outputs or {}
  363. # Increase (global) env step by one.
  364. self.env_t += 1
  365. # Find out, whether this episode is terminated/truncated (for all agents).
  366. # Case 1: all agents are terminated or all are truncated.
  367. self.is_terminated = terminateds.get("__all__", False)
  368. self.is_truncated = truncateds.get("__all__", False)
  369. # Find all agents that were done at prior timesteps and add the agents that are
  370. # done at the present timestep.
  371. agents_done = set(
  372. [aid for aid, sa_eps in self.agent_episodes.items() if sa_eps.is_done]
  373. + [aid for aid in terminateds if terminateds[aid]]
  374. + [aid for aid in truncateds if truncateds[aid]]
  375. )
  376. # Case 2: Some agents are truncated and the others are terminated -> Declare
  377. # this episode as terminated.
  378. if all(aid in set(agents_done) for aid in self.agent_ids):
  379. self.is_terminated = True
  380. # For all agents that are not stepping in this env step, but that are not done
  381. # yet -> Add a skip tag to their env- to agent-step mappings.
  382. stepped_agent_ids = set(observations.keys())
  383. for agent_id, env_t_to_agent_t in self.env_t_to_agent_t.items():
  384. if agent_id not in stepped_agent_ids:
  385. env_t_to_agent_t.append(self.SKIP_ENV_TS_TAG)
  386. # Loop through all agent IDs that we received data for in this step:
  387. # Those found in observations, actions, and rewards.
  388. agent_ids_with_data = (
  389. set(observations.keys())
  390. | set(actions.keys())
  391. | set(rewards.keys())
  392. | set(terminateds.keys())
  393. | set(truncateds.keys())
  394. | set(
  395. self.agent_episodes.keys()
  396. if terminateds.get("__all__") or truncateds.get("__all__")
  397. else set()
  398. )
  399. ) - {"__all__"}
  400. for agent_id in agent_ids_with_data:
  401. if agent_id not in self.agent_episodes:
  402. sa_episode = SingleAgentEpisode(
  403. agent_id=agent_id,
  404. module_id=self.module_for(agent_id),
  405. multi_agent_episode_id=self.id_,
  406. observation_space=self.observation_space.get(agent_id),
  407. action_space=self.action_space.get(agent_id),
  408. )
  409. else:
  410. sa_episode = self.agent_episodes[agent_id]
  411. # Collect value to be passed (at end of for-loop) into `add_env_step()`
  412. # call.
  413. _observation = observations.get(agent_id)
  414. _action = actions.get(agent_id)
  415. _reward = rewards.get(agent_id)
  416. _infos = infos.get(agent_id)
  417. _terminated = terminateds.get(agent_id, False) or self.is_terminated
  418. _truncated = truncateds.get(agent_id, False) or self.is_truncated
  419. _extra_model_outputs = extra_model_outputs.get(agent_id)
  420. # The value to place into the env- to agent-step map for this agent ID.
  421. # _agent_step = self.SKIP_ENV_TS_TAG
  422. # Agents, whose SingleAgentEpisode had already been done before this
  423. # step should NOT have received any data in this step.
  424. if sa_episode.is_done and any(
  425. v is not None
  426. for v in [_observation, _action, _reward, _infos, _extra_model_outputs]
  427. ):
  428. raise MultiAgentEnvError(
  429. f"Agent {agent_id} already had its `SingleAgentEpisode.is_done` "
  430. f"set to True, but still received data in a following step! "
  431. f"obs={_observation} act={_action} rew={_reward} info={_infos} "
  432. f"extra_model_outputs={_extra_model_outputs}."
  433. )
  434. _reward = _reward or 0.0
  435. # CASE 1: A complete agent step is available (in one env step).
  436. # -------------------------------------------------------------
  437. # We have an observation and an action for this agent ->
  438. # Add the agent step to the single agent episode.
  439. # ... action -> next obs + reward ...
  440. if _observation is not None and _action is not None:
  441. if agent_id not in rewards:
  442. raise MultiAgentEnvError(
  443. f"Agent {agent_id} acted (and received next obs), but did NOT "
  444. f"receive any reward from the env!"
  445. )
  446. # CASE 2: Step gets completed with a hanging action OR first observation.
  447. # ------------------------------------------------------------------------
  448. # We have an observation, but no action ->
  449. # a) Action (and extra model outputs) must be hanging already. Also use
  450. # collected hanging rewards and extra_model_outputs.
  451. # b) The observation is the first observation for this agent ID.
  452. elif _observation is not None and _action is None:
  453. _action = self._hanging_actions_end.pop(agent_id, None)
  454. # We have a hanging action (the agent had acted after the previous
  455. # observation, but the env had not responded - until now - with another
  456. # observation).
  457. # ...[hanging action] ... ... -> next obs + (reward)? ...
  458. if _action is not None:
  459. # Get the extra model output if available.
  460. _extra_model_outputs = self._hanging_extra_model_outputs_end.pop(
  461. agent_id, None
  462. )
  463. _reward = self._hanging_rewards_end.pop(agent_id, 0.0) + _reward
  464. # First observation for this agent, we have no hanging action.
  465. # ... [done]? ... -> [1st obs for agent ID]
  466. else:
  467. # The agent is already done -> The agent thus has never stepped once
  468. # and we do not have to create a SingleAgentEpisode for it.
  469. if _terminated or _truncated:
  470. self._del_hanging(agent_id)
  471. continue
  472. # This must be the agent's initial observation.
  473. else:
  474. # Prepend n skip tags to this agent's mapping + the initial [0].
  475. assert agent_id not in self.env_t_to_agent_t
  476. self.env_t_to_agent_t[agent_id].extend(
  477. [self.SKIP_ENV_TS_TAG] * self.env_t + [0]
  478. )
  479. self.env_t_to_agent_t[
  480. agent_id
  481. ].lookback = self._len_lookback_buffers
  482. # Make `add_env_reset` call and continue with next agent.
  483. sa_episode.add_env_reset(observation=_observation, infos=_infos)
  484. # Add possible reward to begin cache.
  485. self._hanging_rewards_begin[agent_id] += _reward
  486. # Now that the SAEps is valid, add it to our dict.
  487. self.agent_episodes[agent_id] = sa_episode
  488. continue
  489. # CASE 3: Step is started (by an action), but not completed (no next obs).
  490. # ------------------------------------------------------------------------
  491. # We have no observation, but we have a hanging action (used when we receive
  492. # the next obs for this agent in the future).
  493. elif agent_id not in observations and agent_id in actions:
  494. # Agent got truncated -> Error b/c we would need a last (truncation)
  495. # observation for this (otherwise, e.g. bootstrapping would not work).
  496. # [previous obs] [action] (hanging) ... ... [truncated]
  497. if _truncated:
  498. raise MultiAgentEnvError(
  499. f"Agent {agent_id} acted and then got truncated, but did NOT "
  500. "receive a last (truncation) observation, required for e.g. "
  501. "value function bootstrapping!"
  502. )
  503. # Agent got terminated.
  504. # [previous obs] [action] (hanging) ... ... [terminated]
  505. elif _terminated:
  506. # If the agent was terminated and no observation is provided,
  507. # duplicate the previous one (this is a technical "fix" to properly
  508. # complete the single agent episode; this last observation is never
  509. # used for learning anyway).
  510. _observation = sa_episode._last_added_observation
  511. _infos = sa_episode._last_added_infos
  512. # Agent is still alive.
  513. # [previous obs] [action] (hanging) ...
  514. else:
  515. # Hanging action, reward, and extra_model_outputs.
  516. assert agent_id not in self._hanging_actions_end
  517. self._hanging_actions_end[agent_id] = _action
  518. self._hanging_rewards_end[agent_id] = _reward
  519. self._hanging_extra_model_outputs_end[
  520. agent_id
  521. ] = _extra_model_outputs
  522. # CASE 4: Step has started in the past and is still ongoing (no observation,
  523. # no action).
  524. # --------------------------------------------------------------------------
  525. # Record reward and terminated/truncated flags.
  526. else:
  527. _action = self._hanging_actions_end.get(agent_id)
  528. # Agent is done.
  529. if _terminated or _truncated:
  530. # If the agent has NOT stepped, we treat it as not being
  531. # part of this episode.
  532. # ... ... [other agents doing stuff] ... ... [agent done]
  533. if _action is None:
  534. self._del_hanging(agent_id)
  535. continue
  536. # Agent got truncated -> Error b/c we would need a last (truncation)
  537. # observation for this (otherwise, e.g. bootstrapping would not
  538. # work).
  539. if _truncated:
  540. raise MultiAgentEnvError(
  541. f"Agent {agent_id} acted and then got truncated, but did "
  542. "NOT receive a last (truncation) observation, required "
  543. "for e.g. value function bootstrapping!"
  544. )
  545. # [obs] ... ... [hanging action] ... ... [done]
  546. # If the agent was terminated and no observation is provided,
  547. # duplicate the previous one (this is a technical "fix" to properly
  548. # complete the single agent episode; this last observation is never
  549. # used for learning anyway).
  550. _observation = sa_episode._last_added_observation
  551. _infos = sa_episode._last_added_infos
  552. # `_action` is already `get` above. We don't need to pop out from
  553. # the cache as it gets wiped out anyway below b/c the agent is
  554. # done.
  555. _extra_model_outputs = self._hanging_extra_model_outputs_end.pop(
  556. agent_id, None
  557. )
  558. _reward = self._hanging_rewards_end.pop(agent_id, 0.0) + _reward
  559. # The agent is still alive, just add current reward to cache.
  560. else:
  561. # But has never stepped in this episode -> add to begin cache.
  562. if agent_id not in self.agent_episodes:
  563. self._hanging_rewards_begin[agent_id] += _reward
  564. # Otherwise, add to end cache.
  565. else:
  566. self._hanging_rewards_end[agent_id] += _reward
  567. # If agent is stepping, add timestep to `SingleAgentEpisode`.
  568. if _observation is not None:
  569. sa_episode.add_env_step(
  570. observation=_observation,
  571. action=_action,
  572. reward=_reward,
  573. infos=_infos,
  574. terminated=_terminated,
  575. truncated=_truncated,
  576. extra_model_outputs=_extra_model_outputs,
  577. )
  578. # Update the env- to agent-step mapping.
  579. self.env_t_to_agent_t[agent_id].append(
  580. len(sa_episode) + sa_episode.observations.lookback
  581. )
  582. # Agent is also done. -> Erase all hanging values for this agent
  583. # (they should be empty at this point anyways).
  584. if _terminated or _truncated:
  585. self._del_hanging(agent_id)
  586. # Validate our data.
  587. self.validate()
  588. # Step time stats.
  589. self._last_step_time = time.perf_counter()
  590. if self._start_time is None:
  591. self._start_time = self._last_step_time
  592. def validate(self) -> None:
  593. """Validates the episode's data.
  594. This function ensures that the data stored to a `MultiAgentEpisode` is
  595. in order (e.g. that the correct number of observations, actions, rewards
  596. are there).
  597. """
  598. for eps in self.agent_episodes.values():
  599. eps.validate()
  600. # TODO (sven): Validate MultiAgentEpisode specifics, like the timestep mappings,
  601. # action/reward caches, etc..
  602. @property
  603. def custom_data(self):
  604. return self._custom_data
  605. @property
  606. def is_reset(self) -> bool:
  607. """Returns True if `self.add_env_reset()` has already been called."""
  608. return any(
  609. len(sa_episode.observations) > 0
  610. for sa_episode in self.agent_episodes.values()
  611. )
  612. @property
  613. def is_numpy(self) -> bool:
  614. """True, if the data in this episode is already stored as numpy arrays."""
  615. is_numpy = next(iter(self.agent_episodes.values())).is_numpy
  616. # Make sure that all single agent's episodes' `is_numpy` flags are the same.
  617. if not all(eps.is_numpy is is_numpy for eps in self.agent_episodes.values()):
  618. raise RuntimeError(
  619. f"Only some SingleAgentEpisode objects in {self} are converted to "
  620. f"numpy, others are not!"
  621. )
  622. return is_numpy
  623. @property
  624. def is_done(self):
  625. """Whether the episode is actually done (terminated or truncated).
  626. A done episode cannot be continued via `self.add_env_step()` or being
  627. concatenated on its right-side with another episode chunk or being
  628. succeeded via `self.cut()`.
  629. Note that in a multi-agent environment this does not necessarily
  630. correspond to single agents having terminated or being truncated.
  631. `self.is_terminated` should be `True`, if all agents are terminated and
  632. `self.is_truncated` should be `True`, if all agents are truncated. If
  633. only one or more (but not all!) agents are `terminated/truncated the
  634. `MultiAgentEpisode.is_terminated/is_truncated` should be `False`. This
  635. information about single agent's terminated/truncated states can always
  636. be retrieved from the `SingleAgentEpisode`s inside the 'MultiAgentEpisode`
  637. one.
  638. If all agents are either terminated or truncated, but in a mixed fashion,
  639. i.e. some are terminated and others are truncated: This is currently
  640. undefined and could potentially be a problem (if a user really implemented
  641. such a multi-agent env that behaves this way).
  642. Returns:
  643. Boolean defining if an episode has either terminated or truncated.
  644. """
  645. return self.is_terminated or self.is_truncated
  646. def to_numpy(self) -> "MultiAgentEpisode":
  647. """Converts this Episode's list attributes to numpy arrays.
  648. This means in particular that this episodes' lists (per single agent) of
  649. (possibly complex) data (e.g. an agent having a dict obs space) will be
  650. converted to (possibly complex) structs, whose leafs are now numpy arrays.
  651. Each of these leaf numpy arrays will have the same length (batch dimension)
  652. as the length of the original lists.
  653. Note that Columns.INFOS are NEVER numpy'ized and will remain a list
  654. (normally, a list of the original, env-returned dicts). This is due to the
  655. heterogeneous nature of INFOS returned by envs, which would make it unwieldy to
  656. convert this information to numpy arrays.
  657. After calling this method, no further data may be added to this episode via
  658. the `self.add_env_step()` method.
  659. Examples:
  660. .. testcode::
  661. import numpy as np
  662. from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
  663. from ray.rllib.env.tests.test_multi_agent_episode import (
  664. TestMultiAgentEpisode
  665. )
  666. # Create some multi-agent episode data.
  667. (
  668. observations,
  669. actions,
  670. rewards,
  671. terminateds,
  672. truncateds,
  673. infos,
  674. ) = TestMultiAgentEpisode._mock_multi_agent_records()
  675. # Define the agent ids.
  676. agent_ids = ["agent_1", "agent_2", "agent_3", "agent_4", "agent_5"]
  677. episode = MultiAgentEpisode(
  678. observations=observations,
  679. infos=infos,
  680. actions=actions,
  681. rewards=rewards,
  682. # Note: terminated/truncated have nothing to do with an episode
  683. # being converted `to_numpy` or not (via the `self.to_numpy()` method)!
  684. terminateds=terminateds,
  685. truncateds=truncateds,
  686. len_lookback_buffer=0, # no lookback; all data is actually "in" episode
  687. )
  688. # Episode has not been numpy'ized yet.
  689. assert not episode.is_numpy
  690. # We are still operating on lists.
  691. assert (
  692. episode.get_observations(
  693. indices=[1],
  694. agent_ids="agent_1",
  695. ) == {"agent_1": [1]}
  696. )
  697. # Numpy'ized the episode.
  698. episode.to_numpy()
  699. assert episode.is_numpy
  700. # Everything is now numpy arrays (with 0-axis of size
  701. # B=[len of requested slice]).
  702. assert (
  703. isinstance(episode.get_observations(
  704. indices=[1],
  705. agent_ids="agent_1",
  706. )["agent_1"], np.ndarray)
  707. )
  708. Returns:
  709. This `MultiAgentEpisode` object with the converted numpy data.
  710. """
  711. for agent_id, agent_eps in self.agent_episodes.copy().items():
  712. agent_eps.to_numpy()
  713. return self
  714. def concat_episode(self, other: "MultiAgentEpisode") -> None:
  715. """Adds the given `other` MultiAgentEpisode to the right side of `self`.
  716. In order for this to work, both chunks (`self` and `other`) must fit
  717. together that are split through `cut`. For sequential multi-agent environments
  718. using slice might cause problems from hanging observation/actions.
  719. This is checked by the IDs (must be identical), the time step counters
  720. (`self.env_t` must be the same as `other.env_t_started`), as well as the
  721. observations/infos of the individual agents at the concatenation boundaries.
  722. Also, `self.is_done` must not be True, meaning `self.is_terminated` and
  723. `self.is_truncated` are both False.
  724. Args:
  725. other: The other `MultiAgentEpisode` to be concatenated to this one.
  726. Returns:
  727. A `MultiAgentEpisode` instance containing the concatenated data
  728. from both episodes (`self` and `other`).
  729. """
  730. # Make sure the IDs match.
  731. assert other.id_ == self.id_
  732. # NOTE (sven): This is what we agreed on. As the replay buffers must be
  733. # able to concatenate.
  734. assert not self.is_done
  735. # Make sure the timesteps match.
  736. assert self.env_t == other.env_t_started
  737. # Validate `other`.
  738. other.validate()
  739. # Concatenate the individual SingleAgentEpisodes from both chunks.
  740. all_agent_ids = set(self.agent_ids) | set(other.agent_ids)
  741. for agent_id in all_agent_ids:
  742. sa_episode = self.agent_episodes.get(agent_id)
  743. # If agent is only in the new episode chunk -> Store all the data of `other`
  744. # wrt agent in `self`.
  745. if sa_episode is None:
  746. self.agent_episodes[agent_id] = other.agent_episodes[agent_id]
  747. self.env_t_to_agent_t[agent_id] = other.env_t_to_agent_t[agent_id]
  748. self.agent_t_started[agent_id] = other.agent_t_started[agent_id]
  749. self._copy_hanging(agent_id, other)
  750. # If the agent was done in `self`, ignore and continue. There should not be
  751. # any data of that agent in `other`.
  752. elif sa_episode.is_done:
  753. continue
  754. # If the agent has data in both chunks, concatenate on the single-agent
  755. # level, thereby making sure the hanging values (begin and end) match.
  756. elif agent_id in other.agent_episodes:
  757. sa_episode.concat_episode(other.agent_episodes[agent_id])
  758. # Override `self`'s hanging (end) values with `other`'s hanging (end).
  759. if agent_id in other._hanging_actions_end:
  760. self._hanging_actions_end[agent_id] = copy.deepcopy(
  761. other._hanging_actions_end[agent_id]
  762. )
  763. self._hanging_rewards_end[agent_id] = other._hanging_rewards_end[
  764. agent_id
  765. ]
  766. self._hanging_extra_model_outputs_end[agent_id] = copy.deepcopy(
  767. other._hanging_extra_model_outputs_end[agent_id]
  768. )
  769. # Concatenate the env- to agent-timestep mappings.
  770. j = self.env_t
  771. for i, val in enumerate(other.env_t_to_agent_t[agent_id][1:]):
  772. if val == self.SKIP_ENV_TS_TAG:
  773. self.env_t_to_agent_t[agent_id].append(self.SKIP_ENV_TS_TAG)
  774. else:
  775. self.env_t_to_agent_t[agent_id].append(i + 1 + j)
  776. # Otherwise, the agent is only in `self` and not done. All data is stored
  777. # already -> skip
  778. # else: pass
  779. # Update all timestep counters.
  780. self.env_t = other.env_t
  781. # Check, if the episode is terminated or truncated.
  782. if other.is_terminated:
  783. self.is_terminated = True
  784. elif other.is_truncated:
  785. self.is_truncated = True
  786. # Merge with `other`'s custom_data, but give `other` priority b/c we assume
  787. # that as a follow-up chunk of `self` other has a more complete version of
  788. # `custom_data`.
  789. self.custom_data.update(other.custom_data)
  790. # Validate.
  791. self.validate()
  792. def cut(self, len_lookback_buffer: int = 0) -> "MultiAgentEpisode":
  793. """Returns a successor episode chunk (of len=0) continuing from this Episode.
  794. The successor will have the same ID as `self`.
  795. If no lookback buffer is requested (len_lookback_buffer=0), the successor's
  796. observations will be the last observation(s) of `self` and its length will
  797. therefore be 0 (no further steps taken yet). If `len_lookback_buffer` > 0,
  798. the returned successor will have `len_lookback_buffer` observations (and
  799. actions, rewards, etc..) taken from the right side (end) of `self`. For example
  800. if `len_lookback_buffer=2`, the returned successor's lookback buffer actions
  801. will be identical to the results of `self.get_actions([-2, -1])`.
  802. This method is useful if you would like to discontinue building an episode
  803. chunk (b/c you have to return it from somewhere), but would like to have a new
  804. episode instance to continue building the actual gym.Env episode at a later
  805. time. Vie the `len_lookback_buffer` argument, the continuing chunk (successor)
  806. will still be able to "look back" into this predecessor episode's data (at
  807. least to some extend, depending on the value of `len_lookback_buffer`).
  808. Args:
  809. len_lookback_buffer: The number of environment timesteps to take along into
  810. the new chunk as "lookback buffer". A lookback buffer is additional data
  811. on the left side of the actual episode data for visibility purposes
  812. (but without actually being part of the new chunk). For example, if
  813. `self` ends in actions: agent_1=5,6,7 and agent_2=6,7, and we call
  814. `self.cut(len_lookback_buffer=2)`, the returned chunk will have
  815. actions 6 and 7 for both agents already in it, but still
  816. `t_started`==t==8 (not 7!) and a length of 0. If there is not enough
  817. data in `self` yet to fulfil the `len_lookback_buffer` request, the
  818. value of `len_lookback_buffer` is automatically adjusted (lowered).
  819. Returns:
  820. The successor Episode chunk of this one with the same ID and state and the
  821. only observation being the last observation in self.
  822. """
  823. assert len_lookback_buffer >= 0
  824. if self.is_done:
  825. raise RuntimeError(
  826. "Can't call `MultiAgentEpisode.cut()` when the episode is already done!"
  827. )
  828. # If there is hanging data (e.g. actions) in the agents' caches, we might have
  829. # to re-adjust the lookback len further into the past to make sure that these
  830. # agents have at least one observation to look back to. Otherwise, the timestep
  831. # that got cut into will be "lost" for learning from it.
  832. orig_len_lb = len_lookback_buffer
  833. for agent_id, agent_actions in self._hanging_actions_end.items():
  834. assert self.env_t_to_agent_t[agent_id].get(-1) == self.SKIP_ENV_TS_TAG
  835. for i in range(orig_len_lb, len(self.env_t_to_agent_t[agent_id].data) + 1):
  836. if self.env_t_to_agent_t[agent_id].get(-i) != self.SKIP_ENV_TS_TAG:
  837. len_lookback_buffer = max(len_lookback_buffer, i - 1)
  838. break
  839. # Initialize this episode chunk with the most recent observations
  840. # and infos (even if lookback is zero). Similar to an initial `env.reset()`
  841. indices_obs_and_infos = slice(-len_lookback_buffer - 1, None)
  842. indices_rest = (
  843. slice(-len_lookback_buffer, None)
  844. if len_lookback_buffer > 0
  845. else slice(None, 0) # -> empty slice
  846. )
  847. observations = self.get_observations(
  848. indices=indices_obs_and_infos, return_list=True
  849. )
  850. infos = self.get_infos(indices=indices_obs_and_infos, return_list=True)
  851. actions = self.get_actions(indices=indices_rest, return_list=True)
  852. rewards = self.get_rewards(indices=indices_rest, return_list=True)
  853. extra_model_outputs = self.get_extra_model_outputs(
  854. key=None, # all keys
  855. indices=indices_rest,
  856. return_list=True,
  857. )
  858. successor = MultiAgentEpisode(
  859. # Same ID.
  860. id_=self.id_,
  861. observations=observations,
  862. observation_space=self.observation_space,
  863. infos=infos,
  864. actions=actions,
  865. action_space=self.action_space,
  866. rewards=rewards,
  867. # List of MADicts, mapping agent IDs to their respective extra model output
  868. # dicts.
  869. extra_model_outputs=extra_model_outputs,
  870. terminateds=self.get_terminateds(),
  871. truncateds=self.get_truncateds(),
  872. # Continue with `self`'s current timesteps.
  873. env_t_started=self.env_t,
  874. agent_t_started={
  875. aid: self.agent_episodes[aid].t
  876. for aid in self.agent_ids
  877. if not self.agent_episodes[aid].is_done
  878. },
  879. # Same AgentIDs and SingleAgentEpisode IDs.
  880. agent_episode_ids=self.agent_episode_ids,
  881. agent_module_ids={
  882. aid: self.agent_episodes[aid].module_id for aid in self.agent_ids
  883. },
  884. agent_to_module_mapping_fn=self.agent_to_module_mapping_fn,
  885. # All data we provided to the c'tor goes into the lookback buffer.
  886. len_lookback_buffer="auto",
  887. )
  888. # Copy over the hanging (end) values into the hanging (begin) caches of the
  889. # successor.
  890. successor._hanging_rewards_begin = self._hanging_rewards_end.copy()
  891. # Deepcopy all custom data in `self` to be continued in the cut episode.
  892. successor._custom_data = copy.deepcopy(self.custom_data)
  893. return successor
  894. @property
  895. def agent_ids(self) -> Set[AgentID]:
  896. """Returns the agent ids."""
  897. return set(self.agent_episodes.keys())
  898. @property
  899. def agent_episode_ids(self) -> MultiAgentDict:
  900. """Returns ids from each agent's `SingleAgentEpisode`."""
  901. return {
  902. agent_id: agent_eps.id_
  903. for agent_id, agent_eps in self.agent_episodes.items()
  904. }
  905. def module_for(self, agent_id: AgentID) -> Optional[ModuleID]:
  906. """Returns the ModuleID for a given AgentID.
  907. Forces the agent-to-module mapping to be performed (via
  908. `self.agent_to_module_mapping_fn`), if this has not been done yet.
  909. Note that all such mappings are stored in the `self._agent_to_module_mapping`
  910. property.
  911. Args:
  912. agent_id: The AgentID to get a mapped ModuleID for.
  913. Returns:
  914. The ModuleID mapped to from the given `agent_id`.
  915. """
  916. if agent_id not in self._agent_to_module_mapping:
  917. module_id = self._agent_to_module_mapping[
  918. agent_id
  919. ] = self.agent_to_module_mapping_fn(agent_id, self)
  920. return module_id
  921. else:
  922. return self._agent_to_module_mapping[agent_id]
  923. def get_observations(
  924. self,
  925. indices: Optional[Union[int, List[int], slice]] = None,
  926. agent_ids: Optional[Union[Collection[AgentID], AgentID]] = None,
  927. *,
  928. env_steps: bool = True,
  929. # global_indices: bool = False,
  930. neg_index_as_lookback: bool = False,
  931. fill: Optional[Any] = None,
  932. one_hot_discrete: bool = False,
  933. return_list: bool = False,
  934. ) -> Union[MultiAgentDict, List[MultiAgentDict]]:
  935. """Returns agents' observations or batched ranges thereof from this episode.
  936. Args:
  937. indices: A single int is interpreted as an index, from which to return the
  938. individual observation stored at this index.
  939. A list of ints is interpreted as a list of indices from which to gather
  940. individual observations in a batch of size len(indices).
  941. A slice object is interpreted as a range of observations to be returned.
  942. Thereby, negative indices by default are interpreted as "before the end"
  943. unless the `neg_index_as_lookback=True` option is used, in which case
  944. negative indices are interpreted as "before ts=0", meaning going back
  945. into the lookback buffer.
  946. If None, will return all observations (from ts=0 to the end).
  947. agent_ids: An optional collection of AgentIDs or a single AgentID to get
  948. observations for. If None, will return observations for all agents in
  949. this episode.
  950. env_steps: Whether `indices` should be interpreted as environment time steps
  951. (True) or per-agent timesteps (False).
  952. neg_index_as_lookback: If True, negative values in `indices` are
  953. interpreted as "before ts=0", meaning going back into the lookback
  954. buffer. For example, an episode with agent A's observations
  955. [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the lookback buffer range
  956. (ts=0 item is 7), will respond to `get_observations(-1, agent_ids=[A],
  957. neg_index_as_lookback=True)` with {A: `6`} and to
  958. `get_observations(slice(-2, 1), agent_ids=[A],
  959. neg_index_as_lookback=True)` with {A: `[5, 6, 7]`}.
  960. fill: An optional value to use for filling up the returned results at
  961. the boundaries. This filling only happens if the requested index range's
  962. start/stop boundaries exceed the episode's boundaries (including the
  963. lookback buffer on the left side). This comes in very handy, if users
  964. don't want to worry about reaching such boundaries and want to zero-pad.
  965. For example, an episode with agent A' observations [10, 11, 12, 13, 14]
  966. and lookback buffer size of 2 (meaning observations `10` and `11` are
  967. part of the lookback buffer) will respond to
  968. `get_observations(slice(-7, -2), agent_ids=[A], fill=0.0)` with
  969. `{A: [0.0, 0.0, 10, 11, 12]}`.
  970. one_hot_discrete: If True, will return one-hot vectors (instead of
  971. int-values) for those sub-components of a (possibly complex) observation
  972. space that are Discrete or MultiDiscrete. Note that if `fill=0` and the
  973. requested `indices` are out of the range of our data, the returned
  974. one-hot vectors will actually be zero-hot (all slots zero).
  975. return_list: Whether to return a list of multi-agent dicts (instead of
  976. a single multi-agent dict of lists/structs). False by default. This
  977. option can only be used when `env_steps` is True due to the fact the
  978. such a list can only be interpreted as one env step per list item
  979. (would not work with agent steps).
  980. Returns:
  981. A dictionary mapping agent IDs to observations (at the given
  982. `indices`). If `env_steps` is True, only agents that have stepped
  983. (were ready) at the given env step `indices` are returned (i.e. not all
  984. agent IDs are necessarily in the keys).
  985. If `return_list` is True, returns a list of MultiAgentDicts (mapping agent
  986. IDs to observations) instead.
  987. """
  988. return self._get(
  989. what="observations",
  990. indices=indices,
  991. agent_ids=agent_ids,
  992. env_steps=env_steps,
  993. neg_index_as_lookback=neg_index_as_lookback,
  994. fill=fill,
  995. one_hot_discrete=one_hot_discrete,
  996. return_list=return_list,
  997. )
  998. def get_infos(
  999. self,
  1000. indices: Optional[Union[int, List[int], slice]] = None,
  1001. agent_ids: Optional[Union[Collection[AgentID], AgentID]] = None,
  1002. *,
  1003. env_steps: bool = True,
  1004. neg_index_as_lookback: bool = False,
  1005. fill: Optional[Any] = None,
  1006. return_list: bool = False,
  1007. ) -> Union[MultiAgentDict, List[MultiAgentDict]]:
  1008. """Returns agents' info dicts or list (ranges) thereof from this episode.
  1009. Args:
  1010. indices: A single int is interpreted as an index, from which to return the
  1011. individual info dict stored at this index.
  1012. A list of ints is interpreted as a list of indices from which to gather
  1013. individual info dicts in a list of size len(indices).
  1014. A slice object is interpreted as a range of info dicts to be returned.
  1015. Thereby, negative indices by default are interpreted as "before the end"
  1016. unless the `neg_index_as_lookback=True` option is used, in which case
  1017. negative indices are interpreted as "before ts=0", meaning going back
  1018. into the lookback buffer.
  1019. If None, will return all infos (from ts=0 to the end).
  1020. agent_ids: An optional collection of AgentIDs or a single AgentID to get
  1021. info dicts for. If None, will return info dicts for all agents in
  1022. this episode.
  1023. env_steps: Whether `indices` should be interpreted as environment time steps
  1024. (True) or per-agent timesteps (False).
  1025. neg_index_as_lookback: If True, negative values in `indices` are
  1026. interpreted as "before ts=0", meaning going back into the lookback
  1027. buffer. For example, an episode with agent A's info dicts
  1028. [{"l":4}, {"l":5}, {"l":6}, {"a":7}, {"b":8}, {"c":9}], where the
  1029. first 3 items are the lookback buffer (ts=0 item is {"a": 7}), will
  1030. respond to `get_infos(-1, agent_ids=A, neg_index_as_lookback=True)`
  1031. with `{A: {"l":6}}` and to
  1032. `get_infos(slice(-2, 1), agent_ids=A, neg_index_as_lookback=True)`
  1033. with `{A: [{"l":5}, {"l":6}, {"a":7}]}`.
  1034. fill: An optional value to use for filling up the returned results at
  1035. the boundaries. This filling only happens if the requested index range's
  1036. start/stop boundaries exceed the episode's boundaries (including the
  1037. lookback buffer on the left side). This comes in very handy, if users
  1038. don't want to worry about reaching such boundaries and want to
  1039. auto-fill. For example, an episode with agent A's infos being
  1040. [{"l":10}, {"l":11}, {"a":12}, {"b":13}, {"c":14}] and lookback buffer
  1041. size of 2 (meaning infos {"l":10}, {"l":11} are part of the lookback
  1042. buffer) will respond to `get_infos(slice(-7, -2), agent_ids=A,
  1043. fill={"o": 0.0})` with
  1044. `{A: [{"o":0.0}, {"o":0.0}, {"l":10}, {"l":11}, {"a":12}]}`.
  1045. return_list: Whether to return a list of multi-agent dicts (instead of
  1046. a single multi-agent dict of lists/structs). False by default. This
  1047. option can only be used when `env_steps` is True due to the fact the
  1048. such a list can only be interpreted as one env step per list item
  1049. (would not work with agent steps).
  1050. Returns:
  1051. A dictionary mapping agent IDs to observations (at the given
  1052. `indices`). If `env_steps` is True, only agents that have stepped
  1053. (were ready) at the given env step `indices` are returned (i.e. not all
  1054. agent IDs are necessarily in the keys).
  1055. If `return_list` is True, returns a list of MultiAgentDicts (mapping agent
  1056. IDs to infos) instead.
  1057. """
  1058. return self._get(
  1059. what="infos",
  1060. indices=indices,
  1061. agent_ids=agent_ids,
  1062. env_steps=env_steps,
  1063. neg_index_as_lookback=neg_index_as_lookback,
  1064. fill=fill,
  1065. return_list=return_list,
  1066. )
  1067. def get_actions(
  1068. self,
  1069. indices: Optional[Union[int, List[int], slice]] = None,
  1070. agent_ids: Optional[Union[Collection[AgentID], AgentID]] = None,
  1071. *,
  1072. env_steps: bool = True,
  1073. neg_index_as_lookback: bool = False,
  1074. fill: Optional[Any] = None,
  1075. one_hot_discrete: bool = False,
  1076. return_list: bool = False,
  1077. ) -> Union[MultiAgentDict, List[MultiAgentDict]]:
  1078. """Returns agents' actions or batched ranges thereof from this episode.
  1079. Args:
  1080. indices: A single int is interpreted as an index, from which to return the
  1081. individual actions stored at this index.
  1082. A list of ints is interpreted as a list of indices from which to gather
  1083. individual actions in a batch of size len(indices).
  1084. A slice object is interpreted as a range of actions to be returned.
  1085. Thereby, negative indices by default are interpreted as "before the end"
  1086. unless the `neg_index_as_lookback=True` option is used, in which case
  1087. negative indices are interpreted as "before ts=0", meaning going back
  1088. into the lookback buffer.
  1089. If None, will return all actions (from ts=0 to the end).
  1090. agent_ids: An optional collection of AgentIDs or a single AgentID to get
  1091. actions for. If None, will return actions for all agents in
  1092. this episode.
  1093. env_steps: Whether `indices` should be interpreted as environment time steps
  1094. (True) or per-agent timesteps (False).
  1095. neg_index_as_lookback: If True, negative values in `indices` are
  1096. interpreted as "before ts=0", meaning going back into the lookback
  1097. buffer. For example, an episode with agent A's actions
  1098. [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the lookback buffer range
  1099. (ts=0 item is 7), will respond to `get_actions(-1, agent_ids=[A],
  1100. neg_index_as_lookback=True)` with {A: `6`} and to
  1101. `get_actions(slice(-2, 1), agent_ids=[A],
  1102. neg_index_as_lookback=True)` with {A: `[5, 6, 7]`}.
  1103. fill: An optional value to use for filling up the returned results at
  1104. the boundaries. This filling only happens if the requested index range's
  1105. start/stop boundaries exceed the episode's boundaries (including the
  1106. lookback buffer on the left side). This comes in very handy, if users
  1107. don't want to worry about reaching such boundaries and want to zero-pad.
  1108. For example, an episode with agent A' actions [10, 11, 12, 13, 14]
  1109. and lookback buffer size of 2 (meaning actions `10` and `11` are
  1110. part of the lookback buffer) will respond to
  1111. `get_actions(slice(-7, -2), agent_ids=[A], fill=0.0)` with
  1112. `{A: [0.0, 0.0, 10, 11, 12]}`.
  1113. one_hot_discrete: If True, will return one-hot vectors (instead of
  1114. int-values) for those sub-components of a (possibly complex) observation
  1115. space that are Discrete or MultiDiscrete. Note that if `fill=0` and the
  1116. requested `indices` are out of the range of our data, the returned
  1117. one-hot vectors will actually be zero-hot (all slots zero).
  1118. return_list: Whether to return a list of multi-agent dicts (instead of
  1119. a single multi-agent dict of lists/structs). False by default. This
  1120. option can only be used when `env_steps` is True due to the fact the
  1121. such a list can only be interpreted as one env step per list item
  1122. (would not work with agent steps).
  1123. Returns:
  1124. A dictionary mapping agent IDs to actions (at the given
  1125. `indices`). If `env_steps` is True, only agents that have stepped
  1126. (were ready) at the given env step `indices` are returned (i.e. not all
  1127. agent IDs are necessarily in the keys).
  1128. If `return_list` is True, returns a list of MultiAgentDicts (mapping agent
  1129. IDs to actions) instead.
  1130. """
  1131. return self._get(
  1132. what="actions",
  1133. indices=indices,
  1134. agent_ids=agent_ids,
  1135. env_steps=env_steps,
  1136. neg_index_as_lookback=neg_index_as_lookback,
  1137. fill=fill,
  1138. one_hot_discrete=one_hot_discrete,
  1139. return_list=return_list,
  1140. )
  1141. def get_rewards(
  1142. self,
  1143. indices: Optional[Union[int, List[int], slice]] = None,
  1144. agent_ids: Optional[Union[Collection[AgentID], AgentID]] = None,
  1145. *,
  1146. env_steps: bool = True,
  1147. neg_index_as_lookback: bool = False,
  1148. fill: Optional[float] = None,
  1149. return_list: bool = False,
  1150. ) -> Union[MultiAgentDict, List[MultiAgentDict]]:
  1151. """Returns agents' rewards or batched ranges thereof from this episode.
  1152. Args:
  1153. indices: A single int is interpreted as an index, from which to return the
  1154. individual rewards stored at this index.
  1155. A list of ints is interpreted as a list of indices from which to gather
  1156. individual rewards in a batch of size len(indices).
  1157. A slice object is interpreted as a range of rewards to be returned.
  1158. Thereby, negative indices by default are interpreted as "before the end"
  1159. unless the `neg_index_as_lookback=True` option is used, in which case
  1160. negative indices are interpreted as "before ts=0", meaning going back
  1161. into the lookback buffer.
  1162. If None, will return all rewards (from ts=0 to the end).
  1163. agent_ids: An optional collection of AgentIDs or a single AgentID to get
  1164. rewards for. If None, will return rewards for all agents in
  1165. this episode.
  1166. env_steps: Whether `indices` should be interpreted as environment time steps
  1167. (True) or per-agent timesteps (False).
  1168. neg_index_as_lookback: If True, negative values in `indices` are
  1169. interpreted as "before ts=0", meaning going back into the lookback
  1170. buffer. For example, an episode with agent A's rewards
  1171. [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the lookback buffer range
  1172. (ts=0 item is 7), will respond to `get_rewards(-1, agent_ids=[A],
  1173. neg_index_as_lookback=True)` with {A: `6`} and to
  1174. `get_rewards(slice(-2, 1), agent_ids=[A],
  1175. neg_index_as_lookback=True)` with {A: `[5, 6, 7]`}.
  1176. fill: An optional float value to use for filling up the returned results at
  1177. the boundaries. This filling only happens if the requested index range's
  1178. start/stop boundaries exceed the episode's boundaries (including the
  1179. lookback buffer on the left side). This comes in very handy, if users
  1180. don't want to worry about reaching such boundaries and want to zero-pad.
  1181. For example, an episode with agent A' rewards [10, 11, 12, 13, 14]
  1182. and lookback buffer size of 2 (meaning rewards `10` and `11` are
  1183. part of the lookback buffer) will respond to
  1184. `get_rewards(slice(-7, -2), agent_ids=[A], fill=0.0)` with
  1185. `{A: [0.0, 0.0, 10, 11, 12]}`.
  1186. return_list: Whether to return a list of multi-agent dicts (instead of
  1187. a single multi-agent dict of lists/structs). False by default. This
  1188. option can only be used when `env_steps` is True due to the fact the
  1189. such a list can only be interpreted as one env step per list item
  1190. (would not work with agent steps).
  1191. Returns:
  1192. A dictionary mapping agent IDs to rewards (at the given
  1193. `indices`). If `env_steps` is True, only agents that have stepped
  1194. (were ready) at the given env step `indices` are returned (i.e. not all
  1195. agent IDs are necessarily in the keys).
  1196. If `return_list` is True, returns a list of MultiAgentDicts (mapping agent
  1197. IDs to rewards) instead.
  1198. """
  1199. return self._get(
  1200. what="rewards",
  1201. indices=indices,
  1202. agent_ids=agent_ids,
  1203. env_steps=env_steps,
  1204. neg_index_as_lookback=neg_index_as_lookback,
  1205. fill=fill,
  1206. return_list=return_list,
  1207. )
  1208. def get_extra_model_outputs(
  1209. self,
  1210. key: Optional[str] = None,
  1211. indices: Optional[Union[int, List[int], slice]] = None,
  1212. agent_ids: Optional[Union[Collection[AgentID], AgentID]] = None,
  1213. *,
  1214. env_steps: bool = True,
  1215. neg_index_as_lookback: bool = False,
  1216. fill: Optional[Any] = None,
  1217. return_list: bool = False,
  1218. ) -> Union[MultiAgentDict, List[MultiAgentDict]]:
  1219. """Returns agents' actions or batched ranges thereof from this episode.
  1220. Args:
  1221. key: The `key` within each agents' extra_model_outputs dict to extract
  1222. data for. If None, return data of all extra model output keys.
  1223. indices: A single int is interpreted as an index, from which to return the
  1224. individual extra model outputs stored at this index.
  1225. A list of ints is interpreted as a list of indices from which to gather
  1226. individual extra model outputs in a batch of size len(indices).
  1227. A slice object is interpreted as a range of extra model outputs to be
  1228. returned.
  1229. Thereby, negative indices by default are interpreted as "before the end"
  1230. unless the `neg_index_as_lookback=True` option is used, in which case
  1231. negative indices are interpreted as "before ts=0", meaning going back
  1232. into the lookback buffer.
  1233. If None, will return all extra model outputs (from ts=0 to the end).
  1234. agent_ids: An optional collection of AgentIDs or a single AgentID to get
  1235. extra model outputs for. If None, will return extra model outputs for
  1236. all agents in this episode.
  1237. env_steps: Whether `indices` should be interpreted as environment time steps
  1238. (True) or per-agent timesteps (False).
  1239. neg_index_as_lookback: If True, negative values in `indices` are
  1240. interpreted as "before ts=0", meaning going back into the lookback
  1241. buffer. For example, an episode with agent A's actions
  1242. [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the lookback buffer range
  1243. (ts=0 item is 7), will respond to `get_actions(-1, agent_ids=[A],
  1244. neg_index_as_lookback=True)` with {A: `6`} and to
  1245. `get_actions(slice(-2, 1), agent_ids=[A],
  1246. neg_index_as_lookback=True)` with {A: `[5, 6, 7]`}.
  1247. fill: An optional value to use for filling up the returned results at
  1248. the boundaries. This filling only happens if the requested index range's
  1249. start/stop boundaries exceed the episode's boundaries (including the
  1250. lookback buffer on the left side). This comes in very handy, if users
  1251. don't want to worry about reaching such boundaries and want to zero-pad.
  1252. For example, an episode with agent A' actions [10, 11, 12, 13, 14]
  1253. and lookback buffer size of 2 (meaning actions `10` and `11` are
  1254. part of the lookback buffer) will respond to
  1255. `get_actions(slice(-7, -2), agent_ids=[A], fill=0.0)` with
  1256. `{A: [0.0, 0.0, 10, 11, 12]}`.
  1257. one_hot_discrete: If True, will return one-hot vectors (instead of
  1258. int-values) for those sub-components of a (possibly complex) observation
  1259. space that are Discrete or MultiDiscrete. Note that if `fill=0` and the
  1260. requested `indices` are out of the range of our data, the returned
  1261. one-hot vectors will actually be zero-hot (all slots zero).
  1262. return_list: Whether to return a list of multi-agent dicts (instead of
  1263. a single multi-agent dict of lists/structs). False by default. This
  1264. option can only be used when `env_steps` is True due to the fact the
  1265. such a list can only be interpreted as one env step per list item
  1266. (would not work with agent steps).
  1267. Returns:
  1268. A dictionary mapping agent IDs to actions (at the given
  1269. `indices`). If `env_steps` is True, only agents that have stepped
  1270. (were ready) at the given env step `indices` are returned (i.e. not all
  1271. agent IDs are necessarily in the keys).
  1272. If `return_list` is True, returns a list of MultiAgentDicts (mapping agent
  1273. IDs to extra_model_outputs) instead.
  1274. """
  1275. return self._get(
  1276. what="extra_model_outputs",
  1277. extra_model_outputs_key=key,
  1278. indices=indices,
  1279. agent_ids=agent_ids,
  1280. env_steps=env_steps,
  1281. neg_index_as_lookback=neg_index_as_lookback,
  1282. fill=fill,
  1283. return_list=return_list,
  1284. )
  1285. def get_terminateds(self) -> MultiAgentDict:
  1286. """Gets the terminateds at given indices."""
  1287. terminateds = {
  1288. agent_id: self.agent_episodes[agent_id].is_terminated
  1289. for agent_id in self.agent_ids
  1290. }
  1291. terminateds.update({"__all__": self.is_terminated})
  1292. return terminateds
  1293. def get_truncateds(self) -> MultiAgentDict:
  1294. truncateds = {
  1295. agent_id: self.agent_episodes[agent_id].is_truncated
  1296. for agent_id in self.agent_ids
  1297. }
  1298. truncateds.update({"__all__": self.is_terminated})
  1299. return truncateds
  1300. def slice(
  1301. self,
  1302. slice_: slice,
  1303. *,
  1304. len_lookback_buffer: Optional[int] = None,
  1305. ) -> "MultiAgentEpisode":
  1306. """Returns a slice of this episode with the given slice object.
  1307. Works analogous to
  1308. :py:meth:`~ray.rllib.env.single_agent_episode.SingleAgentEpisode.slice`
  1309. However, the important differences are:
  1310. - `slice_` is provided in (global) env steps, not agent steps.
  1311. - In case `slice_` ends - for a certain agent - in an env step, where that
  1312. particular agent does not have an observation, the previous observation will
  1313. be included, but the next action and sum of rewards until this point will
  1314. be stored in the agent's hanging values caches for the returned
  1315. MultiAgentEpisode slice.
  1316. .. testcode::
  1317. from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
  1318. from ray.rllib.utils.test_utils import check
  1319. # Generate a simple multi-agent episode.
  1320. observations = [
  1321. {"a0": 0, "a1": 0}, # 0
  1322. { "a1": 1}, # 1
  1323. { "a1": 2}, # 2
  1324. {"a0": 3, "a1": 3}, # 3
  1325. {"a0": 4}, # 4
  1326. ]
  1327. # Actions are the same as observations (except for last obs, which doesn't
  1328. # have an action).
  1329. actions = observations[:-1]
  1330. # Make up a reward for each action.
  1331. rewards = [
  1332. {aid: r / 10 + 0.1 for aid, r in o.items()}
  1333. for o in observations
  1334. ]
  1335. episode = MultiAgentEpisode(
  1336. observations=observations,
  1337. actions=actions,
  1338. rewards=rewards,
  1339. len_lookback_buffer=0,
  1340. )
  1341. # Slice the episode and check results.
  1342. slice = episode[1:3]
  1343. a0 = slice.agent_episodes["a0"]
  1344. a1 = slice.agent_episodes["a1"]
  1345. check((a0.observations, a1.observations), ([3], [1, 2, 3]))
  1346. check((a0.actions, a1.actions), ([], [1, 2]))
  1347. check((a0.rewards, a1.rewards), ([], [0.2, 0.3]))
  1348. check((a0.is_done, a1.is_done), (False, False))
  1349. # If a slice ends in a "gap" for an agent, expect actions and rewards to be
  1350. # cached for this agent.
  1351. slice = episode[:2]
  1352. a0 = slice.agent_episodes["a0"]
  1353. check(a0.observations, [0])
  1354. check(a0.actions, [])
  1355. check(a0.rewards, [])
  1356. check(slice._hanging_actions_end["a0"], 0)
  1357. check(slice._hanging_rewards_end["a0"], 0.1)
  1358. Args:
  1359. slice_: The slice object to use for slicing. This should exclude the
  1360. lookback buffer, which will be prepended automatically to the returned
  1361. slice.
  1362. len_lookback_buffer: If not None, forces the returned slice to try to have
  1363. this number of timesteps in its lookback buffer (if available). If None
  1364. (default), tries to make the returned slice's lookback as large as the
  1365. current lookback buffer of this episode (`self`).
  1366. Returns:
  1367. The new MultiAgentEpisode representing the requested slice.
  1368. """
  1369. if slice_.step not in [1, None]:
  1370. raise NotImplementedError(
  1371. "Slicing MultiAgentEnv with a step other than 1 (you used"
  1372. f" {slice_.step}) is not supported!"
  1373. )
  1374. # Translate `slice_` into one that only contains 0-or-positive ints and will
  1375. # NOT contain any None.
  1376. start = slice_.start
  1377. stop = slice_.stop
  1378. # Start is None -> 0.
  1379. if start is None:
  1380. start = 0
  1381. # Start is negative -> Interpret index as counting "from end".
  1382. elif start < 0:
  1383. start = max(len(self) + start, 0)
  1384. # Start is larger than len(self) -> Clip to len(self).
  1385. elif start > len(self):
  1386. start = len(self)
  1387. # Stop is None -> Set stop to our len (one ts past last valid index).
  1388. if stop is None:
  1389. stop = len(self)
  1390. # Stop is negative -> Interpret index as counting "from end".
  1391. elif stop < 0:
  1392. stop = max(len(self) + stop, 0)
  1393. # Stop is larger than len(self) -> Clip to len(self).
  1394. elif stop > len(self):
  1395. stop = len(self)
  1396. ref_lookback = None
  1397. try:
  1398. for aid, sa_episode in self.agent_episodes.items():
  1399. if ref_lookback is None:
  1400. ref_lookback = sa_episode.observations.lookback
  1401. assert sa_episode.observations.lookback == ref_lookback
  1402. assert sa_episode.actions.lookback == ref_lookback
  1403. assert sa_episode.rewards.lookback == ref_lookback
  1404. assert all(
  1405. ilb.lookback == ref_lookback
  1406. for ilb in sa_episode.extra_model_outputs.values()
  1407. )
  1408. except AssertionError:
  1409. raise ValueError(
  1410. "Can only slice a MultiAgentEpisode if all lookback buffers in this "
  1411. "episode have the exact same size!"
  1412. )
  1413. # Determine terminateds/truncateds and when (in agent timesteps) the
  1414. # single-agent episode slices start.
  1415. terminateds = {}
  1416. truncateds = {}
  1417. agent_t_started = {}
  1418. for aid, sa_episode in self.agent_episodes.items():
  1419. mapping = self.env_t_to_agent_t[aid]
  1420. # If the (agent) timestep directly at the slice stop boundary is equal to
  1421. # the length of the single-agent episode of this agent -> Use the
  1422. # single-agent episode's terminated/truncated flags.
  1423. # If `stop` is already beyond this agent's single-agent episode, then we
  1424. # don't have to keep track of this: The MultiAgentEpisode initializer will
  1425. # automatically determine that this agent must be done (b/c it has no action
  1426. # following its final observation).
  1427. if (
  1428. stop < len(mapping)
  1429. and mapping[stop] != self.SKIP_ENV_TS_TAG
  1430. and len(sa_episode) == mapping[stop]
  1431. ):
  1432. terminateds[aid] = sa_episode.is_terminated
  1433. truncateds[aid] = sa_episode.is_truncated
  1434. # Determine this agent's t_started.
  1435. if start < len(mapping):
  1436. for i in range(start, len(mapping)):
  1437. if mapping[i] != self.SKIP_ENV_TS_TAG:
  1438. agent_t_started[aid] = sa_episode.t_started + mapping[i]
  1439. break
  1440. terminateds["__all__"] = all(
  1441. terminateds.get(aid) for aid in self.agent_episodes
  1442. )
  1443. truncateds["__all__"] = all(truncateds.get(aid) for aid in self.agent_episodes)
  1444. # Determine all other slice contents.
  1445. _lb = len_lookback_buffer if len_lookback_buffer is not None else ref_lookback
  1446. if start - _lb < 0 and ref_lookback < (_lb - start):
  1447. _lb = ref_lookback + start
  1448. observations = self.get_observations(
  1449. slice(start - _lb, stop + 1),
  1450. neg_index_as_lookback=True,
  1451. return_list=True,
  1452. )
  1453. actions = self.get_actions(
  1454. slice(start - _lb, stop),
  1455. neg_index_as_lookback=True,
  1456. return_list=True,
  1457. )
  1458. rewards = self.get_rewards(
  1459. slice(start - _lb, stop),
  1460. neg_index_as_lookback=True,
  1461. return_list=True,
  1462. )
  1463. extra_model_outputs = self.get_extra_model_outputs(
  1464. indices=slice(start - _lb, stop),
  1465. neg_index_as_lookback=True,
  1466. return_list=True,
  1467. )
  1468. # Create the actual slice to be returned.
  1469. ma_episode = MultiAgentEpisode(
  1470. id_=self.id_,
  1471. # In the following, offset `start`s automatically by lookbacks.
  1472. observations=observations,
  1473. observation_space=self.observation_space,
  1474. actions=actions,
  1475. action_space=self.action_space,
  1476. rewards=rewards,
  1477. extra_model_outputs=extra_model_outputs,
  1478. terminateds=terminateds,
  1479. truncateds=truncateds,
  1480. len_lookback_buffer=_lb,
  1481. env_t_started=self.env_t_started + start,
  1482. agent_episode_ids={
  1483. aid: eid.id_ for aid, eid in self.agent_episodes.items()
  1484. },
  1485. agent_t_started=agent_t_started,
  1486. agent_module_ids=self._agent_to_module_mapping,
  1487. agent_to_module_mapping_fn=self.agent_to_module_mapping_fn,
  1488. )
  1489. # Numpy'ize slice if `self` is also finalized.
  1490. if self.is_numpy:
  1491. ma_episode.to_numpy()
  1492. return ma_episode
  1493. def __len__(self):
  1494. """Returns the length of an `MultiAgentEpisode`.
  1495. Note that the length of an episode is defined by the difference
  1496. between its actual timestep and the starting point.
  1497. Returns: An integer defining the length of the episode or an
  1498. error if the episode has not yet started.
  1499. """
  1500. return self.env_t - self.env_t_started
  1501. def __repr__(self):
  1502. sa_eps_returns = {
  1503. aid: sa_eps.get_return() for aid, sa_eps in self.agent_episodes.items()
  1504. }
  1505. return (
  1506. f"MAEps(len={len(self)} done={self.is_done} "
  1507. f"Rs={sa_eps_returns} id_={self.id_})"
  1508. )
  1509. def print(self) -> None:
  1510. """Prints this MultiAgentEpisode as a table of observations for the agents."""
  1511. # Find the maximum timestep across all agents to determine the grid width.
  1512. max_ts = max(ts.len_incl_lookback() for ts in self.env_t_to_agent_t.values())
  1513. lookback = next(iter(self.env_t_to_agent_t.values())).lookback
  1514. longest_agent = max(len(aid) for aid in self.agent_ids)
  1515. # Construct the header.
  1516. header = (
  1517. "ts"
  1518. + (" " * longest_agent)
  1519. + " ".join(str(i) for i in range(-lookback, max_ts - lookback))
  1520. + "\n"
  1521. )
  1522. # Construct each agent's row.
  1523. rows = []
  1524. for agent, inf_buffer in self.env_t_to_agent_t.items():
  1525. row = f"{agent} " + (" " * (longest_agent - len(agent)))
  1526. for t in inf_buffer.data:
  1527. # Two spaces for alignment.
  1528. if t == "S":
  1529. row += " "
  1530. # Mark the step with an x.
  1531. else:
  1532. row += " x "
  1533. # Remove trailing space for alignment.
  1534. rows.append(row.rstrip())
  1535. # Join all components into a final string
  1536. print(header + "\n".join(rows))
  1537. def get_state(self) -> Dict[str, Any]:
  1538. """Returns the state of a multi-agent episode.
  1539. Note that from an episode's state the episode itself can
  1540. be recreated.
  1541. Returns: A dicitonary containing pickable data for a
  1542. `MultiAgentEpisode`.
  1543. """
  1544. return {
  1545. "id_": self.id_,
  1546. "agent_to_module_mapping_fn": self.agent_to_module_mapping_fn,
  1547. "_agent_to_module_mapping": self._agent_to_module_mapping,
  1548. "observation_space": self.observation_space,
  1549. "action_space": self.action_space,
  1550. "env_t_started": self.env_t_started,
  1551. "env_t": self.env_t,
  1552. "agent_t_started": self.agent_t_started,
  1553. # TODO (simon): Check, if we can store the `InfiniteLookbackBuffer`
  1554. "env_t_to_agent_t": self.env_t_to_agent_t,
  1555. "_hanging_actions_end": self._hanging_actions_end,
  1556. "_hanging_extra_model_outputs_end": self._hanging_extra_model_outputs_end,
  1557. "_hanging_rewards_end": self._hanging_rewards_end,
  1558. "_hanging_rewards_begin": self._hanging_rewards_begin,
  1559. "is_terminated": self.is_terminated,
  1560. "is_truncated": self.is_truncated,
  1561. "agent_episodes": list(
  1562. {
  1563. agent_id: agent_eps.get_state()
  1564. for agent_id, agent_eps in self.agent_episodes.items()
  1565. }.items()
  1566. ),
  1567. "_start_time": self._start_time,
  1568. "_last_step_time": self._last_step_time,
  1569. "custom_data": self.custom_data,
  1570. }
  1571. @staticmethod
  1572. def from_state(state: Dict[str, Any]) -> "MultiAgentEpisode":
  1573. """Creates a multi-agent episode from a state dictionary.
  1574. See `MultiAgentEpisode.get_state()` for creating a state for
  1575. a `MultiAgentEpisode` pickable state. For recreating a
  1576. `MultiAgentEpisode` from a state, this state has to be complete,
  1577. i.e. all data must have been stored in the state.
  1578. Args:
  1579. state: A dict containing all data required to recreate a MultiAgentEpisode`.
  1580. See `MultiAgentEpisode.get_state()`.
  1581. Returns:
  1582. A `MultiAgentEpisode` instance created from the state data.
  1583. """
  1584. # Create an empty `MultiAgentEpisode` instance.
  1585. episode = MultiAgentEpisode(id_=state["id_"])
  1586. # Fill the instance with the state data.
  1587. episode.agent_to_module_mapping_fn = state["agent_to_module_mapping_fn"]
  1588. episode._agent_to_module_mapping = state["_agent_to_module_mapping"]
  1589. episode.observation_space = state["observation_space"]
  1590. episode.action_space = state["action_space"]
  1591. episode.env_t_started = state["env_t_started"]
  1592. episode.env_t = state["env_t"]
  1593. episode.agent_t_started = state["agent_t_started"]
  1594. episode.env_t_to_agent_t = state["env_t_to_agent_t"]
  1595. episode._hanging_actions_end = state["_hanging_actions_end"]
  1596. episode._hanging_extra_model_outputs_end = state[
  1597. "_hanging_extra_model_outputs_end"
  1598. ]
  1599. episode._hanging_rewards_end = state["_hanging_rewards_end"]
  1600. episode._hanging_rewards_begin = state["_hanging_rewards_begin"]
  1601. episode.is_terminated = state["is_terminated"]
  1602. episode.is_truncated = state["is_truncated"]
  1603. episode.agent_episodes = {
  1604. agent_id: SingleAgentEpisode.from_state(agent_state)
  1605. for agent_id, agent_state in state["agent_episodes"]
  1606. }
  1607. episode._start_time = state["_start_time"]
  1608. episode._last_step_time = state["_last_step_time"]
  1609. episode._custom_data = state.get("custom_data", {})
  1610. # Validate the episode.
  1611. episode.validate()
  1612. return episode
  1613. def get_sample_batch(self) -> MultiAgentBatch:
  1614. """Converts this `MultiAgentEpisode` into a `MultiAgentBatch`.
  1615. Each `SingleAgentEpisode` instances in `MultiAgentEpisode.agent_epiosdes`
  1616. will be converted into a `SampleBatch` and the environment timestep will be
  1617. passed as the returned MultiAgentBatch's `env_steps`.
  1618. Returns:
  1619. A MultiAgentBatch containing all of this episode's data.
  1620. """
  1621. # TODO (simon): Check, if timesteps should be converted into global
  1622. # timesteps instead of agent steps.
  1623. # Note, only agents that have stepped are included into the batch.
  1624. return MultiAgentBatch(
  1625. policy_batches={
  1626. agent_id: agent_eps.get_sample_batch()
  1627. for agent_id, agent_eps in self.agent_episodes.items()
  1628. if agent_eps.t - agent_eps.t_started > 0
  1629. },
  1630. env_steps=self.env_t - self.env_t_started,
  1631. )
  1632. def get_return(
  1633. self,
  1634. include_hanging_rewards: bool = False,
  1635. ) -> float:
  1636. """Returns all-agent return.
  1637. Args:
  1638. include_hanging_rewards: Whether we should also consider
  1639. hanging rewards wehn calculating the overall return. Agents might
  1640. have received partial rewards, i.e. rewards without an
  1641. observation. These are stored in the "hanging" caches (begin and end)
  1642. for each agent and added up until the next observation is received by
  1643. that agent.
  1644. Returns:
  1645. The sum of all single-agents' returns (maybe including the hanging
  1646. rewards per agent).
  1647. """
  1648. env_return = sum(
  1649. agent_eps.get_return() for agent_eps in self.agent_episodes.values()
  1650. )
  1651. if include_hanging_rewards:
  1652. for hanging_r in self._hanging_rewards_begin.values():
  1653. env_return += hanging_r
  1654. for hanging_r in self._hanging_rewards_end.values():
  1655. env_return += hanging_r
  1656. return env_return
  1657. def get_agents_to_act(self) -> Set[AgentID]:
  1658. """Returns a set of agent IDs required to send an action to `env.step()` next.
  1659. Those are generally the agents that received an observation in the most recent
  1660. `env.step()` call.
  1661. Returns:
  1662. A set of AgentIDs that are supposed to send actions to the next `env.step()`
  1663. call.
  1664. """
  1665. return {
  1666. aid
  1667. for aid in self.get_observations(-1).keys()
  1668. if not self.agent_episodes[aid].is_done
  1669. }
  1670. def get_agents_that_stepped(self) -> Set[AgentID]:
  1671. """Returns a set of agent IDs of those agents that just finished stepping.
  1672. These are all the agents that have an observation logged at the last env
  1673. timestep, which may include agents, whose single agent episode just terminated
  1674. or truncated.
  1675. Returns:
  1676. A set of AgentIDs of those agents that just finished stepping (that have a
  1677. most recent observation on the env timestep scale), regardless of whether
  1678. their single agent episodes are done or not.
  1679. """
  1680. return set(self.get_observations(-1).keys())
  1681. def get_duration_s(self) -> float:
  1682. """Returns the duration of this Episode (chunk) in seconds."""
  1683. if self._last_step_time is None:
  1684. return 0.0
  1685. return self._last_step_time - self._start_time
  1686. def set_observations(
  1687. self,
  1688. *,
  1689. new_data: MultiAgentDict,
  1690. at_indices: Optional[Union[int, List[int], slice]] = None,
  1691. neg_index_as_lookback: bool = False,
  1692. ) -> None:
  1693. """Overwrites all or some single-agent Episode's observations with the provided data.
  1694. This is a helper method to batch `SingleAgentEpisode.set_observations`.
  1695. For more detail, see `SingleAgentEpisode.set_observations`.
  1696. Args:
  1697. new_data: A dict mapping agent IDs to new observation data.
  1698. Each value in the dict is the new observation data to overwrite existing data with.
  1699. This may be a list of individual observation(s) in case this episode
  1700. is still not numpy'ized yet. In case this episode has already been
  1701. numpy'ized, this should be (possibly complex) struct matching the
  1702. observation space and with a batch size of its leafs exactly the size
  1703. of the to-be-overwritten slice or segment (provided by `at_indices`).
  1704. at_indices: A single int is interpreted as one index, which to overwrite
  1705. with `new_data` (which is expected to be a single observation).
  1706. A list of ints is interpreted as a list of indices, all of which to
  1707. overwrite with `new_data` (which is expected to be of the same size
  1708. as `len(at_indices)`).
  1709. A slice object is interpreted as a range of indices to be overwritten
  1710. with `new_data` (which is expected to be of the same size as the
  1711. provided slice).
  1712. Thereby, negative indices by default are interpreted as "before the end"
  1713. unless the `neg_index_as_lookback=True` option is used, in which case
  1714. negative indices are interpreted as "before ts=0", meaning going back
  1715. into the lookback buffer.
  1716. neg_index_as_lookback: If True, negative values in `at_indices` are
  1717. interpreted as "before ts=0", meaning going back into the lookback
  1718. buffer. For example, an episode with
  1719. observations = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the
  1720. lookback buffer range (ts=0 item is 7), will handle a call to
  1721. `set_observations(individual_observation, -1,
  1722. neg_index_as_lookback=True)` by overwriting the value of 6 in our
  1723. observations buffer with the provided "individual_observation".
  1724. Raises:
  1725. IndexError: If the provided `at_indices` do not match the size of
  1726. `new_data`.
  1727. """
  1728. for agent_id, new_agent_data in new_data.items():
  1729. if agent_id not in self.agent_episodes:
  1730. raise KeyError(f"AgentID '{agent_id}' not found in this episode.")
  1731. self.agent_episodes[agent_id].set_observations(
  1732. new_data=new_agent_data,
  1733. at_indices=at_indices,
  1734. neg_index_as_lookback=neg_index_as_lookback,
  1735. )
  1736. def set_actions(
  1737. self,
  1738. *,
  1739. new_data: MultiAgentDict,
  1740. at_indices: Optional[Union[int, List[int], slice]] = None,
  1741. neg_index_as_lookback: bool = False,
  1742. ) -> None:
  1743. """Overwrites all or some of this Episode's actions with the provided data.
  1744. This is a helper method to batch `SingleAgentEpisode.set_actions`.
  1745. For more detail, see `SingleAgentEpisode.set_actions`.
  1746. Args:
  1747. new_data: A dict mapping agent IDs to new action data.
  1748. Each value in the dict is the new action data to overwrite existing data with.
  1749. This may be a list of individual action(s) in case this episode
  1750. is still not numpy'ized yet. In case this episode has already been
  1751. numpy'ized, this should be (possibly complex) struct matching the
  1752. action space and with a batch size of its leafs exactly the size
  1753. of the to-be-overwritten slice or segment (provided by `at_indices`).
  1754. at_indices: A single int is interpreted as one index, which to overwrite
  1755. with `new_data` (which is expected to be a single observation).
  1756. A list of ints is interpreted as a list of indices, all of which to
  1757. overwrite with `new_data` (which is expected to be of the same size
  1758. as `len(at_indices)`).
  1759. A slice object is interpreted as a range of indices to be overwritten
  1760. with `new_data` (which is expected to be of the same size as the
  1761. provided slice).
  1762. Thereby, negative indices by default are interpreted as "before the end"
  1763. unless the `neg_index_as_lookback=True` option is used, in which case
  1764. negative indices are interpreted as "before ts=0", meaning going back
  1765. into the lookback buffer.
  1766. neg_index_as_lookback: If True, negative values in `at_indices` are
  1767. interpreted as "before ts=0", meaning going back into the lookback
  1768. buffer. For example, an episode with
  1769. actions = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the
  1770. lookback buffer range (ts=0 item is 7), will handle a call to
  1771. `set_actions(individual_action, -1,
  1772. neg_index_as_lookback=True)` by overwriting the value of 6 in our
  1773. actions buffer with the provided "individual_action".
  1774. Raises:
  1775. IndexError: If the provided `at_indices` do not match the size of
  1776. `new_data`.
  1777. """
  1778. for agent_id, new_agent_data in new_data.items():
  1779. if agent_id not in self.agent_episodes:
  1780. raise KeyError(f"AgentID '{agent_id}' not found in this episode.")
  1781. self.agent_episodes[agent_id].set_actions(
  1782. new_data=new_agent_data,
  1783. at_indices=at_indices,
  1784. neg_index_as_lookback=neg_index_as_lookback,
  1785. )
  1786. def set_rewards(
  1787. self,
  1788. *,
  1789. new_data: MultiAgentDict,
  1790. at_indices: Optional[Union[int, List[int], slice]] = None,
  1791. neg_index_as_lookback: bool = False,
  1792. ) -> None:
  1793. """Overwrites all or some of this Episode's rewards with the provided data.
  1794. This is a helper method to batch `SingleAgentEpisode.set_rewards`.
  1795. For more detail, see `SingleAgentEpisode.set_rewards`.
  1796. Args:
  1797. new_data: A dict mapping agent IDs to new reward data.
  1798. Each value in the dict is the new reward data to overwrite existing data with.
  1799. This may be a list of individual reward(s) in case this episode
  1800. is still not numpy'ized yet. In case this episode has already been
  1801. numpy'ized, this should be a np.ndarray with a length exactly
  1802. the size of the to-be-overwritten slice or segment (provided by
  1803. `at_indices`).
  1804. at_indices: A single int is interpreted as one index, which to overwrite
  1805. with `new_data` (which is expected to be a single reward).
  1806. A list of ints is interpreted as a list of indices, all of which to
  1807. overwrite with `new_data` (which is expected to be of the same size
  1808. as `len(at_indices)`).
  1809. A slice object is interpreted as a range of indices to be overwritten
  1810. with `new_data` (which is expected to be of the same size as the
  1811. provided slice).
  1812. Thereby, negative indices by default are interpreted as "before the end"
  1813. unless the `neg_index_as_lookback=True` option is used, in which case
  1814. negative indices are interpreted as "before ts=0", meaning going back
  1815. into the lookback buffer.
  1816. neg_index_as_lookback: If True, negative values in `at_indices` are
  1817. interpreted as "before ts=0", meaning going back into the lookback
  1818. buffer. For example, an episode with
  1819. rewards = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the
  1820. lookback buffer range (ts=0 item is 7), will handle a call to
  1821. `set_rewards(individual_reward, -1,
  1822. neg_index_as_lookback=True)` by overwriting the value of 6 in our
  1823. rewards buffer with the provided "individual_reward".
  1824. Raises:
  1825. IndexError: If the provided `at_indices` do not match the size of
  1826. `new_data`.
  1827. """
  1828. for agent_id, new_agent_data in new_data.items():
  1829. if agent_id not in self.agent_episodes:
  1830. raise KeyError(f"AgentID '{agent_id}' not found in this episode.")
  1831. self.agent_episodes[agent_id].set_rewards(
  1832. new_data=new_agent_data,
  1833. at_indices=at_indices,
  1834. neg_index_as_lookback=neg_index_as_lookback,
  1835. )
  1836. def set_extra_model_outputs(
  1837. self,
  1838. *,
  1839. key,
  1840. new_data: MultiAgentDict,
  1841. at_indices: Optional[Union[int, List[int], slice]] = None,
  1842. neg_index_as_lookback: bool = False,
  1843. ) -> None:
  1844. """Overwrites all or some of this Episode's extra model outputs with `new_data`.
  1845. This is a helper method to batch `SingleAgentEpisode.set_extra_model_outputs`.
  1846. For more detail, see `SingleAgentEpisode.set_extra_model_outputs`.
  1847. Args:
  1848. key: The `key` within `self.extra_model_outputs` to override data on or
  1849. to insert as a new key into `self.extra_model_outputs`.
  1850. new_data: A dict mapping agent IDs to new extra model outputs data.
  1851. Each value in the dict is the new extra model outputs data to overwrite existing data with.
  1852. This may be a list of individual reward(s) in case this episode
  1853. is still not numpy'ized yet. In case this episode has already been
  1854. numpy'ized, this should be a np.ndarray with a length exactly
  1855. the size of the to-be-overwritten slice or segment (provided by
  1856. `at_indices`).
  1857. at_indices: A single int is interpreted as one index, which to overwrite
  1858. with `new_data` (which is expected to be a single extra model output).
  1859. A list of ints is interpreted as a list of indices, all of which to
  1860. overwrite with `new_data` (which is expected to be of the same size
  1861. as `len(at_indices)`).
  1862. A slice object is interpreted as a range of indices to be overwritten
  1863. with `new_data` (which is expected to be of the same size as the
  1864. provided slice).
  1865. Thereby, negative indices by default are interpreted as "before the end"
  1866. unless the `neg_index_as_lookback=True` option is used, in which case
  1867. negative indices are interpreted as "before ts=0", meaning going back
  1868. into the lookback buffer.
  1869. neg_index_as_lookback: If True, negative values in `at_indices` are
  1870. interpreted as "before ts=0", meaning going back into the lookback
  1871. buffer. For example, an episode with
  1872. extra_model_outputs[key][agent_id] = [4, 5, 6, 7, 8, 9], where [4, 5, 6] is the
  1873. lookback buffer range (ts=0 item is 7), will handle a call to
  1874. `set_extra_model_outputs(key, individual_output, -1,
  1875. neg_index_as_lookback=True)` by overwriting the value of 6 in our
  1876. extra_model_outputs[key][agent_id] buffer with the provided "individual_output".
  1877. Raises:
  1878. IndexError: If the provided `at_indices` do not match the size of
  1879. `new_data`.
  1880. """
  1881. for agent_id, new_agent_data in new_data.items():
  1882. if agent_id not in self.agent_episodes:
  1883. raise KeyError(f"AgentID '{agent_id}' not found in this episode.")
  1884. self.agent_episodes[agent_id].set_extra_model_outputs(
  1885. key=key,
  1886. new_data=new_agent_data,
  1887. at_indices=at_indices,
  1888. neg_index_as_lookback=neg_index_as_lookback,
  1889. )
  1890. def env_steps(self) -> int:
  1891. """Returns the number of environment steps.
  1892. Note, this episode instance could be a chunk of an actual episode.
  1893. Returns:
  1894. An integer that counts the number of environment steps this episode instance
  1895. has seen.
  1896. """
  1897. return len(self)
  1898. def agent_steps(self) -> int:
  1899. """Number of agent steps.
  1900. Note, there are >= 1 agent steps per environment step.
  1901. Returns:
  1902. An integer counting the number of agent steps executed during the time this
  1903. episode instance records.
  1904. """
  1905. return sum(len(eps) for eps in self.agent_episodes.values())
  1906. def __getitem__(self, item: slice) -> "MultiAgentEpisode":
  1907. """Enable squared bracket indexing- and slicing syntax, e.g. episode[-4:]."""
  1908. if isinstance(item, slice):
  1909. return self.slice(slice_=item)
  1910. else:
  1911. raise NotImplementedError(
  1912. f"MultiAgentEpisode does not support getting item '{item}'! "
  1913. "Only slice objects allowed with the syntax: `episode[a:b]`."
  1914. )
  1915. def _init_single_agent_episodes(
  1916. self,
  1917. *,
  1918. agent_module_ids: Optional[Dict[AgentID, ModuleID]] = None,
  1919. agent_episode_ids: Optional[Dict[AgentID, str]] = None,
  1920. observations: Optional[List[MultiAgentDict]] = None,
  1921. actions: Optional[List[MultiAgentDict]] = None,
  1922. rewards: Optional[List[MultiAgentDict]] = None,
  1923. infos: Optional[List[MultiAgentDict]] = None,
  1924. terminateds: Union[MultiAgentDict, bool] = False,
  1925. truncateds: Union[MultiAgentDict, bool] = False,
  1926. extra_model_outputs: Optional[List[MultiAgentDict]] = None,
  1927. ):
  1928. if observations is None:
  1929. return
  1930. if actions is None:
  1931. assert not rewards
  1932. assert not extra_model_outputs
  1933. actions = []
  1934. rewards = []
  1935. extra_model_outputs = []
  1936. # Infos and `extra_model_outputs` are allowed to be None -> Fill them with
  1937. # proper dummy values, if so.
  1938. if infos is None:
  1939. infos = [{} for _ in range(len(observations))]
  1940. if extra_model_outputs is None:
  1941. extra_model_outputs = [{} for _ in range(len(actions))]
  1942. observations_per_agent = defaultdict(list)
  1943. infos_per_agent = defaultdict(list)
  1944. actions_per_agent = defaultdict(list)
  1945. rewards_per_agent = defaultdict(list)
  1946. extra_model_outputs_per_agent = defaultdict(list)
  1947. done_per_agent = defaultdict(bool)
  1948. len_lookback_buffer_per_agent = defaultdict(lambda: self._len_lookback_buffers)
  1949. all_agent_ids = set(
  1950. agent_episode_ids.keys() if agent_episode_ids is not None else []
  1951. )
  1952. agent_module_ids = agent_module_ids or {}
  1953. # Step through all observations and interpret these as the (global) env steps.
  1954. for data_idx, (obs, inf) in enumerate(zip(observations, infos)):
  1955. # If we do have actions/extra outs/rewards for this timestep, use the data.
  1956. # It may be that these lists have the same length as the observations list,
  1957. # in which case the data will be cached (agent did step/send an action,
  1958. # but the step has not been concluded yet by the env).
  1959. act = actions[data_idx] if len(actions) > data_idx else {}
  1960. extra_outs = (
  1961. extra_model_outputs[data_idx]
  1962. if len(extra_model_outputs) > data_idx
  1963. else {}
  1964. )
  1965. rew = rewards[data_idx] if len(rewards) > data_idx else {}
  1966. for agent_id, agent_obs in obs.items():
  1967. all_agent_ids.add(agent_id)
  1968. observations_per_agent[agent_id].append(agent_obs)
  1969. infos_per_agent[agent_id].append(inf.get(agent_id, {}))
  1970. # Pull out hanging action (if not first obs for this agent) and
  1971. # complete step for agent.
  1972. if len(observations_per_agent[agent_id]) > 1:
  1973. actions_per_agent[agent_id].append(
  1974. self._hanging_actions_end.pop(agent_id)
  1975. )
  1976. extra_model_outputs_per_agent[agent_id].append(
  1977. self._hanging_extra_model_outputs_end.pop(agent_id)
  1978. )
  1979. rewards_per_agent[agent_id].append(
  1980. self._hanging_rewards_end.pop(agent_id)
  1981. )
  1982. # First obs for this agent. Make sure the agent's mapping is
  1983. # appropriately prepended with self.SKIP_ENV_TS_TAG tags.
  1984. else:
  1985. if agent_id not in self.env_t_to_agent_t:
  1986. self.env_t_to_agent_t[agent_id].extend(
  1987. [self.SKIP_ENV_TS_TAG] * data_idx
  1988. )
  1989. len_lookback_buffer_per_agent[agent_id] -= data_idx
  1990. # Agent is still continuing (has an action for the next step).
  1991. if agent_id in act:
  1992. # Always push actions/extra outputs into cache, then remove them
  1993. # from there, once the next observation comes in. Same for rewards.
  1994. self._hanging_actions_end[agent_id] = act[agent_id]
  1995. self._hanging_extra_model_outputs_end[agent_id] = extra_outs.get(
  1996. agent_id, {}
  1997. )
  1998. self._hanging_rewards_end[agent_id] += rew.get(agent_id, 0.0)
  1999. # Agent is done (has no action for the next step).
  2000. elif terminateds.get(agent_id) or truncateds.get(agent_id):
  2001. done_per_agent[agent_id] = True
  2002. # There is more (global) action/reward data. This agent must therefore
  2003. # be done. Automatically add it to `done_per_agent` and `terminateds`.
  2004. elif data_idx < len(observations) - 1:
  2005. done_per_agent[agent_id] = terminateds[agent_id] = True
  2006. # Update env_t_to_agent_t mapping.
  2007. self.env_t_to_agent_t[agent_id].append(
  2008. len(observations_per_agent[agent_id]) - 1
  2009. )
  2010. # Those agents that did NOT step:
  2011. # - Get self.SKIP_ENV_TS_TAG added to their env_t_to_agent_t mapping.
  2012. # - Get their reward (if any) added up.
  2013. for agent_id in all_agent_ids:
  2014. if agent_id not in obs and agent_id not in done_per_agent:
  2015. self.env_t_to_agent_t[agent_id].append(self.SKIP_ENV_TS_TAG)
  2016. # If we are still in the global lookback buffer segment, deduct 1
  2017. # from this agents' lookback buffer, b/c we don't want the agent
  2018. # to use this (missing) obs/data in its single-agent lookback.
  2019. if (
  2020. len(self.env_t_to_agent_t[agent_id])
  2021. - self._len_lookback_buffers
  2022. <= 0
  2023. ):
  2024. len_lookback_buffer_per_agent[agent_id] -= 1
  2025. self._hanging_rewards_end[agent_id] += rew.get(agent_id, 0.0)
  2026. # - Validate per-agent data.
  2027. # - Fix lookback buffers of env_t_to_agent_t mappings.
  2028. for agent_id in list(self.env_t_to_agent_t.keys()):
  2029. # Skip agent if it doesn't seem to have any data.
  2030. if agent_id not in observations_per_agent:
  2031. del self.env_t_to_agent_t[agent_id]
  2032. continue
  2033. assert (
  2034. len(observations_per_agent[agent_id])
  2035. == len(infos_per_agent[agent_id])
  2036. == len(actions_per_agent[agent_id]) + 1
  2037. == len(extra_model_outputs_per_agent[agent_id]) + 1
  2038. == len(rewards_per_agent[agent_id]) + 1
  2039. )
  2040. self.env_t_to_agent_t[agent_id].lookback = self._len_lookback_buffers
  2041. # Now create the individual episodes from the collected per-agent data.
  2042. for agent_id, agent_obs in observations_per_agent.items():
  2043. # If agent only has a single obs AND is already done, remove all its traces
  2044. # from this MultiAgentEpisode.
  2045. if len(agent_obs) == 1 and done_per_agent.get(agent_id):
  2046. self._del_agent(agent_id)
  2047. continue
  2048. # Try to figure out the module ID for this agent.
  2049. # If not provided explicitly by the user that initializes this episode
  2050. # object, try our mapping function.
  2051. module_id = agent_module_ids.get(
  2052. agent_id, self.agent_to_module_mapping_fn(agent_id, self)
  2053. )
  2054. # Create this agent's SingleAgentEpisode.
  2055. sa_episode = SingleAgentEpisode(
  2056. id_=(
  2057. agent_episode_ids.get(agent_id)
  2058. if agent_episode_ids is not None
  2059. else None
  2060. ),
  2061. agent_id=agent_id,
  2062. module_id=module_id,
  2063. multi_agent_episode_id=self.id_,
  2064. observations=agent_obs,
  2065. observation_space=self.observation_space.get(agent_id),
  2066. infos=infos_per_agent[agent_id],
  2067. actions=actions_per_agent[agent_id],
  2068. action_space=self.action_space.get(agent_id),
  2069. rewards=rewards_per_agent[agent_id],
  2070. extra_model_outputs=(
  2071. {
  2072. k: [i[k] for i in extra_model_outputs_per_agent[agent_id]]
  2073. for k in extra_model_outputs_per_agent[agent_id][0].keys()
  2074. }
  2075. if extra_model_outputs_per_agent[agent_id]
  2076. else None
  2077. ),
  2078. terminated=terminateds.get(agent_id, False),
  2079. truncated=truncateds.get(agent_id, False),
  2080. t_started=self.agent_t_started[agent_id],
  2081. len_lookback_buffer=max(len_lookback_buffer_per_agent[agent_id], 0),
  2082. )
  2083. # .. and store it.
  2084. self.agent_episodes[agent_id] = sa_episode
  2085. def _get(
  2086. self,
  2087. *,
  2088. what,
  2089. indices,
  2090. agent_ids=None,
  2091. env_steps=True,
  2092. neg_index_as_lookback=False,
  2093. fill=None,
  2094. one_hot_discrete=False,
  2095. return_list=False,
  2096. extra_model_outputs_key=None,
  2097. ):
  2098. agent_ids = set(force_list(agent_ids)) or self.agent_ids
  2099. kwargs = dict(
  2100. what=what,
  2101. indices=indices,
  2102. agent_ids=agent_ids,
  2103. neg_index_as_lookback=neg_index_as_lookback,
  2104. fill=fill,
  2105. # Rewards and infos do not support one_hot_discrete option.
  2106. one_hot_discrete=dict(
  2107. {} if not one_hot_discrete else {"one_hot_discrete": one_hot_discrete}
  2108. ),
  2109. extra_model_outputs_key=extra_model_outputs_key,
  2110. )
  2111. # User specified agent timesteps (indices) -> Simply delegate everything
  2112. # to the individual agents' SingleAgentEpisodes.
  2113. if env_steps is False:
  2114. if return_list:
  2115. raise ValueError(
  2116. f"`MultiAgentEpisode.get_{what}()` can't be called with both "
  2117. "`env_steps=False` and `return_list=True`!"
  2118. )
  2119. return self._get_data_by_agent_steps(**kwargs)
  2120. # User specified env timesteps (indices) -> We need to translate them for each
  2121. # agent into agent-timesteps.
  2122. # Return a list of individual per-env-timestep multi-agent dicts.
  2123. elif return_list:
  2124. return self._get_data_by_env_steps_as_list(**kwargs)
  2125. # Return a single multi-agent dict with lists/arrays as leafs.
  2126. else:
  2127. return self._get_data_by_env_steps(**kwargs)
  2128. def _get_data_by_agent_steps(
  2129. self,
  2130. *,
  2131. what,
  2132. indices,
  2133. agent_ids,
  2134. neg_index_as_lookback,
  2135. fill,
  2136. one_hot_discrete,
  2137. extra_model_outputs_key,
  2138. ):
  2139. # Return requested data by agent-steps.
  2140. ret = {}
  2141. # For each agent, we retrieve the data through passing the given indices into
  2142. # the SingleAgentEpisode of that agent.
  2143. for agent_id, sa_episode in self.agent_episodes.items():
  2144. if agent_id not in agent_ids:
  2145. continue
  2146. inf_lookback_buffer = getattr(sa_episode, what)
  2147. hanging_val = self._get_hanging_value(what, agent_id)
  2148. # User wants a specific `extra_model_outputs` key.
  2149. if extra_model_outputs_key is not None:
  2150. inf_lookback_buffer = inf_lookback_buffer[extra_model_outputs_key]
  2151. hanging_val = hanging_val[extra_model_outputs_key]
  2152. agent_value = inf_lookback_buffer.get(
  2153. indices=indices,
  2154. neg_index_as_lookback=neg_index_as_lookback,
  2155. fill=fill,
  2156. _add_last_ts_value=hanging_val,
  2157. **one_hot_discrete,
  2158. )
  2159. if agent_value is None or agent_value == []:
  2160. continue
  2161. ret[agent_id] = agent_value
  2162. return ret
  2163. def _get_data_by_env_steps_as_list(
  2164. self,
  2165. *,
  2166. what: str,
  2167. indices: Union[int, slice, List[int]],
  2168. agent_ids: Collection[AgentID],
  2169. neg_index_as_lookback: bool,
  2170. fill: Any,
  2171. one_hot_discrete,
  2172. extra_model_outputs_key: str,
  2173. ) -> List[MultiAgentDict]:
  2174. # Collect indices for each agent first, so we can construct the list in
  2175. # the next step.
  2176. agent_indices = {}
  2177. for agent_id in self.agent_episodes.keys():
  2178. if agent_id not in agent_ids:
  2179. continue
  2180. agent_indices[agent_id] = self.env_t_to_agent_t[agent_id].get(
  2181. indices,
  2182. neg_index_as_lookback=neg_index_as_lookback,
  2183. fill=self.SKIP_ENV_TS_TAG,
  2184. # For those records where there is no "hanging" last timestep (all
  2185. # other than obs and infos), we have to ignore the last entry in
  2186. # the env_t_to_agent_t mappings.
  2187. _ignore_last_ts=what not in ["observations", "infos"],
  2188. )
  2189. if not agent_indices:
  2190. return []
  2191. ret = []
  2192. for i in range(len(next(iter(agent_indices.values())))):
  2193. ret2 = {}
  2194. for agent_id, idxes in agent_indices.items():
  2195. hanging_val = self._get_hanging_value(what, agent_id)
  2196. (
  2197. inf_lookback_buffer,
  2198. indices_to_use,
  2199. ) = self._get_inf_lookback_buffer_or_dict(
  2200. agent_id,
  2201. what,
  2202. extra_model_outputs_key,
  2203. hanging_val,
  2204. filter_for_skip_indices=idxes[i],
  2205. )
  2206. if (
  2207. what == "extra_model_outputs"
  2208. and not inf_lookback_buffer
  2209. and not hanging_val
  2210. ):
  2211. continue
  2212. agent_value = self._get_single_agent_data_by_index(
  2213. what=what,
  2214. inf_lookback_buffer=inf_lookback_buffer,
  2215. agent_id=agent_id,
  2216. index_incl_lookback=indices_to_use,
  2217. fill=fill,
  2218. one_hot_discrete=one_hot_discrete,
  2219. extra_model_outputs_key=extra_model_outputs_key,
  2220. hanging_val=hanging_val,
  2221. )
  2222. if agent_value is not None:
  2223. ret2[agent_id] = agent_value
  2224. ret.append(ret2)
  2225. return ret
  2226. def _get_data_by_env_steps(
  2227. self,
  2228. *,
  2229. what: str,
  2230. indices: Union[int, slice, List[int]],
  2231. agent_ids: Collection[AgentID],
  2232. neg_index_as_lookback: bool,
  2233. fill: Any,
  2234. one_hot_discrete: bool,
  2235. extra_model_outputs_key: str,
  2236. ) -> MultiAgentDict:
  2237. ignore_last_ts = what not in ["observations", "infos"]
  2238. ret = {}
  2239. for agent_id, sa_episode in self.agent_episodes.items():
  2240. if agent_id not in agent_ids:
  2241. continue
  2242. hanging_val = self._get_hanging_value(what, agent_id)
  2243. agent_indices = self.env_t_to_agent_t[agent_id].get(
  2244. indices,
  2245. neg_index_as_lookback=neg_index_as_lookback,
  2246. fill=self.SKIP_ENV_TS_TAG if fill is not None else None,
  2247. # For those records where there is no "hanging" last timestep (all
  2248. # other than obs and infos), we have to ignore the last entry in
  2249. # the env_t_to_agent_t mappings.
  2250. _ignore_last_ts=ignore_last_ts,
  2251. )
  2252. inf_lookback_buffer, agent_indices = self._get_inf_lookback_buffer_or_dict(
  2253. agent_id,
  2254. what,
  2255. extra_model_outputs_key,
  2256. hanging_val,
  2257. filter_for_skip_indices=agent_indices,
  2258. )
  2259. if isinstance(agent_indices, list):
  2260. agent_values = self._get_single_agent_data_by_env_step_indices(
  2261. what=what,
  2262. agent_id=agent_id,
  2263. indices_incl_lookback=agent_indices,
  2264. fill=fill,
  2265. one_hot_discrete=one_hot_discrete,
  2266. hanging_val=hanging_val,
  2267. extra_model_outputs_key=extra_model_outputs_key,
  2268. )
  2269. if len(agent_values) > 0:
  2270. ret[agent_id] = agent_values
  2271. else:
  2272. agent_values = self._get_single_agent_data_by_index(
  2273. what=what,
  2274. inf_lookback_buffer=inf_lookback_buffer,
  2275. agent_id=agent_id,
  2276. index_incl_lookback=agent_indices,
  2277. fill=fill,
  2278. one_hot_discrete=one_hot_discrete,
  2279. extra_model_outputs_key=extra_model_outputs_key,
  2280. hanging_val=hanging_val,
  2281. )
  2282. if agent_values is not None:
  2283. ret[agent_id] = agent_values
  2284. return ret
  2285. def _get_single_agent_data_by_index(
  2286. self,
  2287. *,
  2288. what: str,
  2289. inf_lookback_buffer: InfiniteLookbackBuffer,
  2290. agent_id: AgentID,
  2291. index_incl_lookback: Union[int, str],
  2292. fill: Any,
  2293. one_hot_discrete: dict,
  2294. extra_model_outputs_key: str,
  2295. hanging_val: Any,
  2296. ) -> Any:
  2297. sa_episode = self.agent_episodes[agent_id]
  2298. if index_incl_lookback == self.SKIP_ENV_TS_TAG:
  2299. # We don't want to fill -> Skip this agent.
  2300. if fill is None:
  2301. return
  2302. # Provide filled value for this agent.
  2303. return getattr(sa_episode, f"get_{what}")(
  2304. indices=1000000000000,
  2305. neg_index_as_lookback=False,
  2306. fill=fill,
  2307. **dict(
  2308. {}
  2309. if extra_model_outputs_key is None
  2310. else {"key": extra_model_outputs_key}
  2311. ),
  2312. **one_hot_discrete,
  2313. )
  2314. # No skip timestep -> Provide value at given index for this agent.
  2315. # Special case: extra_model_outputs and key=None (return all keys as
  2316. # a dict). Note that `inf_lookback_buffer` is NOT an infinite lookback
  2317. # buffer, but a dict mapping keys to individual infinite lookback
  2318. # buffers.
  2319. elif what == "extra_model_outputs" and extra_model_outputs_key is None:
  2320. assert hanging_val is None or isinstance(hanging_val, dict)
  2321. ret = {}
  2322. if inf_lookback_buffer:
  2323. for key, sub_buffer in inf_lookback_buffer.items():
  2324. ret[key] = sub_buffer.get(
  2325. indices=index_incl_lookback - sub_buffer.lookback,
  2326. neg_index_as_lookback=True,
  2327. fill=fill,
  2328. _add_last_ts_value=(
  2329. None if hanging_val is None else hanging_val[key]
  2330. ),
  2331. **one_hot_discrete,
  2332. )
  2333. else:
  2334. for key in hanging_val.keys():
  2335. ret[key] = InfiniteLookbackBuffer().get(
  2336. indices=index_incl_lookback,
  2337. neg_index_as_lookback=True,
  2338. fill=fill,
  2339. _add_last_ts_value=hanging_val[key],
  2340. **one_hot_discrete,
  2341. )
  2342. return ret
  2343. # Extract data directly from the infinite lookback buffer object.
  2344. else:
  2345. return inf_lookback_buffer.get(
  2346. indices=index_incl_lookback - inf_lookback_buffer.lookback,
  2347. neg_index_as_lookback=True,
  2348. fill=fill,
  2349. _add_last_ts_value=hanging_val,
  2350. **one_hot_discrete,
  2351. )
  2352. def _get_single_agent_data_by_env_step_indices(
  2353. self,
  2354. *,
  2355. what: str,
  2356. agent_id: AgentID,
  2357. indices_incl_lookback: Union[int, str],
  2358. fill: Optional[Any] = None,
  2359. one_hot_discrete: bool = False,
  2360. extra_model_outputs_key: Optional[str] = None,
  2361. hanging_val: Optional[Any] = None,
  2362. ) -> Any:
  2363. """Returns single data item from the episode based on given (env step) indices.
  2364. The returned data item will have a batch size that matches the env timesteps
  2365. defined via `indices_incl_lookback`.
  2366. Args:
  2367. what: A (str) descriptor of what data to collect. Must be one of
  2368. "observations", "infos", "actions", "rewards", or "extra_model_outputs".
  2369. indices_incl_lookback: A list of ints specifying, which indices
  2370. to pull from the InfiniteLookbackBuffer defined by `agent_id` and `what`
  2371. (and maybe `extra_model_outputs_key`). Note that these indices
  2372. disregard the special logic of the lookback buffer. Meaning if one
  2373. index in `indices_incl_lookback` is 0, then the first value in the
  2374. lookback buffer should be returned, not the first value after the
  2375. lookback buffer (which would be normal behavior for pulling items from
  2376. an `InfiniteLookbackBuffer` object).
  2377. agent_id: The individual agent ID to pull data for. Used to lookup the
  2378. `SingleAgentEpisode` object for this agent in `self`.
  2379. fill: An optional float value to use for filling up the returned results at
  2380. the boundaries. This filling only happens if the requested index range's
  2381. start/stop boundaries exceed the buffer's boundaries (including the
  2382. lookback buffer on the left side). This comes in very handy, if users
  2383. don't want to worry about reaching such boundaries and want to zero-pad.
  2384. For example, a buffer with data [10, 11, 12, 13, 14] and lookback
  2385. buffer size of 2 (meaning `10` and `11` are part of the lookback buffer)
  2386. will respond to `indices_incl_lookback=[-1, -2, 0]` and `fill=0.0`
  2387. with `[0.0, 0.0, 10]`.
  2388. one_hot_discrete: If True, will return one-hot vectors (instead of
  2389. int-values) for those sub-components of a (possibly complex) space
  2390. that are Discrete or MultiDiscrete. Note that if `fill=0` and the
  2391. requested `indices_incl_lookback` are out of the range of our data, the
  2392. returned one-hot vectors will actually be zero-hot (all slots zero).
  2393. extra_model_outputs_key: Only if what is "extra_model_outputs", this
  2394. specifies the sub-key (str) inside the extra_model_outputs dict, e.g.
  2395. STATE_OUT or ACTION_DIST_INPUTS.
  2396. hanging_val: In case we are pulling actions, rewards, or extra_model_outputs
  2397. data, there might be information "hanging" (cached). For example,
  2398. if an agent receives an observation o0 and then immediately sends an
  2399. action a0 back, but then does NOT immediately reveive a next
  2400. observation, a0 is now cached (not fully logged yet with this
  2401. episode). The currently cached value must be provided here to be able
  2402. to return it in case the index is -1 (most recent timestep).
  2403. Returns:
  2404. A data item corresponding to the provided args.
  2405. """
  2406. sa_episode = self.agent_episodes[agent_id]
  2407. inf_lookback_buffer = getattr(sa_episode, what)
  2408. if extra_model_outputs_key is not None:
  2409. inf_lookback_buffer = inf_lookback_buffer[extra_model_outputs_key]
  2410. # If there are self.SKIP_ENV_TS_TAG items in `indices_incl_lookback` and user
  2411. # wants to fill these (together with outside-episode-bounds indices) ->
  2412. # Provide these skipped timesteps as filled values.
  2413. if self.SKIP_ENV_TS_TAG in indices_incl_lookback and fill is not None:
  2414. single_fill_value = inf_lookback_buffer.get(
  2415. indices=1000000000000,
  2416. neg_index_as_lookback=False,
  2417. fill=fill,
  2418. **one_hot_discrete,
  2419. )
  2420. ret = []
  2421. for i in indices_incl_lookback:
  2422. if i == self.SKIP_ENV_TS_TAG:
  2423. ret.append(single_fill_value)
  2424. else:
  2425. ret.append(
  2426. inf_lookback_buffer.get(
  2427. indices=i - getattr(sa_episode, what).lookback,
  2428. neg_index_as_lookback=True,
  2429. fill=fill,
  2430. _add_last_ts_value=hanging_val,
  2431. **one_hot_discrete,
  2432. )
  2433. )
  2434. if self.is_numpy:
  2435. ret = batch(ret)
  2436. else:
  2437. # Filter these indices out up front.
  2438. indices = [
  2439. i - inf_lookback_buffer.lookback
  2440. for i in indices_incl_lookback
  2441. if i != self.SKIP_ENV_TS_TAG
  2442. ]
  2443. ret = inf_lookback_buffer.get(
  2444. indices=indices,
  2445. neg_index_as_lookback=True,
  2446. fill=fill,
  2447. _add_last_ts_value=hanging_val,
  2448. **one_hot_discrete,
  2449. )
  2450. return ret
  2451. def _get_hanging_value(self, what: str, agent_id: AgentID) -> Any:
  2452. """Returns the hanging action/reward/extra_model_outputs for given agent."""
  2453. if what == "actions":
  2454. return self._hanging_actions_end.get(agent_id)
  2455. elif what == "extra_model_outputs":
  2456. return self._hanging_extra_model_outputs_end.get(agent_id)
  2457. elif what == "rewards":
  2458. return self._hanging_rewards_end.get(agent_id)
  2459. def _copy_hanging(self, agent_id: AgentID, other: "MultiAgentEpisode") -> None:
  2460. """Copies hanging action, reward, extra_model_outputs from `other` to `self."""
  2461. if agent_id in other._hanging_rewards_begin:
  2462. self._hanging_rewards_begin[agent_id] = other._hanging_rewards_begin[
  2463. agent_id
  2464. ]
  2465. if agent_id in other._hanging_rewards_end:
  2466. self._hanging_actions_end[agent_id] = copy.deepcopy(
  2467. other._hanging_actions_end[agent_id]
  2468. )
  2469. self._hanging_rewards_end[agent_id] = other._hanging_rewards_end[agent_id]
  2470. self._hanging_extra_model_outputs_end[agent_id] = copy.deepcopy(
  2471. other._hanging_extra_model_outputs_end[agent_id]
  2472. )
  2473. def _del_hanging(self, agent_id: AgentID) -> None:
  2474. """Deletes all hanging action, reward, extra_model_outputs of given agent."""
  2475. self._hanging_rewards_begin.pop(agent_id, None)
  2476. self._hanging_actions_end.pop(agent_id, None)
  2477. self._hanging_extra_model_outputs_end.pop(agent_id, None)
  2478. self._hanging_rewards_end.pop(agent_id, None)
  2479. def _del_agent(self, agent_id: AgentID) -> None:
  2480. """Deletes all data of given agent from this episode."""
  2481. self._del_hanging(agent_id)
  2482. self.agent_episodes.pop(agent_id, None)
  2483. self.agent_ids.discard(agent_id)
  2484. self.env_t_to_agent_t.pop(agent_id, None)
  2485. self._agent_to_module_mapping.pop(agent_id, None)
  2486. self.agent_t_started.pop(agent_id, None)
  2487. def _get_inf_lookback_buffer_or_dict(
  2488. self,
  2489. agent_id: AgentID,
  2490. what: str,
  2491. extra_model_outputs_key: Optional[str] = None,
  2492. hanging_val: Optional[Any] = None,
  2493. filter_for_skip_indices=None,
  2494. ):
  2495. """Returns a single InfiniteLookbackBuffer or a dict of such.
  2496. In case `what` is "extra_model_outputs" AND `extra_model_outputs_key` is None,
  2497. a dict is returned. In all other cases, a single InfiniteLookbackBuffer is
  2498. returned.
  2499. """
  2500. inf_lookback_buffer_or_dict = inf_lookback_buffer = getattr(
  2501. self.agent_episodes[agent_id], what
  2502. )
  2503. if what == "extra_model_outputs":
  2504. if extra_model_outputs_key is not None:
  2505. inf_lookback_buffer = inf_lookback_buffer_or_dict[
  2506. extra_model_outputs_key
  2507. ]
  2508. elif inf_lookback_buffer_or_dict:
  2509. inf_lookback_buffer = next(iter(inf_lookback_buffer_or_dict.values()))
  2510. elif filter_for_skip_indices is not None:
  2511. return inf_lookback_buffer_or_dict, filter_for_skip_indices
  2512. else:
  2513. return inf_lookback_buffer_or_dict
  2514. if filter_for_skip_indices is not None:
  2515. inf_lookback_buffer_len = (
  2516. len(inf_lookback_buffer)
  2517. + inf_lookback_buffer.lookback
  2518. + (hanging_val is not None)
  2519. )
  2520. ignore_last_ts = what not in ["observations", "infos"]
  2521. if isinstance(filter_for_skip_indices, list):
  2522. filter_for_skip_indices = [
  2523. "S" if ignore_last_ts and i == inf_lookback_buffer_len else i
  2524. for i in filter_for_skip_indices
  2525. ]
  2526. elif ignore_last_ts and filter_for_skip_indices == inf_lookback_buffer_len:
  2527. filter_for_skip_indices = "S"
  2528. return inf_lookback_buffer_or_dict, filter_for_skip_indices
  2529. else:
  2530. return inf_lookback_buffer_or_dict
  2531. @Deprecated(new="MultiAgentEpisode.custom_data[some-key] = ...", error=True)
  2532. def add_temporary_timestep_data(self):
  2533. pass
  2534. @Deprecated(new="MultiAgentEpisode.custom_data[some-key]", error=True)
  2535. def get_temporary_timestep_data(self):
  2536. pass