comm.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. # mypy: allow-untyped-defs
  2. import warnings
  3. import torch
  4. from torch._utils import (
  5. _flatten_dense_tensors,
  6. _get_device_index,
  7. _handle_complex,
  8. _reorder_tensors_as,
  9. _take_tensors,
  10. _unflatten_dense_tensors,
  11. )
  12. from torch.cuda import nccl
  13. def broadcast(tensor, devices=None, *, out=None):
  14. r"""Broadcasts a tensor to specified GPU devices.
  15. Args:
  16. tensor (Tensor): tensor to broadcast. Can be on CPU or GPU.
  17. devices (Iterable[torch.device, str or int], optional): an iterable of
  18. GPU devices, among which to broadcast.
  19. out (Sequence[Tensor], optional, keyword-only): the GPU tensors to
  20. store output results.
  21. .. note::
  22. Exactly one of :attr:`devices` and :attr:`out` must be specified.
  23. Returns:
  24. - If :attr:`devices` is specified,
  25. a tuple containing copies of :attr:`tensor`, placed on
  26. :attr:`devices`.
  27. - If :attr:`out` is specified,
  28. a tuple containing :attr:`out` tensors, each containing a copy of
  29. :attr:`tensor`.
  30. """
  31. tensor = _handle_complex(tensor)
  32. if not ((devices is None) ^ (out is None)):
  33. raise RuntimeError(
  34. f"Exactly one of 'devices' and 'out' must be specified, but got devices={devices} and out={out}"
  35. )
  36. if devices is not None:
  37. devices = [_get_device_index(d) for d in devices]
  38. return torch._C._broadcast(tensor, devices)
  39. else:
  40. # pyrefly: ignore [bad-argument-type]
  41. return torch._C._broadcast_out(tensor, out)
  42. def broadcast_coalesced(tensors, devices, buffer_size=10485760):
  43. """Broadcast a sequence of tensors to the specified GPUs.
  44. Small tensors are first coalesced into a buffer to reduce the number of synchronizations.
  45. Args:
  46. tensors (sequence): tensors to broadcast. Must be on the same device,
  47. either CPU or GPU.
  48. devices (Iterable[torch.device, str or int]): an iterable of GPU
  49. devices, among which to broadcast.
  50. buffer_size (int): maximum size of the buffer used for coalescing
  51. Returns:
  52. A tuple containing copies of :attr:`tensor`, placed on :attr:`devices`.
  53. """
  54. devices = [_get_device_index(d) for d in devices]
  55. tensors = [_handle_complex(t) for t in tensors]
  56. return torch._C._broadcast_coalesced(tensors, devices, buffer_size)
  57. def reduce_add(inputs, destination=None):
  58. """Sum tensors from multiple GPUs.
  59. All inputs should have matching shapes, dtype, and layout. The output tensor
  60. will be of the same shape, dtype, and layout.
  61. Args:
  62. inputs (Iterable[Tensor]): an iterable of tensors to add.
  63. destination (int, optional): a device on which the output will be
  64. placed (default: current device).
  65. Returns:
  66. A tensor containing an elementwise sum of all inputs, placed on the
  67. :attr:`destination` device.
  68. """
  69. destination = _get_device_index(destination, optional=True)
  70. input_size = inputs[0].size()
  71. root_index = None # index of input tensor that already is on the correct device
  72. for i, inp in enumerate(inputs):
  73. if inp.device.type == "cpu":
  74. raise AssertionError(
  75. f"reduce_add expects all inputs to be on GPUs, but input {i} is on CPU"
  76. )
  77. if inp.get_device() == destination:
  78. root_index = i
  79. if inp.size() != input_size:
  80. got = "x".join(str(x) for x in inp.size())
  81. expected = "x".join(str(x) for x in input_size)
  82. raise ValueError(
  83. f"input {i} has invalid size: got {got}, but expected {expected}"
  84. )
  85. if root_index is None:
  86. raise RuntimeError(
  87. "reduce_add expects destination to be on the same GPU with one of the tensors"
  88. )
  89. if len(inputs) == 1:
  90. return inputs[0]
  91. if nccl.is_available(inputs):
  92. result = torch.empty_like(inputs[root_index])
  93. nccl.reduce(inputs, output=result, root=root_index)
  94. else:
  95. destination_device = torch.device(inputs[root_index].device.type, destination)
  96. nonroot = [t for i, t in enumerate(inputs) if i != root_index]
  97. # make a new tensor w/o clone
  98. result = inputs[root_index] + nonroot[0].to(
  99. device=destination_device, non_blocking=True
  100. )
  101. for other in nonroot[1:]:
  102. result.add_(other.to(device=destination_device, non_blocking=True))
  103. return result
  104. def reduce_add_coalesced(inputs, destination=None, buffer_size=10485760):
  105. """Sum tensors from multiple GPUs.
  106. Small tensors are first coalesced into a buffer to reduce the number
  107. of synchronizations.
  108. Args:
  109. inputs (Iterable[Iterable[Tensor]]): iterable of iterables that
  110. contain tensors from a single device.
  111. destination (int, optional): a device on which the output will be
  112. placed (default: current device).
  113. buffer_size (int): maximum size of the buffer used for coalescing
  114. Returns:
  115. A tuple of tensors containing an elementwise sum of each group of
  116. inputs, placed on the ``destination`` device.
  117. """
  118. # TODO: When `len(inputs) == 1` and all inputs are on `destination`, just
  119. # return `inputs`.
  120. dense_tensors: list[list] = [[] for _ in inputs] # shape (num_gpus, num_tensors)
  121. output = []
  122. ref_order = []
  123. # process sparse ones first since they may have different sizes on different gpus
  124. for tensor_at_gpus in zip(*inputs, strict=True):
  125. if all(t.is_sparse for t in tensor_at_gpus):
  126. result = reduce_add(tensor_at_gpus, destination) # this will be sparse too
  127. output.append(result)
  128. ref_order.append(tensor_at_gpus[0])
  129. else:
  130. for coll, t in zip(dense_tensors, tensor_at_gpus, strict=True):
  131. coll.append(t.to_dense() if t.is_sparse else t)
  132. ref_order.append(dense_tensors[0][-1])
  133. itrs = [_take_tensors(tensors, buffer_size) for tensors in dense_tensors]
  134. # now the dense ones, which have consistent sizes
  135. for chunks in zip(*itrs, strict=True):
  136. flat_tensors = [
  137. _flatten_dense_tensors(chunk) for chunk in chunks
  138. ] # (num_gpus,)
  139. flat_result = reduce_add(flat_tensors, destination)
  140. for t in _unflatten_dense_tensors(flat_result, chunks[0]):
  141. # The unflattened tensors do not share storage, and we don't expose
  142. # base flat tensor anyways, so give them different version counters.
  143. # See NOTE [ Version Counter in comm.*_coalesced ]
  144. output.append(t.data)
  145. return tuple(_reorder_tensors_as(output, ref_order))
  146. def scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out=None):
  147. """Scatters tensor across multiple GPUs.
  148. Args:
  149. tensor (Tensor): tensor to scatter. Can be on CPU or GPU.
  150. devices (Iterable[torch.device, str or int], optional): an iterable of
  151. GPU devices, among which to scatter.
  152. chunk_sizes (Iterable[int], optional): sizes of chunks to be placed on
  153. each device. It should match :attr:`devices` in length and sums to
  154. ``tensor.size(dim)``. If not specified, :attr:`tensor` will be divided
  155. into equal chunks.
  156. dim (int, optional): A dimension along which to chunk :attr:`tensor`.
  157. Default: ``0``.
  158. streams (Iterable[torch.cuda.Stream], optional): an iterable of Streams, among
  159. which to execute the scatter. If not specified, the default stream will
  160. be utilized.
  161. out (Sequence[Tensor], optional, keyword-only): the GPU tensors to
  162. store output results. Sizes of these tensors must match that of
  163. :attr:`tensor`, except for :attr:`dim`, where the total size must
  164. sum to ``tensor.size(dim)``.
  165. .. note::
  166. Exactly one of :attr:`devices` and :attr:`out` must be specified. When
  167. :attr:`out` is specified, :attr:`chunk_sizes` must not be specified and
  168. will be inferred from sizes of :attr:`out`.
  169. Returns:
  170. - If :attr:`devices` is specified,
  171. a tuple containing chunks of :attr:`tensor`, placed on
  172. :attr:`devices`.
  173. - If :attr:`out` is specified,
  174. a tuple containing :attr:`out` tensors, each containing a chunk of
  175. :attr:`tensor`.
  176. """
  177. tensor = _handle_complex(tensor)
  178. if out is None:
  179. # pyrefly: ignore [not-iterable]
  180. devices = [_get_device_index(d) for d in devices]
  181. return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams))
  182. else:
  183. if devices is not None:
  184. raise RuntimeError(
  185. f"'devices' must not be specified when 'out' is specified, but got devices={devices}"
  186. )
  187. if chunk_sizes is not None:
  188. raise RuntimeError(
  189. f"'chunk_sizes' must not be specified when 'out' is specified, but got chunk_sizes={chunk_sizes}"
  190. )
  191. return tuple(torch._C._scatter_out(tensor, out, dim, streams))
  192. def gather(tensors, dim=0, destination=None, *, out=None):
  193. r"""Gathers tensors from multiple GPU devices.
  194. Args:
  195. tensors (Iterable[Tensor]): an iterable of tensors to gather.
  196. Tensor sizes in all dimensions other than :attr:`dim` have to match.
  197. dim (int, optional): a dimension along which the tensors will be
  198. concatenated. Default: ``0``.
  199. destination (torch.device, str, or int, optional): the output device.
  200. Can be CPU or CUDA. Default: the current CUDA device.
  201. out (Tensor, optional, keyword-only): the tensor to store gather result.
  202. Its sizes must match those of :attr:`tensors`, except for :attr:`dim`,
  203. where the size must equal ``sum(tensor.size(dim) for tensor in tensors)``.
  204. Can be on CPU or CUDA.
  205. .. note::
  206. :attr:`destination` must not be specified when :attr:`out` is specified.
  207. Returns:
  208. - If :attr:`destination` is specified,
  209. a tensor located on :attr:`destination` device, that is a result of
  210. concatenating :attr:`tensors` along :attr:`dim`.
  211. - If :attr:`out` is specified,
  212. the :attr:`out` tensor, now containing results of concatenating
  213. :attr:`tensors` along :attr:`dim`.
  214. """
  215. tensors = [_handle_complex(t) for t in tensors]
  216. if out is None:
  217. if destination == -1:
  218. warnings.warn(
  219. "Using -1 to represent CPU tensor is deprecated. Please use a "
  220. 'device object or string instead, e.g., "cpu".',
  221. FutureWarning,
  222. stacklevel=2,
  223. )
  224. destination = _get_device_index(destination, allow_cpu=True, optional=True)
  225. return torch._C._gather(tensors, dim, destination)
  226. else:
  227. if destination is not None:
  228. raise RuntimeError(
  229. f"'destination' must not be specified when 'out' is specified, but got destination={destination}"
  230. )
  231. return torch._C._gather_out(tensors, out, dim)