| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144 |
- from typing import List, Tuple
- import numpy as np
- from ray.rllib.env.single_agent_episode import SingleAgentEpisode
- from ray.util.annotations import DeveloperAPI
- @DeveloperAPI
- def add_one_ts_to_episodes_and_truncate(episodes: List[SingleAgentEpisode]):
- """Adds an artificial timestep to an episode at the end.
- In detail: The last observations, infos, actions, and all `extra_model_outputs`
- will be duplicated and appended to each episode's data. An extra 0.0 reward
- will be appended to the episode's rewards. The episode's timestep will be
- increased by 1. Also, adds the truncated=True flag to each episode if the
- episode is not already done (terminated or truncated).
- Useful for value function bootstrapping, where it is required to compute a
- forward pass for the very last timestep within the episode,
- i.e. using the following input dict: {
- obs=[final obs],
- state=[final state output],
- prev. reward=[final reward],
- etc..
- }
- Args:
- episodes: The list of SingleAgentEpisode objects to extend by one timestep
- and add a truncation flag if necessary.
- Returns:
- A list of the original episodes' truncated values (so the episodes can be
- properly restored later into their original states).
- """
- orig_truncateds = []
- for episode in episodes:
- orig_truncateds.append(episode.is_truncated)
- # Add timestep.
- episode.t += 1
- # Use the episode API that allows appending (possibly complex) structs
- # to the data.
- episode.observations.append(episode.observations[-1])
- episode.infos.append(episode.infos[-1])
- episode.actions.append(episode.actions[-1])
- episode.rewards.append(0.0)
- for v in episode.extra_model_outputs.values():
- v.append(v[-1])
- # Artificially make this episode truncated for the upcoming GAE
- # computations.
- if not episode.is_done:
- episode.is_truncated = True
- # Validate to make sure, everything is in order.
- episode.validate()
- return orig_truncateds
- @DeveloperAPI
- def remove_last_ts_from_data(
- episode_lens: List[int],
- *data: Tuple[np._typing.NDArray],
- ) -> Tuple[np._typing.NDArray]:
- """Removes the last timesteps from each given data item.
- Each item in data is a concatenated sequence of episodes data.
- For example if `episode_lens` is [2, 4], then data is a shape=(6,)
- ndarray. The returned corresponding value will have shape (4,), meaning
- both episodes have been shortened by exactly one timestep to 1 and 3.
- ..testcode::
- from ray.rllib.algorithms.ppo.ppo_learner import PPOLearner
- import numpy as np
- unpadded = PPOLearner._remove_last_ts_from_data(
- [5, 3],
- np.array([0, 1, 2, 3, 4, 0, 1, 2]),
- )
- assert (unpadded[0] == [0, 1, 2, 3, 0, 1]).all()
- unpadded = PPOLearner._remove_last_ts_from_data(
- [4, 2, 3],
- np.array([0, 1, 2, 3, 0, 1, 0, 1, 2]),
- np.array([4, 5, 6, 7, 2, 3, 3, 4, 5]),
- )
- assert (unpadded[0] == [0, 1, 2, 0, 0, 1]).all()
- assert (unpadded[1] == [4, 5, 6, 2, 3, 4]).all()
- Args:
- episode_lens: A list of current episode lengths. The returned
- data will have the same lengths minus 1 timestep.
- data: A tuple of data items (np.ndarrays) representing concatenated episodes
- to be shortened by one timestep per episode.
- Note that only arrays with `shape=(n,)` are supported! The
- returned data will have `shape=(n-len(episode_lens),)` (each
- episode gets shortened by one timestep).
- Returns:
- A tuple of new data items shortened by one timestep.
- """
- # Figure out the new slices to apply to each data item based on
- # the given episode_lens.
- slices = []
- sum = 0
- for len_ in episode_lens:
- slices.append(slice(sum, sum + len_ - 1))
- sum += len_
- # Compiling return data by slicing off one timestep at the end of
- # each episode.
- ret = []
- for d in data:
- ret.append(np.concatenate([d[s] for s in slices]))
- return tuple(ret) if len(ret) > 1 else ret[0]
- @DeveloperAPI
- def remove_last_ts_from_episodes_and_restore_truncateds(
- episodes: List[SingleAgentEpisode],
- orig_truncateds: List[bool],
- ) -> None:
- """Reverts the effects of `_add_ts_to_episodes_and_truncate`.
- Args:
- episodes: The list of SingleAgentEpisode objects to extend by one timestep
- and add a truncation flag if necessary.
- orig_truncateds: A list of the original episodes' truncated values to be
- applied to the `episodes`.
- """
- # Fix all episodes.
- for episode, orig_truncated in zip(episodes, orig_truncateds):
- # Reduce timesteps by 1.
- episode.t -= 1
- # Remove all extra timestep data from the episode's buffers.
- episode.observations.pop()
- episode.infos.pop()
- episode.actions.pop()
- episode.rewards.pop()
- for v in episode.extra_model_outputs.values():
- v.pop()
- # Fix the truncateds flag again.
- episode.is_truncated = orig_truncated
|