optuna_search.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730
  1. import functools
  2. import logging
  3. import pickle
  4. import time
  5. import warnings
  6. from typing import Any, Callable, Dict, List, Optional, Tuple, Union
  7. from packaging import version
  8. from ray.air.constants import TRAINING_ITERATION
  9. from ray.tune.result import DEFAULT_METRIC
  10. from ray.tune.search import (
  11. UNDEFINED_METRIC_MODE,
  12. UNDEFINED_SEARCH_SPACE,
  13. UNRESOLVED_SEARCH_SPACE,
  14. Searcher,
  15. )
  16. from ray.tune.search.sample import (
  17. Categorical,
  18. Domain,
  19. Float,
  20. Integer,
  21. LogUniform,
  22. Quantized,
  23. Uniform,
  24. )
  25. from ray.tune.search.variant_generator import parse_spec_vars
  26. from ray.tune.utils.util import flatten_dict, unflatten_dict, validate_warmstart
  27. try:
  28. import optuna as ot
  29. from optuna.distributions import BaseDistribution as OptunaDistribution
  30. from optuna.samplers import BaseSampler
  31. from optuna.storages import BaseStorage
  32. from optuna.trial import Trial as OptunaTrial, TrialState as OptunaTrialState
  33. except ImportError:
  34. ot = None
  35. OptunaDistribution = None
  36. BaseSampler = None
  37. BaseStorage = None
  38. OptunaTrialState = None
  39. OptunaTrial = None
  40. logger = logging.getLogger(__name__)
  41. # print a warning if define by run function takes longer than this to execute
  42. DEFINE_BY_RUN_WARN_THRESHOLD_S = 1 # 1 is arbitrary
  43. class _OptunaTrialSuggestCaptor:
  44. """Utility to capture returned values from Optuna's suggest_ methods.
  45. This will wrap around the ``optuna.Trial` object and decorate all
  46. `suggest_` callables with a function capturing the returned value,
  47. which will be saved in the ``captured_values`` dict.
  48. """
  49. def __init__(self, ot_trial: OptunaTrial) -> None:
  50. self.ot_trial = ot_trial
  51. self.captured_values: Dict[str, Any] = {}
  52. def _get_wrapper(self, func: Callable) -> Callable:
  53. @functools.wraps(func)
  54. def wrapper(*args, **kwargs):
  55. # name is always the first arg for suggest_ methods
  56. name = kwargs.get("name", args[0])
  57. ret = func(*args, **kwargs)
  58. self.captured_values[name] = ret
  59. return ret
  60. return wrapper
  61. def __getattr__(self, item_name: str) -> Any:
  62. item = getattr(self.ot_trial, item_name)
  63. if item_name.startswith("suggest_") and callable(item):
  64. return self._get_wrapper(item)
  65. return item
  66. class OptunaSearch(Searcher):
  67. """A wrapper around Optuna to provide trial suggestions.
  68. `Optuna <https://optuna.org/>`_ is a hyperparameter optimization library.
  69. In contrast to other libraries, it employs define-by-run style
  70. hyperparameter definitions.
  71. This Searcher is a thin wrapper around Optuna's search algorithms.
  72. You can pass any Optuna sampler, which will be used to generate
  73. hyperparameter suggestions.
  74. Multi-objective optimization is supported.
  75. Args:
  76. space: Hyperparameter search space definition for
  77. Optuna's sampler. This can be either a :class:`dict` with
  78. parameter names as keys and ``optuna.distributions`` as values,
  79. or a Callable - in which case, it should be a define-by-run
  80. function using ``optuna.trial`` to obtain the hyperparameter
  81. values. The function should return either a :class:`dict` of
  82. constant values with names as keys, or None.
  83. For more information, see https://optuna.readthedocs.io\
  84. /en/stable/tutorial/10_key_features/002_configurations.html.
  85. .. warning::
  86. No actual computation should take place in the define-by-run
  87. function. Instead, put the training logic inside the function
  88. or class trainable passed to ``tune.Tuner()``.
  89. metric: The training result objective value attribute. If
  90. None but a mode was passed, the anonymous metric ``_metric``
  91. will be used per default. Can be a list of metrics for
  92. multi-objective optimization.
  93. mode: One of {min, max}. Determines whether objective is
  94. minimizing or maximizing the metric attribute. Can be a list of
  95. modes for multi-objective optimization (corresponding to
  96. ``metric``).
  97. points_to_evaluate: Initial parameter suggestions to be run
  98. first. This is for when you already have some good parameters
  99. you want to run first to help the algorithm make better suggestions
  100. for future parameters. Needs to be a list of dicts containing the
  101. configurations.
  102. sampler: Optuna sampler used to
  103. draw hyperparameter configurations. Defaults to ``MOTPESampler``
  104. for multi-objective optimization with Optuna<2.9.0, and
  105. ``TPESampler`` in every other case.
  106. See https://optuna.readthedocs.io/en/stable/reference/samplers/index.html
  107. for available Optuna samplers.
  108. .. warning::
  109. Please note that with Optuna 2.10.0 and earlier
  110. default ``MOTPESampler``/``TPESampler`` suffer
  111. from performance issues when dealing with a large number of
  112. completed trials (approx. >100). This will manifest as
  113. a delay when suggesting new configurations.
  114. This is an Optuna issue and may be fixed in a future
  115. Optuna release.
  116. study_name: Optuna study name that uniquely identifies the trial
  117. results. Defaults to ``"optuna"``.
  118. storage: Optuna storage used for storing trial results to
  119. storages other than in-memory storage,
  120. for instance optuna.storages.RDBStorage.
  121. seed: Seed to initialize sampler with. This parameter is only
  122. used when ``sampler=None``. In all other cases, the sampler
  123. you pass should be initialized with the seed already.
  124. evaluated_rewards: If you have previously evaluated the
  125. parameters passed in as points_to_evaluate you can avoid
  126. re-running those trials by passing in the reward attributes
  127. as a list so the optimiser can be told the results without
  128. needing to re-compute the trial. Must be the same length as
  129. points_to_evaluate.
  130. .. warning::
  131. When using ``evaluated_rewards``, the search space ``space``
  132. must be provided as a :class:`dict` with parameter names as
  133. keys and ``optuna.distributions`` instances as values. The
  134. define-by-run search space definition is not yet supported with
  135. this functionality.
  136. Tune automatically converts search spaces to Optuna's format:
  137. .. code-block:: python
  138. from ray.tune.search.optuna import OptunaSearch
  139. config = {
  140. "a": tune.uniform(6, 8)
  141. "b": tune.loguniform(1e-4, 1e-2)
  142. }
  143. optuna_search = OptunaSearch(
  144. metric="loss",
  145. mode="min")
  146. tuner = tune.Tuner(
  147. trainable,
  148. tune_config=tune.TuneConfig(
  149. search_alg=optuna_search,
  150. ),
  151. param_space=config,
  152. )
  153. tuner.fit()
  154. If you would like to pass the search space manually, the code would
  155. look like this:
  156. .. code-block:: python
  157. from ray.tune.search.optuna import OptunaSearch
  158. import optuna
  159. space = {
  160. "a": optuna.distributions.FloatDistribution(6, 8),
  161. "b": optuna.distributions.FloatDistribution(1e-4, 1e-2, log=True),
  162. }
  163. optuna_search = OptunaSearch(
  164. space,
  165. metric="loss",
  166. mode="min")
  167. tuner = tune.Tuner(
  168. trainable,
  169. tune_config=tune.TuneConfig(
  170. search_alg=optuna_search,
  171. ),
  172. )
  173. tuner.fit()
  174. # Equivalent Optuna define-by-run function approach:
  175. def define_search_space(trial: optuna.Trial):
  176. trial.suggest_float("a", 6, 8)
  177. trial.suggest_float("b", 1e-4, 1e-2, log=True)
  178. # training logic goes into trainable, this is just
  179. # for search space definition
  180. optuna_search = OptunaSearch(
  181. define_search_space,
  182. metric="loss",
  183. mode="min")
  184. tuner = tune.Tuner(
  185. trainable,
  186. tune_config=tune.TuneConfig(
  187. search_alg=optuna_search,
  188. ),
  189. )
  190. tuner.fit()
  191. Multi-objective optimization is supported:
  192. .. code-block:: python
  193. from ray.tune.search.optuna import OptunaSearch
  194. import optuna
  195. space = {
  196. "a": optuna.distributions.FloatDistribution(6, 8),
  197. "b": optuna.distributions.FloatDistribution(1e-4, 1e-2, log=True),
  198. }
  199. # Note you have to specify metric and mode here instead of
  200. # in tune.TuneConfig
  201. optuna_search = OptunaSearch(
  202. space,
  203. metric=["loss1", "loss2"],
  204. mode=["min", "max"])
  205. # Do not specify metric and mode here!
  206. tuner = tune.Tuner(
  207. trainable,
  208. tune_config=tune.TuneConfig(
  209. search_alg=optuna_search,
  210. ),
  211. )
  212. tuner.fit()
  213. You can pass configs that will be evaluated first using
  214. ``points_to_evaluate``:
  215. .. code-block:: python
  216. from ray.tune.search.optuna import OptunaSearch
  217. import optuna
  218. space = {
  219. "a": optuna.distributions.FloatDistribution(6, 8),
  220. "b": optuna.distributions.FloatDistribution(1e-4, 1e-2, log=True),
  221. }
  222. optuna_search = OptunaSearch(
  223. space,
  224. points_to_evaluate=[{"a": 6.5, "b": 5e-4}, {"a": 7.5, "b": 1e-3}]
  225. metric="loss",
  226. mode="min")
  227. tuner = tune.Tuner(
  228. trainable,
  229. tune_config=tune.TuneConfig(
  230. search_alg=optuna_search,
  231. ),
  232. )
  233. tuner.fit()
  234. Avoid re-running evaluated trials by passing the rewards together with
  235. `points_to_evaluate`:
  236. .. code-block:: python
  237. from ray.tune.search.optuna import OptunaSearch
  238. import optuna
  239. space = {
  240. "a": optuna.distributions.FloatDistribution(6, 8),
  241. "b": optuna.distributions.FloatDistribution(1e-4, 1e-2, log=True),
  242. }
  243. optuna_search = OptunaSearch(
  244. space,
  245. points_to_evaluate=[{"a": 6.5, "b": 5e-4}, {"a": 7.5, "b": 1e-3}]
  246. evaluated_rewards=[0.89, 0.42]
  247. metric="loss",
  248. mode="min")
  249. tuner = tune.Tuner(
  250. trainable,
  251. tune_config=tune.TuneConfig(
  252. search_alg=optuna_search,
  253. ),
  254. )
  255. tuner.fit()
  256. .. versionadded:: 0.8.8
  257. """
  258. def __init__(
  259. self,
  260. space: Optional[
  261. Union[
  262. Dict[str, "OptunaDistribution"],
  263. List[Tuple],
  264. Callable[["OptunaTrial"], Optional[Dict[str, Any]]],
  265. ]
  266. ] = None,
  267. metric: Optional[Union[str, List[str]]] = None,
  268. mode: Optional[Union[str, List[str]]] = None,
  269. points_to_evaluate: Optional[List[Dict]] = None,
  270. sampler: Optional["BaseSampler"] = None,
  271. study_name: Optional[str] = None,
  272. storage: Optional["BaseStorage"] = None,
  273. seed: Optional[int] = None,
  274. evaluated_rewards: Optional[List] = None,
  275. ):
  276. assert ot is not None, "Optuna must be installed! Run `pip install optuna`."
  277. super(OptunaSearch, self).__init__(metric=metric, mode=mode)
  278. if isinstance(space, dict) and space:
  279. resolved_vars, domain_vars, grid_vars = parse_spec_vars(space)
  280. if domain_vars or grid_vars:
  281. logger.warning(
  282. UNRESOLVED_SEARCH_SPACE.format(par="space", cls=type(self).__name__)
  283. )
  284. space = self.convert_search_space(space)
  285. else:
  286. # Flatten to support nested dicts
  287. space = flatten_dict(space, "/")
  288. self._space = space
  289. self._points_to_evaluate = points_to_evaluate or []
  290. self._evaluated_rewards = evaluated_rewards
  291. if study_name:
  292. self._study_name = study_name
  293. else:
  294. self._study_name = "optuna" # Fixed study name for in-memory storage
  295. if sampler and seed:
  296. logger.warning(
  297. "You passed an initialized sampler to `OptunaSearch`. The "
  298. "`seed` parameter has to be passed to the sampler directly "
  299. "and will be ignored."
  300. )
  301. elif sampler:
  302. assert isinstance(sampler, BaseSampler), (
  303. "You can only pass an instance of "
  304. "`optuna.samplers.BaseSampler` "
  305. "as a sampler to `OptunaSearcher`."
  306. )
  307. self._sampler = sampler
  308. self._seed = seed
  309. if storage:
  310. assert isinstance(storage, BaseStorage), (
  311. "The `storage` parameter in `OptunaSearcher` must be an instance "
  312. "of `optuna.storages.BaseStorage`."
  313. )
  314. # If storage is not provided, just set self._storage to None
  315. # so that the default in-memory storage is used.
  316. self._storage = storage
  317. self._completed_trials = set()
  318. self._ot_trials = {}
  319. self._ot_study = None
  320. if self._space:
  321. self._setup_study(mode)
  322. def _setup_study(self, mode: Union[str, list]):
  323. if self._metric is None and self._mode:
  324. if isinstance(self._mode, list):
  325. raise ValueError(
  326. "If ``mode`` is a list (multi-objective optimization "
  327. "case), ``metric`` must be defined."
  328. )
  329. # If only a mode was passed, use anonymous metric
  330. self._metric = DEFAULT_METRIC
  331. pruner = ot.pruners.NopPruner()
  332. if self._sampler:
  333. sampler = self._sampler
  334. elif isinstance(mode, list) and version.parse(ot.__version__) < version.parse(
  335. "2.9.0"
  336. ):
  337. # MOTPESampler deprecated in Optuna>=2.9.0
  338. sampler = ot.samplers.MOTPESampler(seed=self._seed)
  339. else:
  340. sampler = ot.samplers.TPESampler(seed=self._seed)
  341. if isinstance(mode, list):
  342. study_direction_args = dict(
  343. directions=["minimize" if m == "min" else "maximize" for m in mode],
  344. )
  345. else:
  346. study_direction_args = dict(
  347. direction="minimize" if mode == "min" else "maximize",
  348. )
  349. self._ot_study = ot.study.create_study(
  350. storage=self._storage,
  351. sampler=sampler,
  352. pruner=pruner,
  353. study_name=self._study_name,
  354. load_if_exists=True,
  355. **study_direction_args,
  356. )
  357. if self._points_to_evaluate:
  358. validate_warmstart(
  359. self._space,
  360. self._points_to_evaluate,
  361. self._evaluated_rewards,
  362. validate_point_name_lengths=not callable(self._space),
  363. )
  364. if self._evaluated_rewards:
  365. for point, reward in zip(
  366. self._points_to_evaluate, self._evaluated_rewards
  367. ):
  368. self.add_evaluated_point(point, reward)
  369. else:
  370. for point in self._points_to_evaluate:
  371. self._ot_study.enqueue_trial(point)
  372. def set_search_properties(
  373. self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
  374. ) -> bool:
  375. if self._space:
  376. return False
  377. space = self.convert_search_space(config)
  378. self._space = space
  379. if metric:
  380. self._metric = metric
  381. if mode:
  382. self._mode = mode
  383. self._setup_study(self._mode)
  384. return True
  385. def _suggest_from_define_by_run_func(
  386. self,
  387. func: Callable[["OptunaTrial"], Optional[Dict[str, Any]]],
  388. ot_trial: "OptunaTrial",
  389. ) -> Dict:
  390. captor = _OptunaTrialSuggestCaptor(ot_trial)
  391. time_start = time.time()
  392. ret = func(captor)
  393. time_taken = time.time() - time_start
  394. if time_taken > DEFINE_BY_RUN_WARN_THRESHOLD_S:
  395. warnings.warn(
  396. "Define-by-run function passed in the `space` argument "
  397. f"took {time_taken} seconds to "
  398. "run. Ensure that actual computation, training takes "
  399. "place inside Tune's train functions or Trainables "
  400. "passed to `tune.Tuner()`."
  401. )
  402. if ret is not None:
  403. if not isinstance(ret, dict):
  404. raise TypeError(
  405. "The return value of the define-by-run function "
  406. "passed in the `space` argument should be "
  407. "either None or a `dict` with `str` keys. "
  408. f"Got {type(ret)}."
  409. )
  410. if not all(isinstance(k, str) for k in ret.keys()):
  411. raise TypeError(
  412. "At least one of the keys in the dict returned by the "
  413. "define-by-run function passed in the `space` argument "
  414. "was not a `str`."
  415. )
  416. return {**captor.captured_values, **ret} if ret else captor.captured_values
  417. def suggest(self, trial_id: str) -> Optional[Dict]:
  418. if not self._space:
  419. raise RuntimeError(
  420. UNDEFINED_SEARCH_SPACE.format(
  421. cls=self.__class__.__name__, space="space"
  422. )
  423. )
  424. if not self._metric or not self._mode:
  425. raise RuntimeError(
  426. UNDEFINED_METRIC_MODE.format(
  427. cls=self.__class__.__name__, metric=self._metric, mode=self._mode
  428. )
  429. )
  430. if callable(self._space):
  431. # Define-by-run case
  432. if trial_id not in self._ot_trials:
  433. self._ot_trials[trial_id] = self._ot_study.ask()
  434. ot_trial = self._ot_trials[trial_id]
  435. params = self._suggest_from_define_by_run_func(self._space, ot_trial)
  436. else:
  437. # Use Optuna ask interface (since version 2.6.0)
  438. if trial_id not in self._ot_trials:
  439. self._ot_trials[trial_id] = self._ot_study.ask(
  440. fixed_distributions=self._space
  441. )
  442. ot_trial = self._ot_trials[trial_id]
  443. params = ot_trial.params
  444. return unflatten_dict(params)
  445. def on_trial_result(self, trial_id: str, result: Dict):
  446. if isinstance(self.metric, list):
  447. # Optuna doesn't support incremental results
  448. # for multi-objective optimization
  449. return
  450. if trial_id in self._completed_trials:
  451. logger.warning(
  452. f"Received additional result for trial {trial_id}, but "
  453. f"it already finished. Result: {result}"
  454. )
  455. return
  456. metric = result[self.metric]
  457. step = result[TRAINING_ITERATION]
  458. ot_trial = self._ot_trials[trial_id]
  459. ot_trial.report(metric, step)
  460. def on_trial_complete(
  461. self, trial_id: str, result: Optional[Dict] = None, error: bool = False
  462. ):
  463. if trial_id in self._completed_trials:
  464. logger.warning(
  465. f"Received additional completion for trial {trial_id}, but "
  466. f"it already finished. Result: {result}"
  467. )
  468. return
  469. ot_trial = self._ot_trials[trial_id]
  470. if result:
  471. if isinstance(self.metric, list):
  472. val = [result.get(metric, None) for metric in self.metric]
  473. else:
  474. val = result.get(self.metric, None)
  475. else:
  476. val = None
  477. ot_trial_state = OptunaTrialState.COMPLETE
  478. if val is None:
  479. if error:
  480. ot_trial_state = OptunaTrialState.FAIL
  481. else:
  482. ot_trial_state = OptunaTrialState.PRUNED
  483. try:
  484. self._ot_study.tell(ot_trial, val, state=ot_trial_state)
  485. except Exception as exc:
  486. logger.warning(exc) # E.g. if NaN was reported
  487. self._completed_trials.add(trial_id)
  488. def add_evaluated_point(
  489. self,
  490. parameters: Dict,
  491. value: float,
  492. error: bool = False,
  493. pruned: bool = False,
  494. intermediate_values: Optional[List[float]] = None,
  495. ):
  496. if not self._space:
  497. raise RuntimeError(
  498. UNDEFINED_SEARCH_SPACE.format(
  499. cls=self.__class__.__name__, space="space"
  500. )
  501. )
  502. if not self._metric or not self._mode:
  503. raise RuntimeError(
  504. UNDEFINED_METRIC_MODE.format(
  505. cls=self.__class__.__name__, metric=self._metric, mode=self._mode
  506. )
  507. )
  508. if callable(self._space):
  509. raise TypeError(
  510. "Define-by-run function passed in `space` argument is not "
  511. "yet supported when using `evaluated_rewards`. Please provide "
  512. "an `OptunaDistribution` dict or pass a Ray Tune "
  513. "search space to `tune.Tuner()`."
  514. )
  515. ot_trial_state = OptunaTrialState.COMPLETE
  516. if error:
  517. ot_trial_state = OptunaTrialState.FAIL
  518. elif pruned:
  519. ot_trial_state = OptunaTrialState.PRUNED
  520. if intermediate_values:
  521. intermediate_values_dict = dict(enumerate(intermediate_values))
  522. else:
  523. intermediate_values_dict = None
  524. # If the trial state is FAILED, the value must be `None` in Optuna==4.1.0
  525. # Reference: https://github.com/optuna/optuna/pull/5211
  526. # This is a temporary fix for the issue that Optuna enforces the value
  527. # to be `None` if the trial state is FAILED.
  528. # TODO (hpguo): A better solution may requires us to update the base class
  529. # to allow the `value` arg in `add_evaluated_point` being `Optional[float]`.
  530. if ot_trial_state == OptunaTrialState.FAIL:
  531. value = None
  532. trial = ot.trial.create_trial(
  533. state=ot_trial_state,
  534. value=value,
  535. params=parameters,
  536. distributions=self._space,
  537. intermediate_values=intermediate_values_dict,
  538. )
  539. self._ot_study.add_trial(trial)
  540. def save(self, checkpoint_path: str):
  541. save_object = self.__dict__.copy()
  542. with open(checkpoint_path, "wb") as outputFile:
  543. pickle.dump(save_object, outputFile)
  544. def restore(self, checkpoint_path: str):
  545. with open(checkpoint_path, "rb") as inputFile:
  546. save_object = pickle.load(inputFile)
  547. if isinstance(save_object, dict):
  548. self.__dict__.update(save_object)
  549. else:
  550. # Backwards compatibility
  551. (
  552. self._sampler,
  553. self._ot_trials,
  554. self._ot_study,
  555. self._points_to_evaluate,
  556. self._evaluated_rewards,
  557. ) = save_object
  558. @staticmethod
  559. def convert_search_space(spec: Dict) -> Dict[str, Any]:
  560. resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
  561. if not domain_vars and not grid_vars:
  562. return {}
  563. if grid_vars:
  564. raise ValueError(
  565. "Grid search parameters cannot be automatically converted "
  566. "to an Optuna search space."
  567. )
  568. # Flatten and resolve again after checking for grid search.
  569. spec = flatten_dict(spec, prevent_delimiter=True)
  570. resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
  571. def resolve_value(domain: Domain) -> ot.distributions.BaseDistribution:
  572. quantize = None
  573. sampler = domain.get_sampler()
  574. if isinstance(sampler, Quantized):
  575. quantize = sampler.q
  576. sampler = sampler.sampler
  577. if isinstance(sampler, LogUniform):
  578. logger.warning(
  579. "Optuna does not handle quantization in loguniform "
  580. "sampling. The parameter will be passed but it will "
  581. "probably be ignored."
  582. )
  583. if isinstance(domain, Float):
  584. if isinstance(sampler, LogUniform):
  585. if quantize:
  586. logger.warning(
  587. "Optuna does not support both quantization and "
  588. "sampling from LogUniform. Dropped quantization."
  589. )
  590. return ot.distributions.FloatDistribution(
  591. domain.lower, domain.upper, log=True
  592. )
  593. elif isinstance(sampler, Uniform):
  594. if quantize:
  595. return ot.distributions.FloatDistribution(
  596. domain.lower, domain.upper, step=quantize
  597. )
  598. return ot.distributions.FloatDistribution(
  599. domain.lower, domain.upper
  600. )
  601. elif isinstance(domain, Integer):
  602. if isinstance(sampler, LogUniform):
  603. return ot.distributions.IntDistribution(
  604. domain.lower, domain.upper - 1, step=quantize or 1, log=True
  605. )
  606. elif isinstance(sampler, Uniform):
  607. # Upper bound should be inclusive for quantization and
  608. # exclusive otherwise
  609. return ot.distributions.IntDistribution(
  610. domain.lower,
  611. domain.upper - int(bool(not quantize)),
  612. step=quantize or 1,
  613. )
  614. elif isinstance(domain, Categorical):
  615. if isinstance(sampler, Uniform):
  616. return ot.distributions.CategoricalDistribution(domain.categories)
  617. raise ValueError(
  618. "Optuna search does not support parameters of type "
  619. "`{}` with samplers of type `{}`".format(
  620. type(domain).__name__, type(domain.sampler).__name__
  621. )
  622. )
  623. # Parameter name is e.g. "a/b/c" for nested dicts
  624. values = {"/".join(path): resolve_value(domain) for path, domain in domain_vars}
  625. return values