| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394 |
- from collections import defaultdict, deque
- from typing import Dict, Optional
- import numpy as np
- from ray.tune.stopper.stopper import Stopper
- from ray.util.annotations import PublicAPI
- @PublicAPI
- class TrialPlateauStopper(Stopper):
- """Early stop single trials when they reached a plateau.
- When the standard deviation of the `metric` result of a trial is
- below a threshold `std`, the trial plateaued and will be stopped
- early.
- Args:
- metric: Metric to check for convergence.
- std: Maximum metric standard deviation to decide if a
- trial plateaued. Defaults to 0.01.
- num_results: Number of results to consider for stdev
- calculation.
- grace_period: Minimum number of timesteps before a trial
- can be early stopped
- metric_threshold (Optional[float]):
- Minimum or maximum value the result has to exceed before it can
- be stopped early.
- mode: If a `metric_threshold` argument has been
- passed, this must be one of [min, max]. Specifies if we optimize
- for a large metric (max) or a small metric (min). If max, the
- `metric_threshold` has to be exceeded, if min the value has to
- be lower than `metric_threshold` in order to early stop.
- """
- def __init__(
- self,
- metric: str,
- std: float = 0.01,
- num_results: int = 4,
- grace_period: int = 4,
- metric_threshold: Optional[float] = None,
- mode: Optional[str] = None,
- ):
- self._metric = metric
- self._mode = mode
- self._std = std
- self._num_results = num_results
- self._grace_period = grace_period
- self._metric_threshold = metric_threshold
- if self._metric_threshold:
- if mode not in ["min", "max"]:
- raise ValueError(
- f"When specifying a `metric_threshold`, the `mode` "
- f"argument has to be one of [min, max]. "
- f"Got: {mode}"
- )
- self._iter = defaultdict(lambda: 0)
- self._trial_results = defaultdict(lambda: deque(maxlen=self._num_results))
- def __call__(self, trial_id: str, result: Dict):
- metric_result = result.get(self._metric)
- self._trial_results[trial_id].append(metric_result)
- self._iter[trial_id] += 1
- # If still in grace period, do not stop yet
- if self._iter[trial_id] < self._grace_period:
- return False
- # If not enough results yet, do not stop yet
- if len(self._trial_results[trial_id]) < self._num_results:
- return False
- # If metric threshold value not reached, do not stop yet
- if self._metric_threshold is not None:
- if self._mode == "min" and metric_result > self._metric_threshold:
- return False
- elif self._mode == "max" and metric_result < self._metric_threshold:
- return False
- # Calculate stdev of last `num_results` results
- try:
- current_std = np.std(self._trial_results[trial_id])
- except Exception:
- current_std = float("inf")
- # If stdev is lower than threshold, stop early.
- return current_std < self._std
- def stop_all(self):
- return False
|