trainable_fn_utils.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. from typing import Dict, Optional
  2. from ray.train._checkpoint import Checkpoint as TrainCheckpoint
  3. from ray.train._internal.session import _warn_session_misuse, get_session
  4. from ray.train.constants import (
  5. V2_MIGRATION_GUIDE_MESSAGE,
  6. _v2_migration_warnings_enabled,
  7. )
  8. from ray.train.utils import _copy_doc, _log_deprecation_warning
  9. from ray.util.annotations import PublicAPI
  10. @_copy_doc(TrainCheckpoint)
  11. class Checkpoint(TrainCheckpoint):
  12. # NOTE: This is just a pass-through wrapper around `ray.train.Checkpoint`
  13. # in order to detect whether the import module was correct `ray.tune.Checkpoint`.
  14. pass
  15. @PublicAPI(stability="stable")
  16. @_warn_session_misuse()
  17. def report(metrics: Dict, *, checkpoint: Optional[Checkpoint] = None) -> None:
  18. """Report metrics and optionally save and register a checkpoint to Ray Tune.
  19. If a checkpoint is provided, it will be
  20. :ref:`persisted to storage <persistent-storage-guide>`.
  21. .. note::
  22. Each invocation of this method will automatically increment the underlying
  23. ``training_iteration`` number. The physical meaning of this "iteration" is
  24. defined by user depending on how often they call ``report``.
  25. It does not necessarily map to one epoch.
  26. Args:
  27. metrics: The metrics you want to report.
  28. checkpoint: The optional checkpoint you want to report.
  29. """
  30. if checkpoint and not isinstance(checkpoint, Checkpoint):
  31. if _v2_migration_warnings_enabled():
  32. _log_deprecation_warning(
  33. "The `Checkpoint` class should be imported from `ray.tune` "
  34. "when passing it to `ray.tune.report` in a Tune function. "
  35. "Please update your imports. "
  36. f"{V2_MIGRATION_GUIDE_MESSAGE}"
  37. )
  38. get_session().report(metrics, checkpoint=checkpoint)
  39. @PublicAPI(stability="stable")
  40. @_warn_session_misuse()
  41. def get_checkpoint() -> Optional[Checkpoint]:
  42. """Access the latest reported checkpoint to resume from if one exists."""
  43. return get_session().loaded_checkpoint
  44. def _in_tune_session() -> bool:
  45. return get_session() and get_session().world_rank is None