hb_bohb.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import logging
  2. from typing import TYPE_CHECKING, Dict, Optional
  3. from ray.tune.experiment import Trial
  4. from ray.tune.schedulers.hyperband import HyperBandScheduler
  5. from ray.tune.schedulers.trial_scheduler import TrialScheduler
  6. from ray.util import PublicAPI
  7. if TYPE_CHECKING:
  8. from ray.tune.execution.tune_controller import TuneController
  9. logger = logging.getLogger(__name__)
  10. @PublicAPI
  11. class HyperBandForBOHB(HyperBandScheduler):
  12. """Extends HyperBand early stopping algorithm for BOHB.
  13. This implementation removes the ``HyperBandScheduler`` pipelining. This
  14. class introduces key changes:
  15. 1. Trials are now placed so that the bracket with the largest size is
  16. filled first.
  17. 2. Trials will be paused even if the bracket is not filled. This allows
  18. BOHB to insert new trials into the training.
  19. See ray.tune.schedulers.HyperBandScheduler for parameter docstring.
  20. """
  21. def on_trial_add(self, tune_controller: "TuneController", trial: Trial):
  22. """Adds new trial.
  23. On a new trial add, if current bracket is not filled, add to current
  24. bracket. Else, if current band is not filled, create new bracket, add
  25. to current bracket. Else, create new iteration, create new bracket,
  26. add to bracket.
  27. """
  28. if not self._metric or not self._metric_op:
  29. raise ValueError(
  30. "{} has been instantiated without a valid `metric` ({}) or "
  31. "`mode` ({}) parameter. Either pass these parameters when "
  32. "instantiating the scheduler, or pass them as parameters "
  33. "to `tune.TuneConfig()`".format(
  34. self.__class__.__name__, self._metric, self._mode
  35. )
  36. )
  37. cur_bracket = self._state["bracket"]
  38. cur_band = self._hyperbands[self._state["band_idx"]]
  39. if cur_bracket is None or cur_bracket.filled():
  40. retry = True
  41. while retry:
  42. # if current iteration is filled, create new iteration
  43. if self._cur_band_filled():
  44. cur_band = []
  45. self._hyperbands.append(cur_band)
  46. self._state["band_idx"] += 1
  47. # MAIN CHANGE HERE - largest bracket first!
  48. # cur_band will always be less than s_max_1 or else filled
  49. s = self._s_max_1 - len(cur_band) - 1
  50. assert s >= 0, "Current band is filled!"
  51. if self._get_r0(s) == 0:
  52. logger.debug("BOHB: Bracket too small - Retrying...")
  53. cur_bracket = None
  54. else:
  55. retry = False
  56. cur_bracket = self._create_bracket(s)
  57. cur_band.append(cur_bracket)
  58. self._state["bracket"] = cur_bracket
  59. self._state["bracket"].add_trial(trial)
  60. self._trial_info[trial] = cur_bracket, self._state["band_idx"]
  61. def on_trial_result(
  62. self, tune_controller: "TuneController", trial: Trial, result: Dict
  63. ) -> str:
  64. """If bracket is finished, all trials will be stopped.
  65. If a given trial finishes and bracket iteration is not done,
  66. the trial will be paused and resources will be given up.
  67. This scheduler will not start trials but will stop trials.
  68. The current running trial will not be handled,
  69. as the trialrunner will be given control to handle it."""
  70. result["hyperband_info"] = {}
  71. bracket, _ = self._trial_info[trial]
  72. bracket.update_trial_stats(trial, result)
  73. if bracket.continue_trial(trial):
  74. return TrialScheduler.CONTINUE
  75. result["hyperband_info"]["budget"] = bracket._cumul_r
  76. # MAIN CHANGE HERE!
  77. statuses = [(t, t.status) for t in bracket._live_trials]
  78. if not bracket.filled() or any(
  79. status != Trial.PAUSED for t, status in statuses if t is not trial
  80. ):
  81. # BOHB Specific. This hack existed in old Ray versions
  82. # and was removed, but it needs to be brought back
  83. # as otherwise the BOHB doesn't behave as intended.
  84. # The default concurrency limiter works by discarding
  85. # new suggestions if there are more running trials
  86. # than the limit. That doesn't take into account paused
  87. # trials. With BOHB, this leads to N trials finishing
  88. # completely and then another N trials starting,
  89. # instead of trials being paused and resumed in brackets
  90. # as intended.
  91. # There should be a better API for this.
  92. # TODO(team-ml): Refactor alongside HyperBandForBOHB
  93. tune_controller.search_alg.searcher.on_pause(trial.trial_id)
  94. return TrialScheduler.PAUSE
  95. logger.debug(f"Processing bracket after trial {trial} result")
  96. action = self._process_bracket(tune_controller, bracket)
  97. if action == TrialScheduler.PAUSE:
  98. tune_controller.search_alg.searcher.on_pause(trial.trial_id)
  99. return action
  100. def _unpause_trial(self, tune_controller: "TuneController", trial: Trial):
  101. # Hack. See comment in on_trial_result
  102. tune_controller.search_alg.searcher.on_unpause(trial.trial_id)
  103. def choose_trial_to_run(
  104. self, tune_controller: "TuneController", allow_recurse: bool = True
  105. ) -> Optional[Trial]:
  106. """Fair scheduling within iteration by completion percentage.
  107. List of trials not used since all trials are tracked as state
  108. of scheduler. If iteration is occupied (ie, no trials to run),
  109. then look into next iteration.
  110. """
  111. for hyperband in self._hyperbands:
  112. # band will have None entries if no resources
  113. # are to be allocated to that bracket.
  114. scrubbed = [b for b in hyperband if b is not None]
  115. for bracket in scrubbed:
  116. for trial in bracket.current_trials():
  117. if (
  118. trial.status == Trial.PAUSED
  119. and trial in bracket.trials_to_unpause
  120. ) or trial.status == Trial.PENDING:
  121. return trial
  122. # MAIN CHANGE HERE!
  123. if not any(t.status == Trial.RUNNING for t in tune_controller.get_trials()):
  124. for hyperband in self._hyperbands:
  125. for bracket in hyperband:
  126. if bracket and any(
  127. trial.status == Trial.PAUSED
  128. for trial in bracket.current_trials()
  129. ):
  130. # This will change the trial state
  131. logger.debug("Processing bracket since no trial is running.")
  132. self._process_bracket(tune_controller, bracket)
  133. # If there are pending trials now, suggest one.
  134. # This is because there might be both PENDING and
  135. # PAUSED trials now, and PAUSED trials will raise
  136. # an error before the trial runner tries again.
  137. if allow_recurse and any(
  138. (
  139. trial.status == Trial.PAUSED
  140. and trial in bracket.trials_to_unpause
  141. )
  142. or trial.status == Trial.PENDING
  143. for trial in bracket.current_trials()
  144. ):
  145. return self.choose_trial_to_run(
  146. tune_controller, allow_recurse=False
  147. )
  148. # MAIN CHANGE HERE!
  149. return None