scheduler.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. from typing import Optional
  2. from ray.rllib.utils.framework import try_import_tf, try_import_torch
  3. from ray.rllib.utils.schedules.piecewise_schedule import PiecewiseSchedule
  4. from ray.rllib.utils.typing import LearningRateOrSchedule, TensorType
  5. from ray.util.annotations import DeveloperAPI
  6. _, tf, _ = try_import_tf()
  7. torch, _ = try_import_torch()
  8. @DeveloperAPI
  9. class Scheduler:
  10. """Class to manage a scheduled (framework-dependent) tensor variable.
  11. Uses the PiecewiseSchedule (for maximum configuration flexibility)
  12. """
  13. def __init__(
  14. self,
  15. fixed_value_or_schedule: LearningRateOrSchedule,
  16. *,
  17. framework: str = "torch",
  18. device: Optional[str] = None,
  19. ):
  20. """Initializes a Scheduler instance.
  21. Args:
  22. fixed_value_or_schedule: A fixed, constant value (in case no schedule should
  23. be used) or a schedule configuration in the format of
  24. [[timestep, value], [timestep, value], ...]
  25. Intermediary timesteps will be assigned to linerarly interpolated
  26. values. A schedule config's first entry must
  27. start with timestep 0, i.e.: [[0, initial_value], [...]].
  28. framework: The framework string, for which to create the tensor variable
  29. that hold the current value. This is the variable that can be used in
  30. the graph, e.g. in a loss function.
  31. device: Optional device (for torch) to place the tensor variable on.
  32. """
  33. self.framework = framework
  34. self.device = device
  35. self.use_schedule = isinstance(fixed_value_or_schedule, (list, tuple))
  36. if self.use_schedule:
  37. # Custom schedule, based on list of
  38. # ([ts], [value to be reached by ts])-tuples.
  39. self._schedule = PiecewiseSchedule(
  40. fixed_value_or_schedule,
  41. outside_value=fixed_value_or_schedule[-1][-1],
  42. framework=None,
  43. )
  44. # As initial tensor valie, use the first timestep's (must be 0) value.
  45. self._curr_value = self._create_tensor_variable(
  46. initial_value=fixed_value_or_schedule[0][1]
  47. )
  48. # If no schedule, pin (fix) given value.
  49. else:
  50. self._curr_value = fixed_value_or_schedule
  51. @staticmethod
  52. def validate(
  53. *,
  54. fixed_value_or_schedule: LearningRateOrSchedule,
  55. setting_name: str,
  56. description: str,
  57. ) -> None:
  58. """Performs checking of a certain schedule configuration.
  59. The first entry in `value_or_schedule` (if it's not a fixed value) must have a
  60. timestep of 0.
  61. Args:
  62. fixed_value_or_schedule: A fixed, constant value (in case no schedule should
  63. be used) or a schedule configuration in the format of
  64. [[timestep, value], [timestep, value], ...]
  65. Intermediary timesteps will be assigned to linerarly interpolated
  66. values. A schedule config's first entry must
  67. start with timestep 0, i.e.: [[0, initial_value], [...]].
  68. setting_name: The property name of the schedule setting (within a config),
  69. e.g. `lr` or `entropy_coeff`.
  70. description: A full text description of the property that's being scheduled,
  71. e.g. `learning rate`.
  72. Raises:
  73. ValueError: In case, errors are found in the schedule's format.
  74. """
  75. # Fixed (single) value.
  76. if (
  77. isinstance(fixed_value_or_schedule, (int, float))
  78. or fixed_value_or_schedule is None
  79. ):
  80. return
  81. if not isinstance(fixed_value_or_schedule, (list, tuple)) or (
  82. len(fixed_value_or_schedule) < 2
  83. ):
  84. raise ValueError(
  85. f"Invalid `{setting_name}` ({fixed_value_or_schedule}) specified! "
  86. f"Must be a list of 2 or more tuples, each of the form "
  87. f"(`timestep`, `{description} to reach`), for example "
  88. "`[(0, 0.001), (1e6, 0.0001), (2e6, 0.00005)]`."
  89. )
  90. elif fixed_value_or_schedule[0][0] != 0:
  91. raise ValueError(
  92. f"When providing a `{setting_name}` schedule, the first timestep must "
  93. f"be 0 and the corresponding lr value is the initial {description}! "
  94. f"You provided ts={fixed_value_or_schedule[0][0]} {description}="
  95. f"{fixed_value_or_schedule[0][1]}."
  96. )
  97. elif any(len(pair) != 2 for pair in fixed_value_or_schedule):
  98. raise ValueError(
  99. f"When providing a `{setting_name}` schedule, each tuple in the "
  100. f"schedule list must have exctly 2 items of the form "
  101. f"(`timestep`, `{description} to reach`), for example "
  102. "`[(0, 0.001), (1e6, 0.0001), (2e6, 0.00005)]`."
  103. )
  104. def get_current_value(self) -> TensorType:
  105. """Returns the current value (as a tensor variable).
  106. This method should be used in loss functions of other (in-graph) places
  107. where the current value is needed.
  108. Returns:
  109. The tensor variable (holding the current value to be used).
  110. """
  111. return self._curr_value
  112. def update(self, timestep: int) -> float:
  113. """Updates the underlying (framework specific) tensor variable.
  114. In case of a fixed value, this method does nothing and only returns the fixed
  115. value as-is.
  116. Args:
  117. timestep: The current timestep that the update might depend on.
  118. Returns:
  119. The current value of the tensor variable as a python float.
  120. """
  121. if self.use_schedule:
  122. python_value = self._schedule.value(t=timestep)
  123. if self.framework == "torch":
  124. self._curr_value.data = torch.tensor(python_value)
  125. else:
  126. self._curr_value.assign(python_value)
  127. else:
  128. python_value = self._curr_value
  129. return python_value
  130. def _create_tensor_variable(self, initial_value: float) -> TensorType:
  131. """Creates a framework-specific tensor variable to be scheduled.
  132. Args:
  133. initial_value: The initial (float) value for the variable to hold.
  134. Returns:
  135. The created framework-specific tensor variable.
  136. """
  137. if self.framework == "torch":
  138. return torch.tensor(
  139. initial_value,
  140. requires_grad=False,
  141. dtype=torch.float32,
  142. device=self.device,
  143. )
  144. else:
  145. return tf.Variable(
  146. initial_value,
  147. trainable=False,
  148. dtype=tf.float32,
  149. )