piecewise_schedule.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. from typing import Callable, List, Optional, Tuple
  2. from ray.rllib.utils.annotations import override
  3. from ray.rllib.utils.framework import try_import_tf
  4. from ray.rllib.utils.schedules.schedule import Schedule
  5. from ray.rllib.utils.typing import TensorType
  6. from ray.util.annotations import DeveloperAPI
  7. tf1, tf, tfv = try_import_tf()
  8. def _linear_interpolation(left, right, alpha):
  9. return left + alpha * (right - left)
  10. @DeveloperAPI
  11. class PiecewiseSchedule(Schedule):
  12. """Implements a Piecewise Scheduler."""
  13. def __init__(
  14. self,
  15. endpoints: List[Tuple[int, float]],
  16. framework: Optional[str] = None,
  17. interpolation: Callable[
  18. [TensorType, TensorType, TensorType], TensorType
  19. ] = _linear_interpolation,
  20. outside_value: Optional[float] = None,
  21. ):
  22. """Initializes a PiecewiseSchedule instance.
  23. Args:
  24. endpoints: A list of tuples
  25. `(t, value)` such that the output
  26. is an interpolation (given by the `interpolation` callable)
  27. between two values.
  28. E.g.
  29. t=400 and endpoints=[(0, 20.0),(500, 30.0)]
  30. output=20.0 + 0.8 * (30.0 - 20.0) = 28.0
  31. NOTE: All the values for time must be sorted in an increasing
  32. order.
  33. framework: The framework descriptor string, e.g. "tf",
  34. "torch", or None.
  35. interpolation: A function that takes the left-value,
  36. the right-value and an alpha interpolation parameter
  37. (0.0=only left value, 1.0=only right value), which is the
  38. fraction of distance from left endpoint to right endpoint.
  39. outside_value: If t in call to `value` is
  40. outside of all the intervals in `endpoints` this value is
  41. returned. If None then an AssertionError is raised when outside
  42. value is requested.
  43. """
  44. super().__init__(framework=framework)
  45. idxes = [e[0] for e in endpoints]
  46. assert idxes == sorted(idxes)
  47. self.interpolation = interpolation
  48. self.outside_value = outside_value
  49. self.endpoints = [(int(e[0]), float(e[1])) for e in endpoints]
  50. @override(Schedule)
  51. def _value(self, t: TensorType) -> TensorType:
  52. # Find t in our list of endpoints.
  53. for (l_t, l), (r_t, r) in zip(self.endpoints[:-1], self.endpoints[1:]):
  54. # When found, return an interpolation (default: linear).
  55. if l_t <= t < r_t:
  56. alpha = float(t - l_t) / (r_t - l_t)
  57. return self.interpolation(l, r, alpha)
  58. # t does not belong to any of the pieces, return `self.outside_value`.
  59. assert self.outside_value is not None
  60. return self.outside_value
  61. @override(Schedule)
  62. def _tf_value_op(self, t: TensorType) -> TensorType:
  63. assert self.outside_value is not None, (
  64. "tf-version of PiecewiseSchedule requires `outside_value` to be "
  65. "provided!"
  66. )
  67. endpoints = tf.cast(tf.stack([e[0] for e in self.endpoints] + [-1]), tf.int64)
  68. # Create all possible interpolation results.
  69. results_list = []
  70. for (l_t, l), (r_t, r) in zip(self.endpoints[:-1], self.endpoints[1:]):
  71. alpha = tf.cast(t - l_t, tf.float32) / tf.cast(r_t - l_t, tf.float32)
  72. results_list.append(self.interpolation(l, r, alpha))
  73. # If t does not belong to any of the pieces, return `outside_value`.
  74. results_list.append(self.outside_value)
  75. results_list = tf.stack(results_list)
  76. # Return correct results tensor depending on where we find t.
  77. def _cond(i, x):
  78. x = tf.cast(x, tf.int64)
  79. return tf.logical_not(
  80. tf.logical_or(
  81. tf.equal(endpoints[i + 1], -1),
  82. tf.logical_and(endpoints[i] <= x, x < endpoints[i + 1]),
  83. )
  84. )
  85. def _body(i, x):
  86. return (i + 1, t)
  87. idx_and_t = tf.while_loop(_cond, _body, [tf.constant(0, dtype=tf.int64), t])
  88. return results_list[idx_and_t[0]]