| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220 |
- import threading
- import uuid
- from typing import Dict, List, Optional, Union
- import ray
- import ray.experimental.internal_kv as internal_kv
- from ray.experimental.collective.communicator import CommunicatorHandle
- from ray.util.annotations import PublicAPI
- from ray.util.collective.collective import get_address_and_port
- from ray.util.collective.collective_group.torch_gloo_collective_group import (
- get_master_address_metadata_key,
- )
- from ray.util.collective.types import Backend
- _remote_communicator_manager: "Optional[RemoteCommunicatorManager]" = None
- _remote_communicator_manager_lock = threading.Lock()
- class RemoteCommunicatorManager:
- """Singleton class to store the mapping between actors and communicators
- that the actors are a part of.
- """
- def __init__(self):
- # Handles to communicators that we created. Key is a user-provided
- # name or UUID.
- self._remote_communicators: Dict[str, CommunicatorHandle] = {}
- @staticmethod
- def get() -> "RemoteCommunicatorManager":
- global _remote_communicator_manager
- with _remote_communicator_manager_lock:
- if _remote_communicator_manager is None:
- _remote_communicator_manager = RemoteCommunicatorManager()
- return _remote_communicator_manager
- def add_remote_communicator(self, comm_handle: CommunicatorHandle):
- self._remote_communicators[comm_handle.name] = comm_handle
- def remove_remote_communicator(self, name: str):
- return self._remote_communicators.pop(name, None)
- def get_collective_groups(
- self,
- actors: Optional[List[ray.actor.ActorHandle]] = None,
- backend: Optional[Backend] = None,
- ):
- """
- Get the collective groups that the given actors are a subset of. Filter by
- backend if provided.
- """
- actors = actors or []
- actors = set(actors)
- collectives = []
- # Find all collective groups that the given actors are a subset
- # of, with the matching backend if provided.
- for collective in self._remote_communicators.values():
- if actors.issubset(set(collective.actors)):
- if backend is None or collective.backend == backend:
- collectives.append(collective)
- return collectives
- @PublicAPI(stability="alpha")
- def get_collective_groups(
- actors: List[ray.actor.ActorHandle], backend: Optional[str] = None
- ) -> List[CommunicatorHandle]:
- """
- Get the collective groups that the given actors are a subset of. Filter by
- backend if provided.
- Args:
- actors: List of actors. Return handles to all collective groups that
- these actors are a subset of.
- backend: An optional backend to filter by. See
- ray.util.collective.types.Backend for valid backends.
- Returns:
- A list of communicator handles that the actors are a subset of.
- """
- manager = RemoteCommunicatorManager.get()
- backend = Backend(backend) if backend is not None else None
- return manager.get_collective_groups(actors, backend)
- @PublicAPI(stability="alpha")
- def create_collective_group(
- actors: List[ray.actor.ActorHandle],
- backend: str,
- name: Optional[str] = None,
- ) -> CommunicatorHandle:
- """Create a collective group on the given list of actors. If this function
- returns successfully, then the collective group has been initialized on all
- actors, using the given order of actors as the ranks.
- Currently, an actor can only participate in one collective group per
- backend at a time. To reuse an actor, destroy its collective group and
- create a new one.
- Args:
- actors: The actors to participate in the collective group.
- backend: The backend to use. See ray.util.collective.types.Backend for
- valid backends.
- name: A name to use for the collective group. If None is provided, a
- random name will be generated.
- Returns:
- Handle to the communicator.
- """
- manager = RemoteCommunicatorManager.get()
- if name is None:
- name = str(uuid.uuid4())
- # Validate the backend.
- backend = Backend(backend)
- world_size = len(actors)
- for actor in actors:
- if manager.get_collective_groups([actor], backend):
- raise RuntimeError(
- f"Actor {actor} already in group for backend {backend}. Actors can currently only participate in at most one group per backend."
- )
- actor_ids = [actor._ray_actor_id for actor in actors]
- if len(set(actor_ids)) != len(actor_ids):
- raise ValueError(f"All actors must be unique, got: {actors}")
- metadata_key = None
- if backend == Backend.GLOO:
- # Perform extra setup for torch.distributed.
- # torch.distributed requires a master address and port. Find a suitable
- # port on one of the actors.
- master_addr, master_port = ray.get(
- actors[0].__ray_call__.remote(lambda self: get_address_and_port())
- )
- # Store the metadata on a named actor that all of the other
- # actors can access.
- metadata_key = get_master_address_metadata_key(name)
- internal_kv._internal_kv_put(metadata_key, f"{master_addr}:{master_port}")
- def _do_init_collective_group(self, rank: int):
- ray.util.collective.init_collective_group(
- world_size, rank, backend, group_name=name
- )
- try:
- init_tasks = [
- actor.__ray_call__.remote(
- _do_init_collective_group,
- rank,
- )
- for rank, actor in enumerate(actors)
- ]
- ray.get(init_tasks)
- finally:
- # Clean up the metadata once collective group is initialized
- # (or failed to initialize).
- if metadata_key is not None:
- internal_kv._internal_kv_del(metadata_key)
- # Group was successfully created.
- comm = CommunicatorHandle(actors, name, backend)
- manager.add_remote_communicator(comm)
- return comm
- @PublicAPI(stability="alpha")
- def destroy_collective_group(group_or_name: Union[CommunicatorHandle, str]):
- """
- Destroy a collective group. If this functions returns successfully, then
- the actors that were in the collective can be reused to create a new
- collective group.
- Args:
- group_or_name: Either a communicator handle or the name of the group to
- destroy.
- """
- if isinstance(group_or_name, CommunicatorHandle):
- name = group_or_name.name
- elif isinstance(group_or_name, str):
- name = group_or_name
- else:
- raise ValueError("Expected CommunicatorHandle or str (group name).")
- manager = RemoteCommunicatorManager.get()
- group = manager.remove_remote_communicator(name)
- if group is not None:
- def _do_destroy_collective_group(self):
- ray.util.collective.destroy_collective_group(name)
- destroy_tasks = [
- actor.__ray_call__.options(concurrency_group="_ray_system").remote(
- _do_destroy_collective_group
- )
- for actor in group.actors
- ]
- try:
- ray.get(destroy_tasks)
- except ray.exceptions.ActorDiedError:
- pass
- else:
- raise ValueError(f"No group with name {name} found.")
- @PublicAPI(stability="alpha")
- def destroy_all_collective_groups():
- """
- Destroy all collective groups. This will destroy all collective groups that
- were previously created by this process. After this function returns, the
- actors participating in those collective groups can be reused to create a
- new collective group.
- """
- manager = RemoteCommunicatorManager.get()
- for collective in manager.get_collective_groups():
- destroy_collective_group(collective.name)
|