trial_plateau.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. from collections import defaultdict, deque
  2. from typing import Dict, Optional
  3. import numpy as np
  4. from ray.tune.stopper.stopper import Stopper
  5. from ray.util.annotations import PublicAPI
  6. @PublicAPI
  7. class TrialPlateauStopper(Stopper):
  8. """Early stop single trials when they reached a plateau.
  9. When the standard deviation of the `metric` result of a trial is
  10. below a threshold `std`, the trial plateaued and will be stopped
  11. early.
  12. Args:
  13. metric: Metric to check for convergence.
  14. std: Maximum metric standard deviation to decide if a
  15. trial plateaued. Defaults to 0.01.
  16. num_results: Number of results to consider for stdev
  17. calculation.
  18. grace_period: Minimum number of timesteps before a trial
  19. can be early stopped
  20. metric_threshold (Optional[float]):
  21. Minimum or maximum value the result has to exceed before it can
  22. be stopped early.
  23. mode: If a `metric_threshold` argument has been
  24. passed, this must be one of [min, max]. Specifies if we optimize
  25. for a large metric (max) or a small metric (min). If max, the
  26. `metric_threshold` has to be exceeded, if min the value has to
  27. be lower than `metric_threshold` in order to early stop.
  28. """
  29. def __init__(
  30. self,
  31. metric: str,
  32. std: float = 0.01,
  33. num_results: int = 4,
  34. grace_period: int = 4,
  35. metric_threshold: Optional[float] = None,
  36. mode: Optional[str] = None,
  37. ):
  38. self._metric = metric
  39. self._mode = mode
  40. self._std = std
  41. self._num_results = num_results
  42. self._grace_period = grace_period
  43. self._metric_threshold = metric_threshold
  44. if self._metric_threshold:
  45. if mode not in ["min", "max"]:
  46. raise ValueError(
  47. f"When specifying a `metric_threshold`, the `mode` "
  48. f"argument has to be one of [min, max]. "
  49. f"Got: {mode}"
  50. )
  51. self._iter = defaultdict(lambda: 0)
  52. self._trial_results = defaultdict(lambda: deque(maxlen=self._num_results))
  53. def __call__(self, trial_id: str, result: Dict):
  54. metric_result = result.get(self._metric)
  55. self._trial_results[trial_id].append(metric_result)
  56. self._iter[trial_id] += 1
  57. # If still in grace period, do not stop yet
  58. if self._iter[trial_id] < self._grace_period:
  59. return False
  60. # If not enough results yet, do not stop yet
  61. if len(self._trial_results[trial_id]) < self._num_results:
  62. return False
  63. # If metric threshold value not reached, do not stop yet
  64. if self._metric_threshold is not None:
  65. if self._mode == "min" and metric_result > self._metric_threshold:
  66. return False
  67. elif self._mode == "max" and metric_result < self._metric_threshold:
  68. return False
  69. # Calculate stdev of last `num_results` results
  70. try:
  71. current_std = np.std(self._trial_results[trial_id])
  72. except Exception:
  73. current_std = float("inf")
  74. # If stdev is lower than threshold, stop early.
  75. return current_std < self._std
  76. def stop_all(self):
  77. return False