| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202 |
- import logging
- import numbers
- from typing import Any, Callable, Dict, List, Optional, Tuple
- from ray._private import ray_constants
- from ray._private.dict import flatten_dict
- from ray.air._internal.util import is_nan
- from ray.air.config import MAX
- from ray.train import Checkpoint, CheckpointConfig
- from ray.train._internal.session import _TrainingResult
- from ray.train._internal.storage import _delete_fs_path
- from ray.train.constants import TUNE_ONLY_STORE_CHECKPOINT_SCORE_ATTRIBUTE
- logger = logging.getLogger(__name__)
- def _insert_into_sorted_list(
- list: List[_TrainingResult],
- item: _TrainingResult,
- key: Callable[[_TrainingResult], Any],
- checkpoint_to_report_index: Optional[Dict[Checkpoint, int]] = None,
- ):
- """Insert an item into a sorted list with a custom key function.
- Args:
- list: The list to insert the item into.
- item: The item to insert.
- key: The key function to use to sort the list.
- checkpoint_to_report_index: A dictionary mapping checkpoints to report indices.
- Used to break ties when scores are equal.
- """
- checkpoint_to_report_index = checkpoint_to_report_index or {}
- # TODO: optimize this with sortedlist, batching, etc
- i = 0
- while i < len(list):
- # When scores are equal, later checkpoints are later in the list.
- list_item_key, item_key = key(list[i]), key(item)
- if list_item_key > item_key or (
- list_item_key == item_key
- and checkpoint_to_report_index.get(list[i].checkpoint, 0)
- > checkpoint_to_report_index.get(item.checkpoint, 0)
- ):
- break
- i += 1
- list.insert(i, item)
- class _CheckpointManager:
- """Checkpoint manager that handles checkpoint book-keeping for a trial.
- The main purpose of this abstraction is to keep the top K checkpoints based on
- recency/a user-provided metric.
- NOTE: This class interacts with `_TrainingResult` objects, which are
- (checkpoint, metrics) pairs. This is to order checkpoints by metrics.
- Args:
- checkpoint_config: Defines how many and which checkpoints to keep.
- """
- def __init__(self, checkpoint_config: Optional[CheckpointConfig]):
- self._checkpoint_config = checkpoint_config or CheckpointConfig()
- # List of checkpoints ordered by ascending score.
- self._checkpoint_results: List[_TrainingResult] = []
- # The latest registered checkpoint.
- # This should never be immediately deleted upon registration,
- # even if it's not in the top K checkpoints, based on score.
- self._latest_checkpoint_result: Optional[_TrainingResult] = None
- if (
- self._checkpoint_config.num_to_keep is not None
- and self._checkpoint_config.num_to_keep <= 0
- ):
- raise ValueError(
- f"`num_to_keep` must >= 1, got: "
- f"{self._checkpoint_config.num_to_keep}"
- )
- @property
- def checkpoint_config(self):
- return self._checkpoint_config
- def register_checkpoint(self, checkpoint_result: _TrainingResult):
- """Register new checkpoint and add to bookkeeping.
- This method will register a new checkpoint and add it to the internal
- bookkeeping logic. This means the checkpoint manager will decide if
- this checkpoint should be kept, and if older or worse performing
- checkpoints should be deleted.
- Args:
- checkpoint: Tracked checkpoint object to add to bookkeeping.
- """
- self._latest_checkpoint_result = checkpoint_result
- score_attr = self._checkpoint_config.checkpoint_score_attribute
- if ray_constants.env_bool(TUNE_ONLY_STORE_CHECKPOINT_SCORE_ATTRIBUTE, False):
- metrics = (
- {score_attr: checkpoint_result.metrics[score_attr]}
- if score_attr in checkpoint_result.metrics
- else {}
- )
- checkpoint_result = _TrainingResult(
- checkpoint=checkpoint_result.checkpoint,
- metrics=metrics,
- )
- if score_attr is not None and score_attr in checkpoint_result.metrics:
- # If we're ordering by a score, insert the checkpoint
- # so that the list remains sorted.
- _insert_into_sorted_list(
- self._checkpoint_results,
- checkpoint_result,
- key=self._get_checkpoint_score,
- )
- else:
- # If no metric is provided, just append (ordering by time of registration).
- self._checkpoint_results.append(checkpoint_result)
- if self._checkpoint_config.num_to_keep is not None:
- # Delete the bottom (N - K) checkpoints
- worst_results = set(
- self._checkpoint_results[: -self._checkpoint_config.num_to_keep]
- )
- # Except for the latest checkpoint.
- results_to_delete = worst_results - {self._latest_checkpoint_result}
- # Update internal state before actually deleting them.
- self._checkpoint_results = [
- checkpoint_result
- for checkpoint_result in self._checkpoint_results
- if checkpoint_result not in results_to_delete
- ]
- for checkpoint_result in results_to_delete:
- checkpoint = checkpoint_result.checkpoint
- logger.debug("Deleting checkpoint: ", checkpoint)
- _delete_fs_path(fs=checkpoint.filesystem, fs_path=checkpoint.path)
- def _get_checkpoint_score(
- self, checkpoint: _TrainingResult
- ) -> Tuple[bool, numbers.Number]:
- """Get the score for a checkpoint, according to checkpoint config.
- If `mode="min"`, the metric is negated so that the lowest score is
- treated as the best.
- Returns:
- Tuple: A tuple of (not_is_nan: bool, score: numbers.Number).
- This score orders: nan values < float("-inf") < valid numeric metrics
- """
- checkpoint_score_attribute = self._checkpoint_config.checkpoint_score_attribute
- if checkpoint_score_attribute:
- flat_metrics = flatten_dict(checkpoint.metrics)
- try:
- checkpoint_result = flat_metrics[checkpoint_score_attribute]
- except KeyError:
- valid_keys = list(flat_metrics.keys())
- logger.error(
- f"Result dict has no key: {checkpoint_score_attribute}. "
- f"checkpoint_score_attr must be set to a key in the "
- f"result dict. Valid keys are: {valid_keys}"
- )
- checkpoint_result = float("-inf")
- else:
- checkpoint_result = float("-inf")
- checkpoint_score_order = self._checkpoint_config.checkpoint_score_order
- order_factor = 1.0 if checkpoint_score_order == MAX else -1.0
- checkpoint_score = order_factor * checkpoint_result
- if not isinstance(checkpoint_score, numbers.Number):
- raise ValueError(
- f"Unable to persist checkpoint for "
- f"checkpoint_score_attribute: "
- f"{checkpoint_score_attribute} with value "
- f"{checkpoint_score}. "
- f"This attribute must be numerical."
- )
- return (
- (not is_nan(checkpoint_score), checkpoint_score)
- if not is_nan(checkpoint_score)
- else (False, float("-inf"))
- )
- @property
- def best_checkpoint_result(self) -> Optional[_TrainingResult]:
- return self._checkpoint_results[-1] if self._checkpoint_results else None
- @property
- def latest_checkpoint_result(self) -> Optional[_TrainingResult]:
- return self._latest_checkpoint_result
- @property
- def best_checkpoint_results(self) -> List[_TrainingResult]:
- if self._checkpoint_config.num_to_keep is None:
- return self._checkpoint_results
- return self._checkpoint_results[-self._checkpoint_config.num_to_keep :]
|