| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199 |
- import copy
- import logging
- from typing import Dict, List, Optional
- import numpy as np
- from ray.tune.search.searcher import Searcher
- from ray.tune.search.util import _set_search_properties_backwards_compatible
- from ray.util import PublicAPI
- logger = logging.getLogger(__name__)
- TRIAL_INDEX = "__trial_index__"
- """str: A constant value representing the repeat index of the trial."""
- def _warn_num_samples(searcher: Searcher, num_samples: int):
- if isinstance(searcher, Repeater) and num_samples % searcher.repeat:
- logger.warning(
- "`num_samples` is now expected to be the total number of trials, "
- "including the repeat trials. For example, set num_samples=15 if "
- "you intend to obtain 3 search algorithm suggestions and repeat "
- "each suggestion 5 times. Any leftover trials "
- "(num_samples mod repeat) will be ignored."
- )
- class _TrialGroup:
- """Internal class for grouping trials of same parameters.
- This is used when repeating trials for reducing training variance.
- Args:
- primary_trial_id: Trial ID of the "primary trial".
- This trial is the one that the Searcher is aware of.
- config: Suggested configuration shared across all trials
- in the trial group.
- max_trials: Max number of trials to execute within this group.
- """
- def __init__(self, primary_trial_id: str, config: Dict, max_trials: int = 1):
- assert type(config) is dict, "config is not a dict, got {}".format(config)
- self.primary_trial_id = primary_trial_id
- self.config = config
- self._trials = {primary_trial_id: None}
- self.max_trials = max_trials
- def add(self, trial_id: str):
- assert len(self._trials) < self.max_trials
- self._trials.setdefault(trial_id, None)
- def full(self) -> bool:
- return len(self._trials) == self.max_trials
- def report(self, trial_id: str, score: float):
- assert trial_id in self._trials
- if score is None:
- raise ValueError("Internal Error: Score cannot be None.")
- self._trials[trial_id] = score
- def finished_reporting(self) -> bool:
- return (
- None not in self._trials.values() and len(self._trials) == self.max_trials
- )
- def scores(self) -> List[Optional[float]]:
- return list(self._trials.values())
- def count(self) -> int:
- return len(self._trials)
- @PublicAPI
- class Repeater(Searcher):
- """A wrapper algorithm for repeating trials of same parameters.
- Set tune.TuneConfig(num_samples=...) to be a multiple of `repeat`. For example,
- set num_samples=15 if you intend to obtain 3 search algorithm suggestions
- and repeat each suggestion 5 times. Any leftover trials
- (num_samples mod repeat) will be ignored.
- It is recommended that you do not run an early-stopping TrialScheduler
- simultaneously.
- Args:
- searcher: Searcher object that the
- Repeater will optimize. Note that the Searcher
- will only see 1 trial among multiple repeated trials.
- The result/metric passed to the Searcher upon
- trial completion will be averaged among all repeats.
- repeat: Number of times to generate a trial with a repeated
- configuration. Defaults to 1.
- set_index: Sets a tune.search.repeater.TRIAL_INDEX in
- Trainable/Function config which corresponds to the index of the
- repeated trial. This can be used for seeds. Defaults to True.
- Example:
- .. code-block:: python
- from ray.tune.search import Repeater
- search_alg = BayesOptSearch(...)
- re_search_alg = Repeater(search_alg, repeat=10)
- # Repeat 2 samples 10 times each.
- tuner = tune.Tuner(
- trainable,
- tune_config=tune.TuneConfig(
- search_alg=re_search_alg,
- num_samples=20,
- ),
- )
- tuner.fit()
- """
- def __init__(self, searcher: Searcher, repeat: int = 1, set_index: bool = True):
- self.searcher = searcher
- self.repeat = repeat
- self._set_index = set_index
- self._groups = []
- self._trial_id_to_group = {}
- self._current_group = None
- super(Repeater, self).__init__(
- metric=self.searcher.metric, mode=self.searcher.mode
- )
- def suggest(self, trial_id: str) -> Optional[Dict]:
- if self._current_group is None or self._current_group.full():
- config = self.searcher.suggest(trial_id)
- if config is None:
- return config
- self._current_group = _TrialGroup(
- trial_id, copy.deepcopy(config), max_trials=self.repeat
- )
- self._groups.append(self._current_group)
- index_in_group = 0
- else:
- index_in_group = self._current_group.count()
- self._current_group.add(trial_id)
- config = self._current_group.config.copy()
- if self._set_index:
- config[TRIAL_INDEX] = index_in_group
- self._trial_id_to_group[trial_id] = self._current_group
- return config
- def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, **kwargs):
- """Stores the score for and keeps track of a completed trial.
- Stores the metric of a trial as nan if any of the following conditions
- are met:
- 1. ``result`` is empty or not provided.
- 2. ``result`` is provided but no metric was provided.
- """
- if trial_id not in self._trial_id_to_group:
- logger.error(
- "Trial {} not in group; cannot report score. "
- "Seen trials: {}".format(trial_id, list(self._trial_id_to_group))
- )
- trial_group = self._trial_id_to_group[trial_id]
- if not result or self.searcher.metric not in result:
- score = np.nan
- else:
- score = result[self.searcher.metric]
- trial_group.report(trial_id, score)
- if trial_group.finished_reporting():
- scores = trial_group.scores()
- self.searcher.on_trial_complete(
- trial_group.primary_trial_id,
- result={self.searcher.metric: np.nanmean(scores)},
- **kwargs
- )
- def get_state(self) -> Dict:
- self_state = self.__dict__.copy()
- del self_state["searcher"]
- return self_state
- def set_state(self, state: Dict):
- self.__dict__.update(state)
- def save(self, checkpoint_path: str):
- self.searcher.save(checkpoint_path)
- def restore(self, checkpoint_path: str):
- self.searcher.restore(checkpoint_path)
- def set_search_properties(
- self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
- ) -> bool:
- return _set_search_properties_backwards_compatible(
- self.searcher.set_search_properties, metric, mode, config, **spec
- )
|