mock.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. import logging
  2. import os
  3. import random
  4. import time
  5. from collections import defaultdict
  6. from pathlib import Path
  7. from typing import Dict
  8. from ray.tune.callback import Callback
  9. from ray.tune.experiment import Trial
  10. logger = logging.getLogger(__name__)
  11. class FailureInjectorCallback(Callback):
  12. """Adds random failure injection to the TrialExecutor."""
  13. def __init__(
  14. self,
  15. config_path="~/ray_bootstrap_config.yaml",
  16. probability=0.1,
  17. time_between_checks=0,
  18. disable=False,
  19. ):
  20. self.probability = probability
  21. self.config_path = Path(config_path).expanduser().as_posix()
  22. self.disable = disable
  23. self.time_between_checks = time_between_checks
  24. # Initialize with current time so we don't fail right away
  25. self.last_fail_check = time.monotonic()
  26. def on_step_begin(self, **info):
  27. if not os.path.exists(self.config_path):
  28. return
  29. if time.monotonic() < self.last_fail_check + self.time_between_checks:
  30. return
  31. self.last_fail_check = time.monotonic()
  32. import click
  33. from ray.autoscaler._private.commands import kill_node
  34. failures = 0
  35. max_failures = 3
  36. # With 10% probability inject failure to a worker.
  37. if random.random() < self.probability and not self.disable:
  38. # With 10% probability fully terminate the node.
  39. should_terminate = random.random() < self.probability
  40. while failures < max_failures:
  41. try:
  42. kill_node(
  43. self.config_path,
  44. yes=True,
  45. hard=should_terminate,
  46. override_cluster_name=None,
  47. )
  48. return
  49. except click.exceptions.ClickException:
  50. failures += 1
  51. logger.exception(
  52. "Killing random node failed in attempt "
  53. "{}. "
  54. "Retrying {} more times".format(
  55. str(failures), str(max_failures - failures)
  56. )
  57. )
  58. class TrialStatusSnapshot:
  59. """A sequence of statuses of trials as they progress.
  60. If all trials keep previous status, no snapshot is taken.
  61. """
  62. def __init__(self):
  63. self._snapshot = []
  64. def append(self, new_snapshot: Dict[str, str]):
  65. """May append a new snapshot to the sequence."""
  66. if not new_snapshot:
  67. # Don't add an empty snapshot.
  68. return
  69. if not self._snapshot or new_snapshot != self._snapshot[-1]:
  70. self._snapshot.append(new_snapshot)
  71. def max_running_trials(self) -> int:
  72. """Outputs the max number of running trials at a given time.
  73. Usually used to assert certain number given resource restrictions.
  74. """
  75. result = 0
  76. for snapshot in self._snapshot:
  77. count = 0
  78. for trial_id in snapshot:
  79. if snapshot[trial_id] == Trial.RUNNING:
  80. count += 1
  81. result = max(result, count)
  82. return result
  83. def all_trials_are_terminated(self) -> bool:
  84. """True if all trials are terminated."""
  85. if not self._snapshot:
  86. return False
  87. last_snapshot = self._snapshot[-1]
  88. return all(
  89. last_snapshot[trial_id] == Trial.TERMINATED for trial_id in last_snapshot
  90. )
  91. class TrialStatusSnapshotTaker(Callback):
  92. """Collects a sequence of statuses of trials as they progress.
  93. If all trials keep previous status, no snapshot is taken.
  94. """
  95. def __init__(self, snapshot: TrialStatusSnapshot):
  96. self._snapshot = snapshot
  97. def on_step_end(self, iteration, trials, **kwargs):
  98. new_snapshot = defaultdict(str)
  99. for trial in trials:
  100. new_snapshot[trial.trial_id] = trial.status
  101. self._snapshot.append(new_snapshot)