| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176 |
- import copy
- import logging
- from typing import Dict, List, Optional
- from ray.tune.search.searcher import Searcher
- from ray.tune.search.util import _set_search_properties_backwards_compatible
- from ray.util.annotations import PublicAPI
- logger = logging.getLogger(__name__)
- @PublicAPI
- class ConcurrencyLimiter(Searcher):
- """A wrapper algorithm for limiting the number of concurrent trials.
- Certain Searchers have their own internal logic for limiting
- the number of concurrent trials. If such a Searcher is passed to a
- ``ConcurrencyLimiter``, the ``max_concurrent`` of the
- ``ConcurrencyLimiter`` will override the ``max_concurrent`` value
- of the Searcher. The ``ConcurrencyLimiter`` will then let the
- Searcher's internal logic take over.
- Args:
- searcher: Searcher object that the
- ConcurrencyLimiter will manage.
- max_concurrent: Maximum concurrent samples from the underlying
- searcher.
- batch: Whether to wait for all concurrent samples
- to finish before updating the underlying searcher.
- Example:
- .. code-block:: python
- from ray.tune.search import ConcurrencyLimiter
- search_alg = HyperOptSearch(metric="accuracy")
- search_alg = ConcurrencyLimiter(search_alg, max_concurrent=2)
- tuner = tune.Tuner(
- trainable,
- tune_config=tune.TuneConfig(
- search_alg=search_alg
- ),
- )
- tuner.fit()
- """
- def __init__(self, searcher: Searcher, max_concurrent: int, batch: bool = False):
- assert type(max_concurrent) is int and max_concurrent > 0
- self.searcher = searcher
- self.max_concurrent = max_concurrent
- self.batch = batch
- self.live_trials = set()
- self.num_unfinished_live_trials = 0
- self.cached_results = {}
- self._limit_concurrency = True
- if not isinstance(searcher, Searcher):
- raise RuntimeError(
- f"The `ConcurrencyLimiter` only works with `Searcher` "
- f"objects (got {type(searcher)}). Please try to pass "
- f"`max_concurrent` to the search generator directly."
- )
- self._set_searcher_max_concurrency()
- super(ConcurrencyLimiter, self).__init__(
- metric=self.searcher.metric, mode=self.searcher.mode
- )
- def _set_searcher_max_concurrency(self):
- # If the searcher has special logic for handling max concurrency,
- # we do not do anything inside the ConcurrencyLimiter
- self._limit_concurrency = not self.searcher.set_max_concurrency(
- self.max_concurrent
- )
- def set_max_concurrency(self, max_concurrent: int) -> bool:
- # Determine if this behavior is acceptable, or if it should
- # raise an exception.
- self.max_concurrent = max_concurrent
- return True
- def set_search_properties(
- self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
- ) -> bool:
- self._set_searcher_max_concurrency()
- return _set_search_properties_backwards_compatible(
- self.searcher.set_search_properties, metric, mode, config, **spec
- )
- def suggest(self, trial_id: str) -> Optional[Dict]:
- if not self._limit_concurrency:
- return self.searcher.suggest(trial_id)
- assert (
- trial_id not in self.live_trials
- ), f"Trial ID {trial_id} must be unique: already found in set."
- if len(self.live_trials) >= self.max_concurrent:
- logger.debug(
- f"Not providing a suggestion for {trial_id} due to "
- "concurrency limit: %s/%s.",
- len(self.live_trials),
- self.max_concurrent,
- )
- return
- suggestion = self.searcher.suggest(trial_id)
- if suggestion not in (None, Searcher.FINISHED):
- self.live_trials.add(trial_id)
- self.num_unfinished_live_trials += 1
- return suggestion
- def on_trial_complete(
- self, trial_id: str, result: Optional[Dict] = None, error: bool = False
- ):
- if not self._limit_concurrency:
- return self.searcher.on_trial_complete(trial_id, result=result, error=error)
- if trial_id not in self.live_trials:
- return
- elif self.batch:
- self.cached_results[trial_id] = (result, error)
- self.num_unfinished_live_trials -= 1
- if self.num_unfinished_live_trials <= 0:
- # Update the underlying searcher once the
- # full batch is completed.
- for trial_id, (result, error) in self.cached_results.items():
- self.searcher.on_trial_complete(
- trial_id, result=result, error=error
- )
- self.live_trials.remove(trial_id)
- self.cached_results = {}
- self.num_unfinished_live_trials = 0
- else:
- return
- else:
- self.searcher.on_trial_complete(trial_id, result=result, error=error)
- self.live_trials.remove(trial_id)
- self.num_unfinished_live_trials -= 1
- def on_trial_result(self, trial_id: str, result: Dict) -> None:
- self.searcher.on_trial_result(trial_id, result)
- def add_evaluated_point(
- self,
- parameters: Dict,
- value: float,
- error: bool = False,
- pruned: bool = False,
- intermediate_values: Optional[List[float]] = None,
- ):
- return self.searcher.add_evaluated_point(
- parameters, value, error, pruned, intermediate_values
- )
- def get_state(self) -> Dict:
- state = self.__dict__.copy()
- del state["searcher"]
- return copy.deepcopy(state)
- def set_state(self, state: Dict):
- self.__dict__.update(state)
- def save(self, checkpoint_path: str):
- self.searcher.save(checkpoint_path)
- def restore(self, checkpoint_path: str):
- self.searcher.restore(checkpoint_path)
- # BOHB Specific.
- # TODO(team-ml): Refactor alongside HyperBandForBOHB
- def on_pause(self, trial_id: str):
- self.searcher.on_pause(trial_id)
- def on_unpause(self, trial_id: str):
- self.searcher.on_unpause(trial_id)
|