result.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. import io
  2. import json
  3. import logging
  4. import os
  5. from dataclasses import dataclass
  6. from pathlib import Path
  7. from typing import Any, Dict, List, Optional, Tuple, Union
  8. import pandas as pd
  9. import pyarrow
  10. import ray
  11. from ray._private.dict import unflattened_lookup
  12. from ray.air.constants import (
  13. EXPR_ERROR_PICKLE_FILE,
  14. EXPR_PROGRESS_FILE,
  15. EXPR_RESULT_FILE,
  16. )
  17. from ray.util.annotations import PublicAPI
  18. logger = logging.getLogger(__name__)
  19. @dataclass
  20. @PublicAPI(stability="stable")
  21. class Result:
  22. """The final result of a ML training run or a Tune trial.
  23. This is the output produced by ``Trainer.fit``.
  24. ``Tuner.fit`` outputs a :class:`~ray.tune.ResultGrid` that is a collection
  25. of ``Result`` objects.
  26. This API is the recommended way to access the outputs such as:
  27. - checkpoints (``Result.checkpoint``)
  28. - the history of reported metrics (``Result.metrics_dataframe``, ``Result.metrics``)
  29. - errors encountered during a training run (``Result.error``)
  30. The constructor is a private API -- use ``Result.from_path`` to create a result
  31. object from a directory.
  32. Attributes:
  33. metrics: The latest set of reported metrics.
  34. checkpoint: The latest checkpoint.
  35. error: The execution error of the Trainable run, if the trial finishes in error.
  36. path: Path pointing to the result directory on persistent storage. This can
  37. point to a remote storage location (e.g. S3) or to a local location (path
  38. on the head node). The path is accessible via the result's associated
  39. `filesystem`. For instance, for a result stored in S3 at
  40. ``s3://bucket/location``, ``path`` will have the value ``bucket/location``.
  41. metrics_dataframe: The full result dataframe of the Trainable.
  42. The dataframe is indexed by iterations and contains reported
  43. metrics. Note that the dataframe columns are indexed with the
  44. *flattened* keys of reported metrics, so the format of this dataframe
  45. may be slightly different than ``Result.metrics``, which is an unflattened
  46. dict of the latest set of reported metrics.
  47. best_checkpoints: A list of tuples of the best checkpoints and
  48. their associated metrics. The number of
  49. saved checkpoints is determined by :class:`~ray.train.CheckpointConfig`
  50. (by default, all checkpoints will be saved).
  51. """
  52. metrics: Optional[Dict[str, Any]]
  53. checkpoint: Optional["ray.tune.Checkpoint"]
  54. error: Optional[Exception]
  55. path: str
  56. metrics_dataframe: Optional["pd.DataFrame"] = None
  57. best_checkpoints: Optional[
  58. List[Tuple["ray.tune.Checkpoint", Dict[str, Any]]]
  59. ] = None
  60. _storage_filesystem: Optional[pyarrow.fs.FileSystem] = None
  61. _items_to_repr = ["error", "metrics", "path", "filesystem", "checkpoint"]
  62. @property
  63. def config(self) -> Optional[Dict[str, Any]]:
  64. """The config associated with the result."""
  65. if not self.metrics:
  66. return None
  67. return self.metrics.get("config", None)
  68. @property
  69. def filesystem(self) -> pyarrow.fs.FileSystem:
  70. """Return the filesystem that can be used to access the result path.
  71. Returns:
  72. pyarrow.fs.FileSystem implementation.
  73. """
  74. return self._storage_filesystem or pyarrow.fs.LocalFileSystem()
  75. def _repr(self, indent: int = 0) -> str:
  76. """Construct the representation with specified number of space indent."""
  77. from ray.tune.experimental.output import BLACKLISTED_KEYS
  78. from ray.tune.result import AUTO_RESULT_KEYS
  79. shown_attributes = {k: getattr(self, k) for k in self._items_to_repr}
  80. if self.error:
  81. shown_attributes["error"] = type(self.error).__name__
  82. else:
  83. shown_attributes.pop("error")
  84. shown_attributes["filesystem"] = shown_attributes["filesystem"].type_name
  85. if self.metrics:
  86. exclude = set(AUTO_RESULT_KEYS)
  87. exclude.update(BLACKLISTED_KEYS)
  88. shown_attributes["metrics"] = {
  89. k: v for k, v in self.metrics.items() if k not in exclude
  90. }
  91. cls_indent = " " * indent
  92. kws_indent = " " * (indent + 2)
  93. kws = [
  94. f"{kws_indent}{key}={value!r}" for key, value in shown_attributes.items()
  95. ]
  96. kws_repr = ",\n".join(kws)
  97. return "{0}{1}(\n{2}\n{0})".format(cls_indent, type(self).__name__, kws_repr)
  98. def __repr__(self) -> str:
  99. return self._repr(indent=0)
  100. @staticmethod
  101. def _read_file_as_str(
  102. storage_filesystem: pyarrow.fs.FileSystem,
  103. storage_path: str,
  104. ) -> str:
  105. """Opens a file as an input stream reading all byte content sequentially and
  106. decoding read bytes as utf-8 string.
  107. Args:
  108. storage_filesystem: The filesystem to use.
  109. storage_path: The source to open for reading.
  110. """
  111. with storage_filesystem.open_input_stream(storage_path) as f:
  112. return f.readall().decode()
  113. @classmethod
  114. def from_path(
  115. cls,
  116. path: Union[str, os.PathLike],
  117. storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
  118. ) -> "Result":
  119. """Restore a Result object from local or remote trial directory.
  120. Args:
  121. path: A path of a trial directory on local or remote storage
  122. (ex: s3://bucket/path or /tmp/ray_results).
  123. storage_filesystem: A custom filesystem to use. If not provided,
  124. this will be auto-resolved by pyarrow. If provided, the path
  125. is assumed to be prefix-stripped already, and must be a valid path
  126. on the filesystem.
  127. Returns:
  128. A :py:class:`Result` object of that trial.
  129. """
  130. # TODO(justinvyu): Fix circular dependency.
  131. from ray.train import Checkpoint
  132. from ray.train._internal.storage import (
  133. _exists_at_fs_path,
  134. _list_at_fs_path,
  135. get_fs_and_path,
  136. )
  137. from ray.train.constants import CHECKPOINT_DIR_NAME
  138. fs, fs_path = get_fs_and_path(path, storage_filesystem)
  139. if not _exists_at_fs_path(fs, fs_path):
  140. raise RuntimeError(f"Trial folder {fs_path} doesn't exist!")
  141. # Restore metrics from result.json
  142. result_json_file = Path(fs_path, EXPR_RESULT_FILE).as_posix()
  143. progress_csv_file = Path(fs_path, EXPR_PROGRESS_FILE).as_posix()
  144. if _exists_at_fs_path(fs, result_json_file):
  145. lines = cls._read_file_as_str(fs, result_json_file).split("\n")
  146. json_list = [json.loads(line) for line in lines if line]
  147. metrics_df = pd.json_normalize(json_list, sep="/")
  148. latest_metrics = json_list[-1] if json_list else {}
  149. # Fallback to restore from progress.csv
  150. elif _exists_at_fs_path(fs, progress_csv_file):
  151. metrics_df = pd.read_csv(
  152. io.StringIO(cls._read_file_as_str(fs, progress_csv_file))
  153. )
  154. latest_metrics = (
  155. metrics_df.iloc[-1].to_dict() if not metrics_df.empty else {}
  156. )
  157. else:
  158. raise RuntimeError(
  159. f"Failed to restore the Result object: Neither {EXPR_RESULT_FILE}"
  160. f" nor {EXPR_PROGRESS_FILE} exists in the trial folder!"
  161. )
  162. # Restore all checkpoints from the checkpoint folders
  163. checkpoint_dir_names = sorted(
  164. _list_at_fs_path(
  165. fs,
  166. fs_path,
  167. file_filter=lambda file_info: file_info.type
  168. == pyarrow.fs.FileType.Directory
  169. and file_info.base_name.startswith("checkpoint_"),
  170. )
  171. )
  172. if checkpoint_dir_names:
  173. checkpoints = [
  174. Checkpoint(
  175. path=Path(fs_path, checkpoint_dir_name).as_posix(), filesystem=fs
  176. )
  177. for checkpoint_dir_name in checkpoint_dir_names
  178. ]
  179. metrics = []
  180. for checkpoint_dir_name in checkpoint_dir_names:
  181. metrics_corresponding_to_checkpoint = metrics_df[
  182. metrics_df[CHECKPOINT_DIR_NAME] == checkpoint_dir_name
  183. ]
  184. if metrics_corresponding_to_checkpoint.empty:
  185. logger.warning(
  186. "Could not find metrics corresponding to "
  187. f"{checkpoint_dir_name}. These will default to an empty dict."
  188. )
  189. metrics.append(
  190. {}
  191. if metrics_corresponding_to_checkpoint.empty
  192. else metrics_corresponding_to_checkpoint.iloc[-1].to_dict()
  193. )
  194. latest_checkpoint = checkpoints[-1]
  195. # TODO(justinvyu): These are ordered by checkpoint index, since we don't
  196. # know the metric to order these with.
  197. best_checkpoints = list(zip(checkpoints, metrics))
  198. else:
  199. best_checkpoints = latest_checkpoint = None
  200. # Restore the trial error if it exists
  201. error = None
  202. error_file_path = Path(fs_path, EXPR_ERROR_PICKLE_FILE).as_posix()
  203. if _exists_at_fs_path(fs, error_file_path):
  204. with fs.open_input_stream(error_file_path) as f:
  205. error = ray.cloudpickle.load(f)
  206. return Result(
  207. metrics=latest_metrics,
  208. checkpoint=latest_checkpoint,
  209. path=fs_path,
  210. _storage_filesystem=fs,
  211. metrics_dataframe=metrics_df,
  212. best_checkpoints=best_checkpoints,
  213. error=error,
  214. )
  215. @PublicAPI(stability="alpha")
  216. def get_best_checkpoint(
  217. self, metric: str, mode: str
  218. ) -> Optional["ray.tune.Checkpoint"]:
  219. """Get the best checkpoint from this trial based on a specific metric.
  220. Any checkpoints without an associated metric value will be filtered out.
  221. Args:
  222. metric: The key for checkpoints to order on.
  223. mode: One of ["min", "max"].
  224. Returns:
  225. :class:`Checkpoint <ray.train.Checkpoint>` object, or None if there is
  226. no valid checkpoint associated with the metric.
  227. """
  228. if not self.best_checkpoints:
  229. raise RuntimeError("No checkpoint exists in the trial directory!")
  230. if mode not in ["max", "min"]:
  231. raise ValueError(
  232. f'Unsupported mode: {mode}. Please choose from ["min", "max"]!'
  233. )
  234. op = max if mode == "max" else min
  235. valid_checkpoints = [
  236. ckpt_info
  237. for ckpt_info in self.best_checkpoints
  238. if unflattened_lookup(metric, ckpt_info[1], default=None) is not None
  239. ]
  240. if not valid_checkpoints:
  241. raise RuntimeError(
  242. f"Invalid metric name {metric}! "
  243. f"You may choose from the following metrics: {self.metrics.keys()}."
  244. )
  245. return op(valid_checkpoints, key=lambda x: unflattened_lookup(metric, x[1]))[0]