context.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import threading
  2. from typing import Any, Dict, Optional
  3. from ray.train._internal import session
  4. from ray.train.constants import (
  5. V2_MIGRATION_GUIDE_MESSAGE,
  6. _v2_migration_warnings_enabled,
  7. )
  8. from ray.train.context import TrainContext as TrainV1Context
  9. from ray.train.utils import _copy_doc
  10. from ray.tune.execution.placement_groups import PlacementGroupFactory
  11. from ray.util.annotations import Deprecated, PublicAPI
  12. # The context singleton on this process.
  13. _tune_context: Optional["TuneContext"] = None
  14. _tune_context_lock = threading.Lock()
  15. _TRAIN_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE = (
  16. "`{}` is deprecated for Ray Tune because there is no concept of worker ranks "
  17. "for Ray Tune, so these methods only make sense to use in the context of "
  18. f"a Ray Train worker. {V2_MIGRATION_GUIDE_MESSAGE}"
  19. )
  20. @PublicAPI(stability="beta")
  21. class TuneContext(TrainV1Context):
  22. """Context to access metadata within Ray Tune functions."""
  23. # NOTE: These methods are deprecated on the TrainContext, but are still
  24. # available on the TuneContext. Re-defining them here to avoid the
  25. # deprecation warnings.
  26. @_copy_doc(session.get_trial_name)
  27. def get_trial_name(self) -> str:
  28. return session.get_trial_name()
  29. @_copy_doc(session.get_trial_id)
  30. def get_trial_id(self) -> str:
  31. return session.get_trial_id()
  32. @_copy_doc(session.get_trial_resources)
  33. def get_trial_resources(self) -> PlacementGroupFactory:
  34. return session.get_trial_resources()
  35. @_copy_doc(session.get_trial_dir)
  36. def get_trial_dir(self) -> str:
  37. return session.get_trial_dir()
  38. # Deprecated APIs
  39. @Deprecated
  40. def get_metadata(self) -> Dict[str, Any]:
  41. raise DeprecationWarning(
  42. "`get_metadata` is deprecated for Ray Tune, as it has never been usable."
  43. )
  44. @Deprecated(
  45. message=_TRAIN_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_world_size"),
  46. warning=_v2_migration_warnings_enabled(),
  47. )
  48. @_copy_doc(TrainV1Context.get_world_size)
  49. def get_world_size(self) -> int:
  50. return session.get_world_size()
  51. @Deprecated(
  52. message=_TRAIN_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_world_rank"),
  53. warning=_v2_migration_warnings_enabled(),
  54. )
  55. @_copy_doc(TrainV1Context.get_world_rank)
  56. def get_world_rank(self) -> int:
  57. return session.get_world_rank()
  58. @Deprecated(
  59. message=_TRAIN_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_local_rank"),
  60. warning=_v2_migration_warnings_enabled(),
  61. )
  62. @_copy_doc(TrainV1Context.get_local_rank)
  63. def get_local_rank(self) -> int:
  64. return session.get_local_rank()
  65. @Deprecated(
  66. message=_TRAIN_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format(
  67. "get_local_world_size"
  68. ),
  69. warning=_v2_migration_warnings_enabled(),
  70. )
  71. @_copy_doc(TrainV1Context.get_local_world_size)
  72. def get_local_world_size(self) -> int:
  73. return session.get_local_world_size()
  74. @Deprecated(
  75. message=_TRAIN_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_node_rank"),
  76. warning=_v2_migration_warnings_enabled(),
  77. )
  78. @_copy_doc(TrainV1Context.get_node_rank)
  79. def get_node_rank(self) -> int:
  80. return session.get_node_rank()
  81. @PublicAPI(stability="beta")
  82. def get_context() -> TuneContext:
  83. """Get or create a singleton Ray Tune context.
  84. The context is only available in a tune function passed to the `ray.tune.Tuner`.
  85. See the :class:`~ray.tune.TuneContext` API reference to see available methods.
  86. """
  87. global _tune_context
  88. with _tune_context_lock:
  89. if _tune_context is None:
  90. # TODO(justinvyu): This default should be a dummy context
  91. # that is only used for testing / running outside of Tune.
  92. _tune_context = TuneContext()
  93. return _tune_context