trial_scheduler.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. from typing import TYPE_CHECKING, Dict, Optional
  2. from ray.air._internal.usage import tag_scheduler
  3. from ray.tune.experiment import Trial
  4. from ray.tune.result import DEFAULT_METRIC
  5. from ray.util.annotations import DeveloperAPI, PublicAPI
  6. if TYPE_CHECKING:
  7. from ray.tune.execution.tune_controller import TuneController
  8. @DeveloperAPI
  9. class TrialScheduler:
  10. """Interface for implementing a Trial Scheduler class.
  11. Note to Tune developers: If a new scheduler is added, please update
  12. `air/_internal/usage.py`.
  13. """
  14. CONTINUE = "CONTINUE" #: Status for continuing trial execution
  15. PAUSE = "PAUSE" #: Status for pausing trial execution
  16. STOP = "STOP" #: Status for stopping trial execution
  17. # Caution: Temporary and anti-pattern! This means Scheduler calls
  18. # into Executor directly without going through TrialRunner.
  19. # TODO(xwjiang): Deprecate this after we control the interaction
  20. # between schedulers and executor.
  21. NOOP = "NOOP"
  22. _metric = None
  23. _supports_buffered_results = True
  24. def __init__(self):
  25. tag_scheduler(self)
  26. @property
  27. def metric(self):
  28. return self._metric
  29. @property
  30. def supports_buffered_results(self):
  31. return self._supports_buffered_results
  32. def set_search_properties(
  33. self, metric: Optional[str], mode: Optional[str], **spec
  34. ) -> bool:
  35. """Pass search properties to scheduler.
  36. This method acts as an alternative to instantiating schedulers
  37. that react to metrics with their own `metric` and `mode` parameters.
  38. Args:
  39. metric: Metric to optimize
  40. mode: One of ["min", "max"]. Direction to optimize.
  41. **spec: Any kwargs for forward compatibility.
  42. Info like Experiment.PUBLIC_KEYS is provided through here.
  43. """
  44. if self._metric and metric:
  45. return False
  46. if metric:
  47. self._metric = metric
  48. if self._metric is None:
  49. # Per default, use anonymous metric
  50. self._metric = DEFAULT_METRIC
  51. return True
  52. def on_trial_add(self, tune_controller: "TuneController", trial: Trial):
  53. """Called when a new trial is added to the trial runner."""
  54. raise NotImplementedError
  55. def on_trial_error(self, tune_controller: "TuneController", trial: Trial):
  56. """Notification for the error of trial.
  57. This will only be called when the trial is in the RUNNING state."""
  58. raise NotImplementedError
  59. def on_trial_result(
  60. self, tune_controller: "TuneController", trial: Trial, result: Dict
  61. ) -> str:
  62. """Called on each intermediate result returned by a trial.
  63. At this point, the trial scheduler can make a decision by returning
  64. one of CONTINUE, PAUSE, and STOP. This will only be called when the
  65. trial is in the RUNNING state."""
  66. raise NotImplementedError
  67. def on_trial_complete(
  68. self, tune_controller: "TuneController", trial: Trial, result: Dict
  69. ):
  70. """Notification for the completion of trial.
  71. This will only be called when the trial is in the RUNNING state and
  72. either completes naturally or by manual termination."""
  73. raise NotImplementedError
  74. def on_trial_remove(self, tune_controller: "TuneController", trial: Trial):
  75. """Called to remove trial.
  76. This is called when the trial is in PAUSED or PENDING state. Otherwise,
  77. call `on_trial_complete`."""
  78. raise NotImplementedError
  79. def choose_trial_to_run(self, tune_controller: "TuneController") -> Optional[Trial]:
  80. """Called to choose a new trial to run.
  81. This should return one of the trials in tune_controller that is in
  82. the PENDING or PAUSED state. This function must be idempotent.
  83. If no trial is ready, return None."""
  84. raise NotImplementedError
  85. def debug_string(self) -> str:
  86. """Returns a human readable message for printing to the console."""
  87. raise NotImplementedError
  88. def save(self, checkpoint_path: str):
  89. """Save trial scheduler to a checkpoint"""
  90. raise NotImplementedError
  91. def restore(self, checkpoint_path: str):
  92. """Restore trial scheduler from checkpoint."""
  93. raise NotImplementedError
  94. @PublicAPI
  95. class FIFOScheduler(TrialScheduler):
  96. """Simple scheduler that just runs trials in submission order."""
  97. def __init__(self):
  98. super().__init__()
  99. def on_trial_add(self, tune_controller: "TuneController", trial: Trial):
  100. pass
  101. def on_trial_error(self, tune_controller: "TuneController", trial: Trial):
  102. pass
  103. def on_trial_result(
  104. self, tune_controller: "TuneController", trial: Trial, result: Dict
  105. ) -> str:
  106. return TrialScheduler.CONTINUE
  107. def on_trial_complete(
  108. self, tune_controller: "TuneController", trial: Trial, result: Dict
  109. ):
  110. pass
  111. def on_trial_remove(self, tune_controller: "TuneController", trial: Trial):
  112. pass
  113. def choose_trial_to_run(self, tune_controller: "TuneController") -> Optional[Trial]:
  114. for trial in tune_controller.get_trials():
  115. if trial.status == Trial.PENDING:
  116. return trial
  117. for trial in tune_controller.get_trials():
  118. if trial.status == Trial.PAUSED:
  119. return trial
  120. return None
  121. def debug_string(self) -> str:
  122. return "Using FIFO scheduling algorithm."