| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124 |
- import logging
- import os
- import random
- import time
- from collections import defaultdict
- from pathlib import Path
- from typing import Dict
- from ray.tune.callback import Callback
- from ray.tune.experiment import Trial
- logger = logging.getLogger(__name__)
- class FailureInjectorCallback(Callback):
- """Adds random failure injection to the TrialExecutor."""
- def __init__(
- self,
- config_path="~/ray_bootstrap_config.yaml",
- probability=0.1,
- time_between_checks=0,
- disable=False,
- ):
- self.probability = probability
- self.config_path = Path(config_path).expanduser().as_posix()
- self.disable = disable
- self.time_between_checks = time_between_checks
- # Initialize with current time so we don't fail right away
- self.last_fail_check = time.monotonic()
- def on_step_begin(self, **info):
- if not os.path.exists(self.config_path):
- return
- if time.monotonic() < self.last_fail_check + self.time_between_checks:
- return
- self.last_fail_check = time.monotonic()
- import click
- from ray.autoscaler._private.commands import kill_node
- failures = 0
- max_failures = 3
- # With 10% probability inject failure to a worker.
- if random.random() < self.probability and not self.disable:
- # With 10% probability fully terminate the node.
- should_terminate = random.random() < self.probability
- while failures < max_failures:
- try:
- kill_node(
- self.config_path,
- yes=True,
- hard=should_terminate,
- override_cluster_name=None,
- )
- return
- except click.exceptions.ClickException:
- failures += 1
- logger.exception(
- "Killing random node failed in attempt "
- "{}. "
- "Retrying {} more times".format(
- str(failures), str(max_failures - failures)
- )
- )
- class TrialStatusSnapshot:
- """A sequence of statuses of trials as they progress.
- If all trials keep previous status, no snapshot is taken.
- """
- def __init__(self):
- self._snapshot = []
- def append(self, new_snapshot: Dict[str, str]):
- """May append a new snapshot to the sequence."""
- if not new_snapshot:
- # Don't add an empty snapshot.
- return
- if not self._snapshot or new_snapshot != self._snapshot[-1]:
- self._snapshot.append(new_snapshot)
- def max_running_trials(self) -> int:
- """Outputs the max number of running trials at a given time.
- Usually used to assert certain number given resource restrictions.
- """
- result = 0
- for snapshot in self._snapshot:
- count = 0
- for trial_id in snapshot:
- if snapshot[trial_id] == Trial.RUNNING:
- count += 1
- result = max(result, count)
- return result
- def all_trials_are_terminated(self) -> bool:
- """True if all trials are terminated."""
- if not self._snapshot:
- return False
- last_snapshot = self._snapshot[-1]
- return all(
- last_snapshot[trial_id] == Trial.TERMINATED for trial_id in last_snapshot
- )
- class TrialStatusSnapshotTaker(Callback):
- """Collects a sequence of statuses of trials as they progress.
- If all trials keep previous status, no snapshot is taken.
- """
- def __init__(self, snapshot: TrialStatusSnapshot):
- self._snapshot = snapshot
- def on_step_end(self, iteration, trials, **kwargs):
- new_snapshot = defaultdict(str)
- for trial in trials:
- new_snapshot[trial.trial_id] = trial.status
- self._snapshot.append(new_snapshot)
|