accelerator_context.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. import importlib
  2. import threading
  3. from contextlib import nullcontext
  4. from typing import TYPE_CHECKING, ContextManager, List, Optional, Type
  5. import ray
  6. from ray._private.accelerators import get_accelerator_manager_for_resource
  7. from ray.experimental.channel.communicator import Communicator
  8. if TYPE_CHECKING:
  9. import torch
  10. # The accelerator context singleton on this process.
  11. _accelerator_context_lock = threading.Lock()
  12. _default_accelerator_context: Optional["AcceleratorContext"] = None
  13. _global_custom_context: Optional["AcceleratorContext"] = None
  14. class AcceleratorContext:
  15. """
  16. Provides a unified interface for managing different accelerator backends
  17. This includes stream management, event creation, device context control,
  18. and communicator support for distributed communication.
  19. """
  20. def __init__(self, torch_module_name: str, communicator_cls: Type[Communicator]):
  21. """
  22. Initializes an accelerator context with the specified torch device module
  23. and communicator class.
  24. Args:
  25. torch_module_name: Name of the torch device module (e.g., "cuda", "cpu").
  26. communicator_cls: Class used to handle communication.
  27. """
  28. # The name of the torch module (e.g., 'cuda', 'npu')
  29. self._torch_module_name: str = torch_module_name
  30. # The Communicator class used to manage communication
  31. self._communicator_cls: Type[Communicator] = communicator_cls
  32. # Import the torch backend module (e.g., torch.cuda) if the device is not 'cpu'.
  33. if torch_module_name != "cpu":
  34. self._torch_mod = importlib.import_module(f"torch.{torch_module_name}")
  35. @staticmethod
  36. def get() -> "AcceleratorContext":
  37. """
  38. Returns the singleton instance of the accelerator context.
  39. If a custom accelerator has been registered, initializes the context
  40. based on the registration. Otherwise, selects an appropriate runtime
  41. based on the available device (CUDA or CPU) and registers the
  42. corresponding default communicator.
  43. Returns:
  44. AcceleratorContext: A singleton instance of the appropriate
  45. runtime context.
  46. """
  47. global _default_accelerator_context, _global_custom_context
  48. with _accelerator_context_lock:
  49. if _global_custom_context is not None:
  50. return _global_custom_context
  51. if _default_accelerator_context is None:
  52. if len(ray.get_gpu_ids()) > 0:
  53. from ray.experimental.channel.nccl_group import _NcclGroup
  54. _default_accelerator_context = AcceleratorContext(
  55. "cuda", _NcclGroup
  56. )
  57. else:
  58. from ray.experimental.channel.cpu_communicator import (
  59. CPUCommunicator,
  60. )
  61. _default_accelerator_context = AcceleratorContext(
  62. "cpu", CPUCommunicator
  63. )
  64. return _default_accelerator_context
  65. @staticmethod
  66. def set(accelerator_context: "AcceleratorContext") -> None:
  67. """
  68. Overwrites the default accelerator context.
  69. Args:
  70. accelerator_context: The context to register.
  71. """
  72. global _global_custom_context
  73. # Accelerator context is registered.
  74. _global_custom_context = accelerator_context
  75. def get_accelerator_devices(self) -> List["torch.device"]:
  76. """
  77. Gets the torch device list configured for this process.
  78. Returns:
  79. List[torch.device]: The torch device list.
  80. """
  81. import torch
  82. if self._torch_module_name == "cpu":
  83. return [torch.device("cpu")]
  84. if self._torch_module_name == "cuda":
  85. accelerator_ids = [str(id) for id in ray.get_gpu_ids()]
  86. accelerator_manager = get_accelerator_manager_for_resource("GPU")
  87. else:
  88. accelerator_ids = [
  89. str(id)
  90. for id in ray.get_runtime_context().get_accelerator_ids()[
  91. self._torch_module_name.upper()
  92. ]
  93. ]
  94. accelerator_manager = get_accelerator_manager_for_resource(
  95. self._torch_module_name.upper()
  96. )
  97. device_ids = []
  98. if len(accelerator_ids) > 0:
  99. accelerator_visible_list = (
  100. accelerator_manager.get_current_process_visible_accelerator_ids()
  101. )
  102. if accelerator_visible_list is None:
  103. accelerator_visible_list = []
  104. # If there are multiple Accelerators, return a list of devices.
  105. # If using fractional Accelerators, these IDs are not guaranteed
  106. # to be unique across different processes.
  107. for accelerator_id in accelerator_ids:
  108. try:
  109. device_ids.append(accelerator_visible_list.index(accelerator_id))
  110. except ValueError:
  111. raise RuntimeError(
  112. f"{accelerator_manager.get_visible_accelerator_ids_env_var()} set incorrectly. "
  113. f"expected to include {accelerator_id}. "
  114. "Did you override this environment"
  115. " variable? If not, please help file an issue on Github."
  116. )
  117. else:
  118. # If called on the driver or outside of Ray Train, return the
  119. # 0th device.
  120. device_ids.append(0)
  121. return [
  122. torch.device(f"{self._torch_module_name}:{device_id}")
  123. for device_id in device_ids
  124. ]
  125. def get_device_context(self, device: "torch.device") -> ContextManager:
  126. """
  127. Retrieves the context manager for the specified accelerator device.
  128. There is no device context for CPU, returning a nullcontext.
  129. Args:
  130. device: The target device for which the context manager is required.
  131. Returns:
  132. ContextManager: A context manager specific to the device type.
  133. """
  134. if device.type == "cpu":
  135. return nullcontext()
  136. return self._torch_mod.device(device)
  137. def current_stream(self):
  138. """
  139. Retrieves the current execution stream for the accelerator device.
  140. """
  141. return self._torch_mod.current_stream()
  142. def create_event(self):
  143. """
  144. Creates an event object for the accelerator device.
  145. """
  146. return self._torch_mod.Event()
  147. def generate_communicator_id(self) -> str:
  148. """
  149. Generates a communication identifier for communication group.
  150. """
  151. return self._communicator_cls.generate_communicator_id()
  152. def create_communicator(self, *args, **kwargs) -> Communicator:
  153. """
  154. Creates a communication group for collective operations.
  155. """
  156. return self._communicator_cls(*args, **kwargs)
  157. @property
  158. def module_name(self) -> str:
  159. """
  160. Gets the name of the torch module backing the accelerator.
  161. """
  162. return self._torch_module_name
  163. @property
  164. def communicator_cls(self) -> Optional[Type[Communicator]]:
  165. """
  166. Returns the communicator class.
  167. """
  168. return self._communicator_cls
  169. @property
  170. def accelerator_count(self) -> int:
  171. """
  172. Returns the number of accelerators assigned by ray.
  173. """
  174. if self._torch_module_name == "cuda":
  175. return len(ray.get_gpu_ids())
  176. else:
  177. accelerator_ids = ray.get_runtime_context().get_accelerator_ids()
  178. return len(accelerator_ids.get(self._torch_module_name.upper(), []))
  179. def register_accelerator_context(
  180. torch_module_name: str, communicator_cls: Type[Communicator]
  181. ):
  182. """
  183. Registers the accelerator context with the specified device type and communicator.
  184. Args:
  185. torch_module_name: The name of the device module under torch.
  186. communicator_cls: The communicator class associated with the device.
  187. """
  188. accelerator_context = AcceleratorContext(torch_module_name, communicator_cls)
  189. AcceleratorContext.set(accelerator_context)
  190. def is_accelerator_context_registered():
  191. """
  192. Checks whether a custom accelerator context has been registered.
  193. Returns:
  194. bool: True if a custom accelerator context is registered
  195. (_global_custom_context is not None), False otherwise.
  196. """
  197. if _global_custom_context is not None:
  198. return True
  199. return False