tune_controller.py 83 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184
  1. import copy
  2. import json
  3. import logging
  4. import os
  5. import time
  6. import traceback
  7. import warnings
  8. from collections import defaultdict, deque
  9. from datetime import datetime
  10. from functools import partial
  11. from pathlib import Path
  12. from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
  13. import ray
  14. from ray.air import ResourceRequest
  15. from ray.air.constants import TIME_THIS_ITER_S
  16. from ray.air.execution import PlacementGroupResourceManager, ResourceManager
  17. from ray.air.execution._internal import RayActorManager, TrackedActor
  18. from ray.exceptions import RayActorError, RayTaskError
  19. from ray.train._internal.session import _FutureTrainingResult, _TrainingResult
  20. from ray.train._internal.storage import StorageContext
  21. from ray.tune import CheckpointConfig
  22. from ray.tune.callback import Callback, CallbackList
  23. from ray.tune.error import TuneError, _AbortTrialExecution, _TuneStopTrialError
  24. from ray.tune.execution.class_cache import _ActorClassCache
  25. from ray.tune.execution.experiment_state import (
  26. _ExperimentCheckpointManager,
  27. _find_newest_experiment_checkpoint,
  28. )
  29. from ray.tune.execution.insufficient_resources_manager import (
  30. _InsufficientResourcesManager,
  31. )
  32. from ray.tune.execution.placement_groups import PlacementGroupFactory
  33. from ray.tune.experiment import Experiment, Trial
  34. from ray.tune.experiment.trial import (
  35. _change_working_directory,
  36. _get_trainable_kwargs,
  37. _Location,
  38. _noop_logger_creator,
  39. _TrialInfo,
  40. )
  41. from ray.tune.result import (
  42. DEBUG_METRICS,
  43. DEFAULT_METRIC,
  44. DONE,
  45. RESULT_DUPLICATE,
  46. SHOULD_CHECKPOINT,
  47. STDERR_FILE,
  48. STDOUT_FILE,
  49. TRIAL_INFO,
  50. )
  51. from ray.tune.schedulers import FIFOScheduler, TrialScheduler
  52. from ray.tune.search import BasicVariantGenerator, SearchAlgorithm
  53. from ray.tune.stopper import NoopStopper, Stopper
  54. from ray.tune.tune_config import ResumeConfig
  55. from ray.tune.utils import flatten_dict, warn_if_slow
  56. from ray.tune.utils.log import Verbosity, _dedup_logs, has_verbosity
  57. from ray.tune.utils.object_cache import _ObjectCache
  58. from ray.tune.utils.resource_updater import _ResourceUpdater
  59. from ray.tune.utils.serialization import TuneFunctionDecoder, TuneFunctionEncoder
  60. from ray.util.annotations import DeveloperAPI
  61. from ray.util.debug import log_once
  62. logger = logging.getLogger(__name__)
  63. @DeveloperAPI
  64. class TuneController:
  65. CKPT_FILE_TMPL = "experiment_state-{}.json"
  66. RAISE = "RAISE"
  67. def __init__(
  68. self,
  69. *,
  70. search_alg: Optional[SearchAlgorithm] = None,
  71. placeholder_resolvers: Optional[Dict[Tuple, Any]] = None,
  72. scheduler: Optional[TrialScheduler] = None,
  73. stopper: Optional[Stopper] = None,
  74. resume_config: Optional[ResumeConfig] = None,
  75. fail_fast: bool = False,
  76. checkpoint_period: Union[str, int] = None,
  77. callbacks: Optional[List[Callback]] = None,
  78. metric: Optional[str] = None,
  79. trial_checkpoint_config: Optional[CheckpointConfig] = None,
  80. storage: Optional[StorageContext] = None,
  81. reuse_actors: bool = False,
  82. resource_manager_factory: Optional[Callable[[], ResourceManager]] = None,
  83. _trainer_api: bool = False,
  84. ):
  85. if resource_manager_factory:
  86. resource_manager = resource_manager_factory()
  87. else:
  88. resource_manager = PlacementGroupResourceManager()
  89. self._actor_manager = RayActorManager(resource_manager=resource_manager)
  90. self._class_cache = _ActorClassCache()
  91. # Resource status
  92. self._resource_updater = _ResourceUpdater(None)
  93. # Actor <-> Trial mappings
  94. self._actor_to_trial: Dict[TrackedActor, Trial] = {}
  95. self._trial_to_actor: Dict[Trial, TrackedActor] = {}
  96. # Resources <-> Trial
  97. self._resources_to_pending_trials: Dict[
  98. ResourceRequest, Set[Trial]
  99. ] = defaultdict(set)
  100. # Keep track of actor states
  101. self._pending_trials: Set[Trial] = set()
  102. self._pending_trials_list: List[Trial] = []
  103. self._running_trials: Set[Trial] = set()
  104. self._paused_trials: Set[Trial] = set()
  105. self._stopped_trials: Set[Trial] = set()
  106. self._failed_trials: Set[Trial] = set()
  107. self._resetting_trials: Set[Trial] = set()
  108. self._staged_trials: Set[Trial] = set()
  109. # Removed actors
  110. self._started_actors: Set[TrackedActor] = set()
  111. # Map of tracked actors -> timestamp
  112. # The timestamp is when we requested the stop.
  113. # We track these actors here to force a
  114. # cleanup after some time (as they might be hanging).
  115. # Todo: This timeout logic should be moved into the actor manager.
  116. # This map is populated whenever we request an actor stop:
  117. # - Regular STOP decision
  118. # - Removing an actor because its trial REUSEs a different trial's actor
  119. # - Removing a cached actor because it's not needed anymore
  120. # Actors are only tracked in this map if they actually started (not if they
  121. # were only requested but never started).
  122. # Actors are removed from this map:
  123. # - When the STOP resolved and the actor actually stopped
  124. # - When they are forcefully cleaned up after the timeout.
  125. self._stopping_actors: Dict[TrackedActor, float] = {}
  126. self._earliest_stopping_actor: float = float("inf")
  127. self._actor_cleanup_timeout: int = int(
  128. os.environ.get("TUNE_FORCE_TRIAL_CLEANUP_S", "600")
  129. )
  130. self._actor_force_cleanup_timeout: int = 10
  131. # Reuse actors
  132. self._reuse_actors = reuse_actors
  133. self._actor_cache = _ObjectCache(may_keep_one=True)
  134. # Trial metadata for experiment checkpoints
  135. self._trials_to_cache: Set[Trial] = set()
  136. self._trial_metadata: Dict[str, str] = {}
  137. # TRAINING
  138. self._buffer_length = int(os.getenv("TUNE_RESULT_BUFFER_LENGTH", 1))
  139. self._buffer_min_time_s = float(os.getenv("TUNE_RESULT_BUFFER_MIN_TIME_S", 0.0))
  140. self._buffer_max_time_s = float(
  141. os.getenv("TUNE_RESULT_BUFFER_MAX_TIME_S", 100.0)
  142. )
  143. # Legacy TrialRunner init
  144. self._search_alg = search_alg or BasicVariantGenerator()
  145. self._placeholder_resolvers = placeholder_resolvers
  146. self._scheduler_alg = scheduler or FIFOScheduler()
  147. self._callbacks = CallbackList(callbacks or [])
  148. self._insufficient_resources_manager = _InsufficientResourcesManager(
  149. for_train=_trainer_api
  150. )
  151. self._pending_trial_queue_times = {}
  152. self._max_pending_trials = _get_max_pending_trials(self._search_alg)
  153. self._storage = storage
  154. self._metric = metric
  155. self._total_time = 0
  156. self._iteration = 0
  157. self._has_errored = False
  158. self._fail_fast = fail_fast
  159. if isinstance(self._fail_fast, str):
  160. self._fail_fast = self._fail_fast.upper()
  161. if self._fail_fast == self.RAISE:
  162. warnings.warn(
  163. "fail_fast='raise' detected. Be careful when using this "
  164. "mode as resources (such as Ray processes, "
  165. "file descriptors, and temporary files) may not be "
  166. "cleaned up properly. To use "
  167. "a safer mode, use fail_fast=True."
  168. )
  169. else:
  170. raise ValueError(
  171. "fail_fast must be one of {bool, RAISE}. " f"Got {self._fail_fast}."
  172. )
  173. self._print_trial_errors = bool(
  174. int(os.environ.get("TUNE_PRINT_ALL_TRIAL_ERRORS", "1"))
  175. )
  176. self._trials: List[Trial] = []
  177. self._live_trials: Set[Trial] = set() # Set of non-terminated trials
  178. self._cached_trial_decisions = {}
  179. self._queued_trial_decisions = {}
  180. self._stop_queue = []
  181. self._should_stop_experiment = False # used by TuneServer
  182. self._stopper = stopper or NoopStopper()
  183. self._start_time = time.time()
  184. self._session_str = datetime.fromtimestamp(self._start_time).strftime(
  185. "%Y-%m-%d_%H-%M-%S"
  186. )
  187. if checkpoint_period is None:
  188. checkpoint_period = os.getenv("TUNE_GLOBAL_CHECKPOINT_S", "auto")
  189. self._checkpoint_period = checkpoint_period
  190. self._trial_checkpoint_config = trial_checkpoint_config or CheckpointConfig()
  191. self._checkpoint_manager = self._create_checkpoint_manager()
  192. self._resumed = False
  193. if resume_config is not None:
  194. # Use the metadata file to restore TuneController state
  195. try:
  196. self.resume(resume_config=resume_config)
  197. self._resumed = True
  198. except Exception as e:
  199. if has_verbosity(Verbosity.V3_TRIAL_DETAILS):
  200. logger.error(str(e))
  201. logger.exception("Failed to restore the run state.")
  202. if self._fail_fast:
  203. raise
  204. logger.info("Restarting experiment.")
  205. else:
  206. logger.debug("Starting a new experiment.")
  207. def _wrapped(self):
  208. """Return wrapped tune controller to be passed to scheduler/searchers."""
  209. return TrialRunnerWrapper(
  210. self,
  211. trial_executor=_FakeRayTrialExecutor(self),
  212. runner_whitelist_attr={
  213. "search_alg",
  214. "get_trials",
  215. "get_live_trials",
  216. "_set_trial_status",
  217. "pause_trial",
  218. "stop_trial",
  219. "_schedule_trial_save",
  220. },
  221. executor_whitelist_attr={
  222. "has_resources_for_trial",
  223. "pause_trial",
  224. "save",
  225. "_resource_updater",
  226. },
  227. )
  228. @property
  229. def resumed(self):
  230. return self._resumed
  231. @property
  232. def search_alg(self):
  233. return self._search_alg
  234. @property
  235. def scheduler_alg(self):
  236. return self._scheduler_alg
  237. def setup_experiments(
  238. self, experiments: List[Experiment], total_num_samples: int
  239. ) -> None:
  240. """Obtains any necessary information from experiments.
  241. Mainly used to setup callbacks.
  242. Args:
  243. experiments: List of Experiments
  244. to use.
  245. total_num_samples: Total number of samples
  246. factoring in grid search samplers.
  247. """
  248. experiment = experiments[0]
  249. spec = experiment.public_spec if experiment else {}
  250. spec["total_num_samples"] = total_num_samples
  251. self._callbacks.setup(**spec)
  252. def end_experiment_callbacks(self) -> None:
  253. """Calls ``on_experiment_end`` method in callbacks."""
  254. self._callbacks.on_experiment_end(trials=self._trials)
  255. @property
  256. def experiment_state_file_name(self) -> str:
  257. return self.CKPT_FILE_TMPL.format(self._session_str)
  258. @property
  259. def experiment_state_path(self) -> str:
  260. """Returns the local experiment checkpoint path."""
  261. return Path(
  262. self._storage.experiment_driver_staging_path,
  263. self.experiment_state_file_name,
  264. ).as_posix()
  265. @property
  266. def experiment_path(self) -> str:
  267. return self._storage.experiment_fs_path
  268. def _create_checkpoint_manager(self):
  269. return _ExperimentCheckpointManager(
  270. storage=self._storage,
  271. checkpoint_period=self._checkpoint_period,
  272. sync_every_n_trial_checkpoints=self._trial_checkpoint_config.num_to_keep,
  273. )
  274. def save_to_dir(self):
  275. """Save TuneController state to the local staging experiment directory.
  276. This includes:
  277. - trial states
  278. - TuneController internal state (all the serializable attributes)
  279. - the searcher state
  280. - the callback states
  281. """
  282. # Get state from trial executor and runner
  283. runner_state = {
  284. # Trials
  285. "trial_data": list(self._get_trial_checkpoints().values()),
  286. # Experiment data
  287. "runner_data": self.__getstate__(),
  288. # Metadata
  289. "stats": {"start_time": self._start_time},
  290. }
  291. driver_staging_path = self._storage.experiment_driver_staging_path
  292. os.makedirs(driver_staging_path, exist_ok=True)
  293. with open(
  294. Path(driver_staging_path, self.experiment_state_file_name),
  295. "w",
  296. ) as f:
  297. json.dump(runner_state, f, cls=TuneFunctionEncoder)
  298. self._search_alg.save_to_dir(driver_staging_path, session_str=self._session_str)
  299. self._callbacks.save_to_dir(driver_staging_path, session_str=self._session_str)
  300. def checkpoint(self, force: bool = False, wait: bool = False):
  301. self._checkpoint_manager.sync_up_experiment_state(
  302. save_fn=self.save_to_dir, force=force, wait=wait
  303. )
  304. def _requeue_restored_trials(
  305. self, trials: List[Trial], resume_config: ResumeConfig
  306. ):
  307. # Set trial statuses according to the resume configuration
  308. for trial in sorted(
  309. trials, key=lambda t: t.run_metadata.last_result_time, reverse=True
  310. ):
  311. if trial.status == Trial.ERROR:
  312. resume_type = resume_config.errored
  313. elif trial.status == Trial.TERMINATED:
  314. resume_type = resume_config.finished
  315. else: # Unfinished (PENDING, RUNNING, PAUSED)
  316. resume_type = resume_config.unfinished
  317. trial_to_add = None
  318. if resume_type == ResumeConfig.ResumeType.RESUME:
  319. # Keep trial ID on resume
  320. trial_to_add = trial
  321. trial_to_add.run_metadata.error_filename = None
  322. trial_to_add.run_metadata.pickled_error_filename = None
  323. trial_to_add.set_status(Trial.PENDING)
  324. elif resume_type == ResumeConfig.ResumeType.RESTART:
  325. trial_to_add = trial.reset()
  326. trial_to_add.restore_path = None
  327. elif resume_type == ResumeConfig.ResumeType.SKIP:
  328. trial_to_add = trial
  329. if trial_to_add.status != Trial.ERROR:
  330. # Set the status to terminated to skip it.
  331. # Keep errored trial status as ERROR.
  332. trial_to_add.set_status(Trial.TERMINATED)
  333. else:
  334. raise ValueError(f"Unknown resume type: {resume_type}")
  335. assert trial_to_add is not None
  336. self.add_trial(trial_to_add)
  337. def _restore_trials(self, experiment_state: Dict) -> List[Trial]:
  338. trials = []
  339. for trial_json_state, trial_runtime_metadata in experiment_state["trial_data"]:
  340. trial = Trial.from_json_state(trial_json_state)
  341. trial.restore_run_metadata(trial_runtime_metadata)
  342. # The following properties may be updated on restoration
  343. # Ex: moved local/cloud experiment directory
  344. # Propagate updated storage ctx properties to the trial's restored copy.
  345. new_storage = copy.copy(trial.storage)
  346. new_storage.storage_filesystem = self._storage.storage_filesystem
  347. new_storage.storage_fs_path = self._storage.storage_fs_path
  348. new_storage.experiment_dir_name = self._storage.experiment_dir_name
  349. # ATTN: `trial.set_storage` is used intentionally, since it
  350. # also updates the absolute paths and filesystem of tracked checkpoints.
  351. trial.set_storage(new_storage)
  352. # Avoid creating logdir in client mode for returned trial results,
  353. # since the dir might not be creatable locally.
  354. # TODO(ekl) this is kind of a hack.
  355. if not ray.util.client.ray.is_connected():
  356. trial.init_local_path() # Create logdir if it does not exist
  357. trials.append(trial)
  358. # NOTE: The restored run should reuse the same driver staging directory.
  359. self._storage._timestamp = trials[0].storage._timestamp
  360. return trials
  361. def resume(self, resume_config: ResumeConfig):
  362. """Resumes all checkpointed trials from previous run.
  363. Requires user to manually re-register their objects. Also stops
  364. all ongoing trials.
  365. """
  366. # 1. Restore TuneController state
  367. # Find newest state file
  368. newest_state_path = _find_newest_experiment_checkpoint(
  369. self._storage.experiment_fs_path, fs=self._storage.storage_filesystem
  370. )
  371. if newest_state_path is None:
  372. raise ValueError(
  373. f"Tried to resume experiment from directory "
  374. f"'{self._storage.experiment_fs_path}', but no "
  375. f"experiment state file of the form '{TuneController.CKPT_FILE_TMPL}' "
  376. "was found. This is expected if you are launching a new experiment."
  377. )
  378. logger.info(
  379. "Restoring the run from the latest experiment state file: "
  380. f"{Path(newest_state_path).name}"
  381. )
  382. with self._storage.storage_filesystem.open_input_stream(newest_state_path) as f:
  383. experiment_state = json.loads(f.readall(), cls=TuneFunctionDecoder)
  384. self.__setstate__(experiment_state["runner_data"])
  385. # 2. Get the trial states that the run left off at.
  386. trials = self._restore_trials(experiment_state)
  387. # 3. Restore search algorithm and callback state
  388. # Download the search algorithm and callback state to the driver staging dir.
  389. self._checkpoint_manager.sync_down_experiment_state()
  390. driver_staging_dir = self._storage.experiment_driver_staging_path
  391. if self._search_alg.has_checkpoint(driver_staging_dir):
  392. self._search_alg.restore_from_dir(driver_staging_dir)
  393. if self._callbacks.can_restore(driver_staging_dir):
  394. self._callbacks.restore_from_dir(driver_staging_dir)
  395. # 4. Re-queue trials as needed, depending on their status.
  396. self._requeue_restored_trials(trials, resume_config)
  397. def update_max_pending_trials(self, max_pending_trials: Optional[int] = None):
  398. self._max_pending_trials = max_pending_trials or _get_max_pending_trials(
  399. self._search_alg
  400. )
  401. def update_pending_trial_resources(
  402. self, resources: Union[dict, PlacementGroupFactory]
  403. ):
  404. """Update trial resources when resuming from checkpoint.
  405. Only updating the pending ones.
  406. """
  407. assert resources
  408. if isinstance(resources, dict) and "gpu" not in resources:
  409. resources["gpu"] = 0
  410. for trial in self._trials:
  411. if trial.status == Trial.PENDING:
  412. trial.update_resources(resources=resources)
  413. def is_finished(self):
  414. """Returns whether all trials have finished running."""
  415. # The checks here are partly redundant but optimized for quick
  416. # evaluation. Specifically, if there are live trials, we check
  417. # these live trials first. Only if none of the live trials is
  418. # live anymore do we loop over all trials for a final check.
  419. trials_done = (
  420. len(self._live_trials) == 0
  421. or all(trial.is_finished() for trial in self._live_trials)
  422. ) and all(trial.is_finished() for trial in self._trials)
  423. return trials_done and self._search_alg.is_finished()
  424. def get_trial(self, tid):
  425. trial = [t for t in self._trials if t.trial_id == tid]
  426. return trial[0] if trial else None
  427. def get_trials(self):
  428. """Returns the list of trials managed by this TrialRunner.
  429. Note that the caller usually should not mutate trial state directly.
  430. """
  431. return self._trials
  432. def get_live_trials(self):
  433. """Returns the set of trials that are not in Trial.TERMINATED state."""
  434. return self._live_trials
  435. def add_trial(self, trial: Trial):
  436. """Adds a new trial to this TrialRunner.
  437. Trials may be added at any time.
  438. Args:
  439. trial: Trial to queue.
  440. """
  441. # If the config map has had all the references replaced with placeholders,
  442. # resolve them before adding the trial.
  443. if self._placeholder_resolvers:
  444. trial.resolve_config_placeholders(self._placeholder_resolvers)
  445. # With trial.config resolved, create placement group factory if needed.
  446. trial.create_placement_group_factory()
  447. self._trials.append(trial)
  448. if trial.status != Trial.TERMINATED:
  449. self._live_trials.add(trial)
  450. with warn_if_slow("scheduler.on_trial_add"):
  451. self._scheduler_alg.on_trial_add(self._wrapped(), trial)
  452. self._mark_trial_to_checkpoint(trial)
  453. logger.debug(f"Adding trial {trial} with status {trial.status}")
  454. status_str_map = {
  455. Trial.PENDING: self._pending_trials,
  456. Trial.RUNNING: self._running_trials,
  457. Trial.PAUSED: self._paused_trials,
  458. Trial.TERMINATED: self._stopped_trials,
  459. Trial.ERROR: self._failed_trials,
  460. }
  461. status_str_map[trial.status].add(trial)
  462. if trial.status == Trial.PENDING:
  463. self._pending_trials_list.append(trial)
  464. self._resources_to_pending_trials[trial.placement_group_factory].add(trial)
  465. def _update_trial_queue(self, blocking: bool = False, timeout: int = 600) -> bool:
  466. """Adds next trials to queue if possible.
  467. Note that the timeout is currently unexposed to the user.
  468. Args:
  469. blocking: Blocks until either a trial is available
  470. or is_finished (timeout or search algorithm finishes).
  471. timeout: Seconds before blocking times out.
  472. Returns:
  473. Boolean indicating if a new trial was created or not.
  474. """
  475. trial = self._search_alg.next_trial()
  476. if blocking and not trial:
  477. start = time.time()
  478. # Checking `is_finished` instead of _search_alg.is_finished
  479. # is fine because blocking only occurs if all trials are
  480. # finished and search_algorithm is not yet finished
  481. while (
  482. not trial and not self.is_finished() and time.time() - start < timeout
  483. ):
  484. logger.debug("Blocking for next trial...")
  485. trial = self._search_alg.next_trial()
  486. time.sleep(1)
  487. if trial:
  488. self.add_trial(trial)
  489. return True
  490. return False
  491. def _used_resources_string(self) -> str:
  492. allocated_resources = self._actor_manager.get_live_actors_resources()
  493. return self._resource_updater.debug_string(allocated_resources)
  494. def on_step_begin(self):
  495. self._resource_updater.update_avail_resources()
  496. def on_step_end(self):
  497. self._cleanup_cached_actors(force_all=False)
  498. self._cleanup_stopping_actors(force_all=False)
  499. def _cleanup_cached_actors(self, force_all: bool = False):
  500. if (
  501. self._search_alg.is_finished()
  502. and not self._staged_trials
  503. and self._actor_cache.total_max_objects == 0
  504. ):
  505. # If there are no more trials coming in, no trials are pending execution,
  506. # and we don't explicitly want to cache objects, we can evict the full
  507. # cache.
  508. force_all = True
  509. for tracked_actor in self._actor_cache.flush_cached_objects(
  510. force_all=force_all
  511. ):
  512. logger.debug(f"Cleaning up cached actor: {tracked_actor}")
  513. # Unset termination callbacks as no trial is associated
  514. tracked_actor.set_on_stop(None)
  515. tracked_actor.set_on_error(None)
  516. self._remove_actor(tracked_actor=tracked_actor)
  517. def _cleanup_stopping_actors(self, force_all: bool = False):
  518. now = time.monotonic()
  519. if (
  520. not force_all
  521. and now - self._earliest_stopping_actor <= self._actor_cleanup_timeout
  522. ):
  523. # If the earliest actor to timeout has not reached the timeout, return
  524. return
  525. # This is a bit costly, so we want to avoid running it too often
  526. times = deque(
  527. sorted(
  528. [
  529. (timestamp, tracked_actor)
  530. for tracked_actor, timestamp in self._stopping_actors.items()
  531. ],
  532. key=lambda item: item[0],
  533. )
  534. )
  535. while times and (
  536. force_all or time.monotonic() - times[0][0] > self._actor_cleanup_timeout
  537. ):
  538. if (
  539. time.monotonic() - times[0][0] < self._actor_force_cleanup_timeout
  540. ) and self._actor_manager.is_actor_started(tracked_actor=times[0][1]):
  541. # Even if force_all=True, we give the actors time to clean up
  542. self._actor_manager.next(timeout=1)
  543. continue
  544. _, tracked_actor = times.popleft()
  545. if tracked_actor not in self._stopping_actors:
  546. # Actor stopping has been handled by the block above
  547. continue
  548. if self._actor_manager.is_actor_started(tracked_actor=tracked_actor):
  549. logger.debug(f"Forcefully killing actor: {tracked_actor}")
  550. self._actor_manager.remove_actor(tracked_actor=tracked_actor, kill=True)
  551. self._stopping_actors.pop(tracked_actor)
  552. if times:
  553. self._earliest_stopping_actor = times[0][0]
  554. else:
  555. self._earliest_stopping_actor = float("inf")
  556. def step(self):
  557. if self.is_finished():
  558. raise TuneError("Called step when all trials finished?")
  559. with warn_if_slow("on_step_begin"):
  560. self.on_step_begin()
  561. with warn_if_slow("callbacks.on_step_begin"):
  562. self._callbacks.on_step_begin(
  563. iteration=self._iteration, trials=self._trials
  564. )
  565. # Ask searcher for more trials
  566. self._maybe_update_trial_queue()
  567. # Start actors for added trials
  568. self._maybe_add_actors()
  569. # Handle one event
  570. if not self._actor_manager.next(timeout=0.1):
  571. # If there are no actors running, warn about potentially
  572. # insufficient resources
  573. if not self._actor_manager.num_live_actors:
  574. self._insufficient_resources_manager.on_no_available_trials(
  575. self.get_trials()
  576. )
  577. # Maybe stop whole experiment
  578. self._stop_experiment_if_needed()
  579. # Maybe save experiment state
  580. try:
  581. self.checkpoint()
  582. except Exception as e:
  583. logger.warning(f"Trial controller checkpointing failed: {str(e)}")
  584. raise e
  585. self._iteration += 1
  586. with warn_if_slow("on_step_end"):
  587. self.on_step_end()
  588. with warn_if_slow("callbacks.on_step_end"):
  589. self._callbacks.on_step_end(iteration=self._iteration, trials=self._trials)
  590. def _set_trial_status(self, trial: Trial, status: str):
  591. """Set trial to a specific status.
  592. This will keep track of trials with specific statuses in sets.
  593. For PENDING and PAUSED trials we also keep a list of trials to be able
  594. to retain FIFO ordering. See ``_maybe_add_actors`` for details.
  595. Lastly we also keep a mapping from resources to pending/paused trials
  596. to be able to efficiently start trials for cached actors.
  597. """
  598. current_status = trial.status
  599. if current_status == status:
  600. logger.debug(f"Trial {trial} already has status {status}. Skipping update.")
  601. return
  602. status_str_map = {
  603. Trial.PENDING: self._pending_trials,
  604. Trial.RUNNING: self._running_trials,
  605. Trial.PAUSED: self._paused_trials,
  606. Trial.TERMINATED: self._stopped_trials,
  607. Trial.ERROR: self._failed_trials,
  608. }
  609. logger.debug(
  610. f"Setting status for trial {trial} from {current_status} to {status}"
  611. )
  612. assert trial in status_str_map[current_status], (trial, current_status)
  613. assert trial not in status_str_map[status], (trial, status)
  614. status_str_map[current_status].remove(trial)
  615. status_str_map[status].add(trial)
  616. # We keep a log for pending trials for FIFO scheduling.
  617. # We do not need to remove from this list as we will just discard
  618. # items that are in this list but not in the respective set.
  619. if status == Trial.PENDING:
  620. self._pending_trials_list.append(trial)
  621. self._resources_to_pending_trials[trial.placement_group_factory].add(trial)
  622. else:
  623. self._resources_to_pending_trials[trial.placement_group_factory].discard(
  624. trial
  625. )
  626. trial.set_status(status)
  627. def _get_trial_checkpoints(self) -> Dict[str, str]:
  628. for trial in self._trials_to_cache:
  629. self._trial_metadata[trial.trial_id] = trial.get_json_state()
  630. self._trials_to_cache.clear()
  631. return self._trial_metadata
  632. def _mark_trial_to_checkpoint(self, trial: Trial):
  633. self._trials_to_cache.add(trial)
  634. ###
  635. # UPDATE TRIALS
  636. def _maybe_update_trial_queue(self):
  637. """Ask the searcher for more trials."""
  638. if self._search_alg.is_finished():
  639. return
  640. dont_wait_for_trial = (
  641. self._pending_trials or self._running_trials or self._paused_trials
  642. )
  643. while len(self._pending_trials) < self._max_pending_trials:
  644. if not self._update_trial_queue(blocking=not dont_wait_for_trial):
  645. break
  646. dont_wait_for_trial = True
  647. def _cleanup_trials(self):
  648. logger.debug("CLEANING UP all trials")
  649. for tracked_actor in list(self._actor_to_trial):
  650. trial = self._actor_to_trial[tracked_actor]
  651. logger.debug(
  652. f"Scheduling trial stop at end of experiment (trial {trial}): "
  653. f"{tracked_actor}"
  654. )
  655. self._schedule_trial_stop(trial)
  656. # Clean up cached actors now
  657. self._cleanup_cached_actors(force_all=True)
  658. start = time.monotonic()
  659. while time.monotonic() - start < 5 and self._actor_manager.num_total_actors:
  660. if _dedup_logs("actor_manager_cleanup", str(start)):
  661. logger.debug(
  662. "Waiting for actor manager to clean up final state [dedup]"
  663. )
  664. self._actor_manager.next(timeout=1)
  665. logger.debug("Force cleanup of remaining actors")
  666. self._cleanup_stopping_actors(force_all=True)
  667. self._actor_manager.cleanup()
  668. def _remove_actor(self, tracked_actor: TrackedActor):
  669. stop_future = self._actor_manager.schedule_actor_task(
  670. tracked_actor, "stop", _return_future=True
  671. )
  672. now = time.monotonic()
  673. if self._actor_manager.remove_actor(
  674. tracked_actor, kill=False, stop_future=stop_future
  675. ):
  676. # If the actor was previously alive, track
  677. self._stopping_actors[tracked_actor] = now
  678. self._earliest_stopping_actor = min(self._earliest_stopping_actor, now)
  679. ###
  680. # ADD ACTORS
  681. def _maybe_add_actors(self) -> None:
  682. """Add actors for pending and paused trials.
  683. For actors that have not been staged, yet, we request an actor.
  684. For actors that have been staged, already, we try to reuse a cached actor.
  685. First, we handle the trial that the scheduler chooses to run.
  686. Then, we handle all trials that are pending.
  687. Lastly, we see if we have cached actors that we can assign to a pending or
  688. paused trial. This can be the case when a trial has not been staged, yet,
  689. for instance because the number of staging trials was too large.
  690. """
  691. ###
  692. # 1: Start trial that the scheduler wants to run
  693. with warn_if_slow("choose_trial_to_run"):
  694. trial_to_run = self._scheduler_alg.choose_trial_to_run(self._wrapped())
  695. if trial_to_run:
  696. if _dedup_logs("trial_to_run_chosen", trial_to_run.trial_id):
  697. logger.debug(
  698. f"Chose trial to run from scheduler: {trial_to_run} [dedup]"
  699. )
  700. if (
  701. trial_to_run not in self._staged_trials
  702. and trial_to_run not in self._trial_to_actor
  703. ):
  704. logger.debug(f"Staging trial to run: {trial_to_run}")
  705. self._set_trial_status(trial_to_run, Trial.PENDING)
  706. self._staged_trials.add(trial_to_run)
  707. self._actor_cache.increase_max(trial_to_run.placement_group_factory)
  708. # schedule_trial_actor also potentially uses cached actors
  709. self._schedule_trial_actor(trial_to_run)
  710. else:
  711. # Otherwise, only try to use the cached actor
  712. if _dedup_logs("trial_to_run_reuse", trial_to_run.trial_id):
  713. logger.debug(
  714. f"Trying to re-use actor for trial to run: {trial_to_run} "
  715. f"[dedup]"
  716. )
  717. self._maybe_reuse_cached_actor(trial_to_run)
  718. ###
  719. # 2: Start trials that are PENDING
  720. def _maybe_add_actors(candidates: List[Trial]):
  721. new_candidates = []
  722. while candidates:
  723. if self._actor_manager.num_pending_actors >= self._max_pending_trials:
  724. break
  725. trial = candidates.pop(0)
  726. # If the trial is part of the list, but not of the set,
  727. # we just ignore it. Removing it from the list on status
  728. # change is too expensive.
  729. if trial not in self._pending_trials:
  730. continue
  731. if trial in self._trial_to_actor:
  732. new_candidates.append(trial)
  733. continue
  734. if trial in self._staged_trials:
  735. self._maybe_reuse_cached_actor(trial)
  736. continue
  737. logger.debug(f"Scheduling actor for enqueued trial: {trial}")
  738. self._staged_trials.add(trial)
  739. self._actor_cache.increase_max(trial.placement_group_factory)
  740. self._schedule_trial_actor(trial)
  741. return new_candidates + candidates
  742. self._pending_trials_list = _maybe_add_actors(self._pending_trials_list)
  743. ###
  744. # 3: Start any trial that can be started with a cached actor
  745. if self._actor_cache.num_cached_objects:
  746. for resource in self._resources_to_pending_trials:
  747. if not self._resources_to_pending_trials[resource]:
  748. continue
  749. if not self._actor_cache.has_cached_object(resource):
  750. continue
  751. start_trial = self._resources_to_pending_trials[resource].pop()
  752. logger.debug(
  753. f"Trying to re-use actor for enqueued trial: {start_trial}"
  754. )
  755. if not self._maybe_reuse_cached_actor(start_trial):
  756. self._resources_to_pending_trials[resource].add(start_trial)
  757. else:
  758. if start_trial not in self._staged_trials:
  759. self._staged_trials.add(start_trial)
  760. self._actor_cache.increase_max(
  761. start_trial.placement_group_factory
  762. )
  763. def _maybe_reuse_cached_actor(self, trial: Trial) -> bool:
  764. """Maybe reuse a cached actor for a trial.
  765. If an actor has been scheduled for the trial already,
  766. this will remove the original actor.
  767. """
  768. if trial in self._resetting_trials:
  769. return True
  770. resource_request = trial.placement_group_factory
  771. if not self._actor_cache.has_cached_object(resource_request):
  772. return False
  773. cached_actor = self._actor_cache.pop_cached_object(resource_request)
  774. logger.debug(f"Reusing ACTOR for trial {trial}: {cached_actor}")
  775. if trial in self._trial_to_actor:
  776. original_actor = self._trial_to_actor.pop(trial)
  777. self._actor_to_trial.pop(original_actor)
  778. logger.debug(f"Removing ORIGINAL ACTOR for trial {trial}: {original_actor}")
  779. self._remove_actor(tracked_actor=original_actor)
  780. self._trial_to_actor[trial] = cached_actor
  781. self._actor_to_trial[cached_actor] = trial
  782. # Todo: get rid of Trial.runner
  783. ray_actor = self._actor_manager._live_actors_to_ray_actors_resources[
  784. cached_actor
  785. ][0]
  786. trial.set_ray_actor(ray_actor)
  787. self._schedule_trial_reset(trial, trial.config, trial.experiment_tag)
  788. return True
  789. def _schedule_trial_actor(self, trial: Trial):
  790. """Schedule an actor for a trial.
  791. If a cached actor is available, use it. Otherwise, request a
  792. new actor.
  793. """
  794. logger.debug(f"Trying to schedule new ACTOR for trial {trial}")
  795. assert trial.status == Trial.PENDING
  796. trial.init_local_path()
  797. # We checkpoint metadata here to try mitigating logdir duplication
  798. self._mark_trial_to_checkpoint(trial)
  799. if self._maybe_reuse_cached_actor(trial):
  800. return
  801. # Safeguard
  802. if trial in self._trial_to_actor:
  803. raise RuntimeError(
  804. f"Tried to request a new actor for trial {trial}, but an old "
  805. f"actor still exists. This can lead to leaked resources. The old "
  806. f"actor should be removed first. "
  807. f"This is an internal problem in Ray Tune. If you encounter this "
  808. f"error, please raise an issue on "
  809. f"https://github.com/ray-project/ray/issues"
  810. )
  811. trainable_cls = trial.get_trainable_cls()
  812. if not trainable_cls:
  813. exception = _AbortTrialExecution(
  814. f"Invalid trainable: {trial.trainable_name}. If you passed "
  815. f"a string, make sure the trainable was registered before."
  816. )
  817. trial.handle_error(exception)
  818. self._schedule_trial_stop(trial, exception=exception)
  819. return
  820. _actor_cls = self._class_cache.get(trainable_cls)
  821. trial.set_location(_Location())
  822. trainable_kwargs = _get_trainable_kwargs(trial=trial)
  823. with _change_working_directory(trial):
  824. tracked_actor = self._actor_manager.add_actor(
  825. cls=_actor_cls,
  826. resource_request=trial.placement_group_factory,
  827. kwargs=trainable_kwargs,
  828. on_start=self._actor_started,
  829. on_stop=self._actor_stopped,
  830. on_error=self._actor_failed,
  831. )
  832. self._trial_to_actor[trial] = tracked_actor
  833. self._actor_to_trial[tracked_actor] = trial
  834. logger.debug(
  835. f"Scheduled new ACTOR for trial {trial}: {tracked_actor}. "
  836. f"Resources: {trial.placement_group_factory}"
  837. )
  838. def _unstage_trial_with_resources(self, trial: Trial):
  839. """Unstage trial, or one with the same resources as ``trial``."""
  840. # Case 1: The trial we started was staged. Just remove it
  841. if trial in self._staged_trials:
  842. self._staged_trials.remove(trial)
  843. self._actor_cache.decrease_max(trial.placement_group_factory)
  844. return
  845. # Case 2: We staged a trial "A" with the same resources, but our trial "B"
  846. # was selected by the scheduler to run. The resource manager does not care
  847. # about "trials", it just cares about resources being available. Thus we
  848. # look for a staged trial with the same resource requirements and remove it
  849. resource_request = trial.placement_group_factory
  850. # Remove staged trial with same resource requirements
  851. candidate_trial = None
  852. for staged_trial in self._staged_trials:
  853. staged_resources = staged_trial.placement_group_factory
  854. if staged_resources == resource_request:
  855. candidate_trial = staged_trial
  856. break
  857. if candidate_trial:
  858. self._staged_trials.remove(candidate_trial)
  859. self._actor_cache.decrease_max(candidate_trial.placement_group_factory)
  860. return
  861. raise RuntimeError(
  862. "Started a trial with resources requested by a different trial, but "
  863. "this trial was lost. This is an error in Ray Tune's execution "
  864. "logic. Please raise a GitHub issue at "
  865. "https://github.com/ray-project/ray/issues"
  866. )
  867. def _maybe_cache_trial_actor(self, trial: Trial) -> bool:
  868. """Cache trial actor for reuse, if needed.
  869. We will only cache as many actors as are needed to fulfill any pending
  870. resource requests for actors with the same resource requirements.
  871. E.g. if we have 6 running trials and 4 additional staged actors, we will only
  872. cache up to 4 of the running trial actors when they finish.
  873. One exception is the case when we have no cached actors, yet. In that case,
  874. we will always cache the actor in this method.
  875. Later, in `_cleanup_cached_actors`, we will check again if we need this cached
  876. actor. That method will keep the actor if we don't have any staged trials,
  877. because we don't know at that point if the next trial might require the same
  878. resources. But because there is no staged trial, it is safe to keep the actor
  879. around, as it won't occupy resources needed by another trial until it's staged.
  880. """
  881. if not self._reuse_actors:
  882. return False
  883. if self._search_alg.is_finished() and not self._staged_trials:
  884. logger.debug(
  885. f"Not caching actor of trial {trial} as the search is over "
  886. f"and no more trials are staged."
  887. )
  888. return False
  889. tracked_actor = self._trial_to_actor[trial]
  890. if (
  891. not self._actor_manager.is_actor_started(tracked_actor)
  892. or self._actor_manager.is_actor_failed(tracked_actor)
  893. or tracked_actor not in self._started_actors
  894. ):
  895. logger.debug(
  896. f"Not caching actor of trial {trial} as it has not been started, yet: "
  897. f"{tracked_actor}"
  898. )
  899. return False
  900. if not self._actor_cache.cache_object(
  901. trial.placement_group_factory, tracked_actor
  902. ):
  903. logger.debug(
  904. f"Could not cache actor of trial {trial} for "
  905. "reuse, as there are no pending trials "
  906. "requiring its resources."
  907. )
  908. return False
  909. logger.debug(f"Caching actor of trial {trial} for re-use: {tracked_actor}")
  910. tracked_actor = self._trial_to_actor.pop(trial)
  911. self._actor_to_trial.pop(tracked_actor)
  912. trial.set_ray_actor(None)
  913. return True
  914. def _actor_started(self, tracked_actor: TrackedActor, log: str = "STARTED"):
  915. self._started_actors.add(tracked_actor)
  916. trial = self._actor_to_trial[tracked_actor]
  917. logger.debug(f"Actor {log} for trial {trial}: {tracked_actor}")
  918. self._unstage_trial_with_resources(trial)
  919. ray_actor = self._actor_manager._live_actors_to_ray_actors_resources[
  920. tracked_actor
  921. ][0]
  922. trial.set_ray_actor(ray_actor)
  923. self._callbacks.on_trial_start(
  924. iteration=self._iteration, trials=self._trials, trial=trial
  925. )
  926. self._set_trial_status(trial, Trial.RUNNING)
  927. self._mark_trial_to_checkpoint(trial)
  928. if not self._schedule_trial_restore(trial):
  929. self._schedule_trial_train(trial)
  930. def _actor_stopped(self, tracked_actor: TrackedActor):
  931. if tracked_actor in self._actor_to_trial:
  932. trial = self._actor_to_trial.pop(tracked_actor)
  933. logger.debug(f"Actor STOPPED for trial {trial}: {tracked_actor}")
  934. self._trial_to_actor.pop(trial)
  935. trial.set_ray_actor(None)
  936. logger.debug(f"Actor STOPPED: {tracked_actor}")
  937. self._stopping_actors.pop(tracked_actor, None)
  938. self._started_actors.discard(tracked_actor)
  939. def _actor_failed(self, tracked_actor: TrackedActor, exception: Exception):
  940. trial = self._actor_to_trial[tracked_actor]
  941. logger.debug(
  942. f"Actor FAILED for trial {trial}: {tracked_actor}. "
  943. f"Exception: {exception}"
  944. )
  945. if trial in (self._pending_trials | self._paused_trials):
  946. # First, set to running (needed downstream in _process_trial_failure)
  947. self._set_trial_status(trial, Trial.RUNNING)
  948. logger.debug(
  949. f"Trial {trial} failed in its creation task. Unstaging "
  950. f"to allow it to be re-scheduled."
  951. )
  952. self._unstage_trial_with_resources(trial)
  953. self._trial_task_failure(trial, exception=exception)
  954. self._actor_manager.clear_actor_task_futures(tracked_actor)
  955. # Clean up actor
  956. tracked_actor.set_on_stop(None)
  957. tracked_actor.set_on_error(None)
  958. self._actor_manager.remove_actor(tracked_actor, kill=False)
  959. # Trigger actor stopped callback
  960. self._actor_stopped(tracked_actor)
  961. def _schedule_trial_task(
  962. self,
  963. trial: Trial,
  964. method_name: str,
  965. args: Optional[Tuple] = None,
  966. kwargs: Optional[Dict] = None,
  967. on_result: Optional[Callable[[Trial, Any], None]] = None,
  968. on_error: Optional[Callable[[Trial, Exception], None]] = None,
  969. _return_future: bool = False,
  970. ) -> Optional[ray.ObjectRef]:
  971. """Schedule an actor task future for a trial.
  972. This is a wrapper around ``ActorManager.schedule_actor_task``. This method
  973. retrieves the tracked actor for a trial to kick off the task.
  974. It also wraps around the callbacks, retrieving the trial object given the
  975. tracked actor.
  976. """
  977. tracked_actor = self._trial_to_actor[trial]
  978. _on_result = None
  979. _on_error = None
  980. args = args or tuple()
  981. kwargs = kwargs or {}
  982. if on_result:
  983. def _on_result(tracked_actor: TrackedActor, *args, **kwargs):
  984. assert trial == self._actor_to_trial[tracked_actor]
  985. logger.debug(
  986. f"Future {method_name.upper()} RESOLVED for trial {trial}: "
  987. f"{args}, {kwargs}"
  988. )
  989. try:
  990. on_result(trial, *args, **kwargs)
  991. except Exception as e:
  992. logger.debug(
  993. f"Error handling {method_name.upper()} result "
  994. f"for trial {trial}: {e}"
  995. )
  996. if e is TuneError or self._fail_fast == self.RAISE:
  997. raise e
  998. else:
  999. raise TuneError(traceback.format_exc())
  1000. if on_error:
  1001. def _on_error(tracked_actor: TrackedActor, exception: Exception):
  1002. # If the actor failed, it has already been cleaned up.
  1003. if tracked_actor not in self._actor_to_trial:
  1004. assert isinstance(exception, RayActorError), type(exception)
  1005. else:
  1006. assert trial == self._actor_to_trial[tracked_actor]
  1007. logger.debug(
  1008. f"Future {method_name.upper()} FAILED for trial {trial}: "
  1009. f"{exception}"
  1010. )
  1011. try:
  1012. on_error(trial, exception)
  1013. except Exception as e:
  1014. logger.debug(
  1015. f"Error handling {method_name.upper()} failure "
  1016. f"for trial {trial}: {e}"
  1017. )
  1018. if e is TuneError or self._fail_fast == self.RAISE:
  1019. raise e
  1020. else:
  1021. raise TuneError(traceback.format_exc())
  1022. logger.debug(f"Future {method_name.upper()} SCHEDULED for trial {trial}")
  1023. with _change_working_directory(trial):
  1024. future = self._actor_manager.schedule_actor_task(
  1025. tracked_actor=tracked_actor,
  1026. method_name=method_name,
  1027. args=args,
  1028. kwargs=kwargs,
  1029. on_result=_on_result,
  1030. on_error=_on_error,
  1031. _return_future=_return_future,
  1032. )
  1033. if _return_future:
  1034. return future
  1035. def _queue_decision(self, trial, decision):
  1036. # Get old decision, setting it to the current decision if it isn't set
  1037. old_decision = self._queued_trial_decisions.setdefault(trial.trial_id, decision)
  1038. # Stopping always takes precedence. If we decided to stop, just quit
  1039. if old_decision is TrialScheduler.STOP:
  1040. return
  1041. # The old decision wasn't STOP. We update the decision only if it is
  1042. # STOP or PAUSE. The action will only be CONTINUE if it was set by
  1043. # the first received result and was never updated after that.
  1044. if decision is TrialScheduler.STOP or decision is TrialScheduler.PAUSE:
  1045. self._queued_trial_decisions[trial.trial_id] = decision
  1046. def _execute_action(self, trial: Trial, decision: str, after_save: bool = False):
  1047. """Executes action based on decision.
  1048. Args:
  1049. trial: Trial to act on.
  1050. decision: Scheduling decision to undertake.
  1051. """
  1052. if decision == TrialScheduler.CONTINUE:
  1053. self._schedule_trial_train(trial)
  1054. elif decision == TrialScheduler.PAUSE:
  1055. self.pause_trial(trial, should_checkpoint=not after_save)
  1056. elif decision == TrialScheduler.STOP:
  1057. self.stop_trial(trial)
  1058. elif decision == TrialScheduler.NOOP:
  1059. pass
  1060. else:
  1061. raise ValueError("Invalid decision: {}".format(decision))
  1062. def _maybe_execute_queued_decision(self, trial: Trial, after_save: bool = False):
  1063. # `self._queued_trial_decisions` now contains a final decision
  1064. # based on all results
  1065. final_decision = self._queued_trial_decisions.pop(trial.trial_id, None)
  1066. if final_decision:
  1067. logger.debug(
  1068. f"Executing final queued decision for {trial}: {final_decision}"
  1069. )
  1070. self._execute_action(trial, final_decision, after_save=after_save)
  1071. def _stop_experiment_if_needed(self):
  1072. """Stops all trials."""
  1073. fail_fast = self._fail_fast and self._has_errored
  1074. if self._stopper.stop_all() or fail_fast or self._should_stop_experiment:
  1075. self._search_alg.set_finished()
  1076. [
  1077. self._schedule_trial_stop(t)
  1078. for t in self._trials
  1079. if t.status not in {Trial.ERROR, Trial.TERMINATED}
  1080. ]
  1081. ###
  1082. # Failure
  1083. def _trial_task_failure(self, trial: Trial, exception: Exception):
  1084. if self._fail_fast == self.RAISE:
  1085. raise exception
  1086. else:
  1087. if self._print_trial_errors:
  1088. logger.error(f"Trial task failed for trial {trial}", exc_info=exception)
  1089. self._process_trial_failure(trial, exception=exception)
  1090. def _process_trial_failure(
  1091. self,
  1092. trial: Trial,
  1093. exception: Union[TuneError, RayTaskError, RayActorError],
  1094. ):
  1095. """Handle trial failure.
  1096. Attempt trial recovery if possible, clean up state otherwise.
  1097. Args:
  1098. trial: Failed trial.
  1099. exception: Exception prior to invoking this method.
  1100. """
  1101. self._has_errored = True
  1102. trial.handle_error(exception)
  1103. if trial.status == Trial.RUNNING and trial.should_recover():
  1104. self._try_recover(trial, exc=exception)
  1105. self._callbacks.on_trial_recover(
  1106. iteration=self._iteration, trials=self._trials, trial=trial
  1107. )
  1108. elif trial.status in {Trial.RUNNING, Trial.PENDING}:
  1109. self._scheduler_alg.on_trial_error(self, trial)
  1110. self._search_alg.on_trial_complete(trial.trial_id, error=True)
  1111. self._schedule_trial_stop(trial, exception=exception)
  1112. self._callbacks.on_trial_error(
  1113. iteration=self._iteration, trials=self._trials, trial=trial
  1114. )
  1115. def _schedule_trial_stop(self, trial: Trial, exception: Optional[Exception] = None):
  1116. if trial.status == Trial.ERROR:
  1117. logger.debug(f"Not requesting trial STOP as it is ERROR already: {trial}")
  1118. return
  1119. logger.debug(f"Requesting to STOP actor for trial {trial}")
  1120. if trial.is_saving:
  1121. logger.debug(
  1122. f"Trial {trial} is currently saving/pausing. Scheduling STOP after "
  1123. f"save resolved."
  1124. )
  1125. self._cached_trial_decisions[trial.trial_id] = TrialScheduler.STOP
  1126. trial.temporary_state.saving_to = None
  1127. trial.temporary_state.restoring_from = None
  1128. self._set_trial_status(trial, Trial.ERROR if exception else Trial.TERMINATED)
  1129. trial.set_location(_Location())
  1130. if trial not in self._trial_to_actor:
  1131. logger.debug(f"Will not STOP trial actor as it is not live: {trial}")
  1132. return
  1133. tracked_actor = self._trial_to_actor[trial]
  1134. self._actor_manager.clear_actor_task_futures(tracked_actor=tracked_actor)
  1135. self._mark_trial_to_checkpoint(trial)
  1136. if not exception and self._maybe_cache_trial_actor(trial):
  1137. # Trial runner has been cached
  1138. return
  1139. logger.debug(f"Terminating actor for trial {trial}: {tracked_actor}")
  1140. tracked_actor = self._trial_to_actor.pop(trial)
  1141. self._actor_to_trial.pop(tracked_actor)
  1142. trial.set_ray_actor(None)
  1143. self._remove_actor(tracked_actor=tracked_actor)
  1144. def stop_trial(self, trial):
  1145. """The canonical implementation of stopping a trial.
  1146. Trials may be in any external status when this function is called.
  1147. If trial is in state PENDING or PAUSED, calls `on_trial_remove` for
  1148. scheduler and `on_trial_complete()` for search_alg.
  1149. If trial is in state RUNNING, calls `on_trial_complete` for scheduler
  1150. and search_alg if RUNNING. Caller to ensure that there is no
  1151. outstanding future to be handled for the trial. If there is, the future
  1152. would be discarded.
  1153. """
  1154. try:
  1155. if trial.status in [Trial.ERROR, Trial.TERMINATED]:
  1156. return
  1157. elif trial.status in [Trial.PENDING, Trial.PAUSED]:
  1158. self._scheduler_alg.on_trial_remove(self, trial)
  1159. self._search_alg.on_trial_complete(trial.trial_id)
  1160. elif trial.status is Trial.RUNNING:
  1161. # By this time trial.last_result should have been
  1162. # updated already.
  1163. self._scheduler_alg.on_trial_complete(
  1164. self, trial, flatten_dict(trial.last_result)
  1165. )
  1166. self._search_alg.on_trial_complete(
  1167. trial.trial_id, result=flatten_dict(trial.last_result)
  1168. )
  1169. self._callbacks.on_trial_complete(
  1170. iteration=self._iteration, trials=self._trials, trial=trial
  1171. )
  1172. self._schedule_graceful_trial_stop(trial)
  1173. self._live_trials.discard(trial)
  1174. except Exception as e:
  1175. logger.exception("Trial %s: Error stopping trial.", trial)
  1176. if self._fail_fast == self.RAISE:
  1177. raise
  1178. if isinstance(e, TuneError):
  1179. self._process_trial_failure(trial, exception=e)
  1180. else:
  1181. self._process_trial_failure(
  1182. trial, _TuneStopTrialError(traceback.format_exc())
  1183. )
  1184. def _schedule_graceful_trial_stop(self, trial: Trial):
  1185. self._schedule_trial_export(trial)
  1186. if trial.status != "ERROR":
  1187. self._schedule_trial_stop(trial)
  1188. def _schedule_trial_pause(self, trial: Trial, should_checkpoint: bool = True):
  1189. if trial not in self._trial_to_actor:
  1190. logger.debug(
  1191. f"Trial PAUSE requested for trial {trial} but trial is already "
  1192. f"stopping. Ignoring."
  1193. )
  1194. return
  1195. if should_checkpoint:
  1196. self._cached_trial_decisions[trial.trial_id] = TrialScheduler.PAUSE
  1197. self._schedule_trial_save(trial=trial)
  1198. else:
  1199. self._schedule_trial_stop(trial)
  1200. self._set_trial_status(trial, Trial.PAUSED)
  1201. ###
  1202. # TRAIN
  1203. def _schedule_trial_train(self, trial: Trial):
  1204. args = ()
  1205. method_name = "train"
  1206. buffer_length, buffer_time_s = self._maybe_buffer_training(trial)
  1207. if buffer_length > 1:
  1208. method_name = "train_buffered"
  1209. args = (buffer_length, buffer_time_s)
  1210. logger.debug(f"Scheduling future {method_name.upper()} for trial {trial}")
  1211. self._schedule_trial_task(
  1212. trial=trial,
  1213. method_name=method_name,
  1214. args=args,
  1215. on_result=self._on_training_result,
  1216. on_error=self._trial_task_failure,
  1217. )
  1218. def _maybe_buffer_training(self, trial: Trial) -> Tuple[int, float]:
  1219. buffer_time_s = max(
  1220. self._buffer_min_time_s,
  1221. min(self._buffer_max_time_s, self._actor_manager.num_actor_tasks // 10),
  1222. )
  1223. buffer_length = self._buffer_length
  1224. if buffer_length > 1 and trial.checkpoint_at_end:
  1225. # If a trial checkpoint can be triggered externally,
  1226. # it is not safe to buffer results.
  1227. if log_once("trial_executor_buffer_checkpoint"):
  1228. logger.warning(
  1229. "Disabling buffered training as you passed "
  1230. "`checkpoint_at_end` to `tune.CheckpointConfig()`."
  1231. )
  1232. return 1, buffer_time_s
  1233. if buffer_length > 1 and trial.checkpoint_freq > 0:
  1234. return min(buffer_length, trial.checkpoint_freq), buffer_time_s
  1235. return buffer_length, buffer_time_s
  1236. ###
  1237. # RESULT
  1238. def _on_training_result(self, trial, result):
  1239. if not isinstance(result, list):
  1240. result = [result]
  1241. with warn_if_slow("process_trial_result"):
  1242. self._process_trial_results(trial, result)
  1243. self._maybe_execute_queued_decision(trial, after_save=False)
  1244. def _process_trial_results(self, trial, results):
  1245. logger.debug(f"Processing trial results for trial {trial}: {results}")
  1246. with warn_if_slow(
  1247. "process_trial_results",
  1248. message="Processing trial results took {duration:.3f} s, "
  1249. "which may be a performance bottleneck. Please consider "
  1250. "reporting results less frequently to Ray Tune.",
  1251. ):
  1252. for i, result in enumerate(results):
  1253. with warn_if_slow("process_trial_result"):
  1254. decision = self._process_trial_result(trial, result)
  1255. if decision is None:
  1256. # If we didn't get a decision, this means a
  1257. # non-training future (e.g. a save) was scheduled.
  1258. # We do not allow processing more results then.
  1259. if i < len(results) - 1:
  1260. if log_once("tune_controller_buffer_checkpoint"):
  1261. logger.warning(
  1262. f"Trial {trial} has a non-training future "
  1263. f"scheduled but {len(results) - i} results "
  1264. f"left to process. This means that a "
  1265. f"checkpoint was requested, but buffered "
  1266. f"training was continued before it was "
  1267. f"saved. Consider using non-buffered "
  1268. f"training by setting the env variable "
  1269. f"`TUNE_RESULT_BUFFER_LENGTH=1`."
  1270. )
  1271. elif decision == TrialScheduler.STOP:
  1272. # If the decision is to stop the trial,
  1273. # ignore all results that came after that.
  1274. break
  1275. def _process_trial_result(self, trial: Trial, result: dict[str, Any]):
  1276. result.update(trial_id=trial.trial_id)
  1277. is_duplicate = RESULT_DUPLICATE in result
  1278. force_checkpoint = False
  1279. # TrialScheduler and SearchAlgorithm still receive a
  1280. # notification because there may be special handling for
  1281. # the `on_trial_complete` hook.
  1282. if is_duplicate:
  1283. logger.debug("Trial finished without logging 'done'.")
  1284. result = trial.last_result
  1285. result.update(done=True)
  1286. self._total_time += result.get(TIME_THIS_ITER_S, 0)
  1287. flat_result = flatten_dict(result)
  1288. self._validate_result_metrics(flat_result)
  1289. if self._stopper(trial.trial_id, result) or trial.should_stop(flat_result):
  1290. decision = TrialScheduler.STOP
  1291. else:
  1292. with warn_if_slow("scheduler.on_trial_result"):
  1293. decision = self._scheduler_alg.on_trial_result(
  1294. self._wrapped(), trial, flat_result
  1295. )
  1296. if decision == TrialScheduler.STOP:
  1297. result.update(done=True)
  1298. else:
  1299. # Only updating search alg if the trial is not to be stopped.
  1300. with warn_if_slow("search_alg.on_trial_result"):
  1301. self._search_alg.on_trial_result(trial.trial_id, flat_result)
  1302. # If this is not a duplicate result, the callbacks should
  1303. # be informed about the result.
  1304. if not is_duplicate:
  1305. with warn_if_slow("callbacks.on_trial_result"):
  1306. self._callbacks.on_trial_result(
  1307. iteration=self._iteration,
  1308. trials=self._trials,
  1309. trial=trial,
  1310. # NOTE: Allow user callbacks to modify the Trial result in place.
  1311. result=result,
  1312. )
  1313. force_checkpoint = result.get(SHOULD_CHECKPOINT, False)
  1314. trial.update_last_result(result)
  1315. # Include in next experiment checkpoint
  1316. self._mark_trial_to_checkpoint(trial)
  1317. # Checkpoints to disk. This should be checked even if
  1318. # the scheduler decision is STOP or PAUSE. Note that
  1319. # PAUSE only checkpoints to memory and does not update
  1320. # the global checkpoint state.
  1321. if decision != TrialScheduler.PAUSE:
  1322. # TODO(justinvyu): This is a temporary hack to fix pausing trials.
  1323. # We already schedule a save task in `pause_trial`, so no need
  1324. # to do it again here.
  1325. self._checkpoint_trial_if_needed(trial, force=force_checkpoint)
  1326. if trial.is_saving:
  1327. logger.debug(f"Caching trial decision for trial {trial}: {decision}")
  1328. # Cache decision to execute on after the save is processed.
  1329. # This prevents changing the trial's state or kicking off
  1330. # another training step prematurely.
  1331. if not self._cached_trial_decisions.get(trial.trial_id) or decision in {
  1332. TrialScheduler.PAUSE,
  1333. TrialScheduler.STOP,
  1334. }:
  1335. # If already set, only overwrite if it's a PAUSE or STOP. This is
  1336. # to avoid that CONTINUE decisions from a training step that resolve
  1337. # late overwrite PAUSE/STOP decision.
  1338. self._cached_trial_decisions[trial.trial_id] = decision
  1339. return None
  1340. else:
  1341. self._queue_decision(trial, decision)
  1342. return decision
  1343. def _validate_result_metrics(self, result):
  1344. """
  1345. Check if any of the required metrics was not reported
  1346. in the last result. If the only items are ``done`` or any of
  1347. DEBUG_METRICS, this means that no result was ever received and
  1348. the trial just returned. This is also okay and will not raise
  1349. an error.
  1350. This will ignore checking for the DEFAULT_METRIC.
  1351. """
  1352. if int(os.environ.get("TUNE_DISABLE_STRICT_METRIC_CHECKING", 0)) != 1 and (
  1353. len({k for k in result if k not in list(DEBUG_METRICS) + [DONE]}) > 1
  1354. ):
  1355. base_metric = self._metric if self._metric != DEFAULT_METRIC else None
  1356. scheduler_metric = (
  1357. self._scheduler_alg.metric
  1358. if self._scheduler_alg.metric != DEFAULT_METRIC
  1359. else None
  1360. )
  1361. search_metrics = (
  1362. self._search_alg.metric
  1363. if self._search_alg.metric != DEFAULT_METRIC
  1364. else None
  1365. )
  1366. if isinstance(search_metrics, str):
  1367. search_metrics = [search_metrics]
  1368. if base_metric and base_metric not in result:
  1369. report_metric = base_metric
  1370. location = "tune.TuneConfig()"
  1371. elif scheduler_metric and scheduler_metric not in result:
  1372. report_metric = scheduler_metric
  1373. location = type(self._scheduler_alg).__name__
  1374. elif search_metrics and any(
  1375. search_metric not in result for search_metric in search_metrics
  1376. ):
  1377. report_metric = list(
  1378. filter(
  1379. lambda search_metric: search_metric not in result,
  1380. search_metrics,
  1381. )
  1382. )
  1383. if len(report_metric) == 1:
  1384. report_metric = report_metric[0]
  1385. location = type(self._search_alg).__name__
  1386. else:
  1387. report_metric = None
  1388. location = None
  1389. if report_metric:
  1390. raise ValueError(
  1391. "Trial returned a result which did not include the "
  1392. "specified metric(s) `{}` that `{}` expects. "
  1393. "Make sure your calls to `tune.report()` include the "
  1394. "metric, or set the "
  1395. "TUNE_DISABLE_STRICT_METRIC_CHECKING "
  1396. "environment variable to 1. Result: {}".format(
  1397. report_metric, location, result
  1398. )
  1399. )
  1400. ###
  1401. # SAVE
  1402. def _schedule_trial_save(
  1403. self,
  1404. trial: Trial,
  1405. result: Optional[Dict] = None,
  1406. ) -> Optional[_FutureTrainingResult]:
  1407. if trial not in self._trial_to_actor:
  1408. logger.debug(
  1409. f"Trial SAVE requested for trial {trial} but trial is already "
  1410. f"stopping. Ignoring."
  1411. )
  1412. return None
  1413. result = result or trial.last_result
  1414. future = self._schedule_trial_task(
  1415. trial=trial,
  1416. method_name="save",
  1417. on_result=self._on_saving_result,
  1418. on_error=self._trial_task_failure,
  1419. _return_future=True,
  1420. )
  1421. # TODO(justinvyu): `trial.saving_to` (and trial.is_saving) is needed
  1422. # in order to prevent a done=True result from executing a STOP decision
  1423. # (which clears all futures) before the save gets processed.
  1424. # Keep this in for now while `train` and `save` are 2 separate steps.
  1425. trial.temporary_state.saving_to = _FutureTrainingResult(future)
  1426. # `trial.saving_to` holds a future training result -- this is only used
  1427. # in the case of PBT to block until the checkpoint is ready.
  1428. # In all other situations, the checkpoint future is processed by the
  1429. # actor event manager when it is ready.
  1430. return trial.temporary_state.saving_to
  1431. def _on_saving_result(self, trial, checkpoint_value: _TrainingResult):
  1432. with warn_if_slow("process_trial_save"):
  1433. self._process_trial_save(trial, checkpoint_value)
  1434. with warn_if_slow("callbacks.on_trial_save"):
  1435. self._callbacks.on_trial_save(
  1436. iteration=self._iteration, trials=self._trials, trial=trial
  1437. )
  1438. self._maybe_execute_queued_decision(trial, after_save=True)
  1439. def _process_trial_save(self, trial: Trial, checkpoint_value: _TrainingResult):
  1440. """Processes a trial save.
  1441. Acts on the decision cached during the last `_process_trial` call.
  1442. Args:
  1443. trial: Trial being saved.
  1444. """
  1445. logger.debug("Trial %s: Processing trial save.", trial)
  1446. try:
  1447. if not checkpoint_value.checkpoint:
  1448. logger.debug(f"Got empty checkpoint for trial {trial}")
  1449. else:
  1450. try:
  1451. self._callbacks.on_checkpoint(
  1452. iteration=self._iteration,
  1453. trials=self._trials,
  1454. trial=trial,
  1455. checkpoint=checkpoint_value.checkpoint,
  1456. )
  1457. except Exception:
  1458. logger.warning(
  1459. "Error encountered during processing of callbacks. "
  1460. "Ray Train/Tune recently changed the checkpoint interface "
  1461. "that is passed to callbacks. If you implemented your own "
  1462. "callback with an `on_checkpoint` handler, please review "
  1463. "the checkpoint interface and adjust your code "
  1464. "accordingly."
  1465. )
  1466. raise
  1467. trial.on_checkpoint(checkpoint_value)
  1468. self._checkpoint_manager.on_trial_checkpoint(trial)
  1469. self._mark_trial_to_checkpoint(trial)
  1470. except Exception:
  1471. logger.exception(
  1472. "Trial %s: Error handling checkpoint %s", trial, checkpoint_value
  1473. )
  1474. trial.temporary_state.saving_to = None
  1475. decision = self._cached_trial_decisions.pop(trial.trial_id, None)
  1476. if decision and checkpoint_value:
  1477. self._queue_decision(trial, decision)
  1478. def _checkpoint_trial_if_needed(self, trial, force=False):
  1479. """Checkpoints trial based off trial.last_result."""
  1480. if trial.should_checkpoint() or force:
  1481. # Save trial runtime if possible.
  1482. if trial.temporary_state.ray_actor:
  1483. self._schedule_trial_save(trial)
  1484. ###
  1485. # RESTORE
  1486. def _schedule_trial_restore(self, trial: Trial) -> bool:
  1487. checkpoint_result = trial.latest_checkpoint_result
  1488. if not checkpoint_result:
  1489. logger.debug(f"Not restoring trial {trial}: No checkpoint found.")
  1490. return False
  1491. # TODO(justinvyu): Is this really needed?
  1492. trial.temporary_state.restoring_from = checkpoint_result
  1493. method_name = "restore"
  1494. args = (checkpoint_result,)
  1495. self._schedule_trial_task(
  1496. trial=trial,
  1497. method_name=method_name,
  1498. args=args,
  1499. kwargs={},
  1500. on_result=self._on_restoring_result,
  1501. on_error=self._trial_task_failure,
  1502. )
  1503. return True
  1504. def _on_restoring_result(self, trial: Trial, result: Any):
  1505. self._process_trial_restore(trial)
  1506. def _process_trial_restore(self, trial: Trial):
  1507. """Processes a trial restore.
  1508. Args:
  1509. trial: Trial being restored.
  1510. """
  1511. logger.debug("Trial %s: Processing trial restore.", trial)
  1512. trial.on_restore()
  1513. logger.debug("Trial %s: Restore processed successfully", trial)
  1514. self._set_trial_status(trial, Trial.RUNNING)
  1515. self._schedule_trial_train(trial)
  1516. self._live_trials.add(trial)
  1517. def _try_recover(
  1518. self, trial: Trial, exc: Union[TuneError, RayTaskError, RayActorError]
  1519. ):
  1520. """Tries to recover trial.
  1521. Notifies SearchAlgorithm and Scheduler if failure to recover.
  1522. Args:
  1523. trial: Trial to recover.
  1524. exc: Exception prior to invoking this method.
  1525. """
  1526. self._cached_trial_decisions.pop(trial.trial_id, None)
  1527. # Resetting this, in case that the trial is in saving status when it crashes.
  1528. if trial.is_saving:
  1529. trial.temporary_state.saving_to = None
  1530. self._schedule_trial_stop(trial, exception=exc)
  1531. logger.debug("Trial %s: Notifying Scheduler and requeueing.", trial)
  1532. self._requeue_trial(trial)
  1533. def _requeue_trial(self, trial):
  1534. """Notification to TrialScheduler and requeue trial.
  1535. This does not notify the SearchAlgorithm because the function
  1536. evaluation is still in progress.
  1537. """
  1538. self._scheduler_alg.on_trial_error(self, trial)
  1539. self._set_trial_status(trial, status=Trial.PENDING)
  1540. # TODO(rliaw): Right now, this pushes the trial to the end of queue
  1541. # because restoration can be expensive. However, this is not
  1542. # ideal since it just hides the issue - a better fix would
  1543. # be to use an actor table to detect the IP of the Trainable
  1544. # and rsync the files there.
  1545. # See https://github.com/ray-project/ray/issues/5168
  1546. self._trials.pop(self._trials.index(trial))
  1547. self._trials.append(trial)
  1548. self._live_trials.add(trial)
  1549. with warn_if_slow("scheduler.on_trial_add"):
  1550. self._scheduler_alg.on_trial_add(self._wrapped(), trial)
  1551. ###
  1552. # EXPORT
  1553. def _schedule_trial_export(self, trial: Trial):
  1554. if not trial.export_formats or len(trial.export_formats) <= 0:
  1555. return
  1556. # Todo: We are waiting here synchronously until the task resolved.
  1557. # Instead, we should schedule the trial stop after the export resolved.
  1558. # This requires changes in TrialRunner, which we can remove once the
  1559. # legacy execution path has been removed.
  1560. future = self._schedule_trial_task(
  1561. trial=trial,
  1562. method_name="export_model",
  1563. args=(trial.export_formats,),
  1564. on_result=None,
  1565. on_error=self._trial_task_failure,
  1566. _return_future=True,
  1567. )
  1568. self._actor_manager._actor_task_events.resolve_future(future)
  1569. ###
  1570. # RESET
  1571. def _schedule_trial_reset(
  1572. self,
  1573. trial: Trial,
  1574. new_config: Dict,
  1575. new_experiment_tag: str,
  1576. ):
  1577. trial.set_experiment_tag(new_experiment_tag)
  1578. trial.set_config(new_config)
  1579. # Pass magic variables
  1580. extra_config = copy.deepcopy(new_config)
  1581. extra_config[TRIAL_INFO] = _TrialInfo(trial)
  1582. stdout_file, stderr_file = trial.log_to_file
  1583. extra_config[STDOUT_FILE] = stdout_file
  1584. extra_config[STDERR_FILE] = stderr_file
  1585. logger_creator = partial(
  1586. _noop_logger_creator, logdir=trial.storage.trial_working_directory
  1587. )
  1588. self._resetting_trials.add(trial)
  1589. self._schedule_trial_task(
  1590. trial=trial,
  1591. method_name="reset",
  1592. args=(extra_config,),
  1593. kwargs={
  1594. "logger_creator": logger_creator,
  1595. "storage": trial.storage,
  1596. },
  1597. on_result=self._on_trial_reset,
  1598. on_error=self._trial_task_failure,
  1599. )
  1600. def _on_trial_reset(self, trial: Trial, success: bool):
  1601. self._resetting_trials.remove(trial)
  1602. if not success:
  1603. info = (
  1604. "Trainable runner reuse requires reset_config() to be "
  1605. "implemented and return True."
  1606. )
  1607. logger.error(f"Could not re-use actor for trial {trial}: {info}")
  1608. exception = _AbortTrialExecution(info)
  1609. trial.handle_error(exception)
  1610. self._schedule_trial_stop(trial, exception=exception)
  1611. return
  1612. tracked_actor = self._trial_to_actor[trial]
  1613. self._actor_started(tracked_actor, log="REUSED")
  1614. def request_stop_trial(self, trial):
  1615. self._stop_queue.append(trial)
  1616. def request_stop_experiment(self):
  1617. self._should_stop_experiment = True
  1618. def _process_stop_requests(self):
  1619. while self._stop_queue:
  1620. t = self._stop_queue.pop()
  1621. self.stop_trial(t)
  1622. def pause_trial(self, trial: Trial, should_checkpoint: bool = True):
  1623. """Pause a trial and reset the necessary state variables for resuming later.
  1624. Args:
  1625. trial: Trial to pause.
  1626. should_checkpoint: Whether or not an in-memory checkpoint should be created
  1627. for this paused trial. Defaults to True.
  1628. """
  1629. # NOTE: The cached trial decision is not needed since we will overrule this
  1630. # decision with PAUSE.
  1631. self._cached_trial_decisions.pop(trial.trial_id, None)
  1632. self._schedule_trial_pause(trial, should_checkpoint=should_checkpoint)
  1633. def cleanup(self):
  1634. """Cleanup trials and callbacks."""
  1635. self._cleanup_trials()
  1636. self.end_experiment_callbacks()
  1637. def __getstate__(self):
  1638. """Gets state for trial.
  1639. Note that this is not used as a pickling override as
  1640. does not have all fields.
  1641. """
  1642. state = self.__dict__.copy()
  1643. for k in [
  1644. "_trials",
  1645. "_live_trials",
  1646. "_stop_queue",
  1647. "_search_alg",
  1648. "_placeholder_resolvers",
  1649. "_scheduler_alg",
  1650. "_pending_trial_queue_times",
  1651. "_callbacks",
  1652. "_checkpoint_manager",
  1653. "_storage",
  1654. "_insufficient_resources_manager",
  1655. "_actor_manager",
  1656. "_class_cache",
  1657. "_resource_updater",
  1658. "_trials_to_cache",
  1659. "_trial_metadata",
  1660. "_actor_to_trial",
  1661. "_trial_to_actor",
  1662. "_resources_to_pending_trials",
  1663. "_pending_trials",
  1664. "_pending_trials_list",
  1665. "_running_trials",
  1666. "_paused_trials",
  1667. "_stopped_trials",
  1668. "_failed_trials",
  1669. "_resetting_trials",
  1670. "_started_actors",
  1671. "_stopping_actors",
  1672. "_staged_trials",
  1673. "_actor_cache",
  1674. ]:
  1675. del state[k]
  1676. return state
  1677. def __setstate__(self, state):
  1678. # Use session_str from previous checkpoint if does not exist
  1679. session_str = state.pop("_session_str")
  1680. self.__dict__.setdefault("_session_str", session_str)
  1681. # Use start_time from previous checkpoint if does not exist
  1682. start_time = state.pop("_start_time")
  1683. self.__dict__.setdefault("_start_time", start_time)
  1684. self.__dict__.update(state)
  1685. self._checkpoint_manager = self._create_checkpoint_manager()
  1686. class _TrialExecutorWrapper:
  1687. """Wraps around TrialExecutor class, intercepts API calls and warns users
  1688. of restricted API access.
  1689. This is meant to facilitate restricting
  1690. the current API exposure of TrialExecutor by TrialScheduler.
  1691. """
  1692. def __init__(
  1693. self,
  1694. trial_executor: "_FakeRayTrialExecutor",
  1695. whitelist_attr: Optional[set] = None,
  1696. ):
  1697. self._trial_executor = trial_executor
  1698. self._whitelist_attr = whitelist_attr or set()
  1699. for attr in self._whitelist_attr:
  1700. assert hasattr(self._trial_executor, attr)
  1701. def __getattr__(self, attr):
  1702. if attr not in self._whitelist_attr:
  1703. if log_once("restrict_accessing_trial_executor"):
  1704. logger.warning(
  1705. f"You are trying to access {attr} interface of "
  1706. f"TrialExecutor in TrialScheduler, which is being "
  1707. f"restricted. If you believe it is reasonable for "
  1708. f"your scheduler to access this TrialExecutor API, "
  1709. f"please reach out to Ray team on GitHub. A more "
  1710. f"strict API access pattern would be enforced "
  1711. f"starting 1.12.0"
  1712. )
  1713. return getattr(self._trial_executor, attr)
  1714. @DeveloperAPI
  1715. class TrialRunnerWrapper:
  1716. """Wraps around TrialRunner class, intercepts API calls and warns users
  1717. of restricted API access.
  1718. This is meant to facilitate restricting
  1719. the current API exposure of TrialRunner by TrialScheduler.
  1720. """
  1721. _EXECUTOR_ATTR = "trial_executor"
  1722. def __init__(
  1723. self,
  1724. tune_controller: TuneController,
  1725. trial_executor: Any,
  1726. runner_whitelist_attr: Optional[set] = None,
  1727. executor_whitelist_attr: Optional[set] = None,
  1728. ):
  1729. self._tune_controller = tune_controller
  1730. self._trial_executor = _TrialExecutorWrapper(
  1731. trial_executor, executor_whitelist_attr
  1732. )
  1733. self._runner_whitelist_attr = runner_whitelist_attr or set()
  1734. for attr in self._runner_whitelist_attr:
  1735. assert hasattr(self, attr)
  1736. def __getattr__(self, attr):
  1737. if attr == self._EXECUTOR_ATTR:
  1738. return self._trial_executor
  1739. if attr not in self._runner_whitelist_attr:
  1740. if log_once("restrict_accessing_tune_controller"):
  1741. logger.warning(
  1742. f"You are trying to access {attr} interface of "
  1743. f"TrialRunner in TrialScheduler, which is being "
  1744. f"restricted. If you believe it is reasonable for "
  1745. f"your scheduler to access this TrialRunner API, "
  1746. f"please reach out to Ray team on GitHub. A more "
  1747. f"strict API access pattern would be enforced "
  1748. f"starting 1.12s.0"
  1749. )
  1750. return getattr(self._tune_controller, attr)
  1751. def _get_max_pending_trials(search_alg: SearchAlgorithm) -> int:
  1752. max_pending_trials = os.getenv("TUNE_MAX_PENDING_TRIALS_PG", "auto")
  1753. if max_pending_trials != "auto":
  1754. return int(max_pending_trials)
  1755. # Else, auto detect.
  1756. # Only BasicVariantGenerator supports > 1 pending trials.
  1757. # This is because we don't want to generate too many trials
  1758. # before we fit the searcher model.
  1759. if not isinstance(search_alg, BasicVariantGenerator):
  1760. return 1
  1761. # Allow up to at least 200 pending trials to trigger fast autoscaling
  1762. min_autoscaling_rate = 200
  1763. # Allow more pending trials for larger clusters (based on number of CPUs)
  1764. cluster_cpus = ray.cluster_resources().get("CPU", 1.0)
  1765. max_pending_trials = max(min_autoscaling_rate, int(cluster_cpus * 1.1))
  1766. if max_pending_trials > min_autoscaling_rate:
  1767. logger.warning(
  1768. f"The maximum number of pending trials has been "
  1769. f"automatically set to the number of available "
  1770. f"cluster CPUs, which is high "
  1771. f"({max_pending_trials} CPUs/pending trials). "
  1772. f"If you're running an experiment with a large number "
  1773. f"of trials, this could lead to scheduling overhead. "
  1774. f"In this case, consider setting the "
  1775. f"`TUNE_MAX_PENDING_TRIALS_PG` environment variable "
  1776. f"to the desired maximum number of concurrent pending trials."
  1777. )
  1778. return max_pending_trials
  1779. class _FakeRayTrialExecutor:
  1780. """The TuneController does not use a RayTrialExecutor anymore.
  1781. Instead, we pass this fake executor for searchers/schedulers to use
  1782. as an interface.
  1783. In the future, we should have the searchers/schedulers either interact with
  1784. the tune controller, or define a different API for more fine-grained scheduler
  1785. control.
  1786. """
  1787. def __init__(self, tune_controller: TuneController):
  1788. self._tune_controller = tune_controller
  1789. def pause_trial(self, trial: Trial, should_checkpoint: bool = True):
  1790. return self._tune_controller._schedule_trial_pause(
  1791. trial, should_checkpoint=should_checkpoint
  1792. )
  1793. def save(
  1794. self,
  1795. trial: Trial,
  1796. result: Optional[Dict] = None,
  1797. ) -> Optional[_FutureTrainingResult]:
  1798. return self._tune_controller._schedule_trial_save(trial=trial, result=result)
  1799. def has_resources_for_trial(self, trial: Trial):
  1800. return True
  1801. @property
  1802. def _resource_updater(self):
  1803. return self._tune_controller._resource_updater
  1804. def force_reconcilation_on_next_step_end(self):
  1805. pass