value_predictions.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. import numpy as np
  2. from ray.util.annotations import DeveloperAPI
  3. @DeveloperAPI
  4. def compute_value_targets(
  5. values,
  6. rewards,
  7. terminateds,
  8. truncateds,
  9. gamma: float,
  10. lambda_: float,
  11. ):
  12. """Computes value function (vf) targets given vf predictions and rewards.
  13. Note that advantages can then easily be computed via the formula:
  14. advantages = targets - vf_predictions
  15. """
  16. # Force-set all values at terminals (not at truncations!) to 0.0.
  17. orig_values = flat_values = values * (1.0 - terminateds)
  18. flat_values = np.append(flat_values, 0.0)
  19. intermediates = rewards + gamma * (1 - lambda_) * flat_values[1:]
  20. continues = 1.0 - terminateds
  21. Rs = []
  22. last = flat_values[-1]
  23. for t in reversed(range(intermediates.shape[0])):
  24. last = intermediates[t] + continues[t] * gamma * lambda_ * last
  25. Rs.append(last)
  26. if truncateds[t]:
  27. last = orig_values[t]
  28. # Reverse back to correct (time) direction.
  29. value_targets = np.stack(list(reversed(Rs)), axis=0)
  30. return value_targets.astype(np.float32)
  31. def extract_bootstrapped_values(vf_preds, episode_lengths, T):
  32. """Returns a bootstrapped value batch given value predictions.
  33. Note that the incoming value predictions must have happened over (artificially)
  34. elongated episodes (by 1 timestep at the end). This way, we can either extract the
  35. `vf_preds` at these extra timesteps (as "bootstrap values") or skip over them
  36. entirely if they lie in the middle of the T-slices.
  37. For example, given an episodes structure like this:
  38. 01234a 0123456b 01c 012- 0123e 012-
  39. where each episode is separated by a space and goes from 0 to n and ends in an
  40. artificially elongated timestep (denoted by 'a', 'b', 'c', '-', or 'e'), where '-'
  41. means that the episode was terminated and the bootstrap value at the end should be
  42. zero and 'a', 'b', 'c', etc.. represent truncated episode ends with computed vf
  43. estimates.
  44. The output for the above sequence (and T=4) should then be:
  45. 4 3 b 2 3 -
  46. Args:
  47. vf_preds: The computed value function predictions over the artificially
  48. elongated episodes (by one timestep at the end).
  49. episode_lengths: The original (correct) episode lengths, NOT counting the
  50. artificially added timestep at the end.
  51. T: The size of the time dimension by which to slice the data. Note that the
  52. sum of all episode lengths (`sum(episode_lengths)`) must be dividable by T.
  53. Returns:
  54. The batch of bootstrapped values.
  55. """
  56. bootstrapped_values = []
  57. if sum(episode_lengths) % T != 0:
  58. raise ValueError(
  59. "Can only extract bootstrapped values if the sum of episode lengths "
  60. f"({sum(episode_lengths)}) is dividable by the given T ({T})!"
  61. )
  62. # Loop over all episode lengths and collect bootstrap values.
  63. # Do not alter incoming `episode_lengths` list.
  64. episode_lengths = episode_lengths[:]
  65. i = -1
  66. while i < len(episode_lengths) - 1:
  67. i += 1
  68. eps_len = episode_lengths[i]
  69. # We can make another T-stride inside this episode ->
  70. # - Use a vf prediction within the episode as bootstrapped value.
  71. # - "Fix" the episode_lengths array and continue within the same episode.
  72. if T < eps_len:
  73. bootstrapped_values.append(vf_preds[T])
  74. vf_preds = vf_preds[T:]
  75. episode_lengths[i] -= T
  76. i -= 1
  77. # We can make another T-stride inside this episode, but will then be at the end
  78. # of it ->
  79. # - Use the value function prediction at the artificially added timestep
  80. # as bootstrapped value.
  81. # - Skip the additional timestep at the end and ,ove on with next episode.
  82. elif T == eps_len:
  83. bootstrapped_values.append(vf_preds[T])
  84. vf_preds = vf_preds[T + 1 :]
  85. # The episode fits entirely into the T-stride ->
  86. # - Move on to next episode ("fix" its length by make it seemingly longer).
  87. else:
  88. # Skip bootstrap value of current episode (not needed).
  89. vf_preds = vf_preds[1:]
  90. # Make next episode seem longer.
  91. episode_lengths[i + 1] += eps_len
  92. return np.array(bootstrapped_values)