postprocessing.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. from typing import Dict, Optional
  2. import numpy as np
  3. from ray.rllib.policy.policy import Policy
  4. from ray.rllib.policy.sample_batch import SampleBatch
  5. from ray.rllib.utils.annotations import DeveloperAPI, OldAPIStack
  6. from ray.rllib.utils.numpy import convert_to_numpy
  7. from ray.rllib.utils.typing import AgentID, TensorType
  8. @DeveloperAPI
  9. class Postprocessing:
  10. """Constant definitions for postprocessing."""
  11. ADVANTAGES = "advantages"
  12. VALUE_TARGETS = "value_targets"
  13. @OldAPIStack
  14. def adjust_nstep(n_step: int, gamma: float, batch: SampleBatch) -> None:
  15. """Rewrites `batch` to encode n-step rewards, terminateds, truncateds, and next-obs.
  16. Observations and actions remain unaffected. At the end of the trajectory,
  17. n is truncated to fit in the traj length.
  18. Args:
  19. n_step: The number of steps to look ahead and adjust.
  20. gamma: The discount factor.
  21. batch: The SampleBatch to adjust (in place).
  22. Examples:
  23. n-step=3
  24. Trajectory=o0 r0 d0, o1 r1 d1, o2 r2 d2, o3 r3 d3, o4 r4 d4=True o5
  25. gamma=0.9
  26. Returned trajectory:
  27. 0: o0 [r0 + 0.9*r1 + 0.9^2*r2 + 0.9^3*r3] d3 o0'=o3
  28. 1: o1 [r1 + 0.9*r2 + 0.9^2*r3 + 0.9^3*r4] d4 o1'=o4
  29. 2: o2 [r2 + 0.9*r3 + 0.9^2*r4] d4 o1'=o5
  30. 3: o3 [r3 + 0.9*r4] d4 o3'=o5
  31. 4: o4 r4 d4 o4'=o5
  32. """
  33. assert (
  34. batch.is_single_trajectory()
  35. ), "Unexpected terminated|truncated in middle of trajectory!"
  36. len_ = len(batch)
  37. # Shift NEXT_OBS, TERMINATEDS, and TRUNCATEDS.
  38. batch[SampleBatch.NEXT_OBS] = np.concatenate(
  39. [
  40. batch[SampleBatch.OBS][n_step:],
  41. np.stack([batch[SampleBatch.NEXT_OBS][-1]] * min(n_step, len_)),
  42. ],
  43. axis=0,
  44. )
  45. batch[SampleBatch.TERMINATEDS] = np.concatenate(
  46. [
  47. batch[SampleBatch.TERMINATEDS][n_step - 1 :],
  48. np.tile(batch[SampleBatch.TERMINATEDS][-1], min(n_step - 1, len_)),
  49. ],
  50. axis=0,
  51. )
  52. # Only fix `truncateds`, if present in the batch.
  53. if SampleBatch.TRUNCATEDS in batch:
  54. batch[SampleBatch.TRUNCATEDS] = np.concatenate(
  55. [
  56. batch[SampleBatch.TRUNCATEDS][n_step - 1 :],
  57. np.tile(batch[SampleBatch.TRUNCATEDS][-1], min(n_step - 1, len_)),
  58. ],
  59. axis=0,
  60. )
  61. # Change rewards in place.
  62. for i in range(len_):
  63. for j in range(1, n_step):
  64. if i + j < len_:
  65. batch[SampleBatch.REWARDS][i] += (
  66. gamma**j * batch[SampleBatch.REWARDS][i + j]
  67. )
  68. @OldAPIStack
  69. def compute_advantages(
  70. rollout: SampleBatch,
  71. last_r: float,
  72. gamma: float = 0.9,
  73. lambda_: float = 1.0,
  74. use_gae: bool = True,
  75. use_critic: bool = True,
  76. rewards: TensorType = None,
  77. vf_preds: TensorType = None,
  78. ):
  79. """Given a rollout, compute its value targets and the advantages.
  80. Args:
  81. rollout: SampleBatch of a single trajectory.
  82. last_r: Value estimation for last observation.
  83. gamma: Discount factor.
  84. lambda_: Parameter for GAE.
  85. use_gae: Using Generalized Advantage Estimation.
  86. use_critic: Whether to use critic (value estimates). Setting
  87. this to False will use 0 as baseline.
  88. rewards: Override the reward values in rollout.
  89. vf_preds: Override the value function predictions in rollout.
  90. Returns:
  91. SampleBatch with experience from rollout and processed rewards.
  92. """
  93. assert (
  94. SampleBatch.VF_PREDS in rollout or not use_critic
  95. ), "use_critic=True but values not found"
  96. assert use_critic or not use_gae, "Can't use gae without using a value function"
  97. last_r = convert_to_numpy(last_r)
  98. if rewards is None:
  99. rewards = rollout[SampleBatch.REWARDS]
  100. if vf_preds is None and use_critic:
  101. vf_preds = rollout[SampleBatch.VF_PREDS]
  102. if use_gae:
  103. vpred_t = np.concatenate([vf_preds, np.array([last_r])])
  104. delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1]
  105. # This formula for the advantage comes from:
  106. # "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438
  107. rollout[Postprocessing.ADVANTAGES] = discount_cumsum(delta_t, gamma * lambda_)
  108. rollout[Postprocessing.VALUE_TARGETS] = (
  109. rollout[Postprocessing.ADVANTAGES] + vf_preds
  110. ).astype(np.float32)
  111. else:
  112. rewards_plus_v = np.concatenate([rewards, np.array([last_r])])
  113. discounted_returns = discount_cumsum(rewards_plus_v, gamma)[:-1].astype(
  114. np.float32
  115. )
  116. if use_critic:
  117. rollout[Postprocessing.ADVANTAGES] = discounted_returns - vf_preds
  118. rollout[Postprocessing.VALUE_TARGETS] = discounted_returns
  119. else:
  120. rollout[Postprocessing.ADVANTAGES] = discounted_returns
  121. rollout[Postprocessing.VALUE_TARGETS] = np.zeros_like(
  122. rollout[Postprocessing.ADVANTAGES]
  123. )
  124. rollout[Postprocessing.ADVANTAGES] = rollout[Postprocessing.ADVANTAGES].astype(
  125. np.float32
  126. )
  127. return rollout
  128. @OldAPIStack
  129. def compute_gae_for_sample_batch(
  130. policy: Policy,
  131. sample_batch: SampleBatch,
  132. other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None,
  133. episode=None,
  134. ) -> SampleBatch:
  135. """Adds GAE (generalized advantage estimations) to a trajectory.
  136. The trajectory contains only data from one episode and from one agent.
  137. - If `config.batch_mode=truncate_episodes` (default), sample_batch may
  138. contain a truncated (at-the-end) episode, in case the
  139. `config.rollout_fragment_length` was reached by the sampler.
  140. - If `config.batch_mode=complete_episodes`, sample_batch will contain
  141. exactly one episode (no matter how long).
  142. New columns can be added to sample_batch and existing ones may be altered.
  143. Args:
  144. policy: The Policy used to generate the trajectory (`sample_batch`)
  145. sample_batch: The SampleBatch to postprocess.
  146. other_agent_batches: Optional dict of AgentIDs mapping to other
  147. agents' trajectory data (from the same episode).
  148. NOTE: The other agents use the same policy.
  149. episode: Optional multi-agent episode object in which the agents
  150. operated.
  151. Returns:
  152. The postprocessed, modified SampleBatch (or a new one).
  153. """
  154. # Compute the SampleBatch.VALUES_BOOTSTRAPPED column, which we'll need for the
  155. # following `last_r` arg in `compute_advantages()`.
  156. sample_batch = compute_bootstrap_value(sample_batch, policy)
  157. vf_preds = np.array(sample_batch[SampleBatch.VF_PREDS])
  158. rewards = np.array(sample_batch[SampleBatch.REWARDS])
  159. # We need to squeeze out the time dimension if there is one
  160. # Sanity check that both have the same shape
  161. if len(vf_preds.shape) == 2:
  162. assert vf_preds.shape == rewards.shape
  163. vf_preds = np.squeeze(vf_preds, axis=1)
  164. rewards = np.squeeze(rewards, axis=1)
  165. squeezed = True
  166. else:
  167. squeezed = False
  168. # Adds the policy logits, VF preds, and advantages to the batch,
  169. # using GAE ("generalized advantage estimation") or not.
  170. batch = compute_advantages(
  171. rollout=sample_batch,
  172. last_r=sample_batch[SampleBatch.VALUES_BOOTSTRAPPED][-1],
  173. gamma=policy.config["gamma"],
  174. lambda_=policy.config["lambda"],
  175. use_gae=policy.config["use_gae"],
  176. use_critic=policy.config.get("use_critic", True),
  177. vf_preds=vf_preds,
  178. rewards=rewards,
  179. )
  180. if squeezed:
  181. # If we needed to squeeze rewards and vf_preds, we need to unsqueeze
  182. # advantages again for it to have the same shape
  183. batch[Postprocessing.ADVANTAGES] = np.expand_dims(
  184. batch[Postprocessing.ADVANTAGES], axis=1
  185. )
  186. return batch
  187. @OldAPIStack
  188. def compute_bootstrap_value(sample_batch: SampleBatch, policy: Policy) -> SampleBatch:
  189. """Performs a value function computation at the end of a trajectory.
  190. If the trajectory is terminated (not truncated), will not use the value function,
  191. but assume that the value of the last timestep is 0.0.
  192. In all other cases, will use the given policy's value function to compute the
  193. "bootstrapped" value estimate at the end of the given trajectory. To do so, the
  194. very last observation (sample_batch[NEXT_OBS][-1]) and - if applicable -
  195. the very last state output (sample_batch[STATE_OUT][-1]) wil be used as inputs to
  196. the value function.
  197. The thus computed value estimate will be stored in a new column of the
  198. `sample_batch`: SampleBatch.VALUES_BOOTSTRAPPED. Thereby, values at all timesteps
  199. in this column are set to 0.0, except or the last timestep, which receives the
  200. computed bootstrapped value.
  201. This is done, such that in any loss function (which processes raw, intact
  202. trajectories, such as those of IMPALA and APPO) can use this new column as follows:
  203. Example: numbers=ts in episode, '|'=episode boundary (terminal),
  204. X=bootstrapped value (!= 0.0 b/c ts=12 is not a terminal).
  205. ts=5 is NOT a terminal.
  206. T: 8 9 10 11 12 <- no terminal
  207. VF_PREDS: . . . . .
  208. VALUES_BOOTSTRAPPED: 0 0 0 0 X
  209. Args:
  210. sample_batch: The SampleBatch (single trajectory) for which to compute the
  211. bootstrap value at the end. This SampleBatch will be altered in place
  212. (by adding a new column: SampleBatch.VALUES_BOOTSTRAPPED).
  213. policy: The Policy object, whose value function to use.
  214. Returns:
  215. The altered SampleBatch (with the extra SampleBatch.VALUES_BOOTSTRAPPED
  216. column).
  217. """
  218. # Trajectory is actually complete -> last r=0.0.
  219. if sample_batch[SampleBatch.TERMINATEDS][-1]:
  220. last_r = 0.0
  221. # Trajectory has been truncated -> last r=VF estimate of last obs.
  222. else:
  223. # Input dict is provided to us automatically via the Model's
  224. # requirements. It's a single-timestep (last one in trajectory)
  225. # input_dict.
  226. # Create an input dict according to the Policy's requirements.
  227. input_dict = sample_batch.get_single_step_input_dict(
  228. policy.view_requirements, index="last"
  229. )
  230. last_r = policy._value(**input_dict)
  231. vf_preds = np.array(sample_batch[SampleBatch.VF_PREDS])
  232. # We need to squeeze out the time dimension if there is one
  233. if len(vf_preds.shape) == 2:
  234. vf_preds = np.squeeze(vf_preds, axis=1)
  235. squeezed = True
  236. else:
  237. squeezed = False
  238. # Set the SampleBatch.VALUES_BOOTSTRAPPED field to VF_PREDS[1:] + the
  239. # very last timestep (where this bootstrapping value is actually needed), which
  240. # we set to the computed `last_r`.
  241. sample_batch[SampleBatch.VALUES_BOOTSTRAPPED] = np.concatenate(
  242. [
  243. convert_to_numpy(vf_preds[1:]),
  244. np.array([convert_to_numpy(last_r)], dtype=np.float32),
  245. ],
  246. axis=0,
  247. )
  248. if squeezed:
  249. sample_batch[SampleBatch.VF_PREDS] = np.expand_dims(vf_preds, axis=1)
  250. sample_batch[SampleBatch.VALUES_BOOTSTRAPPED] = np.expand_dims(
  251. sample_batch[SampleBatch.VALUES_BOOTSTRAPPED], axis=1
  252. )
  253. return sample_batch
  254. @OldAPIStack
  255. def discount_cumsum(x: np.ndarray, gamma: float) -> np.ndarray:
  256. """Calculates the discounted cumulative sum over a reward sequence `x`.
  257. y[t] - discount*y[t+1] = x[t]
  258. reversed(y)[t] - discount*reversed(y)[t-1] = reversed(x)[t]
  259. Args:
  260. gamma: The discount factor gamma.
  261. Returns:
  262. The sequence containing the discounted cumulative sums
  263. for each individual reward in `x` till the end of the trajectory.
  264. .. testcode::
  265. :skipif: True
  266. x = np.array([0.0, 1.0, 2.0, 3.0])
  267. gamma = 0.9
  268. discount_cumsum(x, gamma)
  269. .. testoutput::
  270. array([0.0 + 0.9*1.0 + 0.9^2*2.0 + 0.9^3*3.0,
  271. 1.0 + 0.9*2.0 + 0.9^2*3.0,
  272. 2.0 + 0.9*3.0,
  273. 3.0])
  274. """
  275. # Import scipy here to avoid import error when framework is tensorflow.
  276. import scipy
  277. return scipy.signal.lfilter([1], [1, float(-gamma)], x[::-1], axis=0)[::-1]