searcher.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600
  1. import copy
  2. import glob
  3. import logging
  4. import os
  5. import uuid
  6. import warnings
  7. from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
  8. from ray.air._internal.usage import tag_searcher
  9. from ray.tune.search.util import _set_search_properties_backwards_compatible
  10. from ray.util.annotations import DeveloperAPI, PublicAPI
  11. from ray.util.debug import log_once
  12. if TYPE_CHECKING:
  13. from ray.tune.analysis import ExperimentAnalysis
  14. from ray.tune.experiment import Trial
  15. logger = logging.getLogger(__name__)
  16. @DeveloperAPI
  17. class Searcher:
  18. """Abstract class for wrapping suggesting algorithms.
  19. Custom algorithms can extend this class easily by overriding the
  20. `suggest` method provide generated parameters for the trials.
  21. Any subclass that implements ``__init__`` must also call the
  22. constructor of this class: ``super(Subclass, self).__init__(...)``.
  23. To track suggestions and their corresponding evaluations, the method
  24. `suggest` will be passed a trial_id, which will be used in
  25. subsequent notifications.
  26. Not all implementations support multi objectives.
  27. Note to Tune developers: If a new searcher is added, please update
  28. `air/_internal/usage.py`.
  29. Args:
  30. metric: The training result objective value attribute. If
  31. list then list of training result objective value attributes
  32. mode: If string One of {min, max}. If list then
  33. list of max and min, determines whether objective is minimizing
  34. or maximizing the metric attribute. Must match type of metric.
  35. .. code-block:: python
  36. class ExampleSearch(Searcher):
  37. def __init__(self, metric="mean_loss", mode="min", **kwargs):
  38. super(ExampleSearch, self).__init__(
  39. metric=metric, mode=mode, **kwargs)
  40. self.optimizer = Optimizer()
  41. self.configurations = {}
  42. def suggest(self, trial_id):
  43. configuration = self.optimizer.query()
  44. self.configurations[trial_id] = configuration
  45. def on_trial_complete(self, trial_id, result, **kwargs):
  46. configuration = self.configurations[trial_id]
  47. if result and self.metric in result:
  48. self.optimizer.update(configuration, result[self.metric])
  49. tuner = tune.Tuner(
  50. trainable_function,
  51. tune_config=tune.TuneConfig(
  52. search_alg=ExampleSearch()
  53. )
  54. )
  55. tuner.fit()
  56. """
  57. FINISHED = "FINISHED"
  58. CKPT_FILE_TMPL = "searcher-state-{}.pkl"
  59. def __init__(
  60. self,
  61. metric: Optional[str] = None,
  62. mode: Optional[str] = None,
  63. ):
  64. tag_searcher(self)
  65. self._metric = metric
  66. self._mode = mode
  67. if not mode or not metric:
  68. # Early return to avoid assertions
  69. return
  70. assert isinstance(
  71. metric, type(mode)
  72. ), "metric and mode must be of the same type"
  73. if isinstance(mode, str):
  74. assert mode in ["min", "max"], "if `mode` is a str must be 'min' or 'max'!"
  75. elif isinstance(mode, list):
  76. assert len(mode) == len(metric), "Metric and mode must be the same length"
  77. assert all(
  78. mod in ["min", "max", "obs"] for mod in mode
  79. ), "All of mode must be 'min' or 'max' or 'obs'!"
  80. else:
  81. raise ValueError("Mode most either be a list or string")
  82. def set_search_properties(
  83. self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
  84. ) -> bool:
  85. """Pass search properties to searcher.
  86. This method acts as an alternative to instantiating search algorithms
  87. with their own specific search spaces. Instead they can accept a
  88. Tune config through this method. A searcher should return ``True``
  89. if setting the config was successful, or ``False`` if it was
  90. unsuccessful, e.g. when the search space has already been set.
  91. Args:
  92. metric: Metric to optimize
  93. mode: One of ["min", "max"]. Direction to optimize.
  94. config: Tune config dict.
  95. **spec: Any kwargs for forward compatibility.
  96. Info like Experiment.PUBLIC_KEYS is provided through here.
  97. """
  98. return False
  99. def on_trial_result(self, trial_id: str, result: Dict) -> None:
  100. """Optional notification for result during training.
  101. Note that by default, the result dict may include NaNs or
  102. may not include the optimization metric. It is up to the
  103. subclass implementation to preprocess the result to
  104. avoid breaking the optimization process.
  105. Args:
  106. trial_id: A unique string ID for the trial.
  107. result: Dictionary of metrics for current training progress.
  108. Note that the result dict may include NaNs or
  109. may not include the optimization metric. It is up to the
  110. subclass implementation to preprocess the result to
  111. avoid breaking the optimization process.
  112. """
  113. pass
  114. def on_trial_complete(
  115. self, trial_id: str, result: Optional[Dict] = None, error: bool = False
  116. ) -> None:
  117. """Notification for the completion of trial.
  118. Typically, this method is used for notifying the underlying
  119. optimizer of the result.
  120. Args:
  121. trial_id: A unique string ID for the trial.
  122. result: Dictionary of metrics for current training progress.
  123. Note that the result dict may include NaNs or
  124. may not include the optimization metric. It is up to the
  125. subclass implementation to preprocess the result to
  126. avoid breaking the optimization process. Upon errors, this
  127. may also be None.
  128. error: True if the training process raised an error.
  129. """
  130. raise NotImplementedError
  131. def suggest(self, trial_id: str) -> Optional[Dict]:
  132. """Queries the algorithm to retrieve the next set of parameters.
  133. Arguments:
  134. trial_id: Trial ID used for subsequent notifications.
  135. Returns:
  136. dict | FINISHED | None: Configuration for a trial, if possible.
  137. If FINISHED is returned, Tune will be notified that
  138. no more suggestions/configurations will be provided.
  139. If None is returned, Tune will skip the querying of the
  140. searcher for this step.
  141. """
  142. raise NotImplementedError
  143. def add_evaluated_point(
  144. self,
  145. parameters: Dict,
  146. value: float,
  147. error: bool = False,
  148. pruned: bool = False,
  149. intermediate_values: Optional[List[float]] = None,
  150. ):
  151. """Pass results from a point that has been evaluated separately.
  152. This method allows for information from outside the
  153. suggest - on_trial_complete loop to be passed to the search
  154. algorithm.
  155. This functionality depends on the underlying search algorithm
  156. and may not be always available.
  157. Args:
  158. parameters: Parameters used for the trial.
  159. value: Metric value obtained in the trial.
  160. error: True if the training process raised an error.
  161. pruned: True if trial was pruned.
  162. intermediate_values: List of metric values for
  163. intermediate iterations of the result. None if not
  164. applicable.
  165. """
  166. raise NotImplementedError
  167. def add_evaluated_trials(
  168. self,
  169. trials_or_analysis: Union["Trial", List["Trial"], "ExperimentAnalysis"],
  170. metric: str,
  171. ):
  172. """Pass results from trials that have been evaluated separately.
  173. This method allows for information from outside the
  174. suggest - on_trial_complete loop to be passed to the search
  175. algorithm.
  176. This functionality depends on the underlying search algorithm
  177. and may not be always available (same as ``add_evaluated_point``.)
  178. Args:
  179. trials_or_analysis: Trials to pass results form to the searcher.
  180. metric: Metric name reported by trials used for
  181. determining the objective value.
  182. """
  183. if self.add_evaluated_point == Searcher.add_evaluated_point:
  184. raise NotImplementedError
  185. # lazy imports to avoid circular dependencies
  186. from ray.tune.analysis import ExperimentAnalysis
  187. from ray.tune.experiment import Trial
  188. from ray.tune.result import DONE
  189. if isinstance(trials_or_analysis, (list, tuple)):
  190. trials = trials_or_analysis
  191. elif isinstance(trials_or_analysis, Trial):
  192. trials = [trials_or_analysis]
  193. elif isinstance(trials_or_analysis, ExperimentAnalysis):
  194. trials = trials_or_analysis.trials
  195. else:
  196. raise NotImplementedError(
  197. "Expected input to be a `Trial`, a list of `Trial`s, or "
  198. f"`ExperimentAnalysis`, got: {trials_or_analysis}"
  199. )
  200. any_trial_had_metric = False
  201. def trial_to_points(trial: Trial) -> Dict[str, Any]:
  202. nonlocal any_trial_had_metric
  203. has_trial_been_pruned = (
  204. trial.status == Trial.TERMINATED
  205. and not trial.last_result.get(DONE, False)
  206. )
  207. has_trial_finished = (
  208. trial.status == Trial.TERMINATED and trial.last_result.get(DONE, False)
  209. )
  210. if not any_trial_had_metric:
  211. any_trial_had_metric = (
  212. metric in trial.last_result and has_trial_finished
  213. )
  214. if Trial.TERMINATED and metric not in trial.last_result:
  215. return None
  216. return dict(
  217. parameters=trial.config,
  218. value=trial.last_result.get(metric, None),
  219. error=trial.status == Trial.ERROR,
  220. pruned=has_trial_been_pruned,
  221. intermediate_values=None, # we do not save those
  222. )
  223. for trial in trials:
  224. kwargs = trial_to_points(trial)
  225. if kwargs:
  226. self.add_evaluated_point(**kwargs)
  227. if not any_trial_had_metric:
  228. warnings.warn(
  229. "No completed trial returned the specified metric. "
  230. "Make sure the name you have passed is correct. "
  231. )
  232. def save(self, checkpoint_path: str):
  233. """Save state to path for this search algorithm.
  234. Args:
  235. checkpoint_path: File where the search algorithm
  236. state is saved. This path should be used later when
  237. restoring from file.
  238. Example:
  239. .. code-block:: python
  240. search_alg = Searcher(...)
  241. tuner = tune.Tuner(
  242. cost,
  243. tune_config=tune.TuneConfig(
  244. search_alg=search_alg,
  245. num_samples=5
  246. ),
  247. param_space=config
  248. )
  249. results = tuner.fit()
  250. search_alg.save("./my_favorite_path.pkl")
  251. .. versionchanged:: 0.8.7
  252. Save is automatically called by `Tuner().fit()`. You can use
  253. `Tuner().restore()` to restore from an experiment directory
  254. such as `~/ray_results/trainable`.
  255. """
  256. raise NotImplementedError
  257. def restore(self, checkpoint_path: str):
  258. """Restore state for this search algorithm
  259. Args:
  260. checkpoint_path: File where the search algorithm
  261. state is saved. This path should be the same
  262. as the one provided to "save".
  263. Example:
  264. .. code-block:: python
  265. search_alg.save("./my_favorite_path.pkl")
  266. search_alg2 = Searcher(...)
  267. search_alg2 = ConcurrencyLimiter(search_alg2, 1)
  268. search_alg2.restore(checkpoint_path)
  269. tuner = tune.Tuner(
  270. cost,
  271. tune_config=tune.TuneConfig(
  272. search_alg=search_alg2,
  273. num_samples=5
  274. ),
  275. )
  276. tuner.fit()
  277. """
  278. raise NotImplementedError
  279. def set_max_concurrency(self, max_concurrent: int) -> bool:
  280. """Set max concurrent trials this searcher can run.
  281. This method will be called on the wrapped searcher by the
  282. ``ConcurrencyLimiter``. It is intended to allow for searchers
  283. which have custom, internal logic handling max concurrent trials
  284. to inherit the value passed to ``ConcurrencyLimiter``.
  285. If this method returns False, it signifies that no special
  286. logic for handling this case is present in the searcher.
  287. Args:
  288. max_concurrent: Number of maximum concurrent trials.
  289. """
  290. return False
  291. def get_state(self) -> Dict:
  292. raise NotImplementedError
  293. def set_state(self, state: Dict):
  294. raise NotImplementedError
  295. def save_to_dir(self, checkpoint_dir: str, session_str: str = "default"):
  296. """Automatically saves the given searcher to the checkpoint_dir.
  297. This is automatically used by Tuner().fit() during a Tune job.
  298. Args:
  299. checkpoint_dir: Filepath to experiment dir.
  300. session_str: Unique identifier of the current run
  301. session.
  302. """
  303. file_name = self.CKPT_FILE_TMPL.format(session_str)
  304. tmp_file_name = f".{str(uuid.uuid4())}-tmp-{file_name}"
  305. tmp_search_ckpt_path = os.path.join(checkpoint_dir, tmp_file_name)
  306. success = True
  307. try:
  308. self.save(tmp_search_ckpt_path)
  309. except NotImplementedError:
  310. if log_once("suggest:save_to_dir"):
  311. logger.warning("save not implemented for Searcher. Skipping save.")
  312. success = False
  313. if success and os.path.exists(tmp_search_ckpt_path):
  314. os.replace(
  315. tmp_search_ckpt_path,
  316. os.path.join(checkpoint_dir, file_name),
  317. )
  318. def restore_from_dir(self, checkpoint_dir: str):
  319. """Restores the state of a searcher from a given checkpoint_dir.
  320. Typically, you should use this function to restore from an
  321. experiment directory such as `~/ray_results/trainable`.
  322. .. code-block:: python
  323. tuner = tune.Tuner(
  324. cost,
  325. run_config=tune.RunConfig(
  326. name=self.experiment_name,
  327. storage_path="~/my_results",
  328. ),
  329. tune_config=tune.TuneConfig(
  330. search_alg=search_alg,
  331. num_samples=5
  332. ),
  333. param_space=config
  334. )
  335. tuner.fit()
  336. search_alg2 = Searcher()
  337. search_alg2.restore_from_dir(
  338. os.path.join("~/my_results", self.experiment_name)
  339. """
  340. pattern = self.CKPT_FILE_TMPL.format("*")
  341. full_paths = glob.glob(os.path.join(checkpoint_dir, pattern))
  342. if not full_paths:
  343. raise RuntimeError(
  344. "Searcher unable to find checkpoint in {}".format(checkpoint_dir)
  345. ) # TODO
  346. most_recent_checkpoint = max(full_paths)
  347. self.restore(most_recent_checkpoint)
  348. @property
  349. def metric(self) -> str:
  350. """The training result objective value attribute."""
  351. return self._metric
  352. @property
  353. def mode(self) -> str:
  354. """Specifies if minimizing or maximizing the metric."""
  355. return self._mode
  356. @PublicAPI
  357. class ConcurrencyLimiter(Searcher):
  358. """A wrapper algorithm for limiting the number of concurrent trials.
  359. Certain Searchers have their own internal logic for limiting
  360. the number of concurrent trials. If such a Searcher is passed to a
  361. ``ConcurrencyLimiter``, the ``max_concurrent`` of the
  362. ``ConcurrencyLimiter`` will override the ``max_concurrent`` value
  363. of the Searcher. The ``ConcurrencyLimiter`` will then let the
  364. Searcher's internal logic take over.
  365. Args:
  366. searcher: Searcher object that the
  367. ConcurrencyLimiter will manage.
  368. max_concurrent: Maximum concurrent samples from the underlying
  369. searcher.
  370. batch: Whether to wait for all concurrent samples
  371. to finish before updating the underlying searcher.
  372. Example:
  373. .. code-block:: python
  374. from ray.tune.search import ConcurrencyLimiter
  375. search_alg = HyperOptSearch(metric="accuracy")
  376. search_alg = ConcurrencyLimiter(search_alg, max_concurrent=2)
  377. tuner = tune.Tuner(
  378. trainable_function,
  379. tune_config=tune.TuneConfig(
  380. search_alg=search_alg
  381. ),
  382. )
  383. tuner.fit()
  384. """
  385. def __init__(self, searcher: Searcher, max_concurrent: int, batch: bool = False):
  386. assert type(max_concurrent) is int and max_concurrent > 0
  387. self.searcher = searcher
  388. self.max_concurrent = max_concurrent
  389. self.batch = batch
  390. self.live_trials = set()
  391. self.num_unfinished_live_trials = 0
  392. self.cached_results = {}
  393. self._limit_concurrency = True
  394. if not isinstance(searcher, Searcher):
  395. raise RuntimeError(
  396. f"The `ConcurrencyLimiter` only works with `Searcher` "
  397. f"objects (got {type(searcher)}). Please try to pass "
  398. f"`max_concurrent` to the search generator directly."
  399. )
  400. self._set_searcher_max_concurrency()
  401. super(ConcurrencyLimiter, self).__init__(
  402. metric=self.searcher.metric, mode=self.searcher.mode
  403. )
  404. def _set_searcher_max_concurrency(self):
  405. # If the searcher has special logic for handling max concurrency,
  406. # we do not do anything inside the ConcurrencyLimiter
  407. self._limit_concurrency = not self.searcher.set_max_concurrency(
  408. self.max_concurrent
  409. )
  410. def set_max_concurrency(self, max_concurrent: int) -> bool:
  411. # Determine if this behavior is acceptable, or if it should
  412. # raise an exception.
  413. self.max_concurrent = max_concurrent
  414. return True
  415. def set_search_properties(
  416. self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
  417. ) -> bool:
  418. self._set_searcher_max_concurrency()
  419. return _set_search_properties_backwards_compatible(
  420. self.searcher.set_search_properties, metric, mode, config, **spec
  421. )
  422. def suggest(self, trial_id: str) -> Optional[Dict]:
  423. if not self._limit_concurrency:
  424. return self.searcher.suggest(trial_id)
  425. assert (
  426. trial_id not in self.live_trials
  427. ), f"Trial ID {trial_id} must be unique: already found in set."
  428. if len(self.live_trials) >= self.max_concurrent:
  429. logger.debug(
  430. f"Not providing a suggestion for {trial_id} due to "
  431. "concurrency limit: %s/%s.",
  432. len(self.live_trials),
  433. self.max_concurrent,
  434. )
  435. return
  436. suggestion = self.searcher.suggest(trial_id)
  437. if suggestion not in (None, Searcher.FINISHED):
  438. self.live_trials.add(trial_id)
  439. self.num_unfinished_live_trials += 1
  440. return suggestion
  441. def on_trial_complete(
  442. self, trial_id: str, result: Optional[Dict] = None, error: bool = False
  443. ):
  444. if not self._limit_concurrency:
  445. return self.searcher.on_trial_complete(trial_id, result=result, error=error)
  446. if trial_id not in self.live_trials:
  447. return
  448. elif self.batch:
  449. self.cached_results[trial_id] = (result, error)
  450. self.num_unfinished_live_trials -= 1
  451. if self.num_unfinished_live_trials <= 0:
  452. # Update the underlying searcher once the
  453. # full batch is completed.
  454. for trial_id, (result, error) in self.cached_results.items():
  455. self.searcher.on_trial_complete(
  456. trial_id, result=result, error=error
  457. )
  458. self.live_trials.remove(trial_id)
  459. self.cached_results = {}
  460. self.num_unfinished_live_trials = 0
  461. else:
  462. return
  463. else:
  464. self.searcher.on_trial_complete(trial_id, result=result, error=error)
  465. self.live_trials.remove(trial_id)
  466. self.num_unfinished_live_trials -= 1
  467. def on_trial_result(self, trial_id: str, result: Dict) -> None:
  468. self.searcher.on_trial_result(trial_id, result)
  469. def add_evaluated_point(
  470. self,
  471. parameters: Dict,
  472. value: float,
  473. error: bool = False,
  474. pruned: bool = False,
  475. intermediate_values: Optional[List[float]] = None,
  476. ):
  477. return self.searcher.add_evaluated_point(
  478. parameters, value, error, pruned, intermediate_values
  479. )
  480. def get_state(self) -> Dict:
  481. state = self.__dict__.copy()
  482. del state["searcher"]
  483. return copy.deepcopy(state)
  484. def set_state(self, state: Dict):
  485. self.__dict__.update(state)
  486. def save(self, checkpoint_path: str):
  487. self.searcher.save(checkpoint_path)
  488. def restore(self, checkpoint_path: str):
  489. self.searcher.restore(checkpoint_path)