collective.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import threading
  2. import uuid
  3. from typing import Dict, List, Optional, Union
  4. import ray
  5. import ray.experimental.internal_kv as internal_kv
  6. from ray.experimental.collective.communicator import CommunicatorHandle
  7. from ray.util.annotations import PublicAPI
  8. from ray.util.collective.collective import get_address_and_port
  9. from ray.util.collective.collective_group.torch_gloo_collective_group import (
  10. get_master_address_metadata_key,
  11. )
  12. from ray.util.collective.types import Backend
  13. _remote_communicator_manager: "Optional[RemoteCommunicatorManager]" = None
  14. _remote_communicator_manager_lock = threading.Lock()
  15. class RemoteCommunicatorManager:
  16. """Singleton class to store the mapping between actors and communicators
  17. that the actors are a part of.
  18. """
  19. def __init__(self):
  20. # Handles to communicators that we created. Key is a user-provided
  21. # name or UUID.
  22. self._remote_communicators: Dict[str, CommunicatorHandle] = {}
  23. @staticmethod
  24. def get() -> "RemoteCommunicatorManager":
  25. global _remote_communicator_manager
  26. with _remote_communicator_manager_lock:
  27. if _remote_communicator_manager is None:
  28. _remote_communicator_manager = RemoteCommunicatorManager()
  29. return _remote_communicator_manager
  30. def add_remote_communicator(self, comm_handle: CommunicatorHandle):
  31. self._remote_communicators[comm_handle.name] = comm_handle
  32. def remove_remote_communicator(self, name: str):
  33. return self._remote_communicators.pop(name, None)
  34. def get_collective_groups(
  35. self,
  36. actors: Optional[List[ray.actor.ActorHandle]] = None,
  37. backend: Optional[Backend] = None,
  38. ):
  39. """
  40. Get the collective groups that the given actors are a subset of. Filter by
  41. backend if provided.
  42. """
  43. actors = actors or []
  44. actors = set(actors)
  45. collectives = []
  46. # Find all collective groups that the given actors are a subset
  47. # of, with the matching backend if provided.
  48. for collective in self._remote_communicators.values():
  49. if actors.issubset(set(collective.actors)):
  50. if backend is None or collective.backend == backend:
  51. collectives.append(collective)
  52. return collectives
  53. @PublicAPI(stability="alpha")
  54. def get_collective_groups(
  55. actors: List[ray.actor.ActorHandle], backend: Optional[str] = None
  56. ) -> List[CommunicatorHandle]:
  57. """
  58. Get the collective groups that the given actors are a subset of. Filter by
  59. backend if provided.
  60. Args:
  61. actors: List of actors. Return handles to all collective groups that
  62. these actors are a subset of.
  63. backend: An optional backend to filter by. See
  64. ray.util.collective.types.Backend for valid backends.
  65. Returns:
  66. A list of communicator handles that the actors are a subset of.
  67. """
  68. manager = RemoteCommunicatorManager.get()
  69. backend = Backend(backend) if backend is not None else None
  70. return manager.get_collective_groups(actors, backend)
  71. @PublicAPI(stability="alpha")
  72. def create_collective_group(
  73. actors: List[ray.actor.ActorHandle],
  74. backend: str,
  75. name: Optional[str] = None,
  76. ) -> CommunicatorHandle:
  77. """Create a collective group on the given list of actors. If this function
  78. returns successfully, then the collective group has been initialized on all
  79. actors, using the given order of actors as the ranks.
  80. Currently, an actor can only participate in one collective group per
  81. backend at a time. To reuse an actor, destroy its collective group and
  82. create a new one.
  83. Args:
  84. actors: The actors to participate in the collective group.
  85. backend: The backend to use. See ray.util.collective.types.Backend for
  86. valid backends.
  87. name: A name to use for the collective group. If None is provided, a
  88. random name will be generated.
  89. Returns:
  90. Handle to the communicator.
  91. """
  92. manager = RemoteCommunicatorManager.get()
  93. if name is None:
  94. name = str(uuid.uuid4())
  95. # Validate the backend.
  96. backend = Backend(backend)
  97. world_size = len(actors)
  98. for actor in actors:
  99. if manager.get_collective_groups([actor], backend):
  100. raise RuntimeError(
  101. f"Actor {actor} already in group for backend {backend}. Actors can currently only participate in at most one group per backend."
  102. )
  103. actor_ids = [actor._ray_actor_id for actor in actors]
  104. if len(set(actor_ids)) != len(actor_ids):
  105. raise ValueError(f"All actors must be unique, got: {actors}")
  106. metadata_key = None
  107. if backend == Backend.GLOO:
  108. # Perform extra setup for torch.distributed.
  109. # torch.distributed requires a master address and port. Find a suitable
  110. # port on one of the actors.
  111. master_addr, master_port = ray.get(
  112. actors[0].__ray_call__.remote(lambda self: get_address_and_port())
  113. )
  114. # Store the metadata on a named actor that all of the other
  115. # actors can access.
  116. metadata_key = get_master_address_metadata_key(name)
  117. internal_kv._internal_kv_put(metadata_key, f"{master_addr}:{master_port}")
  118. def _do_init_collective_group(self, rank: int):
  119. ray.util.collective.init_collective_group(
  120. world_size, rank, backend, group_name=name
  121. )
  122. try:
  123. init_tasks = [
  124. actor.__ray_call__.remote(
  125. _do_init_collective_group,
  126. rank,
  127. )
  128. for rank, actor in enumerate(actors)
  129. ]
  130. ray.get(init_tasks)
  131. finally:
  132. # Clean up the metadata once collective group is initialized
  133. # (or failed to initialize).
  134. if metadata_key is not None:
  135. internal_kv._internal_kv_del(metadata_key)
  136. # Group was successfully created.
  137. comm = CommunicatorHandle(actors, name, backend)
  138. manager.add_remote_communicator(comm)
  139. return comm
  140. @PublicAPI(stability="alpha")
  141. def destroy_collective_group(group_or_name: Union[CommunicatorHandle, str]):
  142. """
  143. Destroy a collective group. If this functions returns successfully, then
  144. the actors that were in the collective can be reused to create a new
  145. collective group.
  146. Args:
  147. group_or_name: Either a communicator handle or the name of the group to
  148. destroy.
  149. """
  150. if isinstance(group_or_name, CommunicatorHandle):
  151. name = group_or_name.name
  152. elif isinstance(group_or_name, str):
  153. name = group_or_name
  154. else:
  155. raise ValueError("Expected CommunicatorHandle or str (group name).")
  156. manager = RemoteCommunicatorManager.get()
  157. group = manager.remove_remote_communicator(name)
  158. if group is not None:
  159. def _do_destroy_collective_group(self):
  160. ray.util.collective.destroy_collective_group(name)
  161. destroy_tasks = [
  162. actor.__ray_call__.options(concurrency_group="_ray_system").remote(
  163. _do_destroy_collective_group
  164. )
  165. for actor in group.actors
  166. ]
  167. try:
  168. ray.get(destroy_tasks)
  169. except ray.exceptions.ActorDiedError:
  170. pass
  171. else:
  172. raise ValueError(f"No group with name {name} found.")
  173. @PublicAPI(stability="alpha")
  174. def destroy_all_collective_groups():
  175. """
  176. Destroy all collective groups. This will destroy all collective groups that
  177. were previously created by this process. After this function returns, the
  178. actors participating in those collective groups can be reused to create a
  179. new collective group.
  180. """
  181. manager = RemoteCommunicatorManager.get()
  182. for collective in manager.get_collective_groups():
  183. destroy_collective_group(collective.name)