async_hyperband.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. import logging
  2. import pickle
  3. from typing import TYPE_CHECKING, Dict, Optional, Union
  4. import numpy as np
  5. from ray.tune.experiment import Trial
  6. from ray.tune.result import DEFAULT_METRIC
  7. from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
  8. from ray.util import PublicAPI
  9. if TYPE_CHECKING:
  10. from ray.tune.execution.tune_controller import TuneController
  11. logger = logging.getLogger(__name__)
  12. @PublicAPI
  13. class AsyncHyperBandScheduler(FIFOScheduler):
  14. """Implements the Async Successive Halving.
  15. This should provide similar theoretical performance as HyperBand but
  16. avoid straggler issues that HyperBand faces. One implementation detail
  17. is when using multiple brackets, trial allocation to bracket is done
  18. randomly with over a softmax probability.
  19. See https://arxiv.org/abs/1810.05934
  20. Args:
  21. time_attr: A training result attr to use for comparing time.
  22. Note that you can pass in something non-temporal such as
  23. `training_iteration` as a measure of progress, the only requirement
  24. is that the attribute should increase monotonically.
  25. metric: The training result objective value attribute. Stopping
  26. procedures will use this attribute. If None but a mode was passed,
  27. the `ray.tune.result.DEFAULT_METRIC` will be used per default.
  28. mode: One of {min, max}. Determines whether objective is
  29. minimizing or maximizing the metric attribute.
  30. max_t: max time units per trial. Trials will be stopped after
  31. max_t time units (determined by time_attr) have passed.
  32. grace_period: Only stop trials at least this old in time.
  33. The units are the same as the attribute named by `time_attr`.
  34. reduction_factor: Used to set halving rate and amount. This
  35. is simply a unit-less scalar.
  36. brackets: Number of brackets. Each bracket has a different
  37. halving rate, specified by the reduction factor.
  38. stop_last_trials: Whether to terminate the trials after
  39. reaching max_t. Defaults to True.
  40. """
  41. def __init__(
  42. self,
  43. time_attr: str = "training_iteration",
  44. metric: Optional[str] = None,
  45. mode: Optional[str] = None,
  46. max_t: int = 100,
  47. grace_period: int = 1,
  48. reduction_factor: float = 4,
  49. brackets: int = 1,
  50. stop_last_trials: bool = True,
  51. ):
  52. assert max_t > 0, "Max (time_attr) not valid!"
  53. assert max_t >= grace_period, "grace_period must be <= max_t!"
  54. assert grace_period > 0, "grace_period must be positive!"
  55. assert reduction_factor > 1, "Reduction Factor not valid!"
  56. assert brackets > 0, "brackets must be positive!"
  57. if mode:
  58. assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
  59. super().__init__()
  60. self._reduction_factor = reduction_factor
  61. self._max_t = max_t
  62. self._trial_info = {} # Stores Trial -> Bracket
  63. # Tracks state for new trial add
  64. self._brackets = [
  65. _Bracket(
  66. grace_period,
  67. max_t,
  68. reduction_factor,
  69. s,
  70. stop_last_trials=stop_last_trials,
  71. )
  72. for s in range(brackets)
  73. ]
  74. self._counter = 0 # for
  75. self._num_stopped = 0
  76. self._metric = metric
  77. self._mode = mode
  78. self._metric_op = None
  79. if self._mode == "max":
  80. self._metric_op = 1.0
  81. elif self._mode == "min":
  82. self._metric_op = -1.0
  83. self._time_attr = time_attr
  84. self._stop_last_trials = stop_last_trials
  85. def set_search_properties(
  86. self, metric: Optional[str], mode: Optional[str], **spec
  87. ) -> bool:
  88. if self._metric and metric:
  89. return False
  90. if self._mode and mode:
  91. return False
  92. if metric:
  93. self._metric = metric
  94. if mode:
  95. self._mode = mode
  96. if self._mode == "max":
  97. self._metric_op = 1.0
  98. elif self._mode == "min":
  99. self._metric_op = -1.0
  100. if self._metric is None and self._mode:
  101. # If only a mode was passed, use anonymous metric
  102. self._metric = DEFAULT_METRIC
  103. return True
  104. def on_trial_add(self, tune_controller: "TuneController", trial: Trial):
  105. if not self._metric or not self._metric_op:
  106. raise ValueError(
  107. "{} has been instantiated without a valid `metric` ({}) or "
  108. "`mode` ({}) parameter. Either pass these parameters when "
  109. "instantiating the scheduler, or pass them as parameters "
  110. "to `tune.TuneConfig()`".format(
  111. self.__class__.__name__, self._metric, self._mode
  112. )
  113. )
  114. sizes = np.array([len(b._rungs) for b in self._brackets])
  115. probs = np.e ** (sizes - sizes.max())
  116. normalized = probs / probs.sum()
  117. idx = np.random.choice(len(self._brackets), p=normalized)
  118. self._trial_info[trial.trial_id] = self._brackets[idx]
  119. def on_trial_result(
  120. self, tune_controller: "TuneController", trial: Trial, result: Dict
  121. ) -> str:
  122. action = TrialScheduler.CONTINUE
  123. if self._time_attr not in result or self._metric not in result:
  124. return action
  125. if result[self._time_attr] >= self._max_t and self._stop_last_trials:
  126. action = TrialScheduler.STOP
  127. else:
  128. bracket = self._trial_info[trial.trial_id]
  129. action = bracket.on_result(
  130. trial, result[self._time_attr], self._metric_op * result[self._metric]
  131. )
  132. if action == TrialScheduler.STOP:
  133. self._num_stopped += 1
  134. return action
  135. def on_trial_complete(
  136. self, tune_controller: "TuneController", trial: Trial, result: Dict
  137. ):
  138. if self._time_attr not in result or self._metric not in result:
  139. return
  140. bracket = self._trial_info[trial.trial_id]
  141. bracket.on_result(
  142. trial, result[self._time_attr], self._metric_op * result[self._metric]
  143. )
  144. del self._trial_info[trial.trial_id]
  145. def on_trial_remove(self, tune_controller: "TuneController", trial: Trial):
  146. del self._trial_info[trial.trial_id]
  147. def debug_string(self) -> str:
  148. out = "Using AsyncHyperBand: num_stopped={}".format(self._num_stopped)
  149. out += "\n" + "\n".join([b.debug_str() for b in self._brackets])
  150. return out
  151. def save(self, checkpoint_path: str):
  152. save_object = self.__dict__
  153. with open(checkpoint_path, "wb") as outputFile:
  154. pickle.dump(save_object, outputFile)
  155. def restore(self, checkpoint_path: str):
  156. with open(checkpoint_path, "rb") as inputFile:
  157. save_object = pickle.load(inputFile)
  158. self.__dict__.update(save_object)
  159. class _Bracket:
  160. """Bookkeeping system to track the cutoffs.
  161. Rungs are created in reversed order so that we can more easily find
  162. the correct rung corresponding to the current iteration of the result.
  163. Example:
  164. >>> trial1, trial2, trial3 = ... # doctest: +SKIP
  165. >>> b = _Bracket(1, 10, 2, 0) # doctest: +SKIP
  166. >>> # CONTINUE
  167. >>> b.on_result(trial1, 1, 2) # doctest: +SKIP
  168. >>> # CONTINUE
  169. >>> b.on_result(trial2, 1, 4) # doctest: +SKIP
  170. >>> # rungs are reversed
  171. >>> b.cutoff(b._rungs[-1][1]) == 3.0 # doctest: +SKIP
  172. # STOP
  173. >>> b.on_result(trial3, 1, 1) # doctest: +SKIP
  174. >>> b.cutoff(b._rungs[3][1]) == 2.0 # doctest: +SKIP
  175. """
  176. def __init__(
  177. self,
  178. min_t: int,
  179. max_t: int,
  180. reduction_factor: float,
  181. s: int,
  182. stop_last_trials: bool = True,
  183. ):
  184. self.rf = reduction_factor
  185. MAX_RUNGS = int(np.log(max_t / min_t) / np.log(self.rf) - s + 1)
  186. self._rungs = [
  187. (min_t * self.rf ** (k + s), {}) for k in reversed(range(MAX_RUNGS))
  188. ]
  189. self._stop_last_trials = stop_last_trials
  190. def cutoff(self, recorded) -> Optional[Union[int, float, complex, np.ndarray]]:
  191. if not recorded:
  192. return None
  193. return np.nanpercentile(list(recorded.values()), (1 - 1 / self.rf) * 100)
  194. def on_result(self, trial: Trial, cur_iter: int, cur_rew: Optional[float]) -> str:
  195. action = TrialScheduler.CONTINUE
  196. for milestone, recorded in self._rungs:
  197. if (
  198. cur_iter >= milestone
  199. and trial.trial_id in recorded
  200. and not self._stop_last_trials
  201. ):
  202. # If our result has been recorded for this trial already, the
  203. # decision to continue training has already been made. Thus we can
  204. # skip new cutoff calculation and just continue training.
  205. # We can also break as milestones are descending.
  206. break
  207. if cur_iter < milestone or trial.trial_id in recorded:
  208. continue
  209. else:
  210. cutoff = self.cutoff(recorded)
  211. if cutoff is not None and cur_rew < cutoff:
  212. action = TrialScheduler.STOP
  213. if cur_rew is None:
  214. logger.warning(
  215. "Reward attribute is None! Consider"
  216. " reporting using a different field."
  217. )
  218. else:
  219. recorded[trial.trial_id] = cur_rew
  220. break
  221. return action
  222. def debug_str(self) -> str:
  223. # TODO: fix up the output for this
  224. iters = " | ".join(
  225. [
  226. "Iter {:.3f}: {}".format(milestone, self.cutoff(recorded))
  227. for milestone, recorded in self._rungs
  228. ]
  229. )
  230. return "Bracket: " + iters
  231. ASHAScheduler = AsyncHyperBandScheduler
  232. if __name__ == "__main__":
  233. sched = AsyncHyperBandScheduler(grace_period=1, max_t=10, reduction_factor=2)
  234. print(sched.debug_string())
  235. bracket = sched._brackets[0]
  236. print(bracket.cutoff({str(i): i for i in range(20)}))