import copy import json import logging import os import time import traceback import warnings from collections import defaultdict, deque from datetime import datetime from functools import partial from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import ray from ray.air import ResourceRequest from ray.air.constants import TIME_THIS_ITER_S from ray.air.execution import PlacementGroupResourceManager, ResourceManager from ray.air.execution._internal import RayActorManager, TrackedActor from ray.exceptions import RayActorError, RayTaskError from ray.train._internal.session import _FutureTrainingResult, _TrainingResult from ray.train._internal.storage import StorageContext from ray.tune import CheckpointConfig from ray.tune.callback import Callback, CallbackList from ray.tune.error import TuneError, _AbortTrialExecution, _TuneStopTrialError from ray.tune.execution.class_cache import _ActorClassCache from ray.tune.execution.experiment_state import ( _ExperimentCheckpointManager, _find_newest_experiment_checkpoint, ) from ray.tune.execution.insufficient_resources_manager import ( _InsufficientResourcesManager, ) from ray.tune.execution.placement_groups import PlacementGroupFactory from ray.tune.experiment import Experiment, Trial from ray.tune.experiment.trial import ( _change_working_directory, _get_trainable_kwargs, _Location, _noop_logger_creator, _TrialInfo, ) from ray.tune.result import ( DEBUG_METRICS, DEFAULT_METRIC, DONE, RESULT_DUPLICATE, SHOULD_CHECKPOINT, STDERR_FILE, STDOUT_FILE, TRIAL_INFO, ) from ray.tune.schedulers import FIFOScheduler, TrialScheduler from ray.tune.search import BasicVariantGenerator, SearchAlgorithm from ray.tune.stopper import NoopStopper, Stopper from ray.tune.tune_config import ResumeConfig from ray.tune.utils import flatten_dict, warn_if_slow from ray.tune.utils.log import Verbosity, _dedup_logs, has_verbosity from ray.tune.utils.object_cache import _ObjectCache from ray.tune.utils.resource_updater import _ResourceUpdater from ray.tune.utils.serialization import TuneFunctionDecoder, TuneFunctionEncoder from ray.util.annotations import DeveloperAPI from ray.util.debug import log_once logger = logging.getLogger(__name__) @DeveloperAPI class TuneController: CKPT_FILE_TMPL = "experiment_state-{}.json" RAISE = "RAISE" def __init__( self, *, search_alg: Optional[SearchAlgorithm] = None, placeholder_resolvers: Optional[Dict[Tuple, Any]] = None, scheduler: Optional[TrialScheduler] = None, stopper: Optional[Stopper] = None, resume_config: Optional[ResumeConfig] = None, fail_fast: bool = False, checkpoint_period: Union[str, int] = None, callbacks: Optional[List[Callback]] = None, metric: Optional[str] = None, trial_checkpoint_config: Optional[CheckpointConfig] = None, storage: Optional[StorageContext] = None, reuse_actors: bool = False, resource_manager_factory: Optional[Callable[[], ResourceManager]] = None, _trainer_api: bool = False, ): if resource_manager_factory: resource_manager = resource_manager_factory() else: resource_manager = PlacementGroupResourceManager() self._actor_manager = RayActorManager(resource_manager=resource_manager) self._class_cache = _ActorClassCache() # Resource status self._resource_updater = _ResourceUpdater(None) # Actor <-> Trial mappings self._actor_to_trial: Dict[TrackedActor, Trial] = {} self._trial_to_actor: Dict[Trial, TrackedActor] = {} # Resources <-> Trial self._resources_to_pending_trials: Dict[ ResourceRequest, Set[Trial] ] = defaultdict(set) # Keep track of actor states self._pending_trials: Set[Trial] = set() self._pending_trials_list: List[Trial] = [] self._running_trials: Set[Trial] = set() self._paused_trials: Set[Trial] = set() self._stopped_trials: Set[Trial] = set() self._failed_trials: Set[Trial] = set() self._resetting_trials: Set[Trial] = set() self._staged_trials: Set[Trial] = set() # Removed actors self._started_actors: Set[TrackedActor] = set() # Map of tracked actors -> timestamp # The timestamp is when we requested the stop. # We track these actors here to force a # cleanup after some time (as they might be hanging). # Todo: This timeout logic should be moved into the actor manager. # This map is populated whenever we request an actor stop: # - Regular STOP decision # - Removing an actor because its trial REUSEs a different trial's actor # - Removing a cached actor because it's not needed anymore # Actors are only tracked in this map if they actually started (not if they # were only requested but never started). # Actors are removed from this map: # - When the STOP resolved and the actor actually stopped # - When they are forcefully cleaned up after the timeout. self._stopping_actors: Dict[TrackedActor, float] = {} self._earliest_stopping_actor: float = float("inf") self._actor_cleanup_timeout: int = int( os.environ.get("TUNE_FORCE_TRIAL_CLEANUP_S", "600") ) self._actor_force_cleanup_timeout: int = 10 # Reuse actors self._reuse_actors = reuse_actors self._actor_cache = _ObjectCache(may_keep_one=True) # Trial metadata for experiment checkpoints self._trials_to_cache: Set[Trial] = set() self._trial_metadata: Dict[str, str] = {} # TRAINING self._buffer_length = int(os.getenv("TUNE_RESULT_BUFFER_LENGTH", 1)) self._buffer_min_time_s = float(os.getenv("TUNE_RESULT_BUFFER_MIN_TIME_S", 0.0)) self._buffer_max_time_s = float( os.getenv("TUNE_RESULT_BUFFER_MAX_TIME_S", 100.0) ) # Legacy TrialRunner init self._search_alg = search_alg or BasicVariantGenerator() self._placeholder_resolvers = placeholder_resolvers self._scheduler_alg = scheduler or FIFOScheduler() self._callbacks = CallbackList(callbacks or []) self._insufficient_resources_manager = _InsufficientResourcesManager( for_train=_trainer_api ) self._pending_trial_queue_times = {} self._max_pending_trials = _get_max_pending_trials(self._search_alg) self._storage = storage self._metric = metric self._total_time = 0 self._iteration = 0 self._has_errored = False self._fail_fast = fail_fast if isinstance(self._fail_fast, str): self._fail_fast = self._fail_fast.upper() if self._fail_fast == self.RAISE: warnings.warn( "fail_fast='raise' detected. Be careful when using this " "mode as resources (such as Ray processes, " "file descriptors, and temporary files) may not be " "cleaned up properly. To use " "a safer mode, use fail_fast=True." ) else: raise ValueError( "fail_fast must be one of {bool, RAISE}. " f"Got {self._fail_fast}." ) self._print_trial_errors = bool( int(os.environ.get("TUNE_PRINT_ALL_TRIAL_ERRORS", "1")) ) self._trials: List[Trial] = [] self._live_trials: Set[Trial] = set() # Set of non-terminated trials self._cached_trial_decisions = {} self._queued_trial_decisions = {} self._stop_queue = [] self._should_stop_experiment = False # used by TuneServer self._stopper = stopper or NoopStopper() self._start_time = time.time() self._session_str = datetime.fromtimestamp(self._start_time).strftime( "%Y-%m-%d_%H-%M-%S" ) if checkpoint_period is None: checkpoint_period = os.getenv("TUNE_GLOBAL_CHECKPOINT_S", "auto") self._checkpoint_period = checkpoint_period self._trial_checkpoint_config = trial_checkpoint_config or CheckpointConfig() self._checkpoint_manager = self._create_checkpoint_manager() self._resumed = False if resume_config is not None: # Use the metadata file to restore TuneController state try: self.resume(resume_config=resume_config) self._resumed = True except Exception as e: if has_verbosity(Verbosity.V3_TRIAL_DETAILS): logger.error(str(e)) logger.exception("Failed to restore the run state.") if self._fail_fast: raise logger.info("Restarting experiment.") else: logger.debug("Starting a new experiment.") def _wrapped(self): """Return wrapped tune controller to be passed to scheduler/searchers.""" return TrialRunnerWrapper( self, trial_executor=_FakeRayTrialExecutor(self), runner_whitelist_attr={ "search_alg", "get_trials", "get_live_trials", "_set_trial_status", "pause_trial", "stop_trial", "_schedule_trial_save", }, executor_whitelist_attr={ "has_resources_for_trial", "pause_trial", "save", "_resource_updater", }, ) @property def resumed(self): return self._resumed @property def search_alg(self): return self._search_alg @property def scheduler_alg(self): return self._scheduler_alg def setup_experiments( self, experiments: List[Experiment], total_num_samples: int ) -> None: """Obtains any necessary information from experiments. Mainly used to setup callbacks. Args: experiments: List of Experiments to use. total_num_samples: Total number of samples factoring in grid search samplers. """ experiment = experiments[0] spec = experiment.public_spec if experiment else {} spec["total_num_samples"] = total_num_samples self._callbacks.setup(**spec) def end_experiment_callbacks(self) -> None: """Calls ``on_experiment_end`` method in callbacks.""" self._callbacks.on_experiment_end(trials=self._trials) @property def experiment_state_file_name(self) -> str: return self.CKPT_FILE_TMPL.format(self._session_str) @property def experiment_state_path(self) -> str: """Returns the local experiment checkpoint path.""" return Path( self._storage.experiment_driver_staging_path, self.experiment_state_file_name, ).as_posix() @property def experiment_path(self) -> str: return self._storage.experiment_fs_path def _create_checkpoint_manager(self): return _ExperimentCheckpointManager( storage=self._storage, checkpoint_period=self._checkpoint_period, sync_every_n_trial_checkpoints=self._trial_checkpoint_config.num_to_keep, ) def save_to_dir(self): """Save TuneController state to the local staging experiment directory. This includes: - trial states - TuneController internal state (all the serializable attributes) - the searcher state - the callback states """ # Get state from trial executor and runner runner_state = { # Trials "trial_data": list(self._get_trial_checkpoints().values()), # Experiment data "runner_data": self.__getstate__(), # Metadata "stats": {"start_time": self._start_time}, } driver_staging_path = self._storage.experiment_driver_staging_path os.makedirs(driver_staging_path, exist_ok=True) with open( Path(driver_staging_path, self.experiment_state_file_name), "w", ) as f: json.dump(runner_state, f, cls=TuneFunctionEncoder) self._search_alg.save_to_dir(driver_staging_path, session_str=self._session_str) self._callbacks.save_to_dir(driver_staging_path, session_str=self._session_str) def checkpoint(self, force: bool = False, wait: bool = False): self._checkpoint_manager.sync_up_experiment_state( save_fn=self.save_to_dir, force=force, wait=wait ) def _requeue_restored_trials( self, trials: List[Trial], resume_config: ResumeConfig ): # Set trial statuses according to the resume configuration for trial in sorted( trials, key=lambda t: t.run_metadata.last_result_time, reverse=True ): if trial.status == Trial.ERROR: resume_type = resume_config.errored elif trial.status == Trial.TERMINATED: resume_type = resume_config.finished else: # Unfinished (PENDING, RUNNING, PAUSED) resume_type = resume_config.unfinished trial_to_add = None if resume_type == ResumeConfig.ResumeType.RESUME: # Keep trial ID on resume trial_to_add = trial trial_to_add.run_metadata.error_filename = None trial_to_add.run_metadata.pickled_error_filename = None trial_to_add.set_status(Trial.PENDING) elif resume_type == ResumeConfig.ResumeType.RESTART: trial_to_add = trial.reset() trial_to_add.restore_path = None elif resume_type == ResumeConfig.ResumeType.SKIP: trial_to_add = trial if trial_to_add.status != Trial.ERROR: # Set the status to terminated to skip it. # Keep errored trial status as ERROR. trial_to_add.set_status(Trial.TERMINATED) else: raise ValueError(f"Unknown resume type: {resume_type}") assert trial_to_add is not None self.add_trial(trial_to_add) def _restore_trials(self, experiment_state: Dict) -> List[Trial]: trials = [] for trial_json_state, trial_runtime_metadata in experiment_state["trial_data"]: trial = Trial.from_json_state(trial_json_state) trial.restore_run_metadata(trial_runtime_metadata) # The following properties may be updated on restoration # Ex: moved local/cloud experiment directory # Propagate updated storage ctx properties to the trial's restored copy. new_storage = copy.copy(trial.storage) new_storage.storage_filesystem = self._storage.storage_filesystem new_storage.storage_fs_path = self._storage.storage_fs_path new_storage.experiment_dir_name = self._storage.experiment_dir_name # ATTN: `trial.set_storage` is used intentionally, since it # also updates the absolute paths and filesystem of tracked checkpoints. trial.set_storage(new_storage) # Avoid creating logdir in client mode for returned trial results, # since the dir might not be creatable locally. # TODO(ekl) this is kind of a hack. if not ray.util.client.ray.is_connected(): trial.init_local_path() # Create logdir if it does not exist trials.append(trial) # NOTE: The restored run should reuse the same driver staging directory. self._storage._timestamp = trials[0].storage._timestamp return trials def resume(self, resume_config: ResumeConfig): """Resumes all checkpointed trials from previous run. Requires user to manually re-register their objects. Also stops all ongoing trials. """ # 1. Restore TuneController state # Find newest state file newest_state_path = _find_newest_experiment_checkpoint( self._storage.experiment_fs_path, fs=self._storage.storage_filesystem ) if newest_state_path is None: raise ValueError( f"Tried to resume experiment from directory " f"'{self._storage.experiment_fs_path}', but no " f"experiment state file of the form '{TuneController.CKPT_FILE_TMPL}' " "was found. This is expected if you are launching a new experiment." ) logger.info( "Restoring the run from the latest experiment state file: " f"{Path(newest_state_path).name}" ) with self._storage.storage_filesystem.open_input_stream(newest_state_path) as f: experiment_state = json.loads(f.readall(), cls=TuneFunctionDecoder) self.__setstate__(experiment_state["runner_data"]) # 2. Get the trial states that the run left off at. trials = self._restore_trials(experiment_state) # 3. Restore search algorithm and callback state # Download the search algorithm and callback state to the driver staging dir. self._checkpoint_manager.sync_down_experiment_state() driver_staging_dir = self._storage.experiment_driver_staging_path if self._search_alg.has_checkpoint(driver_staging_dir): self._search_alg.restore_from_dir(driver_staging_dir) if self._callbacks.can_restore(driver_staging_dir): self._callbacks.restore_from_dir(driver_staging_dir) # 4. Re-queue trials as needed, depending on their status. self._requeue_restored_trials(trials, resume_config) def update_max_pending_trials(self, max_pending_trials: Optional[int] = None): self._max_pending_trials = max_pending_trials or _get_max_pending_trials( self._search_alg ) def update_pending_trial_resources( self, resources: Union[dict, PlacementGroupFactory] ): """Update trial resources when resuming from checkpoint. Only updating the pending ones. """ assert resources if isinstance(resources, dict) and "gpu" not in resources: resources["gpu"] = 0 for trial in self._trials: if trial.status == Trial.PENDING: trial.update_resources(resources=resources) def is_finished(self): """Returns whether all trials have finished running.""" # The checks here are partly redundant but optimized for quick # evaluation. Specifically, if there are live trials, we check # these live trials first. Only if none of the live trials is # live anymore do we loop over all trials for a final check. trials_done = ( len(self._live_trials) == 0 or all(trial.is_finished() for trial in self._live_trials) ) and all(trial.is_finished() for trial in self._trials) return trials_done and self._search_alg.is_finished() def get_trial(self, tid): trial = [t for t in self._trials if t.trial_id == tid] return trial[0] if trial else None def get_trials(self): """Returns the list of trials managed by this TrialRunner. Note that the caller usually should not mutate trial state directly. """ return self._trials def get_live_trials(self): """Returns the set of trials that are not in Trial.TERMINATED state.""" return self._live_trials def add_trial(self, trial: Trial): """Adds a new trial to this TrialRunner. Trials may be added at any time. Args: trial: Trial to queue. """ # If the config map has had all the references replaced with placeholders, # resolve them before adding the trial. if self._placeholder_resolvers: trial.resolve_config_placeholders(self._placeholder_resolvers) # With trial.config resolved, create placement group factory if needed. trial.create_placement_group_factory() self._trials.append(trial) if trial.status != Trial.TERMINATED: self._live_trials.add(trial) with warn_if_slow("scheduler.on_trial_add"): self._scheduler_alg.on_trial_add(self._wrapped(), trial) self._mark_trial_to_checkpoint(trial) logger.debug(f"Adding trial {trial} with status {trial.status}") status_str_map = { Trial.PENDING: self._pending_trials, Trial.RUNNING: self._running_trials, Trial.PAUSED: self._paused_trials, Trial.TERMINATED: self._stopped_trials, Trial.ERROR: self._failed_trials, } status_str_map[trial.status].add(trial) if trial.status == Trial.PENDING: self._pending_trials_list.append(trial) self._resources_to_pending_trials[trial.placement_group_factory].add(trial) def _update_trial_queue(self, blocking: bool = False, timeout: int = 600) -> bool: """Adds next trials to queue if possible. Note that the timeout is currently unexposed to the user. Args: blocking: Blocks until either a trial is available or is_finished (timeout or search algorithm finishes). timeout: Seconds before blocking times out. Returns: Boolean indicating if a new trial was created or not. """ trial = self._search_alg.next_trial() if blocking and not trial: start = time.time() # Checking `is_finished` instead of _search_alg.is_finished # is fine because blocking only occurs if all trials are # finished and search_algorithm is not yet finished while ( not trial and not self.is_finished() and time.time() - start < timeout ): logger.debug("Blocking for next trial...") trial = self._search_alg.next_trial() time.sleep(1) if trial: self.add_trial(trial) return True return False def _used_resources_string(self) -> str: allocated_resources = self._actor_manager.get_live_actors_resources() return self._resource_updater.debug_string(allocated_resources) def on_step_begin(self): self._resource_updater.update_avail_resources() def on_step_end(self): self._cleanup_cached_actors(force_all=False) self._cleanup_stopping_actors(force_all=False) def _cleanup_cached_actors(self, force_all: bool = False): if ( self._search_alg.is_finished() and not self._staged_trials and self._actor_cache.total_max_objects == 0 ): # If there are no more trials coming in, no trials are pending execution, # and we don't explicitly want to cache objects, we can evict the full # cache. force_all = True for tracked_actor in self._actor_cache.flush_cached_objects( force_all=force_all ): logger.debug(f"Cleaning up cached actor: {tracked_actor}") # Unset termination callbacks as no trial is associated tracked_actor.set_on_stop(None) tracked_actor.set_on_error(None) self._remove_actor(tracked_actor=tracked_actor) def _cleanup_stopping_actors(self, force_all: bool = False): now = time.monotonic() if ( not force_all and now - self._earliest_stopping_actor <= self._actor_cleanup_timeout ): # If the earliest actor to timeout has not reached the timeout, return return # This is a bit costly, so we want to avoid running it too often times = deque( sorted( [ (timestamp, tracked_actor) for tracked_actor, timestamp in self._stopping_actors.items() ], key=lambda item: item[0], ) ) while times and ( force_all or time.monotonic() - times[0][0] > self._actor_cleanup_timeout ): if ( time.monotonic() - times[0][0] < self._actor_force_cleanup_timeout ) and self._actor_manager.is_actor_started(tracked_actor=times[0][1]): # Even if force_all=True, we give the actors time to clean up self._actor_manager.next(timeout=1) continue _, tracked_actor = times.popleft() if tracked_actor not in self._stopping_actors: # Actor stopping has been handled by the block above continue if self._actor_manager.is_actor_started(tracked_actor=tracked_actor): logger.debug(f"Forcefully killing actor: {tracked_actor}") self._actor_manager.remove_actor(tracked_actor=tracked_actor, kill=True) self._stopping_actors.pop(tracked_actor) if times: self._earliest_stopping_actor = times[0][0] else: self._earliest_stopping_actor = float("inf") def step(self): if self.is_finished(): raise TuneError("Called step when all trials finished?") with warn_if_slow("on_step_begin"): self.on_step_begin() with warn_if_slow("callbacks.on_step_begin"): self._callbacks.on_step_begin( iteration=self._iteration, trials=self._trials ) # Ask searcher for more trials self._maybe_update_trial_queue() # Start actors for added trials self._maybe_add_actors() # Handle one event if not self._actor_manager.next(timeout=0.1): # If there are no actors running, warn about potentially # insufficient resources if not self._actor_manager.num_live_actors: self._insufficient_resources_manager.on_no_available_trials( self.get_trials() ) # Maybe stop whole experiment self._stop_experiment_if_needed() # Maybe save experiment state try: self.checkpoint() except Exception as e: logger.warning(f"Trial controller checkpointing failed: {str(e)}") raise e self._iteration += 1 with warn_if_slow("on_step_end"): self.on_step_end() with warn_if_slow("callbacks.on_step_end"): self._callbacks.on_step_end(iteration=self._iteration, trials=self._trials) def _set_trial_status(self, trial: Trial, status: str): """Set trial to a specific status. This will keep track of trials with specific statuses in sets. For PENDING and PAUSED trials we also keep a list of trials to be able to retain FIFO ordering. See ``_maybe_add_actors`` for details. Lastly we also keep a mapping from resources to pending/paused trials to be able to efficiently start trials for cached actors. """ current_status = trial.status if current_status == status: logger.debug(f"Trial {trial} already has status {status}. Skipping update.") return status_str_map = { Trial.PENDING: self._pending_trials, Trial.RUNNING: self._running_trials, Trial.PAUSED: self._paused_trials, Trial.TERMINATED: self._stopped_trials, Trial.ERROR: self._failed_trials, } logger.debug( f"Setting status for trial {trial} from {current_status} to {status}" ) assert trial in status_str_map[current_status], (trial, current_status) assert trial not in status_str_map[status], (trial, status) status_str_map[current_status].remove(trial) status_str_map[status].add(trial) # We keep a log for pending trials for FIFO scheduling. # We do not need to remove from this list as we will just discard # items that are in this list but not in the respective set. if status == Trial.PENDING: self._pending_trials_list.append(trial) self._resources_to_pending_trials[trial.placement_group_factory].add(trial) else: self._resources_to_pending_trials[trial.placement_group_factory].discard( trial ) trial.set_status(status) def _get_trial_checkpoints(self) -> Dict[str, str]: for trial in self._trials_to_cache: self._trial_metadata[trial.trial_id] = trial.get_json_state() self._trials_to_cache.clear() return self._trial_metadata def _mark_trial_to_checkpoint(self, trial: Trial): self._trials_to_cache.add(trial) ### # UPDATE TRIALS def _maybe_update_trial_queue(self): """Ask the searcher for more trials.""" if self._search_alg.is_finished(): return dont_wait_for_trial = ( self._pending_trials or self._running_trials or self._paused_trials ) while len(self._pending_trials) < self._max_pending_trials: if not self._update_trial_queue(blocking=not dont_wait_for_trial): break dont_wait_for_trial = True def _cleanup_trials(self): logger.debug("CLEANING UP all trials") for tracked_actor in list(self._actor_to_trial): trial = self._actor_to_trial[tracked_actor] logger.debug( f"Scheduling trial stop at end of experiment (trial {trial}): " f"{tracked_actor}" ) self._schedule_trial_stop(trial) # Clean up cached actors now self._cleanup_cached_actors(force_all=True) start = time.monotonic() while time.monotonic() - start < 5 and self._actor_manager.num_total_actors: if _dedup_logs("actor_manager_cleanup", str(start)): logger.debug( "Waiting for actor manager to clean up final state [dedup]" ) self._actor_manager.next(timeout=1) logger.debug("Force cleanup of remaining actors") self._cleanup_stopping_actors(force_all=True) self._actor_manager.cleanup() def _remove_actor(self, tracked_actor: TrackedActor): stop_future = self._actor_manager.schedule_actor_task( tracked_actor, "stop", _return_future=True ) now = time.monotonic() if self._actor_manager.remove_actor( tracked_actor, kill=False, stop_future=stop_future ): # If the actor was previously alive, track self._stopping_actors[tracked_actor] = now self._earliest_stopping_actor = min(self._earliest_stopping_actor, now) ### # ADD ACTORS def _maybe_add_actors(self) -> None: """Add actors for pending and paused trials. For actors that have not been staged, yet, we request an actor. For actors that have been staged, already, we try to reuse a cached actor. First, we handle the trial that the scheduler chooses to run. Then, we handle all trials that are pending. Lastly, we see if we have cached actors that we can assign to a pending or paused trial. This can be the case when a trial has not been staged, yet, for instance because the number of staging trials was too large. """ ### # 1: Start trial that the scheduler wants to run with warn_if_slow("choose_trial_to_run"): trial_to_run = self._scheduler_alg.choose_trial_to_run(self._wrapped()) if trial_to_run: if _dedup_logs("trial_to_run_chosen", trial_to_run.trial_id): logger.debug( f"Chose trial to run from scheduler: {trial_to_run} [dedup]" ) if ( trial_to_run not in self._staged_trials and trial_to_run not in self._trial_to_actor ): logger.debug(f"Staging trial to run: {trial_to_run}") self._set_trial_status(trial_to_run, Trial.PENDING) self._staged_trials.add(trial_to_run) self._actor_cache.increase_max(trial_to_run.placement_group_factory) # schedule_trial_actor also potentially uses cached actors self._schedule_trial_actor(trial_to_run) else: # Otherwise, only try to use the cached actor if _dedup_logs("trial_to_run_reuse", trial_to_run.trial_id): logger.debug( f"Trying to re-use actor for trial to run: {trial_to_run} " f"[dedup]" ) self._maybe_reuse_cached_actor(trial_to_run) ### # 2: Start trials that are PENDING def _maybe_add_actors(candidates: List[Trial]): new_candidates = [] while candidates: if self._actor_manager.num_pending_actors >= self._max_pending_trials: break trial = candidates.pop(0) # If the trial is part of the list, but not of the set, # we just ignore it. Removing it from the list on status # change is too expensive. if trial not in self._pending_trials: continue if trial in self._trial_to_actor: new_candidates.append(trial) continue if trial in self._staged_trials: self._maybe_reuse_cached_actor(trial) continue logger.debug(f"Scheduling actor for enqueued trial: {trial}") self._staged_trials.add(trial) self._actor_cache.increase_max(trial.placement_group_factory) self._schedule_trial_actor(trial) return new_candidates + candidates self._pending_trials_list = _maybe_add_actors(self._pending_trials_list) ### # 3: Start any trial that can be started with a cached actor if self._actor_cache.num_cached_objects: for resource in self._resources_to_pending_trials: if not self._resources_to_pending_trials[resource]: continue if not self._actor_cache.has_cached_object(resource): continue start_trial = self._resources_to_pending_trials[resource].pop() logger.debug( f"Trying to re-use actor for enqueued trial: {start_trial}" ) if not self._maybe_reuse_cached_actor(start_trial): self._resources_to_pending_trials[resource].add(start_trial) else: if start_trial not in self._staged_trials: self._staged_trials.add(start_trial) self._actor_cache.increase_max( start_trial.placement_group_factory ) def _maybe_reuse_cached_actor(self, trial: Trial) -> bool: """Maybe reuse a cached actor for a trial. If an actor has been scheduled for the trial already, this will remove the original actor. """ if trial in self._resetting_trials: return True resource_request = trial.placement_group_factory if not self._actor_cache.has_cached_object(resource_request): return False cached_actor = self._actor_cache.pop_cached_object(resource_request) logger.debug(f"Reusing ACTOR for trial {trial}: {cached_actor}") if trial in self._trial_to_actor: original_actor = self._trial_to_actor.pop(trial) self._actor_to_trial.pop(original_actor) logger.debug(f"Removing ORIGINAL ACTOR for trial {trial}: {original_actor}") self._remove_actor(tracked_actor=original_actor) self._trial_to_actor[trial] = cached_actor self._actor_to_trial[cached_actor] = trial # Todo: get rid of Trial.runner ray_actor = self._actor_manager._live_actors_to_ray_actors_resources[ cached_actor ][0] trial.set_ray_actor(ray_actor) self._schedule_trial_reset(trial, trial.config, trial.experiment_tag) return True def _schedule_trial_actor(self, trial: Trial): """Schedule an actor for a trial. If a cached actor is available, use it. Otherwise, request a new actor. """ logger.debug(f"Trying to schedule new ACTOR for trial {trial}") assert trial.status == Trial.PENDING trial.init_local_path() # We checkpoint metadata here to try mitigating logdir duplication self._mark_trial_to_checkpoint(trial) if self._maybe_reuse_cached_actor(trial): return # Safeguard if trial in self._trial_to_actor: raise RuntimeError( f"Tried to request a new actor for trial {trial}, but an old " f"actor still exists. This can lead to leaked resources. The old " f"actor should be removed first. " f"This is an internal problem in Ray Tune. If you encounter this " f"error, please raise an issue on " f"https://github.com/ray-project/ray/issues" ) trainable_cls = trial.get_trainable_cls() if not trainable_cls: exception = _AbortTrialExecution( f"Invalid trainable: {trial.trainable_name}. If you passed " f"a string, make sure the trainable was registered before." ) trial.handle_error(exception) self._schedule_trial_stop(trial, exception=exception) return _actor_cls = self._class_cache.get(trainable_cls) trial.set_location(_Location()) trainable_kwargs = _get_trainable_kwargs(trial=trial) with _change_working_directory(trial): tracked_actor = self._actor_manager.add_actor( cls=_actor_cls, resource_request=trial.placement_group_factory, kwargs=trainable_kwargs, on_start=self._actor_started, on_stop=self._actor_stopped, on_error=self._actor_failed, ) self._trial_to_actor[trial] = tracked_actor self._actor_to_trial[tracked_actor] = trial logger.debug( f"Scheduled new ACTOR for trial {trial}: {tracked_actor}. " f"Resources: {trial.placement_group_factory}" ) def _unstage_trial_with_resources(self, trial: Trial): """Unstage trial, or one with the same resources as ``trial``.""" # Case 1: The trial we started was staged. Just remove it if trial in self._staged_trials: self._staged_trials.remove(trial) self._actor_cache.decrease_max(trial.placement_group_factory) return # Case 2: We staged a trial "A" with the same resources, but our trial "B" # was selected by the scheduler to run. The resource manager does not care # about "trials", it just cares about resources being available. Thus we # look for a staged trial with the same resource requirements and remove it resource_request = trial.placement_group_factory # Remove staged trial with same resource requirements candidate_trial = None for staged_trial in self._staged_trials: staged_resources = staged_trial.placement_group_factory if staged_resources == resource_request: candidate_trial = staged_trial break if candidate_trial: self._staged_trials.remove(candidate_trial) self._actor_cache.decrease_max(candidate_trial.placement_group_factory) return raise RuntimeError( "Started a trial with resources requested by a different trial, but " "this trial was lost. This is an error in Ray Tune's execution " "logic. Please raise a GitHub issue at " "https://github.com/ray-project/ray/issues" ) def _maybe_cache_trial_actor(self, trial: Trial) -> bool: """Cache trial actor for reuse, if needed. We will only cache as many actors as are needed to fulfill any pending resource requests for actors with the same resource requirements. E.g. if we have 6 running trials and 4 additional staged actors, we will only cache up to 4 of the running trial actors when they finish. One exception is the case when we have no cached actors, yet. In that case, we will always cache the actor in this method. Later, in `_cleanup_cached_actors`, we will check again if we need this cached actor. That method will keep the actor if we don't have any staged trials, because we don't know at that point if the next trial might require the same resources. But because there is no staged trial, it is safe to keep the actor around, as it won't occupy resources needed by another trial until it's staged. """ if not self._reuse_actors: return False if self._search_alg.is_finished() and not self._staged_trials: logger.debug( f"Not caching actor of trial {trial} as the search is over " f"and no more trials are staged." ) return False tracked_actor = self._trial_to_actor[trial] if ( not self._actor_manager.is_actor_started(tracked_actor) or self._actor_manager.is_actor_failed(tracked_actor) or tracked_actor not in self._started_actors ): logger.debug( f"Not caching actor of trial {trial} as it has not been started, yet: " f"{tracked_actor}" ) return False if not self._actor_cache.cache_object( trial.placement_group_factory, tracked_actor ): logger.debug( f"Could not cache actor of trial {trial} for " "reuse, as there are no pending trials " "requiring its resources." ) return False logger.debug(f"Caching actor of trial {trial} for re-use: {tracked_actor}") tracked_actor = self._trial_to_actor.pop(trial) self._actor_to_trial.pop(tracked_actor) trial.set_ray_actor(None) return True def _actor_started(self, tracked_actor: TrackedActor, log: str = "STARTED"): self._started_actors.add(tracked_actor) trial = self._actor_to_trial[tracked_actor] logger.debug(f"Actor {log} for trial {trial}: {tracked_actor}") self._unstage_trial_with_resources(trial) ray_actor = self._actor_manager._live_actors_to_ray_actors_resources[ tracked_actor ][0] trial.set_ray_actor(ray_actor) self._callbacks.on_trial_start( iteration=self._iteration, trials=self._trials, trial=trial ) self._set_trial_status(trial, Trial.RUNNING) self._mark_trial_to_checkpoint(trial) if not self._schedule_trial_restore(trial): self._schedule_trial_train(trial) def _actor_stopped(self, tracked_actor: TrackedActor): if tracked_actor in self._actor_to_trial: trial = self._actor_to_trial.pop(tracked_actor) logger.debug(f"Actor STOPPED for trial {trial}: {tracked_actor}") self._trial_to_actor.pop(trial) trial.set_ray_actor(None) logger.debug(f"Actor STOPPED: {tracked_actor}") self._stopping_actors.pop(tracked_actor, None) self._started_actors.discard(tracked_actor) def _actor_failed(self, tracked_actor: TrackedActor, exception: Exception): trial = self._actor_to_trial[tracked_actor] logger.debug( f"Actor FAILED for trial {trial}: {tracked_actor}. " f"Exception: {exception}" ) if trial in (self._pending_trials | self._paused_trials): # First, set to running (needed downstream in _process_trial_failure) self._set_trial_status(trial, Trial.RUNNING) logger.debug( f"Trial {trial} failed in its creation task. Unstaging " f"to allow it to be re-scheduled." ) self._unstage_trial_with_resources(trial) self._trial_task_failure(trial, exception=exception) self._actor_manager.clear_actor_task_futures(tracked_actor) # Clean up actor tracked_actor.set_on_stop(None) tracked_actor.set_on_error(None) self._actor_manager.remove_actor(tracked_actor, kill=False) # Trigger actor stopped callback self._actor_stopped(tracked_actor) def _schedule_trial_task( self, trial: Trial, method_name: str, args: Optional[Tuple] = None, kwargs: Optional[Dict] = None, on_result: Optional[Callable[[Trial, Any], None]] = None, on_error: Optional[Callable[[Trial, Exception], None]] = None, _return_future: bool = False, ) -> Optional[ray.ObjectRef]: """Schedule an actor task future for a trial. This is a wrapper around ``ActorManager.schedule_actor_task``. This method retrieves the tracked actor for a trial to kick off the task. It also wraps around the callbacks, retrieving the trial object given the tracked actor. """ tracked_actor = self._trial_to_actor[trial] _on_result = None _on_error = None args = args or tuple() kwargs = kwargs or {} if on_result: def _on_result(tracked_actor: TrackedActor, *args, **kwargs): assert trial == self._actor_to_trial[tracked_actor] logger.debug( f"Future {method_name.upper()} RESOLVED for trial {trial}: " f"{args}, {kwargs}" ) try: on_result(trial, *args, **kwargs) except Exception as e: logger.debug( f"Error handling {method_name.upper()} result " f"for trial {trial}: {e}" ) if e is TuneError or self._fail_fast == self.RAISE: raise e else: raise TuneError(traceback.format_exc()) if on_error: def _on_error(tracked_actor: TrackedActor, exception: Exception): # If the actor failed, it has already been cleaned up. if tracked_actor not in self._actor_to_trial: assert isinstance(exception, RayActorError), type(exception) else: assert trial == self._actor_to_trial[tracked_actor] logger.debug( f"Future {method_name.upper()} FAILED for trial {trial}: " f"{exception}" ) try: on_error(trial, exception) except Exception as e: logger.debug( f"Error handling {method_name.upper()} failure " f"for trial {trial}: {e}" ) if e is TuneError or self._fail_fast == self.RAISE: raise e else: raise TuneError(traceback.format_exc()) logger.debug(f"Future {method_name.upper()} SCHEDULED for trial {trial}") with _change_working_directory(trial): future = self._actor_manager.schedule_actor_task( tracked_actor=tracked_actor, method_name=method_name, args=args, kwargs=kwargs, on_result=_on_result, on_error=_on_error, _return_future=_return_future, ) if _return_future: return future def _queue_decision(self, trial, decision): # Get old decision, setting it to the current decision if it isn't set old_decision = self._queued_trial_decisions.setdefault(trial.trial_id, decision) # Stopping always takes precedence. If we decided to stop, just quit if old_decision is TrialScheduler.STOP: return # The old decision wasn't STOP. We update the decision only if it is # STOP or PAUSE. The action will only be CONTINUE if it was set by # the first received result and was never updated after that. if decision is TrialScheduler.STOP or decision is TrialScheduler.PAUSE: self._queued_trial_decisions[trial.trial_id] = decision def _execute_action(self, trial: Trial, decision: str, after_save: bool = False): """Executes action based on decision. Args: trial: Trial to act on. decision: Scheduling decision to undertake. """ if decision == TrialScheduler.CONTINUE: self._schedule_trial_train(trial) elif decision == TrialScheduler.PAUSE: self.pause_trial(trial, should_checkpoint=not after_save) elif decision == TrialScheduler.STOP: self.stop_trial(trial) elif decision == TrialScheduler.NOOP: pass else: raise ValueError("Invalid decision: {}".format(decision)) def _maybe_execute_queued_decision(self, trial: Trial, after_save: bool = False): # `self._queued_trial_decisions` now contains a final decision # based on all results final_decision = self._queued_trial_decisions.pop(trial.trial_id, None) if final_decision: logger.debug( f"Executing final queued decision for {trial}: {final_decision}" ) self._execute_action(trial, final_decision, after_save=after_save) def _stop_experiment_if_needed(self): """Stops all trials.""" fail_fast = self._fail_fast and self._has_errored if self._stopper.stop_all() or fail_fast or self._should_stop_experiment: self._search_alg.set_finished() [ self._schedule_trial_stop(t) for t in self._trials if t.status not in {Trial.ERROR, Trial.TERMINATED} ] ### # Failure def _trial_task_failure(self, trial: Trial, exception: Exception): if self._fail_fast == self.RAISE: raise exception else: if self._print_trial_errors: logger.error(f"Trial task failed for trial {trial}", exc_info=exception) self._process_trial_failure(trial, exception=exception) def _process_trial_failure( self, trial: Trial, exception: Union[TuneError, RayTaskError, RayActorError], ): """Handle trial failure. Attempt trial recovery if possible, clean up state otherwise. Args: trial: Failed trial. exception: Exception prior to invoking this method. """ self._has_errored = True trial.handle_error(exception) if trial.status == Trial.RUNNING and trial.should_recover(): self._try_recover(trial, exc=exception) self._callbacks.on_trial_recover( iteration=self._iteration, trials=self._trials, trial=trial ) elif trial.status in {Trial.RUNNING, Trial.PENDING}: self._scheduler_alg.on_trial_error(self, trial) self._search_alg.on_trial_complete(trial.trial_id, error=True) self._schedule_trial_stop(trial, exception=exception) self._callbacks.on_trial_error( iteration=self._iteration, trials=self._trials, trial=trial ) def _schedule_trial_stop(self, trial: Trial, exception: Optional[Exception] = None): if trial.status == Trial.ERROR: logger.debug(f"Not requesting trial STOP as it is ERROR already: {trial}") return logger.debug(f"Requesting to STOP actor for trial {trial}") if trial.is_saving: logger.debug( f"Trial {trial} is currently saving/pausing. Scheduling STOP after " f"save resolved." ) self._cached_trial_decisions[trial.trial_id] = TrialScheduler.STOP trial.temporary_state.saving_to = None trial.temporary_state.restoring_from = None self._set_trial_status(trial, Trial.ERROR if exception else Trial.TERMINATED) trial.set_location(_Location()) if trial not in self._trial_to_actor: logger.debug(f"Will not STOP trial actor as it is not live: {trial}") return tracked_actor = self._trial_to_actor[trial] self._actor_manager.clear_actor_task_futures(tracked_actor=tracked_actor) self._mark_trial_to_checkpoint(trial) if not exception and self._maybe_cache_trial_actor(trial): # Trial runner has been cached return logger.debug(f"Terminating actor for trial {trial}: {tracked_actor}") tracked_actor = self._trial_to_actor.pop(trial) self._actor_to_trial.pop(tracked_actor) trial.set_ray_actor(None) self._remove_actor(tracked_actor=tracked_actor) def stop_trial(self, trial): """The canonical implementation of stopping a trial. Trials may be in any external status when this function is called. If trial is in state PENDING or PAUSED, calls `on_trial_remove` for scheduler and `on_trial_complete()` for search_alg. If trial is in state RUNNING, calls `on_trial_complete` for scheduler and search_alg if RUNNING. Caller to ensure that there is no outstanding future to be handled for the trial. If there is, the future would be discarded. """ try: if trial.status in [Trial.ERROR, Trial.TERMINATED]: return elif trial.status in [Trial.PENDING, Trial.PAUSED]: self._scheduler_alg.on_trial_remove(self, trial) self._search_alg.on_trial_complete(trial.trial_id) elif trial.status is Trial.RUNNING: # By this time trial.last_result should have been # updated already. self._scheduler_alg.on_trial_complete( self, trial, flatten_dict(trial.last_result) ) self._search_alg.on_trial_complete( trial.trial_id, result=flatten_dict(trial.last_result) ) self._callbacks.on_trial_complete( iteration=self._iteration, trials=self._trials, trial=trial ) self._schedule_graceful_trial_stop(trial) self._live_trials.discard(trial) except Exception as e: logger.exception("Trial %s: Error stopping trial.", trial) if self._fail_fast == self.RAISE: raise if isinstance(e, TuneError): self._process_trial_failure(trial, exception=e) else: self._process_trial_failure( trial, _TuneStopTrialError(traceback.format_exc()) ) def _schedule_graceful_trial_stop(self, trial: Trial): self._schedule_trial_export(trial) if trial.status != "ERROR": self._schedule_trial_stop(trial) def _schedule_trial_pause(self, trial: Trial, should_checkpoint: bool = True): if trial not in self._trial_to_actor: logger.debug( f"Trial PAUSE requested for trial {trial} but trial is already " f"stopping. Ignoring." ) return if should_checkpoint: self._cached_trial_decisions[trial.trial_id] = TrialScheduler.PAUSE self._schedule_trial_save(trial=trial) else: self._schedule_trial_stop(trial) self._set_trial_status(trial, Trial.PAUSED) ### # TRAIN def _schedule_trial_train(self, trial: Trial): args = () method_name = "train" buffer_length, buffer_time_s = self._maybe_buffer_training(trial) if buffer_length > 1: method_name = "train_buffered" args = (buffer_length, buffer_time_s) logger.debug(f"Scheduling future {method_name.upper()} for trial {trial}") self._schedule_trial_task( trial=trial, method_name=method_name, args=args, on_result=self._on_training_result, on_error=self._trial_task_failure, ) def _maybe_buffer_training(self, trial: Trial) -> Tuple[int, float]: buffer_time_s = max( self._buffer_min_time_s, min(self._buffer_max_time_s, self._actor_manager.num_actor_tasks // 10), ) buffer_length = self._buffer_length if buffer_length > 1 and trial.checkpoint_at_end: # If a trial checkpoint can be triggered externally, # it is not safe to buffer results. if log_once("trial_executor_buffer_checkpoint"): logger.warning( "Disabling buffered training as you passed " "`checkpoint_at_end` to `tune.CheckpointConfig()`." ) return 1, buffer_time_s if buffer_length > 1 and trial.checkpoint_freq > 0: return min(buffer_length, trial.checkpoint_freq), buffer_time_s return buffer_length, buffer_time_s ### # RESULT def _on_training_result(self, trial, result): if not isinstance(result, list): result = [result] with warn_if_slow("process_trial_result"): self._process_trial_results(trial, result) self._maybe_execute_queued_decision(trial, after_save=False) def _process_trial_results(self, trial, results): logger.debug(f"Processing trial results for trial {trial}: {results}") with warn_if_slow( "process_trial_results", message="Processing trial results took {duration:.3f} s, " "which may be a performance bottleneck. Please consider " "reporting results less frequently to Ray Tune.", ): for i, result in enumerate(results): with warn_if_slow("process_trial_result"): decision = self._process_trial_result(trial, result) if decision is None: # If we didn't get a decision, this means a # non-training future (e.g. a save) was scheduled. # We do not allow processing more results then. if i < len(results) - 1: if log_once("tune_controller_buffer_checkpoint"): logger.warning( f"Trial {trial} has a non-training future " f"scheduled but {len(results) - i} results " f"left to process. This means that a " f"checkpoint was requested, but buffered " f"training was continued before it was " f"saved. Consider using non-buffered " f"training by setting the env variable " f"`TUNE_RESULT_BUFFER_LENGTH=1`." ) elif decision == TrialScheduler.STOP: # If the decision is to stop the trial, # ignore all results that came after that. break def _process_trial_result(self, trial: Trial, result: dict[str, Any]): result.update(trial_id=trial.trial_id) is_duplicate = RESULT_DUPLICATE in result force_checkpoint = False # TrialScheduler and SearchAlgorithm still receive a # notification because there may be special handling for # the `on_trial_complete` hook. if is_duplicate: logger.debug("Trial finished without logging 'done'.") result = trial.last_result result.update(done=True) self._total_time += result.get(TIME_THIS_ITER_S, 0) flat_result = flatten_dict(result) self._validate_result_metrics(flat_result) if self._stopper(trial.trial_id, result) or trial.should_stop(flat_result): decision = TrialScheduler.STOP else: with warn_if_slow("scheduler.on_trial_result"): decision = self._scheduler_alg.on_trial_result( self._wrapped(), trial, flat_result ) if decision == TrialScheduler.STOP: result.update(done=True) else: # Only updating search alg if the trial is not to be stopped. with warn_if_slow("search_alg.on_trial_result"): self._search_alg.on_trial_result(trial.trial_id, flat_result) # If this is not a duplicate result, the callbacks should # be informed about the result. if not is_duplicate: with warn_if_slow("callbacks.on_trial_result"): self._callbacks.on_trial_result( iteration=self._iteration, trials=self._trials, trial=trial, # NOTE: Allow user callbacks to modify the Trial result in place. result=result, ) force_checkpoint = result.get(SHOULD_CHECKPOINT, False) trial.update_last_result(result) # Include in next experiment checkpoint self._mark_trial_to_checkpoint(trial) # Checkpoints to disk. This should be checked even if # the scheduler decision is STOP or PAUSE. Note that # PAUSE only checkpoints to memory and does not update # the global checkpoint state. if decision != TrialScheduler.PAUSE: # TODO(justinvyu): This is a temporary hack to fix pausing trials. # We already schedule a save task in `pause_trial`, so no need # to do it again here. self._checkpoint_trial_if_needed(trial, force=force_checkpoint) if trial.is_saving: logger.debug(f"Caching trial decision for trial {trial}: {decision}") # Cache decision to execute on after the save is processed. # This prevents changing the trial's state or kicking off # another training step prematurely. if not self._cached_trial_decisions.get(trial.trial_id) or decision in { TrialScheduler.PAUSE, TrialScheduler.STOP, }: # If already set, only overwrite if it's a PAUSE or STOP. This is # to avoid that CONTINUE decisions from a training step that resolve # late overwrite PAUSE/STOP decision. self._cached_trial_decisions[trial.trial_id] = decision return None else: self._queue_decision(trial, decision) return decision def _validate_result_metrics(self, result): """ Check if any of the required metrics was not reported in the last result. If the only items are ``done`` or any of DEBUG_METRICS, this means that no result was ever received and the trial just returned. This is also okay and will not raise an error. This will ignore checking for the DEFAULT_METRIC. """ if int(os.environ.get("TUNE_DISABLE_STRICT_METRIC_CHECKING", 0)) != 1 and ( len({k for k in result if k not in list(DEBUG_METRICS) + [DONE]}) > 1 ): base_metric = self._metric if self._metric != DEFAULT_METRIC else None scheduler_metric = ( self._scheduler_alg.metric if self._scheduler_alg.metric != DEFAULT_METRIC else None ) search_metrics = ( self._search_alg.metric if self._search_alg.metric != DEFAULT_METRIC else None ) if isinstance(search_metrics, str): search_metrics = [search_metrics] if base_metric and base_metric not in result: report_metric = base_metric location = "tune.TuneConfig()" elif scheduler_metric and scheduler_metric not in result: report_metric = scheduler_metric location = type(self._scheduler_alg).__name__ elif search_metrics and any( search_metric not in result for search_metric in search_metrics ): report_metric = list( filter( lambda search_metric: search_metric not in result, search_metrics, ) ) if len(report_metric) == 1: report_metric = report_metric[0] location = type(self._search_alg).__name__ else: report_metric = None location = None if report_metric: raise ValueError( "Trial returned a result which did not include the " "specified metric(s) `{}` that `{}` expects. " "Make sure your calls to `tune.report()` include the " "metric, or set the " "TUNE_DISABLE_STRICT_METRIC_CHECKING " "environment variable to 1. Result: {}".format( report_metric, location, result ) ) ### # SAVE def _schedule_trial_save( self, trial: Trial, result: Optional[Dict] = None, ) -> Optional[_FutureTrainingResult]: if trial not in self._trial_to_actor: logger.debug( f"Trial SAVE requested for trial {trial} but trial is already " f"stopping. Ignoring." ) return None result = result or trial.last_result future = self._schedule_trial_task( trial=trial, method_name="save", on_result=self._on_saving_result, on_error=self._trial_task_failure, _return_future=True, ) # TODO(justinvyu): `trial.saving_to` (and trial.is_saving) is needed # in order to prevent a done=True result from executing a STOP decision # (which clears all futures) before the save gets processed. # Keep this in for now while `train` and `save` are 2 separate steps. trial.temporary_state.saving_to = _FutureTrainingResult(future) # `trial.saving_to` holds a future training result -- this is only used # in the case of PBT to block until the checkpoint is ready. # In all other situations, the checkpoint future is processed by the # actor event manager when it is ready. return trial.temporary_state.saving_to def _on_saving_result(self, trial, checkpoint_value: _TrainingResult): with warn_if_slow("process_trial_save"): self._process_trial_save(trial, checkpoint_value) with warn_if_slow("callbacks.on_trial_save"): self._callbacks.on_trial_save( iteration=self._iteration, trials=self._trials, trial=trial ) self._maybe_execute_queued_decision(trial, after_save=True) def _process_trial_save(self, trial: Trial, checkpoint_value: _TrainingResult): """Processes a trial save. Acts on the decision cached during the last `_process_trial` call. Args: trial: Trial being saved. """ logger.debug("Trial %s: Processing trial save.", trial) try: if not checkpoint_value.checkpoint: logger.debug(f"Got empty checkpoint for trial {trial}") else: try: self._callbacks.on_checkpoint( iteration=self._iteration, trials=self._trials, trial=trial, checkpoint=checkpoint_value.checkpoint, ) except Exception: logger.warning( "Error encountered during processing of callbacks. " "Ray Train/Tune recently changed the checkpoint interface " "that is passed to callbacks. If you implemented your own " "callback with an `on_checkpoint` handler, please review " "the checkpoint interface and adjust your code " "accordingly." ) raise trial.on_checkpoint(checkpoint_value) self._checkpoint_manager.on_trial_checkpoint(trial) self._mark_trial_to_checkpoint(trial) except Exception: logger.exception( "Trial %s: Error handling checkpoint %s", trial, checkpoint_value ) trial.temporary_state.saving_to = None decision = self._cached_trial_decisions.pop(trial.trial_id, None) if decision and checkpoint_value: self._queue_decision(trial, decision) def _checkpoint_trial_if_needed(self, trial, force=False): """Checkpoints trial based off trial.last_result.""" if trial.should_checkpoint() or force: # Save trial runtime if possible. if trial.temporary_state.ray_actor: self._schedule_trial_save(trial) ### # RESTORE def _schedule_trial_restore(self, trial: Trial) -> bool: checkpoint_result = trial.latest_checkpoint_result if not checkpoint_result: logger.debug(f"Not restoring trial {trial}: No checkpoint found.") return False # TODO(justinvyu): Is this really needed? trial.temporary_state.restoring_from = checkpoint_result method_name = "restore" args = (checkpoint_result,) self._schedule_trial_task( trial=trial, method_name=method_name, args=args, kwargs={}, on_result=self._on_restoring_result, on_error=self._trial_task_failure, ) return True def _on_restoring_result(self, trial: Trial, result: Any): self._process_trial_restore(trial) def _process_trial_restore(self, trial: Trial): """Processes a trial restore. Args: trial: Trial being restored. """ logger.debug("Trial %s: Processing trial restore.", trial) trial.on_restore() logger.debug("Trial %s: Restore processed successfully", trial) self._set_trial_status(trial, Trial.RUNNING) self._schedule_trial_train(trial) self._live_trials.add(trial) def _try_recover( self, trial: Trial, exc: Union[TuneError, RayTaskError, RayActorError] ): """Tries to recover trial. Notifies SearchAlgorithm and Scheduler if failure to recover. Args: trial: Trial to recover. exc: Exception prior to invoking this method. """ self._cached_trial_decisions.pop(trial.trial_id, None) # Resetting this, in case that the trial is in saving status when it crashes. if trial.is_saving: trial.temporary_state.saving_to = None self._schedule_trial_stop(trial, exception=exc) logger.debug("Trial %s: Notifying Scheduler and requeueing.", trial) self._requeue_trial(trial) def _requeue_trial(self, trial): """Notification to TrialScheduler and requeue trial. This does not notify the SearchAlgorithm because the function evaluation is still in progress. """ self._scheduler_alg.on_trial_error(self, trial) self._set_trial_status(trial, status=Trial.PENDING) # TODO(rliaw): Right now, this pushes the trial to the end of queue # because restoration can be expensive. However, this is not # ideal since it just hides the issue - a better fix would # be to use an actor table to detect the IP of the Trainable # and rsync the files there. # See https://github.com/ray-project/ray/issues/5168 self._trials.pop(self._trials.index(trial)) self._trials.append(trial) self._live_trials.add(trial) with warn_if_slow("scheduler.on_trial_add"): self._scheduler_alg.on_trial_add(self._wrapped(), trial) ### # EXPORT def _schedule_trial_export(self, trial: Trial): if not trial.export_formats or len(trial.export_formats) <= 0: return # Todo: We are waiting here synchronously until the task resolved. # Instead, we should schedule the trial stop after the export resolved. # This requires changes in TrialRunner, which we can remove once the # legacy execution path has been removed. future = self._schedule_trial_task( trial=trial, method_name="export_model", args=(trial.export_formats,), on_result=None, on_error=self._trial_task_failure, _return_future=True, ) self._actor_manager._actor_task_events.resolve_future(future) ### # RESET def _schedule_trial_reset( self, trial: Trial, new_config: Dict, new_experiment_tag: str, ): trial.set_experiment_tag(new_experiment_tag) trial.set_config(new_config) # Pass magic variables extra_config = copy.deepcopy(new_config) extra_config[TRIAL_INFO] = _TrialInfo(trial) stdout_file, stderr_file = trial.log_to_file extra_config[STDOUT_FILE] = stdout_file extra_config[STDERR_FILE] = stderr_file logger_creator = partial( _noop_logger_creator, logdir=trial.storage.trial_working_directory ) self._resetting_trials.add(trial) self._schedule_trial_task( trial=trial, method_name="reset", args=(extra_config,), kwargs={ "logger_creator": logger_creator, "storage": trial.storage, }, on_result=self._on_trial_reset, on_error=self._trial_task_failure, ) def _on_trial_reset(self, trial: Trial, success: bool): self._resetting_trials.remove(trial) if not success: info = ( "Trainable runner reuse requires reset_config() to be " "implemented and return True." ) logger.error(f"Could not re-use actor for trial {trial}: {info}") exception = _AbortTrialExecution(info) trial.handle_error(exception) self._schedule_trial_stop(trial, exception=exception) return tracked_actor = self._trial_to_actor[trial] self._actor_started(tracked_actor, log="REUSED") def request_stop_trial(self, trial): self._stop_queue.append(trial) def request_stop_experiment(self): self._should_stop_experiment = True def _process_stop_requests(self): while self._stop_queue: t = self._stop_queue.pop() self.stop_trial(t) def pause_trial(self, trial: Trial, should_checkpoint: bool = True): """Pause a trial and reset the necessary state variables for resuming later. Args: trial: Trial to pause. should_checkpoint: Whether or not an in-memory checkpoint should be created for this paused trial. Defaults to True. """ # NOTE: The cached trial decision is not needed since we will overrule this # decision with PAUSE. self._cached_trial_decisions.pop(trial.trial_id, None) self._schedule_trial_pause(trial, should_checkpoint=should_checkpoint) def cleanup(self): """Cleanup trials and callbacks.""" self._cleanup_trials() self.end_experiment_callbacks() def __getstate__(self): """Gets state for trial. Note that this is not used as a pickling override as does not have all fields. """ state = self.__dict__.copy() for k in [ "_trials", "_live_trials", "_stop_queue", "_search_alg", "_placeholder_resolvers", "_scheduler_alg", "_pending_trial_queue_times", "_callbacks", "_checkpoint_manager", "_storage", "_insufficient_resources_manager", "_actor_manager", "_class_cache", "_resource_updater", "_trials_to_cache", "_trial_metadata", "_actor_to_trial", "_trial_to_actor", "_resources_to_pending_trials", "_pending_trials", "_pending_trials_list", "_running_trials", "_paused_trials", "_stopped_trials", "_failed_trials", "_resetting_trials", "_started_actors", "_stopping_actors", "_staged_trials", "_actor_cache", ]: del state[k] return state def __setstate__(self, state): # Use session_str from previous checkpoint if does not exist session_str = state.pop("_session_str") self.__dict__.setdefault("_session_str", session_str) # Use start_time from previous checkpoint if does not exist start_time = state.pop("_start_time") self.__dict__.setdefault("_start_time", start_time) self.__dict__.update(state) self._checkpoint_manager = self._create_checkpoint_manager() class _TrialExecutorWrapper: """Wraps around TrialExecutor class, intercepts API calls and warns users of restricted API access. This is meant to facilitate restricting the current API exposure of TrialExecutor by TrialScheduler. """ def __init__( self, trial_executor: "_FakeRayTrialExecutor", whitelist_attr: Optional[set] = None, ): self._trial_executor = trial_executor self._whitelist_attr = whitelist_attr or set() for attr in self._whitelist_attr: assert hasattr(self._trial_executor, attr) def __getattr__(self, attr): if attr not in self._whitelist_attr: if log_once("restrict_accessing_trial_executor"): logger.warning( f"You are trying to access {attr} interface of " f"TrialExecutor in TrialScheduler, which is being " f"restricted. If you believe it is reasonable for " f"your scheduler to access this TrialExecutor API, " f"please reach out to Ray team on GitHub. A more " f"strict API access pattern would be enforced " f"starting 1.12.0" ) return getattr(self._trial_executor, attr) @DeveloperAPI class TrialRunnerWrapper: """Wraps around TrialRunner class, intercepts API calls and warns users of restricted API access. This is meant to facilitate restricting the current API exposure of TrialRunner by TrialScheduler. """ _EXECUTOR_ATTR = "trial_executor" def __init__( self, tune_controller: TuneController, trial_executor: Any, runner_whitelist_attr: Optional[set] = None, executor_whitelist_attr: Optional[set] = None, ): self._tune_controller = tune_controller self._trial_executor = _TrialExecutorWrapper( trial_executor, executor_whitelist_attr ) self._runner_whitelist_attr = runner_whitelist_attr or set() for attr in self._runner_whitelist_attr: assert hasattr(self, attr) def __getattr__(self, attr): if attr == self._EXECUTOR_ATTR: return self._trial_executor if attr not in self._runner_whitelist_attr: if log_once("restrict_accessing_tune_controller"): logger.warning( f"You are trying to access {attr} interface of " f"TrialRunner in TrialScheduler, which is being " f"restricted. If you believe it is reasonable for " f"your scheduler to access this TrialRunner API, " f"please reach out to Ray team on GitHub. A more " f"strict API access pattern would be enforced " f"starting 1.12s.0" ) return getattr(self._tune_controller, attr) def _get_max_pending_trials(search_alg: SearchAlgorithm) -> int: max_pending_trials = os.getenv("TUNE_MAX_PENDING_TRIALS_PG", "auto") if max_pending_trials != "auto": return int(max_pending_trials) # Else, auto detect. # Only BasicVariantGenerator supports > 1 pending trials. # This is because we don't want to generate too many trials # before we fit the searcher model. if not isinstance(search_alg, BasicVariantGenerator): return 1 # Allow up to at least 200 pending trials to trigger fast autoscaling min_autoscaling_rate = 200 # Allow more pending trials for larger clusters (based on number of CPUs) cluster_cpus = ray.cluster_resources().get("CPU", 1.0) max_pending_trials = max(min_autoscaling_rate, int(cluster_cpus * 1.1)) if max_pending_trials > min_autoscaling_rate: logger.warning( f"The maximum number of pending trials has been " f"automatically set to the number of available " f"cluster CPUs, which is high " f"({max_pending_trials} CPUs/pending trials). " f"If you're running an experiment with a large number " f"of trials, this could lead to scheduling overhead. " f"In this case, consider setting the " f"`TUNE_MAX_PENDING_TRIALS_PG` environment variable " f"to the desired maximum number of concurrent pending trials." ) return max_pending_trials class _FakeRayTrialExecutor: """The TuneController does not use a RayTrialExecutor anymore. Instead, we pass this fake executor for searchers/schedulers to use as an interface. In the future, we should have the searchers/schedulers either interact with the tune controller, or define a different API for more fine-grained scheduler control. """ def __init__(self, tune_controller: TuneController): self._tune_controller = tune_controller def pause_trial(self, trial: Trial, should_checkpoint: bool = True): return self._tune_controller._schedule_trial_pause( trial, should_checkpoint=should_checkpoint ) def save( self, trial: Trial, result: Optional[Dict] = None, ) -> Optional[_FutureTrainingResult]: return self._tune_controller._schedule_trial_save(trial=trial, result=result) def has_resources_for_trial(self, trial: Trial): return True @property def _resource_updater(self): return self._tune_controller._resource_updater def force_reconcilation_on_next_step_end(self): pass