concurrency_limiter.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. import copy
  2. import logging
  3. from typing import Dict, List, Optional
  4. from ray.tune.search.searcher import Searcher
  5. from ray.tune.search.util import _set_search_properties_backwards_compatible
  6. from ray.util.annotations import PublicAPI
  7. logger = logging.getLogger(__name__)
  8. @PublicAPI
  9. class ConcurrencyLimiter(Searcher):
  10. """A wrapper algorithm for limiting the number of concurrent trials.
  11. Certain Searchers have their own internal logic for limiting
  12. the number of concurrent trials. If such a Searcher is passed to a
  13. ``ConcurrencyLimiter``, the ``max_concurrent`` of the
  14. ``ConcurrencyLimiter`` will override the ``max_concurrent`` value
  15. of the Searcher. The ``ConcurrencyLimiter`` will then let the
  16. Searcher's internal logic take over.
  17. Args:
  18. searcher: Searcher object that the
  19. ConcurrencyLimiter will manage.
  20. max_concurrent: Maximum concurrent samples from the underlying
  21. searcher.
  22. batch: Whether to wait for all concurrent samples
  23. to finish before updating the underlying searcher.
  24. Example:
  25. .. code-block:: python
  26. from ray.tune.search import ConcurrencyLimiter
  27. search_alg = HyperOptSearch(metric="accuracy")
  28. search_alg = ConcurrencyLimiter(search_alg, max_concurrent=2)
  29. tuner = tune.Tuner(
  30. trainable,
  31. tune_config=tune.TuneConfig(
  32. search_alg=search_alg
  33. ),
  34. )
  35. tuner.fit()
  36. """
  37. def __init__(self, searcher: Searcher, max_concurrent: int, batch: bool = False):
  38. assert type(max_concurrent) is int and max_concurrent > 0
  39. self.searcher = searcher
  40. self.max_concurrent = max_concurrent
  41. self.batch = batch
  42. self.live_trials = set()
  43. self.num_unfinished_live_trials = 0
  44. self.cached_results = {}
  45. self._limit_concurrency = True
  46. if not isinstance(searcher, Searcher):
  47. raise RuntimeError(
  48. f"The `ConcurrencyLimiter` only works with `Searcher` "
  49. f"objects (got {type(searcher)}). Please try to pass "
  50. f"`max_concurrent` to the search generator directly."
  51. )
  52. self._set_searcher_max_concurrency()
  53. super(ConcurrencyLimiter, self).__init__(
  54. metric=self.searcher.metric, mode=self.searcher.mode
  55. )
  56. def _set_searcher_max_concurrency(self):
  57. # If the searcher has special logic for handling max concurrency,
  58. # we do not do anything inside the ConcurrencyLimiter
  59. self._limit_concurrency = not self.searcher.set_max_concurrency(
  60. self.max_concurrent
  61. )
  62. def set_max_concurrency(self, max_concurrent: int) -> bool:
  63. # Determine if this behavior is acceptable, or if it should
  64. # raise an exception.
  65. self.max_concurrent = max_concurrent
  66. return True
  67. def set_search_properties(
  68. self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
  69. ) -> bool:
  70. self._set_searcher_max_concurrency()
  71. return _set_search_properties_backwards_compatible(
  72. self.searcher.set_search_properties, metric, mode, config, **spec
  73. )
  74. def suggest(self, trial_id: str) -> Optional[Dict]:
  75. if not self._limit_concurrency:
  76. return self.searcher.suggest(trial_id)
  77. assert (
  78. trial_id not in self.live_trials
  79. ), f"Trial ID {trial_id} must be unique: already found in set."
  80. if len(self.live_trials) >= self.max_concurrent:
  81. logger.debug(
  82. f"Not providing a suggestion for {trial_id} due to "
  83. "concurrency limit: %s/%s.",
  84. len(self.live_trials),
  85. self.max_concurrent,
  86. )
  87. return
  88. suggestion = self.searcher.suggest(trial_id)
  89. if suggestion not in (None, Searcher.FINISHED):
  90. self.live_trials.add(trial_id)
  91. self.num_unfinished_live_trials += 1
  92. return suggestion
  93. def on_trial_complete(
  94. self, trial_id: str, result: Optional[Dict] = None, error: bool = False
  95. ):
  96. if not self._limit_concurrency:
  97. return self.searcher.on_trial_complete(trial_id, result=result, error=error)
  98. if trial_id not in self.live_trials:
  99. return
  100. elif self.batch:
  101. self.cached_results[trial_id] = (result, error)
  102. self.num_unfinished_live_trials -= 1
  103. if self.num_unfinished_live_trials <= 0:
  104. # Update the underlying searcher once the
  105. # full batch is completed.
  106. for trial_id, (result, error) in self.cached_results.items():
  107. self.searcher.on_trial_complete(
  108. trial_id, result=result, error=error
  109. )
  110. self.live_trials.remove(trial_id)
  111. self.cached_results = {}
  112. self.num_unfinished_live_trials = 0
  113. else:
  114. return
  115. else:
  116. self.searcher.on_trial_complete(trial_id, result=result, error=error)
  117. self.live_trials.remove(trial_id)
  118. self.num_unfinished_live_trials -= 1
  119. def on_trial_result(self, trial_id: str, result: Dict) -> None:
  120. self.searcher.on_trial_result(trial_id, result)
  121. def add_evaluated_point(
  122. self,
  123. parameters: Dict,
  124. value: float,
  125. error: bool = False,
  126. pruned: bool = False,
  127. intermediate_values: Optional[List[float]] = None,
  128. ):
  129. return self.searcher.add_evaluated_point(
  130. parameters, value, error, pruned, intermediate_values
  131. )
  132. def get_state(self) -> Dict:
  133. state = self.__dict__.copy()
  134. del state["searcher"]
  135. return copy.deepcopy(state)
  136. def set_state(self, state: Dict):
  137. self.__dict__.update(state)
  138. def save(self, checkpoint_path: str):
  139. self.searcher.save(checkpoint_path)
  140. def restore(self, checkpoint_path: str):
  141. self.searcher.restore(checkpoint_path)
  142. # BOHB Specific.
  143. # TODO(team-ml): Refactor alongside HyperBandForBOHB
  144. def on_pause(self, trial_id: str):
  145. self.searcher.on_pause(trial_id)
  146. def on_unpause(self, trial_id: str):
  147. self.searcher.on_unpause(trial_id)