experiment_plateau.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import numpy as np
  2. from ray.tune.stopper.stopper import Stopper
  3. from ray.util.annotations import PublicAPI
  4. @PublicAPI
  5. class ExperimentPlateauStopper(Stopper):
  6. """Early stop the experiment when a metric plateaued across trials.
  7. Stops the entire experiment when the metric has plateaued
  8. for more than the given amount of iterations specified in
  9. the patience parameter.
  10. Args:
  11. metric: The metric to be monitored.
  12. std: The minimal standard deviation after which
  13. the tuning process has to stop.
  14. top: The number of best models to consider.
  15. mode: The mode to select the top results.
  16. Can either be "min" or "max".
  17. patience: Number of epochs to wait for
  18. a change in the top models.
  19. Raises:
  20. ValueError: If the mode parameter is not "min" nor "max".
  21. ValueError: If the top parameter is not an integer
  22. greater than 1.
  23. ValueError: If the standard deviation parameter is not
  24. a strictly positive float.
  25. ValueError: If the patience parameter is not
  26. a strictly positive integer.
  27. """
  28. def __init__(
  29. self,
  30. metric: str,
  31. std: float = 0.001,
  32. top: int = 10,
  33. mode: str = "min",
  34. patience: int = 0,
  35. ):
  36. if mode not in ("min", "max"):
  37. raise ValueError("The mode parameter can only be either min or max.")
  38. if not isinstance(top, int) or top <= 1:
  39. raise ValueError(
  40. "Top results to consider must be"
  41. " a positive integer greater than one."
  42. )
  43. if not isinstance(patience, int) or patience < 0:
  44. raise ValueError("Patience must be a strictly positive integer.")
  45. if not isinstance(std, float) or std <= 0:
  46. raise ValueError(
  47. "The standard deviation must be a strictly positive float number."
  48. )
  49. self._mode = mode
  50. self._metric = metric
  51. self._patience = patience
  52. self._iterations = 0
  53. self._std = std
  54. self._top = top
  55. self._top_values = []
  56. def __call__(self, trial_id, result):
  57. """Return a boolean representing if the tuning has to stop."""
  58. self._top_values.append(result[self._metric])
  59. if self._mode == "min":
  60. self._top_values = sorted(self._top_values)[: self._top]
  61. else:
  62. self._top_values = sorted(self._top_values)[-self._top :]
  63. # If the current iteration has to stop
  64. if self.has_plateaued():
  65. # we increment the total counter of iterations
  66. self._iterations += 1
  67. else:
  68. # otherwise we reset the counter
  69. self._iterations = 0
  70. # and then call the method that re-executes
  71. # the checks, including the iterations.
  72. return self.stop_all()
  73. def has_plateaued(self):
  74. return (
  75. len(self._top_values) == self._top and np.std(self._top_values) <= self._std
  76. )
  77. def stop_all(self):
  78. """Return whether to stop and prevent trials from starting."""
  79. return self.has_plateaued() and self._iterations >= self._patience