ax_search.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. import copy
  2. import logging
  3. from typing import Dict, List, Optional, Union
  4. import numpy as np
  5. from ray import cloudpickle
  6. from ray.tune.result import DEFAULT_METRIC
  7. from ray.tune.search import (
  8. UNDEFINED_METRIC_MODE,
  9. UNDEFINED_SEARCH_SPACE,
  10. UNRESOLVED_SEARCH_SPACE,
  11. Searcher,
  12. )
  13. from ray.tune.search.sample import (
  14. Categorical,
  15. Float,
  16. Integer,
  17. LogUniform,
  18. Quantized,
  19. Uniform,
  20. )
  21. from ray.tune.search.variant_generator import parse_spec_vars
  22. from ray.tune.utils.util import flatten_dict, unflatten_list_dict
  23. try:
  24. import ax
  25. from ax.service.ax_client import AxClient
  26. except ImportError:
  27. ax = AxClient = None
  28. # This exception only exists in newer Ax releases for python 3.7
  29. try:
  30. from ax.exceptions.core import DataRequiredError
  31. from ax.exceptions.generation_strategy import MaxParallelismReachedException
  32. except ImportError:
  33. MaxParallelismReachedException = DataRequiredError = Exception
  34. logger = logging.getLogger(__name__)
  35. class AxSearch(Searcher):
  36. """Uses `Ax <https://ax.dev/>`_ to optimize hyperparameters.
  37. Ax is a platform for understanding, managing, deploying, and
  38. automating adaptive experiments. Ax provides an easy to use
  39. interface with BoTorch, a flexible, modern library for Bayesian
  40. optimization in PyTorch. More information can be found in https://ax.dev/.
  41. To use this search algorithm, you must install Ax:
  42. .. code-block:: bash
  43. $ pip install ax-platform
  44. Parameters:
  45. space: Parameters in the experiment search space.
  46. Required elements in the dictionaries are: "name" (name of
  47. this parameter, string), "type" (type of the parameter: "range",
  48. "fixed", or "choice", string), "bounds" for range parameters
  49. (list of two values, lower bound first), "values" for choice
  50. parameters (list of values), and "value" for fixed parameters
  51. (single value).
  52. metric: Name of the metric used as objective in this
  53. experiment. This metric must be present in `raw_data` argument
  54. to `log_data`. This metric must also be present in the dict
  55. reported/returned by the Trainable. If None but a mode was passed,
  56. the `ray.tune.result.DEFAULT_METRIC` will be used per default.
  57. mode: One of {min, max}. Determines whether objective is
  58. minimizing or maximizing the metric attribute. Defaults to "max".
  59. points_to_evaluate: Initial parameter suggestions to be run
  60. first. This is for when you already have some good parameters
  61. you want to run first to help the algorithm make better suggestions
  62. for future parameters. Needs to be a list of dicts containing the
  63. configurations.
  64. parameter_constraints: Parameter constraints, such as
  65. "x3 >= x4" or "x3 + x4 >= 2".
  66. outcome_constraints: Outcome constraints of form
  67. "metric_name >= bound", like "m1 <= 3."
  68. ax_client: Optional AxClient instance. If this is set, do
  69. not pass any values to these parameters: `space`, `metric`,
  70. `parameter_constraints`, `outcome_constraints`.
  71. **ax_kwargs: Passed to AxClient instance. Ignored if `AxClient` is not
  72. None.
  73. Tune automatically converts search spaces to Ax's format:
  74. .. code-block:: python
  75. from ray import tune
  76. from ray.tune.search.ax import AxSearch
  77. config = {
  78. "x1": tune.uniform(0.0, 1.0),
  79. "x2": tune.uniform(0.0, 1.0)
  80. }
  81. def easy_objective(config):
  82. for i in range(100):
  83. intermediate_result = config["x1"] + config["x2"] * i
  84. tune.report({"score": intermediate_result})
  85. ax_search = AxSearch()
  86. tuner = tune.Tuner(
  87. easy_objective,
  88. tune_config=tune.TuneConfig(
  89. search_alg=ax_search,
  90. metric="score",
  91. mode="max",
  92. ),
  93. param_space=config,
  94. )
  95. tuner.fit()
  96. If you would like to pass the search space manually, the code would
  97. look like this:
  98. .. code-block:: python
  99. from ray import tune
  100. from ray.tune.search.ax import AxSearch
  101. parameters = [
  102. {"name": "x1", "type": "range", "bounds": [0.0, 1.0]},
  103. {"name": "x2", "type": "range", "bounds": [0.0, 1.0]},
  104. ]
  105. def easy_objective(config):
  106. for i in range(100):
  107. intermediate_result = config["x1"] + config["x2"] * i
  108. tune.report({"score": intermediate_result})
  109. ax_search = AxSearch(space=parameters, metric="score", mode="max")
  110. tuner = tune.Tuner(
  111. easy_objective,
  112. tune_config=tune.TuneConfig(
  113. search_alg=ax_search,
  114. ),
  115. )
  116. tuner.fit()
  117. """
  118. def __init__(
  119. self,
  120. space: Optional[Union[Dict, List[Dict]]] = None,
  121. metric: Optional[str] = None,
  122. mode: Optional[str] = None,
  123. points_to_evaluate: Optional[List[Dict]] = None,
  124. parameter_constraints: Optional[List] = None,
  125. outcome_constraints: Optional[List] = None,
  126. ax_client: Optional[AxClient] = None,
  127. **ax_kwargs,
  128. ):
  129. assert (
  130. ax is not None
  131. ), """Ax must be installed!
  132. You can install AxSearch with the command:
  133. `pip install ax-platform`."""
  134. if mode:
  135. assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
  136. super(AxSearch, self).__init__(
  137. metric=metric,
  138. mode=mode,
  139. )
  140. self._ax = ax_client
  141. self._ax_kwargs = ax_kwargs or {}
  142. if isinstance(space, dict) and space:
  143. resolved_vars, domain_vars, grid_vars = parse_spec_vars(space)
  144. if domain_vars or grid_vars:
  145. logger.warning(
  146. UNRESOLVED_SEARCH_SPACE.format(par="space", cls=type(self))
  147. )
  148. space = self.convert_search_space(space)
  149. self._space = space
  150. self._parameter_constraints = parameter_constraints
  151. self._outcome_constraints = outcome_constraints
  152. self._points_to_evaluate = copy.deepcopy(points_to_evaluate)
  153. self._parameters = []
  154. self._live_trial_mapping = {}
  155. if self._ax or self._space:
  156. self._setup_experiment()
  157. def _setup_experiment(self):
  158. if self._metric is None and self._mode:
  159. # If only a mode was passed, use anonymous metric
  160. self._metric = DEFAULT_METRIC
  161. if not self._ax:
  162. self._ax = AxClient(**self._ax_kwargs)
  163. try:
  164. exp = self._ax.experiment
  165. has_experiment = True
  166. except ValueError:
  167. has_experiment = False
  168. if not has_experiment:
  169. if not self._space:
  170. raise ValueError(
  171. "You have to create an Ax experiment by calling "
  172. "`AxClient.create_experiment()`, or you should pass an "
  173. "Ax search space as the `space` parameter to `AxSearch`, "
  174. "or pass a `param_space` dict to `tune.Tuner()`."
  175. )
  176. if self._mode not in ["min", "max"]:
  177. raise ValueError(
  178. "Please specify the `mode` argument when initializing "
  179. "the `AxSearch` object or pass it to `tune.TuneConfig()`."
  180. )
  181. self._ax.create_experiment(
  182. parameters=self._space,
  183. objective_name=self._metric,
  184. parameter_constraints=self._parameter_constraints,
  185. outcome_constraints=self._outcome_constraints,
  186. minimize=self._mode != "max",
  187. )
  188. else:
  189. if any(
  190. [
  191. self._space,
  192. self._parameter_constraints,
  193. self._outcome_constraints,
  194. self._mode,
  195. self._metric,
  196. ]
  197. ):
  198. raise ValueError(
  199. "If you create the Ax experiment yourself, do not pass "
  200. "values for these parameters to `AxSearch`: {}.".format(
  201. [
  202. "space",
  203. "parameter_constraints",
  204. "outcome_constraints",
  205. "mode",
  206. "metric",
  207. ]
  208. )
  209. )
  210. exp = self._ax.experiment
  211. # Update mode and metric from experiment if it has been passed
  212. self._mode = "min" if exp.optimization_config.objective.minimize else "max"
  213. self._metric = exp.optimization_config.objective.metric.name
  214. self._parameters = list(exp.parameters)
  215. if self._ax._enforce_sequential_optimization:
  216. logger.warning(
  217. "Detected sequential enforcement. Be sure to use "
  218. "a ConcurrencyLimiter."
  219. )
  220. def set_search_properties(
  221. self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
  222. ):
  223. if self._ax:
  224. return False
  225. space = self.convert_search_space(config)
  226. self._space = space
  227. if metric:
  228. self._metric = metric
  229. if mode:
  230. self._mode = mode
  231. self._setup_experiment()
  232. return True
  233. def suggest(self, trial_id: str) -> Optional[Dict]:
  234. if not self._ax:
  235. raise RuntimeError(
  236. UNDEFINED_SEARCH_SPACE.format(
  237. cls=self.__class__.__name__, space="space"
  238. )
  239. )
  240. if not self._metric or not self._mode:
  241. raise RuntimeError(
  242. UNDEFINED_METRIC_MODE.format(
  243. cls=self.__class__.__name__, metric=self._metric, mode=self._mode
  244. )
  245. )
  246. if self._points_to_evaluate:
  247. config = self._points_to_evaluate.pop(0)
  248. parameters, trial_index = self._ax.attach_trial(config)
  249. else:
  250. try:
  251. parameters, trial_index = self._ax.get_next_trial()
  252. except (MaxParallelismReachedException, DataRequiredError):
  253. return None
  254. self._live_trial_mapping[trial_id] = trial_index
  255. try:
  256. suggested_config = unflatten_list_dict(parameters)
  257. except AssertionError:
  258. # Fails to unflatten if keys are out of order, which only happens
  259. # if search space includes a list with both constants and
  260. # tunable hyperparameters:
  261. # Ex: "a": [1, tune.uniform(2, 3), 4]
  262. suggested_config = unflatten_list_dict(
  263. {k: parameters[k] for k in sorted(parameters.keys())}
  264. )
  265. return suggested_config
  266. def on_trial_complete(self, trial_id, result=None, error=False):
  267. """Notification for the completion of trial.
  268. Data of form key value dictionary of metric names and values.
  269. """
  270. if result:
  271. self._process_result(trial_id, result)
  272. self._live_trial_mapping.pop(trial_id)
  273. def _process_result(self, trial_id, result):
  274. ax_trial_index = self._live_trial_mapping[trial_id]
  275. metrics_to_include = [self._metric] + [
  276. oc.metric.name
  277. for oc in self._ax.experiment.optimization_config.outcome_constraints
  278. ]
  279. metric_dict = {}
  280. for key in metrics_to_include:
  281. val = result[key]
  282. if np.isnan(val) or np.isinf(val):
  283. # Don't report trials with NaN metrics to Ax
  284. self._ax.abandon_trial(
  285. trial_index=ax_trial_index,
  286. reason=f"nan/inf metrics reported by {trial_id}",
  287. )
  288. return
  289. metric_dict[key] = (val, None)
  290. self._ax.complete_trial(trial_index=ax_trial_index, raw_data=metric_dict)
  291. @staticmethod
  292. def convert_search_space(spec: Dict):
  293. resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
  294. if grid_vars:
  295. raise ValueError(
  296. "Grid search parameters cannot be automatically converted "
  297. "to an Ax search space."
  298. )
  299. # Flatten and resolve again after checking for grid search.
  300. spec = flatten_dict(spec, prevent_delimiter=True)
  301. resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
  302. def resolve_value(par, domain):
  303. sampler = domain.get_sampler()
  304. if isinstance(sampler, Quantized):
  305. logger.warning(
  306. "AxSearch does not support quantization. Dropped quantization."
  307. )
  308. sampler = sampler.sampler
  309. if isinstance(domain, Float):
  310. if isinstance(sampler, LogUniform):
  311. return {
  312. "name": par,
  313. "type": "range",
  314. "bounds": [domain.lower, domain.upper],
  315. "value_type": "float",
  316. "log_scale": True,
  317. }
  318. elif isinstance(sampler, Uniform):
  319. return {
  320. "name": par,
  321. "type": "range",
  322. "bounds": [domain.lower, domain.upper],
  323. "value_type": "float",
  324. "log_scale": False,
  325. }
  326. elif isinstance(domain, Integer):
  327. if isinstance(sampler, LogUniform):
  328. return {
  329. "name": par,
  330. "type": "range",
  331. "bounds": [domain.lower, domain.upper - 1],
  332. "value_type": "int",
  333. "log_scale": True,
  334. }
  335. elif isinstance(sampler, Uniform):
  336. return {
  337. "name": par,
  338. "type": "range",
  339. "bounds": [domain.lower, domain.upper - 1],
  340. "value_type": "int",
  341. "log_scale": False,
  342. }
  343. elif isinstance(domain, Categorical):
  344. if isinstance(sampler, Uniform):
  345. return {"name": par, "type": "choice", "values": domain.categories}
  346. raise ValueError(
  347. "AxSearch does not support parameters of type "
  348. "`{}` with samplers of type `{}`".format(
  349. type(domain).__name__, type(domain.sampler).__name__
  350. )
  351. )
  352. # Parameter name is e.g. "a/b/c" for nested dicts,
  353. # "a/d/0", "a/d/1" for nested lists (using the index in the list)
  354. fixed_values = [
  355. {"name": "/".join(str(p) for p in path), "type": "fixed", "value": val}
  356. for path, val in resolved_vars
  357. ]
  358. resolved_values = [
  359. resolve_value("/".join(str(p) for p in path), domain)
  360. for path, domain in domain_vars
  361. ]
  362. return fixed_values + resolved_values
  363. def save(self, checkpoint_path: str):
  364. save_object = self.__dict__
  365. with open(checkpoint_path, "wb") as outputFile:
  366. cloudpickle.dump(save_object, outputFile)
  367. def restore(self, checkpoint_path: str):
  368. with open(checkpoint_path, "rb") as inputFile:
  369. save_object = cloudpickle.load(inputFile)
  370. self.__dict__.update(save_object)