| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730 |
- import functools
- import logging
- import pickle
- import time
- import warnings
- from typing import Any, Callable, Dict, List, Optional, Tuple, Union
- from packaging import version
- from ray.air.constants import TRAINING_ITERATION
- from ray.tune.result import DEFAULT_METRIC
- from ray.tune.search import (
- UNDEFINED_METRIC_MODE,
- UNDEFINED_SEARCH_SPACE,
- UNRESOLVED_SEARCH_SPACE,
- Searcher,
- )
- from ray.tune.search.sample import (
- Categorical,
- Domain,
- Float,
- Integer,
- LogUniform,
- Quantized,
- Uniform,
- )
- from ray.tune.search.variant_generator import parse_spec_vars
- from ray.tune.utils.util import flatten_dict, unflatten_dict, validate_warmstart
- try:
- import optuna as ot
- from optuna.distributions import BaseDistribution as OptunaDistribution
- from optuna.samplers import BaseSampler
- from optuna.storages import BaseStorage
- from optuna.trial import Trial as OptunaTrial, TrialState as OptunaTrialState
- except ImportError:
- ot = None
- OptunaDistribution = None
- BaseSampler = None
- BaseStorage = None
- OptunaTrialState = None
- OptunaTrial = None
- logger = logging.getLogger(__name__)
- # print a warning if define by run function takes longer than this to execute
- DEFINE_BY_RUN_WARN_THRESHOLD_S = 1 # 1 is arbitrary
- class _OptunaTrialSuggestCaptor:
- """Utility to capture returned values from Optuna's suggest_ methods.
- This will wrap around the ``optuna.Trial` object and decorate all
- `suggest_` callables with a function capturing the returned value,
- which will be saved in the ``captured_values`` dict.
- """
- def __init__(self, ot_trial: OptunaTrial) -> None:
- self.ot_trial = ot_trial
- self.captured_values: Dict[str, Any] = {}
- def _get_wrapper(self, func: Callable) -> Callable:
- @functools.wraps(func)
- def wrapper(*args, **kwargs):
- # name is always the first arg for suggest_ methods
- name = kwargs.get("name", args[0])
- ret = func(*args, **kwargs)
- self.captured_values[name] = ret
- return ret
- return wrapper
- def __getattr__(self, item_name: str) -> Any:
- item = getattr(self.ot_trial, item_name)
- if item_name.startswith("suggest_") and callable(item):
- return self._get_wrapper(item)
- return item
- class OptunaSearch(Searcher):
- """A wrapper around Optuna to provide trial suggestions.
- `Optuna <https://optuna.org/>`_ is a hyperparameter optimization library.
- In contrast to other libraries, it employs define-by-run style
- hyperparameter definitions.
- This Searcher is a thin wrapper around Optuna's search algorithms.
- You can pass any Optuna sampler, which will be used to generate
- hyperparameter suggestions.
- Multi-objective optimization is supported.
- Args:
- space: Hyperparameter search space definition for
- Optuna's sampler. This can be either a :class:`dict` with
- parameter names as keys and ``optuna.distributions`` as values,
- or a Callable - in which case, it should be a define-by-run
- function using ``optuna.trial`` to obtain the hyperparameter
- values. The function should return either a :class:`dict` of
- constant values with names as keys, or None.
- For more information, see https://optuna.readthedocs.io\
- /en/stable/tutorial/10_key_features/002_configurations.html.
- .. warning::
- No actual computation should take place in the define-by-run
- function. Instead, put the training logic inside the function
- or class trainable passed to ``tune.Tuner()``.
- metric: The training result objective value attribute. If
- None but a mode was passed, the anonymous metric ``_metric``
- will be used per default. Can be a list of metrics for
- multi-objective optimization.
- mode: One of {min, max}. Determines whether objective is
- minimizing or maximizing the metric attribute. Can be a list of
- modes for multi-objective optimization (corresponding to
- ``metric``).
- points_to_evaluate: Initial parameter suggestions to be run
- first. This is for when you already have some good parameters
- you want to run first to help the algorithm make better suggestions
- for future parameters. Needs to be a list of dicts containing the
- configurations.
- sampler: Optuna sampler used to
- draw hyperparameter configurations. Defaults to ``MOTPESampler``
- for multi-objective optimization with Optuna<2.9.0, and
- ``TPESampler`` in every other case.
- See https://optuna.readthedocs.io/en/stable/reference/samplers/index.html
- for available Optuna samplers.
- .. warning::
- Please note that with Optuna 2.10.0 and earlier
- default ``MOTPESampler``/``TPESampler`` suffer
- from performance issues when dealing with a large number of
- completed trials (approx. >100). This will manifest as
- a delay when suggesting new configurations.
- This is an Optuna issue and may be fixed in a future
- Optuna release.
- study_name: Optuna study name that uniquely identifies the trial
- results. Defaults to ``"optuna"``.
- storage: Optuna storage used for storing trial results to
- storages other than in-memory storage,
- for instance optuna.storages.RDBStorage.
- seed: Seed to initialize sampler with. This parameter is only
- used when ``sampler=None``. In all other cases, the sampler
- you pass should be initialized with the seed already.
- evaluated_rewards: If you have previously evaluated the
- parameters passed in as points_to_evaluate you can avoid
- re-running those trials by passing in the reward attributes
- as a list so the optimiser can be told the results without
- needing to re-compute the trial. Must be the same length as
- points_to_evaluate.
- .. warning::
- When using ``evaluated_rewards``, the search space ``space``
- must be provided as a :class:`dict` with parameter names as
- keys and ``optuna.distributions`` instances as values. The
- define-by-run search space definition is not yet supported with
- this functionality.
- Tune automatically converts search spaces to Optuna's format:
- .. code-block:: python
- from ray.tune.search.optuna import OptunaSearch
- config = {
- "a": tune.uniform(6, 8)
- "b": tune.loguniform(1e-4, 1e-2)
- }
- optuna_search = OptunaSearch(
- metric="loss",
- mode="min")
- tuner = tune.Tuner(
- trainable,
- tune_config=tune.TuneConfig(
- search_alg=optuna_search,
- ),
- param_space=config,
- )
- tuner.fit()
- If you would like to pass the search space manually, the code would
- look like this:
- .. code-block:: python
- from ray.tune.search.optuna import OptunaSearch
- import optuna
- space = {
- "a": optuna.distributions.FloatDistribution(6, 8),
- "b": optuna.distributions.FloatDistribution(1e-4, 1e-2, log=True),
- }
- optuna_search = OptunaSearch(
- space,
- metric="loss",
- mode="min")
- tuner = tune.Tuner(
- trainable,
- tune_config=tune.TuneConfig(
- search_alg=optuna_search,
- ),
- )
- tuner.fit()
- # Equivalent Optuna define-by-run function approach:
- def define_search_space(trial: optuna.Trial):
- trial.suggest_float("a", 6, 8)
- trial.suggest_float("b", 1e-4, 1e-2, log=True)
- # training logic goes into trainable, this is just
- # for search space definition
- optuna_search = OptunaSearch(
- define_search_space,
- metric="loss",
- mode="min")
- tuner = tune.Tuner(
- trainable,
- tune_config=tune.TuneConfig(
- search_alg=optuna_search,
- ),
- )
- tuner.fit()
- Multi-objective optimization is supported:
- .. code-block:: python
- from ray.tune.search.optuna import OptunaSearch
- import optuna
- space = {
- "a": optuna.distributions.FloatDistribution(6, 8),
- "b": optuna.distributions.FloatDistribution(1e-4, 1e-2, log=True),
- }
- # Note you have to specify metric and mode here instead of
- # in tune.TuneConfig
- optuna_search = OptunaSearch(
- space,
- metric=["loss1", "loss2"],
- mode=["min", "max"])
- # Do not specify metric and mode here!
- tuner = tune.Tuner(
- trainable,
- tune_config=tune.TuneConfig(
- search_alg=optuna_search,
- ),
- )
- tuner.fit()
- You can pass configs that will be evaluated first using
- ``points_to_evaluate``:
- .. code-block:: python
- from ray.tune.search.optuna import OptunaSearch
- import optuna
- space = {
- "a": optuna.distributions.FloatDistribution(6, 8),
- "b": optuna.distributions.FloatDistribution(1e-4, 1e-2, log=True),
- }
- optuna_search = OptunaSearch(
- space,
- points_to_evaluate=[{"a": 6.5, "b": 5e-4}, {"a": 7.5, "b": 1e-3}]
- metric="loss",
- mode="min")
- tuner = tune.Tuner(
- trainable,
- tune_config=tune.TuneConfig(
- search_alg=optuna_search,
- ),
- )
- tuner.fit()
- Avoid re-running evaluated trials by passing the rewards together with
- `points_to_evaluate`:
- .. code-block:: python
- from ray.tune.search.optuna import OptunaSearch
- import optuna
- space = {
- "a": optuna.distributions.FloatDistribution(6, 8),
- "b": optuna.distributions.FloatDistribution(1e-4, 1e-2, log=True),
- }
- optuna_search = OptunaSearch(
- space,
- points_to_evaluate=[{"a": 6.5, "b": 5e-4}, {"a": 7.5, "b": 1e-3}]
- evaluated_rewards=[0.89, 0.42]
- metric="loss",
- mode="min")
- tuner = tune.Tuner(
- trainable,
- tune_config=tune.TuneConfig(
- search_alg=optuna_search,
- ),
- )
- tuner.fit()
- .. versionadded:: 0.8.8
- """
- def __init__(
- self,
- space: Optional[
- Union[
- Dict[str, "OptunaDistribution"],
- List[Tuple],
- Callable[["OptunaTrial"], Optional[Dict[str, Any]]],
- ]
- ] = None,
- metric: Optional[Union[str, List[str]]] = None,
- mode: Optional[Union[str, List[str]]] = None,
- points_to_evaluate: Optional[List[Dict]] = None,
- sampler: Optional["BaseSampler"] = None,
- study_name: Optional[str] = None,
- storage: Optional["BaseStorage"] = None,
- seed: Optional[int] = None,
- evaluated_rewards: Optional[List] = None,
- ):
- assert ot is not None, "Optuna must be installed! Run `pip install optuna`."
- super(OptunaSearch, self).__init__(metric=metric, mode=mode)
- if isinstance(space, dict) and space:
- resolved_vars, domain_vars, grid_vars = parse_spec_vars(space)
- if domain_vars or grid_vars:
- logger.warning(
- UNRESOLVED_SEARCH_SPACE.format(par="space", cls=type(self).__name__)
- )
- space = self.convert_search_space(space)
- else:
- # Flatten to support nested dicts
- space = flatten_dict(space, "/")
- self._space = space
- self._points_to_evaluate = points_to_evaluate or []
- self._evaluated_rewards = evaluated_rewards
- if study_name:
- self._study_name = study_name
- else:
- self._study_name = "optuna" # Fixed study name for in-memory storage
- if sampler and seed:
- logger.warning(
- "You passed an initialized sampler to `OptunaSearch`. The "
- "`seed` parameter has to be passed to the sampler directly "
- "and will be ignored."
- )
- elif sampler:
- assert isinstance(sampler, BaseSampler), (
- "You can only pass an instance of "
- "`optuna.samplers.BaseSampler` "
- "as a sampler to `OptunaSearcher`."
- )
- self._sampler = sampler
- self._seed = seed
- if storage:
- assert isinstance(storage, BaseStorage), (
- "The `storage` parameter in `OptunaSearcher` must be an instance "
- "of `optuna.storages.BaseStorage`."
- )
- # If storage is not provided, just set self._storage to None
- # so that the default in-memory storage is used.
- self._storage = storage
- self._completed_trials = set()
- self._ot_trials = {}
- self._ot_study = None
- if self._space:
- self._setup_study(mode)
- def _setup_study(self, mode: Union[str, list]):
- if self._metric is None and self._mode:
- if isinstance(self._mode, list):
- raise ValueError(
- "If ``mode`` is a list (multi-objective optimization "
- "case), ``metric`` must be defined."
- )
- # If only a mode was passed, use anonymous metric
- self._metric = DEFAULT_METRIC
- pruner = ot.pruners.NopPruner()
- if self._sampler:
- sampler = self._sampler
- elif isinstance(mode, list) and version.parse(ot.__version__) < version.parse(
- "2.9.0"
- ):
- # MOTPESampler deprecated in Optuna>=2.9.0
- sampler = ot.samplers.MOTPESampler(seed=self._seed)
- else:
- sampler = ot.samplers.TPESampler(seed=self._seed)
- if isinstance(mode, list):
- study_direction_args = dict(
- directions=["minimize" if m == "min" else "maximize" for m in mode],
- )
- else:
- study_direction_args = dict(
- direction="minimize" if mode == "min" else "maximize",
- )
- self._ot_study = ot.study.create_study(
- storage=self._storage,
- sampler=sampler,
- pruner=pruner,
- study_name=self._study_name,
- load_if_exists=True,
- **study_direction_args,
- )
- if self._points_to_evaluate:
- validate_warmstart(
- self._space,
- self._points_to_evaluate,
- self._evaluated_rewards,
- validate_point_name_lengths=not callable(self._space),
- )
- if self._evaluated_rewards:
- for point, reward in zip(
- self._points_to_evaluate, self._evaluated_rewards
- ):
- self.add_evaluated_point(point, reward)
- else:
- for point in self._points_to_evaluate:
- self._ot_study.enqueue_trial(point)
- def set_search_properties(
- self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
- ) -> bool:
- if self._space:
- return False
- space = self.convert_search_space(config)
- self._space = space
- if metric:
- self._metric = metric
- if mode:
- self._mode = mode
- self._setup_study(self._mode)
- return True
- def _suggest_from_define_by_run_func(
- self,
- func: Callable[["OptunaTrial"], Optional[Dict[str, Any]]],
- ot_trial: "OptunaTrial",
- ) -> Dict:
- captor = _OptunaTrialSuggestCaptor(ot_trial)
- time_start = time.time()
- ret = func(captor)
- time_taken = time.time() - time_start
- if time_taken > DEFINE_BY_RUN_WARN_THRESHOLD_S:
- warnings.warn(
- "Define-by-run function passed in the `space` argument "
- f"took {time_taken} seconds to "
- "run. Ensure that actual computation, training takes "
- "place inside Tune's train functions or Trainables "
- "passed to `tune.Tuner()`."
- )
- if ret is not None:
- if not isinstance(ret, dict):
- raise TypeError(
- "The return value of the define-by-run function "
- "passed in the `space` argument should be "
- "either None or a `dict` with `str` keys. "
- f"Got {type(ret)}."
- )
- if not all(isinstance(k, str) for k in ret.keys()):
- raise TypeError(
- "At least one of the keys in the dict returned by the "
- "define-by-run function passed in the `space` argument "
- "was not a `str`."
- )
- return {**captor.captured_values, **ret} if ret else captor.captured_values
- def suggest(self, trial_id: str) -> Optional[Dict]:
- if not self._space:
- raise RuntimeError(
- UNDEFINED_SEARCH_SPACE.format(
- cls=self.__class__.__name__, space="space"
- )
- )
- if not self._metric or not self._mode:
- raise RuntimeError(
- UNDEFINED_METRIC_MODE.format(
- cls=self.__class__.__name__, metric=self._metric, mode=self._mode
- )
- )
- if callable(self._space):
- # Define-by-run case
- if trial_id not in self._ot_trials:
- self._ot_trials[trial_id] = self._ot_study.ask()
- ot_trial = self._ot_trials[trial_id]
- params = self._suggest_from_define_by_run_func(self._space, ot_trial)
- else:
- # Use Optuna ask interface (since version 2.6.0)
- if trial_id not in self._ot_trials:
- self._ot_trials[trial_id] = self._ot_study.ask(
- fixed_distributions=self._space
- )
- ot_trial = self._ot_trials[trial_id]
- params = ot_trial.params
- return unflatten_dict(params)
- def on_trial_result(self, trial_id: str, result: Dict):
- if isinstance(self.metric, list):
- # Optuna doesn't support incremental results
- # for multi-objective optimization
- return
- if trial_id in self._completed_trials:
- logger.warning(
- f"Received additional result for trial {trial_id}, but "
- f"it already finished. Result: {result}"
- )
- return
- metric = result[self.metric]
- step = result[TRAINING_ITERATION]
- ot_trial = self._ot_trials[trial_id]
- ot_trial.report(metric, step)
- def on_trial_complete(
- self, trial_id: str, result: Optional[Dict] = None, error: bool = False
- ):
- if trial_id in self._completed_trials:
- logger.warning(
- f"Received additional completion for trial {trial_id}, but "
- f"it already finished. Result: {result}"
- )
- return
- ot_trial = self._ot_trials[trial_id]
- if result:
- if isinstance(self.metric, list):
- val = [result.get(metric, None) for metric in self.metric]
- else:
- val = result.get(self.metric, None)
- else:
- val = None
- ot_trial_state = OptunaTrialState.COMPLETE
- if val is None:
- if error:
- ot_trial_state = OptunaTrialState.FAIL
- else:
- ot_trial_state = OptunaTrialState.PRUNED
- try:
- self._ot_study.tell(ot_trial, val, state=ot_trial_state)
- except Exception as exc:
- logger.warning(exc) # E.g. if NaN was reported
- self._completed_trials.add(trial_id)
- def add_evaluated_point(
- self,
- parameters: Dict,
- value: float,
- error: bool = False,
- pruned: bool = False,
- intermediate_values: Optional[List[float]] = None,
- ):
- if not self._space:
- raise RuntimeError(
- UNDEFINED_SEARCH_SPACE.format(
- cls=self.__class__.__name__, space="space"
- )
- )
- if not self._metric or not self._mode:
- raise RuntimeError(
- UNDEFINED_METRIC_MODE.format(
- cls=self.__class__.__name__, metric=self._metric, mode=self._mode
- )
- )
- if callable(self._space):
- raise TypeError(
- "Define-by-run function passed in `space` argument is not "
- "yet supported when using `evaluated_rewards`. Please provide "
- "an `OptunaDistribution` dict or pass a Ray Tune "
- "search space to `tune.Tuner()`."
- )
- ot_trial_state = OptunaTrialState.COMPLETE
- if error:
- ot_trial_state = OptunaTrialState.FAIL
- elif pruned:
- ot_trial_state = OptunaTrialState.PRUNED
- if intermediate_values:
- intermediate_values_dict = dict(enumerate(intermediate_values))
- else:
- intermediate_values_dict = None
- # If the trial state is FAILED, the value must be `None` in Optuna==4.1.0
- # Reference: https://github.com/optuna/optuna/pull/5211
- # This is a temporary fix for the issue that Optuna enforces the value
- # to be `None` if the trial state is FAILED.
- # TODO (hpguo): A better solution may requires us to update the base class
- # to allow the `value` arg in `add_evaluated_point` being `Optional[float]`.
- if ot_trial_state == OptunaTrialState.FAIL:
- value = None
- trial = ot.trial.create_trial(
- state=ot_trial_state,
- value=value,
- params=parameters,
- distributions=self._space,
- intermediate_values=intermediate_values_dict,
- )
- self._ot_study.add_trial(trial)
- def save(self, checkpoint_path: str):
- save_object = self.__dict__.copy()
- with open(checkpoint_path, "wb") as outputFile:
- pickle.dump(save_object, outputFile)
- def restore(self, checkpoint_path: str):
- with open(checkpoint_path, "rb") as inputFile:
- save_object = pickle.load(inputFile)
- if isinstance(save_object, dict):
- self.__dict__.update(save_object)
- else:
- # Backwards compatibility
- (
- self._sampler,
- self._ot_trials,
- self._ot_study,
- self._points_to_evaluate,
- self._evaluated_rewards,
- ) = save_object
- @staticmethod
- def convert_search_space(spec: Dict) -> Dict[str, Any]:
- resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
- if not domain_vars and not grid_vars:
- return {}
- if grid_vars:
- raise ValueError(
- "Grid search parameters cannot be automatically converted "
- "to an Optuna search space."
- )
- # Flatten and resolve again after checking for grid search.
- spec = flatten_dict(spec, prevent_delimiter=True)
- resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
- def resolve_value(domain: Domain) -> ot.distributions.BaseDistribution:
- quantize = None
- sampler = domain.get_sampler()
- if isinstance(sampler, Quantized):
- quantize = sampler.q
- sampler = sampler.sampler
- if isinstance(sampler, LogUniform):
- logger.warning(
- "Optuna does not handle quantization in loguniform "
- "sampling. The parameter will be passed but it will "
- "probably be ignored."
- )
- if isinstance(domain, Float):
- if isinstance(sampler, LogUniform):
- if quantize:
- logger.warning(
- "Optuna does not support both quantization and "
- "sampling from LogUniform. Dropped quantization."
- )
- return ot.distributions.FloatDistribution(
- domain.lower, domain.upper, log=True
- )
- elif isinstance(sampler, Uniform):
- if quantize:
- return ot.distributions.FloatDistribution(
- domain.lower, domain.upper, step=quantize
- )
- return ot.distributions.FloatDistribution(
- domain.lower, domain.upper
- )
- elif isinstance(domain, Integer):
- if isinstance(sampler, LogUniform):
- return ot.distributions.IntDistribution(
- domain.lower, domain.upper - 1, step=quantize or 1, log=True
- )
- elif isinstance(sampler, Uniform):
- # Upper bound should be inclusive for quantization and
- # exclusive otherwise
- return ot.distributions.IntDistribution(
- domain.lower,
- domain.upper - int(bool(not quantize)),
- step=quantize or 1,
- )
- elif isinstance(domain, Categorical):
- if isinstance(sampler, Uniform):
- return ot.distributions.CategoricalDistribution(domain.categories)
- raise ValueError(
- "Optuna search does not support parameters of type "
- "`{}` with samplers of type `{}`".format(
- type(domain).__name__, type(domain.sampler).__name__
- )
- )
- # Parameter name is e.g. "a/b/c" for nested dicts
- values = {"/".join(path): resolve_value(domain) for path, domain in domain_vars}
- return values
|