__init__.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. from ray._common.utils import get_function_args
  2. from ray.tune.search.basic_variant import BasicVariantGenerator
  3. from ray.tune.search.concurrency_limiter import ConcurrencyLimiter
  4. from ray.tune.search.repeater import Repeater
  5. from ray.tune.search.search_algorithm import SearchAlgorithm
  6. from ray.tune.search.search_generator import SearchGenerator
  7. from ray.tune.search.searcher import Searcher
  8. from ray.tune.search.variant_generator import grid_search
  9. from ray.util import PublicAPI
  10. def _import_variant_generator():
  11. return BasicVariantGenerator
  12. def _import_ax_search():
  13. from ray.tune.search.ax.ax_search import AxSearch
  14. return AxSearch
  15. def _import_hyperopt_search():
  16. from ray.tune.search.hyperopt.hyperopt_search import HyperOptSearch
  17. return HyperOptSearch
  18. def _import_bayesopt_search():
  19. from ray.tune.search.bayesopt.bayesopt_search import BayesOptSearch
  20. return BayesOptSearch
  21. def _import_bohb_search():
  22. from ray.tune.search.bohb.bohb_search import TuneBOHB
  23. return TuneBOHB
  24. def _import_nevergrad_search():
  25. from ray.tune.search.nevergrad.nevergrad_search import NevergradSearch
  26. return NevergradSearch
  27. def _import_optuna_search():
  28. from ray.tune.search.optuna.optuna_search import OptunaSearch
  29. return OptunaSearch
  30. def _import_zoopt_search():
  31. from ray.tune.search.zoopt.zoopt_search import ZOOptSearch
  32. return ZOOptSearch
  33. def _import_hebo_search():
  34. from ray.tune.search.hebo.hebo_search import HEBOSearch
  35. return HEBOSearch
  36. SEARCH_ALG_IMPORT = {
  37. "variant_generator": _import_variant_generator,
  38. "random": _import_variant_generator,
  39. "ax": _import_ax_search,
  40. "hyperopt": _import_hyperopt_search,
  41. "bayesopt": _import_bayesopt_search,
  42. "bohb": _import_bohb_search,
  43. "nevergrad": _import_nevergrad_search,
  44. "optuna": _import_optuna_search,
  45. "zoopt": _import_zoopt_search,
  46. "hebo": _import_hebo_search,
  47. }
  48. @PublicAPI(stability="beta")
  49. def create_searcher(
  50. search_alg,
  51. **kwargs,
  52. ):
  53. """Instantiate a search algorithm based on the given string.
  54. This is useful for swapping between different search algorithms.
  55. Args:
  56. search_alg: The search algorithm to use.
  57. metric: The training result objective value attribute. Stopping
  58. procedures will use this attribute.
  59. mode: One of {min, max}. Determines whether objective is
  60. minimizing or maximizing the metric attribute.
  61. **kwargs: Additional parameters.
  62. These keyword arguments will be passed to the initialization
  63. function of the chosen class.
  64. Returns:
  65. ray.tune.search.Searcher: The search algorithm.
  66. Example:
  67. >>> from ray import tune # doctest: +SKIP
  68. >>> search_alg = tune.create_searcher('ax') # doctest: +SKIP
  69. """
  70. search_alg = search_alg.lower()
  71. if search_alg not in SEARCH_ALG_IMPORT:
  72. raise ValueError(
  73. f"The `search_alg` argument must be one of "
  74. f"{list(SEARCH_ALG_IMPORT)}. "
  75. f"Got: {search_alg}"
  76. )
  77. SearcherClass = SEARCH_ALG_IMPORT[search_alg]()
  78. search_alg_args = get_function_args(SearcherClass)
  79. trimmed_kwargs = {k: v for k, v in kwargs.items() if k in search_alg_args}
  80. return SearcherClass(**trimmed_kwargs)
  81. UNRESOLVED_SEARCH_SPACE = str(
  82. "You passed a `{par}` parameter to {cls} that contained unresolved search "
  83. "space definitions. {cls} should however be instantiated with fully "
  84. "configured search spaces only. To use Ray Tune's automatic search space "
  85. "conversion, pass the space definition as part of the `param_space` argument "
  86. "to `tune.Tuner()` instead."
  87. )
  88. UNDEFINED_SEARCH_SPACE = str(
  89. "Trying to sample a configuration from {cls}, but no search "
  90. "space has been defined. Either pass the `{space}` argument when "
  91. "instantiating the search algorithm, or pass a `param_space` to "
  92. "`tune.Tuner()`."
  93. )
  94. UNDEFINED_METRIC_MODE = str(
  95. "Trying to sample a configuration from {cls}, but the `metric` "
  96. "({metric}) or `mode` ({mode}) parameters have not been set. "
  97. "Either pass these arguments when instantiating the search algorithm, "
  98. "or pass them to `tune.TuneConfig()`."
  99. )
  100. __all__ = [
  101. "SearchAlgorithm",
  102. "Searcher",
  103. "ConcurrencyLimiter",
  104. "Repeater",
  105. "BasicVariantGenerator",
  106. "grid_search",
  107. "SearchGenerator",
  108. "UNRESOLVED_SEARCH_SPACE",
  109. "UNDEFINED_SEARCH_SPACE",
  110. "UNDEFINED_METRIC_MODE",
  111. ]