| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091 |
- import numpy as np
- from ray.tune.stopper.stopper import Stopper
- from ray.util.annotations import PublicAPI
- @PublicAPI
- class ExperimentPlateauStopper(Stopper):
- """Early stop the experiment when a metric plateaued across trials.
- Stops the entire experiment when the metric has plateaued
- for more than the given amount of iterations specified in
- the patience parameter.
- Args:
- metric: The metric to be monitored.
- std: The minimal standard deviation after which
- the tuning process has to stop.
- top: The number of best models to consider.
- mode: The mode to select the top results.
- Can either be "min" or "max".
- patience: Number of epochs to wait for
- a change in the top models.
- Raises:
- ValueError: If the mode parameter is not "min" nor "max".
- ValueError: If the top parameter is not an integer
- greater than 1.
- ValueError: If the standard deviation parameter is not
- a strictly positive float.
- ValueError: If the patience parameter is not
- a strictly positive integer.
- """
- def __init__(
- self,
- metric: str,
- std: float = 0.001,
- top: int = 10,
- mode: str = "min",
- patience: int = 0,
- ):
- if mode not in ("min", "max"):
- raise ValueError("The mode parameter can only be either min or max.")
- if not isinstance(top, int) or top <= 1:
- raise ValueError(
- "Top results to consider must be"
- " a positive integer greater than one."
- )
- if not isinstance(patience, int) or patience < 0:
- raise ValueError("Patience must be a strictly positive integer.")
- if not isinstance(std, float) or std <= 0:
- raise ValueError(
- "The standard deviation must be a strictly positive float number."
- )
- self._mode = mode
- self._metric = metric
- self._patience = patience
- self._iterations = 0
- self._std = std
- self._top = top
- self._top_values = []
- def __call__(self, trial_id, result):
- """Return a boolean representing if the tuning has to stop."""
- self._top_values.append(result[self._metric])
- if self._mode == "min":
- self._top_values = sorted(self._top_values)[: self._top]
- else:
- self._top_values = sorted(self._top_values)[-self._top :]
- # If the current iteration has to stop
- if self.has_plateaued():
- # we increment the total counter of iterations
- self._iterations += 1
- else:
- # otherwise we reset the counter
- self._iterations = 0
- # and then call the method that re-executes
- # the checks, including the iterations.
- return self.stop_all()
- def has_plateaued(self):
- return (
- len(self._top_values) == self._top and np.std(self._top_values) <= self._std
- )
- def stop_all(self):
- """Return whether to stop and prevent trials from starting."""
- return self.has_plateaued() and self._iterations >= self._patience
|