pbt.py 48 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195
  1. import copy
  2. import json
  3. import logging
  4. import math
  5. import os
  6. import random
  7. import shutil
  8. import warnings
  9. from pathlib import Path
  10. from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union
  11. from ray.air.constants import TRAINING_ITERATION
  12. from ray.train._internal.session import _FutureTrainingResult, _TrainingResult
  13. from ray.tune import Checkpoint
  14. from ray.tune.error import TuneError
  15. from ray.tune.experiment import Trial
  16. from ray.tune.result import DEFAULT_METRIC
  17. from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
  18. from ray.tune.search import SearchGenerator
  19. from ray.tune.search.sample import Domain, Function
  20. from ray.tune.search.variant_generator import format_vars
  21. from ray.tune.utils.util import SafeFallbackEncoder
  22. from ray.util import PublicAPI
  23. from ray.util.debug import log_once
  24. if TYPE_CHECKING:
  25. from ray.train import Checkpoint as TrainCheckpoint
  26. from ray.tune.execution.tune_controller import TuneController
  27. logger = logging.getLogger(__name__)
  28. class _PBTTrialState:
  29. """Internal PBT state tracked per-trial."""
  30. def __init__(self, trial: Trial):
  31. self.orig_tag = trial.experiment_tag
  32. self.last_score: Union[float, None] = None # Set on _save_trial_state
  33. self.last_checkpoint: Union[TrainCheckpoint, _FutureTrainingResult, None] = None
  34. self.last_perturbation_time: int = 0
  35. self.last_train_time: int = 0 # Used for synchronous mode
  36. self.last_result: Optional[
  37. dict[str, object]
  38. ] = None # Used for synchronous mode
  39. def __repr__(self) -> str:
  40. # Informative repr for easier debugging.
  41. return (
  42. self.__class__.__name__
  43. + "("
  44. + ", ".join(
  45. f"{k}={v}"
  46. for k, v in self.__dict__.items()
  47. if k
  48. in (
  49. "last_score",
  50. "last_checkpoint",
  51. "last_train_time",
  52. "last_perturbation_time",
  53. )
  54. )
  55. + ")"
  56. )
  57. def _explore(
  58. config: Dict,
  59. mutations: Dict,
  60. resample_probability: float,
  61. perturbation_factors: Tuple[float],
  62. custom_explore_fn: Optional[Callable],
  63. ) -> Tuple[Dict, Dict]:
  64. """Return a perturbed config and string descriptors of the operations performed
  65. on the original config to produce the new config.
  66. Args:
  67. config: Original hyperparameter configuration.
  68. mutations: Specification of mutations to perform as documented
  69. in the PopulationBasedTraining scheduler.
  70. resample_probability: Probability of allowing resampling of a
  71. particular variable.
  72. perturbation_factors: Scaling factors to choose between when mutating
  73. a continuous hyperparameter.
  74. custom_explore_fn: Custom explore function applied after built-in
  75. config perturbations.
  76. Returns:
  77. new_config: New hyperparameter configuration (after random mutations).
  78. operations: Map of hyperparams -> strings describing mutation operations
  79. performed
  80. """
  81. operations = {}
  82. new_config = copy.deepcopy(config)
  83. for key, distribution in mutations.items():
  84. if isinstance(distribution, dict):
  85. # Handle nested hyperparameter configs by recursively perturbing them
  86. nested_new_config, nested_ops = _explore(
  87. config[key],
  88. mutations[key],
  89. resample_probability,
  90. perturbation_factors,
  91. custom_explore_fn=None,
  92. )
  93. new_config.update({key: nested_new_config})
  94. operations.update({key: nested_ops})
  95. elif isinstance(distribution, (list, tuple)):
  96. # Case 1: Hyperparameter resample distribution is a list/tuple
  97. if (
  98. random.random() < resample_probability
  99. or config[key] not in distribution
  100. ):
  101. # Resample a value from the list with `resample_probability`
  102. new_config[key] = random.choice(distribution)
  103. operations[key] = "resample"
  104. else:
  105. # Otherwise, perturb by shifting to the left or right of the list
  106. shift = random.choice([-1, 1])
  107. old_idx = distribution.index(config[key])
  108. new_idx = old_idx + shift
  109. new_idx = min(max(new_idx, 0), len(distribution) - 1)
  110. new_config[key] = distribution[new_idx]
  111. operations[key] = (
  112. f"shift {'left' if shift == -1 else 'right'}"
  113. f"{' (noop)' if old_idx == new_idx else ''}"
  114. )
  115. elif isinstance(distribution, (Domain, Callable)):
  116. # Case 2: Hyperparameter resample distribution is:
  117. # 1. a function (ex: lambda: np.random.uniform(0, 1))
  118. # 2. tune search Domain (ex: tune.uniform(0, 1))
  119. if random.random() < resample_probability:
  120. # Resample a value from the function/domain with `resample_probability`
  121. new_config[key] = (
  122. distribution.sample(None)
  123. if isinstance(distribution, Domain)
  124. else distribution()
  125. )
  126. operations[key] = "resample"
  127. else:
  128. # Otherwise, perturb by multiplying the hyperparameter by one
  129. # of the `perturbation_factors`
  130. perturbation_factor = random.choice(perturbation_factors)
  131. new_config[key] = config[key] * perturbation_factor
  132. operations[key] = f"* {perturbation_factor}"
  133. if isinstance(config[key], int):
  134. # If this hyperparameter started out as an integer (ex: `batch_size`),
  135. # convert the new value back
  136. new_config[key] = int(new_config[key])
  137. else:
  138. raise ValueError(
  139. f"Unsupported hyperparameter distribution type: {type(distribution)}"
  140. )
  141. if custom_explore_fn:
  142. # The user can perform any additional hyperparameter exploration
  143. # via `custom_explore_fn`
  144. new_config = custom_explore_fn(new_config)
  145. assert new_config is not None, "Custom explore fn failed to return new config"
  146. return new_config, operations
  147. def _make_experiment_tag(orig_tag: str, config: Dict, mutations: Dict) -> str:
  148. """Appends perturbed params to the trial name to show in the console."""
  149. resolved_vars = {}
  150. for k in mutations.keys():
  151. resolved_vars[("config", k)] = config[k]
  152. return "{}@perturbed[{}]".format(orig_tag, format_vars(resolved_vars))
  153. def _fill_config(
  154. config: Dict, attr: str, search_space: Union[dict, list, tuple, Callable, Domain]
  155. ):
  156. """Add attr to config by sampling from search_space.
  157. This is a helper used to set initial hyperparameter values if the user doesn't
  158. specify them in the Tuner `param_space`.
  159. """
  160. if isinstance(search_space, Callable):
  161. config[attr] = search_space()
  162. elif isinstance(search_space, Domain):
  163. config[attr] = search_space.sample(None)
  164. elif isinstance(search_space, (list, tuple)):
  165. config[attr] = random.choice(search_space)
  166. elif isinstance(search_space, dict):
  167. config[attr] = {}
  168. for k, v in search_space.items():
  169. _fill_config(config[attr], k, v)
  170. def _filter_mutated_params_from_config(
  171. config: Dict, hyperparam_mutations: Dict
  172. ) -> Dict:
  173. """Filter out hyperparameters from a config so that only parameters specified
  174. within hyperparam_mutations remain. This recursively filters nested configs.
  175. Example:
  176. >>> config = {
  177. ... "a": {"b": 2, "c": 0, "d": {"e": 0.1}},
  178. ... "f": {"g": 0.5},
  179. ... }
  180. >>> hyperparam_mutations = {
  181. ... "a": {"b": [1, 2], "c": [-1, 0]},
  182. ... }
  183. >>> _filter_mutated_params_from_config(config, hyperparam_mutations) == {
  184. ... "a": {"b": 2, "c": 0}
  185. ... }
  186. True
  187. Args:
  188. config: The config dict that we want to filter.
  189. hyperparam_mutations: A dict containing a subset of hyperparameters from
  190. config, used to filter the config.
  191. Returns:
  192. mutated_params: A copy of config containing only params specified in
  193. hyperparam_mutations
  194. """
  195. mutated_params = {}
  196. for param_name in config:
  197. if param_name not in hyperparam_mutations:
  198. continue
  199. if isinstance(config[param_name], dict):
  200. nested_params = _filter_mutated_params_from_config(
  201. config[param_name], hyperparam_mutations[param_name]
  202. )
  203. mutated_params[param_name] = nested_params
  204. else:
  205. mutated_params[param_name] = config[param_name]
  206. return mutated_params
  207. @PublicAPI
  208. class PopulationBasedTraining(FIFOScheduler):
  209. """Implements the Population Based Training (PBT) algorithm.
  210. https://www.deepmind.com/blog/population-based-training-of-neural-networks
  211. PBT trains a group of models (or agents) in parallel. Periodically, poorly
  212. performing models clone the state of the top performers, and a random
  213. mutation is applied to their hyperparameters in the hopes of
  214. outperforming the current top models.
  215. Unlike other hyperparameter search algorithms, PBT mutates hyperparameters
  216. during training time. This enables very fast hyperparameter discovery and
  217. also automatically discovers good annealing schedules.
  218. This Tune PBT implementation considers all trials added as part of the
  219. PBT population. If the number of trials exceeds the cluster capacity,
  220. they will be time-multiplexed as to balance training progress across the
  221. population. To run multiple trials, use `tune.TuneConfig(num_samples=<int>)`.
  222. In {LOG_DIR}/{MY_EXPERIMENT_NAME}/, all mutations are logged in
  223. `pbt_global.txt` and individual policy perturbations are recorded
  224. in pbt_policy_{i}.txt. Tune logs: [target trial tag, clone trial tag,
  225. target trial iteration, clone trial iteration, old config, new config]
  226. on each perturbation step.
  227. Args:
  228. time_attr: The training result attr to use for comparing time.
  229. Note that you can pass in something non-temporal such as
  230. `training_iteration` as a measure of progress, the only requirement
  231. is that the attribute should increase monotonically.
  232. metric: The training result objective value attribute. Stopping
  233. procedures will use this attribute. If None but a mode was passed,
  234. the `ray.tune.result.DEFAULT_METRIC` will be used per default.
  235. mode: One of {min, max}. Determines whether objective is
  236. minimizing or maximizing the metric attribute.
  237. perturbation_interval: Models will be considered for
  238. perturbation at this interval of `time_attr`. Note that
  239. perturbation incurs checkpoint overhead, so you shouldn't set this
  240. to be too frequent.
  241. burn_in_period: Models will not be considered for
  242. perturbation before this interval of `time_attr` has passed. This
  243. guarantees that models are trained for at least a certain amount
  244. of time or timesteps before being perturbed.
  245. hyperparam_mutations: Hyperparams to mutate. The format is
  246. as follows: for each key, either a list, function,
  247. or a tune search space object (tune.loguniform, tune.uniform,
  248. etc.) can be provided. A list specifies an allowed set of
  249. categorical values. A function or tune search space object
  250. specifies the distribution of a continuous parameter. You must
  251. use tune.choice, tune.uniform, tune.loguniform, etc.. Arbitrary
  252. tune.sample_from objects are not supported.
  253. A key can also hold a dict for nested hyperparameters.
  254. You must specify at least one of `hyperparam_mutations` or
  255. `custom_explore_fn`.
  256. Tune will sample the search space provided by
  257. `hyperparam_mutations` for the initial hyperparameter values if the
  258. corresponding hyperparameters are not present in a trial's initial `config`.
  259. quantile_fraction: Parameters are transferred from the top
  260. `quantile_fraction` fraction of trials to the bottom
  261. `quantile_fraction` fraction. Needs to be between 0 and 0.5.
  262. Setting it to 0 essentially implies doing no exploitation at all.
  263. resample_probability: The probability of resampling from the
  264. original distribution when applying `hyperparam_mutations`. If not
  265. resampled, the value will be perturbed by a factor chosen from
  266. `perturbation_factors` if continuous, or changed to an adjacent value
  267. if discrete.
  268. perturbation_factors: Scaling factors to choose between when mutating
  269. a continuous hyperparameter.
  270. custom_explore_fn: You can also specify a custom exploration
  271. function. This function is invoked as `f(config)` after built-in
  272. perturbations from `hyperparam_mutations` are applied, and should
  273. return `config` updated as needed. You must specify at least one of
  274. `hyperparam_mutations` or `custom_explore_fn`.
  275. log_config: Whether to log the ray config of each model to
  276. local_dir at each exploit. Allows config schedule to be
  277. reconstructed.
  278. require_attrs: Whether to require time_attr and metric to appear
  279. in result for every iteration. If True, error will be raised
  280. if these values are not present in trial result.
  281. synch: If False, will use asynchronous implementation of
  282. PBT. Trial perturbations occur every perturbation_interval for each
  283. trial independently. If True, will use synchronous implementation
  284. of PBT. Perturbations will occur only after all trials are
  285. synced at the same time_attr every perturbation_interval.
  286. Defaults to False. See Appendix A.1 here
  287. https://arxiv.org/pdf/1711.09846.pdf.
  288. .. code-block:: python
  289. import random
  290. from ray import tune
  291. from ray.tune.schedulers import PopulationBasedTraining
  292. pbt = PopulationBasedTraining(
  293. time_attr="training_iteration",
  294. metric="episode_reward_mean",
  295. mode="max",
  296. perturbation_interval=10, # every 10 `time_attr` units
  297. # (training_iterations in this case)
  298. hyperparam_mutations={
  299. # Perturb factor1 by scaling it by 0.8 or 1.2. Resampling
  300. # resets it to a value sampled from the lambda function.
  301. "factor_1": lambda: random.uniform(0.0, 20.0),
  302. # Alternatively, use tune search space primitives.
  303. # The search space for factor_1 is equivalent to factor_2.
  304. "factor_2": tune.uniform(0.0, 20.0),
  305. # Perturb factor3 by changing it to an adjacent value, e.g.
  306. # 10 -> 1 or 10 -> 100. Resampling will choose at random.
  307. "factor_3": [1, 10, 100, 1000, 10000],
  308. # Using tune.choice is NOT equivalent to the above.
  309. # factor_4 is treated as a continuous hyperparameter.
  310. "factor_4": tune.choice([1, 10, 100, 1000, 10000]),
  311. })
  312. tuner = tune.Tuner(
  313. trainable,
  314. tune_config=tune.TuneConfig(
  315. scheduler=pbt,
  316. num_samples=8,
  317. ),
  318. )
  319. tuner.fit()
  320. """
  321. def __init__(
  322. self,
  323. time_attr: str = "time_total_s",
  324. metric: Optional[str] = None,
  325. mode: Optional[str] = None,
  326. perturbation_interval: float = 60.0,
  327. burn_in_period: float = 0.0,
  328. hyperparam_mutations: Dict[
  329. str, Union[dict, list, tuple, Callable, Domain]
  330. ] = None,
  331. quantile_fraction: float = 0.25,
  332. resample_probability: float = 0.25,
  333. perturbation_factors: Tuple[float, float] = (1.2, 0.8),
  334. custom_explore_fn: Optional[Callable] = None,
  335. log_config: bool = True,
  336. require_attrs: bool = True,
  337. synch: bool = False,
  338. ):
  339. hyperparam_mutations = hyperparam_mutations or {}
  340. for value in hyperparam_mutations.values():
  341. if not isinstance(value, (dict, list, tuple, Domain, Callable)):
  342. raise TypeError(
  343. "`hyperparam_mutation` values must be either "
  344. "a List, Tuple, Dict, a tune search space object, or "
  345. "a callable."
  346. )
  347. if isinstance(value, Function):
  348. raise ValueError(
  349. "arbitrary tune.sample_from objects are not "
  350. "supported for `hyperparam_mutation` values."
  351. "You must use other built in primitives like"
  352. "tune.uniform, tune.loguniform, etc."
  353. )
  354. if not hyperparam_mutations and not custom_explore_fn:
  355. raise TuneError(
  356. "You must specify at least one of `hyperparam_mutations` "
  357. "or `custom_explore_fn` to use PBT."
  358. )
  359. if quantile_fraction > 0.5 or quantile_fraction < 0:
  360. raise ValueError(
  361. "You must set `quantile_fraction` to a value between 0 and"
  362. "0.5. Current value: '{}'".format(quantile_fraction)
  363. )
  364. if perturbation_interval <= 0:
  365. raise ValueError(
  366. "perturbation_interval must be a positive number greater "
  367. "than 0. Current value: '{}'".format(perturbation_interval)
  368. )
  369. if mode:
  370. assert mode in ["min", "max"], "`mode` must be 'min' or 'max'."
  371. super().__init__()
  372. self._metric = metric
  373. self._mode = mode
  374. self._metric_op = None
  375. if self._mode == "max":
  376. self._metric_op = 1.0
  377. elif self._mode == "min":
  378. self._metric_op = -1.0
  379. self._time_attr = time_attr
  380. self._perturbation_interval = perturbation_interval
  381. self._burn_in_period = burn_in_period
  382. self._hyperparam_mutations = hyperparam_mutations
  383. self._quantile_fraction = quantile_fraction
  384. self._resample_probability = resample_probability
  385. self._perturbation_factors = perturbation_factors
  386. self._trial_state: dict[Trial, _PBTTrialState] = {}
  387. self._custom_explore_fn = custom_explore_fn
  388. self._log_config = log_config
  389. self._require_attrs = require_attrs
  390. self._synch = synch
  391. self._next_perturbation_sync = max(
  392. self._perturbation_interval,
  393. self._burn_in_period,
  394. )
  395. # Metrics
  396. self._num_checkpoints = 0
  397. self._num_perturbations = 0
  398. def set_search_properties(
  399. self, metric: Optional[str], mode: Optional[str], **spec
  400. ) -> bool:
  401. if self._metric and metric:
  402. return False
  403. if self._mode and mode:
  404. return False
  405. if metric:
  406. self._metric = metric
  407. if mode:
  408. self._mode = mode
  409. if self._mode == "max":
  410. self._metric_op = 1.0
  411. elif self._mode == "min":
  412. self._metric_op = -1.0
  413. if self._metric is None and self._mode:
  414. # If only a mode was passed, use anonymous metric
  415. self._metric = DEFAULT_METRIC
  416. return True
  417. def on_trial_add(self, tune_controller: "TuneController", trial: Trial):
  418. if tune_controller.search_alg is not None and isinstance(
  419. tune_controller.search_alg, SearchGenerator
  420. ):
  421. raise ValueError(
  422. "Search algorithms cannot be used with {} "
  423. "schedulers. Please remove {}.".format(
  424. self.__class__.__name__, tune_controller.search_alg
  425. )
  426. )
  427. if not self._metric or not self._metric_op:
  428. raise ValueError(
  429. "{} has been instantiated without a valid `metric` ({}) or "
  430. "`mode` ({}) parameter. Either pass these parameters when "
  431. "instantiating the scheduler, or pass them as parameters "
  432. "to `tune.TuneConfig()`".format(
  433. self.__class__.__name__, self._metric, self._mode
  434. )
  435. )
  436. checkpoint_config = trial.run_metadata.checkpoint_manager.checkpoint_config
  437. if (
  438. checkpoint_config.num_to_keep
  439. and checkpoint_config.num_to_keep <= 2
  440. and log_once("pbt_num_to_keep")
  441. ):
  442. warnings.warn(
  443. "Using `CheckpointConfig.num_to_keep <= 2` with PBT can lead to "
  444. "restoration problems when checkpoint are deleted too early for "
  445. "other trials to exploit them. If this happens, increase the value "
  446. "of `num_to_keep`."
  447. )
  448. self._trial_state[trial] = _PBTTrialState(trial)
  449. for attr in self._hyperparam_mutations.keys():
  450. if attr not in trial.config:
  451. if log_once(attr + "-missing"):
  452. logger.debug(
  453. "Cannot find {} in config. Using search "
  454. "space provided by hyperparam_mutations."
  455. )
  456. # Add attr to trial's config by sampling search space from
  457. # hyperparam_mutations.
  458. _fill_config(trial.config, attr, self._hyperparam_mutations[attr])
  459. # Make sure this attribute is added to CLI output.
  460. trial.evaluated_params[attr] = trial.config[attr]
  461. def on_trial_result(
  462. self, tune_controller: "TuneController", trial: Trial, result: Dict
  463. ) -> str:
  464. if self._time_attr not in result:
  465. time_missing_msg = (
  466. "Cannot find time_attr {} "
  467. "in trial result {}. Make sure that this "
  468. "attribute is returned in the "
  469. "results of your Trainable.".format(self._time_attr, result)
  470. )
  471. if self._require_attrs:
  472. raise RuntimeError(
  473. time_missing_msg
  474. + "If this error is expected, you can change this to "
  475. "a warning message by "
  476. "setting PBT(require_attrs=False)"
  477. )
  478. else:
  479. if log_once("pbt-time_attr-error"):
  480. logger.warning(time_missing_msg)
  481. if self._metric not in result:
  482. metric_missing_msg = (
  483. "Cannot find metric {} in trial result {}. "
  484. "Make sure that this attribute is returned "
  485. "in the "
  486. "results of your Trainable.".format(self._metric, result)
  487. )
  488. if self._require_attrs:
  489. raise RuntimeError(
  490. metric_missing_msg + "If this error is expected, "
  491. "you can change this to a warning message by "
  492. "setting PBT(require_attrs=False)"
  493. )
  494. else:
  495. if log_once("pbt-metric-error"):
  496. logger.warning(metric_missing_msg)
  497. if self._metric not in result or self._time_attr not in result:
  498. return TrialScheduler.CONTINUE
  499. time = result[self._time_attr]
  500. state = self._trial_state[trial]
  501. # Continue training if burn-in period has not been reached, yet.
  502. if time < self._burn_in_period:
  503. logger.debug(f"Still in burn-in period: {time} < {self._burn_in_period}")
  504. return TrialScheduler.CONTINUE
  505. # Continue training if perturbation interval has not been reached, yet.
  506. time_since_perturb = time - state.last_perturbation_time
  507. if time_since_perturb < self._perturbation_interval:
  508. logger.debug(
  509. f"Perturbation interval not reached: "
  510. f"{time_since_perturb} < {self._perturbation_interval}"
  511. )
  512. return TrialScheduler.CONTINUE # avoid checkpoint overhead
  513. logger.debug(f"Updating trial state for trial {trial} at time {time}")
  514. self._save_trial_state(state, time, result, trial)
  515. if not self._synch:
  516. state.last_perturbation_time = time
  517. lower_quantile, upper_quantile = self._quantiles()
  518. decision = TrialScheduler.CONTINUE
  519. for other_trial in tune_controller.get_trials():
  520. if other_trial.status in [Trial.PENDING, Trial.PAUSED]:
  521. decision = TrialScheduler.PAUSE
  522. break
  523. self._checkpoint_or_exploit(
  524. trial, tune_controller, upper_quantile, lower_quantile
  525. )
  526. return TrialScheduler.NOOP if trial.status == Trial.PAUSED else decision
  527. else:
  528. # Synchronous mode.
  529. if any(
  530. self._trial_state[t].last_train_time < self._next_perturbation_sync
  531. and t != trial
  532. for t in tune_controller.get_live_trials()
  533. ):
  534. logger.debug(
  535. f"Sync: Other trials are not at perturb time, yet. "
  536. f"Pausing trial {trial} to wait."
  537. )
  538. else:
  539. # All trials are synced at the same timestep.
  540. logger.debug("Sync: All trials are at perturb time.")
  541. lower_quantile, upper_quantile = self._quantiles()
  542. all_trials = tune_controller.get_trials()
  543. not_in_quantile = []
  544. for t in all_trials:
  545. if t not in lower_quantile and t not in upper_quantile:
  546. not_in_quantile.append(t)
  547. logger.debug(
  548. "Trial statistics\n"
  549. f"Upper quantile: {upper_quantile}\n"
  550. f"Lower quantile: {lower_quantile}\n"
  551. f"Not in quantile: {not_in_quantile}"
  552. )
  553. # Move upper quantile trials to beginning and lower quantile
  554. # to end. This ensures that checkpointing of strong trials
  555. # occurs before exploiting of weaker ones.
  556. all_trials = upper_quantile + not_in_quantile + lower_quantile
  557. for t in all_trials:
  558. logger.debug(f"Perturbing trial {t}")
  559. self._trial_state[t].last_perturbation_time = time
  560. self._checkpoint_or_exploit(
  561. t, tune_controller, upper_quantile, lower_quantile
  562. )
  563. all_train_times = [
  564. self._trial_state[t].last_train_time
  565. for t in tune_controller.get_trials()
  566. ]
  567. max_last_train_time = max(all_train_times)
  568. self._next_perturbation_sync = max(
  569. self._next_perturbation_sync + self._perturbation_interval,
  570. max_last_train_time,
  571. )
  572. logger.debug(f"Next perturb at time {self._next_perturbation_sync}")
  573. # In sync mode we should pause all trials once result comes in.
  574. # Once a perturbation step happens for all trials, they should
  575. # still all be paused.
  576. # choose_trial_to_run will then pick the next trial to run out of
  577. # the paused trials.
  578. return (
  579. TrialScheduler.NOOP
  580. if trial.status == Trial.PAUSED
  581. else TrialScheduler.PAUSE
  582. )
  583. def _save_trial_state(
  584. self, state: _PBTTrialState, time: int, result: Dict, trial: Trial
  585. ):
  586. """Saves necessary trial information when result is received.
  587. Args:
  588. state: The state object for the trial.
  589. time: The current timestep of the trial.
  590. result: The trial's result dictionary.
  591. trial: The trial object.
  592. """
  593. # This trial has reached its perturbation interval.
  594. # Record new state in the state object.
  595. score = self._metric_op * result[self._metric]
  596. state.last_score = score
  597. state.last_train_time = time
  598. state.last_result = result
  599. return score
  600. def _checkpoint_or_exploit(
  601. self,
  602. trial: Trial,
  603. tune_controller: "TuneController",
  604. upper_quantile: List[Trial],
  605. lower_quantile: List[Trial],
  606. ):
  607. """Checkpoint if in upper quantile, exploits if in lower."""
  608. state = self._trial_state[trial]
  609. if trial in upper_quantile:
  610. # The trial last result is only updated after the scheduler
  611. # callback. So, we override with the current result.
  612. logger.debug(f"Trial {trial} is in upper quantile. Saving checkpoint.")
  613. if trial.status == Trial.PAUSED:
  614. if trial.temporary_state.saving_to and isinstance(
  615. trial.temporary_state.saving_to, _FutureTrainingResult
  616. ):
  617. logger.debug(f"Trial {trial} is still saving.")
  618. state.last_checkpoint = trial.temporary_state.saving_to
  619. else:
  620. # Paused trial will always have an in-memory checkpoint.
  621. logger.debug(
  622. f"Trial {trial} is paused. Use last available "
  623. f"checkpoint {trial.checkpoint}."
  624. )
  625. state.last_checkpoint = trial.checkpoint
  626. else:
  627. logger.debug(f"Instructing {trial} to save.")
  628. state.last_checkpoint = tune_controller._schedule_trial_save(
  629. trial, result=state.last_result
  630. )
  631. self._num_checkpoints += 1
  632. else:
  633. state.last_checkpoint = None # not a top trial
  634. if trial in lower_quantile:
  635. trial_to_clone = random.choice(upper_quantile)
  636. assert trial is not trial_to_clone
  637. clone_state = self._trial_state[trial_to_clone]
  638. last_checkpoint = clone_state.last_checkpoint
  639. logger.debug(
  640. f"Trial {trial} is in lower quantile. "
  641. f"Exploiting trial {trial_to_clone}."
  642. )
  643. if isinstance(last_checkpoint, _FutureTrainingResult):
  644. training_result = last_checkpoint.resolve()
  645. if training_result:
  646. clone_state.last_result = training_result.metrics
  647. clone_state.last_checkpoint = training_result.checkpoint
  648. last_checkpoint = clone_state.last_checkpoint
  649. else:
  650. logger.debug(
  651. "PBT-scheduled checkpoint save resolved to None. Trial "
  652. f"{trial_to_clone} didn't save any checkpoint before "
  653. f"and can't be exploited."
  654. )
  655. last_checkpoint = None
  656. if not last_checkpoint:
  657. logger.info(
  658. f"[pbt]: no checkpoint for trial {trial_to_clone}."
  659. f" Skip exploit for Trial {trial}"
  660. )
  661. return
  662. self._exploit(tune_controller, trial, trial_to_clone)
  663. def _log_config_on_step(
  664. self,
  665. trial_state: _PBTTrialState,
  666. new_state: _PBTTrialState,
  667. trial: Trial,
  668. trial_to_clone: Trial,
  669. new_config: Dict,
  670. ):
  671. """Logs transition during exploit/exploit step.
  672. For each step, logs: [target trial tag, clone trial tag, target trial
  673. iteration, clone trial iteration, old config, new config].
  674. """
  675. trial_name, trial_to_clone_name = (trial_state.orig_tag, new_state.orig_tag)
  676. trial_id = trial.trial_id
  677. trial_to_clone_id = trial_to_clone.trial_id
  678. trial_path = os.path.join(
  679. trial.local_experiment_path, "pbt_policy_" + trial_id + ".txt"
  680. )
  681. trial_to_clone_path = os.path.join(
  682. trial_to_clone.local_dir, "pbt_policy_" + trial_to_clone_id + ".txt"
  683. )
  684. policy = [
  685. trial_name,
  686. trial_to_clone_name,
  687. trial.last_result.get(TRAINING_ITERATION, 0),
  688. trial_to_clone.last_result.get(TRAINING_ITERATION, 0),
  689. trial_to_clone.config,
  690. new_config,
  691. ]
  692. # Log to global file.
  693. with open(
  694. os.path.join(trial.local_experiment_path, "pbt_global.txt"), "a+"
  695. ) as f:
  696. print(json.dumps(policy, cls=SafeFallbackEncoder), file=f)
  697. # Overwrite state in target trial from trial_to_clone.
  698. if os.path.exists(trial_to_clone_path):
  699. shutil.copyfile(trial_to_clone_path, trial_path)
  700. # Log new exploit in target trial log.
  701. with open(trial_path, "a+") as f:
  702. f.write(json.dumps(policy, cls=SafeFallbackEncoder) + "\n")
  703. def _get_new_config(self, trial: Trial, trial_to_clone: Trial) -> Tuple[Dict, Dict]:
  704. """Gets new config for trial by exploring trial_to_clone's config.
  705. Args:
  706. trial: The current trial that decided to exploit trial_to_clone.
  707. trial_to_clone: The top-performing trial with a hyperparameter config
  708. that the current trial will explore by perturbing.
  709. Returns:
  710. new_config: New hyperparameter configuration (after random mutations).
  711. operations: Map of hyperparams -> strings describing mutation operations
  712. performed
  713. """
  714. return _explore(
  715. trial_to_clone.config,
  716. self._hyperparam_mutations,
  717. self._resample_probability,
  718. self._perturbation_factors,
  719. self._custom_explore_fn,
  720. )
  721. def _summarize_hyperparam_changes(
  722. self,
  723. old_params: Dict,
  724. new_params: Dict,
  725. operations: Optional[Dict] = None,
  726. prefix: str = "",
  727. ) -> str:
  728. """Generates a summary of hyperparameter changes from a PBT "explore" step.
  729. Example:
  730. Given the following hyperparam_mutations:
  731. hyperparam_mutations = {
  732. "a": tune.uniform(0, 1),
  733. "b": list(range(5)),
  734. "c": {
  735. "d": tune.uniform(2, 3),
  736. "e": {"f": [-1, 0, 1]},
  737. },
  738. }
  739. This is an example summary output of the operations performed on old_params
  740. to get new_params:
  741. a : 0.5 --- (* 0.8) --> 0.4
  742. b : 2 --- (resample) --> 4
  743. c :
  744. d : 2.5 --- (* 1.2) --> 3.0
  745. e :
  746. f : 0 --- (shift right) --> 1
  747. The summary shows the old and new hyperparameter values, with the operation
  748. used to perturb labeled in between.
  749. If the operation for a certain hyperparameter is not provided, then the summary
  750. will just contain arrows without a label. (ex: a : 0.5 -----> 0.4)
  751. Args:
  752. old_params: Old values of hyperparameters that are perturbed to generate
  753. the new config
  754. new_params: The newly generated hyperparameter config from PBT exploration
  755. operations: Map of hyperparams -> string descriptors the operations
  756. performed to generate the values in `new_params`
  757. prefix: Helper argument to format nested dict hyperparam configs
  758. Returns:
  759. summary_str: The hyperparameter change summary to print/log.
  760. """
  761. summary_str = ""
  762. if not old_params:
  763. return summary_str
  764. for param_name in old_params:
  765. old_val = old_params[param_name]
  766. assert param_name in new_params, (
  767. "`old_params` and `new_params` "
  768. f"must both contain the key: '{param_name}'\n"
  769. f"old_params.keys() = {old_params.keys()}\n"
  770. f"new_params.keys() = {new_params.keys()}"
  771. )
  772. new_val = new_params[param_name]
  773. summary_str += f"{prefix}{param_name} : "
  774. if isinstance(old_val, Dict):
  775. # Handle nested hyperparameters by recursively summarizing
  776. summary_str += "\n"
  777. nested_operations = operations.get(param_name, {})
  778. summary_str += self._summarize_hyperparam_changes(
  779. old_val,
  780. new_val,
  781. operations=nested_operations,
  782. prefix=prefix + " " * 4,
  783. )
  784. else:
  785. op = operations.get(param_name, None)
  786. if not op:
  787. arrow = "----->"
  788. else:
  789. arrow = f"--- ({op}) -->"
  790. summary_str += f"{old_val} {arrow} {new_val}\n"
  791. return summary_str
  792. def _exploit(
  793. self,
  794. tune_controller: "TuneController",
  795. trial: Trial,
  796. trial_to_clone: Trial,
  797. ):
  798. """Transfers perturbed state from trial_to_clone -> trial.
  799. If specified, also logs the updated hyperparam state.
  800. """
  801. trial_state = self._trial_state[trial]
  802. new_state = self._trial_state[trial_to_clone]
  803. class_name = self.__class__.__name__
  804. logger.info(
  805. f"\n\n[{class_name}] [Exploit] Cloning trial "
  806. "{} (score = {:4f}) into trial {} (score = {:4f})\n".format(
  807. trial_to_clone.trial_id,
  808. new_state.last_score,
  809. trial.trial_id,
  810. trial_state.last_score,
  811. )
  812. )
  813. new_config, operations = self._get_new_config(trial, trial_to_clone)
  814. # Only log mutated hyperparameters and not entire config.
  815. old_params = _filter_mutated_params_from_config(
  816. trial_to_clone.config, self._hyperparam_mutations
  817. )
  818. new_params = _filter_mutated_params_from_config(
  819. new_config, self._hyperparam_mutations
  820. )
  821. explore_info_str = (
  822. f"\n\n[{class_name}] [Explore] Perturbed the hyperparameter config of trial"
  823. f"{trial.trial_id}:\n"
  824. )
  825. explore_info_str += (
  826. self._summarize_hyperparam_changes(old_params, new_params, operations)
  827. or "No hyperparameters mutated."
  828. )
  829. logger.info(explore_info_str)
  830. if self._log_config:
  831. self._log_config_on_step(
  832. trial_state, new_state, trial, trial_to_clone, new_config
  833. )
  834. new_tag = _make_experiment_tag(
  835. trial_state.orig_tag, new_config, self._hyperparam_mutations
  836. )
  837. if trial.status == Trial.PAUSED:
  838. # If trial is paused we update it with a new checkpoint.
  839. # When the trial is started again, the new checkpoint is used.
  840. if not self._synch:
  841. raise TuneError(
  842. "Trials should be paused here only if in "
  843. "synchronous mode. If you encounter this error"
  844. " please raise an issue on Ray Github."
  845. )
  846. else:
  847. tune_controller.pause_trial(trial, should_checkpoint=False)
  848. trial.set_experiment_tag(new_tag)
  849. # Clone hyperparameters from the `trial_to_clone`
  850. trial.set_config(new_config)
  851. # Resume training from a shallow copy of `trial_to_clone`'s latest
  852. # checkpoint
  853. checkpoint_to_exploit: Checkpoint = copy.copy(new_state.last_checkpoint)
  854. trial.run_metadata.checkpoint_manager._latest_checkpoint_result = (
  855. _TrainingResult(
  856. checkpoint=checkpoint_to_exploit, metrics=new_state.last_result
  857. )
  858. )
  859. self._num_perturbations += 1
  860. # Transfer over the last perturbation time as well
  861. trial_state.last_perturbation_time = new_state.last_perturbation_time
  862. trial_state.last_train_time = new_state.last_train_time
  863. def _quantiles(self) -> Tuple[List[Trial], List[Trial]]:
  864. """Returns trials in the lower and upper `quantile` of the population.
  865. If there is not enough data to compute this, returns empty lists.
  866. """
  867. trials = []
  868. for trial, state in self._trial_state.items():
  869. logger.debug("Trial {}, state {}".format(trial, state))
  870. if trial.is_finished():
  871. logger.debug("Trial {} is finished".format(trial))
  872. if state.last_score is not None and not trial.is_finished():
  873. trials.append(trial)
  874. # last_score is by construction never None
  875. trials.sort(key=lambda t: self._trial_state[t].last_score) # type: ignore[arg-type,return-value]
  876. if len(trials) <= 1:
  877. return [], []
  878. else:
  879. num_trials_in_quantile = int(
  880. math.ceil(len(trials) * self._quantile_fraction)
  881. )
  882. if num_trials_in_quantile > len(trials) / 2:
  883. num_trials_in_quantile = int(math.floor(len(trials) / 2))
  884. return (trials[:num_trials_in_quantile], trials[-num_trials_in_quantile:])
  885. def choose_trial_to_run(self, tune_controller: "TuneController") -> Optional[Trial]:
  886. """Ensures all trials get fair share of time (as defined by time_attr).
  887. This enables the PBT scheduler to support a greater number of
  888. concurrent trials than can fit in the cluster at any given time.
  889. """
  890. candidates = []
  891. for trial in tune_controller.get_trials():
  892. if trial.status in [
  893. Trial.PENDING,
  894. Trial.PAUSED,
  895. ]:
  896. if not self._synch:
  897. candidates.append(trial)
  898. elif (
  899. self._trial_state[trial].last_train_time
  900. < self._next_perturbation_sync
  901. ):
  902. candidates.append(trial)
  903. candidates.sort(key=lambda trial: self._trial_state[trial].last_train_time)
  904. return candidates[0] if candidates else None
  905. # Unit test only. TODO(xwjiang): Remove test-specific APIs.
  906. def reset_stats(self):
  907. self._num_perturbations = 0
  908. self._num_checkpoints = 0
  909. # Unit test only. TODO(xwjiang): Remove test-specific APIs.
  910. def last_scores(self, trials: List[Trial]) -> List[float]:
  911. scores = []
  912. for trial in trials:
  913. state = self._trial_state[trial]
  914. if state.last_score is not None and not trial.is_finished():
  915. scores.append(state.last_score)
  916. return scores
  917. def debug_string(self) -> str:
  918. return "PopulationBasedTraining: {} checkpoints, {} perturbs".format(
  919. self._num_checkpoints, self._num_perturbations
  920. )
  921. @PublicAPI
  922. class PopulationBasedTrainingReplay(FIFOScheduler):
  923. """Replays a Population Based Training run.
  924. Population Based Training does not return a single hyperparameter
  925. configuration, but rather a schedule of configurations. For instance,
  926. PBT might discover that a larger learning rate leads to good results
  927. in the first training iterations, but that a smaller learning rate
  928. is preferable later.
  929. This scheduler enables replaying these parameter schedules from
  930. a finished PBT run. This requires that population based training has
  931. been run with ``log_config=True``, which is the default setting.
  932. The scheduler will only accept and train a single trial. It will
  933. start with the initial config of the existing trial and update the
  934. config according to the schedule.
  935. Args:
  936. policy_file: The PBT policy file. Usually this is
  937. stored in ``~/ray_results/experiment_name/pbt_policy_xxx.txt``
  938. where ``xxx`` is the trial ID.
  939. Example:
  940. .. code-block:: python
  941. # Replaying a result from ray.tune.examples.pbt_convnet_example
  942. from ray import tune
  943. from ray.tune.examples.pbt_convnet_example import PytorchTrainable
  944. from ray.tune.schedulers import PopulationBasedTrainingReplay
  945. replay = PopulationBasedTrainingReplay(
  946. "~/ray_results/pbt_test/pbt_policy_XXXXX_00001.txt")
  947. tuner = tune.Tuner(
  948. PytorchTrainable,
  949. run_config=tune.RunConfig(
  950. stop={"training_iteration": 100}
  951. ),
  952. tune_config=tune.TuneConfig(
  953. scheduler=replay,
  954. ),
  955. )
  956. tuner.fit()
  957. """
  958. def __init__(self, policy_file: str):
  959. policy_file = Path(policy_file).expanduser()
  960. if not policy_file.exists():
  961. raise ValueError("Policy file not found: {}".format(policy_file.as_posix()))
  962. self.policy_file = policy_file.as_posix()
  963. # Find and read pbt policy file, potentially raise error
  964. initial_config, self._policy = self._load_policy(self.policy_file)
  965. self.experiment_tag = "replay_{}".format(os.path.basename(self.policy_file))
  966. self.config = initial_config
  967. self.current_config = self.config
  968. self._trial = None
  969. self._current_step = 0
  970. self._num_perturbations = 0
  971. self._policy_iter = iter(self._policy)
  972. self._next_policy = next(self._policy_iter, None)
  973. def _load_policy(self, policy_file: str) -> Tuple[Dict, List[Tuple[int, Dict]]]:
  974. raw_policy = []
  975. with open(policy_file, "rt") as fp:
  976. for row in fp.readlines():
  977. try:
  978. parsed_row = json.loads(row)
  979. except json.JSONDecodeError:
  980. raise ValueError(
  981. "Could not read PBT policy file: {}.".format(policy_file)
  982. ) from None
  983. raw_policy.append(tuple(parsed_row))
  984. # Loop through policy from end to start to obtain changepoints
  985. policy = []
  986. last_new_tag = None
  987. last_old_conf = None
  988. for old_tag, new_tag, old_step, new_step, old_conf, new_conf in reversed(
  989. raw_policy
  990. ):
  991. if last_new_tag and old_tag != last_new_tag:
  992. # Tag chain ended. This means that previous changes were
  993. # overwritten by the last change and should be ignored.
  994. break
  995. last_new_tag = new_tag
  996. last_old_conf = old_conf
  997. policy.append((new_step, new_conf))
  998. return last_old_conf, list(reversed(policy))
  999. def on_trial_add(self, tune_controller: "TuneController", trial: Trial):
  1000. if self._trial:
  1001. raise ValueError(
  1002. "More than one trial added to PBT replay run. This "
  1003. "means the same schedule will be trained multiple "
  1004. "times. Do you want to set `n_samples=1`?"
  1005. )
  1006. self._trial = trial
  1007. if self._trial.config and self._policy:
  1008. logger.warning(
  1009. "Trial was initialized with a config, which was overwritten. "
  1010. "Did you start the PBT replay with a `config` parameter?"
  1011. )
  1012. elif self._trial.config and not self._policy:
  1013. # Only train with initial policy
  1014. self.config = self._trial.config
  1015. elif not self._trial.config and not self._policy:
  1016. raise ValueError(
  1017. "No replay policy found and trial initialized without a "
  1018. "valid config. Either pass a `config` argument to `tune.Tuner()`"
  1019. "or consider not using PBT replay for this run."
  1020. )
  1021. self._trial.set_config(self.config)
  1022. def on_trial_result(
  1023. self, tune_controller: "TuneController", trial: Trial, result: Dict
  1024. ) -> str:
  1025. if TRAINING_ITERATION not in result:
  1026. # No time reported
  1027. return TrialScheduler.CONTINUE
  1028. if not self._next_policy:
  1029. # No more changes in the config
  1030. return TrialScheduler.CONTINUE
  1031. step = result[TRAINING_ITERATION]
  1032. self._current_step = step
  1033. change_at, new_config = self._next_policy
  1034. if step < change_at:
  1035. # Don't change the policy just yet
  1036. return TrialScheduler.CONTINUE
  1037. logger.info(
  1038. "Population Based Training replay is now at step {}. "
  1039. "Configuration will be changed to {}.".format(step, new_config)
  1040. )
  1041. result = tune_controller._schedule_trial_save(trial, result=result)
  1042. training_result = result.resolve()
  1043. trial.run_metadata.checkpoint_manager._latest_checkpoint_result = (
  1044. training_result
  1045. )
  1046. new_tag = _make_experiment_tag(self.experiment_tag, new_config, new_config)
  1047. tune_controller.pause_trial(trial, should_checkpoint=False)
  1048. trial.set_experiment_tag(new_tag)
  1049. trial.set_config(new_config)
  1050. self.current_config = new_config
  1051. self._num_perturbations += 1
  1052. self._next_policy = next(self._policy_iter, None)
  1053. return TrialScheduler.NOOP
  1054. def debug_string(self) -> str:
  1055. return "PopulationBasedTraining replay: Step {}, perturb {}".format(
  1056. self._current_step, self._num_perturbations
  1057. )