pb2.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. import logging
  2. from copy import deepcopy
  3. from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Union
  4. import numpy as np
  5. import pandas as pd
  6. from ray.tune import TuneError
  7. from ray.tune.experiment import Trial
  8. from ray.tune.schedulers import PopulationBasedTraining
  9. from ray.tune.schedulers.pbt import _PBTTrialState
  10. from ray.tune.utils.util import flatten_dict, unflatten_dict
  11. from ray.util.debug import log_once
  12. if TYPE_CHECKING:
  13. from ray.tune.execution.tune_controller import TuneController
  14. def import_pb2_dependencies():
  15. try:
  16. import sklearn
  17. except ImportError:
  18. sklearn = None
  19. return sklearn
  20. has_sklearn = import_pb2_dependencies()
  21. if has_sklearn:
  22. from sklearn.gaussian_process import GaussianProcessRegressor
  23. from ray.tune.schedulers.pb2_utils import (
  24. UCB,
  25. TV_SquaredExp,
  26. normalize,
  27. optimize_acq,
  28. select_length,
  29. standardize,
  30. )
  31. logger = logging.getLogger(__name__)
  32. def _fill_config(
  33. config: Dict, hyperparam_bounds: Dict[str, Union[dict, list, tuple]]
  34. ) -> Dict:
  35. """Fills missing hyperparameters in config by sampling uniformly from the
  36. specified `hyperparam_bounds`.
  37. Recursively fills the config if `hyperparam_bounds` is a nested dict.
  38. This is a helper used to set initial hyperparameter values if the user doesn't
  39. specify them in the Tuner `param_space`.
  40. Returns the dict of filled hyperparameters.
  41. """
  42. filled_hyperparams = {}
  43. for param_name, bounds in hyperparam_bounds.items():
  44. if isinstance(bounds, dict):
  45. if param_name not in config:
  46. config[param_name] = {}
  47. filled_hyperparams[param_name] = _fill_config(config[param_name], bounds)
  48. elif isinstance(bounds, (list, tuple)) and param_name not in config:
  49. if log_once(param_name + "-missing"):
  50. logger.debug(
  51. f"Cannot find {param_name} in config. Initializing by "
  52. "sampling uniformly from the provided `hyperparam_bounds`."
  53. )
  54. assert len(bounds) == 2
  55. low, high = bounds
  56. config[param_name] = filled_hyperparams[param_name] = np.random.uniform(
  57. low, high
  58. )
  59. return filled_hyperparams
  60. def _select_config(
  61. Xraw: np.array,
  62. yraw: np.array,
  63. current: list,
  64. newpoint: np.array,
  65. bounds: dict,
  66. num_f: int,
  67. ) -> np.ndarray:
  68. """Selects the next hyperparameter config to try.
  69. This function takes the formatted data, fits the GP model and optimizes the
  70. UCB acquisition function to select the next point.
  71. Args:
  72. Xraw: The un-normalized array of hyperparams, Time and
  73. Reward
  74. yraw: The un-normalized vector of reward changes.
  75. current: The hyperparams of trials currently running. This is
  76. important so we do not select the same config twice. If there is
  77. data here then we fit a second GP including it
  78. (with fake y labels). The GP variance doesn't depend on the y
  79. labels so it is ok.
  80. newpoint: The Reward and Time for the new point.
  81. We cannot change these as they are based on the *new weights*.
  82. bounds: Bounds for the hyperparameters. Used to normalize.
  83. num_f: The number of fixed params. Almost always 2 (reward+time)
  84. Return:
  85. xt: A vector of new hyperparameters.
  86. """
  87. length = select_length(Xraw, yraw, bounds, num_f)
  88. Xraw = Xraw[-length:, :]
  89. yraw = yraw[-length:]
  90. base_vals = np.array(list(bounds.values())).T
  91. oldpoints = Xraw[:, :num_f]
  92. old_lims = np.concatenate(
  93. (np.max(oldpoints, axis=0), np.min(oldpoints, axis=0))
  94. ).reshape(2, oldpoints.shape[1])
  95. limits = np.concatenate((old_lims, base_vals), axis=1)
  96. X = normalize(Xraw, limits)
  97. y = standardize(yraw).reshape(yraw.size, 1)
  98. fixed = normalize(newpoint, oldpoints)
  99. kernel = TV_SquaredExp(variance=1.0, lengthscale=1.0, epsilon=0.1)
  100. try:
  101. m = GaussianProcessRegressor(
  102. kernel=kernel, optimizer="fmin_l_bfgs_b", alpha=1e-10
  103. )
  104. m.fit(X, y)
  105. except np.linalg.LinAlgError:
  106. # add diagonal ** we would ideally make this something more robust...
  107. X += np.eye(X.shape[0]) * 1e-3
  108. m = GaussianProcessRegressor(
  109. kernel=kernel, optimizer="fmin_l_bfgs_b", alpha=1e-10
  110. )
  111. m.fit(X, y)
  112. if current is None:
  113. m1 = deepcopy(m)
  114. else:
  115. # add the current trials to the dataset
  116. padding = np.array([fixed for _ in range(current.shape[0])])
  117. current = normalize(current, base_vals)
  118. current = np.hstack((padding, current))
  119. Xnew = np.vstack((X, current))
  120. ypad = np.zeros(current.shape[0])
  121. ypad = ypad.reshape(-1, 1)
  122. ynew = np.vstack((y, ypad))
  123. kernel1 = TV_SquaredExp(variance=1.0, lengthscale=1.0, epsilon=0.1)
  124. m1 = GaussianProcessRegressor(
  125. kernel=kernel1, optimizer="fmin_l_bfgs_b", alpha=1e-10
  126. )
  127. m1.fit(Xnew, ynew)
  128. xt = optimize_acq(UCB, m, m1, fixed, num_f)
  129. # convert back...
  130. xt = xt * (np.max(base_vals, axis=0) - np.min(base_vals, axis=0)) + np.min(
  131. base_vals, axis=0
  132. )
  133. xt = xt.astype(np.float32)
  134. return xt
  135. def _explore(
  136. data: pd.DataFrame,
  137. bounds: Dict[str, Tuple[float, float]],
  138. current: list,
  139. base: Trial,
  140. old: Trial,
  141. config: Dict[str, Tuple[float, float]],
  142. ) -> Tuple[Dict, pd.DataFrame]:
  143. """Returns next hyperparameter configuration to use.
  144. This function primarily processes the data from completed trials
  145. and then requests the next config from the select_config function.
  146. It then adds the new trial to the dataframe, so that the reward change
  147. can be computed using the new weights.
  148. It returns the new point and the dataframe with the new entry.
  149. """
  150. df = data.sort_values(by="Time").reset_index(drop=True)
  151. # Group by trial ID and hyperparams.
  152. # Compute change in timesteps and reward.
  153. df["y"] = df.groupby(["Trial"] + list(bounds.keys()))["Reward"].diff()
  154. df["t_change"] = df.groupby(["Trial"] + list(bounds.keys()))["Time"].diff()
  155. # Delete entries without positive change in t.
  156. df = df[df["t_change"] > 0].reset_index(drop=True)
  157. df["R_before"] = df.Reward - df.y
  158. # Normalize the reward change by the update size.
  159. # For example if trials took diff lengths of time.
  160. df["y"] = df.y / df.t_change
  161. df = df[~df.y.isna()].reset_index(drop=True)
  162. df = df.sort_values(by="Time").reset_index(drop=True)
  163. # Only use the last 1k datapoints, so the GP is not too slow.
  164. df = df.iloc[-1000:, :].reset_index(drop=True)
  165. # We need this to know the T and Reward for the weights.
  166. dfnewpoint = df[df["Trial"] == str(base)]
  167. if not dfnewpoint.empty:
  168. # N ow specify the dataset for the GP.
  169. y = np.array(df.y.values)
  170. # Meta data we keep -> episodes and reward.
  171. # (TODO: convert to curve)
  172. t_r = df[["Time", "R_before"]]
  173. hparams = df[bounds.keys()]
  174. X = pd.concat([t_r, hparams], axis=1).values
  175. newpoint = df[df["Trial"] == str(base)].iloc[-1, :][["Time", "R_before"]].values
  176. new = _select_config(X, y, current, newpoint, bounds, num_f=len(t_r.columns))
  177. new_config = config.copy()
  178. values = []
  179. # Cast types for new hyperparameters.
  180. for i, col in enumerate(hparams.columns):
  181. # Use the type from the old config. Like this types
  182. # should be passed on from the first config downwards.
  183. type_ = type(config[col])
  184. new_config[col] = type_(new[i])
  185. values.append(type_(new[i]))
  186. new_T = df[df["Trial"] == str(base)].iloc[-1, :]["Time"]
  187. new_Reward = df[df["Trial"] == str(base)].iloc[-1, :].Reward
  188. lst = [[str(old)] + [new_T] + values + [new_Reward]]
  189. cols = ["Trial", "Time"] + list(bounds) + ["Reward"]
  190. new_entry = pd.DataFrame(lst, columns=cols)
  191. # Create an entry for the new config, with the reward from the
  192. # copied agent.
  193. data = pd.concat([data, new_entry]).reset_index(drop=True)
  194. else:
  195. new_config = config.copy()
  196. return new_config, data
  197. class PB2(PopulationBasedTraining):
  198. """Implements the Population Based Bandit (PB2) algorithm.
  199. PB2 trains a group of models (or agents) in parallel. Periodically, poorly
  200. performing models clone the state of the top performers, and the hyper-
  201. parameters are re-selected using GP-bandit optimization. The GP model is
  202. trained to predict the improvement in the next training period.
  203. Like PBT, PB2 adapts hyperparameters during training time. This enables
  204. very fast hyperparameter discovery and also automatically discovers
  205. schedules.
  206. This Tune PB2 implementation is built on top of Tune's PBT implementation.
  207. It considers all trials added as part of the PB2 population. If the number
  208. of trials exceeds the cluster capacity, they will be time-multiplexed as to
  209. balance training progress across the population. To run multiple trials,
  210. use `tune.TuneConfig(num_samples=<int>)`.
  211. In {LOG_DIR}/{MY_EXPERIMENT_NAME}/, all mutations are logged in
  212. `pb2_global.txt` and individual policy perturbations are recorded
  213. in pb2_policy_{i}.txt. Tune logs: [target trial tag, clone trial tag,
  214. target trial iteration, clone trial iteration, old config, new config]
  215. on each perturbation step.
  216. Args:
  217. time_attr: The training result attr to use for comparing time.
  218. Note that you can pass in something non-temporal such as
  219. `training_iteration` as a measure of progress, the only requirement
  220. is that the attribute should increase monotonically.
  221. metric: The training result objective value attribute. Stopping
  222. procedures will use this attribute.
  223. mode: One of {min, max}. Determines whether objective is
  224. minimizing or maximizing the metric attribute.
  225. perturbation_interval: Models will be considered for
  226. perturbation at this interval of `time_attr`. Note that
  227. perturbation incurs checkpoint overhead, so you shouldn't set this
  228. to be too frequent.
  229. hyperparam_bounds: Hyperparameters to mutate. The format is
  230. as follows: for each key, enter a list of the form [min, max]
  231. representing the minimum and maximum possible hyperparam values.
  232. A key can also hold a dict for nested hyperparameters.
  233. Tune will sample uniformly between the bounds provided by
  234. `hyperparam_bounds` for the initial hyperparameter values if the
  235. corresponding hyperparameters are not present in a trial's initial `config`.
  236. quantile_fraction: Parameters are transferred from the top
  237. `quantile_fraction` fraction of trials to the bottom
  238. `quantile_fraction` fraction. Needs to be between 0 and 0.5.
  239. Setting it to 0 essentially implies doing no exploitation at all.
  240. custom_explore_fn: You can also specify a custom exploration
  241. function. This function is invoked as `f(config)`, where the input
  242. is the new config generated by Bayesian Optimization. This function
  243. should return the `config` updated as needed.
  244. log_config: Whether to log the ray config of each model to
  245. local_dir at each exploit. Allows config schedule to be
  246. reconstructed.
  247. require_attrs: Whether to require time_attr and metric to appear
  248. in result for every iteration. If True, error will be raised
  249. if these values are not present in trial result.
  250. synch: If False, will use asynchronous implementation of
  251. PBT. Trial perturbations occur every perturbation_interval for each
  252. trial independently. If True, will use synchronous implementation
  253. of PBT. Perturbations will occur only after all trials are
  254. synced at the same time_attr every perturbation_interval.
  255. Defaults to False. See Appendix A.1 here
  256. https://arxiv.org/pdf/1711.09846.pdf.
  257. Example:
  258. .. code-block:: python
  259. from ray import tune
  260. from ray.tune.schedulers.pb2 import PB2
  261. from ray.tune.examples.pbt_function import pbt_function
  262. pb2 = PB2(
  263. metric="mean_accuracy",
  264. mode="max",
  265. perturbation_interval=20,
  266. hyperparam_bounds={"lr": [0.0001, 0.1]},
  267. )
  268. tuner = tune.Tuner(
  269. pbt_function,
  270. tune_config=tune.TuneConfig(
  271. scheduler=pb2,
  272. num_samples=8,
  273. ),
  274. param_space={"lr": 0.0001},
  275. )
  276. tuner.fit()
  277. """
  278. def __init__(
  279. self,
  280. time_attr: str = "time_total_s",
  281. metric: Optional[str] = None,
  282. mode: Optional[str] = None,
  283. perturbation_interval: float = 60.0,
  284. hyperparam_bounds: Dict[str, Union[dict, list, tuple]] = None,
  285. quantile_fraction: float = 0.25,
  286. log_config: bool = True,
  287. require_attrs: bool = True,
  288. synch: bool = False,
  289. custom_explore_fn: Optional[Callable[[dict], dict]] = None,
  290. ):
  291. sklearn_available = import_pb2_dependencies()
  292. if not sklearn_available:
  293. raise RuntimeError("Please install scikit-learn to use PB2.")
  294. hyperparam_bounds = hyperparam_bounds or {}
  295. if not hyperparam_bounds:
  296. raise TuneError(
  297. "`hyperparam_bounds` must be specified to use PB2 scheduler."
  298. )
  299. super(PB2, self).__init__(
  300. time_attr=time_attr,
  301. metric=metric,
  302. mode=mode,
  303. perturbation_interval=perturbation_interval,
  304. hyperparam_mutations=hyperparam_bounds,
  305. quantile_fraction=quantile_fraction,
  306. resample_probability=0,
  307. custom_explore_fn=custom_explore_fn,
  308. log_config=log_config,
  309. require_attrs=require_attrs,
  310. synch=synch,
  311. )
  312. self.last_exploration_time = 0 # when we last explored
  313. self.data = pd.DataFrame()
  314. self._hyperparam_bounds = hyperparam_bounds
  315. self._hyperparam_bounds_flat = flatten_dict(
  316. hyperparam_bounds, prevent_delimiter=True
  317. )
  318. self._validate_hyperparam_bounds(self._hyperparam_bounds_flat)
  319. # Current = trials running that have already re-started after reaching
  320. # the checkpoint. When exploring we care if these trials
  321. # are already in or scheduled to be in the next round.
  322. self.current = None
  323. def on_trial_add(self, tune_controller: "TuneController", trial: Trial):
  324. filled_hyperparams = _fill_config(trial.config, self._hyperparam_bounds)
  325. # Make sure that the params we sampled show up in the CLI output
  326. trial.evaluated_params.update(flatten_dict(filled_hyperparams))
  327. super().on_trial_add(tune_controller, trial)
  328. def _validate_hyperparam_bounds(self, hyperparam_bounds: dict):
  329. """Check that each hyperparam bound is of the form [low, high].
  330. Raises:
  331. ValueError: if any of the hyperparam bounds are of an invalid format.
  332. """
  333. for key, value in hyperparam_bounds.items():
  334. if not isinstance(value, (list, tuple)) or len(value) != 2:
  335. raise ValueError(
  336. "`hyperparam_bounds` values must either be "
  337. f"a list or tuple of size 2, but got {value} "
  338. f"instead for the param '{key}'"
  339. )
  340. low, high = value
  341. if low > high:
  342. raise ValueError(
  343. "`hyperparam_bounds` values must be of the form [low, high] "
  344. f"where low <= high, but got {value} instead for param '{key}'."
  345. )
  346. def _save_trial_state(
  347. self, state: _PBTTrialState, time: int, result: Dict, trial: Trial
  348. ):
  349. score = super(PB2, self)._save_trial_state(state, time, result, trial)
  350. # Data logging for PB2.
  351. # Collect hyperparams names and current values for this trial.
  352. names = list(self._hyperparam_bounds_flat.keys())
  353. flattened_config = flatten_dict(trial.config)
  354. values = [flattened_config[key] for key in names]
  355. # Store trial state and hyperparams in dataframe.
  356. # this needs to be made more general.
  357. lst = [[trial, result[self._time_attr]] + values + [score]]
  358. cols = ["Trial", "Time"] + names + ["Reward"]
  359. entry = pd.DataFrame(lst, columns=cols)
  360. self.data = pd.concat([self.data, entry]).reset_index(drop=True)
  361. self.data.Trial = self.data.Trial.astype("str")
  362. def _get_new_config(self, trial: Trial, trial_to_clone: Trial) -> Tuple[Dict, Dict]:
  363. """Gets new config for trial by exploring trial_to_clone's config using
  364. Bayesian Optimization (BO) to choose the hyperparameter values to explore.
  365. Overrides `PopulationBasedTraining._get_new_config`.
  366. Args:
  367. trial: The current trial that decided to exploit trial_to_clone.
  368. trial_to_clone: The top-performing trial with a hyperparameter config
  369. that the current trial will explore.
  370. Returns:
  371. new_config: New hyperparameter configuration (after BO).
  372. operations: Empty dict since PB2 doesn't explore in easily labeled ways
  373. like PBT does.
  374. """
  375. # If we are at a new timestep, we dont want to penalise for trials
  376. # still going.
  377. if self.data["Time"].max() > self.last_exploration_time:
  378. self.current = None
  379. new_config_flat, data = _explore(
  380. self.data,
  381. self._hyperparam_bounds_flat,
  382. self.current,
  383. trial_to_clone,
  384. trial,
  385. flatten_dict(trial_to_clone.config),
  386. )
  387. # Important to replace the old values, since we are copying across
  388. self.data = data.copy()
  389. # If the current guy being selecting is at a point that is already
  390. # done, then append the data to the "current" which contains the
  391. # points in the current batch.
  392. new = [new_config_flat[key] for key in self._hyperparam_bounds_flat]
  393. new = np.array(new)
  394. new = new.reshape(1, new.size)
  395. if self.data["Time"].max() > self.last_exploration_time:
  396. self.last_exploration_time = self.data["Time"].max()
  397. self.current = new.copy()
  398. else:
  399. self.current = np.concatenate((self.current, new), axis=0)
  400. logger.debug(self.current)
  401. new_config = unflatten_dict(new_config_flat)
  402. if self._custom_explore_fn:
  403. new_config = self._custom_explore_fn(new_config)
  404. assert (
  405. new_config is not None
  406. ), "Custom explore function failed to return a new config"
  407. return new_config, {}