schedule.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. from abc import ABCMeta, abstractmethod
  2. from typing import Any, Union
  3. from ray.rllib.utils.annotations import OldAPIStack
  4. from ray.rllib.utils.framework import try_import_tf
  5. from ray.rllib.utils.typing import TensorType
  6. tf1, tf, tfv = try_import_tf()
  7. @OldAPIStack
  8. class Schedule(metaclass=ABCMeta):
  9. """Schedule classes implement various time-dependent scheduling schemas.
  10. - Constant behavior.
  11. - Linear decay.
  12. - Piecewise decay.
  13. - Exponential decay.
  14. Useful for backend-agnostic rate/weight changes for learning rates,
  15. exploration epsilons, beta parameters for prioritized replay, loss weights
  16. decay, etc..
  17. Each schedule can be called directly with the `t` (absolute time step)
  18. value and returns the value dependent on the Schedule and the passed time.
  19. """
  20. def __init__(self, framework):
  21. self.framework = framework
  22. def value(self, t: Union[int, TensorType]) -> Any:
  23. """Generates the value given a timestep (based on schedule's logic).
  24. Args:
  25. t: The time step. This could be a tf.Tensor.
  26. Returns:
  27. The calculated value depending on the schedule and `t`.
  28. """
  29. if self.framework in ["tf2", "tf"]:
  30. return self._tf_value_op(t)
  31. return self._value(t)
  32. def __call__(self, t: Union[int, TensorType]) -> Any:
  33. """Simply calls self.value(t). Implemented to make Schedules callable."""
  34. return self.value(t)
  35. @abstractmethod
  36. def _value(self, t: Union[int, TensorType]) -> Any:
  37. """
  38. Returns the value based on a time step input.
  39. Args:
  40. t: The time step. This could be a tf.Tensor.
  41. Returns:
  42. The calculated value depending on the schedule and `t`.
  43. """
  44. raise NotImplementedError
  45. def _tf_value_op(self, t: TensorType) -> TensorType:
  46. """
  47. Returns the tf-op that calculates the value based on a time step input.
  48. Args:
  49. t: The time step op (int tf.Tensor).
  50. Returns:
  51. The calculated value depending on the schedule and `t`.
  52. """
  53. # By default (most of the time), tf should work with python code.
  54. # Override only if necessary.
  55. return self._value(t)