import logging import os import random import time from collections import defaultdict from pathlib import Path from typing import Dict from ray.tune.callback import Callback from ray.tune.experiment import Trial logger = logging.getLogger(__name__) class FailureInjectorCallback(Callback): """Adds random failure injection to the TrialExecutor.""" def __init__( self, config_path="~/ray_bootstrap_config.yaml", probability=0.1, time_between_checks=0, disable=False, ): self.probability = probability self.config_path = Path(config_path).expanduser().as_posix() self.disable = disable self.time_between_checks = time_between_checks # Initialize with current time so we don't fail right away self.last_fail_check = time.monotonic() def on_step_begin(self, **info): if not os.path.exists(self.config_path): return if time.monotonic() < self.last_fail_check + self.time_between_checks: return self.last_fail_check = time.monotonic() import click from ray.autoscaler._private.commands import kill_node failures = 0 max_failures = 3 # With 10% probability inject failure to a worker. if random.random() < self.probability and not self.disable: # With 10% probability fully terminate the node. should_terminate = random.random() < self.probability while failures < max_failures: try: kill_node( self.config_path, yes=True, hard=should_terminate, override_cluster_name=None, ) return except click.exceptions.ClickException: failures += 1 logger.exception( "Killing random node failed in attempt " "{}. " "Retrying {} more times".format( str(failures), str(max_failures - failures) ) ) class TrialStatusSnapshot: """A sequence of statuses of trials as they progress. If all trials keep previous status, no snapshot is taken. """ def __init__(self): self._snapshot = [] def append(self, new_snapshot: Dict[str, str]): """May append a new snapshot to the sequence.""" if not new_snapshot: # Don't add an empty snapshot. return if not self._snapshot or new_snapshot != self._snapshot[-1]: self._snapshot.append(new_snapshot) def max_running_trials(self) -> int: """Outputs the max number of running trials at a given time. Usually used to assert certain number given resource restrictions. """ result = 0 for snapshot in self._snapshot: count = 0 for trial_id in snapshot: if snapshot[trial_id] == Trial.RUNNING: count += 1 result = max(result, count) return result def all_trials_are_terminated(self) -> bool: """True if all trials are terminated.""" if not self._snapshot: return False last_snapshot = self._snapshot[-1] return all( last_snapshot[trial_id] == Trial.TERMINATED for trial_id in last_snapshot ) class TrialStatusSnapshotTaker(Callback): """Collects a sequence of statuses of trials as they progress. If all trials keep previous status, no snapshot is taken. """ def __init__(self, snapshot: TrialStatusSnapshot): self._snapshot = snapshot def on_step_end(self, iteration, trials, **kwargs): new_snapshot = defaultdict(str) for trial in trials: new_snapshot[trial.trial_id] = trial.status self._snapshot.append(new_snapshot)