context.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. from abc import ABC, abstractmethod
  2. from typing import Any, Dict
  3. from ray.train.v2._internal.execution.context import (
  4. get_train_context as get_internal_train_context,
  5. )
  6. from ray.util.annotations import Deprecated, DeveloperAPI, PublicAPI
  7. @PublicAPI(stability="stable")
  8. class TrainContext(ABC):
  9. """Abstract interface for training context."""
  10. @Deprecated
  11. def get_metadata(self) -> Dict[str, Any]:
  12. """[Deprecated] User metadata dict passed to the Trainer constructor."""
  13. from ray.train.context import _GET_METADATA_DEPRECATION_MESSAGE
  14. raise DeprecationWarning(_GET_METADATA_DEPRECATION_MESSAGE)
  15. @Deprecated
  16. def get_trial_name(self) -> str:
  17. """[Deprecated] Trial name for the corresponding trial."""
  18. from ray.train.context import _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE
  19. raise DeprecationWarning(
  20. _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_trial_name")
  21. )
  22. @Deprecated
  23. def get_trial_id(self) -> str:
  24. """[Deprecated] Trial id for the corresponding trial."""
  25. from ray.train.context import _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE
  26. raise DeprecationWarning(
  27. _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_trial_id")
  28. )
  29. @Deprecated
  30. def get_trial_resources(self):
  31. """[Deprecated] Trial resources for the corresponding trial."""
  32. from ray.train.context import _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE
  33. raise DeprecationWarning(
  34. _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_trial_resources")
  35. )
  36. @Deprecated
  37. def get_trial_dir(self) -> str:
  38. """[Deprecated] Log directory corresponding to the trial directory for a Tune session.
  39. This is deprecated for Ray Train and should no longer be called in Ray Train workers.
  40. If this directory is needed, please pass it into the `train_loop_config` directly.
  41. """
  42. from ray.train.context import _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE
  43. raise DeprecationWarning(
  44. _TUNE_SPECIFIC_CONTEXT_DEPRECATION_MESSAGE.format("get_trial_dir")
  45. )
  46. @abstractmethod
  47. def get_experiment_name(self) -> str:
  48. """Experiment name for the corresponding trial."""
  49. pass
  50. @abstractmethod
  51. def get_world_size(self) -> int:
  52. """Get the current world size (i.e. total number of workers) for this run.
  53. .. testcode::
  54. import ray.train
  55. from ray.train.torch import TorchTrainer
  56. NUM_WORKERS = 2
  57. def train_fn_per_worker(config):
  58. assert ray.train.get_context().get_world_size() == NUM_WORKERS
  59. trainer = TorchTrainer(
  60. train_fn_per_worker,
  61. scaling_config=ray.train.ScalingConfig(num_workers=NUM_WORKERS),
  62. )
  63. trainer.fit()
  64. """
  65. pass
  66. @abstractmethod
  67. def get_world_rank(self) -> int:
  68. """Get the world rank of this worker.
  69. .. testcode::
  70. import ray.train
  71. from ray.train.torch import TorchTrainer
  72. def train_fn_per_worker(config):
  73. if ray.train.get_context().get_world_rank() == 0:
  74. print("Worker 0")
  75. trainer = TorchTrainer(
  76. train_fn_per_worker,
  77. scaling_config=ray.train.ScalingConfig(num_workers=2),
  78. )
  79. trainer.fit()
  80. """
  81. pass
  82. @abstractmethod
  83. def get_local_rank(self) -> int:
  84. """Get the local rank of this worker (rank of the worker on its node).
  85. .. testcode::
  86. import ray.train
  87. from ray.train.torch import TorchTrainer
  88. def train_fn_per_worker(config):
  89. if ray.train.get_context().get_local_rank() == 0:
  90. print("Local rank 0 worker")
  91. trainer = TorchTrainer(
  92. train_fn_per_worker,
  93. scaling_config=ray.train.ScalingConfig(num_workers=2),
  94. )
  95. trainer.fit()
  96. """
  97. pass
  98. @abstractmethod
  99. def get_local_world_size(self) -> int:
  100. """Get the local world size of this node (i.e. number of workers on this node).
  101. Example:
  102. .. testcode::
  103. import ray.train
  104. from ray.train.torch import TorchTrainer
  105. def train_fn_per_worker():
  106. print(ray.train.get_context().get_local_world_size())
  107. trainer = TorchTrainer(
  108. train_fn_per_worker,
  109. scaling_config=ray.train.ScalingConfig(num_workers=2),
  110. )
  111. trainer.fit()
  112. """
  113. pass
  114. @abstractmethod
  115. def get_node_rank(self) -> int:
  116. """Get the rank of this node.
  117. Example:
  118. .. testcode::
  119. import ray.train
  120. from ray.train.torch import TorchTrainer
  121. def train_fn_per_worker():
  122. print(ray.train.get_context().get_node_rank())
  123. trainer = TorchTrainer(
  124. train_fn_per_worker,
  125. scaling_config=ray.train.ScalingConfig(num_workers=1),
  126. )
  127. trainer.fit()
  128. """
  129. pass
  130. @DeveloperAPI
  131. @abstractmethod
  132. def get_storage(self):
  133. """Returns the :class:`~ray.train._internal.storage.StorageContext` storage
  134. context which gives advanced access to the filesystem and paths
  135. configured through `RunConfig`.
  136. NOTE: This is a DeveloperAPI, and the `StorageContext` interface may change
  137. without notice between minor versions.
  138. """
  139. pass
  140. @DeveloperAPI
  141. class DistributedTrainContext(TrainContext):
  142. """Implementation of TrainContext for distributed mode."""
  143. def get_experiment_name(self) -> str:
  144. return get_internal_train_context().get_experiment_name()
  145. def get_world_size(self) -> int:
  146. return get_internal_train_context().get_world_size()
  147. def get_world_rank(self) -> int:
  148. return get_internal_train_context().get_world_rank()
  149. def get_local_rank(self) -> int:
  150. return get_internal_train_context().get_local_rank()
  151. def get_local_world_size(self) -> int:
  152. return get_internal_train_context().get_local_world_size()
  153. def get_node_rank(self) -> int:
  154. return get_internal_train_context().get_node_rank()
  155. def get_storage(self):
  156. return get_internal_train_context().get_storage()
  157. @DeveloperAPI
  158. class LocalTrainContext(TrainContext):
  159. """Implementation of TrainContext for local mode."""
  160. def __init__(
  161. self,
  162. experiment_name: str,
  163. world_size: int = 1,
  164. world_rank: int = 0,
  165. local_rank: int = 0,
  166. local_world_size: int = 1,
  167. node_rank: int = 0,
  168. ):
  169. self.experiment_name = experiment_name
  170. self.world_size = world_size
  171. self.world_rank = world_rank
  172. self.local_rank = local_rank
  173. self.local_world_size = local_world_size
  174. self.node_rank = node_rank
  175. def get_experiment_name(self) -> str:
  176. return self.experiment_name
  177. def get_world_size(self) -> int:
  178. return self.world_size
  179. def get_world_rank(self) -> int:
  180. return self.world_rank
  181. def get_local_rank(self) -> int:
  182. return self.local_rank
  183. def get_local_world_size(self) -> int:
  184. return self.local_world_size
  185. def get_node_rank(self) -> int:
  186. return self.node_rank
  187. def get_storage(self):
  188. raise NotImplementedError("Local storage context not yet implemented. ")