context.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import threading
  2. from typing import TYPE_CHECKING, Any, Dict, Optional
  3. from ray.train._internal import session
  4. from ray.train._internal.storage import StorageContext
  5. from ray.train.constants import (
  6. V2_MIGRATION_GUIDE_MESSAGE,
  7. _v2_migration_warnings_enabled,
  8. )
  9. from ray.train.utils import _copy_doc, _log_deprecation_warning
  10. from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI
  11. if TYPE_CHECKING:
  12. from ray.tune.execution.placement_groups import PlacementGroupFactory
  13. # The context singleton on this process.
  14. _default_context: "Optional[TrainContext]" = None
  15. _context_lock = threading.Lock()
  16. _GET_METADATA_DEPRECATION_MESSAGE = (
  17. "`get_metadata` was an experimental API that accessed the metadata passed "
  18. "to `<Framework>Trainer(metadata=...)`. This API can be replaced by passing "
  19. "the metadata directly to the training function (e.g., via `train_loop_config`). "
  20. f"{V2_MIGRATION_GUIDE_MESSAGE}"
  21. )
  22. _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE = (
  23. "`{}` is deprecated because the concept of a `Trial` will "
  24. "soon be removed in Ray Train."
  25. "Ray Train will no longer assume that it's running within a Ray Tune `Trial` "
  26. "in the future. "
  27. f"{V2_MIGRATION_GUIDE_MESSAGE}"
  28. )
  29. @PublicAPI(stability="stable")
  30. class TrainContext:
  31. """Context containing metadata that can be accessed within Ray Train workers."""
  32. @_copy_doc(session.get_experiment_name)
  33. def get_experiment_name(self) -> str:
  34. return session.get_experiment_name()
  35. @_copy_doc(session.get_world_size)
  36. def get_world_size(self) -> int:
  37. return session.get_world_size()
  38. @_copy_doc(session.get_world_rank)
  39. def get_world_rank(self) -> int:
  40. return session.get_world_rank()
  41. @_copy_doc(session.get_local_rank)
  42. def get_local_rank(self) -> int:
  43. return session.get_local_rank()
  44. @_copy_doc(session.get_local_world_size)
  45. def get_local_world_size(self) -> int:
  46. return session.get_local_world_size()
  47. @_copy_doc(session.get_node_rank)
  48. def get_node_rank(self) -> int:
  49. return session.get_node_rank()
  50. @DeveloperAPI
  51. @_copy_doc(session.get_storage)
  52. def get_storage(self) -> StorageContext:
  53. return session.get_storage()
  54. # Deprecated APIs
  55. @Deprecated(
  56. message=_GET_METADATA_DEPRECATION_MESSAGE,
  57. warning=_v2_migration_warnings_enabled(),
  58. )
  59. @_copy_doc(session.get_metadata)
  60. def get_metadata(self) -> Dict[str, Any]:
  61. return session.get_metadata()
  62. @Deprecated(
  63. message=_TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_trial_name"),
  64. warning=_v2_migration_warnings_enabled(),
  65. )
  66. @_copy_doc(session.get_trial_name)
  67. def get_trial_name(self) -> str:
  68. return session.get_trial_name()
  69. @Deprecated(
  70. message=_TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_trial_id"),
  71. warning=_v2_migration_warnings_enabled(),
  72. )
  73. @_copy_doc(session.get_trial_id)
  74. def get_trial_id(self) -> str:
  75. return session.get_trial_id()
  76. @Deprecated(
  77. message=_TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format(
  78. "get_trial_resources"
  79. ),
  80. warning=_v2_migration_warnings_enabled(),
  81. )
  82. @_copy_doc(session.get_trial_resources)
  83. def get_trial_resources(self) -> "PlacementGroupFactory":
  84. return session.get_trial_resources()
  85. @Deprecated(
  86. message=_TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_trial_dir"),
  87. warning=_v2_migration_warnings_enabled(),
  88. )
  89. @_copy_doc(session.get_trial_dir)
  90. def get_trial_dir(self) -> str:
  91. return session.get_trial_dir()
  92. @PublicAPI(stability="stable")
  93. def get_context() -> TrainContext:
  94. """Get or create a singleton training context.
  95. The context is only available within a function passed to Ray Train.
  96. See the :class:`~ray.train.TrainContext` API reference to see available methods.
  97. """
  98. from ray.tune.trainable.trainable_fn_utils import _in_tune_session
  99. # If we are running in a Tune function, switch to Tune context.
  100. if _in_tune_session():
  101. from ray.tune import get_context as get_tune_context
  102. if _v2_migration_warnings_enabled():
  103. _log_deprecation_warning(
  104. "`ray.train.get_context()` should be switched to "
  105. "`ray.tune.get_context()` when running in a function "
  106. "passed to Ray Tune. This will be an error in the future. "
  107. f"{V2_MIGRATION_GUIDE_MESSAGE}"
  108. )
  109. return get_tune_context()
  110. global _default_context
  111. with _context_lock:
  112. if _default_context is None:
  113. _default_context = TrainContext()
  114. return _default_context