search_algorithm.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. from typing import TYPE_CHECKING, Dict, List, Optional, Union
  2. from ray.util.annotations import DeveloperAPI
  3. if TYPE_CHECKING:
  4. from ray.tune.experiment import Experiment
  5. @DeveloperAPI
  6. class SearchAlgorithm:
  7. """Interface of an event handler API for hyperparameter search.
  8. Unlike TrialSchedulers, SearchAlgorithms will not have the ability
  9. to modify the execution (i.e., stop and pause trials).
  10. Trials added manually (i.e., via the Client API) will also notify
  11. this class upon new events, so custom search algorithms should
  12. maintain a list of trials ID generated from this class.
  13. See also: `ray.tune.search.BasicVariantGenerator`.
  14. """
  15. _finished = False
  16. _metric = None
  17. @property
  18. def metric(self):
  19. return self._metric
  20. def set_search_properties(
  21. self, metric: Optional[str], mode: Optional[str], config: Dict, **spec
  22. ) -> bool:
  23. """Pass search properties to search algorithm.
  24. This method acts as an alternative to instantiating search algorithms
  25. with their own specific search spaces. Instead they can accept a
  26. Tune config through this method.
  27. The search algorithm will usually pass this method to their
  28. ``Searcher`` instance.
  29. Args:
  30. metric: Metric to optimize
  31. mode: One of ["min", "max"]. Direction to optimize.
  32. config: Tune config dict.
  33. **spec: Any kwargs for forward compatibility.
  34. Info like Experiment.PUBLIC_KEYS is provided through here.
  35. """
  36. if self._metric and metric:
  37. return False
  38. if metric:
  39. self._metric = metric
  40. return True
  41. @property
  42. def total_samples(self):
  43. """Get number of total trials to be generated"""
  44. return 0
  45. def add_configurations(
  46. self, experiments: Union["Experiment", List["Experiment"], Dict[str, Dict]]
  47. ):
  48. """Tracks given experiment specifications.
  49. Arguments:
  50. experiments: Experiments to run.
  51. """
  52. raise NotImplementedError
  53. def next_trial(self):
  54. """Returns single Trial object to be queued into the TrialRunner.
  55. Returns:
  56. trial: Returns a Trial object.
  57. """
  58. raise NotImplementedError
  59. def on_trial_result(self, trial_id: str, result: Dict):
  60. """Called on each intermediate result returned by a trial.
  61. This will only be called when the trial is in the RUNNING state.
  62. Arguments:
  63. trial_id: Identifier for the trial.
  64. result: Result dictionary.
  65. """
  66. pass
  67. def on_trial_complete(
  68. self, trial_id: str, result: Optional[Dict] = None, error: bool = False
  69. ):
  70. """Notification for the completion of trial.
  71. Arguments:
  72. trial_id: Identifier for the trial.
  73. result: Defaults to None. A dict will
  74. be provided with this notification when the trial is in
  75. the RUNNING state AND either completes naturally or
  76. by manual termination.
  77. error: Defaults to False. True if the trial is in
  78. the RUNNING state and errors.
  79. """
  80. pass
  81. def is_finished(self) -> bool:
  82. """Returns True if no trials left to be queued into TrialRunner.
  83. Can return True before all trials have finished executing.
  84. """
  85. return self._finished
  86. def set_finished(self):
  87. """Marks the search algorithm as finished."""
  88. self._finished = True
  89. def has_checkpoint(self, dirpath: str) -> bool:
  90. """Should return False if restoring is not implemented."""
  91. return False
  92. def save_to_dir(self, dirpath: str, **kwargs):
  93. """Saves a search algorithm."""
  94. pass
  95. def restore_from_dir(self, dirpath: str):
  96. """Restores a search algorithm along with its wrapped state."""
  97. pass