polynomial_schedule.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. from typing import Optional
  2. from ray.rllib.utils.annotations import OldAPIStack, override
  3. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  4. from ray.rllib.utils.schedules.schedule import Schedule
  5. from ray.rllib.utils.typing import TensorType
  6. tf1, tf, tfv = try_import_tf()
  7. torch, _ = try_import_torch()
  8. @OldAPIStack
  9. class PolynomialSchedule(Schedule):
  10. """Polynomial interpolation between `initial_p` and `final_p`.
  11. Over `schedule_timesteps`. After this many time steps, always returns
  12. `final_p`.
  13. """
  14. def __init__(
  15. self,
  16. schedule_timesteps: int,
  17. final_p: float,
  18. framework: Optional[str],
  19. initial_p: float = 1.0,
  20. power: float = 2.0,
  21. ):
  22. """Initializes a PolynomialSchedule instance.
  23. Args:
  24. schedule_timesteps: Number of time steps for which to
  25. linearly anneal initial_p to final_p
  26. final_p: Final output value.
  27. framework: The framework descriptor string, e.g. "tf",
  28. "torch", or None.
  29. initial_p: Initial output value.
  30. power: The exponent to use (default: quadratic).
  31. """
  32. super().__init__(framework=framework)
  33. assert schedule_timesteps > 0
  34. self.schedule_timesteps = schedule_timesteps
  35. self.final_p = final_p
  36. self.initial_p = initial_p
  37. self.power = power
  38. @override(Schedule)
  39. def _value(self, t: TensorType) -> TensorType:
  40. """Returns the result of:
  41. final_p + (initial_p - final_p) * (1 - `t`/t_max) ** power
  42. """
  43. if self.framework == "torch" and torch and isinstance(t, torch.Tensor):
  44. t = t.float()
  45. t = min(t, self.schedule_timesteps)
  46. return (
  47. self.final_p
  48. + (self.initial_p - self.final_p)
  49. * (1.0 - (t / self.schedule_timesteps)) ** self.power
  50. )
  51. @override(Schedule)
  52. def _tf_value_op(self, t: TensorType) -> TensorType:
  53. t = tf.math.minimum(t, self.schedule_timesteps)
  54. return (
  55. self.final_p
  56. + (self.initial_p - self.final_p)
  57. * (1.0 - (t / self.schedule_timesteps)) ** self.power
  58. )