constant_schedule.py 1002 B

1234567891011121314151617181920212223242526272829303132
  1. from typing import Optional
  2. from ray.rllib.utils.annotations import OldAPIStack, override
  3. from ray.rllib.utils.framework import try_import_tf
  4. from ray.rllib.utils.schedules.schedule import Schedule
  5. from ray.rllib.utils.typing import TensorType
  6. tf1, tf, tfv = try_import_tf()
  7. @OldAPIStack
  8. class ConstantSchedule(Schedule):
  9. """A Schedule where the value remains constant over time."""
  10. def __init__(self, value: float, framework: Optional[str] = None):
  11. """Initializes a ConstantSchedule instance.
  12. Args:
  13. value: The constant value to return, independently of time.
  14. framework: The framework descriptor string, e.g. "tf",
  15. "torch", or None.
  16. """
  17. super().__init__(framework=framework)
  18. self._v = value
  19. @override(Schedule)
  20. def _value(self, t: TensorType) -> TensorType:
  21. return self._v
  22. @override(Schedule)
  23. def _tf_value_op(self, t: TensorType) -> TensorType:
  24. return tf.constant(self._v)