experiment_analysis.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678
  1. import copy
  2. import io
  3. import json
  4. import logging
  5. import os
  6. from numbers import Number
  7. from pathlib import Path
  8. from typing import Any, Dict, List, Optional, Tuple, Union
  9. import pyarrow.fs
  10. from ray.air.constants import EXPR_PROGRESS_FILE, EXPR_RESULT_FILE, TRAINING_ITERATION
  11. from ray.train._internal.storage import _exists_at_fs_path, get_fs_and_path
  12. from ray.tune import Checkpoint
  13. from ray.tune.execution.experiment_state import _find_newest_experiment_checkpoint
  14. from ray.tune.execution.tune_controller import TuneController
  15. from ray.tune.experiment import Trial
  16. from ray.tune.result import CONFIG_PREFIX, DEFAULT_METRIC
  17. from ray.tune.utils import flatten_dict
  18. from ray.tune.utils.serialization import TuneFunctionDecoder
  19. from ray.tune.utils.util import is_nan, is_nan_or_inf, unflattened_lookup
  20. from ray.util.annotations import PublicAPI
  21. try:
  22. import pandas as pd
  23. from pandas import DataFrame
  24. except ImportError:
  25. pd = None
  26. DataFrame = None
  27. logger = logging.getLogger(__name__)
  28. @PublicAPI(stability="beta")
  29. class ExperimentAnalysis:
  30. """Analyze results from a Ray Train/Tune experiment.
  31. To use this class, the run must store the history of reported metrics
  32. in log files (e.g., `result.json` and `progress.csv`).
  33. This is the default behavior, unless default loggers are explicitly excluded
  34. with the `TUNE_DISABLE_AUTO_CALLBACK_LOGGERS=1` environment variable.
  35. Parameters:
  36. experiment_checkpoint_path: Path to an `experiment_state.json` file,
  37. or a directory that contains an `experiment_state.json` file.
  38. default_metric: Default metric for comparing results. Can be
  39. overwritten with the ``metric`` parameter in the respective
  40. functions.
  41. default_mode: Default mode for comparing results. Has to be one
  42. of [min, max]. Can be overwritten with the ``mode`` parameter
  43. in the respective functions.
  44. trials: List of trials that can be accessed via `analysis.trials`.
  45. """
  46. def __init__(
  47. self,
  48. experiment_checkpoint_path: Union[str, os.PathLike],
  49. *,
  50. storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
  51. trials: Optional[List[Trial]] = None,
  52. default_metric: Optional[str] = None,
  53. default_mode: Optional[str] = None,
  54. ):
  55. self.default_metric = default_metric
  56. if default_mode and default_mode not in ["min", "max"]:
  57. raise ValueError("`default_mode` has to be None or one of [min, max]")
  58. self.default_mode = default_mode
  59. if self.default_metric is None and self.default_mode is not None:
  60. # If only a mode was passed, use anonymous metric
  61. self.default_metric = DEFAULT_METRIC
  62. # Resolve the filesystem if not specified.
  63. if storage_filesystem:
  64. self._fs = storage_filesystem
  65. else:
  66. self._fs, experiment_checkpoint_path = get_fs_and_path(
  67. experiment_checkpoint_path
  68. )
  69. # Find the json state file.
  70. experiment_checkpoint_path = str(experiment_checkpoint_path)
  71. if experiment_checkpoint_path.endswith(".json"):
  72. self._experiment_fs_path = os.path.dirname(experiment_checkpoint_path)
  73. self._experiment_json_fs_path = experiment_checkpoint_path
  74. else:
  75. self._experiment_fs_path = experiment_checkpoint_path
  76. experiment_json_fs_path = _find_newest_experiment_checkpoint(
  77. experiment_path=self._experiment_fs_path, fs=self._fs
  78. )
  79. if experiment_json_fs_path is None:
  80. pattern = TuneController.CKPT_FILE_TMPL.format("*")
  81. raise ValueError(
  82. f"No experiment snapshot file of form '{pattern}' was found at: "
  83. f"({self._fs.type_name}, {self._experiment_fs_path})\n"
  84. "Please check if you specified the correct experiment path, "
  85. "which should be a combination of the `storage_path` and `name` "
  86. "specified in your run."
  87. )
  88. self._experiment_json_fs_path = experiment_json_fs_path
  89. self.trials = trials or self._load_trials()
  90. self._trial_dataframes = self._fetch_trial_dataframes()
  91. self._configs = self.get_all_configs()
  92. def _load_trials(self) -> List[Trial]:
  93. with self._fs.open_input_stream(self._experiment_json_fs_path) as f:
  94. experiment_state = json.loads(f.readall(), cls=TuneFunctionDecoder)
  95. experiment_fs_path = Path(self._experiment_fs_path)
  96. trials = []
  97. trial_states = experiment_state["trial_data"]
  98. for trial_json_state, trial_runtime_metadata in trial_states:
  99. trial = Trial.from_json_state(trial_json_state, stub=True)
  100. trial.restore_run_metadata(trial_runtime_metadata)
  101. new_storage = copy.copy(trial.storage)
  102. new_storage.storage_fs_path = experiment_fs_path.parent.as_posix()
  103. new_storage.storage_filesystem = self._fs
  104. new_storage.experiment_dir_name = experiment_fs_path.name
  105. trial.set_storage(new_storage)
  106. trials.append(trial)
  107. return trials
  108. def _fetch_trial_dataframe(self, trial: Trial) -> DataFrame:
  109. force_dtype = {"trial_id": str} # Never convert trial_id to float.
  110. # If there were no reported results, there will be no files into a DataFrame
  111. if trial.last_result is None:
  112. return DataFrame()
  113. json_fs_path = Path(trial.storage.trial_fs_path, EXPR_RESULT_FILE).as_posix()
  114. csv_fs_path = Path(trial.storage.trial_fs_path, EXPR_PROGRESS_FILE).as_posix()
  115. # Prefer reading the JSON if it exists.
  116. if _exists_at_fs_path(trial.storage.storage_filesystem, json_fs_path):
  117. with trial.storage.storage_filesystem.open_input_stream(json_fs_path) as f:
  118. content = f.readall().decode("utf-8").rstrip("\n")
  119. if not content:
  120. return DataFrame()
  121. json_list = [json.loads(row) for row in content.split("\n")]
  122. df = pd.json_normalize(json_list, sep="/")
  123. # Fallback to reading the CSV.
  124. elif _exists_at_fs_path(trial.storage.storage_filesystem, csv_fs_path):
  125. with trial.storage.storage_filesystem.open_input_stream(csv_fs_path) as f:
  126. csv_str = f.readall().decode("utf-8")
  127. df = pd.read_csv(io.StringIO(csv_str), dtype=force_dtype)
  128. else:
  129. raise FileNotFoundError(
  130. f"Could not fetch metrics for {trial}: both {EXPR_RESULT_FILE} and "
  131. f"{EXPR_PROGRESS_FILE} were not found at {trial.storage.trial_fs_path}"
  132. )
  133. return df
  134. def _fetch_trial_dataframes(self) -> Dict[str, DataFrame]:
  135. """Fetches trial dataframes from files.
  136. Returns:
  137. A dictionary mapping trial_id -> pd.DataFrame
  138. """
  139. failures = []
  140. trial_dfs = {}
  141. for trial in self.trials:
  142. try:
  143. trial_dfs[trial.trial_id] = self._fetch_trial_dataframe(trial)
  144. except Exception as e:
  145. failures.append((trial, e))
  146. trial_dfs[trial.trial_id] = DataFrame()
  147. continue
  148. if failures:
  149. fail_str = "\n".join(
  150. [f"- {trial}: {repr(error)}" for trial, error in failures]
  151. )
  152. logger.warning(
  153. f"Failed to fetch metrics for {len(failures)} trial(s):\n{fail_str}"
  154. )
  155. return trial_dfs
  156. def get_all_configs(self, prefix: bool = False) -> Dict[str, Dict]:
  157. """Returns all trial hyperparameter configurations.
  158. Args:
  159. prefix: If True, flattens the config dict
  160. and prepends `config/`.
  161. Returns:
  162. Dict[str, Dict]: Mapping trial_id -> config dict
  163. """
  164. return {
  165. trial.trial_id: (
  166. flatten_dict({CONFIG_PREFIX: trial.config}) if prefix else trial.config
  167. )
  168. for trial in self.trials
  169. }
  170. @property
  171. def experiment_path(self) -> str:
  172. """Path pointing to the experiment directory on persistent storage.
  173. This can point to a remote storage location (e.g. S3) or to a local
  174. location (path on the head node)."""
  175. return self._experiment_fs_path
  176. @property
  177. def best_trial(self) -> Trial:
  178. """Get the best trial of the experiment
  179. The best trial is determined by comparing the last trial results
  180. using the `metric` and `mode` parameters passed to `tune.run()`.
  181. If you didn't pass these parameters, use
  182. `get_best_trial(metric, mode, scope)` instead.
  183. """
  184. if not self.default_metric or not self.default_mode:
  185. raise ValueError(
  186. "To fetch the `best_trial`, pass a `metric` and `mode` "
  187. "parameter to `tune.run()`. Alternatively, use the "
  188. "`get_best_trial(metric, mode)` method to set the metric "
  189. "and mode explicitly."
  190. )
  191. return self.get_best_trial(self.default_metric, self.default_mode)
  192. @property
  193. def best_config(self) -> Dict:
  194. """Get the config of the best trial of the experiment
  195. The best trial is determined by comparing the last trial results
  196. using the `metric` and `mode` parameters passed to `tune.run()`.
  197. If you didn't pass these parameters, use
  198. `get_best_config(metric, mode, scope)` instead.
  199. """
  200. if not self.default_metric or not self.default_mode:
  201. raise ValueError(
  202. "To fetch the `best_config`, pass a `metric` and `mode` "
  203. "parameter to `tune.run()`. Alternatively, use the "
  204. "`get_best_config(metric, mode)` method to set the metric "
  205. "and mode explicitly."
  206. )
  207. return self.get_best_config(self.default_metric, self.default_mode)
  208. @property
  209. def best_checkpoint(self) -> Checkpoint:
  210. """Get the checkpoint path of the best trial of the experiment
  211. The best trial is determined by comparing the last trial results
  212. using the `metric` and `mode` parameters passed to `tune.run()`.
  213. If you didn't pass these parameters, use
  214. `get_best_checkpoint(trial, metric, mode)` instead.
  215. Returns:
  216. :class:`Checkpoint <ray.tune.Checkpoint>` object.
  217. """
  218. if not self.default_metric or not self.default_mode:
  219. raise ValueError(
  220. "To fetch the `best_checkpoint`, pass a `metric` and `mode` "
  221. "parameter to `tune.run()`. Alternatively, use the "
  222. "`get_best_checkpoint(trial, metric, mode)` method to set the "
  223. "metric and mode explicitly."
  224. )
  225. best_trial = self.best_trial
  226. if not best_trial:
  227. raise ValueError(
  228. f"No best trial found. Please check if you specified the "
  229. f"correct default metric ({self.default_metric}) and mode "
  230. f"({self.default_mode})."
  231. )
  232. return self.get_best_checkpoint(
  233. best_trial, self.default_metric, self.default_mode
  234. )
  235. @property
  236. def best_dataframe(self) -> DataFrame:
  237. """Get the full result dataframe of the best trial of the experiment
  238. The best trial is determined by comparing the last trial results
  239. using the `metric` and `mode` parameters passed to `tune.run()`.
  240. If you didn't pass these parameters, use
  241. `get_best_trial(metric, mode)` and use it to look for the dataframe
  242. in the `self.trial_dataframes` dict.
  243. """
  244. if not self.default_metric or not self.default_mode:
  245. raise ValueError(
  246. "To fetch the `best_result`, pass a `metric` and `mode` "
  247. "parameter to `tune.run()`."
  248. )
  249. return self.trial_dataframes[self.best_trial.trial_id]
  250. @property
  251. def best_result(self) -> Dict:
  252. """Get the last result of the best trial of the experiment
  253. The best trial is determined by comparing the last trial results
  254. using the `metric` and `mode` parameters passed to `tune.run()`.
  255. If you didn't pass these parameters, use
  256. `get_best_trial(metric, mode, scope).last_result` instead.
  257. """
  258. if not self.default_metric or not self.default_mode:
  259. raise ValueError(
  260. "To fetch the `best_result`, pass a `metric` and `mode` "
  261. "parameter to `tune.run()`. Alternatively, use "
  262. "`get_best_trial(metric, mode).last_result` to set "
  263. "the metric and mode explicitly and fetch the last result."
  264. )
  265. return self.best_trial.last_result
  266. def _delimiter(self):
  267. return os.environ.get("TUNE_RESULT_DELIM", "/")
  268. @property
  269. def best_result_df(self) -> DataFrame:
  270. """Get the best result of the experiment as a pandas dataframe.
  271. The best trial is determined by comparing the last trial results
  272. using the `metric` and `mode` parameters passed to `tune.run()`.
  273. If you didn't pass these parameters, use
  274. `get_best_trial(metric, mode, scope).last_result` instead.
  275. """
  276. if not pd:
  277. raise ValueError(
  278. "`best_result_df` requires pandas. Install with "
  279. "`pip install pandas`."
  280. )
  281. best_result = flatten_dict(self.best_result, delimiter=self._delimiter())
  282. return pd.DataFrame.from_records([best_result], index="trial_id")
  283. @property
  284. def results(self) -> Dict[str, Dict]:
  285. """Get the last result of the all trials of the experiment"""
  286. return {trial.trial_id: trial.last_result for trial in self.trials}
  287. @property
  288. def results_df(self) -> DataFrame:
  289. """Get all the last results as a pandas dataframe."""
  290. if not pd:
  291. raise ValueError(
  292. "`results_df` requires pandas. Install with `pip install pandas`."
  293. )
  294. return pd.DataFrame.from_records(
  295. [
  296. flatten_dict(trial.last_result, delimiter=self._delimiter())
  297. for trial in self.trials
  298. ],
  299. index="trial_id",
  300. )
  301. @property
  302. def trial_dataframes(self) -> Dict[str, DataFrame]:
  303. """List of all dataframes of the trials.
  304. Each dataframe is indexed by iterations and contains reported
  305. metrics.
  306. """
  307. return self._trial_dataframes
  308. def dataframe(
  309. self, metric: Optional[str] = None, mode: Optional[str] = None
  310. ) -> DataFrame:
  311. """Returns a pandas.DataFrame object constructed from the trials.
  312. This function will look through all observed results of each trial
  313. and return the one corresponding to the passed ``metric`` and
  314. ``mode``: If ``mode=min``, it returns the result with the lowest
  315. *ever* observed ``metric`` for this trial (this is not necessarily
  316. the last)! For ``mode=max``, it's the highest, respectively. If
  317. ``metric=None`` or ``mode=None``, the last result will be returned.
  318. Args:
  319. metric: Key for trial info to order on. If None, uses last result.
  320. mode: One of [None, "min", "max"].
  321. Returns:
  322. pd.DataFrame: Constructed from a result dict of each trial.
  323. """
  324. # Do not validate metric/mode here or set from default metric/mode!
  325. # Otherwise we will get confusing results as the lowest ever observed
  326. # result may not be the last result.
  327. if mode and mode not in ["min", "max"]:
  328. raise ValueError("If set, `mode` has to be one of [min, max]")
  329. if mode and not metric:
  330. raise ValueError(
  331. "If a `mode` is passed to `ExperimentAnalysis.dataframe(),"
  332. " you'll also have to pass a `metric`!"
  333. )
  334. rows = self._retrieve_rows(metric=metric, mode=mode)
  335. all_configs = self.get_all_configs(prefix=True)
  336. for path, config in all_configs.items():
  337. if path in rows:
  338. rows[path].update(config)
  339. rows[path].update(logdir=path)
  340. return pd.DataFrame(list(rows.values()))
  341. def _get_trial_checkpoints_with_metric(
  342. self, trial: Trial, metric: Optional[str] = None
  343. ) -> List[Tuple[Checkpoint, Number]]:
  344. """Get all checkpoints and a specified metric of a trial.
  345. Args:
  346. trial: The log directory of a trial, or a trial instance.
  347. metric: key for trial info to return, e.g. "mean_accuracy".
  348. "training_iteration" is used by default if no value was
  349. passed to ``self.default_metric``.
  350. Returns:
  351. List of [Checkpoint, metric] for all checkpoints of the trial.
  352. """
  353. metric = metric or self.default_metric or TRAINING_ITERATION
  354. best_checkpoint_results = (
  355. trial.run_metadata.checkpoint_manager.best_checkpoint_results
  356. )
  357. best_checkpoints = [
  358. (checkpoint_result.checkpoint, checkpoint_result.metrics)
  359. for checkpoint_result in best_checkpoint_results
  360. ]
  361. # Support nested metrics given as flattened strings, e.g.
  362. # "info/learner/default_policy/policy_loss".
  363. return [
  364. (checkpoint, unflattened_lookup(metric, metrics))
  365. for checkpoint, metrics in best_checkpoints
  366. ]
  367. def get_best_checkpoint(
  368. self,
  369. trial: Trial,
  370. metric: Optional[str] = None,
  371. mode: Optional[str] = None,
  372. ) -> Optional[Checkpoint]:
  373. """Gets best persistent checkpoint path of provided trial.
  374. Any checkpoints with an associated metric value of ``nan`` will be filtered out.
  375. Args:
  376. trial: The log directory of a trial, or a trial instance.
  377. metric: key of trial info to return, e.g. "mean_accuracy".
  378. "training_iteration" is used by default if no value was
  379. passed to ``self.default_metric``.
  380. mode: One of [min, max]. Defaults to ``self.default_mode``.
  381. Returns:
  382. A :class:`Checkpoint <ray.tune.Checkpoint>` object
  383. """
  384. metric = metric or self.default_metric or TRAINING_ITERATION
  385. mode = self._validate_mode(mode)
  386. checkpoints_and_metrics = self._get_trial_checkpoints_with_metric(trial, metric)
  387. # Filter out nan. Sorting nan values leads to undefined behavior.
  388. checkpoints_and_metrics = list(
  389. filter(lambda x: not is_nan(x[1]), checkpoints_and_metrics)
  390. )
  391. if not checkpoints_and_metrics:
  392. logger.error(f"No checkpoints have been found for trial {trial}.")
  393. return None
  394. score_order_factor = -1 if mode == "min" else 1
  395. best_checkpoint, _ = max(
  396. checkpoints_and_metrics, key=lambda x: score_order_factor * x[1]
  397. )
  398. return best_checkpoint
  399. def get_best_trial(
  400. self,
  401. metric: Optional[str] = None,
  402. mode: Optional[str] = None,
  403. scope: str = "last",
  404. filter_nan_and_inf: bool = True,
  405. ) -> Optional[Trial]:
  406. """Retrieve the best trial object.
  407. Compares all trials' scores on ``metric``.
  408. If ``metric`` is not specified, ``self.default_metric`` will be used.
  409. If `mode` is not specified, ``self.default_mode`` will be used.
  410. These values are usually initialized by passing the ``metric`` and
  411. ``mode`` parameters to ``tune.run()``.
  412. Args:
  413. metric: Key for trial info to order on. Defaults to
  414. ``self.default_metric``.
  415. mode: One of [min, max]. Defaults to ``self.default_mode``.
  416. scope: One of [all, last, avg, last-5-avg, last-10-avg].
  417. If `scope=last`, only look at each trial's final step for
  418. `metric`, and compare across trials based on `mode=[min,max]`.
  419. If `scope=avg`, consider the simple average over all steps
  420. for `metric` and compare across trials based on
  421. `mode=[min,max]`. If `scope=last-5-avg` or `scope=last-10-avg`,
  422. consider the simple average over the last 5 or 10 steps for
  423. `metric` and compare across trials based on `mode=[min,max]`.
  424. If `scope=all`, find each trial's min/max score for `metric`
  425. based on `mode`, and compare trials based on `mode=[min,max]`.
  426. filter_nan_and_inf: If True (default), NaN or infinite
  427. values are disregarded and these trials are never selected as
  428. the best trial.
  429. Returns:
  430. The best trial for the provided metric. If no trials contain the provided
  431. metric, or if the value for the metric is NaN for all trials,
  432. then returns None.
  433. """
  434. if len(self.trials) == 1:
  435. return self.trials[0]
  436. metric = self._validate_metric(metric)
  437. mode = self._validate_mode(mode)
  438. if scope not in ["all", "last", "avg", "last-5-avg", "last-10-avg"]:
  439. raise ValueError(
  440. "ExperimentAnalysis: attempting to get best trial for "
  441. 'metric {} for scope {} not in ["all", "last", "avg", '
  442. '"last-5-avg", "last-10-avg"]. '
  443. "If you didn't pass a `metric` parameter to `tune.run()`, "
  444. "you have to pass one when fetching the best trial.".format(
  445. metric, scope
  446. )
  447. )
  448. best_trial = None
  449. best_metric_score = None
  450. for trial in self.trials:
  451. if metric not in trial.metric_analysis:
  452. continue
  453. if scope in ["last", "avg", "last-5-avg", "last-10-avg"]:
  454. metric_score = trial.metric_analysis[metric][scope]
  455. else:
  456. metric_score = trial.metric_analysis[metric][mode]
  457. if filter_nan_and_inf and is_nan_or_inf(metric_score):
  458. continue
  459. if best_metric_score is None:
  460. best_metric_score = metric_score
  461. best_trial = trial
  462. continue
  463. if (mode == "max") and (best_metric_score < metric_score):
  464. best_metric_score = metric_score
  465. best_trial = trial
  466. elif (mode == "min") and (best_metric_score > metric_score):
  467. best_metric_score = metric_score
  468. best_trial = trial
  469. if not best_trial:
  470. logger.warning(
  471. "Could not find best trial. Did you pass the correct `metric` "
  472. "parameter?"
  473. )
  474. return best_trial
  475. def get_best_config(
  476. self,
  477. metric: Optional[str] = None,
  478. mode: Optional[str] = None,
  479. scope: str = "last",
  480. ) -> Optional[Dict]:
  481. """Retrieve the best config corresponding to the trial.
  482. Compares all trials' scores on `metric`.
  483. If ``metric`` is not specified, ``self.default_metric`` will be used.
  484. If `mode` is not specified, ``self.default_mode`` will be used.
  485. These values are usually initialized by passing the ``metric`` and
  486. ``mode`` parameters to ``tune.run()``.
  487. Args:
  488. metric: Key for trial info to order on. Defaults to
  489. ``self.default_metric``.
  490. mode: One of [min, max]. Defaults to ``self.default_mode``.
  491. scope: One of [all, last, avg, last-5-avg, last-10-avg].
  492. If `scope=last`, only look at each trial's final step for
  493. `metric`, and compare across trials based on `mode=[min,max]`.
  494. If `scope=avg`, consider the simple average over all steps
  495. for `metric` and compare across trials based on
  496. `mode=[min,max]`. If `scope=last-5-avg` or `scope=last-10-avg`,
  497. consider the simple average over the last 5 or 10 steps for
  498. `metric` and compare across trials based on `mode=[min,max]`.
  499. If `scope=all`, find each trial's min/max score for `metric`
  500. based on `mode`, and compare trials based on `mode=[min,max]`.
  501. """
  502. best_trial = self.get_best_trial(metric, mode, scope)
  503. return best_trial.config if best_trial else None
  504. def get_last_checkpoint(
  505. self, trial=None, metric="training_iteration", mode="max"
  506. ) -> Optional[Checkpoint]:
  507. """Gets the last checkpoint of the provided trial,
  508. i.e., with the highest "training_iteration".
  509. If no trial is specified, it loads the best trial according to the
  510. provided metric and mode (defaults to max. training iteration).
  511. Args:
  512. trial: If None, load the best trial automatically.
  513. metric: If no trial is specified, use this metric to identify
  514. the best trial and load the last checkpoint from this trial.
  515. mode: If no trial is specified, use the metric and this mode
  516. to identify the best trial and load the last checkpoint from it.
  517. Returns:
  518. Path for last checkpoint of trial
  519. """
  520. trial = trial or self.get_best_trial(metric, mode)
  521. return self.get_best_checkpoint(trial, TRAINING_ITERATION, "max")
  522. def _validate_metric(self, metric: str) -> str:
  523. if not metric and not self.default_metric:
  524. raise ValueError(
  525. "No `metric` has been passed and `default_metric` has "
  526. "not been set. Please specify the `metric` parameter."
  527. )
  528. return metric or self.default_metric
  529. def _validate_mode(self, mode: str) -> str:
  530. if not mode and not self.default_mode:
  531. raise ValueError(
  532. "No `mode` has been passed and `default_mode` has "
  533. "not been set. Please specify the `mode` parameter."
  534. )
  535. if mode and mode not in ["min", "max"]:
  536. raise ValueError("If set, `mode` has to be one of [min, max]")
  537. return mode or self.default_mode
  538. def _retrieve_rows(
  539. self, metric: Optional[str] = None, mode: Optional[str] = None
  540. ) -> Dict[str, Any]:
  541. assert mode is None or mode in ["max", "min"]
  542. assert not mode or metric
  543. rows = {}
  544. for path, df in self.trial_dataframes.items():
  545. if df.empty:
  546. continue
  547. if metric not in df:
  548. idx = -1
  549. elif mode == "max":
  550. idx = df[metric].idxmax()
  551. elif mode == "min":
  552. idx = df[metric].idxmin()
  553. else:
  554. idx = -1
  555. try:
  556. rows[path] = df.iloc[idx].to_dict()
  557. except TypeError:
  558. # idx is nan
  559. logger.warning(
  560. "Warning: Non-numerical value(s) encountered for {}".format(path)
  561. )
  562. return rows
  563. def __getstate__(self) -> Dict[str, Any]:
  564. """Ensure that trials are marked as stubs when pickling,
  565. so that they can be loaded later without the trainable
  566. being registered.
  567. """
  568. state = self.__dict__.copy()
  569. def make_stub_if_needed(trial: Trial) -> Trial:
  570. if trial.stub:
  571. return trial
  572. trial_copy = Trial(trial.trainable_name, stub=True)
  573. trial_copy.__setstate__(trial.__getstate__())
  574. return trial_copy
  575. state["trials"] = [make_stub_if_needed(t) for t in state["trials"]]
  576. return state