conftest.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import asyncio
  2. from collections import defaultdict
  3. from typing import Optional, Tuple
  4. from unittest import mock
  5. import torch
  6. import ray
  7. import ray.dag
  8. import ray.experimental.channel as ray_channel
  9. from ray.experimental.channel import nccl_group
  10. from ray.experimental.channel.communicator import TorchTensorAllocator
  11. from ray.experimental.util.types import Device
  12. @ray.remote(num_cpus=0)
  13. class Barrier:
  14. """
  15. Barrier that blocks the given number of actors until all actors have
  16. reached the barrier. This is used to mock out blocking NCCL ops.
  17. """
  18. def __init__(self, num_actors=2):
  19. self.num_actors = num_actors
  20. self.condition = asyncio.Condition()
  21. # Buffer for the data that is "sent" between the actors, each entry is
  22. # one p2p op.
  23. self.data = {}
  24. # Buffer for the number of actors seen, each entry is one p2p op.
  25. self.num_actors_seen = defaultdict(int)
  26. # Add a new mock for the TorchTensorType.device property
  27. device_property_patcher = mock.patch(
  28. "ray.experimental.channel.torch_tensor_type.TorchTensorType.device",
  29. new_callable=mock.PropertyMock,
  30. return_value=Device.CPU,
  31. )
  32. device_property_patcher.start()
  33. async def wait(self, idx: int, data=None):
  34. """
  35. Wait at barrier until all actors have sent `idx`. One actor should
  36. provide `data`, and this value will be returned by this method for all
  37. other actors.
  38. """
  39. async with self.condition:
  40. if data is not None:
  41. assert idx not in self.data, (self.data, self.num_actors_seen)
  42. self.data[idx] = data
  43. self.num_actors_seen[idx] += 1
  44. if self.num_actors_seen[idx] == self.num_actors:
  45. # Wake up all tasks waiting on this condition.
  46. self.condition.notify_all()
  47. else:
  48. await self.condition.wait_for(
  49. lambda: self.num_actors_seen[idx] == self.num_actors
  50. )
  51. if data is None:
  52. data = self.data[idx]
  53. return data
  54. class MockCudaStream:
  55. def __init__(self):
  56. self.cuda_stream = 0
  57. def synchronize(self):
  58. pass
  59. class MockNcclGroup(nccl_group._NcclGroup):
  60. """
  61. Mock the internal _NcclGroup to use a barrier actor instead of a NCCL group
  62. for communication.
  63. """
  64. def __init__(self, *args, **kwargs):
  65. super().__init__(*args, **kwargs)
  66. # We use the op index to synchronize the sender and receiver at the
  67. # barrier.
  68. self.num_ops = defaultdict(int)
  69. self.barriers = set()
  70. def send(self, tensor: torch.Tensor, peer_rank: int):
  71. # "Send" the tensor to the barrier actor.
  72. barrier_key = sorted([self.get_self_rank(), peer_rank])
  73. barrier_key = f"barrier-{barrier_key[0]}-{barrier_key[1]}"
  74. barrier = ray.get_actor(name=barrier_key)
  75. self.barriers.add(barrier)
  76. ray.get(barrier.wait.remote(self.num_ops[barrier_key], tensor))
  77. self.num_ops[barrier_key] += 1
  78. def recv(
  79. self,
  80. shape: Tuple[int],
  81. dtype: torch.dtype,
  82. peer_rank: int,
  83. allocator: Optional[TorchTensorAllocator] = None,
  84. ):
  85. # "Receive" the tensor from the barrier actor.
  86. barrier_key = sorted([self.get_self_rank(), peer_rank])
  87. barrier_key = f"barrier-{barrier_key[0]}-{barrier_key[1]}"
  88. barrier = ray.get_actor(name=barrier_key)
  89. self.barriers.add(barrier)
  90. received_tensor = ray.get(barrier.wait.remote(self.num_ops[barrier_key]))
  91. assert (
  92. allocator is not None
  93. ), "torch tensor allocator is required for MockNcclGroup"
  94. buf = allocator(shape, dtype)
  95. buf[:] = received_tensor[:]
  96. self.num_ops[barrier_key] += 1
  97. return buf
  98. def destroy(self) -> None:
  99. for barrier in self.barriers:
  100. ray.kill(barrier)
  101. def start_nccl_mock():
  102. """
  103. Patch methods that require CUDA.
  104. """
  105. # Mock cupy dependencies.
  106. nccl_mock = mock.MagicMock()
  107. nccl_mock.nccl.get_unique_id.return_value = 0
  108. cp_patcher = mock.patch.dict(
  109. "sys.modules",
  110. {
  111. "cupy.cuda": nccl_mock,
  112. "cupy": mock.MagicMock(),
  113. "ray.util.collective.collective_group": mock.MagicMock(),
  114. },
  115. )
  116. cp_patcher.start()
  117. # Mock send/recv ops to use an actor instead of NCCL.
  118. ray.experimental.channel.nccl_group._NcclGroup = MockNcclGroup
  119. # PyTorch mocks.
  120. stream_patcher = mock.patch(
  121. "torch.cuda.current_stream", new_callable=lambda: MockCudaStream
  122. )
  123. stream_patcher.start()
  124. new_stream_patcher = mock.patch(
  125. "torch.cuda.Stream", new_callable=lambda: MockCudaStream
  126. )
  127. new_stream_patcher.start()
  128. tensor_patcher = mock.patch("torch.Tensor.device", torch.device("cuda"))
  129. tensor_patcher.start()
  130. tensor_patcher = mock.patch("torch.Tensor.is_cuda", True)
  131. tensor_patcher.start()
  132. tensor_allocator_patcher = mock.patch(
  133. "ray.experimental.channel.torch_tensor_accelerator_channel._torch_tensor_allocator",
  134. lambda shape, dtype: torch.empty(shape, dtype=dtype),
  135. )
  136. tensor_allocator_patcher.start()
  137. # Add a new mock for the TorchTensorType.device property
  138. device_property_patcher = mock.patch(
  139. "ray.experimental.channel.torch_tensor_type.TorchTensorType.device",
  140. new_callable=mock.PropertyMock,
  141. return_value=Device.CPU,
  142. )
  143. device_property_patcher.start()
  144. ctx = ray_channel.ChannelContext.get_current()
  145. ctx.set_torch_device(torch.device("cuda"))
  146. class TracedChannel(ray_channel.shared_memory_channel.Channel):
  147. """
  148. Patched Channel that records all write ops for testing.
  149. """
  150. def __init__(self, *args, **kwargs):
  151. super().__init__(*args, **kwargs)
  152. self.ops = []
  153. def write(self, *args, **kwargs):
  154. self.ops.append((args, kwargs))
  155. return super().write(*args, **kwargs)