function_stopper.py 1.1 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. from typing import Callable, Dict
  2. from ray.tune.stopper.stopper import Stopper
  3. from ray.util.annotations import PublicAPI
  4. @PublicAPI
  5. class FunctionStopper(Stopper):
  6. """Provide a custom function to check if trial should be stopped.
  7. The passed function will be called after each iteration. If it returns
  8. True, the trial will be stopped.
  9. Args:
  10. function: Function that checks if a trial
  11. should be stopped. Must accept the `trial_id` string and `result`
  12. dictionary as arguments. Must return a boolean.
  13. """
  14. def __init__(self, function: Callable[[str, Dict], bool]):
  15. self._fn = function
  16. def __call__(self, trial_id, result):
  17. return self._fn(trial_id, result)
  18. def stop_all(self):
  19. return False
  20. @classmethod
  21. def is_valid_function(cls, fn):
  22. is_function = callable(fn) and not issubclass(type(fn), Stopper)
  23. if is_function and hasattr(fn, "stop_all"):
  24. raise ValueError(
  25. "Stop object must be ray.tune.Stopper subclass to be detected "
  26. "correctly."
  27. )
  28. return is_function