maximum_iteration.py 656 B

12345678910111213141516171819202122232425
  1. from collections import defaultdict
  2. from typing import Dict
  3. from ray.tune.stopper.stopper import Stopper
  4. from ray.util.annotations import PublicAPI
  5. @PublicAPI
  6. class MaximumIterationStopper(Stopper):
  7. """Stop trials after reaching a maximum number of iterations
  8. Args:
  9. max_iter: Number of iterations before stopping a trial.
  10. """
  11. def __init__(self, max_iter: int):
  12. self._max_iter = max_iter
  13. self._iter = defaultdict(lambda: 0)
  14. def __call__(self, trial_id: str, result: Dict):
  15. self._iter[trial_id] += 1
  16. return self._iter[trial_id] >= self._max_iter
  17. def stop_all(self):
  18. return False