tuner_internal.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674
  1. import copy
  2. import io
  3. import logging
  4. import math
  5. from pathlib import Path
  6. from typing import (
  7. TYPE_CHECKING,
  8. Any,
  9. Callable,
  10. Dict,
  11. List,
  12. Optional,
  13. Tuple,
  14. Type,
  15. Union,
  16. )
  17. import pyarrow.fs
  18. import ray.cloudpickle as pickle
  19. import ray.train
  20. import ray.tune
  21. from ray.air._internal.uri_utils import URI
  22. from ray.air._internal.usage import AirEntrypoint
  23. from ray.train._internal.storage import StorageContext, get_fs_and_path
  24. from ray.train.constants import (
  25. V2_MIGRATION_GUIDE_MESSAGE,
  26. _v2_migration_warnings_enabled,
  27. )
  28. from ray.train.utils import _log_deprecation_warning
  29. from ray.tune import (
  30. Experiment,
  31. ExperimentAnalysis,
  32. ResumeConfig,
  33. RunConfig,
  34. TuneConfig,
  35. TuneError,
  36. )
  37. from ray.tune.registry import is_function_trainable
  38. from ray.tune.result_grid import ResultGrid
  39. from ray.tune.trainable import Trainable
  40. from ray.tune.tune import _Config, run
  41. from ray.tune.utils import flatten_dict
  42. from ray.util import inspect_serializability
  43. if TYPE_CHECKING:
  44. from ray.train.trainer import BaseTrainer
  45. from ray.util.queue import Queue
  46. _TUNER_PKL = "tuner.pkl"
  47. _TRAINABLE_KEY = "_trainable"
  48. _CONVERTED_TRAINABLE_KEY = "_converted_trainable"
  49. _PARAM_SPACE_KEY = "_param_space"
  50. _EXPERIMENT_ANALYSIS_KEY = "_experiment_analysis"
  51. logger = logging.getLogger(__name__)
  52. TrainableType = Union[str, Callable, Type[Trainable]]
  53. TrainableTypeOrTrainer = Union[TrainableType, "BaseTrainer"]
  54. class TunerInternal:
  55. """The real implementation behind external facing ``Tuner``.
  56. The external facing ``Tuner`` multiplexes between local Tuner and remote Tuner
  57. depending on whether in Ray client mode.
  58. In Ray client mode, external ``Tuner`` wraps ``TunerInternal`` into a remote actor,
  59. which is guaranteed to be placed on head node.
  60. ``TunerInternal`` can be constructed from fresh, in which case, ``trainable`` needs
  61. to be provided, together with optional ``param_space``, ``tune_config`` and
  62. ``run_config``.
  63. It can also be restored from a previous failed run (given ``restore_path``).
  64. Args:
  65. restore_path: The path from where the Tuner can be restored. If provided, None
  66. of the rest args are needed.
  67. resume_config: Resume config to configure which trials to continue.
  68. trainable: The trainable to be tuned.
  69. param_space: Search space of the tuning job.
  70. One thing to note is that both preprocessor and dataset can be tuned here.
  71. tune_config: Tuning algorithm specific configs.
  72. Refer to ray.tune.tune_config.TuneConfig for more info.
  73. run_config: Runtime configuration that is specific to individual trials.
  74. If passed, this will overwrite the run config passed to the Trainer,
  75. if applicable. Refer to ray.tune.RunConfig for more info.
  76. """
  77. def __init__(
  78. self,
  79. restore_path: str = None,
  80. storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
  81. resume_config: Optional[ResumeConfig] = None,
  82. trainable: Optional[TrainableTypeOrTrainer] = None,
  83. param_space: Optional[Dict[str, Any]] = None,
  84. tune_config: Optional[TuneConfig] = None,
  85. run_config: Optional[RunConfig] = None,
  86. _tuner_kwargs: Optional[Dict] = None,
  87. _entrypoint: AirEntrypoint = AirEntrypoint.TUNER,
  88. ):
  89. from ray.train.trainer import BaseTrainer
  90. if isinstance(trainable, BaseTrainer):
  91. if _v2_migration_warnings_enabled():
  92. _log_deprecation_warning(
  93. "The Ray Train + Ray Tune integration has been reworked. "
  94. "Passing a Trainer to the Tuner is deprecated and will be removed "
  95. "in a future release. "
  96. f"{V2_MIGRATION_GUIDE_MESSAGE}"
  97. )
  98. run_config = self._choose_run_config(
  99. tuner_run_config=run_config,
  100. trainer=trainable,
  101. param_space=param_space,
  102. )
  103. self._tune_config = tune_config or TuneConfig()
  104. self._run_config = copy.copy(run_config) or RunConfig()
  105. self._entrypoint = _entrypoint
  106. # Restore from Tuner checkpoint.
  107. if restore_path:
  108. self._restore_from_path_or_uri(
  109. path_or_uri=restore_path,
  110. trainable=trainable,
  111. overwrite_param_space=param_space,
  112. resume_config=resume_config,
  113. storage_filesystem=storage_filesystem,
  114. )
  115. return
  116. # Start from fresh
  117. if not trainable:
  118. raise TuneError("You need to provide a trainable to tune.")
  119. if self._entrypoint == AirEntrypoint.TUNER and not isinstance(
  120. self._run_config, ray.tune.RunConfig
  121. ):
  122. if _v2_migration_warnings_enabled():
  123. _log_deprecation_warning(
  124. "The `RunConfig` class should be imported from `ray.tune` "
  125. "when passing it to the Tuner. Please update your imports. "
  126. f"{V2_MIGRATION_GUIDE_MESSAGE}"
  127. )
  128. self.trainable = trainable
  129. assert self.converted_trainable
  130. self._validate_trainable(self.converted_trainable)
  131. self.param_space = param_space
  132. self._resume_config = None
  133. self._is_restored = False
  134. self._tuner_kwargs = copy.deepcopy(_tuner_kwargs) or {}
  135. self._experiment_analysis = None
  136. self._run_config.name = (
  137. self._run_config.name
  138. or StorageContext.get_experiment_dir_name(self.converted_trainable)
  139. )
  140. # The storage context here is only used to access the resolved
  141. # storage fs and experiment path, in order to avoid duplicating that logic.
  142. # This is NOT the storage context object that gets passed to remote workers.
  143. storage = StorageContext(
  144. storage_path=self._run_config.storage_path,
  145. experiment_dir_name=self._run_config.name,
  146. storage_filesystem=self._run_config.storage_filesystem,
  147. )
  148. fs = storage.storage_filesystem
  149. fs.create_dir(storage.experiment_fs_path)
  150. with fs.open_output_stream(
  151. Path(storage.experiment_fs_path, _TUNER_PKL).as_posix()
  152. ) as f:
  153. f.write(pickle.dumps(self.__getstate__()))
  154. def get_run_config(self) -> RunConfig:
  155. return self._run_config
  156. # For Jupyter output with Ray Client
  157. def set_run_config_and_remote_string_queue(
  158. self, run_config: RunConfig, string_queue: "Queue"
  159. ):
  160. self._run_config = run_config
  161. self._tuner_kwargs["_remote_string_queue"] = string_queue
  162. def clear_remote_string_queue(self):
  163. self._tuner_kwargs.pop("_remote_string_queue", None)
  164. def _expected_utilization(self, cpus_per_trial, cpus_total):
  165. num_samples = self._tune_config.num_samples
  166. if num_samples < 0: # TODO: simplify this in Tune
  167. num_samples = math.inf
  168. concurrent_trials = self._tune_config.max_concurrent_trials or 0
  169. if concurrent_trials < 1: # TODO: simplify this in Tune
  170. concurrent_trials = math.inf
  171. actual_concurrency = min(
  172. (
  173. (cpus_total // cpus_per_trial) if cpus_per_trial else 0,
  174. num_samples,
  175. concurrent_trials,
  176. )
  177. )
  178. return (actual_concurrency * cpus_per_trial) / (cpus_total + 0.001)
  179. def _validate_trainable(
  180. self, trainable: TrainableType, required_trainable_name: Optional[str] = None
  181. ):
  182. """Determines whether or not the trainable is valid.
  183. This includes checks on the serializability of the trainable, as well
  184. asserting that the trainable name is as expected on restoration.
  185. This trainable name validation is needed due to an implementation detail
  186. where the trainable name (which is differently generated depending on
  187. the trainable type) is saved in the Trial metadata and needs to match
  188. upon restoration. This does not affect the typical path, since `Tuner.restore`
  189. expects the exact same trainable (which will have the same name).
  190. Raises:
  191. ValueError: if the trainable name does not match or if the trainable
  192. is not serializable.
  193. """
  194. try:
  195. pickle.dumps(trainable)
  196. except TypeError as e:
  197. sio = io.StringIO()
  198. inspect_serializability(trainable, print_file=sio)
  199. msg = (
  200. "The provided trainable is not serializable, which is a requirement "
  201. "since the trainable is serialized and deserialized when transferred "
  202. "to remote workers. See below for a trace of the non-serializable "
  203. "objects that were found in your trainable:\n"
  204. f"{sio.getvalue()}"
  205. )
  206. raise TypeError(msg) from e
  207. if not required_trainable_name:
  208. return
  209. trainable_name = Experiment.get_trainable_name(trainable)
  210. if trainable_name != required_trainable_name:
  211. raise ValueError(
  212. "Invalid `trainable` input to `Tuner.restore()`. To fix this error, "
  213. "pass in the same trainable that was used to initialize the Tuner. "
  214. "Got a trainable with identifier "
  215. f"'{trainable_name}' but expected '{required_trainable_name}'."
  216. )
  217. def _set_trainable_on_restore(
  218. self, trainable: TrainableType, old_trainable_name: Optional[str]
  219. ):
  220. from ray.train.base_trainer import BaseTrainer
  221. self.trainable = trainable
  222. assert self.converted_trainable
  223. self._validate_trainable(
  224. trainable=self.converted_trainable,
  225. required_trainable_name=old_trainable_name,
  226. )
  227. if isinstance(self.trainable, BaseTrainer):
  228. # Log a warning in case the user tries to modify the
  229. # `RunConfig` from the Trainer
  230. trainer: BaseTrainer = self.trainable
  231. # Only log if the Trainer has a non-default RunConfig
  232. if trainer.run_config != RunConfig():
  233. logger.warning(
  234. "The Tune experiment will restore using the original run's "
  235. "`RunConfig`. If you made any changes to the `RunConfig` "
  236. "within the Trainer you passed into `Tuner.restore`, "
  237. "they will be ignored in the resumed run."
  238. )
  239. trainer.run_config = self._run_config
  240. def _validate_param_space_on_restore(
  241. self,
  242. new_param_space: Dict[str, Any],
  243. flattened_param_space_keys: Optional[List[str]],
  244. ):
  245. """Determines whether the (optionally) re-specified `param_space` is valid.
  246. This method performs very loose validation on the new param_space to
  247. prevent users from trying to specify new hyperparameters to tune over.
  248. Raises:
  249. ValueError: if not all keys match the original param_space.
  250. """
  251. if flattened_param_space_keys is None:
  252. # Backwards compatibility: skip validation
  253. return
  254. keys = sorted(flatten_dict(new_param_space).keys())
  255. if keys != flattened_param_space_keys:
  256. raise ValueError(
  257. "Invalid `param_space` input to `Tuner.restore()`. To fix this error, "
  258. "pass in the same `param_space` that was used to initialize the Tuner. "
  259. "Only re-specify the `param_space` to refresh Ray object references "
  260. "that no longer exist due to restoring from a new Ray cluster session. "
  261. "It should not be used to introduce new hyperparameters to tune."
  262. f"\n\nGot: {keys}\nExpected: {flattened_param_space_keys}"
  263. )
  264. def _set_param_space_on_restore(
  265. self,
  266. param_space: Optional[Dict[str, Any]],
  267. flattened_param_space_keys: Optional[List[str]],
  268. ):
  269. self.param_space = param_space
  270. if self.param_space is not None:
  271. # param_space = None -> use the original param_space
  272. self._validate_param_space_on_restore(
  273. new_param_space=self.param_space,
  274. flattened_param_space_keys=flattened_param_space_keys,
  275. )
  276. def _load_tuner_state(
  277. self, tuner_state: Dict[str, Any]
  278. ) -> Tuple[Optional[str], Optional[List[str]]]:
  279. """Loads Tuner state from the previously saved `tuner.pkl`.
  280. Args:
  281. tuner_pkl_path: pathlib.Path of the `tuner.pkl` file saved during the
  282. original Tuner initialization.
  283. Returns:
  284. tuple: of `(old_trainable_name, flattened_param_space_keys)` used for
  285. validating the re-specified `trainable` and `param_space`.
  286. """
  287. # NOTE: These are magic keys used for validating restore args.
  288. old_trainable_name = tuner_state.pop("__trainable_name", None)
  289. flattened_param_space_keys = tuner_state.pop(
  290. "__flattened_param_space_keys", None
  291. )
  292. self.__setstate__(tuner_state)
  293. return old_trainable_name, flattened_param_space_keys
  294. def _restore_from_path_or_uri(
  295. self,
  296. path_or_uri: str,
  297. trainable: TrainableTypeOrTrainer,
  298. overwrite_param_space: Optional[Dict[str, Any]],
  299. resume_config: ResumeConfig,
  300. storage_filesystem: Optional[pyarrow.fs.FileSystem],
  301. ):
  302. fs, fs_path = get_fs_and_path(path_or_uri, storage_filesystem)
  303. with fs.open_input_file(Path(fs_path, _TUNER_PKL).as_posix()) as f:
  304. tuner_state = pickle.loads(f.readall())
  305. old_trainable_name, flattened_param_space_keys = self._load_tuner_state(
  306. tuner_state
  307. )
  308. # Perform validation and set the re-specified `trainable` and `param_space`
  309. self._set_trainable_on_restore(
  310. trainable=trainable, old_trainable_name=old_trainable_name
  311. )
  312. self._set_param_space_on_restore(
  313. param_space=overwrite_param_space,
  314. flattened_param_space_keys=flattened_param_space_keys,
  315. )
  316. # Update RunConfig to reflect changes in the experiment directory
  317. path_or_uri_obj = URI(path_or_uri)
  318. # Infer the `storage_path` and run `name` of the restored run using the
  319. # experiment directory.
  320. # Ex: ~/ray_results/exp_name -> ~/ray_results, exp_name
  321. # Ex: s3://bucket/exp_name -> s3://bucket, exp_name
  322. self._run_config.name = path_or_uri_obj.name
  323. self._run_config.storage_path = str(path_or_uri_obj.parent)
  324. # Update the storage_filesystem with the one passed in on restoration, if any.
  325. self._run_config.storage_filesystem = storage_filesystem
  326. # Load the experiment results at the point where it left off.
  327. try:
  328. self._experiment_analysis = ExperimentAnalysis(
  329. experiment_checkpoint_path=path_or_uri,
  330. default_metric=self._tune_config.metric,
  331. default_mode=self._tune_config.mode,
  332. storage_filesystem=storage_filesystem,
  333. )
  334. except Exception:
  335. self._experiment_analysis = None
  336. self._resume_config = resume_config
  337. self._is_restored = True
  338. def _choose_run_config(
  339. self,
  340. tuner_run_config: Optional[RunConfig],
  341. trainer: "BaseTrainer",
  342. param_space: Optional[Dict[str, Any]],
  343. ) -> RunConfig:
  344. """Chooses which `RunConfig` to use when multiple can be passed in
  345. through a Trainer or the Tuner itself.
  346. Args:
  347. tuner_run_config: The run config passed into the Tuner constructor.
  348. trainer: The Trainer instance to use with Tune, which may have
  349. a RunConfig specified by the user.
  350. param_space: The param space passed to the Tuner.
  351. Raises:
  352. ValueError: if the `run_config` is specified as a hyperparameter.
  353. """
  354. if param_space and "run_config" in param_space:
  355. raise ValueError(
  356. "`RunConfig` cannot be tuned as part of the `param_space`! "
  357. "Move the run config to be a parameter of the `Tuner`: "
  358. "Tuner(..., run_config=RunConfig(...))"
  359. )
  360. # Both Tuner RunConfig + Trainer RunConfig --> prefer Tuner RunConfig
  361. if tuner_run_config and trainer.run_config != ray.train.RunConfig():
  362. logger.info(
  363. "A `RunConfig` was passed to both the `Tuner` and the "
  364. f"`{trainer.__class__.__name__}`. The run config passed to "
  365. "the `Tuner` is the one that will be used."
  366. )
  367. return tuner_run_config
  368. # No Tuner RunConfig -> pass the Trainer config through
  369. # This returns either a user-specified config, or the default RunConfig
  370. # if nothing was provided to both the Trainer or Tuner.
  371. if not tuner_run_config:
  372. return trainer.run_config
  373. # Tuner RunConfig + No Trainer RunConfig --> Use the Tuner config
  374. return tuner_run_config
  375. def _process_scaling_config(self) -> None:
  376. """Converts ``self._param_space["scaling_config"]`` to a dict.
  377. The dict is converted back to a dataclass by the Trainer, after the
  378. Tune search specification is resolved.
  379. """
  380. # TODO: introduce `ray.tune.sample.TuneableDataclass` and allow Tune to
  381. # natively resolve specs with dataclasses.
  382. scaling_config = self._param_space.get("scaling_config")
  383. if not isinstance(scaling_config, ray.train.ScalingConfig):
  384. return
  385. self._param_space["scaling_config"] = scaling_config.__dict__.copy()
  386. @property
  387. def trainable(self) -> TrainableTypeOrTrainer:
  388. return self._trainable
  389. @property
  390. def converted_trainable(self) -> TrainableType:
  391. return self._converted_trainable
  392. @trainable.setter
  393. def trainable(self, trainable: TrainableTypeOrTrainer):
  394. self._trainable = trainable
  395. self._converted_trainable = self._convert_trainable(trainable)
  396. @property
  397. def param_space(self) -> Optional[Dict[str, Any]]:
  398. return self._param_space
  399. @param_space.setter
  400. def param_space(self, param_space: Optional[Dict[str, Any]]):
  401. # Handle any configs that adhere to the `to_dict` interface.
  402. # Ex: AlgorithmConfig from RLlib
  403. if isinstance(param_space, _Config):
  404. param_space = param_space.to_dict()
  405. if not isinstance(param_space, dict) and param_space is not None:
  406. raise ValueError(
  407. "The `param_space` passed to the `Tuner` must be a dict. "
  408. f"Got '{type(param_space)}' instead."
  409. )
  410. self._param_space = param_space
  411. if param_space:
  412. self._process_scaling_config()
  413. def _convert_trainable(self, trainable: TrainableTypeOrTrainer) -> TrainableType:
  414. """Converts a Trainer to a Tune trainable and saves the converted
  415. trainable. If not using a Trainer, this leaves the trainable as is."""
  416. from ray.train.trainer import BaseTrainer
  417. return (
  418. trainable.as_trainable()
  419. if isinstance(trainable, BaseTrainer)
  420. else trainable
  421. )
  422. def fit(self) -> ResultGrid:
  423. trainable = self.converted_trainable
  424. param_space = copy.deepcopy(self.param_space)
  425. if not self._is_restored:
  426. analysis = self._fit_internal(trainable, param_space)
  427. else:
  428. analysis = self._fit_resume(trainable, param_space)
  429. self._experiment_analysis = analysis
  430. return ResultGrid(self._experiment_analysis)
  431. def get_results(self) -> ResultGrid:
  432. if not self._experiment_analysis:
  433. raise RuntimeError(
  434. "Can't return results as experiment has not been run, yet. "
  435. "Call `Tuner.fit()` to run the experiment first."
  436. )
  437. return ResultGrid(self._experiment_analysis)
  438. def _get_tune_run_arguments(self, trainable: TrainableType) -> Dict[str, Any]:
  439. """Get tune.run arguments common for both new and resumed runs."""
  440. # Avoid overwriting the originally configured checkpoint config.
  441. checkpoint_config = copy.deepcopy(self._run_config.checkpoint_config)
  442. if checkpoint_config.checkpoint_frequency:
  443. # Function trainables (and thus most of our trainers) usually don't handle
  444. # this argument.
  445. handle_checkpoint_freq = getattr(
  446. trainable, "_handles_checkpoint_freq", None
  447. )
  448. if handle_checkpoint_freq is False:
  449. # If we specifically know this trainable doesn't support the
  450. # argument, raise an error
  451. raise ValueError(
  452. "You passed `checkpoint_frequency="
  453. f"{checkpoint_config.checkpoint_frequency}` to your "
  454. "CheckpointConfig, but this trainer does not support "
  455. "this argument. If you passed in a Trainer that takes in a "
  456. "custom training loop, you will need to "
  457. "report a checkpoint every `checkpoint_frequency` iterations "
  458. "within your training loop using "
  459. "`ray.tune.report(metrics=..., checkpoint=...)` "
  460. "to get this behavior."
  461. )
  462. elif handle_checkpoint_freq is True:
  463. # If we specifically support it, it's handled in the training loop,
  464. # so we disable tune's bookkeeping.
  465. checkpoint_config.checkpoint_frequency = 0
  466. # Otherwise, the trainable is not a Trainer and we just keep the
  467. # user-supplied value.
  468. # Function trainables will raise a runtime error later if set > 0
  469. if checkpoint_config.checkpoint_at_end is not None:
  470. # Again, function trainables usually don't handle this argument.
  471. handle_cp_at_end = getattr(trainable, "_handles_checkpoint_at_end", None)
  472. if handle_cp_at_end is False:
  473. # If we specifically know we don't support it, raise an error.
  474. raise ValueError(
  475. "You passed `checkpoint_at_end="
  476. f"{checkpoint_config.checkpoint_at_end}` "
  477. "to your CheckpointConfig, but this trainer does not support "
  478. "this argument. If you passed in a Trainer that takes in a "
  479. "custom training loop, you should include one last call to "
  480. "`ray.tune.report(metrics=..., checkpoint=...)` "
  481. "at the end of your training loop to get this behavior."
  482. )
  483. elif handle_cp_at_end is True:
  484. # If we specifically support it, it's handled in the training loop,
  485. # so we disable tune's internal bookkeeping.
  486. checkpoint_config.checkpoint_at_end = False
  487. # If this is a user-defined trainable, just keep the value
  488. # Function trainables will raise a runtime error later if set to True
  489. else:
  490. # Set default to False for function trainables and True for everything else
  491. if is_function_trainable(trainable):
  492. checkpoint_config.checkpoint_at_end = False
  493. else:
  494. checkpoint_config.checkpoint_at_end = True
  495. return dict(
  496. storage_path=self._run_config.storage_path,
  497. storage_filesystem=self._run_config.storage_filesystem,
  498. name=self._run_config.name,
  499. mode=self._tune_config.mode,
  500. metric=self._tune_config.metric,
  501. callbacks=self._run_config.callbacks,
  502. sync_config=self._run_config.sync_config,
  503. stop=self._run_config.stop,
  504. max_failures=self._run_config.failure_config.max_failures,
  505. checkpoint_config=checkpoint_config,
  506. raise_on_failed_trial=False,
  507. fail_fast=(self._run_config.failure_config.fail_fast),
  508. progress_reporter=self._run_config.progress_reporter,
  509. verbose=self._run_config.verbose,
  510. reuse_actors=self._tune_config.reuse_actors,
  511. max_concurrent_trials=self._tune_config.max_concurrent_trials,
  512. time_budget_s=self._tune_config.time_budget_s,
  513. trial_name_creator=self._tune_config.trial_name_creator,
  514. trial_dirname_creator=self._tune_config.trial_dirname_creator,
  515. _entrypoint=self._entrypoint,
  516. # Deprecated
  517. chdir_to_trial_dir=self._tune_config.chdir_to_trial_dir,
  518. )
  519. def _fit_internal(
  520. self, trainable: TrainableType, param_space: Optional[Dict[str, Any]]
  521. ) -> ExperimentAnalysis:
  522. """Fitting for a fresh Tuner."""
  523. args = {
  524. **self._get_tune_run_arguments(trainable),
  525. **dict(
  526. run_or_experiment=trainable,
  527. config=param_space,
  528. num_samples=self._tune_config.num_samples,
  529. search_alg=self._tune_config.search_alg,
  530. scheduler=self._tune_config.scheduler,
  531. log_to_file=self._run_config.log_to_file,
  532. ),
  533. **self._tuner_kwargs,
  534. }
  535. analysis = run(
  536. **args,
  537. )
  538. self.clear_remote_string_queue()
  539. return analysis
  540. def _fit_resume(
  541. self, trainable: TrainableType, param_space: Optional[Dict[str, Any]]
  542. ) -> ExperimentAnalysis:
  543. """Fitting for a restored Tuner."""
  544. assert self._resume_config
  545. args = {
  546. **self._get_tune_run_arguments(trainable),
  547. **dict(
  548. run_or_experiment=trainable,
  549. config=param_space,
  550. resume_config=self._resume_config,
  551. search_alg=self._tune_config.search_alg,
  552. scheduler=self._tune_config.scheduler,
  553. ),
  554. **self._tuner_kwargs,
  555. }
  556. analysis = run(**args)
  557. self.clear_remote_string_queue()
  558. return analysis
  559. def __getstate__(self):
  560. state = self.__dict__.copy()
  561. state["_tuner_kwargs"] = state["_tuner_kwargs"].copy()
  562. state["_tuner_kwargs"].pop("_remote_string_queue", None)
  563. state.pop(_TRAINABLE_KEY, None)
  564. trainable = state.pop(_CONVERTED_TRAINABLE_KEY, None)
  565. param_space = state.pop(_PARAM_SPACE_KEY, None)
  566. state.pop(_EXPERIMENT_ANALYSIS_KEY, None)
  567. state["__trainable_name"] = (
  568. Experiment.get_trainable_name(trainable) if trainable else None
  569. )
  570. state["__flattened_param_space_keys"] = (
  571. sorted(flatten_dict(param_space).keys())
  572. if param_space is not None
  573. else None
  574. )
  575. return state
  576. def __setstate__(self, state):
  577. # Make sure the magic metadata gets removed first.
  578. state.pop("__flattened_param_space_keys", None)
  579. state.pop("__trainable_name", None)
  580. self.__dict__.update(state)