checkpoint_manager.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import logging
  2. import numbers
  3. from typing import Any, Callable, Dict, List, Optional, Tuple
  4. from ray._private import ray_constants
  5. from ray._private.dict import flatten_dict
  6. from ray.air._internal.util import is_nan
  7. from ray.air.config import MAX
  8. from ray.train import Checkpoint, CheckpointConfig
  9. from ray.train._internal.session import _TrainingResult
  10. from ray.train._internal.storage import _delete_fs_path
  11. from ray.train.constants import TUNE_ONLY_STORE_CHECKPOINT_SCORE_ATTRIBUTE
  12. logger = logging.getLogger(__name__)
  13. def _insert_into_sorted_list(
  14. list: List[_TrainingResult],
  15. item: _TrainingResult,
  16. key: Callable[[_TrainingResult], Any],
  17. checkpoint_to_report_index: Optional[Dict[Checkpoint, int]] = None,
  18. ):
  19. """Insert an item into a sorted list with a custom key function.
  20. Args:
  21. list: The list to insert the item into.
  22. item: The item to insert.
  23. key: The key function to use to sort the list.
  24. checkpoint_to_report_index: A dictionary mapping checkpoints to report indices.
  25. Used to break ties when scores are equal.
  26. """
  27. checkpoint_to_report_index = checkpoint_to_report_index or {}
  28. # TODO: optimize this with sortedlist, batching, etc
  29. i = 0
  30. while i < len(list):
  31. # When scores are equal, later checkpoints are later in the list.
  32. list_item_key, item_key = key(list[i]), key(item)
  33. if list_item_key > item_key or (
  34. list_item_key == item_key
  35. and checkpoint_to_report_index.get(list[i].checkpoint, 0)
  36. > checkpoint_to_report_index.get(item.checkpoint, 0)
  37. ):
  38. break
  39. i += 1
  40. list.insert(i, item)
  41. class _CheckpointManager:
  42. """Checkpoint manager that handles checkpoint book-keeping for a trial.
  43. The main purpose of this abstraction is to keep the top K checkpoints based on
  44. recency/a user-provided metric.
  45. NOTE: This class interacts with `_TrainingResult` objects, which are
  46. (checkpoint, metrics) pairs. This is to order checkpoints by metrics.
  47. Args:
  48. checkpoint_config: Defines how many and which checkpoints to keep.
  49. """
  50. def __init__(self, checkpoint_config: Optional[CheckpointConfig]):
  51. self._checkpoint_config = checkpoint_config or CheckpointConfig()
  52. # List of checkpoints ordered by ascending score.
  53. self._checkpoint_results: List[_TrainingResult] = []
  54. # The latest registered checkpoint.
  55. # This should never be immediately deleted upon registration,
  56. # even if it's not in the top K checkpoints, based on score.
  57. self._latest_checkpoint_result: Optional[_TrainingResult] = None
  58. if (
  59. self._checkpoint_config.num_to_keep is not None
  60. and self._checkpoint_config.num_to_keep <= 0
  61. ):
  62. raise ValueError(
  63. f"`num_to_keep` must >= 1, got: "
  64. f"{self._checkpoint_config.num_to_keep}"
  65. )
  66. @property
  67. def checkpoint_config(self):
  68. return self._checkpoint_config
  69. def register_checkpoint(self, checkpoint_result: _TrainingResult):
  70. """Register new checkpoint and add to bookkeeping.
  71. This method will register a new checkpoint and add it to the internal
  72. bookkeeping logic. This means the checkpoint manager will decide if
  73. this checkpoint should be kept, and if older or worse performing
  74. checkpoints should be deleted.
  75. Args:
  76. checkpoint: Tracked checkpoint object to add to bookkeeping.
  77. """
  78. self._latest_checkpoint_result = checkpoint_result
  79. score_attr = self._checkpoint_config.checkpoint_score_attribute
  80. if ray_constants.env_bool(TUNE_ONLY_STORE_CHECKPOINT_SCORE_ATTRIBUTE, False):
  81. metrics = (
  82. {score_attr: checkpoint_result.metrics[score_attr]}
  83. if score_attr in checkpoint_result.metrics
  84. else {}
  85. )
  86. checkpoint_result = _TrainingResult(
  87. checkpoint=checkpoint_result.checkpoint,
  88. metrics=metrics,
  89. )
  90. if score_attr is not None and score_attr in checkpoint_result.metrics:
  91. # If we're ordering by a score, insert the checkpoint
  92. # so that the list remains sorted.
  93. _insert_into_sorted_list(
  94. self._checkpoint_results,
  95. checkpoint_result,
  96. key=self._get_checkpoint_score,
  97. )
  98. else:
  99. # If no metric is provided, just append (ordering by time of registration).
  100. self._checkpoint_results.append(checkpoint_result)
  101. if self._checkpoint_config.num_to_keep is not None:
  102. # Delete the bottom (N - K) checkpoints
  103. worst_results = set(
  104. self._checkpoint_results[: -self._checkpoint_config.num_to_keep]
  105. )
  106. # Except for the latest checkpoint.
  107. results_to_delete = worst_results - {self._latest_checkpoint_result}
  108. # Update internal state before actually deleting them.
  109. self._checkpoint_results = [
  110. checkpoint_result
  111. for checkpoint_result in self._checkpoint_results
  112. if checkpoint_result not in results_to_delete
  113. ]
  114. for checkpoint_result in results_to_delete:
  115. checkpoint = checkpoint_result.checkpoint
  116. logger.debug("Deleting checkpoint: ", checkpoint)
  117. _delete_fs_path(fs=checkpoint.filesystem, fs_path=checkpoint.path)
  118. def _get_checkpoint_score(
  119. self, checkpoint: _TrainingResult
  120. ) -> Tuple[bool, numbers.Number]:
  121. """Get the score for a checkpoint, according to checkpoint config.
  122. If `mode="min"`, the metric is negated so that the lowest score is
  123. treated as the best.
  124. Returns:
  125. Tuple: A tuple of (not_is_nan: bool, score: numbers.Number).
  126. This score orders: nan values < float("-inf") < valid numeric metrics
  127. """
  128. checkpoint_score_attribute = self._checkpoint_config.checkpoint_score_attribute
  129. if checkpoint_score_attribute:
  130. flat_metrics = flatten_dict(checkpoint.metrics)
  131. try:
  132. checkpoint_result = flat_metrics[checkpoint_score_attribute]
  133. except KeyError:
  134. valid_keys = list(flat_metrics.keys())
  135. logger.error(
  136. f"Result dict has no key: {checkpoint_score_attribute}. "
  137. f"checkpoint_score_attr must be set to a key in the "
  138. f"result dict. Valid keys are: {valid_keys}"
  139. )
  140. checkpoint_result = float("-inf")
  141. else:
  142. checkpoint_result = float("-inf")
  143. checkpoint_score_order = self._checkpoint_config.checkpoint_score_order
  144. order_factor = 1.0 if checkpoint_score_order == MAX else -1.0
  145. checkpoint_score = order_factor * checkpoint_result
  146. if not isinstance(checkpoint_score, numbers.Number):
  147. raise ValueError(
  148. f"Unable to persist checkpoint for "
  149. f"checkpoint_score_attribute: "
  150. f"{checkpoint_score_attribute} with value "
  151. f"{checkpoint_score}. "
  152. f"This attribute must be numerical."
  153. )
  154. return (
  155. (not is_nan(checkpoint_score), checkpoint_score)
  156. if not is_nan(checkpoint_score)
  157. else (False, float("-inf"))
  158. )
  159. @property
  160. def best_checkpoint_result(self) -> Optional[_TrainingResult]:
  161. return self._checkpoint_results[-1] if self._checkpoint_results else None
  162. @property
  163. def latest_checkpoint_result(self) -> Optional[_TrainingResult]:
  164. return self._latest_checkpoint_result
  165. @property
  166. def best_checkpoint_results(self) -> List[_TrainingResult]:
  167. if self._checkpoint_config.num_to_keep is None:
  168. return self._checkpoint_results
  169. return self._checkpoint_results[-self._checkpoint_config.num_to_keep :]