functional.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. # mypy: allow-untyped-defs
  2. import torch
  3. import torch.distributed as dist
  4. from torch.autograd import Function
  5. # The two imports below are not always available depending on the
  6. # USE_DISTRIBUTED compile flag. Make sure they raise import error
  7. # if we're trying to use them.
  8. from torch.distributed import group, ReduceOp
  9. def broadcast(tensor, src, group=group.WORLD):
  10. """
  11. Broadcasts the tensor to the whole group.
  12. ``tensor`` must have the same number of elements in all processes
  13. participating in the collective.
  14. Arguments:
  15. tensor (Tensor): Data to be sent if ``src`` is the rank of current
  16. process.
  17. src (int): Source rank.
  18. group (ProcessGroup, optional): The process group to work on.
  19. Returns:
  20. Tensor: Received tensor from the broadcast op.
  21. """
  22. return _Broadcast.apply(src, group, tensor)
  23. def gather(tensor, dst=0, group=group.WORLD):
  24. """
  25. Gathers a list of tensors in a single process.
  26. Arguments:
  27. tensor (Tensor): Input tensor.
  28. dst (int, optional): Destination rank (default is 0).
  29. group (ProcessGroup, optional): The process group to work on.
  30. Returns:
  31. tuple[Tensor]: List of appropriately-sized tensors with the gathered data.
  32. """
  33. return _Gather.apply(dst, group, tensor)
  34. def scatter(tensors, src=0, group=group.WORLD):
  35. """
  36. Scatters a list of tensors to all processes in a group.
  37. Each process will receive exactly one tensor and store its data in the
  38. ``tensor`` argument.
  39. Arguments:
  40. tensors (list[Tensor]): List of tensors to scatter on the source rank.
  41. Receivers must pass ``None`.
  42. src (int, optional): Source rank (default is 0).
  43. group (ProcessGroup, optional): The process group to work on.
  44. Returns:
  45. Tensor: Output tensor from the scatter operation.
  46. """
  47. return _Scatter.apply(src, group, *tensors)
  48. def reduce(tensor, dst, op=ReduceOp.SUM, group=group.WORLD):
  49. """
  50. Reduces the tensor data across all machines.
  51. Only the process with rank ``dst`` is going to receive the final result.
  52. Arguments:
  53. tensor (Tensor): Input of the collective.
  54. dst (int): Destination rank.
  55. op (optional): One of the values from
  56. ``torch.distributed.ReduceOp``
  57. enum. Specifies an operation used for element-wise reductions.
  58. group (ProcessGroup, optional): The process group to work on.
  59. Returns:
  60. Tensor: Output of the collective.
  61. """
  62. return _Reduce.apply(dst, op, group, tensor)
  63. def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=group.WORLD):
  64. """
  65. Reduces, then scatters a list of tensors to all processes in a group.
  66. Arguments:
  67. output (Tensor): Output tensor.
  68. input_list (list[Tensor]): List of tensors to reduce and scatter.
  69. op (optional): One of the values from
  70. ``torch.distributed.ReduceOp``
  71. enum. Specifies an operation used for element-wise reductions.
  72. group (ProcessGroup, optional): The process group to work on.
  73. Returns:
  74. Tensor: Output of the collective.
  75. """
  76. return _Reduce_Scatter.apply(op, group, output, *input_list)
  77. def all_gather(tensor, group=group.WORLD):
  78. """
  79. Gathers tensors from the whole group in a list.
  80. Arguments:
  81. tensor (Tensor): Tensor to be broadcast from current process.
  82. group (ProcessGroup, optional): The process group to work on.
  83. Returns:
  84. tuple([Tensor]): Output of the collective.
  85. """
  86. return _AllGather.apply(group, tensor)
  87. def _all_gather_base(output_tensor, input_tensor, group=group.WORLD):
  88. """
  89. Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor.
  90. Args:
  91. output_tensor (Tensor): Output tensor. It should contain
  92. correctly-sized tensors to be used for output of the collective.
  93. input_tensor (Tensor): Tensor to be broadcast from current process.
  94. group (ProcessGroup, optional): The process group to work on. If None,
  95. the default process group will be used.
  96. Examples:
  97. >>> # All tensors below are of torch.int64 dtype.
  98. >>> # We have 2 process groups, 2 ranks.
  99. >>> # xdoctest: +SKIP("incorrect want text")
  100. >>> output_tensor = torch.zeros(2, dtype=torch.int64)
  101. >>> output_tensor
  102. [tensor([0, 0])] # Rank 0 and 1
  103. >>> tensor = torch.arange(1, dtype=torch.int64) + 1 + rank
  104. >>> tensor
  105. tensor([1]) # Rank 0
  106. tensor([2]) # Rank 1
  107. >>> dist.all_gather_base(output_tensor, tensor)
  108. >>> output_tensor
  109. tensor([1,2]) # Rank 0
  110. tensor([1,2]) # Rank 1
  111. .. warning::
  112. `_all_gather_base` is experimental and subject to change.
  113. It is the caller's responsibility to ensure the output_tensor
  114. is correctly sized.
  115. """
  116. return _AllGatherBase.apply(output_tensor, input_tensor, group)
  117. def all_to_all(output_tensor_list, input_tensor_list, group=group.WORLD):
  118. """
  119. Each process scatters list of input tensors to all processes in a group and return gathered list of tensors in output list.
  120. Arguments:
  121. output_tensor_list (list[Tensor]): list of tensors to gather one per rank.
  122. input_tensor_list (list[Tensor]): List of tensors to scatter one per rank.
  123. group (ProcessGroup, optional): The process group to work on.
  124. Returns:
  125. tuple([Tensor]): Output of the collective.
  126. """
  127. return _AlltoAll.apply(group, output_tensor_list, *input_tensor_list)
  128. def all_to_all_single(
  129. output,
  130. input,
  131. output_split_sizes=None,
  132. input_split_sizes=None,
  133. group=group.WORLD,
  134. ):
  135. """
  136. Each process splits input tensor and then scatters the split list to all processes in a group.
  137. Then concatenate the received tensors from all the processes in the group and return single output tensor.
  138. Arguments:
  139. output (Tensor): Gathered concatenated output tensor.
  140. input (Tensor): Input tensor to scatter.
  141. output_split_sizes: (list[Int], optional): Output split sizes for dim 0
  142. if specified None or empty, dim 0 of ``output`` tensor must divide
  143. equally by ``world_size``.
  144. input_split_sizes: (list[Int], optional): Input split sizes for dim 0
  145. if specified None or empty, dim 0 of ``input`` tensor must divide
  146. equally by ``world_size``.
  147. Returns:
  148. Tensor: Output of the collective.
  149. """
  150. return _AlltoAllSingle.apply(
  151. group, output, output_split_sizes, input_split_sizes, input
  152. )
  153. def all_reduce(tensor, op=ReduceOp.SUM, group=group.WORLD):
  154. """
  155. Reduces the tensor data across all machines in such a way that all get the final result.
  156. After the call the returned tensor is going to be bitwise
  157. identical in all processes.
  158. Arguments:
  159. tensor (Tensor): Input of the collective.
  160. op (optional): One of the values from
  161. ``torch.distributed.ReduceOp``
  162. enum. Specifies an operation used for element-wise reductions.
  163. group (ProcessGroup, optional): The process group to work on.
  164. Returns:
  165. Tensor: Output of the collective
  166. """
  167. return _AllReduce.apply(op, group, tensor)
  168. class _Broadcast(Function):
  169. @staticmethod
  170. # pyrefly: ignore [bad-override]
  171. def forward(ctx, src, group, tensor):
  172. ctx.src = src
  173. ctx.group = group
  174. ctx.rank = dist.get_rank(group=group)
  175. # torch.distributed makes all the calls in place
  176. # we allocate new tensors to avoid this
  177. tensor = tensor.clone()
  178. dist.broadcast(tensor, src, group=group)
  179. return tensor
  180. @staticmethod
  181. # pyrefly: ignore [bad-override]
  182. def backward(ctx, grad_output):
  183. gx = _Reduce.apply(ctx.src, ReduceOp.SUM, ctx.group, grad_output)
  184. if ctx.src != ctx.rank:
  185. gx.zero_()
  186. return (None, None, gx)
  187. class _Gather(Function):
  188. @staticmethod
  189. # pyrefly: ignore [bad-override]
  190. def forward(ctx, dst, group, tensor):
  191. ctx.dst = dst
  192. ctx.group = group
  193. # Need to create a list of tensors here to do the
  194. # aggregation, get it from the group size
  195. # tensor should be correctly sized for the method
  196. # gathering
  197. tensor_list = [
  198. torch.zeros_like(tensor) for i in range(dist.get_world_size(group=group))
  199. ]
  200. tensor = tensor.contiguous()
  201. if dist.get_rank(group=group) == dst:
  202. dist.gather(tensor, tensor_list, dst, group=group)
  203. else:
  204. dist.gather(tensor, None, dst, group=group)
  205. return tuple(tensor_list)
  206. @staticmethod
  207. def backward(ctx, *grad_outputs):
  208. return (None, None) + (_Scatter.apply(ctx.dst, ctx.group, *grad_outputs),)
  209. class _Scatter(Function):
  210. @staticmethod
  211. # pyrefly: ignore [bad-override]
  212. def forward(ctx, src, group, *tensors):
  213. ctx.src = src
  214. ctx.group = group
  215. assert all(t.size() == tensors[0].size() for t in tensors)
  216. output = torch.zeros_like(tensors[0])
  217. if dist.get_rank(group=group) == src:
  218. dist.scatter(output, list(tensors), src, group=group)
  219. else:
  220. dist.scatter(output, None, src, group=group)
  221. return output
  222. @staticmethod
  223. # pyrefly: ignore [bad-override]
  224. def backward(ctx, grad_output):
  225. return (None, None) + _Gather.apply(ctx.src, ctx.group, grad_output)
  226. class _Reduce(Function):
  227. @staticmethod
  228. # pyrefly: ignore [bad-override]
  229. def forward(ctx, src, op, group, tensor):
  230. ctx.src = src
  231. ctx.group = group
  232. tensor = tensor.clone()
  233. dist.reduce(tensor, src, op=op, group=group)
  234. return tensor
  235. @staticmethod
  236. # pyrefly: ignore [bad-override]
  237. def backward(ctx, grad_output):
  238. return (None, None, None) + (_Broadcast.apply(ctx.src, ctx.group, grad_output),)
  239. class _Reduce_Scatter(Function):
  240. @staticmethod
  241. # pyrefly: ignore [bad-override]
  242. def forward(ctx, op, group, tensor, *input_tensor_list):
  243. ctx.group = group
  244. # Need contiguous tensors for collectives.
  245. tensor = tensor.contiguous()
  246. input_tensor_list = tuple(t.contiguous() for t in input_tensor_list)
  247. dist.reduce_scatter(tensor, list(input_tensor_list), op=op, group=group)
  248. return tensor
  249. @staticmethod
  250. # pyrefly: ignore [bad-override]
  251. def backward(ctx, grad_output):
  252. return (None, None, None) + _AllGather.apply(ctx.group, grad_output)
  253. class _AllGather(Function):
  254. @staticmethod
  255. # pyrefly: ignore [bad-override]
  256. def forward(ctx, group, tensor):
  257. # Need contiguous tensors for collectives.
  258. tensor = tensor.contiguous()
  259. ctx.group = group
  260. out_tensor_list = [
  261. torch.empty_like(tensor) for _ in range(dist.get_world_size(group=group))
  262. ]
  263. dist.all_gather(out_tensor_list, tensor, group=group)
  264. return tuple(out_tensor_list)
  265. @staticmethod
  266. def backward(ctx, *grad_outputs):
  267. if dist.get_backend(group=ctx.group) in (dist.Backend.NCCL, dist.Backend.XCCL):
  268. rank = dist.get_rank(group=ctx.group)
  269. gx = torch.empty_like(grad_outputs[rank])
  270. gx = _Reduce_Scatter.apply(ReduceOp.SUM, ctx.group, gx, *grad_outputs)
  271. else:
  272. # As many backends doesn't support ReduceScatter, we use AlltoAll with .sum()
  273. # to emulate the ReduceScatter behavior
  274. tensor_list = [torch.empty_like(tensor) for tensor in grad_outputs]
  275. gxs = _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs)
  276. gx = torch.sum(torch.stack(gxs), dim=0)
  277. return (None, gx)
  278. class _AllGatherBase(Function):
  279. @staticmethod
  280. # pyrefly: ignore [bad-override]
  281. def forward(ctx, output_tensor, input_tensor, group):
  282. ctx.group = group
  283. dist._all_gather_base(output_tensor, input_tensor.contiguous(), group=group)
  284. return output_tensor
  285. @staticmethod
  286. # pyrefly: ignore [bad-override]
  287. def backward(ctx, grad_output):
  288. if dist.get_backend(group=ctx.group) in (dist.Backend.NCCL, dist.Backend.XCCL):
  289. world_size = dist.get_world_size(group=ctx.group)
  290. out_size = list(grad_output.size())
  291. if out_size[0] % world_size != 0:
  292. raise RuntimeError(
  293. f"Tensor with dimensions: {out_size} does "
  294. f"not have first dimension divisible by world_size: {world_size}"
  295. )
  296. out_size[0] = out_size[0] // dist.get_world_size(group=ctx.group)
  297. gx = torch.empty(
  298. out_size, device=grad_output.device, dtype=grad_output.dtype
  299. )
  300. dist._reduce_scatter_base(gx, grad_output, ReduceOp.SUM, ctx.group)
  301. else:
  302. raise RuntimeError("Backend not supported!")
  303. return (None, gx, None)
  304. class _AlltoAll(Function):
  305. @staticmethod
  306. # pyrefly: ignore [bad-override]
  307. def forward(ctx, group, out_tensor_list, *tensors):
  308. ctx.group = group
  309. ctx.input_tensor_size_list = [
  310. tensors[i].size() for i in range(dist.get_world_size(group=group))
  311. ]
  312. my_rank = dist.get_rank(group=group)
  313. tensors = tuple(t.contiguous() for t in tensors)
  314. # Implement it on means of scatter/gather, send/recv async operations have issues
  315. if dist.get_backend(group=group) is dist.Backend.GLOO:
  316. for i in range(dist.get_world_size(group=group)):
  317. to_send = None
  318. if i == my_rank:
  319. to_send = list(tensors)
  320. dist.scatter(out_tensor_list[i], to_send, i, group=group)
  321. else:
  322. dist.all_to_all(
  323. out_tensor_list,
  324. list(tensors),
  325. group=group,
  326. )
  327. return tuple(out_tensor_list)
  328. @staticmethod
  329. def backward(ctx, *grad_outputs):
  330. tensor_list = [
  331. torch.empty(
  332. size, device=grad_outputs[0].device, dtype=grad_outputs[0].dtype
  333. )
  334. for size in ctx.input_tensor_size_list
  335. ]
  336. return (None, None) + _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs)
  337. class _AlltoAllSingle(Function):
  338. @staticmethod
  339. # pyrefly: ignore [bad-override]
  340. def forward(ctx, group, output, output_split_sizes, input_split_sizes, input):
  341. ctx.group = group
  342. ctx.input_size = input.size()
  343. ctx.output_split_sizes = input_split_sizes
  344. ctx.input_split_sizes = output_split_sizes
  345. dist.all_to_all_single(
  346. output,
  347. input,
  348. output_split_sizes=output_split_sizes,
  349. input_split_sizes=input_split_sizes,
  350. group=group,
  351. )
  352. return output
  353. @staticmethod
  354. # pyrefly: ignore [bad-override]
  355. def backward(ctx, grad_output):
  356. tensor = torch.empty(
  357. ctx.input_size, device=grad_output.device, dtype=grad_output.dtype
  358. )
  359. return (None, None, None, None) + (
  360. _AlltoAllSingle.apply(
  361. ctx.group,
  362. tensor,
  363. ctx.output_split_sizes,
  364. ctx.input_split_sizes,
  365. grad_output.contiguous(),
  366. ),
  367. )
  368. class _AllReduce(Function):
  369. @staticmethod
  370. # pyrefly: ignore [bad-override]
  371. def forward(ctx, op, group, tensor):
  372. ctx.group = group
  373. ctx.op = op
  374. tensor = tensor.clone(memory_format=torch.contiguous_format)
  375. dist.all_reduce(tensor, op=op, group=group)
  376. return tensor
  377. @staticmethod
  378. # pyrefly: ignore [bad-override]
  379. def backward(ctx, grad_output):
  380. return (None, None) + (_AllReduce.apply(ctx.op, ctx.group, grad_output),)