hyperband.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606
  1. import collections
  2. import logging
  3. from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
  4. import numpy as np
  5. from ray.tune.error import TuneError
  6. from ray.tune.experiment import Trial
  7. from ray.tune.result import DEFAULT_METRIC
  8. from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
  9. from ray.util.annotations import PublicAPI
  10. if TYPE_CHECKING:
  11. from ray.tune.execution.tune_controller import TuneController
  12. logger = logging.getLogger(__name__)
  13. # Implementation notes:
  14. # This implementation contains 3 logical levels.
  15. # Each HyperBand iteration is a "band". There can be multiple
  16. # bands running at once, and there can be 1 band that is incomplete.
  17. #
  18. # In each band, there are at most `s` + 1 brackets.
  19. # `s` is a value determined by given parameters, and assigned on
  20. # a cyclic basis.
  21. #
  22. # In each bracket, there are at most `n(s)` trials, indicating that
  23. # `n` is a function of `s`. These trials go through a series of
  24. # halving procedures, dropping lowest performers. Multiple
  25. # brackets are running at once.
  26. #
  27. # Trials added will be inserted into the most recent bracket
  28. # and band and will spill over to new brackets/bands accordingly.
  29. #
  30. # This maintains the bracket size and max trial count per band
  31. # to 5 and 117 respectively, which correspond to that of
  32. # `max_attr=81, eta=3` from the blog post. Trials will fill up
  33. # from smallest bracket to largest, with largest
  34. # having the most rounds of successive halving.
  35. @PublicAPI
  36. class HyperBandScheduler(FIFOScheduler):
  37. """Implements the HyperBand early stopping algorithm.
  38. HyperBandScheduler early stops trials using the HyperBand optimization
  39. algorithm. It divides trials into brackets of varying sizes, and
  40. periodically early stops low-performing trials within each bracket.
  41. To use this implementation of HyperBand with Tune, all you need
  42. to do is specify the max length of time a trial can run `max_t`, the time
  43. units `time_attr`, the name of the reported objective value `metric`,
  44. and if `metric` is to be maximized or minimized (`mode`).
  45. We automatically determine reasonable values for the other
  46. HyperBand parameters based on the given values.
  47. For example, to limit trials to 10 minutes and early stop based on the
  48. `episode_mean_reward` attr, construct:
  49. ``HyperBand('time_total_s', 'episode_reward_mean', max_t=600)``
  50. Note that Tune's stopping criteria will be applied in conjunction with
  51. HyperBand's early stopping mechanisms.
  52. See also: https://blog.ml.cmu.edu/2018/12/12/massively-parallel-hyperparameter-optimization/
  53. Args:
  54. time_attr: The training result attr to use for comparing time.
  55. Note that you can pass in something non-temporal such as
  56. `training_iteration` as a measure of progress, the only requirement
  57. is that the attribute should increase monotonically.
  58. metric: The training result objective value attribute. Stopping
  59. procedures will use this attribute. If None but a mode was passed,
  60. the `ray.tune.result.DEFAULT_METRIC` will be used per default.
  61. mode: One of {min, max}. Determines whether objective is
  62. minimizing or maximizing the metric attribute.
  63. max_t: max time units per trial. Trials will be stopped after
  64. max_t time units (determined by time_attr) have passed.
  65. The scheduler will terminate trials after this time has passed.
  66. Note that this is different from the semantics of `max_t` as
  67. mentioned in the original HyperBand paper.
  68. reduction_factor: Same as `eta`. Determines how sharp
  69. the difference is between bracket space-time allocation ratios.
  70. stop_last_trials: Whether to terminate the trials after
  71. reaching max_t. Defaults to True.
  72. """ # noqa: E501
  73. _supports_buffered_results = False
  74. def __init__(
  75. self,
  76. time_attr: str = "training_iteration",
  77. metric: Optional[str] = None,
  78. mode: Optional[str] = None,
  79. max_t: int = 81,
  80. reduction_factor: float = 3,
  81. stop_last_trials: bool = True,
  82. ):
  83. assert max_t > 0, "Max (time_attr) not valid!"
  84. if mode:
  85. assert mode in ["min", "max"], "`mode` must be 'min' or 'max'!"
  86. super().__init__()
  87. self._eta = reduction_factor
  88. self._s_max_1 = int(np.round(np.log(max_t) / np.log(reduction_factor))) + 1
  89. self._max_t_attr = max_t
  90. # bracket max trials
  91. self._get_n0 = lambda s: int(np.ceil(self._s_max_1 / (s + 1) * self._eta**s))
  92. # bracket initial iterations
  93. self._get_r0 = lambda s: int((max_t * self._eta ** (-s)))
  94. self._hyperbands = [[]] # list of hyperband iterations
  95. self._trial_info = {} # Stores Trial -> Bracket, Band Iteration
  96. # Tracks state for new trial add
  97. self._state = {"bracket": None, "band_idx": 0}
  98. self._num_stopped = 0
  99. self._metric = metric
  100. self._mode = mode
  101. self._metric_op = None
  102. if self._mode == "max":
  103. self._metric_op = 1.0
  104. elif self._mode == "min":
  105. self._metric_op = -1.0
  106. self._time_attr = time_attr
  107. self._stop_last_trials = stop_last_trials
  108. def set_search_properties(
  109. self, metric: Optional[str], mode: Optional[str], **spec
  110. ) -> bool:
  111. if self._metric and metric:
  112. return False
  113. if self._mode and mode:
  114. return False
  115. if metric:
  116. self._metric = metric
  117. if mode:
  118. self._mode = mode
  119. if self._mode == "max":
  120. self._metric_op = 1.0
  121. elif self._mode == "min":
  122. self._metric_op = -1.0
  123. if self._metric is None and self._mode:
  124. # If only a mode was passed, use anonymous metric
  125. self._metric = DEFAULT_METRIC
  126. return True
  127. def on_trial_add(self, tune_controller: "TuneController", trial: Trial):
  128. """Adds new trial.
  129. On a new trial add, if current bracket is not filled,
  130. add to current bracket. Else, if current band is not filled,
  131. create new bracket, add to current bracket.
  132. Else, create new iteration, create new bracket, add to bracket."""
  133. if not self._metric or not self._metric_op:
  134. raise ValueError(
  135. "{} has been instantiated without a valid `metric` ({}) or "
  136. "`mode` ({}) parameter. Either pass these parameters when "
  137. "instantiating the scheduler, or pass them as parameters "
  138. "to `tune.TuneConfig()`".format(
  139. self.__class__.__name__, self._metric, self._mode
  140. )
  141. )
  142. cur_bracket = self._state["bracket"]
  143. cur_band = self._hyperbands[self._state["band_idx"]]
  144. if cur_bracket is None or cur_bracket.filled():
  145. retry = True
  146. while retry:
  147. # if current iteration is filled, create new iteration
  148. if self._cur_band_filled():
  149. cur_band = []
  150. self._hyperbands.append(cur_band)
  151. self._state["band_idx"] += 1
  152. # cur_band will always be less than s_max_1 or else filled
  153. s = len(cur_band)
  154. assert s < self._s_max_1, "Current band is filled!"
  155. if self._get_r0(s) == 0:
  156. logger.info("Bracket too small - Retrying...")
  157. cur_bracket = None
  158. else:
  159. retry = False
  160. cur_bracket = self._create_bracket(s)
  161. cur_band.append(cur_bracket)
  162. self._state["bracket"] = cur_bracket
  163. self._state["bracket"].add_trial(trial)
  164. self._trial_info[trial] = cur_bracket, self._state["band_idx"]
  165. def _create_bracket(self, s):
  166. return _Bracket(
  167. time_attr=self._time_attr,
  168. max_trials=self._get_n0(s),
  169. init_t_attr=self._get_r0(s),
  170. max_t_attr=self._max_t_attr,
  171. eta=self._eta,
  172. s=s,
  173. stop_last_trials=self._stop_last_trials,
  174. )
  175. def _cur_band_filled(self) -> bool:
  176. """Checks if the current band is filled.
  177. The size of the current band should be equal to s_max_1"""
  178. cur_band = self._hyperbands[self._state["band_idx"]]
  179. return len(cur_band) == self._s_max_1
  180. def on_trial_result(
  181. self, tune_controller: "TuneController", trial: Trial, result: Dict
  182. ):
  183. """If bracket is finished, all trials will be stopped.
  184. If a given trial finishes and bracket iteration is not done,
  185. the trial will be paused and resources will be given up.
  186. This scheduler will not start trials but will stop trials.
  187. The current running trial will not be handled,
  188. as the trialrunner will be given control to handle it."""
  189. bracket, _ = self._trial_info[trial]
  190. bracket.update_trial_stats(trial, result)
  191. if bracket.continue_trial(trial):
  192. return TrialScheduler.CONTINUE
  193. logger.debug(f"Processing bracket after trial {trial} result")
  194. action = self._process_bracket(tune_controller, bracket)
  195. logger.debug(
  196. f"{action} for {trial} on "
  197. f"{self._time_attr}={result.get(self._time_attr)}"
  198. )
  199. return action
  200. def _process_bracket(
  201. self, tune_controller: "TuneController", bracket: "_Bracket"
  202. ) -> str:
  203. """This is called whenever a trial makes progress.
  204. When all live trials in the bracket have no more iterations left,
  205. Trials will be successively halved. If bracket is done, all
  206. non-running trials will be stopped and cleaned up,
  207. and during each halving phase, bad trials will be stopped while good
  208. trials will return to "PENDING".
  209. Note some implicit conditions here: In ``on_trial_result`` a trial is
  210. either continued (e.g. if it didn't reach the time threshold for the bracket)
  211. or this method (``_process_bracket``) is called. If there are other trials left
  212. that still haven't reached the threshold, the trial is PAUSED. This means
  213. that when the bracket is actually processed (``bracket.cur_iter_done``), there
  214. is at most one RUNNING trial (which is the trial that is currently processed)
  215. and the rest are either PAUSED (as explained above) or TERMINATED/ERRORED
  216. (if they finish separately).
  217. """
  218. action = TrialScheduler.PAUSE
  219. if bracket.cur_iter_done():
  220. if bracket.finished():
  221. bracket.cleanup_full(tune_controller)
  222. return TrialScheduler.STOP
  223. bracket.is_being_processed = True
  224. good, bad = bracket.successive_halving(self._metric, self._metric_op)
  225. logger.debug(
  226. f"Processing {len(good)} good and {len(bad)} bad trials in "
  227. f"bracket {bracket}.\n"
  228. f"Good: {good}\nBad: {bad}"
  229. )
  230. # kill bad trials
  231. self._num_stopped += len(bad)
  232. for t in bad:
  233. if t.status == Trial.PAUSED or t.is_saving:
  234. logger.debug(f"Stopping other trial {str(t)}")
  235. tune_controller.stop_trial(t)
  236. elif t.status == Trial.RUNNING:
  237. # See the docstring: There can only be at most one RUNNING
  238. # trial, which is the current trial.
  239. logger.debug(f"Stopping current trial {str(t)}")
  240. bracket.cleanup_trial(t)
  241. action = TrialScheduler.STOP
  242. else:
  243. # Trials cannot be ERROR/TERMINATED, as then they would have
  244. # been removed from the bracket (in `bracket.cleanup_trial`).
  245. # Trials cannot be PENDING, as then they wouldn't have reported
  246. # enough results to finish the bracket, and it wouldn't be
  247. # processed.
  248. raise TuneError(
  249. f"Trial with unexpected bad status encountered: "
  250. f"{str(t)} is {t.status}"
  251. )
  252. # ready the good trials - if trial is too far ahead, don't continue
  253. for t in good:
  254. if bracket.continue_trial(t):
  255. # The scheduler should have cleaned up this trial already.
  256. assert t.status not in (Trial.ERROR, Trial.TERMINATED), (
  257. f"Good trial {t.trial_id} is in an invalid state: {t.status}\n"
  258. "Expected trial to be either PAUSED, PENDING, or RUNNING.\n"
  259. "If you encounter this, please file an issue on the Ray Github."
  260. )
  261. if t.status == Trial.PAUSED or t.is_saving:
  262. logger.debug(f"Unpausing trial {str(t)}")
  263. self._unpause_trial(tune_controller, t)
  264. bracket.trials_to_unpause.add(t)
  265. elif t.status == Trial.RUNNING:
  266. # See the docstring: There can only be at most one RUNNING
  267. # trial, which is the current trial.
  268. logger.debug(f"Continuing current trial {str(t)}")
  269. action = TrialScheduler.CONTINUE
  270. # else: PENDING trial (from a previous unpause) should stay as is.
  271. elif bracket.finished() and bracket.stop_last_trials:
  272. # Scheduler decides to not continue trial because the bracket
  273. # reached max_t. In this case, stop the trials
  274. if t.status == Trial.PAUSED or t.is_saving:
  275. logger.debug(f"Bracket finished. Stopping other trial {str(t)}")
  276. tune_controller.stop_trial(t)
  277. elif t.status == Trial.RUNNING:
  278. # See the docstring: There can only be at most one RUNNING
  279. # trial, which is the current trial.
  280. logger.debug(
  281. f"Bracket finished. Stopping current trial {str(t)}"
  282. )
  283. bracket.cleanup_trial(t)
  284. action = TrialScheduler.STOP
  285. return action
  286. def _unpause_trial(self, tune_controller: "TuneController", trial: Trial):
  287. """No-op by default."""
  288. return
  289. def on_trial_remove(self, tune_controller: "TuneController", trial: Trial):
  290. """Notification when trial terminates.
  291. Trial info is removed from bracket. Triggers halving if bracket is
  292. not finished."""
  293. bracket, _ = self._trial_info[trial]
  294. bracket.cleanup_trial(trial)
  295. if not bracket.finished() and not bracket.is_being_processed:
  296. logger.debug(f"Processing bracket after trial {trial} removed")
  297. self._process_bracket(tune_controller, bracket)
  298. def on_trial_complete(
  299. self, tune_controller: "TuneController", trial: Trial, result: Dict
  300. ):
  301. """Cleans up trial info from bracket if trial completed early."""
  302. self.on_trial_remove(tune_controller, trial)
  303. def on_trial_error(self, tune_controller: "TuneController", trial: Trial):
  304. """Cleans up trial info from bracket if trial errored early."""
  305. self.on_trial_remove(tune_controller, trial)
  306. def choose_trial_to_run(self, tune_controller: "TuneController") -> Optional[Trial]:
  307. """Fair scheduling within iteration by completion percentage.
  308. List of trials not used since all trials are tracked as state
  309. of scheduler. If iteration is occupied (ie, no trials to run),
  310. then look into next iteration.
  311. """
  312. for hyperband in self._hyperbands:
  313. # band will have None entries if no resources
  314. # are to be allocated to that bracket.
  315. scrubbed = [b for b in hyperband if b is not None]
  316. for bracket in sorted(scrubbed, key=lambda b: b.completion_percentage()):
  317. for trial in bracket.current_trials():
  318. if (
  319. trial.status == Trial.PAUSED
  320. and trial in bracket.trials_to_unpause
  321. ) or trial.status == Trial.PENDING:
  322. return trial
  323. return None
  324. def debug_string(self) -> str:
  325. """This provides a progress notification for the algorithm.
  326. For each bracket, the algorithm will output a string as follows:
  327. Bracket(Max Size (n)=5, Milestone (r)=33, completed=14.6%):
  328. {PENDING: 2, RUNNING: 3, TERMINATED: 2}
  329. "Max Size" indicates the max number of pending/running experiments
  330. set according to the Hyperband algorithm.
  331. "Milestone" indicates the iterations a trial will run for before
  332. the next halving will occur.
  333. "Completed" indicates an approximate progress metric. Some brackets,
  334. like ones that are unfilled, will not reach 100%.
  335. """
  336. out = "Using HyperBand: "
  337. out += "num_stopped={} total_brackets={}".format(
  338. self._num_stopped, sum(len(band) for band in self._hyperbands)
  339. )
  340. for i, band in enumerate(self._hyperbands):
  341. out += "\nRound #{}:".format(i)
  342. for bracket in band:
  343. if bracket:
  344. out += "\n {}".format(bracket)
  345. return out
  346. def state(self) -> Dict[str, int]:
  347. return {
  348. "num_brackets": sum(len(band) for band in self._hyperbands),
  349. "num_stopped": self._num_stopped,
  350. }
  351. class _Bracket:
  352. """Logical object for tracking Hyperband bracket progress. Keeps track
  353. of proper parameters as designated by HyperBand.
  354. Also keeps track of progress to ensure good scheduling.
  355. """
  356. def __init__(
  357. self,
  358. time_attr: str,
  359. max_trials: int,
  360. init_t_attr: int,
  361. max_t_attr: int,
  362. eta: float,
  363. s: int,
  364. stop_last_trials: bool = True,
  365. ):
  366. self._live_trials = {} # maps trial -> current result
  367. self._all_trials = []
  368. self._time_attr = time_attr # attribute to
  369. self._n = self._n0 = max_trials
  370. self._r = self._r0 = init_t_attr
  371. self._max_t_attr = max_t_attr
  372. self._cumul_r = self._r0
  373. self._eta = eta
  374. self._halves = s
  375. self._total_work = self._calculate_total_work(self._n0, self._r0, s)
  376. self._completed_progress = 0
  377. self.stop_last_trials = stop_last_trials
  378. self.is_being_processed = False
  379. self.trials_to_unpause = set()
  380. def add_trial(self, trial: Trial):
  381. """Add trial to bracket assuming bracket is not filled.
  382. At a later iteration, a newly added trial will be given equal
  383. opportunity to catch up."""
  384. assert not self.filled(), "Cannot add trial to filled bracket!"
  385. self._live_trials[trial] = None
  386. self._all_trials.append(trial)
  387. def cur_iter_done(self) -> bool:
  388. """Checks if all iterations have completed.
  389. TODO(rliaw): also check that `t.iterations == self._r`"""
  390. return all(
  391. self._get_result_time(result) >= self._cumul_r
  392. for result in self._live_trials.values()
  393. )
  394. def finished(self) -> bool:
  395. if not self.stop_last_trials:
  396. return False
  397. return self._halves == 0 and self.cur_iter_done()
  398. def current_trials(self) -> List[Trial]:
  399. return list(self._live_trials)
  400. def continue_trial(self, trial: Trial) -> bool:
  401. result = self._live_trials[trial]
  402. if not self.stop_last_trials and self._halves == 0:
  403. return True
  404. elif self._get_result_time(result) < self._cumul_r:
  405. logger.debug(
  406. f"Continuing trial {trial} as it hasn't reached the time threshold "
  407. f"{self._cumul_r}, yet."
  408. )
  409. return True
  410. return False
  411. def filled(self) -> bool:
  412. """Checks if bracket is filled.
  413. Only let new trials be added at current level minimizing the need
  414. to backtrack and bookkeep previous medians."""
  415. return len(self._live_trials) == self._n
  416. def successive_halving(
  417. self, metric: str, metric_op: float
  418. ) -> Tuple[List[Trial], List[Trial]]:
  419. if self._halves == 0 and not self.stop_last_trials:
  420. return self._live_trials, []
  421. assert self._halves > 0
  422. # "Halving" is a misnomer. We're actually reducing by factor `eta`.
  423. self._halves -= 1
  424. # If we had 8 trials in the bracket and eta=2, we will keep 4.
  425. # If we had 9 trials in the bracket and eta=3, we will keep 3.
  426. self._n = int(np.ceil(self._n / self._eta))
  427. # Likewise, we increase the number of iterations until we process the bracket
  428. # again.
  429. # Remember r0 = max_t * self._eta ** (-s)
  430. # Let max_t=16, eta=2, s=1. Then r0=8, and we calculate r1=16.
  431. # Let max_t=16, eta=2, s=2. Then r0=4, and we calculate r1=8, r2=16.
  432. # Let max_t=81, eta=3, s=1. Then r0=27, and we calculate r1=81.
  433. # Let max_t=81, eta=3, s=2. Then r0=9, and we calculate r1=27, r2=81.
  434. self._r *= self._eta
  435. self._r = int(min(self._r, self._max_t_attr))
  436. self._cumul_r = self._r
  437. sorted_trials = sorted(
  438. self._live_trials, key=lambda t: metric_op * self._live_trials[t][metric]
  439. )
  440. good, bad = sorted_trials[-self._n :], sorted_trials[: -self._n]
  441. return good, bad
  442. def update_trial_stats(self, trial: Trial, result: Dict):
  443. """Update result for trial. Called after trial has finished
  444. an iteration - will decrement iteration count.
  445. TODO(rliaw): The other alternative is to keep the trials
  446. in and make sure they're not set as pending later."""
  447. assert trial in self._live_trials
  448. assert self._get_result_time(result) >= 0
  449. observed_time = self._get_result_time(result)
  450. last_observed = self._get_result_time(self._live_trials[trial])
  451. delta = observed_time - last_observed
  452. if delta <= 0:
  453. logger.info(
  454. "Restoring from a previous point in time. "
  455. "Previous={}; Now={}".format(last_observed, observed_time)
  456. )
  457. self._completed_progress += delta
  458. self._live_trials[trial] = result
  459. self.trials_to_unpause.discard(trial)
  460. def cleanup_trial(self, trial: Trial):
  461. """Clean up statistics tracking for terminated trials (either by force
  462. or otherwise).
  463. This may cause bad trials to continue for a long time, in the case
  464. where all the good trials finish early and there are only bad trials
  465. left in a bracket with a large max-iteration."""
  466. self._live_trials.pop(trial, None)
  467. def cleanup_full(self, tune_controller: "TuneController"):
  468. """Cleans up bracket after bracket is completely finished.
  469. Lets the last trial continue to run until termination condition
  470. kicks in."""
  471. for trial in self.current_trials():
  472. if trial.status == Trial.PAUSED:
  473. tune_controller.stop_trial(trial)
  474. def completion_percentage(self) -> float:
  475. """Returns a progress metric.
  476. This will not be always finish with 100 since dead trials
  477. are dropped."""
  478. if self.finished():
  479. return 1.0
  480. return min(self._completed_progress / self._total_work, 1.0)
  481. def _get_result_time(self, result: Dict) -> float:
  482. if result is None:
  483. return 0
  484. return result[self._time_attr]
  485. def _calculate_total_work(self, n: int, r: float, s: int):
  486. work = 0
  487. cumulative_r = r
  488. for _ in range(s + 1):
  489. work += int(n) * int(r)
  490. n /= self._eta
  491. n = int(np.ceil(n))
  492. r *= self._eta
  493. r = int(min(r, self._max_t_attr - cumulative_r))
  494. return work
  495. def __repr__(self) -> str:
  496. status = ", ".join(
  497. [
  498. "Max Size (n)={}".format(self._n),
  499. "Milestone (r)={}".format(self._cumul_r),
  500. "completed={:.1%}".format(self.completion_percentage()),
  501. ]
  502. )
  503. counts = collections.Counter([t.status for t in self._all_trials])
  504. trial_statuses = ", ".join(
  505. sorted("{}: {}".format(k, v) for k, v in counts.items())
  506. )
  507. return "Bracket({}): {{{}}} ".format(status, trial_statuses)