zoopt_search.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. import copy
  2. import logging
  3. from typing import Dict, List, Optional, Tuple
  4. import ray
  5. import ray.cloudpickle as pickle
  6. from ray.tune.result import DEFAULT_METRIC
  7. from ray.tune.search import (
  8. UNDEFINED_METRIC_MODE,
  9. UNDEFINED_SEARCH_SPACE,
  10. UNRESOLVED_SEARCH_SPACE,
  11. Searcher,
  12. )
  13. from ray.tune.search.sample import (
  14. Categorical,
  15. Domain,
  16. Float,
  17. Integer,
  18. Quantized,
  19. Uniform,
  20. )
  21. from ray.tune.search.variant_generator import parse_spec_vars
  22. from ray.tune.utils.util import unflatten_dict
  23. try:
  24. import zoopt
  25. from zoopt import Solution, ValueType
  26. except ImportError:
  27. zoopt = None
  28. Solution = ValueType = None
  29. logger = logging.getLogger(__name__)
  30. class ZOOptSearch(Searcher):
  31. """A wrapper around ZOOpt to provide trial suggestions.
  32. ZOOptSearch is a library for derivative-free optimization. It is backed by
  33. the `ZOOpt <https://github.com/polixir/ZOOpt>`__ package. Currently,
  34. Asynchronous Sequential RAndomized COordinate Shrinking (ASRacos)
  35. is implemented in Tune.
  36. To use ZOOptSearch, install zoopt (>=0.4.1): ``pip install -U zoopt``.
  37. Tune automatically converts search spaces to ZOOpt"s format:
  38. .. code-block:: python
  39. from ray import tune
  40. from ray.tune.search.zoopt import ZOOptSearch
  41. "config": {
  42. "iterations": 10, # evaluation times
  43. "width": tune.uniform(-10, 10),
  44. "height": tune.uniform(-10, 10)
  45. }
  46. zoopt_search_config = {
  47. "parallel_num": 8, # how many workers to parallel
  48. }
  49. zoopt_search = ZOOptSearch(
  50. algo="Asracos", # only support Asracos currently
  51. budget=20, # must match `num_samples` in `tune.TuneConfig()`.
  52. dim_dict=dim_dict,
  53. metric="mean_loss",
  54. mode="min",
  55. **zoopt_search_config
  56. )
  57. tuner = tune.Tuner(
  58. my_objective,
  59. tune_config=tune.TuneConfig(
  60. search_alg=zoopt_search,
  61. num_samples=20
  62. ),
  63. run_config=tune.RunConfig(
  64. name="zoopt_search",
  65. stop={"timesteps_total": 10}
  66. ),
  67. param_space=config
  68. )
  69. tuner.fit()
  70. If you would like to pass the search space manually, the code would
  71. look like this:
  72. .. code-block:: python
  73. from ray import tune
  74. from ray.tune.search.zoopt import ZOOptSearch
  75. from zoopt import ValueType
  76. dim_dict = {
  77. "height": (ValueType.CONTINUOUS, [-10, 10], 1e-2),
  78. "width": (ValueType.DISCRETE, [-10, 10], False),
  79. "layers": (ValueType.GRID, [4, 8, 16])
  80. }
  81. "config": {
  82. "iterations": 10, # evaluation times
  83. }
  84. zoopt_search_config = {
  85. "parallel_num": 8, # how many workers to parallel
  86. }
  87. zoopt_search = ZOOptSearch(
  88. algo="Asracos", # only support Asracos currently
  89. budget=20, # must match `num_samples` in `tune.TuneConfig()`.
  90. dim_dict=dim_dict,
  91. metric="mean_loss",
  92. mode="min",
  93. **zoopt_search_config
  94. )
  95. tuner = tune.Tuner(
  96. my_objective,
  97. tune_config=tune.TuneConfig(
  98. search_alg=zoopt_search,
  99. num_samples=20
  100. ),
  101. run_config=tune.RunConfig(
  102. name="zoopt_search",
  103. stop={"timesteps_total": 10}
  104. ),
  105. )
  106. tuner.fit()
  107. Parameters:
  108. algo: To specify an algorithm in zoopt you want to use.
  109. Only support ASRacos currently.
  110. budget: Number of samples.
  111. dim_dict: Dimension dictionary.
  112. For continuous dimensions: (continuous, search_range, precision);
  113. For discrete dimensions: (discrete, search_range, has_order);
  114. For grid dimensions: (grid, grid_list).
  115. More details can be found in zoopt package.
  116. metric: The training result objective value attribute. If None
  117. but a mode was passed, the anonymous metric `_metric` will be used
  118. per default.
  119. mode: One of {min, max}. Determines whether objective is
  120. minimizing or maximizing the metric attribute.
  121. points_to_evaluate: Initial parameter suggestions to be run
  122. first. This is for when you already have some good parameters
  123. you want to run first to help the algorithm make better suggestions
  124. for future parameters. Needs to be a list of dicts containing the
  125. configurations.
  126. parallel_num: How many workers to parallel. Note that initial
  127. phase may start less workers than this number. More details can
  128. be found in zoopt package.
  129. """
  130. optimizer = None
  131. def __init__(
  132. self,
  133. algo: str = "asracos",
  134. budget: Optional[int] = None,
  135. dim_dict: Optional[Dict] = None,
  136. metric: Optional[str] = None,
  137. mode: Optional[str] = None,
  138. points_to_evaluate: Optional[List[Dict]] = None,
  139. parallel_num: int = 1,
  140. **kwargs
  141. ):
  142. assert (
  143. zoopt is not None
  144. ), "ZOOpt not found - please install zoopt by `pip install -U zoopt`."
  145. assert budget is not None, "`budget` should not be None!"
  146. if mode:
  147. assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
  148. _algo = algo.lower()
  149. assert _algo in [
  150. "asracos",
  151. "sracos",
  152. ], "`algo` must be in ['asracos', 'sracos'] currently"
  153. self._algo = _algo
  154. if isinstance(dim_dict, dict) and dim_dict:
  155. resolved_vars, domain_vars, grid_vars = parse_spec_vars(dim_dict)
  156. if domain_vars or grid_vars:
  157. logger.warning(
  158. UNRESOLVED_SEARCH_SPACE.format(par="dim_dict", cls=type(self))
  159. )
  160. dim_dict = self.convert_search_space(dim_dict, join=True)
  161. self._dim_dict = dim_dict
  162. self._budget = budget
  163. self._metric = metric
  164. if mode == "max":
  165. self._metric_op = -1.0
  166. elif mode == "min":
  167. self._metric_op = 1.0
  168. self._points_to_evaluate = copy.deepcopy(points_to_evaluate)
  169. self._live_trial_mapping = {}
  170. self._dim_keys = []
  171. self.solution_dict = {}
  172. self.best_solution_list = []
  173. self.optimizer = None
  174. self.kwargs = kwargs
  175. self.parallel_num = parallel_num
  176. super(ZOOptSearch, self).__init__(metric=self._metric, mode=mode)
  177. if self._dim_dict:
  178. self._setup_zoopt()
  179. def _setup_zoopt(self):
  180. if self._metric is None and self._mode:
  181. # If only a mode was passed, use anonymous metric
  182. self._metric = DEFAULT_METRIC
  183. _dim_list = []
  184. for k in self._dim_dict:
  185. self._dim_keys.append(k)
  186. _dim_list.append(self._dim_dict[k])
  187. init_samples = None
  188. if self._points_to_evaluate:
  189. logger.warning(
  190. "`points_to_evaluate` is ignored by ZOOpt in versions <= 0.4.1."
  191. )
  192. init_samples = [
  193. Solution(x=tuple(point[dim] for dim in self._dim_keys))
  194. for point in self._points_to_evaluate
  195. ]
  196. dim = zoopt.Dimension2(_dim_list)
  197. par = zoopt.Parameter(budget=self._budget, init_samples=init_samples)
  198. if self._algo == "sracos" or self._algo == "asracos":
  199. from zoopt.algos.opt_algorithms.racos.sracos import SRacosTune
  200. self.optimizer = SRacosTune(
  201. dimension=dim,
  202. parameter=par,
  203. parallel_num=self.parallel_num,
  204. **self.kwargs
  205. )
  206. def set_search_properties(
  207. self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
  208. ) -> bool:
  209. if self._dim_dict:
  210. return False
  211. space = self.convert_search_space(config)
  212. self._dim_dict = space
  213. if metric:
  214. self._metric = metric
  215. if mode:
  216. self._mode = mode
  217. if self._mode == "max":
  218. self._metric_op = -1.0
  219. elif self._mode == "min":
  220. self._metric_op = 1.0
  221. self._setup_zoopt()
  222. return True
  223. def suggest(self, trial_id: str) -> Optional[Dict]:
  224. if not self._dim_dict or not self.optimizer:
  225. raise RuntimeError(
  226. UNDEFINED_SEARCH_SPACE.format(
  227. cls=self.__class__.__name__, space="dim_dict"
  228. )
  229. )
  230. if not self._metric or not self._mode:
  231. raise RuntimeError(
  232. UNDEFINED_METRIC_MODE.format(
  233. cls=self.__class__.__name__, metric=self._metric, mode=self._mode
  234. )
  235. )
  236. _solution = self.optimizer.suggest()
  237. if _solution == "FINISHED":
  238. if ray.__version__ >= "0.8.7":
  239. return Searcher.FINISHED
  240. else:
  241. return None
  242. if _solution:
  243. self.solution_dict[str(trial_id)] = _solution
  244. _x = _solution.get_x()
  245. new_trial = dict(zip(self._dim_keys, _x))
  246. self._live_trial_mapping[trial_id] = new_trial
  247. return unflatten_dict(new_trial)
  248. def on_trial_complete(
  249. self, trial_id: str, result: Optional[Dict] = None, error: bool = False
  250. ):
  251. """Notification for the completion of trial."""
  252. if result:
  253. _solution = self.solution_dict[str(trial_id)]
  254. _best_solution_so_far = self.optimizer.complete(
  255. _solution, self._metric_op * result[self._metric]
  256. )
  257. if _best_solution_so_far:
  258. self.best_solution_list.append(_best_solution_so_far)
  259. del self._live_trial_mapping[trial_id]
  260. def save(self, checkpoint_path: str):
  261. save_object = self.__dict__
  262. with open(checkpoint_path, "wb") as outputFile:
  263. pickle.dump(save_object, outputFile)
  264. def restore(self, checkpoint_path: str):
  265. with open(checkpoint_path, "rb") as inputFile:
  266. save_object = pickle.load(inputFile)
  267. self.__dict__.update(save_object)
  268. @staticmethod
  269. def convert_search_space(spec: Dict, join: bool = False) -> Dict[str, Tuple]:
  270. spec = copy.deepcopy(spec)
  271. resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
  272. if not domain_vars and not grid_vars:
  273. return {}
  274. if grid_vars:
  275. raise ValueError(
  276. "Grid search parameters cannot be automatically converted "
  277. "to a ZOOpt search space."
  278. )
  279. def resolve_value(domain: Domain) -> Tuple:
  280. quantize = None
  281. sampler = domain.get_sampler()
  282. if isinstance(sampler, Quantized):
  283. quantize = sampler.q
  284. sampler = sampler.sampler
  285. if isinstance(domain, Float):
  286. precision = quantize or 1e-12
  287. if isinstance(sampler, Uniform):
  288. return (
  289. ValueType.CONTINUOUS,
  290. [domain.lower, domain.upper],
  291. precision,
  292. )
  293. elif isinstance(domain, Integer):
  294. if isinstance(sampler, Uniform):
  295. return (ValueType.DISCRETE, [domain.lower, domain.upper - 1], True)
  296. elif isinstance(domain, Categorical):
  297. # Categorical variables would use ValueType.DISCRETE with
  298. # has_partial_order=False, however, currently we do not
  299. # keep track of category values and cannot automatically
  300. # translate back and forth between them.
  301. if isinstance(sampler, Uniform):
  302. return (ValueType.GRID, domain.categories)
  303. raise ValueError(
  304. "ZOOpt does not support parameters of type "
  305. "`{}` with samplers of type `{}`".format(
  306. type(domain).__name__, type(domain.sampler).__name__
  307. )
  308. )
  309. conv_spec = {
  310. "/".join(path): resolve_value(domain) for path, domain in domain_vars
  311. }
  312. if join:
  313. spec.update(conv_spec)
  314. conv_spec = spec
  315. return conv_spec