tune.py 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161
  1. import abc
  2. import copy
  3. import datetime
  4. import logging
  5. import os
  6. import signal
  7. import sys
  8. import threading
  9. import time
  10. import warnings
  11. from typing import (
  12. TYPE_CHECKING,
  13. Any,
  14. Callable,
  15. Dict,
  16. Mapping,
  17. Optional,
  18. Sequence,
  19. Type,
  20. Union,
  21. )
  22. import ray
  23. from ray.air._internal import usage as air_usage
  24. from ray.air._internal.usage import AirEntrypoint
  25. from ray.air.util.node import _force_on_current_node
  26. from ray.train.constants import _DEPRECATED_VALUE, RAY_CHDIR_TO_TRIAL_DIR
  27. from ray.tune import CheckpointConfig, SyncConfig
  28. from ray.tune.analysis import ExperimentAnalysis
  29. from ray.tune.callback import Callback
  30. from ray.tune.error import TuneError
  31. from ray.tune.execution.placement_groups import PlacementGroupFactory
  32. from ray.tune.execution.tune_controller import TuneController
  33. from ray.tune.experiment import Experiment, Trial, _convert_to_experiment_list
  34. from ray.tune.experimental.output import IS_NOTEBOOK, AirVerbosity, get_air_verbosity
  35. from ray.tune.impl.placeholder import create_resolvers_map, inject_placeholders
  36. from ray.tune.logger import TBXLoggerCallback
  37. from ray.tune.progress_reporter import (
  38. ProgressReporter,
  39. _detect_progress_metrics,
  40. _detect_reporter,
  41. _prepare_progress_reporter_for_ray_client,
  42. _stream_client_output,
  43. )
  44. from ray.tune.registry import get_trainable_cls
  45. # Must come last to avoid circular imports
  46. from ray.tune.schedulers import (
  47. FIFOScheduler,
  48. PopulationBasedTraining,
  49. PopulationBasedTrainingReplay,
  50. TrialScheduler,
  51. )
  52. from ray.tune.schedulers.util import (
  53. _set_search_properties_backwards_compatible as scheduler_set_search_props,
  54. )
  55. from ray.tune.search import (
  56. BasicVariantGenerator,
  57. ConcurrencyLimiter,
  58. SearchAlgorithm,
  59. Searcher,
  60. SearchGenerator,
  61. create_searcher,
  62. )
  63. from ray.tune.search.util import (
  64. _set_search_properties_backwards_compatible as searcher_set_search_props,
  65. )
  66. from ray.tune.search.variant_generator import _has_unresolved_values
  67. from ray.tune.stopper import Stopper
  68. from ray.tune.trainable import Trainable
  69. from ray.tune.tune_config import ResumeConfig
  70. from ray.tune.utils.callback import _create_default_callbacks
  71. from ray.tune.utils.log import Verbosity, has_verbosity, set_verbosity
  72. from ray.util.annotations import PublicAPI
  73. from ray.util.queue import Queue
  74. if TYPE_CHECKING:
  75. import pyarrow.fs
  76. from ray.tune.experimental.output import ProgressReporter as AirProgressReporter
  77. logger = logging.getLogger(__name__)
  78. def _get_trainable(
  79. run_identifier: Union[Experiment, str, Type, Callable]
  80. ) -> Optional[Type[Trainable]]:
  81. if isinstance(run_identifier, Experiment):
  82. run_identifier = run_identifier.run_identifier
  83. if isinstance(run_identifier, type):
  84. if not issubclass(run_identifier, Trainable):
  85. # If obscure dtype, assume it is overridden.
  86. return None
  87. trainable_cls = run_identifier
  88. elif callable(run_identifier):
  89. trainable_cls = run_identifier
  90. elif isinstance(run_identifier, str):
  91. trainable_cls = get_trainable_cls(run_identifier)
  92. else:
  93. return None
  94. return trainable_cls
  95. def _build_resume_config_from_legacy_config(
  96. resume: Union[str, bool]
  97. ) -> Optional[ResumeConfig]:
  98. """Converts the legacy resume (str, bool) to a ResumeConfig object.
  99. Returns None if resume is False.
  100. """
  101. if resume is False:
  102. return None
  103. if resume is True:
  104. return ResumeConfig()
  105. # Parse resume string, e.g. AUTO+ERRORED
  106. resume_settings = resume.split("+")
  107. resume_str = resume_settings[0]
  108. if resume_str in ("LOCAL", "REMOTE", "PROMPT", "ERRORED_ONLY"):
  109. raise DeprecationWarning(
  110. f"'{resume_str}' is deprecated. "
  111. "Please pass in one of (True, False, 'AUTO')."
  112. )
  113. resume_config = ResumeConfig()
  114. for setting in resume_settings[1:]:
  115. if setting == "ERRORED":
  116. resume_config = ResumeConfig(errored=ResumeConfig.ResumeType.RESUME)
  117. elif setting == "RESTART_ERRORED":
  118. resume_config = ResumeConfig(errored=ResumeConfig.ResumeType.RESTART)
  119. elif setting == "ERRORED_ONLY":
  120. resume_config = ResumeConfig(
  121. unfinished=ResumeConfig.ResumeType.SKIP,
  122. errored=ResumeConfig.ResumeType.RESUME,
  123. )
  124. elif setting == "RESTART_ERRORED_ONLY":
  125. resume_config = ResumeConfig(
  126. unfinished=ResumeConfig.ResumeType.SKIP,
  127. errored=ResumeConfig.ResumeType.RESTART,
  128. )
  129. else:
  130. raise ValueError(f"Invalid resume setting: '{setting}'")
  131. return resume_config
  132. def _check_default_resources_override(
  133. run_identifier: Union[Experiment, str, Type, Callable]
  134. ) -> bool:
  135. trainable_cls = _get_trainable(run_identifier)
  136. if not trainable_cls:
  137. # If no trainable, assume override
  138. return True
  139. return hasattr(trainable_cls, "default_resource_request") and (
  140. trainable_cls.default_resource_request.__code__
  141. != Trainable.default_resource_request.__code__
  142. )
  143. def _check_mixin(run_identifier: Union[Experiment, str, Type, Callable]) -> bool:
  144. trainable_cls = _get_trainable(run_identifier)
  145. if not trainable_cls:
  146. # Default to True
  147. return True
  148. return hasattr(trainable_cls, "__mixins__") or getattr(
  149. trainable_cls, "_is_mixin", False
  150. )
  151. def _check_gpus_in_resources(
  152. resources: Optional[Union[Dict, PlacementGroupFactory]]
  153. ) -> bool:
  154. if not resources:
  155. return False
  156. if isinstance(resources, PlacementGroupFactory):
  157. return bool(resources.required_resources.get("GPU", None))
  158. if isinstance(resources, dict):
  159. return bool(resources.get("gpu", None))
  160. def _report_progress(
  161. runner: TuneController, reporter: ProgressReporter, done: bool = False
  162. ):
  163. """Reports experiment progress.
  164. Args:
  165. runner: Trial runner to report on.
  166. reporter: Progress reporter.
  167. done: Whether this is the last progress report attempt.
  168. """
  169. trials = runner.get_trials()
  170. if reporter.should_report(trials, done=done):
  171. sched_debug_str = runner.scheduler_alg.debug_string()
  172. used_resources_str = runner._used_resources_string()
  173. reporter.report(trials, done, sched_debug_str, used_resources_str)
  174. def _report_air_progress(
  175. runner: TuneController, reporter: "AirProgressReporter", force: bool = False
  176. ):
  177. trials = runner.get_trials()
  178. reporter_args = []
  179. used_resources_string = runner._used_resources_string()
  180. reporter_args.append(used_resources_string)
  181. reporter.print_heartbeat(trials, *reporter_args, force=force)
  182. def _setup_signal_catching() -> threading.Event:
  183. original_handler = signal.getsignal(signal.SIGINT)
  184. experiment_interrupted_event = threading.Event()
  185. def signal_interrupt_tune_run(sig: int, frame):
  186. logger.warning(
  187. "Stop signal received (e.g. via SIGINT/Ctrl+C), ending Ray Tune run. "
  188. "This will try to checkpoint the experiment state one last time. "
  189. "Press CTRL+C (or send SIGINT/SIGKILL/SIGTERM) "
  190. "to skip. "
  191. )
  192. experiment_interrupted_event.set()
  193. # Restore original signal handler to react to future SIGINT signals.
  194. signal.signal(signal.SIGINT, original_handler)
  195. # We should only install the handler when it is safe to do so.
  196. # When tune.run() is called from worker thread, signal.signal will
  197. # fail.
  198. allow_signal_catching = True
  199. if threading.current_thread() != threading.main_thread():
  200. allow_signal_catching = False
  201. if allow_signal_catching:
  202. if not int(os.getenv("TUNE_DISABLE_SIGINT_HANDLER", "0")):
  203. signal.signal(signal.SIGINT, signal_interrupt_tune_run)
  204. # Always register SIGUSR1 if available (not available e.g. on Windows)
  205. if hasattr(signal, "SIGUSR1"):
  206. signal.signal(signal.SIGUSR1, signal_interrupt_tune_run)
  207. return experiment_interrupted_event
  208. def _ray_auto_init(entrypoint: str):
  209. """Initialize Ray unless already configured."""
  210. if os.environ.get("TUNE_DISABLE_AUTO_INIT") == "1":
  211. logger.info("'TUNE_DISABLE_AUTO_INIT=1' detected.")
  212. elif not ray.is_initialized():
  213. ray.init()
  214. logger.info(
  215. "Initializing Ray automatically. "
  216. "For cluster usage or custom Ray initialization, "
  217. f"call `ray.init(...)` before `{entrypoint}`."
  218. )
  219. class _Config(abc.ABC):
  220. def to_dict(self) -> dict:
  221. """Converts this configuration to a dict format."""
  222. raise NotImplementedError
  223. @PublicAPI
  224. def run(
  225. run_or_experiment: Union[str, Callable, Type],
  226. *,
  227. name: Optional[str] = None,
  228. metric: Optional[str] = None,
  229. mode: Optional[str] = None,
  230. stop: Optional[Union[Mapping, Stopper, Callable[[str, Mapping], bool]]] = None,
  231. time_budget_s: Optional[Union[int, float, datetime.timedelta]] = None,
  232. config: Optional[Dict[str, Any]] = None,
  233. resources_per_trial: Union[
  234. None, Mapping[str, Union[float, int, Mapping]], PlacementGroupFactory
  235. ] = None,
  236. num_samples: int = 1,
  237. storage_path: Optional[str] = None,
  238. storage_filesystem: Optional["pyarrow.fs.FileSystem"] = None,
  239. search_alg: Optional[Union[Searcher, SearchAlgorithm, str]] = None,
  240. scheduler: Optional[Union[TrialScheduler, str]] = None,
  241. checkpoint_config: Optional[CheckpointConfig] = None,
  242. verbose: Optional[Union[int, AirVerbosity, Verbosity]] = None,
  243. progress_reporter: Optional[ProgressReporter] = None,
  244. log_to_file: bool = False,
  245. trial_name_creator: Optional[Callable[[Trial], str]] = None,
  246. trial_dirname_creator: Optional[Callable[[Trial], str]] = None,
  247. sync_config: Optional[SyncConfig] = None,
  248. export_formats: Optional[Sequence] = None,
  249. max_failures: int = 0,
  250. fail_fast: bool = False,
  251. restore: Optional[str] = None,
  252. resume: Optional[Union[bool, str]] = None,
  253. resume_config: Optional[ResumeConfig] = None,
  254. reuse_actors: bool = False,
  255. raise_on_failed_trial: bool = True,
  256. callbacks: Optional[Sequence[Callback]] = None,
  257. max_concurrent_trials: Optional[int] = None,
  258. # Deprecated
  259. keep_checkpoints_num: Optional[int] = None, # Deprecated (2.7)
  260. checkpoint_score_attr: Optional[str] = None, # Deprecated (2.7)
  261. checkpoint_freq: int = 0, # Deprecated (2.7)
  262. checkpoint_at_end: bool = False, # Deprecated (2.7)
  263. chdir_to_trial_dir: bool = _DEPRECATED_VALUE, # Deprecated (2.8)
  264. local_dir: Optional[str] = None,
  265. # == internal only ==
  266. _remote: Optional[bool] = None,
  267. # Passed by the Tuner.
  268. _remote_string_queue: Optional[Queue] = None,
  269. # Todo (krfricke): Find a better way to pass entrypoint information, e.g.
  270. # a context object or similar.
  271. _entrypoint: AirEntrypoint = AirEntrypoint.TUNE_RUN,
  272. ) -> ExperimentAnalysis:
  273. """Executes training.
  274. When a SIGINT signal is received (e.g. through Ctrl+C), the tuning run
  275. will gracefully shut down and checkpoint the latest experiment state.
  276. Sending SIGINT again (or SIGKILL/SIGTERM instead) will skip this step.
  277. Many aspects of Tune, such as the frequency of global checkpointing,
  278. maximum pending placement group trials and the path of the result
  279. directory be configured through environment variables. Refer to
  280. :ref:`tune-env-vars` for a list of environment variables available.
  281. Examples:
  282. .. code-block:: python
  283. # Run 10 trials (each trial is one instance of a Trainable). Tune runs
  284. # in parallel and automatically determines concurrency.
  285. tune.run(trainable, num_samples=10)
  286. # Run 1 trial, stop when trial has reached 10 iterations
  287. tune.run(my_trainable, stop={"training_iteration": 10})
  288. # automatically retry failed trials up to 3 times
  289. tune.run(my_trainable, stop={"training_iteration": 10}, max_failures=3)
  290. # Run 1 trial, search over hyperparameters, stop after 10 iterations.
  291. space = {"lr": tune.uniform(0, 1), "momentum": tune.uniform(0, 1)}
  292. tune.run(my_trainable, config=space, stop={"training_iteration": 10})
  293. # Resumes training if a previous machine crashed
  294. tune.run(
  295. my_trainable, config=space,
  296. storage_path=<path/to/dir>, name=<exp_name>, resume=True
  297. )
  298. Args:
  299. run_or_experiment: If function|class|str, this is the algorithm or
  300. model to train. This may refer to the name of a built-on algorithm
  301. (e.g. RLlib's DQN or PPO), a user-defined trainable
  302. function or class, or the string identifier of a
  303. trainable function or class registered in the tune registry.
  304. If Experiment, then Tune will execute training based on
  305. Experiment.spec. If you want to pass in a Python lambda, you
  306. will need to first register the function:
  307. ``tune.register_trainable("lambda_id", lambda x: ...)``. You can
  308. then use ``tune.run("lambda_id")``.
  309. metric: Metric to optimize. This metric should be reported
  310. with `tune.report()`. If set, will be passed to the search
  311. algorithm and scheduler.
  312. mode: Must be one of [min, max]. Determines whether objective is
  313. minimizing or maximizing the metric attribute. If set, will be
  314. passed to the search algorithm and scheduler.
  315. name: Name of experiment.
  316. stop: Stopping criteria. If dict,
  317. the keys may be any field in the return result of 'train()',
  318. whichever is reached first. If function, it must take (trial_id,
  319. result) as arguments and return a boolean (True if trial should be
  320. stopped, False otherwise). This can also be a subclass of
  321. ``ray.tune.Stopper``, which allows users to implement
  322. custom experiment-wide stopping (i.e., stopping an entire Tune
  323. run based on some time constraint).
  324. time_budget_s: Global time budget in
  325. seconds after which all trials are stopped. Can also be a
  326. ``datetime.timedelta`` object.
  327. config: Algorithm-specific configuration for Tune variant
  328. generation (e.g. env, hyperparams). Defaults to empty dict.
  329. Custom search algorithms may ignore this.
  330. resources_per_trial: Machine resources
  331. to allocate per trial, e.g. ``{"cpu": 64, "gpu": 8}``.
  332. Note that GPUs will not be assigned unless you specify them here.
  333. Defaults to 1 CPU and 0 GPUs in
  334. ``Trainable.default_resource_request()``. This can also
  335. be a PlacementGroupFactory object wrapping arguments to create a
  336. per-trial placement group.
  337. num_samples: Number of times to sample from the
  338. hyperparameter space. Defaults to 1. If `grid_search` is
  339. provided as an argument, the grid will be repeated
  340. `num_samples` of times. If this is -1, (virtually) infinite
  341. samples are generated until a stopping condition is met.
  342. storage_path: Path to store results at. Can be a local directory or
  343. a destination on cloud storage. Defaults to
  344. the local ``~/ray_results`` directory.
  345. search_alg: Search algorithm for
  346. optimization. You can also use the name of the algorithm.
  347. scheduler: Scheduler for executing
  348. the experiment. Choose among FIFO (default), MedianStopping,
  349. AsyncHyperBand, HyperBand and PopulationBasedTraining. Refer to
  350. ray.tune.schedulers for more options. You can also use the
  351. name of the scheduler.
  352. verbose: 0, 1, or 2. Verbosity mode.
  353. 0 = silent, 1 = default, 2 = verbose. Defaults to 1.
  354. If the ``RAY_AIR_NEW_OUTPUT=1`` environment variable is set,
  355. uses the old verbosity settings:
  356. 0 = silent, 1 = only status updates, 2 = status and brief
  357. results, 3 = status and detailed results.
  358. progress_reporter: Progress reporter for reporting
  359. intermediate experiment progress. Defaults to CLIReporter if
  360. running in command-line, or JupyterNotebookReporter if running in
  361. a Jupyter notebook.
  362. log_to_file: Log stdout and stderr to files in
  363. Tune's trial directories. If this is `False` (default), no files
  364. are written. If `true`, outputs are written to `trialdir/stdout`
  365. and `trialdir/stderr`, respectively. If this is a single string,
  366. this is interpreted as a file relative to the trialdir, to which
  367. both streams are written. If this is a Sequence (e.g. a Tuple),
  368. it has to have length 2 and the elements indicate the files to
  369. which stdout and stderr are written, respectively.
  370. trial_name_creator: Optional function that takes in a Trial and returns
  371. its name (i.e. its string representation). Be sure to include some unique
  372. identifier (such as `Trial.trial_id`) in each trial's name.
  373. trial_dirname_creator: Optional function that takes in a trial and
  374. generates its trial directory name as a string. Be sure to include some
  375. unique identifier (such as `Trial.trial_id`) is used in each trial's
  376. directory name. Otherwise, trials could overwrite artifacts and checkpoints
  377. of other trials. The return value cannot be a path.
  378. chdir_to_trial_dir: Deprecated. Set the `RAY_CHDIR_TO_TRIAL_DIR` env var instead
  379. sync_config: Configuration object for syncing. See tune.SyncConfig.
  380. export_formats: List of formats that exported at the end of
  381. the experiment. Default is None.
  382. max_failures: Try to recover a trial at least this many times.
  383. Ray will recover from the latest checkpoint if present.
  384. Setting to -1 will lead to infinite recovery retries.
  385. Setting to 0 will disable retries. Defaults to 0.
  386. fail_fast: Whether to fail upon the first error.
  387. If fail_fast='raise' provided, Tune will automatically
  388. raise the exception received by the Trainable. fail_fast='raise'
  389. can easily leak resources and should be used with caution (it
  390. is best used with `ray.init(local_mode=True)`).
  391. restore: Path to checkpoint. Only makes sense to set if
  392. running 1 trial. Defaults to None.
  393. resume: One of [True, False, "AUTO"]. Can
  394. be suffixed with one or more of ["+ERRORED", "+ERRORED_ONLY",
  395. "+RESTART_ERRORED", "+RESTART_ERRORED_ONLY"] (e.g. ``AUTO+ERRORED``).
  396. `resume=True` and `resume="AUTO"` will attempt to resume from a
  397. checkpoint and otherwise start a new experiment.
  398. The suffix "+ERRORED" resets and reruns errored trials upon resume -
  399. previous trial artifacts will be left untouched. It will try to continue
  400. from the last observed checkpoint.
  401. The suffix "+RESTART_ERRORED" will instead start the errored trials from
  402. scratch. "+ERRORED_ONLY" and "+RESTART_ERRORED_ONLY" will disable
  403. resuming non-errored trials - they will be added as finished instead. New
  404. trials can still be generated by the search algorithm.
  405. resume_config: [Experimental] Config object that controls how to resume
  406. trials of different statuses. Can be used as a substitute to the
  407. `resume` suffixes described above.
  408. reuse_actors: Whether to reuse actors between different trials
  409. when possible. This can drastically speed up experiments that start
  410. and stop actors often (e.g., PBT in time-multiplexing mode). This
  411. requires trials to have the same resource requirements.
  412. Defaults to ``False``.
  413. raise_on_failed_trial: Raise TuneError if there exists failed
  414. trial (of ERROR state) when the experiments complete.
  415. callbacks: List of callbacks that will be called at different
  416. times in the training loop. Must be instances of the
  417. ``ray.tune.callback.Callback`` class. If not passed,
  418. `LoggerCallback` (json/csv/tensorboard) callbacks are automatically added.
  419. max_concurrent_trials: Maximum number of trials to run
  420. concurrently. Must be non-negative. If None or 0, no limit will
  421. be applied. This is achieved by wrapping the ``search_alg`` in
  422. a :class:`ConcurrencyLimiter`, and thus setting this argument
  423. will raise an exception if the ``search_alg`` is already a
  424. :class:`ConcurrencyLimiter`. Defaults to None.
  425. _remote: Whether to run the Tune driver in a remote function.
  426. This is disabled automatically if a custom trial executor is
  427. passed in. This is enabled by default in Ray client mode.
  428. local_dir: Deprecated. Use `storage_path` instead.
  429. keep_checkpoints_num: Deprecated. use checkpoint_config instead.
  430. checkpoint_score_attr: Deprecated. use checkpoint_config instead.
  431. checkpoint_freq: Deprecated. use checkpoint_config instead.
  432. checkpoint_at_end: Deprecated. use checkpoint_config instead.
  433. checkpoint_keep_all_ranks: Deprecated. use checkpoint_config instead.
  434. checkpoint_upload_from_workers: Deprecated. use checkpoint_config instead.
  435. Returns:
  436. ExperimentAnalysis: Object for experiment analysis.
  437. Raises:
  438. TuneError: Any trials failed and `raise_on_failed_trial` is True.
  439. """
  440. # NO CODE IS TO BE ADDED ABOVE THIS COMMENT
  441. # remote_run_kwargs must be defined before any other
  442. # code is ran to ensure that at this point,
  443. # `locals()` is equal to args and kwargs
  444. remote_run_kwargs = locals().copy()
  445. remote_run_kwargs.pop("_remote")
  446. if _entrypoint == AirEntrypoint.TRAINER:
  447. error_message_map = {
  448. "entrypoint": "<FrameworkTrainer>(...)",
  449. "search_space_arg": "param_space",
  450. "restore_entrypoint": '<FrameworkTrainer>.restore(path="{path}", ...)',
  451. }
  452. elif _entrypoint == AirEntrypoint.TUNER:
  453. error_message_map = {
  454. "entrypoint": "Tuner(...)",
  455. "search_space_arg": "param_space",
  456. "restore_entrypoint": 'Tuner.restore(path="{path}", trainable=...)',
  457. }
  458. elif _entrypoint == AirEntrypoint.TUNE_RUN_EXPERIMENTS:
  459. error_message_map = {
  460. "entrypoint": "tune.run_experiments(...)",
  461. "search_space_arg": "experiment=Experiment(config)",
  462. "restore_entrypoint": "tune.run_experiments(..., resume=True)",
  463. }
  464. else:
  465. error_message_map = {
  466. "entrypoint": "tune.run(...)",
  467. "search_space_arg": "config",
  468. "restore_entrypoint": "tune.run(..., resume=True)",
  469. }
  470. _ray_auto_init(entrypoint=error_message_map["entrypoint"])
  471. if _remote is None:
  472. _remote = ray.util.client.ray.is_connected()
  473. if verbose is None:
  474. # Default `verbose` value. For new output engine, this is AirVerbosity.VERBOSE.
  475. # For old output engine, this is Verbosity.V3_TRIAL_DETAILS
  476. verbose = get_air_verbosity(AirVerbosity.VERBOSE) or Verbosity.V3_TRIAL_DETAILS
  477. if _remote:
  478. if get_air_verbosity(verbose) is not None:
  479. logger.info(
  480. "[output] This uses the legacy output and progress reporter, "
  481. "as Ray client is not supported by the new engine. "
  482. "For more information, see "
  483. "https://github.com/ray-project/ray/issues/36949"
  484. )
  485. remote_run = ray.remote(num_cpus=0)(run)
  486. # Make sure tune.run is called on the sever node.
  487. remote_run = _force_on_current_node(remote_run)
  488. progress_reporter, string_queue = _prepare_progress_reporter_for_ray_client(
  489. progress_reporter, verbose, _remote_string_queue
  490. )
  491. # Override with detected progress reporter
  492. remote_run_kwargs["progress_reporter"] = progress_reporter
  493. remote_future = remote_run.remote(_remote=False, **remote_run_kwargs)
  494. _stream_client_output(
  495. remote_future,
  496. progress_reporter,
  497. string_queue,
  498. )
  499. return ray.get(remote_future)
  500. del remote_run_kwargs
  501. # TODO(justinvyu): [Deprecated] Remove in 2.30
  502. ENV_VAR_DEPRECATION_MESSAGE = (
  503. "The environment variable `{}` is deprecated. "
  504. "It is no longer used and will not have any effect. "
  505. "You should set the `storage_path` instead. Files will no longer be "
  506. "written to `~/ray_results` as long as `storage_path` is set."
  507. "See the docs: https://docs.ray.io/en/latest/train/user-guides/"
  508. "persistent-storage.html#setting-the-local-staging-directory"
  509. )
  510. if os.environ.get("TUNE_RESULT_DIR"):
  511. raise DeprecationWarning(ENV_VAR_DEPRECATION_MESSAGE.format("TUNE_RESULT_DIR"))
  512. if os.environ.get("RAY_AIR_LOCAL_CACHE_DIR"):
  513. raise DeprecationWarning(
  514. ENV_VAR_DEPRECATION_MESSAGE.format("RAY_AIR_LOCAL_CACHE_DIR")
  515. )
  516. if local_dir is not None:
  517. raise DeprecationWarning(
  518. "The `local_dir` argument is deprecated. "
  519. "You should set the `storage_path` instead. "
  520. "See the docs: https://docs.ray.io/en/latest/train/user-guides/"
  521. "persistent-storage.html#setting-the-local-staging-directory"
  522. )
  523. ray._common.usage.usage_lib.record_library_usage("tune")
  524. # Tracking environment variable usage here will also catch:
  525. # 1.) Tuner.fit() usage
  526. # 2.) Trainer.fit() usage
  527. # 3.) Ray client usage (env variables are inherited by the Ray runtime env)
  528. air_usage.tag_ray_air_env_vars()
  529. # Track the entrypoint to AIR:
  530. # Tuner.fit / Trainer.fit / tune.run / tune.run_experiments
  531. air_usage.tag_air_entrypoint(_entrypoint)
  532. all_start = time.time()
  533. if mode and mode not in ["min", "max"]:
  534. raise ValueError(
  535. f"The `mode` parameter passed to `{error_message_map['entrypoint']}` "
  536. "must be one of ['min', 'max']"
  537. )
  538. air_verbosity = get_air_verbosity(verbose)
  539. if air_verbosity is not None and IS_NOTEBOOK:
  540. logger.info(
  541. "[output] This uses the legacy output and progress reporter, "
  542. "as Jupyter notebooks are not supported by the new engine, yet. "
  543. "For more information, please see "
  544. "https://github.com/ray-project/ray/issues/36949"
  545. )
  546. air_verbosity = None
  547. if air_verbosity is not None:
  548. # Disable old output engine
  549. set_verbosity(0)
  550. else:
  551. # Use old output engine
  552. set_verbosity(verbose)
  553. config = config or {}
  554. if isinstance(config, _Config):
  555. config = config.to_dict()
  556. if not isinstance(config, dict):
  557. raise ValueError(
  558. f"The `{error_message_map['search_space_arg']}` passed to "
  559. f"`{error_message_map['entrypoint']}` must be a dict. "
  560. f"Got '{type(config)}' instead."
  561. )
  562. sync_config = sync_config or SyncConfig()
  563. checkpoint_config = checkpoint_config or CheckpointConfig()
  564. # For backward compatibility
  565. # TODO(jungong): remove after 2.7 release.
  566. if keep_checkpoints_num is not None:
  567. warnings.warn(
  568. "keep_checkpoints_num is deprecated and will be removed. "
  569. "use checkpoint_config.num_to_keep instead.",
  570. DeprecationWarning,
  571. )
  572. checkpoint_config.num_to_keep = keep_checkpoints_num
  573. if checkpoint_score_attr is not None:
  574. warnings.warn(
  575. "checkpoint_score_attr is deprecated and will be removed. "
  576. "use checkpoint_config.checkpoint_score_attribute instead.",
  577. DeprecationWarning,
  578. )
  579. if checkpoint_score_attr.startswith("min-"):
  580. warnings.warn(
  581. "using min- and max- prefixes to specify checkpoint score "
  582. "order is deprecated. Use CheckpointConfig.checkpoint_score_order "
  583. "instead",
  584. DeprecationWarning,
  585. )
  586. checkpoint_config.checkpoint_score_attribute = checkpoint_score_attr[4:]
  587. checkpoint_config.checkpoint_score_order = "min"
  588. else:
  589. checkpoint_config.checkpoint_score_attribute = checkpoint_score_attr
  590. checkpoint_config.checkpoint_score_order = "max"
  591. checkpoint_config.score_attr = checkpoint_score_attr
  592. if checkpoint_freq > 0:
  593. warnings.warn(
  594. "checkpoint_freq is deprecated and will be removed. "
  595. "use checkpoint_config.checkpoint_frequency instead.",
  596. DeprecationWarning,
  597. )
  598. checkpoint_config.checkpoint_frequency = checkpoint_freq
  599. if checkpoint_at_end:
  600. warnings.warn(
  601. "checkpoint_at_end is deprecated and will be removed. "
  602. "use checkpoint_config.checkpoint_at_end instead.",
  603. DeprecationWarning,
  604. )
  605. checkpoint_config.checkpoint_at_end = checkpoint_at_end
  606. # TODO(justinvyu): [Deprecated] Remove in 2.11.
  607. if chdir_to_trial_dir != _DEPRECATED_VALUE:
  608. raise DeprecationWarning(
  609. "`chdir_to_trial_dir` is deprecated. "
  610. f"Use the {RAY_CHDIR_TO_TRIAL_DIR} environment variable instead. "
  611. "Set it to 0 to disable the default behavior of changing the "
  612. "working directory.",
  613. DeprecationWarning,
  614. )
  615. if num_samples == -1:
  616. num_samples = sys.maxsize
  617. # Create scheduler here as we need access to some of its properties
  618. if isinstance(scheduler, str):
  619. # importing at top level causes a recursive dependency
  620. from ray.tune.schedulers import create_scheduler
  621. scheduler = create_scheduler(scheduler)
  622. scheduler = scheduler or FIFOScheduler()
  623. if not scheduler.supports_buffered_results:
  624. # Result buffering with e.g. a Hyperband scheduler is a bad idea, as
  625. # hyperband tries to stop trials when processing brackets. With result
  626. # buffering, we might trigger this multiple times when evaluating
  627. # a single trial, which leads to unexpected behavior.
  628. env_result_buffer_length = os.getenv("TUNE_RESULT_BUFFER_LENGTH", "")
  629. if env_result_buffer_length:
  630. warnings.warn(
  631. f"You are using a {type(scheduler)} scheduler, but "
  632. f"TUNE_RESULT_BUFFER_LENGTH is set "
  633. f"({env_result_buffer_length}). This can lead to undesired "
  634. f"and faulty behavior, so the buffer length was forcibly set "
  635. f"to 1 instead."
  636. )
  637. os.environ["TUNE_RESULT_BUFFER_LENGTH"] = "1"
  638. if (
  639. isinstance(scheduler, (PopulationBasedTraining, PopulationBasedTrainingReplay))
  640. and not reuse_actors
  641. ):
  642. warnings.warn(
  643. "Consider boosting PBT performance by enabling `reuse_actors` as "
  644. "well as implementing `reset_config` for Trainable."
  645. )
  646. # Before experiments are created, we first clean up the passed in
  647. # Config dictionary by replacing all the non-primitive config values
  648. # with placeholders. This serves two purposes:
  649. # 1. we can replace and "fix" these objects if a Trial is restored.
  650. # 2. the config dictionary will then be compatible with all supported
  651. # search algorithms, since a lot of them do not support non-primitive
  652. # config values.
  653. placeholder_resolvers = create_resolvers_map()
  654. config = inject_placeholders(
  655. # Make a deep copy here to avoid modifying the original config dict.
  656. copy.deepcopy(config),
  657. placeholder_resolvers,
  658. )
  659. # TODO(justinvyu): We should remove the ability to pass a list of
  660. # trainables to tune.run.
  661. if isinstance(run_or_experiment, list):
  662. experiments = run_or_experiment
  663. else:
  664. experiments = [run_or_experiment]
  665. for i, exp in enumerate(experiments):
  666. if not isinstance(exp, Experiment):
  667. experiments[i] = Experiment(
  668. name=name,
  669. run=exp,
  670. stop=stop,
  671. time_budget_s=time_budget_s,
  672. config=config,
  673. resources_per_trial=resources_per_trial,
  674. num_samples=num_samples,
  675. storage_path=storage_path,
  676. storage_filesystem=storage_filesystem,
  677. sync_config=sync_config,
  678. checkpoint_config=checkpoint_config,
  679. trial_name_creator=trial_name_creator,
  680. trial_dirname_creator=trial_dirname_creator,
  681. log_to_file=log_to_file,
  682. export_formats=export_formats,
  683. max_failures=max_failures,
  684. restore=restore,
  685. )
  686. if fail_fast and max_failures != 0:
  687. raise ValueError("max_failures must be 0 if fail_fast=True.")
  688. if isinstance(search_alg, str):
  689. search_alg = create_searcher(search_alg)
  690. # if local_mode=True is set during ray.init().
  691. is_local_mode = ray._private.worker._mode() == ray._private.worker.LOCAL_MODE
  692. if is_local_mode:
  693. max_concurrent_trials = 1
  694. if not search_alg:
  695. search_alg = BasicVariantGenerator(max_concurrent=max_concurrent_trials or 0)
  696. elif max_concurrent_trials or is_local_mode:
  697. if isinstance(search_alg, ConcurrencyLimiter):
  698. if not is_local_mode:
  699. if search_alg.max_concurrent != max_concurrent_trials:
  700. raise ValueError(
  701. "You have specified `max_concurrent_trials="
  702. f"{max_concurrent_trials}`, but the `search_alg` is "
  703. "already a `ConcurrencyLimiter` with `max_concurrent="
  704. f"{search_alg.max_concurrent}. FIX THIS by setting "
  705. "`max_concurrent_trials=None`."
  706. )
  707. else:
  708. logger.warning(
  709. "You have specified `max_concurrent_trials="
  710. f"{max_concurrent_trials}`, but the `search_alg` is "
  711. "already a `ConcurrencyLimiter`. "
  712. "`max_concurrent_trials` will be ignored."
  713. )
  714. else:
  715. if max_concurrent_trials < 1:
  716. raise ValueError(
  717. "`max_concurrent_trials` must be greater or equal than 1, "
  718. f"got {max_concurrent_trials}."
  719. )
  720. if isinstance(search_alg, Searcher):
  721. search_alg = ConcurrencyLimiter(
  722. search_alg, max_concurrent=max_concurrent_trials
  723. )
  724. elif not is_local_mode:
  725. logger.warning(
  726. "You have passed a `SearchGenerator` instance as the "
  727. "`search_alg`, but `max_concurrent_trials` requires a "
  728. "`Searcher` instance`. `max_concurrent_trials` "
  729. "will be ignored."
  730. )
  731. if isinstance(search_alg, Searcher):
  732. search_alg = SearchGenerator(search_alg)
  733. if config and not searcher_set_search_props(
  734. search_alg.set_search_properties,
  735. metric,
  736. mode,
  737. config,
  738. **experiments[0].public_spec,
  739. ):
  740. if _has_unresolved_values(config):
  741. raise ValueError(
  742. f"You passed a `{error_message_map['search_space_arg']}` parameter to "
  743. f"`{error_message_map['entrypoint']}` with "
  744. "unresolved parameters, but the search algorithm was already "
  745. "instantiated with a search space. Make sure that `config` "
  746. "does not contain any more parameter definitions - include "
  747. "them in the search algorithm's search space if necessary."
  748. )
  749. if not scheduler_set_search_props(
  750. scheduler.set_search_properties, metric, mode, **experiments[0].public_spec
  751. ):
  752. raise ValueError(
  753. "You passed a `metric` or `mode` argument to "
  754. f"`{error_message_map['entrypoint']}`, but "
  755. "the scheduler you are using was already instantiated with their "
  756. "own `metric` and `mode` parameters. Either remove the arguments "
  757. f"from your scheduler or from `{error_message_map['entrypoint']}` args."
  758. )
  759. progress_metrics = _detect_progress_metrics(_get_trainable(run_or_experiment))
  760. air_usage.tag_storage_type(experiments[0].storage)
  761. # NOTE: Report callback telemetry before populating the list with default callbacks.
  762. # This tracks user-specified callback usage.
  763. air_usage.tag_callbacks(callbacks)
  764. # Create default logging + syncer callbacks
  765. callbacks = _create_default_callbacks(
  766. callbacks,
  767. air_verbosity=air_verbosity,
  768. entrypoint=_entrypoint,
  769. config=config,
  770. metric=metric,
  771. mode=mode,
  772. progress_metrics=progress_metrics,
  773. )
  774. # User Warning for GPUs
  775. if ray.cluster_resources().get("GPU", 0):
  776. if _check_gpus_in_resources(resources=resources_per_trial):
  777. # "gpu" is manually set.
  778. pass
  779. elif _check_default_resources_override(experiments[0].run_identifier):
  780. # "default_resources" is manually overridden.
  781. pass
  782. else:
  783. logger.warning(
  784. "Tune detects GPUs, but no trials are using GPUs. "
  785. "To enable trials to use GPUs, wrap `train_func` with "
  786. "`tune.with_resources(train_func, resources_per_trial={'gpu': 1})` "
  787. "which allows Tune to expose 1 GPU to each trial. "
  788. "For Ray Train Trainers, you can specify GPU resources "
  789. "through `ScalingConfig(use_gpu=True)`. "
  790. "You can also override "
  791. "`Trainable.default_resource_request` if using the "
  792. "Trainable API."
  793. )
  794. experiment_interrupted_event = _setup_signal_catching()
  795. if progress_reporter and air_verbosity is not None:
  796. logger.warning(
  797. "AIR_VERBOSITY is set, ignoring passed-in ProgressReporter for now."
  798. )
  799. progress_reporter = None
  800. if air_verbosity is None:
  801. is_trainer = _entrypoint == AirEntrypoint.TRAINER
  802. progress_reporter = progress_reporter or _detect_reporter(
  803. _trainer_api=is_trainer
  804. )
  805. if resume is not None:
  806. resume_config = resume_config or _build_resume_config_from_legacy_config(resume)
  807. runner_kwargs = dict(
  808. search_alg=search_alg,
  809. placeholder_resolvers=placeholder_resolvers,
  810. scheduler=scheduler,
  811. stopper=experiments[0].stopper,
  812. resume_config=resume_config,
  813. fail_fast=fail_fast,
  814. callbacks=callbacks,
  815. metric=metric,
  816. trial_checkpoint_config=experiments[0].checkpoint_config,
  817. reuse_actors=reuse_actors,
  818. storage=experiments[0].storage,
  819. _trainer_api=_entrypoint == AirEntrypoint.TRAINER,
  820. )
  821. runner = TuneController(**runner_kwargs)
  822. if not runner.resumed:
  823. for exp in experiments:
  824. search_alg.add_configurations([exp])
  825. # search_alg.total_samples has been updated, so we should
  826. # update the number of pending trials
  827. runner.update_max_pending_trials()
  828. else:
  829. logger.debug(
  830. "You have resumed the Tune run, which means that any newly specified "
  831. "`Experiment`s will be ignored. "
  832. "Tune will just continue what was previously running."
  833. )
  834. if resources_per_trial:
  835. runner.update_pending_trial_resources(resources_per_trial)
  836. # Calls setup on callbacks
  837. runner.setup_experiments(
  838. experiments=experiments, total_num_samples=search_alg.total_samples
  839. )
  840. tune_start = time.time()
  841. air_progress_reporter = None
  842. if air_verbosity is None:
  843. progress_reporter.setup(
  844. start_time=tune_start,
  845. total_samples=search_alg.total_samples,
  846. metric=metric,
  847. mode=mode,
  848. )
  849. else:
  850. from ray.tune.experimental.output import ProgressReporter as AirProgressReporter
  851. for callback in callbacks:
  852. if isinstance(callback, AirProgressReporter):
  853. air_progress_reporter = callback
  854. air_progress_reporter.setup(
  855. start_time=tune_start, total_samples=search_alg.total_samples
  856. )
  857. break
  858. experiment_local_path = runner._storage.experiment_driver_staging_path
  859. experiment_dir_name = runner._storage.experiment_dir_name
  860. if any(isinstance(cb, TBXLoggerCallback) for cb in callbacks):
  861. tensorboard_path = experiment_local_path
  862. else:
  863. tensorboard_path = None
  864. if air_progress_reporter:
  865. air_progress_reporter.experiment_started(
  866. experiment_name=experiment_dir_name,
  867. experiment_path=runner.experiment_path,
  868. searcher_str=search_alg.__class__.__name__,
  869. scheduler_str=scheduler.__class__.__name__,
  870. total_num_samples=search_alg.total_samples,
  871. tensorboard_path=tensorboard_path,
  872. )
  873. try:
  874. while not runner.is_finished() and not experiment_interrupted_event.is_set():
  875. runner.step()
  876. if has_verbosity(Verbosity.V1_EXPERIMENT):
  877. _report_progress(runner, progress_reporter)
  878. if air_verbosity is not None:
  879. _report_air_progress(runner, air_progress_reporter)
  880. except Exception:
  881. runner.cleanup()
  882. raise
  883. tune_taken = time.time() - tune_start
  884. final_sync_start = time.time()
  885. try:
  886. runner.checkpoint(force=True, wait=True)
  887. logger.info(
  888. "Wrote the latest version of all result files and experiment state to "
  889. f"'{runner.experiment_path}' in {time.time() - final_sync_start:.4f}s."
  890. )
  891. except Exception:
  892. logger.error(
  893. "Experiment state snapshotting failed:", exc_info=True, stack_info=True
  894. )
  895. if has_verbosity(Verbosity.V1_EXPERIMENT):
  896. _report_progress(runner, progress_reporter, done=True)
  897. if air_verbosity is not None:
  898. _report_air_progress(runner, air_progress_reporter, force=True)
  899. all_trials = runner.get_trials()
  900. runner.cleanup()
  901. incomplete_trials = []
  902. for trial in all_trials:
  903. if trial.status != Trial.TERMINATED:
  904. incomplete_trials += [trial]
  905. if incomplete_trials:
  906. if raise_on_failed_trial and not experiment_interrupted_event.is_set():
  907. raise TuneError("Trials did not complete", incomplete_trials)
  908. else:
  909. logger.error("Trials did not complete: %s", incomplete_trials)
  910. all_taken = time.time() - all_start
  911. if has_verbosity(Verbosity.V1_EXPERIMENT):
  912. logger.info(
  913. f"Total run time: {all_taken:.2f} seconds "
  914. f"({tune_taken:.2f} seconds for the tuning loop)."
  915. )
  916. if experiment_interrupted_event.is_set():
  917. restore_entrypoint = error_message_map["restore_entrypoint"].format(
  918. path=runner.experiment_path,
  919. )
  920. if _entrypoint == AirEntrypoint.TRAINER:
  921. logger.warning(
  922. f"Training has been interrupted, but the most recent state was saved.\n"
  923. f"Resume training with: {restore_entrypoint}"
  924. )
  925. else:
  926. logger.warning(
  927. f"Experiment has been interrupted, but the most recent state was "
  928. f"saved.\nResume experiment with: {restore_entrypoint}"
  929. )
  930. return ExperimentAnalysis(
  931. experiment_checkpoint_path=runner.experiment_path,
  932. default_metric=metric,
  933. default_mode=mode,
  934. trials=all_trials,
  935. storage_filesystem=experiments[0].storage.storage_filesystem,
  936. )
  937. @PublicAPI
  938. def run_experiments(
  939. experiments: Union[Experiment, Mapping, Sequence[Union[Experiment, Mapping]]],
  940. scheduler: Optional[TrialScheduler] = None,
  941. verbose: Optional[Union[int, AirVerbosity, Verbosity]] = None,
  942. progress_reporter: Optional[ProgressReporter] = None,
  943. resume: Optional[Union[bool, str]] = None,
  944. resume_config: Optional[ResumeConfig] = None,
  945. reuse_actors: bool = False,
  946. raise_on_failed_trial: bool = True,
  947. concurrent: bool = True,
  948. callbacks: Optional[Sequence[Callback]] = None,
  949. _remote: Optional[bool] = None,
  950. ):
  951. """Runs and blocks until all trials finish.
  952. Example:
  953. >>> from ray.tune.experiment import Experiment
  954. >>> from ray.tune.tune import run_experiments
  955. >>> def my_func(config): return {"score": 0}
  956. >>> experiment_spec = Experiment("experiment", my_func) # doctest: +SKIP
  957. >>> run_experiments(experiments=experiment_spec) # doctest: +SKIP
  958. >>> experiment_spec = {"experiment": {"run": my_func}} # doctest: +SKIP
  959. >>> run_experiments(experiments=experiment_spec) # doctest: +SKIP
  960. Returns:
  961. List of Trial objects, holding data for each executed trial.
  962. """
  963. if _remote is None:
  964. _remote = ray.util.client.ray.is_connected()
  965. _ray_auto_init(entrypoint="tune.run_experiments(...)")
  966. if verbose is None:
  967. # Default `verbose` value. For new output engine, this is AirVerbosity.VERBOSE.
  968. # For old output engine, this is Verbosity.V3_TRIAL_DETAILS
  969. verbose = get_air_verbosity(AirVerbosity.VERBOSE) or Verbosity.V3_TRIAL_DETAILS
  970. if _remote:
  971. if get_air_verbosity(verbose) is not None:
  972. logger.info(
  973. "[output] This uses the legacy output and progress reporter, "
  974. "as Ray client is not supported by the new engine. "
  975. "For more information, see "
  976. "https://github.com/ray-project/ray/issues/36949"
  977. )
  978. remote_run = ray.remote(num_cpus=0)(run_experiments)
  979. # Make sure tune.run_experiments is run on the server node.
  980. remote_run = _force_on_current_node(remote_run)
  981. return ray.get(
  982. remote_run.remote(
  983. experiments,
  984. scheduler,
  985. verbose,
  986. progress_reporter,
  987. resume,
  988. resume_config,
  989. reuse_actors,
  990. raise_on_failed_trial,
  991. concurrent,
  992. callbacks,
  993. _remote=False,
  994. )
  995. )
  996. # This is important to do this here
  997. # because it schematize the experiments
  998. # and it conducts the implicit registration.
  999. experiments = _convert_to_experiment_list(experiments)
  1000. tune_run_params = dict(
  1001. verbose=verbose,
  1002. progress_reporter=progress_reporter,
  1003. resume=resume,
  1004. resume_config=resume_config,
  1005. reuse_actors=reuse_actors,
  1006. raise_on_failed_trial=raise_on_failed_trial,
  1007. scheduler=scheduler,
  1008. callbacks=callbacks,
  1009. _entrypoint=AirEntrypoint.TUNE_RUN_EXPERIMENTS,
  1010. )
  1011. if concurrent:
  1012. return run(experiments, **tune_run_params).trials
  1013. else:
  1014. trials = []
  1015. for exp in experiments:
  1016. trials += run(exp, **tune_run_params).trials
  1017. return trials