stopper.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import abc
  2. from typing import Any, Dict
  3. from ray.util.annotations import PublicAPI
  4. @PublicAPI
  5. class Stopper(abc.ABC):
  6. """Base class for implementing a Tune experiment stopper.
  7. Allows users to implement experiment-level stopping via ``stop_all``. By
  8. default, this class does not stop any trials. Subclasses need to
  9. implement ``__call__`` and ``stop_all``.
  10. Examples:
  11. >>> import time
  12. >>> from ray import tune
  13. >>> from ray.tune import Stopper
  14. >>>
  15. >>> class TimeStopper(Stopper):
  16. ... def __init__(self):
  17. ... self._start = time.time()
  18. ... self._deadline = 2 # Stop all trials after 2 seconds
  19. ...
  20. ... def __call__(self, trial_id, result):
  21. ... return False
  22. ...
  23. ... def stop_all(self):
  24. ... return time.time() - self._start > self._deadline
  25. ...
  26. >>> def train_fn(config):
  27. ... for i in range(100):
  28. ... time.sleep(1)
  29. ... tune.report({"iter": i})
  30. ...
  31. >>> tuner = tune.Tuner(
  32. ... train_fn,
  33. ... tune_config=tune.TuneConfig(num_samples=2),
  34. ... run_config=tune.RunConfig(stop=TimeStopper()),
  35. ... )
  36. >>> print("[ignore]"); result_grid = tuner.fit() # doctest: +ELLIPSIS
  37. [ignore]...
  38. """
  39. def __call__(self, trial_id: str, result: Dict[str, Any]) -> bool:
  40. """Returns true if the trial should be terminated given the result."""
  41. raise NotImplementedError
  42. def stop_all(self) -> bool:
  43. """Returns true if the experiment should be terminated."""
  44. raise NotImplementedError
  45. @PublicAPI
  46. class CombinedStopper(Stopper):
  47. """Combine several stoppers via 'OR'.
  48. Args:
  49. *stoppers: Stoppers to be combined.
  50. Examples:
  51. >>> import numpy as np
  52. >>> from ray import tune
  53. >>> from ray.tune.stopper import (
  54. ... CombinedStopper,
  55. ... MaximumIterationStopper,
  56. ... TrialPlateauStopper,
  57. ... )
  58. >>>
  59. >>> stopper = CombinedStopper(
  60. ... MaximumIterationStopper(max_iter=10),
  61. ... TrialPlateauStopper(metric="my_metric"),
  62. ... )
  63. >>> def train_fn(config):
  64. ... for i in range(15):
  65. ... tune.report({"my_metric": np.random.normal(0, 1 - i / 15)})
  66. ...
  67. >>> tuner = tune.Tuner(
  68. ... train_fn,
  69. ... run_config=tune.RunConfig(stop=stopper),
  70. ... )
  71. >>> print("[ignore]"); result_grid = tuner.fit() # doctest: +ELLIPSIS
  72. [ignore]...
  73. >>> all(result.metrics["training_iteration"] <= 20 for result in result_grid)
  74. True
  75. """
  76. def __init__(self, *stoppers: Stopper):
  77. self._stoppers = stoppers
  78. def __call__(self, trial_id: str, result: Dict[str, Any]) -> bool:
  79. return any(s(trial_id, result) for s in self._stoppers)
  80. def stop_all(self) -> bool:
  81. return any(s.stop_all() for s in self._stoppers)