search_generator.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. import copy
  2. import logging
  3. from typing import Dict, List, Optional, Union
  4. from ray.tune.error import TuneError
  5. from ray.tune.experiment import Experiment, Trial, _convert_to_experiment_list
  6. from ray.tune.experiment.config_parser import _create_trial_from_spec, _make_parser
  7. from ray.tune.search.search_algorithm import SearchAlgorithm
  8. from ray.tune.search.searcher import Searcher
  9. from ray.tune.search.util import _set_search_properties_backwards_compatible
  10. from ray.tune.search.variant_generator import _resolve_nested_dict, format_vars
  11. from ray.tune.utils.util import (
  12. _atomic_save,
  13. _load_newest_checkpoint,
  14. flatten_dict,
  15. merge_dicts,
  16. )
  17. from ray.util.annotations import DeveloperAPI
  18. logger = logging.getLogger(__name__)
  19. def _warn_on_repeater(searcher, total_samples):
  20. from ray.tune.search.repeater import _warn_num_samples
  21. _warn_num_samples(searcher, total_samples)
  22. @DeveloperAPI
  23. class SearchGenerator(SearchAlgorithm):
  24. """Generates trials to be passed to the TrialRunner.
  25. Uses the provided ``searcher`` object to generate trials. This class
  26. transparently handles repeating trials with score aggregation
  27. without embedding logic into the Searcher.
  28. Args:
  29. searcher: Search object that subclasses the Searcher base class. This
  30. is then used for generating new hyperparameter samples.
  31. """
  32. CKPT_FILE_TMPL = "search_gen_state-{}.json"
  33. def __init__(self, searcher: Searcher):
  34. assert issubclass(
  35. type(searcher), Searcher
  36. ), "Searcher should be subclassing Searcher."
  37. self.searcher = searcher
  38. self._parser = _make_parser()
  39. self._experiment = None
  40. self._counter = 0 # Keeps track of number of trials created.
  41. self._total_samples = 0 # int: total samples to evaluate.
  42. self._finished = False
  43. @property
  44. def metric(self):
  45. return self.searcher.metric
  46. def set_search_properties(
  47. self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
  48. ) -> bool:
  49. return _set_search_properties_backwards_compatible(
  50. self.searcher.set_search_properties, metric, mode, config, **spec
  51. )
  52. @property
  53. def total_samples(self):
  54. return self._total_samples
  55. def add_configurations(
  56. self, experiments: Union[Experiment, List[Experiment], Dict[str, Dict]]
  57. ):
  58. """Registers experiment specifications.
  59. Arguments:
  60. experiments: Experiments to run.
  61. """
  62. assert not self._experiment
  63. logger.debug("added configurations")
  64. experiment_list = _convert_to_experiment_list(experiments)
  65. assert (
  66. len(experiment_list) == 1
  67. ), "SearchAlgorithms can only support 1 experiment at a time."
  68. self._experiment = experiment_list[0]
  69. experiment_spec = self._experiment.spec
  70. self._total_samples = self._experiment.spec.get("num_samples", 1)
  71. _warn_on_repeater(self.searcher, self._total_samples)
  72. if "run" not in experiment_spec:
  73. raise TuneError("Must specify `run` in {}".format(experiment_spec))
  74. def next_trial(self):
  75. """Provides one Trial object to be queued into the TrialRunner.
  76. Returns:
  77. Trial: Returns a single trial.
  78. """
  79. if not self.is_finished():
  80. return self.create_trial_if_possible(self._experiment.spec)
  81. return None
  82. def create_trial_if_possible(self, experiment_spec: Dict) -> Optional[Trial]:
  83. logger.debug("creating trial")
  84. trial_id = Trial.generate_id()
  85. suggested_config = self.searcher.suggest(trial_id)
  86. if suggested_config == Searcher.FINISHED:
  87. self._finished = True
  88. logger.debug("Searcher has finished.")
  89. return
  90. if suggested_config is None:
  91. return
  92. spec = copy.deepcopy(experiment_spec)
  93. spec["config"] = merge_dicts(spec["config"], copy.deepcopy(suggested_config))
  94. # Create a new trial_id if duplicate trial is created
  95. flattened_config = _resolve_nested_dict(spec["config"])
  96. self._counter += 1
  97. tag = "{0}_{1}".format(str(self._counter), format_vars(flattened_config))
  98. trial = _create_trial_from_spec(
  99. spec,
  100. self._parser,
  101. evaluated_params=flatten_dict(suggested_config),
  102. experiment_tag=tag,
  103. trial_id=trial_id,
  104. )
  105. return trial
  106. def on_trial_result(self, trial_id: str, result: Dict):
  107. """Notifies the underlying searcher."""
  108. self.searcher.on_trial_result(trial_id, result)
  109. def on_trial_complete(
  110. self, trial_id: str, result: Optional[Dict] = None, error: bool = False
  111. ):
  112. self.searcher.on_trial_complete(trial_id=trial_id, result=result, error=error)
  113. def is_finished(self) -> bool:
  114. return self._counter >= self._total_samples or self._finished
  115. def get_state(self) -> Dict:
  116. return {
  117. "counter": self._counter,
  118. "total_samples": self._total_samples,
  119. "finished": self._finished,
  120. "experiment": self._experiment,
  121. }
  122. def set_state(self, state: Dict):
  123. self._counter = state["counter"]
  124. self._total_samples = state["total_samples"]
  125. self._finished = state["finished"]
  126. self._experiment = state["experiment"]
  127. def has_checkpoint(self, dirpath: str):
  128. return bool(_load_newest_checkpoint(dirpath, self.CKPT_FILE_TMPL.format("*")))
  129. def save_to_dir(self, dirpath: str, session_str: str):
  130. """Saves self + searcher to dir.
  131. Separates the "searcher" from its wrappers (concurrency, repeating).
  132. This allows the user to easily restore a given searcher.
  133. The save operation is atomic (write/swap).
  134. Args:
  135. dirpath: Filepath to experiment dir.
  136. session_str: Unique identifier of the current run
  137. session.
  138. """
  139. searcher = self.searcher
  140. search_alg_state = self.get_state()
  141. while hasattr(searcher, "searcher"):
  142. searcher_name = type(searcher).__name__
  143. if searcher_name in search_alg_state:
  144. logger.warning(
  145. "There was a duplicate when saving {}. "
  146. "Restore may not work properly.".format(searcher_name)
  147. )
  148. else:
  149. search_alg_state["name:" + searcher_name] = searcher.get_state()
  150. searcher = searcher.searcher
  151. base_searcher = searcher
  152. # We save the base searcher separately for users to easily
  153. # separate the searcher.
  154. base_searcher.save_to_dir(dirpath, session_str)
  155. file_name = self.CKPT_FILE_TMPL.format(session_str)
  156. _atomic_save(
  157. state=search_alg_state,
  158. checkpoint_dir=dirpath,
  159. file_name=file_name,
  160. tmp_file_name=f"tmp-{file_name}",
  161. )
  162. def restore_from_dir(self, dirpath: str):
  163. """Restores self + searcher + search wrappers from dirpath."""
  164. searcher = self.searcher
  165. search_alg_state = _load_newest_checkpoint(
  166. dirpath, self.CKPT_FILE_TMPL.format("*")
  167. )
  168. if not search_alg_state:
  169. raise RuntimeError("Unable to find checkpoint in {}.".format(dirpath))
  170. while hasattr(searcher, "searcher"):
  171. searcher_name = "name:" + type(searcher).__name__
  172. if searcher_name not in search_alg_state:
  173. names = [
  174. key.split("name:")[1]
  175. for key in search_alg_state
  176. if key.startswith("name:")
  177. ]
  178. logger.warning(
  179. "{} was not found in the experiment "
  180. "state when restoring. Found {}.".format(searcher_name, names)
  181. )
  182. else:
  183. searcher.set_state(search_alg_state.pop(searcher_name))
  184. searcher = searcher.searcher
  185. base_searcher = searcher
  186. logger.debug(f"searching base {base_searcher}")
  187. base_searcher.restore_from_dir(dirpath)
  188. self.set_state(search_alg_state)