| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243 |
- import uuid
- from typing import Dict, FrozenSet, List, Optional, Set, Tuple, Type
- import torch
- import ray
- from ray.experimental.channel.common import ChannelContext
- from ray.experimental.channel.communicator import (
- Communicator,
- ReduceOp,
- TorchTensorAllocator,
- )
- class AbstractNcclGroup(Communicator):
- """
- A dummy NCCL group for testing.
- """
- def __init__(self, actor_handles: List[ray.actor.ActorHandle]):
- self._actor_handles = actor_handles
- self._rank = None
- def initialize(self, rank: int) -> None:
- self._rank = rank
- def get_rank(self, actor: ray.actor.ActorHandle) -> int:
- return self._actor_handles.index(actor)
- def get_world_size(self) -> int:
- return len(self._actor_handles)
- def get_self_rank(self) -> Optional[int]:
- return self._rank
- def get_actor_handles(self) -> List["ray.actor.ActorHandle"]:
- return self._actor_handles
- def send(self, value: "torch.Tensor", peer_rank: int) -> None:
- raise NotImplementedError
- def recv(
- self,
- shape: Tuple[int],
- dtype: "torch.dtype",
- peer_rank: int,
- allocator: Optional[TorchTensorAllocator] = None,
- ) -> "torch.Tensor":
- raise NotImplementedError
- def allgather(
- self,
- send_buf: "torch.Tensor",
- recv_buf: "torch.Tensor",
- ) -> None:
- raise NotImplementedError
- def allreduce(
- self,
- send_buf: "torch.Tensor",
- recv_buf: "torch.Tensor",
- op: ReduceOp = ReduceOp.SUM,
- ) -> None:
- raise NotImplementedError
- def reducescatter(
- self,
- send_buf: "torch.Tensor",
- recv_buf: "torch.Tensor",
- op: ReduceOp = ReduceOp.SUM,
- ) -> None:
- raise NotImplementedError
- @property
- def recv_stream(self):
- return None
- @property
- def send_stream(self):
- return None
- def destroy(self) -> None:
- pass
- def get_transport_name(self) -> str:
- return "accelerator"
- @classmethod
- def generate_communicator_id(cls) -> str:
- pass
- class MockNcclGroupSet:
- def __init__(self):
- # Represents a mapping from a NCCL group ID to a set of actors and a custom
- # NCCL group.
- self.ids_to_actors_and_custom_comms: Dict[
- str, Tuple[FrozenSet["ray.actor.ActorHandle"], Optional[Communicator]]
- ] = {}
- def __call__(
- self,
- actors: List["ray.actor.ActorHandle"],
- custom_nccl_group: Optional[Communicator] = None,
- use_communication_streams: bool = False,
- accelerator_module_name: Optional[str] = None,
- accelerator_communicator_cls: Optional[Type[Communicator]] = None,
- ) -> str:
- group_id = str(uuid.uuid4())
- self.ids_to_actors_and_custom_comms[group_id] = (
- frozenset(actors),
- custom_nccl_group,
- )
- if custom_nccl_group is None:
- ranks = list(range(len(actors)))
- else:
- ranks = [custom_nccl_group.get_rank(actor) for actor in actors]
- init_tasks = [
- actor.__ray_call__.remote(
- mock_do_init_nccl_group,
- group_id,
- rank,
- actors,
- custom_nccl_group,
- )
- for rank, actor in zip(ranks, actors)
- ]
- ray.get(init_tasks, timeout=30)
- ctx = ChannelContext.get_current()
- if custom_nccl_group is not None:
- ctx.communicators[group_id] = custom_nccl_group
- else:
- ctx.communicators[group_id] = AbstractNcclGroup(actors)
- return group_id
- def mock_destroy_nccl_group(self, group_id: str) -> None:
- ctx = ChannelContext.get_current()
- if group_id not in ctx.communicators:
- return
- actors, _ = self.ids_to_actors_and_custom_comms[group_id]
- destroy_tasks = [
- actor.__ray_call__.remote(
- mock_do_destroy_nccl_group,
- group_id,
- )
- for actor in actors
- ]
- ray.wait(destroy_tasks, timeout=30)
- if group_id in self.ids_to_actors_and_custom_comms:
- del self.ids_to_actors_and_custom_comms[group_id]
- ctx.communicators[group_id].destroy()
- del ctx.communicators[group_id]
- def check_teardown(self, nccl_group_ids: List[str]) -> None:
- ctx = ChannelContext.get_current()
- for nccl_group_id in nccl_group_ids:
- assert nccl_group_id not in self.ids_to_actors_and_custom_comms
- assert nccl_group_id not in ctx.communicators
- @ray.remote
- class CPUTorchTensorWorker:
- def __init__(self):
- self.device = "cpu"
- def return_tensor(
- self, size: int, dtype: Optional[torch.dtype] = None
- ) -> torch.Tensor:
- return torch.ones(size, dtype=dtype, device=self.device)
- def recv(self, tensor: torch.Tensor) -> Tuple[int, int]:
- assert tensor.device == self.device
- return tensor.shape, tensor[0]
- def recv_tensors(self, *tensors) -> Tuple[torch.Tensor, ...]:
- return tuple(tensors)
- def mock_do_init_nccl_group(
- self,
- group_id: str,
- rank: int,
- actors: List[ray.actor.ActorHandle],
- custom_nccl_group: Optional[Communicator],
- ) -> None:
- ctx = ChannelContext.get_current()
- if custom_nccl_group is None:
- nccl_group = AbstractNcclGroup(actors)
- nccl_group.initialize(rank)
- ctx.communicators[group_id] = nccl_group
- else:
- custom_nccl_group.initialize(rank)
- ctx.communicators[group_id] = custom_nccl_group
- def mock_do_destroy_nccl_group(self, group_id: str) -> None:
- ctx = ChannelContext.get_current()
- if group_id not in ctx.communicators:
- return
- ctx.communicators[group_id].destroy()
- del ctx.communicators[group_id]
- def check_nccl_group_init(
- monkeypatch,
- dag: "ray.dag.DAGNode",
- actors_and_custom_comms: Set[
- Tuple[FrozenSet["ray.actor.ActorHandle"], Optional[Communicator]]
- ],
- ) -> "ray.dag.CompiledDAG":
- mock_nccl_group_set = MockNcclGroupSet()
- monkeypatch.setattr(
- "ray.dag.compiled_dag_node._init_communicator",
- mock_nccl_group_set,
- )
- compiled_dag = dag.experimental_compile()
- assert (
- set(mock_nccl_group_set.ids_to_actors_and_custom_comms.values())
- == actors_and_custom_comms
- )
- return compiled_dag, mock_nccl_group_set
- def check_nccl_group_teardown(
- monkeypatch,
- compiled_dag: "ray.dag.CompiledDAG",
- mock_nccl_group_set: MockNcclGroupSet,
- ):
- monkeypatch.setattr(
- "ray.dag.compiled_dag_node._destroy_communicator",
- mock_nccl_group_set.mock_destroy_nccl_group,
- )
- created_communicator_ids = compiled_dag._actors_to_created_communicator_id.values()
- compiled_dag.teardown()
- mock_nccl_group_set.check_teardown(created_communicator_ids)
|