episodes.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. from typing import List, Tuple
  2. import numpy as np
  3. from ray.rllib.env.single_agent_episode import SingleAgentEpisode
  4. from ray.util.annotations import DeveloperAPI
  5. @DeveloperAPI
  6. def add_one_ts_to_episodes_and_truncate(episodes: List[SingleAgentEpisode]):
  7. """Adds an artificial timestep to an episode at the end.
  8. In detail: The last observations, infos, actions, and all `extra_model_outputs`
  9. will be duplicated and appended to each episode's data. An extra 0.0 reward
  10. will be appended to the episode's rewards. The episode's timestep will be
  11. increased by 1. Also, adds the truncated=True flag to each episode if the
  12. episode is not already done (terminated or truncated).
  13. Useful for value function bootstrapping, where it is required to compute a
  14. forward pass for the very last timestep within the episode,
  15. i.e. using the following input dict: {
  16. obs=[final obs],
  17. state=[final state output],
  18. prev. reward=[final reward],
  19. etc..
  20. }
  21. Args:
  22. episodes: The list of SingleAgentEpisode objects to extend by one timestep
  23. and add a truncation flag if necessary.
  24. Returns:
  25. A list of the original episodes' truncated values (so the episodes can be
  26. properly restored later into their original states).
  27. """
  28. orig_truncateds = []
  29. for episode in episodes:
  30. orig_truncateds.append(episode.is_truncated)
  31. # Add timestep.
  32. episode.t += 1
  33. # Use the episode API that allows appending (possibly complex) structs
  34. # to the data.
  35. episode.observations.append(episode.observations[-1])
  36. episode.infos.append(episode.infos[-1])
  37. episode.actions.append(episode.actions[-1])
  38. episode.rewards.append(0.0)
  39. for v in episode.extra_model_outputs.values():
  40. v.append(v[-1])
  41. # Artificially make this episode truncated for the upcoming GAE
  42. # computations.
  43. if not episode.is_done:
  44. episode.is_truncated = True
  45. # Validate to make sure, everything is in order.
  46. episode.validate()
  47. return orig_truncateds
  48. @DeveloperAPI
  49. def remove_last_ts_from_data(
  50. episode_lens: List[int],
  51. *data: Tuple[np._typing.NDArray],
  52. ) -> Tuple[np._typing.NDArray]:
  53. """Removes the last timesteps from each given data item.
  54. Each item in data is a concatenated sequence of episodes data.
  55. For example if `episode_lens` is [2, 4], then data is a shape=(6,)
  56. ndarray. The returned corresponding value will have shape (4,), meaning
  57. both episodes have been shortened by exactly one timestep to 1 and 3.
  58. ..testcode::
  59. from ray.rllib.algorithms.ppo.ppo_learner import PPOLearner
  60. import numpy as np
  61. unpadded = PPOLearner._remove_last_ts_from_data(
  62. [5, 3],
  63. np.array([0, 1, 2, 3, 4, 0, 1, 2]),
  64. )
  65. assert (unpadded[0] == [0, 1, 2, 3, 0, 1]).all()
  66. unpadded = PPOLearner._remove_last_ts_from_data(
  67. [4, 2, 3],
  68. np.array([0, 1, 2, 3, 0, 1, 0, 1, 2]),
  69. np.array([4, 5, 6, 7, 2, 3, 3, 4, 5]),
  70. )
  71. assert (unpadded[0] == [0, 1, 2, 0, 0, 1]).all()
  72. assert (unpadded[1] == [4, 5, 6, 2, 3, 4]).all()
  73. Args:
  74. episode_lens: A list of current episode lengths. The returned
  75. data will have the same lengths minus 1 timestep.
  76. data: A tuple of data items (np.ndarrays) representing concatenated episodes
  77. to be shortened by one timestep per episode.
  78. Note that only arrays with `shape=(n,)` are supported! The
  79. returned data will have `shape=(n-len(episode_lens),)` (each
  80. episode gets shortened by one timestep).
  81. Returns:
  82. A tuple of new data items shortened by one timestep.
  83. """
  84. # Figure out the new slices to apply to each data item based on
  85. # the given episode_lens.
  86. slices = []
  87. sum = 0
  88. for len_ in episode_lens:
  89. slices.append(slice(sum, sum + len_ - 1))
  90. sum += len_
  91. # Compiling return data by slicing off one timestep at the end of
  92. # each episode.
  93. ret = []
  94. for d in data:
  95. ret.append(np.concatenate([d[s] for s in slices]))
  96. return tuple(ret) if len(ret) > 1 else ret[0]
  97. @DeveloperAPI
  98. def remove_last_ts_from_episodes_and_restore_truncateds(
  99. episodes: List[SingleAgentEpisode],
  100. orig_truncateds: List[bool],
  101. ) -> None:
  102. """Reverts the effects of `_add_ts_to_episodes_and_truncate`.
  103. Args:
  104. episodes: The list of SingleAgentEpisode objects to extend by one timestep
  105. and add a truncation flag if necessary.
  106. orig_truncateds: A list of the original episodes' truncated values to be
  107. applied to the `episodes`.
  108. """
  109. # Fix all episodes.
  110. for episode, orig_truncated in zip(episodes, orig_truncateds):
  111. # Reduce timesteps by 1.
  112. episode.t -= 1
  113. # Remove all extra timestep data from the episode's buffers.
  114. episode.observations.pop()
  115. episode.infos.pop()
  116. episode.actions.pop()
  117. episode.rewards.pop()
  118. for v in episode.extra_model_outputs.values():
  119. v.pop()
  120. # Fix the truncateds flag again.
  121. episode.is_truncated = orig_truncated