result.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import logging
  2. import os
  3. from dataclasses import dataclass
  4. from typing import Any, Dict, List, Optional, Tuple, Union
  5. import pandas as pd
  6. import pyarrow
  7. import ray
  8. from ray.air.result import Result as ResultV1
  9. from ray.train import Checkpoint, CheckpointConfig
  10. from ray.train.v2._internal.constants import CHECKPOINT_MANAGER_SNAPSHOT_FILENAME
  11. from ray.train.v2._internal.execution.checkpoint.checkpoint_manager import (
  12. CheckpointManager,
  13. )
  14. from ray.train.v2._internal.execution.storage import (
  15. StorageContext,
  16. _exists_at_fs_path,
  17. get_fs_and_path,
  18. )
  19. from ray.train.v2.api.exceptions import TrainingFailedError
  20. from ray.util.annotations import Deprecated, PublicAPI
  21. logger = logging.getLogger(__name__)
  22. @dataclass
  23. class Result(ResultV1):
  24. checkpoint: Optional[Checkpoint]
  25. error: Optional[TrainingFailedError]
  26. best_checkpoints: Optional[List[Tuple[Checkpoint, Dict[str, Any]]]] = None
  27. @PublicAPI(stability="alpha")
  28. def get_best_checkpoint(
  29. self, metric: str, mode: str
  30. ) -> Optional["ray.train.Checkpoint"]:
  31. return super().get_best_checkpoint(metric, mode)
  32. @classmethod
  33. def from_path(
  34. cls,
  35. path: Union[str, os.PathLike],
  36. storage_filesystem: Optional[pyarrow.fs.FileSystem] = None,
  37. ) -> "Result":
  38. """Restore a training result from a previously saved training run path.
  39. Args:
  40. path: Path to the run output directory
  41. storage_filesystem: Optional filesystem to use for accessing the path
  42. Returns:
  43. Result object with restored checkpoints and metrics
  44. """
  45. fs, fs_path = get_fs_and_path(str(path), storage_filesystem)
  46. # Validate that the experiment directory exists
  47. if not _exists_at_fs_path(fs, fs_path):
  48. raise RuntimeError(f"Experiment folder {fs_path} doesn't exist.")
  49. # Remove trailing slashes to handle paths correctly
  50. # os.path.basename() returns empty string for paths with trailing slashes
  51. fs_path = fs_path.rstrip("/")
  52. storage_path, experiment_dir_name = os.path.dirname(fs_path), os.path.basename(
  53. fs_path
  54. )
  55. storage_context = StorageContext(
  56. storage_path=storage_path,
  57. experiment_dir_name=experiment_dir_name,
  58. storage_filesystem=fs,
  59. )
  60. # Validate that the checkpoint manager snapshot file exists
  61. if not _exists_at_fs_path(
  62. storage_context.storage_filesystem,
  63. storage_context.checkpoint_manager_snapshot_path,
  64. ):
  65. raise RuntimeError(
  66. f"Failed to restore the Result object: "
  67. f"{CHECKPOINT_MANAGER_SNAPSHOT_FILENAME} doesn't exist in the "
  68. f"experiment folder. Make sure that this is an output directory created by a Ray Train run."
  69. )
  70. checkpoint_manager = CheckpointManager(
  71. storage_context=storage_context,
  72. checkpoint_config=CheckpointConfig(),
  73. )
  74. # When we build a Result object from checkpoints, the error is not loaded.
  75. return cls._from_checkpoint_manager(
  76. checkpoint_manager=checkpoint_manager,
  77. storage_context=storage_context,
  78. )
  79. @classmethod
  80. def _from_checkpoint_manager(
  81. cls,
  82. checkpoint_manager: CheckpointManager,
  83. storage_context: StorageContext,
  84. error: Optional[TrainingFailedError] = None,
  85. ) -> "Result":
  86. """Create a Result object from a CheckpointManager."""
  87. latest_checkpoint_result = checkpoint_manager.latest_checkpoint_result
  88. if latest_checkpoint_result:
  89. latest_metrics = latest_checkpoint_result.metrics
  90. latest_checkpoint = latest_checkpoint_result.checkpoint
  91. else:
  92. latest_metrics = None
  93. latest_checkpoint = None
  94. best_checkpoints = [
  95. (r.checkpoint, r.metrics)
  96. for r in checkpoint_manager.best_checkpoint_results
  97. ]
  98. # Provide the history of metrics attached to checkpoints as a dataframe.
  99. metrics_dataframe = None
  100. if best_checkpoints:
  101. metrics_dataframe = pd.DataFrame([m for _, m in best_checkpoints])
  102. return Result(
  103. metrics=latest_metrics,
  104. checkpoint=latest_checkpoint,
  105. error=error,
  106. path=storage_context.experiment_fs_path,
  107. best_checkpoints=best_checkpoints,
  108. metrics_dataframe=metrics_dataframe,
  109. _storage_filesystem=storage_context.storage_filesystem,
  110. )
  111. @property
  112. @Deprecated
  113. def config(self) -> Optional[Dict[str, Any]]:
  114. raise DeprecationWarning(
  115. "The `config` property for a `ray.train.Result` is deprecated, "
  116. "since it is only relevant in the context of Ray Tune."
  117. )