timeout.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. import datetime
  2. import time
  3. from typing import Union
  4. from ray import logger
  5. from ray.tune.stopper.stopper import Stopper
  6. from ray.util.annotations import PublicAPI
  7. @PublicAPI
  8. class TimeoutStopper(Stopper):
  9. """Stops all trials after a certain timeout.
  10. This stopper is automatically created when the `time_budget_s`
  11. argument is passed to `tune.RunConfig()`.
  12. Args:
  13. timeout: Either a number specifying the timeout in seconds, or
  14. a `datetime.timedelta` object.
  15. """
  16. def __init__(self, timeout: Union[int, float, datetime.timedelta]):
  17. from datetime import timedelta
  18. if isinstance(timeout, timedelta):
  19. self._timeout_seconds = timeout.total_seconds()
  20. elif isinstance(timeout, (int, float)):
  21. self._timeout_seconds = timeout
  22. else:
  23. raise ValueError(
  24. "`timeout` parameter has to be either a number or a "
  25. "`datetime.timedelta` object. Found: {}".format(type(timeout))
  26. )
  27. self._budget = self._timeout_seconds
  28. # To account for setup overhead, set the last check time only after
  29. # the first call to `stop_all()`.
  30. self._last_check = None
  31. def __call__(self, trial_id, result):
  32. return False
  33. def stop_all(self):
  34. now = time.time()
  35. if self._last_check:
  36. taken = now - self._last_check
  37. self._budget -= taken
  38. self._last_check = now
  39. if self._budget <= 0:
  40. logger.info(
  41. f"Reached timeout of {self._timeout_seconds} seconds. "
  42. f"Stopping all trials."
  43. )
  44. return True
  45. return False
  46. def __setstate__(self, state: dict):
  47. state["_last_check"] = None
  48. self.__dict__.update(state)