conftest.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. import uuid
  2. from typing import Dict, FrozenSet, List, Optional, Set, Tuple, Type
  3. import torch
  4. import ray
  5. from ray.experimental.channel.common import ChannelContext
  6. from ray.experimental.channel.communicator import (
  7. Communicator,
  8. ReduceOp,
  9. TorchTensorAllocator,
  10. )
  11. class AbstractNcclGroup(Communicator):
  12. """
  13. A dummy NCCL group for testing.
  14. """
  15. def __init__(self, actor_handles: List[ray.actor.ActorHandle]):
  16. self._actor_handles = actor_handles
  17. self._rank = None
  18. def initialize(self, rank: int) -> None:
  19. self._rank = rank
  20. def get_rank(self, actor: ray.actor.ActorHandle) -> int:
  21. return self._actor_handles.index(actor)
  22. def get_world_size(self) -> int:
  23. return len(self._actor_handles)
  24. def get_self_rank(self) -> Optional[int]:
  25. return self._rank
  26. def get_actor_handles(self) -> List["ray.actor.ActorHandle"]:
  27. return self._actor_handles
  28. def send(self, value: "torch.Tensor", peer_rank: int) -> None:
  29. raise NotImplementedError
  30. def recv(
  31. self,
  32. shape: Tuple[int],
  33. dtype: "torch.dtype",
  34. peer_rank: int,
  35. allocator: Optional[TorchTensorAllocator] = None,
  36. ) -> "torch.Tensor":
  37. raise NotImplementedError
  38. def allgather(
  39. self,
  40. send_buf: "torch.Tensor",
  41. recv_buf: "torch.Tensor",
  42. ) -> None:
  43. raise NotImplementedError
  44. def allreduce(
  45. self,
  46. send_buf: "torch.Tensor",
  47. recv_buf: "torch.Tensor",
  48. op: ReduceOp = ReduceOp.SUM,
  49. ) -> None:
  50. raise NotImplementedError
  51. def reducescatter(
  52. self,
  53. send_buf: "torch.Tensor",
  54. recv_buf: "torch.Tensor",
  55. op: ReduceOp = ReduceOp.SUM,
  56. ) -> None:
  57. raise NotImplementedError
  58. @property
  59. def recv_stream(self):
  60. return None
  61. @property
  62. def send_stream(self):
  63. return None
  64. def destroy(self) -> None:
  65. pass
  66. def get_transport_name(self) -> str:
  67. return "accelerator"
  68. @classmethod
  69. def generate_communicator_id(cls) -> str:
  70. pass
  71. class MockNcclGroupSet:
  72. def __init__(self):
  73. # Represents a mapping from a NCCL group ID to a set of actors and a custom
  74. # NCCL group.
  75. self.ids_to_actors_and_custom_comms: Dict[
  76. str, Tuple[FrozenSet["ray.actor.ActorHandle"], Optional[Communicator]]
  77. ] = {}
  78. def __call__(
  79. self,
  80. actors: List["ray.actor.ActorHandle"],
  81. custom_nccl_group: Optional[Communicator] = None,
  82. use_communication_streams: bool = False,
  83. accelerator_module_name: Optional[str] = None,
  84. accelerator_communicator_cls: Optional[Type[Communicator]] = None,
  85. ) -> str:
  86. group_id = str(uuid.uuid4())
  87. self.ids_to_actors_and_custom_comms[group_id] = (
  88. frozenset(actors),
  89. custom_nccl_group,
  90. )
  91. if custom_nccl_group is None:
  92. ranks = list(range(len(actors)))
  93. else:
  94. ranks = [custom_nccl_group.get_rank(actor) for actor in actors]
  95. init_tasks = [
  96. actor.__ray_call__.remote(
  97. mock_do_init_nccl_group,
  98. group_id,
  99. rank,
  100. actors,
  101. custom_nccl_group,
  102. )
  103. for rank, actor in zip(ranks, actors)
  104. ]
  105. ray.get(init_tasks, timeout=30)
  106. ctx = ChannelContext.get_current()
  107. if custom_nccl_group is not None:
  108. ctx.communicators[group_id] = custom_nccl_group
  109. else:
  110. ctx.communicators[group_id] = AbstractNcclGroup(actors)
  111. return group_id
  112. def mock_destroy_nccl_group(self, group_id: str) -> None:
  113. ctx = ChannelContext.get_current()
  114. if group_id not in ctx.communicators:
  115. return
  116. actors, _ = self.ids_to_actors_and_custom_comms[group_id]
  117. destroy_tasks = [
  118. actor.__ray_call__.remote(
  119. mock_do_destroy_nccl_group,
  120. group_id,
  121. )
  122. for actor in actors
  123. ]
  124. ray.wait(destroy_tasks, timeout=30)
  125. if group_id in self.ids_to_actors_and_custom_comms:
  126. del self.ids_to_actors_and_custom_comms[group_id]
  127. ctx.communicators[group_id].destroy()
  128. del ctx.communicators[group_id]
  129. def check_teardown(self, nccl_group_ids: List[str]) -> None:
  130. ctx = ChannelContext.get_current()
  131. for nccl_group_id in nccl_group_ids:
  132. assert nccl_group_id not in self.ids_to_actors_and_custom_comms
  133. assert nccl_group_id not in ctx.communicators
  134. @ray.remote
  135. class CPUTorchTensorWorker:
  136. def __init__(self):
  137. self.device = "cpu"
  138. def return_tensor(
  139. self, size: int, dtype: Optional[torch.dtype] = None
  140. ) -> torch.Tensor:
  141. return torch.ones(size, dtype=dtype, device=self.device)
  142. def recv(self, tensor: torch.Tensor) -> Tuple[int, int]:
  143. assert tensor.device == self.device
  144. return tensor.shape, tensor[0]
  145. def recv_tensors(self, *tensors) -> Tuple[torch.Tensor, ...]:
  146. return tuple(tensors)
  147. def mock_do_init_nccl_group(
  148. self,
  149. group_id: str,
  150. rank: int,
  151. actors: List[ray.actor.ActorHandle],
  152. custom_nccl_group: Optional[Communicator],
  153. ) -> None:
  154. ctx = ChannelContext.get_current()
  155. if custom_nccl_group is None:
  156. nccl_group = AbstractNcclGroup(actors)
  157. nccl_group.initialize(rank)
  158. ctx.communicators[group_id] = nccl_group
  159. else:
  160. custom_nccl_group.initialize(rank)
  161. ctx.communicators[group_id] = custom_nccl_group
  162. def mock_do_destroy_nccl_group(self, group_id: str) -> None:
  163. ctx = ChannelContext.get_current()
  164. if group_id not in ctx.communicators:
  165. return
  166. ctx.communicators[group_id].destroy()
  167. del ctx.communicators[group_id]
  168. def check_nccl_group_init(
  169. monkeypatch,
  170. dag: "ray.dag.DAGNode",
  171. actors_and_custom_comms: Set[
  172. Tuple[FrozenSet["ray.actor.ActorHandle"], Optional[Communicator]]
  173. ],
  174. ) -> "ray.dag.CompiledDAG":
  175. mock_nccl_group_set = MockNcclGroupSet()
  176. monkeypatch.setattr(
  177. "ray.dag.compiled_dag_node._init_communicator",
  178. mock_nccl_group_set,
  179. )
  180. compiled_dag = dag.experimental_compile()
  181. assert (
  182. set(mock_nccl_group_set.ids_to_actors_and_custom_comms.values())
  183. == actors_and_custom_comms
  184. )
  185. return compiled_dag, mock_nccl_group_set
  186. def check_nccl_group_teardown(
  187. monkeypatch,
  188. compiled_dag: "ray.dag.CompiledDAG",
  189. mock_nccl_group_set: MockNcclGroupSet,
  190. ):
  191. monkeypatch.setattr(
  192. "ray.dag.compiled_dag_node._destroy_communicator",
  193. mock_nccl_group_set.mock_destroy_nccl_group,
  194. )
  195. created_communicator_ids = compiled_dag._actors_to_created_communicator_id.values()
  196. compiled_dag.teardown()
  197. mock_nccl_group_set.check_teardown(created_communicator_ids)