import copy import datetime import logging import pprint as pp import traceback from functools import partial from pathlib import Path from pickle import PicklingError from typing import ( TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Sequence, Type, Union, ) import ray from ray.exceptions import RpcError from ray.train._internal.storage import StorageContext from ray.train.constants import DEFAULT_STORAGE_PATH from ray.tune import CheckpointConfig, SyncConfig from ray.tune.error import TuneError from ray.tune.registry import is_function_trainable, register_trainable from ray.tune.stopper import CombinedStopper, FunctionStopper, Stopper, TimeoutStopper from ray.util.annotations import Deprecated, DeveloperAPI if TYPE_CHECKING: import pyarrow.fs from ray.tune import PlacementGroupFactory from ray.tune.experiment import Trial logger = logging.getLogger(__name__) def _validate_log_to_file(log_to_file): """Validate ``tune.RunConfig``'s ``log_to_file`` parameter. Return validated relative stdout and stderr filenames.""" if not log_to_file: stdout_file = stderr_file = None elif isinstance(log_to_file, bool) and log_to_file: stdout_file = "stdout" stderr_file = "stderr" elif isinstance(log_to_file, str): stdout_file = stderr_file = log_to_file elif isinstance(log_to_file, Sequence): if len(log_to_file) != 2: raise ValueError( "If you pass a Sequence to `log_to_file` it has to have " "a length of 2 (for stdout and stderr, respectively). The " "Sequence you passed has length {}.".format(len(log_to_file)) ) stdout_file, stderr_file = log_to_file else: raise ValueError( "You can pass a boolean, a string, or a Sequence of length 2 to " "`log_to_file`, but you passed something else ({}).".format( type(log_to_file) ) ) return stdout_file, stderr_file @DeveloperAPI class Experiment: """Tracks experiment specifications. Implicitly registers the Trainable if needed. The args here take the same meaning as the arguments defined `tune.py:run`. .. code-block:: python experiment_spec = Experiment( "my_experiment_name", my_func, stop={"mean_accuracy": 100}, config={ "alpha": tune.grid_search([0.2, 0.4, 0.6]), "beta": tune.grid_search([1, 2]), }, resources_per_trial={ "cpu": 1, "gpu": 0 }, num_samples=10, local_dir="~/ray_results", checkpoint_freq=10, max_failures=2) """ # Keys that will be present in `public_spec` dict. PUBLIC_KEYS = {"stop", "num_samples", "time_budget_s"} _storage_context_cls = StorageContext def __init__( self, name: str, run: Union[str, Callable, Type], *, stop: Optional[Union[Mapping, Stopper, Callable[[str, Mapping], bool]]] = None, time_budget_s: Optional[Union[int, float, datetime.timedelta]] = None, config: Optional[Dict[str, Any]] = None, resources_per_trial: Union[ None, Mapping[str, Union[float, int, Mapping]], "PlacementGroupFactory" ] = None, num_samples: int = 1, storage_path: Optional[str] = None, storage_filesystem: Optional["pyarrow.fs.FileSystem"] = None, sync_config: Optional[Union[SyncConfig, dict]] = None, checkpoint_config: Optional[Union[CheckpointConfig, dict]] = None, trial_name_creator: Optional[Callable[["Trial"], str]] = None, trial_dirname_creator: Optional[Callable[["Trial"], str]] = None, log_to_file: bool = False, export_formats: Optional[Sequence] = None, max_failures: int = 0, restore: Optional[str] = None, # Deprecated local_dir: Optional[str] = None, ): if isinstance(checkpoint_config, dict): checkpoint_config = CheckpointConfig(**checkpoint_config) else: checkpoint_config = checkpoint_config or CheckpointConfig() if is_function_trainable(run): if checkpoint_config.checkpoint_at_end: raise ValueError( "'checkpoint_at_end' cannot be used with a function trainable. " "You should include one last call to " "`ray.tune.report(metrics=..., checkpoint=...)` " "at the end of your training loop to get this behavior." ) if checkpoint_config.checkpoint_frequency: raise ValueError( "'checkpoint_frequency' cannot be set for a function trainable. " "You will need to report a checkpoint every " "`checkpoint_frequency` iterations within your training loop using " "`ray.tune.report(metrics=..., checkpoint=...)` " "to get this behavior." ) try: self._run_identifier = Experiment.register_if_needed(run) except RpcError as e: if e.rpc_code == ray._raylet.GRPC_STATUS_CODE_RESOURCE_EXHAUSTED: raise TuneError( f"The Trainable/training function is too large for grpc resource " f"limit. Check that its definition is not implicitly capturing a " f"large array or other object in scope. " f"Tip: use tune.with_parameters() to put large objects " f"in the Ray object store. \n" f"Original exception: {traceback.format_exc()}" ) else: raise e if not name: name = StorageContext.get_experiment_dir_name(run) storage_path = storage_path or DEFAULT_STORAGE_PATH self.storage = self._storage_context_cls( storage_path=storage_path, storage_filesystem=storage_filesystem, sync_config=sync_config, experiment_dir_name=name, ) logger.debug(f"StorageContext on the DRIVER:\n{self.storage}") config = config or {} if not isinstance(config, dict): raise ValueError( f"`Experiment(config)` must be a dict, got: {type(config)}. " "Please convert your search space to a dict before passing it in." ) self._stopper = None stopping_criteria = {} if not stop: pass elif isinstance(stop, list): bad_stoppers = [s for s in stop if not isinstance(s, Stopper)] if bad_stoppers: stopper_types = [type(s) for s in stop] raise ValueError( "If you pass a list as the `stop` argument to " "`tune.RunConfig()`, each element must be an instance of " f"`tune.stopper.Stopper`. Got {stopper_types}." ) self._stopper = CombinedStopper(*stop) elif isinstance(stop, dict): stopping_criteria = stop elif callable(stop): if FunctionStopper.is_valid_function(stop): self._stopper = FunctionStopper(stop) elif isinstance(stop, Stopper): self._stopper = stop else: raise ValueError( "Provided stop object must be either a dict, " "a function, or a subclass of " f"`ray.tune.Stopper`. Got {type(stop)}." ) else: raise ValueError( f"Invalid stop criteria: {stop}. Must be a " f"callable or dict. Got {type(stop)}." ) if time_budget_s: if self._stopper: self._stopper = CombinedStopper( self._stopper, TimeoutStopper(time_budget_s) ) else: self._stopper = TimeoutStopper(time_budget_s) stdout_file, stderr_file = _validate_log_to_file(log_to_file) spec = { "run": self._run_identifier, "stop": stopping_criteria, "time_budget_s": time_budget_s, "config": config, "resources_per_trial": resources_per_trial, "num_samples": num_samples, "checkpoint_config": checkpoint_config, "trial_name_creator": trial_name_creator, "trial_dirname_creator": trial_dirname_creator, "log_to_file": (stdout_file, stderr_file), "export_formats": export_formats or [], "max_failures": max_failures, "restore": ( Path(restore).expanduser().absolute().as_posix() if restore else None ), "storage": self.storage, } self.spec = spec @classmethod def from_json(cls, name: str, spec: dict): """Generates an Experiment object from JSON. Args: name: Name of Experiment. spec: JSON configuration of experiment. """ if "run" not in spec: raise TuneError("No trainable specified!") # Special case the `env` param for RLlib by automatically # moving it into the `config` section. if "env" in spec: spec["config"] = spec.get("config", {}) spec["config"]["env"] = spec["env"] del spec["env"] if "sync_config" in spec and isinstance(spec["sync_config"], dict): spec["sync_config"] = SyncConfig(**spec["sync_config"]) if "checkpoint_config" in spec and isinstance(spec["checkpoint_config"], dict): spec["checkpoint_config"] = CheckpointConfig(**spec["checkpoint_config"]) spec = copy.deepcopy(spec) run_value = spec.pop("run") try: exp = cls(name, run_value, **spec) except TypeError as e: raise TuneError( f"Failed to load the following Tune experiment " f"specification:\n\n {pp.pformat(spec)}.\n\n" f"Please check that the arguments are valid. " f"Experiment creation failed with the following " f"error:\n {e}" ) return exp @classmethod def get_trainable_name(cls, run_object: Union[str, Callable, Type]): """Get Trainable name. Args: run_object: Trainable to run. If string, assumes it is an ID and does not modify it. Otherwise, returns a string corresponding to the run_object name. Returns: A string representing the trainable identifier. Raises: TuneError: if ``run_object`` passed in is invalid. """ from ray.tune.search.sample import Domain if isinstance(run_object, str) or isinstance(run_object, Domain): return run_object elif isinstance(run_object, type) or callable(run_object): name = "DEFAULT" if hasattr(run_object, "_name"): name = run_object._name elif hasattr(run_object, "__name__"): fn_name = run_object.__name__ if fn_name == "": name = "lambda" elif fn_name.startswith("<"): name = "DEFAULT" else: name = fn_name elif ( isinstance(run_object, partial) and hasattr(run_object, "func") and hasattr(run_object.func, "__name__") ): name = run_object.func.__name__ else: logger.warning("No name detected on trainable. Using {}.".format(name)) return name else: raise TuneError("Improper 'run' - not string nor trainable.") @classmethod def register_if_needed(cls, run_object: Union[str, Callable, Type]): """Registers Trainable or Function at runtime. Assumes already registered if run_object is a string. Also, does not inspect interface of given run_object. Args: run_object: Trainable to run. If string, assumes it is an ID and does not modify it. Otherwise, returns a string corresponding to the run_object name. Returns: A string representing the trainable identifier. """ from ray.tune.search.sample import Domain if isinstance(run_object, str): return run_object elif isinstance(run_object, Domain): logger.warning("Not registering trainable. Resolving as variant.") return run_object name = cls.get_trainable_name(run_object) try: register_trainable(name, run_object) except (TypeError, PicklingError) as e: extra_msg = ( "Other options: " "\n-Try reproducing the issue by calling " "`pickle.dumps(trainable)`. " "\n-If the error is typing-related, try removing " "the type annotations and try again." ) raise type(e)(str(e) + " " + extra_msg) from None return name @property def stopper(self): return self._stopper @property def local_path(self) -> Optional[str]: return self.storage.experiment_driver_staging_path @property @Deprecated("Replaced by `local_path`") def local_dir(self): # TODO(justinvyu): [Deprecated] Remove in 2.11. raise DeprecationWarning("Use `local_path` instead of `local_dir`.") @property def remote_path(self) -> Optional[str]: return self.storage.experiment_fs_path @property def path(self) -> Optional[str]: return self.remote_path or self.local_path @property def checkpoint_config(self): return self.spec.get("checkpoint_config") @property @Deprecated("Replaced by `local_path`") def checkpoint_dir(self): # TODO(justinvyu): [Deprecated] Remove in 2.11. raise DeprecationWarning("Use `local_path` instead of `checkpoint_dir`.") @property def run_identifier(self): """Returns a string representing the trainable identifier.""" return self._run_identifier @property def public_spec(self) -> Dict[str, Any]: """Returns the spec dict with only the public-facing keys. Intended to be used for passing information to callbacks, Searchers and Schedulers. """ return {k: v for k, v in self.spec.items() if k in self.PUBLIC_KEYS} def _convert_to_experiment_list(experiments: Union[Experiment, List[Experiment], Dict]): """Produces a list of Experiment objects. Converts input from dict, single experiment, or list of experiments to list of experiments. If input is None, will return an empty list. Arguments: experiments: Experiments to run. Returns: List of experiments. """ exp_list = experiments # Transform list if necessary if experiments is None: exp_list = [] elif isinstance(experiments, Experiment): exp_list = [experiments] elif type(experiments) is dict: exp_list = [ Experiment.from_json(name, spec) for name, spec in experiments.items() ] # Validate exp_list if type(exp_list) is list and all(isinstance(exp, Experiment) for exp in exp_list): if len(exp_list) > 1: logger.info( "Running with multiple concurrent experiments. " "All experiments will be using the same SearchAlgorithm." ) else: raise TuneError("Invalid argument: {}".format(experiments)) return exp_list