exponential_schedule.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. from typing import Optional
  2. from ray.rllib.utils.annotations import OldAPIStack, override
  3. from ray.rllib.utils.framework import try_import_torch
  4. from ray.rllib.utils.schedules.schedule import Schedule
  5. from ray.rllib.utils.typing import TensorType
  6. torch, _ = try_import_torch()
  7. @OldAPIStack
  8. class ExponentialSchedule(Schedule):
  9. """Exponential decay schedule from `initial_p` to `final_p`.
  10. Reduces output over `schedule_timesteps`. After this many time steps
  11. always returns `final_p`.
  12. """
  13. def __init__(
  14. self,
  15. schedule_timesteps: int,
  16. framework: Optional[str] = None,
  17. initial_p: float = 1.0,
  18. decay_rate: float = 0.1,
  19. ):
  20. """Initializes a ExponentialSchedule instance.
  21. Args:
  22. schedule_timesteps: Number of time steps for which to
  23. linearly anneal initial_p to final_p.
  24. framework: The framework descriptor string, e.g. "tf",
  25. "torch", or None.
  26. initial_p: Initial output value.
  27. decay_rate: The percentage of the original value after
  28. 100% of the time has been reached (see formula above).
  29. >0.0: The smaller the decay-rate, the stronger the decay.
  30. 1.0: No decay at all.
  31. """
  32. super().__init__(framework=framework)
  33. assert schedule_timesteps > 0
  34. self.schedule_timesteps = schedule_timesteps
  35. self.initial_p = initial_p
  36. self.decay_rate = decay_rate
  37. @override(Schedule)
  38. def _value(self, t: TensorType) -> TensorType:
  39. """Returns the result of: initial_p * decay_rate ** (`t`/t_max)."""
  40. if self.framework == "torch" and torch and isinstance(t, torch.Tensor):
  41. t = t.float()
  42. return self.initial_p * self.decay_rate ** (t / self.schedule_timesteps)