repeater.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import copy
  2. import logging
  3. from typing import Dict, List, Optional
  4. import numpy as np
  5. from ray.tune.search.searcher import Searcher
  6. from ray.tune.search.util import _set_search_properties_backwards_compatible
  7. from ray.util import PublicAPI
  8. logger = logging.getLogger(__name__)
  9. TRIAL_INDEX = "__trial_index__"
  10. """str: A constant value representing the repeat index of the trial."""
  11. def _warn_num_samples(searcher: Searcher, num_samples: int):
  12. if isinstance(searcher, Repeater) and num_samples % searcher.repeat:
  13. logger.warning(
  14. "`num_samples` is now expected to be the total number of trials, "
  15. "including the repeat trials. For example, set num_samples=15 if "
  16. "you intend to obtain 3 search algorithm suggestions and repeat "
  17. "each suggestion 5 times. Any leftover trials "
  18. "(num_samples mod repeat) will be ignored."
  19. )
  20. class _TrialGroup:
  21. """Internal class for grouping trials of same parameters.
  22. This is used when repeating trials for reducing training variance.
  23. Args:
  24. primary_trial_id: Trial ID of the "primary trial".
  25. This trial is the one that the Searcher is aware of.
  26. config: Suggested configuration shared across all trials
  27. in the trial group.
  28. max_trials: Max number of trials to execute within this group.
  29. """
  30. def __init__(self, primary_trial_id: str, config: Dict, max_trials: int = 1):
  31. assert type(config) is dict, "config is not a dict, got {}".format(config)
  32. self.primary_trial_id = primary_trial_id
  33. self.config = config
  34. self._trials = {primary_trial_id: None}
  35. self.max_trials = max_trials
  36. def add(self, trial_id: str):
  37. assert len(self._trials) < self.max_trials
  38. self._trials.setdefault(trial_id, None)
  39. def full(self) -> bool:
  40. return len(self._trials) == self.max_trials
  41. def report(self, trial_id: str, score: float):
  42. assert trial_id in self._trials
  43. if score is None:
  44. raise ValueError("Internal Error: Score cannot be None.")
  45. self._trials[trial_id] = score
  46. def finished_reporting(self) -> bool:
  47. return (
  48. None not in self._trials.values() and len(self._trials) == self.max_trials
  49. )
  50. def scores(self) -> List[Optional[float]]:
  51. return list(self._trials.values())
  52. def count(self) -> int:
  53. return len(self._trials)
  54. @PublicAPI
  55. class Repeater(Searcher):
  56. """A wrapper algorithm for repeating trials of same parameters.
  57. Set tune.TuneConfig(num_samples=...) to be a multiple of `repeat`. For example,
  58. set num_samples=15 if you intend to obtain 3 search algorithm suggestions
  59. and repeat each suggestion 5 times. Any leftover trials
  60. (num_samples mod repeat) will be ignored.
  61. It is recommended that you do not run an early-stopping TrialScheduler
  62. simultaneously.
  63. Args:
  64. searcher: Searcher object that the
  65. Repeater will optimize. Note that the Searcher
  66. will only see 1 trial among multiple repeated trials.
  67. The result/metric passed to the Searcher upon
  68. trial completion will be averaged among all repeats.
  69. repeat: Number of times to generate a trial with a repeated
  70. configuration. Defaults to 1.
  71. set_index: Sets a tune.search.repeater.TRIAL_INDEX in
  72. Trainable/Function config which corresponds to the index of the
  73. repeated trial. This can be used for seeds. Defaults to True.
  74. Example:
  75. .. code-block:: python
  76. from ray.tune.search import Repeater
  77. search_alg = BayesOptSearch(...)
  78. re_search_alg = Repeater(search_alg, repeat=10)
  79. # Repeat 2 samples 10 times each.
  80. tuner = tune.Tuner(
  81. trainable,
  82. tune_config=tune.TuneConfig(
  83. search_alg=re_search_alg,
  84. num_samples=20,
  85. ),
  86. )
  87. tuner.fit()
  88. """
  89. def __init__(self, searcher: Searcher, repeat: int = 1, set_index: bool = True):
  90. self.searcher = searcher
  91. self.repeat = repeat
  92. self._set_index = set_index
  93. self._groups = []
  94. self._trial_id_to_group = {}
  95. self._current_group = None
  96. super(Repeater, self).__init__(
  97. metric=self.searcher.metric, mode=self.searcher.mode
  98. )
  99. def suggest(self, trial_id: str) -> Optional[Dict]:
  100. if self._current_group is None or self._current_group.full():
  101. config = self.searcher.suggest(trial_id)
  102. if config is None:
  103. return config
  104. self._current_group = _TrialGroup(
  105. trial_id, copy.deepcopy(config), max_trials=self.repeat
  106. )
  107. self._groups.append(self._current_group)
  108. index_in_group = 0
  109. else:
  110. index_in_group = self._current_group.count()
  111. self._current_group.add(trial_id)
  112. config = self._current_group.config.copy()
  113. if self._set_index:
  114. config[TRIAL_INDEX] = index_in_group
  115. self._trial_id_to_group[trial_id] = self._current_group
  116. return config
  117. def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, **kwargs):
  118. """Stores the score for and keeps track of a completed trial.
  119. Stores the metric of a trial as nan if any of the following conditions
  120. are met:
  121. 1. ``result`` is empty or not provided.
  122. 2. ``result`` is provided but no metric was provided.
  123. """
  124. if trial_id not in self._trial_id_to_group:
  125. logger.error(
  126. "Trial {} not in group; cannot report score. "
  127. "Seen trials: {}".format(trial_id, list(self._trial_id_to_group))
  128. )
  129. trial_group = self._trial_id_to_group[trial_id]
  130. if not result or self.searcher.metric not in result:
  131. score = np.nan
  132. else:
  133. score = result[self.searcher.metric]
  134. trial_group.report(trial_id, score)
  135. if trial_group.finished_reporting():
  136. scores = trial_group.scores()
  137. self.searcher.on_trial_complete(
  138. trial_group.primary_trial_id,
  139. result={self.searcher.metric: np.nanmean(scores)},
  140. **kwargs
  141. )
  142. def get_state(self) -> Dict:
  143. self_state = self.__dict__.copy()
  144. del self_state["searcher"]
  145. return self_state
  146. def set_state(self, state: Dict):
  147. self.__dict__.update(state)
  148. def save(self, checkpoint_path: str):
  149. self.searcher.save(checkpoint_path)
  150. def restore(self, checkpoint_path: str):
  151. self.searcher.restore(checkpoint_path)
  152. def set_search_properties(
  153. self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
  154. ) -> bool:
  155. return _set_search_properties_backwards_compatible(
  156. self.searcher.set_search_properties, metric, mode, config, **spec
  157. )