| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 |
- from typing import TYPE_CHECKING, Dict, List, Optional, Union
- from ray.util.annotations import DeveloperAPI
- if TYPE_CHECKING:
- from ray.tune.experiment import Experiment
- @DeveloperAPI
- class SearchAlgorithm:
- """Interface of an event handler API for hyperparameter search.
- Unlike TrialSchedulers, SearchAlgorithms will not have the ability
- to modify the execution (i.e., stop and pause trials).
- Trials added manually (i.e., via the Client API) will also notify
- this class upon new events, so custom search algorithms should
- maintain a list of trials ID generated from this class.
- See also: `ray.tune.search.BasicVariantGenerator`.
- """
- _finished = False
- _metric = None
- @property
- def metric(self):
- return self._metric
- def set_search_properties(
- self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
- ) -> bool:
- """Pass search properties to search algorithm.
- This method acts as an alternative to instantiating search algorithms
- with their own specific search spaces. Instead they can accept a
- Tune config through this method.
- The search algorithm will usually pass this method to their
- ``Searcher`` instance.
- Args:
- metric: Metric to optimize
- mode: One of ["min", "max"]. Direction to optimize.
- config: Tune config dict.
- **spec: Any kwargs for forward compatibility.
- Info like Experiment.PUBLIC_KEYS is provided through here.
- """
- if self._metric and metric:
- return False
- if metric:
- self._metric = metric
- return True
- @property
- def total_samples(self):
- """Get number of total trials to be generated"""
- return 0
- def add_configurations(
- self, experiments: Union["Experiment", List["Experiment"], Dict[str, Dict]]
- ):
- """Tracks given experiment specifications.
- Arguments:
- experiments: Experiments to run.
- """
- raise NotImplementedError
- def next_trial(self):
- """Returns single Trial object to be queued into the TrialRunner.
- Returns:
- trial: Returns a Trial object.
- """
- raise NotImplementedError
- def on_trial_result(self, trial_id: str, result: Dict):
- """Called on each intermediate result returned by a trial.
- This will only be called when the trial is in the RUNNING state.
- Arguments:
- trial_id: Identifier for the trial.
- result: Result dictionary.
- """
- pass
- def on_trial_complete(
- self, trial_id: str, result: Optional[Dict] = None, error: bool = False
- ):
- """Notification for the completion of trial.
- Arguments:
- trial_id: Identifier for the trial.
- result: Defaults to None. A dict will
- be provided with this notification when the trial is in
- the RUNNING state AND either completes naturally or
- by manual termination.
- error: Defaults to False. True if the trial is in
- the RUNNING state and errors.
- """
- pass
- def is_finished(self) -> bool:
- """Returns True if no trials left to be queued into TrialRunner.
- Can return True before all trials have finished executing.
- """
- return self._finished
- def set_finished(self):
- """Marks the search algorithm as finished."""
- self._finished = True
- def has_checkpoint(self, dirpath: str) -> bool:
- """Should return False if restoring is not implemented."""
- return False
- def save_to_dir(self, dirpath: str, **kwargs):
- """Saves a search algorithm."""
- pass
- def restore_from_dir(self, dirpath: str):
- """Restores a search algorithm along with its wrapped state."""
- pass
|