_functional_collectives.py 59 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import sys
  4. import warnings
  5. from typing import Any, cast, TYPE_CHECKING, Union
  6. import torch
  7. import torch.distributed as dist
  8. import torch.distributed.distributed_c10d as c10d
  9. from torch._utils import _maybe_view_chunk_cat
  10. from torch.distributed.device_mesh import DeviceMesh
  11. from torch.fx.experimental.proxy_tensor import get_proxy_mode
  12. from . import _functional_collectives_impl as fun_col_impl
  13. try:
  14. from torch.utils._cxx_pytree import tree_map_only
  15. except ImportError:
  16. from torch.utils._pytree import tree_map_only # type: ignore[no-redef]
  17. try:
  18. from torch.compiler import is_dynamo_compiling as is_torchdynamo_compiling
  19. except Exception:
  20. warnings.warn(
  21. "Unable to import torchdynamo util `is_torchdynamo_compiling`, so won't support torchdynamo correctly",
  22. stacklevel=2,
  23. )
  24. def is_torchdynamo_compiling(): # type: ignore[misc]
  25. return False
  26. # pyrefly: ignore [unreachable]
  27. return False
  28. """
  29. New traceable, functional collectives.
  30. RFC: https://github.com/pytorch/pytorch/issues/93173
  31. compiler: trace these ops with plain-old-data schemas, then choose how to lower them.
  32. eager: execute these 'functional' ops which in eager return AsyncCollectiveTensor subclasses,
  33. automatically calling .wait() on underlying/hidden async 'work' obj only when fed to
  34. a downstream op.
  35. Issues:
  36. * Where should these ops live? Couldn't `import torch` if putting these ops in existing torch.distributed files
  37. * Proper support for eager requires inplace ops. We should explore having it as an option for the API.
  38. """
  39. """
  40. Functional collectives are asynchronous only and we perform implicit stream synchronization
  41. on behalf of the user.
  42. We use AsyncCollectiveTensor to wrap the result tensor of a collective and it lets us witness
  43. first usage of the tensor and insert cross stream sync at the right place.
  44. The above are the easy bits, the hard one is how we match the Work object returned by
  45. c10d and the tensor AsyncCollectiveTensor wraps. We alloc the tensor inside the collective
  46. op implementation (see ``clone()`` call in ``_all_reduce``) and then it's handled by the
  47. dispatcher which might call other implementations that are allowed to change the returned
  48. tensor - even return a tensor with a different shape (see ``torch.vmap``).
  49. This means the caller of our ops receives a Tensor that is not guaranteed to be the same
  50. allocated by our implementations and that makes pairing The AsyncTensor to the original
  51. tensor a lot harder. This pairing is needed so we can lookup the Work object to use.
  52. Originally, we tried WeakKeyDictionary to map from Tensor to Work, but because Tensor's
  53. identity is not stable across dispatch, the op caller would end up with a different Tensor
  54. instance that would not match any in the dictionary.
  55. With Tensor identity out of the question, we decided use the tensor data pointer, which
  56. should be stable across all the Tensor changes done during dispatch.
  57. We have a dictionary of tensor::data_ptr -> Work that we insert right after we call into c10d.
  58. We use this dictionary when AsyncCollectiveTensor is used to invoke Work::wait()
  59. Finally, we setup a finalizer against the tensor wrapper to observe it getting collected so we
  60. can clean up stale entries in the dictionary.
  61. To eliminate the possibility of races we have a global version counter that is used by the finalizer.
  62. As a wise man said once: Don't cross the streams (https://www.youtube.com/watch?v=wyKQe_i9yyo)
  63. """
  64. """
  65. Functional collectives can accept any of these types to describe the ranks participating in collectives.
  66. The different types will be desugared to a canonical format
  67. """
  68. RANK_TYPES = Union[
  69. list[int],
  70. list[list[int]],
  71. dist.ProcessGroup,
  72. DeviceMesh,
  73. tuple["dist.tensor.DeviceMesh", int],
  74. c10d.GroupName,
  75. ]
  76. """
  77. User facing APIs for functional collectives
  78. -------------------------------------------
  79. These apis are called by user code and expected to work both in eager execution and compilation,
  80. but there are significant differences to how the two modes are implemented underneath.
  81. Eager execution is 'optimized' using a tensor subclass that schedules the synchronization (via wait_tensor() op)
  82. just before the tensor is first used. Compiled tracing currently relies on the compiler to perform this optimization,
  83. and cannot yet correctly trace the AsyncTensor wrapper class. In the future, these paths may be unified
  84. if sufficient subclass support is added in dynamo.
  85. Example: all_reduce is an entrypoint API, and other collectives follow a similar pattern.
  86. Here's how it works under torch.compile/dynamo:
  87. all_reduce(...)
  88. |--> _expand_group(...) - desugars processgroup into canonical/traceable format
  89. |--> c10d_functional.all_reduce(...) - dynamo captures this op call, doesn't trace deeper
  90. |--> _maybe_wrap_tensor(...) - wait_tensor() op is immediately called, no AsyncTensor subclass needed
  91. And under eager execution:
  92. all_reduce(...)
  93. |--> _expand_group(...) - same as above, but less critical for eager
  94. |--> c10d_functional.all_reduce(...) - dispatches to real kernel OR records op in trace
  95. |--> _maybe_wrap_tensor(...) - AsyncTensor wrapper applied to returned tensor,
  96. which issues wait_tensor() at the time of first use
  97. """
  98. def wait_tensor(tensor):
  99. """
  100. Wait on a tensor returned by the collectives ops.
  101. Waiting follows device semantics, which means blocking on CPU and synchronizing streams on CUDA.
  102. """
  103. return torch.ops._c10d_functional.wait_tensor(tensor) # type: ignore[attr-defined]
  104. def broadcast(self: torch.Tensor, src: int, group: RANK_TYPES, tag: str = ""):
  105. """
  106. Broadcasts the tensor to all processes in the given process group.
  107. Args:
  108. src (int): Source rank
  109. group (ProcessGroup or List[int]): The process group to work on.
  110. tag (str, optional): A unique identifier for the collective. Default: empty string
  111. """
  112. group_name = _resolve_group_name(group, tag)
  113. tensor = torch.ops._c10d_functional.broadcast(self, src, group_name)
  114. return _maybe_wrap_tensor(tensor)
  115. def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str = ""):
  116. """
  117. Reduces the tensor data across all machines in such a way that all get
  118. the final result.
  119. The input tensor is left unmodified.
  120. Group can be one of:
  121. List[int]: ranks participating in the collective.
  122. List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
  123. ProcessGroup: Will perform a collective using the ranks and tag of the PG.
  124. DeviceMesh: Do a SPMD collective over all ranks of the mesh
  125. (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
  126. :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
  127. that information and perform collective algebraic optimization. Use other forms of input for that.
  128. """
  129. group_name = _resolve_group_name(group, tag)
  130. tensor = torch.ops._c10d_functional.all_reduce(self, reduceOp.lower(), group_name)
  131. return _maybe_wrap_tensor(tensor)
  132. def all_gather_tensor(
  133. self: torch.Tensor,
  134. gather_dim: int,
  135. group: RANK_TYPES,
  136. tag: str = "",
  137. ) -> torch.Tensor:
  138. """
  139. Gather tensor data across from all machines and concatenate over ``gather_dim``.
  140. Note that it currently only supports gather_dim = 0.
  141. The input tensor is left unmodified.
  142. Group can be one of:
  143. List[int]: ranks participating in the collective.
  144. List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
  145. ProcessGroup: Will perform a collective using the ranks and tag of the PG.
  146. DeviceMesh: Do a SPMD collective over all ranks of the mesh
  147. (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
  148. :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
  149. that information and perform collective algebraic optimization. Use other forms of input for that.
  150. """
  151. if not self.is_contiguous():
  152. raise AssertionError("Tensor must be contiguous for all_gather_tensor")
  153. group_name = _resolve_group_name(group, tag)
  154. group_size = c10d._get_group_size_by_name(group_name)
  155. tensor = torch.ops._c10d_functional.all_gather_into_tensor(
  156. self, group_size, group_name
  157. )
  158. res = _maybe_wrap_tensor(tensor)
  159. # TODO this should be done inside AsyncCollectiveTensor to delay the wait() call
  160. if gather_dim != 0:
  161. # torch.cat access the data so we already need to wait here, first do wait
  162. # and then chunk + cat avoid us going through ACT dispatching logic again
  163. if isinstance(res, AsyncCollectiveTensor):
  164. res = res.wait() # type: ignore[attr-defined]
  165. res = _maybe_view_chunk_cat(res, group_size, gather_dim)
  166. return res
  167. def all_gather_tensor_autograd(
  168. self: torch.Tensor,
  169. gather_dim: int,
  170. group: RANK_TYPES,
  171. tag: str = "",
  172. ):
  173. """
  174. Gather tensor data across from all machines and concatenate over ``gather_dim``.
  175. Note that it currently only supports gather_dim = 0.
  176. This function is the same as all_gather_tensor but will propagate the
  177. backwards gradient across workers.
  178. See all_gather_tensor for more details on usage.
  179. """
  180. group_name = _resolve_group_name(group, tag)
  181. group_size = c10d._get_group_size_by_name(group_name)
  182. tensor = torch.ops._c10d_functional_autograd.all_gather_into_tensor(
  183. self, group_size, group_name
  184. )
  185. res = _FromTorchTensor.apply(tensor)
  186. # TODO this should be done inside AsyncCollectiveTensor to delay the wait() call
  187. if gather_dim != 0:
  188. # torch.cat access the data so we already need to wait here, first do wait
  189. # and then chunk + cat avoid us going through ACT dispatching logic again
  190. if isinstance(res, AsyncCollectiveTensor):
  191. res = res.wait() # type: ignore[attr-defined]
  192. res = torch.cat(torch.chunk(res, group_size, dim=0), dim=gather_dim)
  193. return res
  194. def reduce_scatter_tensor(
  195. self: torch.Tensor,
  196. reduceOp: str,
  197. scatter_dim: int,
  198. group: RANK_TYPES,
  199. tag: str = "",
  200. ):
  201. """
  202. Reduces the tensor data across all machines in such a way that all get
  203. the final result, then scatter the results to corresponding ranks.
  204. The input tensor is left unmodified.
  205. Group can be one of:
  206. List[int]: ranks participating in the collective.
  207. List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
  208. ProcessGroup: Will perform a collective using the ranks and tag of the PG.
  209. DeviceMesh: Do a SPMD collective over all ranks of the mesh
  210. (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
  211. :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
  212. that information and perform collective algebraic optimization. Use other forms of input for that.
  213. """
  214. group_name = _resolve_group_name(group, tag)
  215. group_size = c10d._get_group_size_by_name(group_name)
  216. if self.size(scatter_dim) % group_size != 0:
  217. raise AssertionError(
  218. f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size})"
  219. )
  220. if scatter_dim != 0:
  221. tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
  222. self = torch.cat(tensor_list)
  223. tensor = torch.ops._c10d_functional.reduce_scatter_tensor(
  224. self,
  225. reduceOp.lower(),
  226. group_size,
  227. group_name, # type: ignore[possibly-undefined]
  228. )
  229. res = _maybe_wrap_tensor(tensor)
  230. return res
  231. def reduce_scatter_tensor_autograd(
  232. self: torch.Tensor,
  233. reduceOp: str,
  234. scatter_dim: int,
  235. group: RANK_TYPES,
  236. tag: str = "",
  237. ):
  238. """
  239. Reduces the tensor data across all machines in such a way that all get
  240. the final result, then scatter the results to corresponding ranks.
  241. This function is the same as reduce_scatter_tensor but will propagate the
  242. backwards gradient across workers.
  243. Currently only the "sum" reduceOp is supported.
  244. See reduce_scatter_tensor for more details on usage.
  245. """
  246. group_name = _resolve_group_name(group, tag)
  247. group_size = c10d._get_group_size_by_name(group_name)
  248. if self.size(scatter_dim) % group_size != 0:
  249. raise AssertionError(
  250. f"input dimension 0 ({self.size(0)} must be a multiple of group_size {group_size}"
  251. )
  252. if scatter_dim != 0:
  253. tensor_list = torch.chunk(self, group_size, dim=scatter_dim)
  254. self = torch.cat(tensor_list)
  255. tensor = torch.ops._c10d_functional_autograd.reduce_scatter_tensor(
  256. self,
  257. reduceOp.lower(),
  258. group_size,
  259. group_name, # type: ignore[possibly-undefined]
  260. )
  261. res = _FromTorchTensor.apply(tensor)
  262. return res
  263. def all_reduce_coalesced(
  264. self: list[torch.Tensor], reduceOp: str, group: RANK_TYPES, tag: str = ""
  265. ) -> list[torch.Tensor]:
  266. """
  267. Reduces a list of tensors across all machines in such a way that all get
  268. the final result.
  269. The all tensors in the input list are left unmodified.
  270. Group can be one of:
  271. List[int]: ranks participating in the collective.
  272. List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
  273. ProcessGroup: Will perform a collective using the ranks and tag of the PG.
  274. DeviceMesh: Do a SPMD collective over all ranks of the mesh
  275. (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
  276. :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
  277. that information and perform collective algebraic optimization. Use other forms of input for that.
  278. """
  279. group_name = _resolve_group_name(group, tag)
  280. tensor_list = torch.ops._c10d_functional.all_reduce_coalesced( # type: ignore[attr-defined]
  281. self,
  282. reduceOp.lower(),
  283. group_name,
  284. )
  285. return list(map(_maybe_wrap_tensor, tensor_list))
  286. def all_gather_into_tensor_coalesced(
  287. self: list[torch.Tensor], group: RANK_TYPES, tag: str = ""
  288. ) -> list[torch.Tensor]:
  289. """
  290. Gather a list of tensors across from all machines.
  291. Note that it currently only supports gather_dim = 0.
  292. The input tensor is left unmodified.
  293. Group can be one of:
  294. List[int]: ranks participating in the collective.
  295. List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
  296. ProcessGroup: Will perform a collective using the ranks and tag of the PG.
  297. DeviceMesh: Do a SPMD collective over all ranks of the mesh
  298. (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
  299. :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
  300. that information and perform collective algebraic optimization. Use other forms of input for that.
  301. """
  302. group_name = _resolve_group_name(group, tag)
  303. group_size = c10d._get_group_size_by_name(group_name)
  304. tensor_list = torch.ops._c10d_functional.all_gather_into_tensor_coalesced( # type: ignore[attr-defined]
  305. self,
  306. group_size,
  307. group_name,
  308. )
  309. return list(map(_maybe_wrap_tensor, tensor_list))
  310. def reduce_scatter_tensor_coalesced(
  311. inputs: list[torch.Tensor],
  312. reduceOp: str,
  313. scatter_dim: list[int],
  314. group: RANK_TYPES,
  315. tag: str = "",
  316. ) -> list[torch.Tensor]:
  317. """
  318. Reduces a list of tensors across all machines in such a way that all get
  319. the final result, then scatter the results to corresponding ranks.
  320. The input tensors are left unmodified.
  321. Group can be one of:
  322. List[int]: ranks participating in the collective.
  323. List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
  324. ProcessGroup: Will perform a collective using the ranks and tag of the PG.
  325. DeviceMesh: Do a SPMD collective over all ranks of the mesh
  326. (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
  327. :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
  328. that information and perform collective algebraic optimization. Use other forms of input for that.
  329. """
  330. group_name = _resolve_group_name(group, tag)
  331. group_size = c10d._get_group_size_by_name(group_name)
  332. if len(scatter_dim) != len(inputs):
  333. raise AssertionError(
  334. f"Length of scatter_dim ({len(scatter_dim)}) must equal length of inputs ({len(inputs)})"
  335. )
  336. for idx, (dim, tensor) in enumerate(zip(scatter_dim, inputs)):
  337. if tensor.size(dim) % group_size != 0:
  338. raise AssertionError(
  339. f"input dimension {dim} ({tensor.size(dim)} must be a multiple of group_size {group_size} for tensor at index {idx}"
  340. )
  341. if dim != 0:
  342. tensor_list = torch.chunk(tensor, group_size, dim=dim)
  343. inputs[idx] = torch.cat(tensor_list)
  344. tensor_list = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced( # type: ignore[attr-defined]
  345. inputs,
  346. reduceOp.lower(),
  347. group_size,
  348. group_name, # type: ignore[possibly-undefined]
  349. )
  350. return list(map(_maybe_wrap_tensor, tensor_list))
  351. # This is a bit unsafe: it checks if the first argument in the schema reports as a non-mutable alias.
  352. # Today, this maps 1:1 with "aten ops that are views".
  353. def _is_view_op(tgt):
  354. if not isinstance(tgt, torch._ops.OpOverload):
  355. raise AssertionError(f"Expected torch._ops.OpOverload, got {type(tgt)}")
  356. # Don't apply the view optimization to any `CompositeImplicitAutograd` ops.
  357. # See issue: https://github.com/pytorch/pytorch/issues/133421
  358. if torch._C._dispatch_has_kernel_for_dispatch_key(
  359. tgt.name(), torch.DispatchKey.CompositeImplicitAutograd
  360. ):
  361. return False
  362. schema = tgt._schema
  363. if len(schema.arguments) > 0:
  364. first_arg = schema.arguments[0]
  365. # check if op is a view
  366. return first_arg.alias_info is not None and not first_arg.alias_info.is_write
  367. def all_to_all_single(
  368. self: torch.Tensor,
  369. output_split_sizes: list[int] | None,
  370. input_split_sizes: list[int] | None,
  371. group: RANK_TYPES,
  372. tag: str = "",
  373. ) -> torch.Tensor:
  374. """
  375. Each process splits input tensor and then scatters the split list
  376. to all processes in a group. Then concatenate the received tensors from all
  377. the processes in the group and return single output tensor.
  378. Group can be one of:
  379. List[int]: ranks participating in the collective.
  380. List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
  381. ProcessGroup: Will perform a collective using the ranks and tag of the PG.
  382. DeviceMesh: Do a SPMD collective over all ranks of the mesh
  383. (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
  384. :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
  385. that information and perform collective algebraic optimization. Use other forms of input for that.
  386. """
  387. if output_split_sizes is not None:
  388. if not all(
  389. isinstance(size, (int, torch.SymInt)) for size in output_split_sizes
  390. ):
  391. raise AssertionError(
  392. f"All output_split_sizes must be int or SymInt, got {output_split_sizes}"
  393. )
  394. if input_split_sizes is not None:
  395. if not all(isinstance(size, (int, torch.SymInt)) for size in input_split_sizes):
  396. raise AssertionError(
  397. f"All input_split_sizes must be int or SymInt, got {input_split_sizes}"
  398. )
  399. group_name = _resolve_group_name(group, tag)
  400. group_size = c10d._get_group_size_by_name(group_name)
  401. if output_split_sizes is None or input_split_sizes is None:
  402. if not (output_split_sizes is None and input_split_sizes is None):
  403. raise AssertionError(
  404. "output_split_sizes and input_split_sizes must either be "
  405. "specified together or both set to None"
  406. )
  407. output_split_sizes = [self.shape[0] // group_size] * group_size
  408. input_split_sizes = output_split_sizes
  409. tensor = torch.ops._c10d_functional.all_to_all_single( # type: ignore[attr-defined]
  410. self,
  411. output_split_sizes,
  412. input_split_sizes,
  413. group_name,
  414. )
  415. return _maybe_wrap_tensor(tensor)
  416. def all_to_all_single_autograd(
  417. self: torch.Tensor,
  418. output_split_sizes: list[int] | None,
  419. input_split_sizes: list[int] | None,
  420. group: RANK_TYPES,
  421. tag: str = "",
  422. ) -> torch.Tensor:
  423. """
  424. Same as all_to_all_single but supports autograd.
  425. """
  426. if output_split_sizes is not None:
  427. if not all(
  428. isinstance(size, (int, torch.SymInt)) for size in output_split_sizes
  429. ):
  430. raise AssertionError(
  431. f"All output_split_sizes must be int or SymInt, got {output_split_sizes}"
  432. )
  433. if input_split_sizes is not None:
  434. if not all(isinstance(size, (int, torch.SymInt)) for size in input_split_sizes):
  435. raise AssertionError(
  436. f"All input_split_sizes must be int or SymInt, got {input_split_sizes}"
  437. )
  438. group_name = _resolve_group_name(group, tag)
  439. group_size = c10d._get_group_size_by_name(group_name)
  440. if output_split_sizes is None or input_split_sizes is None:
  441. if not (output_split_sizes is None and input_split_sizes is None):
  442. raise AssertionError(
  443. "output_split_sizes and input_split_sizes must either be "
  444. "specified together or both set to None"
  445. )
  446. output_split_sizes = [self.shape[0] // group_size] * group_size
  447. input_split_sizes = output_split_sizes
  448. tensor = torch.ops._c10d_functional_autograd.all_to_all_single( # type: ignore[attr-defined]
  449. self,
  450. output_split_sizes,
  451. input_split_sizes,
  452. group_name,
  453. )
  454. return _FromTorchTensor.apply(tensor)
  455. # ============================================================================
  456. # Collecive Autograd Functions / Custom Ops
  457. # ============================================================================
  458. def wait_tensor_backward(ctx, grad_output: torch.Tensor):
  459. """
  460. Backward for wait_tensor: identity (no-op).
  461. Wait is just a synchronization primitive, so gradient flows through unchanged.
  462. Args:
  463. ctx: Context object
  464. grad_output: Gradient from downstream operations
  465. Returns:
  466. Gradient unchanged (identity)
  467. """
  468. return grad_output
  469. def wait_tensor_setup_context(ctx, inputs, output):
  470. """
  471. Setup context for wait_tensor backward.
  472. Args:
  473. ctx: Context object to save state for backward
  474. inputs: Tuple of (tensor,)
  475. output: Output from forward pass
  476. """
  477. return
  478. torch.library.register_autograd(
  479. "_c10d_functional::wait_tensor",
  480. wait_tensor_backward,
  481. setup_context=wait_tensor_setup_context,
  482. )
  483. def all_reduce_backward(ctx, grad_output: torch.Tensor):
  484. """
  485. Backward for all_reduce: all_reduce with same reduce_op.
  486. Forward aggregates tensors, backward aggregates gradients.
  487. Args:
  488. ctx: Context object
  489. grad_output: Gradient from downstream operations
  490. Returns:
  491. Tuple of (grad_input, grad_group_name, grad_reduce_op)
  492. grad_group_name and grad_reduce_op are None (not differentiable)
  493. """
  494. group_name = ctx.group_name
  495. reduce_op = ctx.reduce_op
  496. if reduce_op != "sum":
  497. raise RuntimeError(
  498. f"all_reduce backward only supports 'sum' reduction, got '{reduce_op}'"
  499. )
  500. # Backward does all_reduce with the same reduce_op
  501. output = torch.ops._c10d_functional.all_reduce(
  502. grad_output.contiguous(), reduce_op, group_name
  503. )
  504. return wait_tensor(output), None, None
  505. def all_reduce_setup_context(ctx, inputs, output):
  506. """
  507. Setup context for all_reduce backward.
  508. Args:
  509. ctx: Context object to save state for backward
  510. inputs: Tuple of (input, reduce_op, group_name)
  511. output: Output from forward pass
  512. """
  513. input, reduce_op, group_name = inputs
  514. ctx.group_name = group_name
  515. ctx.reduce_op = reduce_op.lower()
  516. torch.library.register_autograd(
  517. "_c10d_functional::all_reduce",
  518. all_reduce_backward,
  519. setup_context=all_reduce_setup_context,
  520. )
  521. def all_gather_into_tensor_backward(ctx, grad_output: torch.Tensor):
  522. """
  523. Backward for all_gather_into_tensor: reduce_scatter with sum.
  524. Forward gathers tensors from all ranks, backward scatters gradients back
  525. with sum reduction.
  526. Args:
  527. ctx: Context object with group_name and group_size
  528. grad_output: Gradient from downstream operations
  529. Returns:
  530. Tuple of (grad_input, grad_group_size, grad_group_name)
  531. grad_group_size and grad_group_name are None (not differentiable)
  532. """
  533. group_name = ctx.group_name
  534. group_size = ctx.group_size
  535. # Backward is reduce_scatter with sum
  536. output = torch.ops._c10d_functional.reduce_scatter_tensor(
  537. grad_output.contiguous(),
  538. "sum",
  539. group_size,
  540. group_name,
  541. )
  542. return wait_tensor(output), None, None
  543. def all_gather_into_tensor_setup_context(ctx, inputs, output):
  544. """
  545. Setup context for all_gather_into_tensor backward.
  546. Args:
  547. ctx: Context object to save state for backward
  548. inputs: Tuple of (input, group_size, group_name)
  549. output: Output from forward pass
  550. """
  551. input, group_size, group_name = inputs
  552. ctx.group_name = group_name
  553. ctx.group_size = group_size
  554. torch.library.register_autograd(
  555. "_c10d_functional::all_gather_into_tensor",
  556. all_gather_into_tensor_backward,
  557. setup_context=all_gather_into_tensor_setup_context,
  558. )
  559. def reduce_scatter_tensor_backward(ctx, grad_output: torch.Tensor):
  560. """
  561. Backward for reduce_scatter_tensor: all_gather.
  562. Forward reduces and scatters tensors to ranks, backward gathers gradients
  563. from all ranks.
  564. Args:
  565. ctx: Context object with group_name, group_size, and reduce_op
  566. grad_output: Gradient from downstream operations
  567. Returns:
  568. Tuple of (grad_input, grad_reduce_op, grad_group_size, grad_group_name)
  569. grad_reduce_op, grad_group_size, grad_group_name are None (not differentiable)
  570. """
  571. group_name = ctx.group_name
  572. group_size = ctx.group_size
  573. reduce_op = ctx.reduce_op
  574. # Lazy validation: check reduce_op only when backward is called
  575. if reduce_op != "sum":
  576. raise RuntimeError(
  577. f"reduce_scatter_tensor backward only supports 'sum' reduction, got '{reduce_op}'"
  578. )
  579. # Backward is all_gather
  580. output = torch.ops._c10d_functional.all_gather_into_tensor(
  581. grad_output.contiguous(),
  582. group_size,
  583. group_name,
  584. )
  585. return wait_tensor(output), None, None, None
  586. def reduce_scatter_tensor_setup_context(ctx, inputs, output):
  587. """
  588. Setup context for reduce_scatter_tensor backward.
  589. Args:
  590. ctx: Context object to save state for backward
  591. inputs: Tuple of (input, reduce_op, group_size, group_name)
  592. output: Output from forward pass
  593. """
  594. input, reduce_op, group_size, group_name = inputs
  595. ctx.group_name = group_name
  596. ctx.group_size = group_size
  597. ctx.reduce_op = reduce_op.lower()
  598. torch.library.register_autograd(
  599. "_c10d_functional::reduce_scatter_tensor",
  600. reduce_scatter_tensor_backward,
  601. setup_context=reduce_scatter_tensor_setup_context,
  602. )
  603. def all_to_all_single_backward(ctx, grad_output: torch.Tensor):
  604. """
  605. Backward for all_to_all_single: all_to_all with reversed split sizes.
  606. Forward does all-to-all with specified split sizes, backward reverses them.
  607. Args:
  608. ctx: Context object with group_name, output_split_sizes, and input_split_sizes
  609. grad_output: Gradient from downstream operations
  610. Returns:
  611. Tuple of (grad_input, grad_output_split_sizes, grad_input_split_sizes, grad_group_name)
  612. All except grad_input are None (not differentiable)
  613. """
  614. group_name = ctx.group_name
  615. output_split_sizes = ctx.output_split_sizes
  616. input_split_sizes = ctx.input_split_sizes
  617. # Backward is all_to_all with reversed split sizes
  618. output = torch.ops._c10d_functional.all_to_all_single(
  619. grad_output.contiguous(),
  620. input_split_sizes, # Reversed
  621. output_split_sizes, # Reversed
  622. group_name,
  623. )
  624. return wait_tensor(output), None, None, None
  625. def all_to_all_single_setup_context(ctx, inputs, output):
  626. """
  627. Setup context for all_to_all_single backward.
  628. Args:
  629. ctx: Context object to save state for backward
  630. inputs: Tuple of (input, output_split_sizes, input_split_sizes, group_name)
  631. output: Output from forward pass
  632. """
  633. input, output_split_sizes, input_split_sizes, group_name = inputs
  634. ctx.group_name = group_name
  635. ctx.output_split_sizes = output_split_sizes
  636. ctx.input_split_sizes = input_split_sizes
  637. torch.library.register_autograd(
  638. "_c10d_functional::all_to_all_single",
  639. all_to_all_single_backward,
  640. setup_context=all_to_all_single_setup_context,
  641. )
  642. def all_reduce_coalesced_backward(ctx, grad_outputs: list[torch.Tensor]):
  643. """
  644. Backward for all_reduce_coalesced: all_reduce each gradient.
  645. Forward aggregates tensors, backward aggregates gradients.
  646. Args:
  647. ctx: Context object with group_name and reduce_op
  648. grad_outputs: Gradients from downstream operations (one per input tensor)
  649. Returns:
  650. Tuple of (grad_inputs..., grad_reduce_op, grad_group_name)
  651. grad_reduce_op and grad_group_name are None (not differentiable)
  652. """
  653. group_name = ctx.group_name
  654. reduce_op = ctx.reduce_op
  655. if reduce_op != "sum":
  656. raise RuntimeError(
  657. f"all_reduce_coalesced backward only supports 'sum' reduction, got '{reduce_op}'"
  658. )
  659. # Backward does all_reduce on list of gradients
  660. grad_inputs = torch.ops._c10d_functional.all_reduce_coalesced(
  661. [grad_output.contiguous() for grad_output in grad_outputs],
  662. reduce_op,
  663. group_name,
  664. )
  665. return (list(map(wait_tensor, grad_inputs)), None, None)
  666. def all_reduce_coalesced_setup_context(ctx, inputs, output):
  667. """
  668. Setup context for all_reduce_coalesced backward.
  669. Args:
  670. ctx: Context object to save state for backward
  671. inputs: Tuple of (tensor_list, reduce_op, group_name)
  672. output: Output from forward pass
  673. """
  674. tensor_list, reduce_op, group_name = inputs
  675. ctx.group_name = group_name
  676. ctx.reduce_op = reduce_op.lower()
  677. torch.library.register_autograd(
  678. "_c10d_functional::all_reduce_coalesced",
  679. all_reduce_coalesced_backward,
  680. setup_context=all_reduce_coalesced_setup_context,
  681. )
  682. def all_gather_into_tensor_coalesced_backward(ctx, grad_outputs: list[torch.Tensor]):
  683. """
  684. Backward for all_gather_into_tensor_coalesced: reduce_scatter each gradient.
  685. Forward gathers tensors from all ranks, backward scatters gradients back
  686. with sum reduction.
  687. Args:
  688. ctx: Context object with group_name and group_size
  689. grad_outputs: Gradients from downstream operations (one per input tensor)
  690. Returns:
  691. Tuple of (grad_inputs..., grad_group_size, grad_group_name)
  692. grad_group_size and grad_group_name are None (not differentiable)
  693. """
  694. group_name = ctx.group_name
  695. group_size = ctx.group_size
  696. # Backward does reduce_scatter on list of gradients
  697. grad_inputs = torch.ops._c10d_functional.reduce_scatter_tensor_coalesced(
  698. [grad_output.contiguous() for grad_output in grad_outputs],
  699. "sum",
  700. group_size,
  701. group_name,
  702. )
  703. return (list(map(wait_tensor, grad_inputs)), None, None)
  704. def all_gather_into_tensor_coalesced_setup_context(ctx, inputs, output):
  705. """
  706. Setup context for all_gather_into_tensor_coalesced backward.
  707. Args:
  708. ctx: Context object to save state for backward
  709. inputs: Tuple of (tensor_list, group_size, group_name)
  710. output: Output from forward pass
  711. """
  712. tensor_list, group_size, group_name = inputs
  713. ctx.group_name = group_name
  714. ctx.group_size = group_size
  715. torch.library.register_autograd(
  716. "_c10d_functional::all_gather_into_tensor_coalesced",
  717. all_gather_into_tensor_coalesced_backward,
  718. setup_context=all_gather_into_tensor_coalesced_setup_context,
  719. )
  720. def reduce_scatter_tensor_coalesced_backward(ctx, grad_outputs: list[torch.Tensor]):
  721. """
  722. Backward for reduce_scatter_tensor_coalesced: all_gather each gradient.
  723. Forward reduces and scatters tensors to ranks, backward gathers gradients
  724. from all ranks.
  725. Args:
  726. ctx: Context object with group_name, group_size, and reduce_op
  727. grad_outputs: Gradients from downstream operations (one per input tensor)
  728. Returns:
  729. Tuple of (grad_inputs..., grad_reduce_op, grad_group_size, grad_group_name)
  730. grad_reduce_op, grad_group_size, grad_group_name are None (not differentiable)
  731. """
  732. group_name = ctx.group_name
  733. group_size = ctx.group_size
  734. reduce_op = ctx.reduce_op
  735. # Lazy validation: check reduce_op only when backward is called
  736. if reduce_op != "sum":
  737. raise RuntimeError(
  738. f"reduce_scatter_tensor_coalesced backward only supports 'sum' reduction, got '{reduce_op}'"
  739. )
  740. # Backward does all_gather on list of gradients
  741. grad_inputs = torch.ops._c10d_functional.all_gather_into_tensor_coalesced(
  742. [grad_output.contiguous() for grad_output in grad_outputs],
  743. group_size,
  744. group_name,
  745. )
  746. return (list(map(wait_tensor, grad_inputs)), None, None, None)
  747. def reduce_scatter_tensor_coalesced_setup_context(ctx, inputs, output):
  748. """
  749. Setup context for reduce_scatter_tensor_coalesced backward.
  750. Args:
  751. ctx: Context object to save state for backward
  752. inputs: Tuple of (tensor_list, reduce_op, group_size, group_name)
  753. output: Output from forward pass
  754. """
  755. tensor_list, reduce_op, group_size, group_name = inputs
  756. ctx.group_name = group_name
  757. ctx.group_size = group_size
  758. ctx.reduce_op = reduce_op.lower()
  759. torch.library.register_autograd(
  760. "_c10d_functional::reduce_scatter_tensor_coalesced",
  761. reduce_scatter_tensor_coalesced_backward,
  762. setup_context=reduce_scatter_tensor_coalesced_setup_context,
  763. )
  764. def permute_tensor(
  765. self: torch.Tensor,
  766. src_dst: list[int],
  767. group: RANK_TYPES,
  768. tag: str = "",
  769. ) -> torch.Tensor:
  770. """
  771. Permutes the elements of the tensor according to the given source/destination pairs. `src_dst` should
  772. be defined such that src_dst[m] == n means m sends to n.
  773. Group can be one of:
  774. List[int]: ranks participating in the collective.
  775. List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
  776. ProcessGroup: Will perform a collective using the ranks and tag of the PG.
  777. DeviceMesh: Do a SPMD collective over all ranks of the mesh
  778. (DeviceMesh, int): Do a MPMD collective over one
  779. """
  780. t, rankset, group_size = _expand_group(group, tag)
  781. local_pg = c10d._find_or_create_pg_by_ranks_and_tag(t, rankset, group_size)
  782. output_split_sizes = [0] * group_size
  783. input_split_sizes = [0] * group_size
  784. for src, dst in enumerate(src_dst):
  785. if src == dist.get_rank(local_pg):
  786. input_split_sizes[dst] = self.numel()
  787. if dst == dist.get_rank(local_pg):
  788. output_split_sizes[src] = self.numel()
  789. return all_to_all_single(self, output_split_sizes, input_split_sizes, group, tag)
  790. class AsyncCollectiveTensor(torch.Tensor):
  791. r"""
  792. A Tensor wrapper subclass that is used to trigger a call to wait
  793. prior to first use of the underlying tensor.
  794. Use it inside functional collective pytorch wrappers like the following:
  795. def functional_collective(self, group, tag):
  796. tag, rankset, group_size = _expand_group(group, tag)
  797. tensor = torch.ops.c10d_functional.{collective}(self, tag, rankset, group_size)
  798. return _maybe_wrap_tensor(tensor)
  799. """
  800. elem: torch.Tensor
  801. completed: bool
  802. __slots__ = ["elem", "completed"]
  803. @staticmethod
  804. def __new__(cls, elem: torch.Tensor):
  805. r = torch.Tensor._make_wrapper_subclass(
  806. cls,
  807. elem.size(),
  808. strides=elem.stride(),
  809. storage_offset=elem.storage_offset(),
  810. dtype=elem.dtype,
  811. layout=elem.layout,
  812. device=elem.device,
  813. requires_grad=elem.requires_grad,
  814. )
  815. r.elem = elem
  816. r.completed = False
  817. return r
  818. def __tensor_flatten__(self):
  819. return ["elem"], None
  820. def tolist(self):
  821. return self.trigger_wait().tolist()
  822. @staticmethod
  823. def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
  824. if meta is not None:
  825. raise AssertionError(
  826. "meta must be None for AsyncCollectiveTensor unflatten"
  827. )
  828. elem = inner_tensors["elem"]
  829. return AsyncCollectiveTensor(elem)
  830. def __coerce_same_metadata_as_tangent__(
  831. self, expected_metadata: Any, expected_type: type | None = None
  832. ):
  833. if expected_type is not torch.Tensor:
  834. return None
  835. return self.trigger_wait()
  836. def __repr__(self) -> str: # type: ignore[override]
  837. return f"AsyncCollectiveTensor({self.trigger_wait()})"
  838. def trigger_wait(self):
  839. if not self.completed:
  840. out = wait_tensor(self.elem)
  841. self.completed = True
  842. return out
  843. else:
  844. return self.elem
  845. def wait(self) -> torch.Tensor:
  846. return wait_tensor(self.elem)
  847. def _get_acs_underlying_tensor(self):
  848. """This method enables _functional_collectives_impl to test if a tensor is an ACS"""
  849. return self.elem
  850. @classmethod
  851. def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override]
  852. if func is torch.ops.aten.view.default:
  853. # Fast handle aten.view as a lot of view related op goes to aten.view
  854. # eventually, this avoids pytree slowdown
  855. res = func(args[0].elem, args[1])
  856. wrapper_res = AsyncCollectiveTensor(res)
  857. return wrapper_res
  858. is_view_op = _is_view_op(func)
  859. def unwrap(e: AsyncCollectiveTensor):
  860. # wait_tensor is idepotent and will do stream sync only once
  861. if not is_view_op:
  862. return e.trigger_wait()
  863. return e.elem
  864. def wrap(e: torch.Tensor):
  865. # wait_tensor is idepotent and will do stream sync only once
  866. if isinstance(e, AsyncCollectiveTensor):
  867. raise AssertionError(
  868. "Cannot wrap an AsyncCollectiveTensor inside another AsyncCollectiveTensor"
  869. )
  870. res = AsyncCollectiveTensor(e)
  871. return res
  872. unwrapped_args = tree_map_only(AsyncCollectiveTensor, unwrap, args)
  873. unwrapped_kwargs = tree_map_only(AsyncCollectiveTensor, unwrap, kwargs)
  874. # we don't wrap the result as it doesn't need to be waited on.
  875. out = func(*unwrapped_args, **unwrapped_kwargs)
  876. # View ops dont require a sync, so we should re-wrap the outputs.
  877. if is_view_op:
  878. out = tree_map_only(torch.Tensor, wrap, out)
  879. return out
  880. def numpy(self): # type: ignore[override]
  881. return self.wait().numpy()
  882. """
  883. Utils and infrastructure for tracing support
  884. """
  885. def _expand_group(group: RANK_TYPES, tag: str = "") -> tuple[str, list[int], int]:
  886. """
  887. _expand_group desugars the different RANK_TYPES types into a canonical format that is traceable.
  888. By having this be part of the explicit eager codepath, we avoid having to specialize behavior inside
  889. torchdynamo and can still interoperate with processgroup objects or other untraceable forms.
  890. """
  891. # had to define this hack _inside_ expand_group to avoid
  892. # graph_break [('torch.* op returned non-Tensor int
  893. # caused by 'cast_*` functions being treated as 'torch.*' ops (iiuc)
  894. if TYPE_CHECKING:
  895. def cast_listlistint(x):
  896. return cast(list[list[int]], x)
  897. def cast_listint(x):
  898. return cast(list[int], x)
  899. else:
  900. # fake cast op for use at runtime since dynamo doesn't support real cast
  901. # also, dynamo didn't like encountering 'typing' objects ()
  902. # NotImplementedError: argument of type: <class 'typing._GenericAlias'>
  903. def cast_listlistint(x):
  904. return x
  905. def cast_listint(x):
  906. return x
  907. rankset: list[int]
  908. if isinstance(group, list):
  909. if isinstance(group[0], list):
  910. nested_list = cast_listlistint(group)
  911. rankset = []
  912. group_size = -1
  913. for rs in nested_list:
  914. rankset.extend(rs)
  915. if group_size != -1 and group_size != len(rs):
  916. raise ValueError(
  917. f"group sizes must be identical found {group_size} and {len(rs)}"
  918. )
  919. group_size = len(rs)
  920. else:
  921. rankset = cast_listint(group)
  922. group_size = len(rankset)
  923. elif isinstance(group, dist.ProcessGroup):
  924. rankset = dist.get_process_group_ranks(group)
  925. group_size = len(rankset)
  926. tag = tag or c10d._get_group_tag(group)
  927. elif isinstance(group, DeviceMesh):
  928. if group.ndim != 1:
  929. raise AssertionError(
  930. "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
  931. )
  932. pg = group.get_group()
  933. rankset = dist.get_process_group_ranks(pg)
  934. group_size = len(rankset)
  935. tag = tag or c10d._get_group_tag(pg)
  936. elif isinstance(group, tuple):
  937. if (
  938. len(group) == 2
  939. and isinstance(group[0], DeviceMesh)
  940. and isinstance(group[1], int)
  941. ):
  942. dmesh = group[0]
  943. dim = group[1]
  944. pg = dmesh.get_group(dim)
  945. rankset = dist.get_process_group_ranks(pg)
  946. group_size = len(rankset)
  947. tag = tag or c10d._get_group_tag(pg)
  948. else:
  949. raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
  950. else:
  951. raise ValueError(
  952. "Invalid type for group, must be one of List, Processgroup, DeviceMesh or (DeviceMesh, int)."
  953. )
  954. return (tag, rankset, group_size)
  955. def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> c10d.GroupName:
  956. """
  957. Given group in RANK_TYPES, return the group name.
  958. """
  959. # `tag` will be deprecated. See details in:
  960. # https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208
  961. if isinstance(group, dist.ProcessGroup):
  962. return group.group_name
  963. elif isinstance(group, str):
  964. # In some cases Dynamo doesn't like tracing through NewType constructors
  965. # - so use a cast instead (the actual newtype representation is
  966. # literally the underlying type so this is fine). I haven't been able to
  967. # reproduce it in isolation (see T247631668).
  968. # pyrefly: ignore [redundant-cast]
  969. return cast(c10d.GroupName, group) # c10d.GroupName(group)
  970. elif isinstance(group, DeviceMesh):
  971. if group.ndim != 1:
  972. raise AssertionError(
  973. "Only 1D mesh is supported, pass in (DeviceMesh, int) together if mesh > 1D"
  974. )
  975. return group._dim_group_names[0]
  976. elif isinstance(group, tuple):
  977. if (
  978. len(group) == 2
  979. and isinstance(group[0], DeviceMesh)
  980. and isinstance(group[1], int)
  981. ):
  982. dmesh = group[0]
  983. dim = group[1]
  984. return dmesh._dim_group_names[dim]
  985. else:
  986. raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
  987. elif isinstance(group, list):
  988. if not is_torchdynamo_compiling():
  989. warnings.warn(
  990. "The combination of ranks + tag as process group "
  991. "identifier has been deprecated. Please switch to "
  992. "using ProcessGroup, DeviceMesh, or group name instead.",
  993. FutureWarning,
  994. stacklevel=3,
  995. )
  996. # pyrefly: ignore [redundant-cast]
  997. return c10d._resolve_group_name_by_ranks_and_tag(cast(list[int], group), tag)
  998. else:
  999. raise ValueError(f"Unsupported group type: {type(group)}, {group}")
  1000. class _FromTorchTensor(torch.autograd.Function):
  1001. """
  1002. _FromTorchTensor allows autograd to propagate from a normal Tensor to an
  1003. AsyncCollectiveTensor.
  1004. """
  1005. @staticmethod
  1006. def forward( # type: ignore[override]
  1007. ctx, # pyre-ignore[2]: Parameter must be annotated.
  1008. input: torch.Tensor,
  1009. ) -> torch.Tensor:
  1010. return _maybe_wrap_tensor(input)
  1011. @staticmethod
  1012. def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: # type: ignore[override]
  1013. return grad_output
  1014. @torch.library.custom_op(
  1015. "_c10d_functional::_wrap_tensor_autograd",
  1016. mutates_args=(),
  1017. schema="(Tensor input) -> Tensor",
  1018. )
  1019. def _wrap_tensor_autograd(input: torch.Tensor) -> torch.Tensor:
  1020. """
  1021. Custom op that allows autograd to propagate
  1022. from a normal Tensor to an AsyncCollectiveTensor.
  1023. This is the low-level implementation. Users should call _maybe_wrap_tensor directly.
  1024. Args:
  1025. input: Input tensor to wrap in AsyncCollectiveTensor
  1026. Returns:
  1027. AsyncCollectiveTensor wrapping the input (or wait_tensor result if tracing)
  1028. """
  1029. return AsyncCollectiveTensor(input)
  1030. @_wrap_tensor_autograd.register_fake
  1031. def _(input: torch.Tensor) -> torch.Tensor:
  1032. """
  1033. Meta kernel for _wrap_tensor_autograd.
  1034. """
  1035. return torch.empty_like(input)
  1036. def _wrap_tensor_autograd_backward(ctx, grad_output: torch.Tensor):
  1037. """
  1038. Backward for _wrap_tensor_autograd: identity (no-op).
  1039. The wrapping is just for async optimization, gradients flow through unchanged.
  1040. Args:
  1041. ctx: Context object (unused)
  1042. grad_output: Gradient from downstream operations
  1043. Returns:
  1044. Gradient unchanged (identity)
  1045. """
  1046. return grad_output
  1047. def _wrap_tensor_autograd_setup_context(ctx, inputs, output):
  1048. """
  1049. Setup context for _wrap_tensor_autograd backward.
  1050. Args:
  1051. ctx: Context object to save state for backward (nothing to save)
  1052. inputs: Tuple of (input,)
  1053. output: Output from forward pass
  1054. """
  1055. return
  1056. _wrap_tensor_autograd.register_autograd(
  1057. _wrap_tensor_autograd_backward,
  1058. setup_context=_wrap_tensor_autograd_setup_context,
  1059. )
  1060. def _are_we_tracing() -> bool:
  1061. if is_torchdynamo_compiling():
  1062. return True
  1063. # If fake mode is turned on, we are almost definitely compiling/tracing.
  1064. if torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) is not None:
  1065. return True
  1066. # See Note [enable_python_dispatcher in dynamo]
  1067. if torch._C._dispatch_tls_is_dispatch_key_included(
  1068. torch._C.DispatchKey.PythonDispatcher
  1069. ):
  1070. return True
  1071. return get_proxy_mode() is not None
  1072. def _maybe_wrap_tensor(self) -> torch.Tensor:
  1073. if _are_we_tracing():
  1074. return wait_tensor(self)
  1075. return _wrap_tensor_autograd(self)
  1076. @contextlib.contextmanager
  1077. def allow_inflight_collective_as_graph_input_ctx(value: bool = True):
  1078. """
  1079. Context manager to temporarily set whether inflight collectives are allowed as torch.compile graph inputs.
  1080. Common use case is when the collective is issued in eager (with `async_op=True`) but waited in compiled region:
  1081. ```
  1082. def all_reduce_eager(x):
  1083. y = x * x
  1084. req = dist.all_reduce(y, op=dist.ReduceOp.SUM, async_op=True)
  1085. return y
  1086. @torch.compile(fullgraph=True)
  1087. def all_reduce_wait_compiled(y):
  1088. torch.ops.c10d_functional.wait_tensor(y)
  1089. return y * y
  1090. x = torch.ones(1280, 1280, device="cuda") + self.rank
  1091. # the context manager ensures that `wait_tensor(y)` will wait on the correct work object
  1092. with allow_inflight_collective_as_graph_input_ctx():
  1093. y = all_reduce_eager(x)
  1094. z = all_reduce_wait_compiled(y)
  1095. ```
  1096. With this context manager, when a collective is called, under the hood the work object of the collective
  1097. will be registered in the work registry, and the wait_tensor() in compiled region called on
  1098. the output tensor of the collective will wait on the correct work object.
  1099. """
  1100. previous = torch._C._distributed_c10d._allow_inflight_collective_as_graph_input()
  1101. try:
  1102. torch._C._distributed_c10d._set_allow_inflight_collective_as_graph_input(value)
  1103. yield
  1104. finally:
  1105. torch._C._distributed_c10d._set_allow_inflight_collective_as_graph_input(
  1106. previous
  1107. )
  1108. def _make_all_gather_out_tensor(input, group_size):
  1109. out_size = list(input.size())
  1110. if len(out_size) == 0:
  1111. out_size.append(group_size)
  1112. else:
  1113. out_size[0] *= group_size
  1114. out_tensor = input.new_empty(out_size)
  1115. return out_tensor
  1116. def _all_gather_into_tensor_coalesced_meta(self, tag, rankset, group_size):
  1117. return [_make_all_gather_out_tensor(t, group_size) for t in self]
  1118. # We now register meta kernels to deal with tracing
  1119. def _broadcast_meta(self, *args):
  1120. return torch.empty_like(self)
  1121. def _all_reduce_meta(self, *args):
  1122. return torch.empty_like(self, memory_format=torch.contiguous_format)
  1123. def _wait_tensor_meta(self, *args):
  1124. return torch.empty_like(self)
  1125. def _all_gather_into_tensor_meta(shard, tag, rankset, group_size):
  1126. return _make_all_gather_out_tensor(shard, group_size)
  1127. def _reduce_scatter_tensor_meta(input, reduce_op, tag, rankset, group_size):
  1128. out_size = list(input.size())
  1129. out_size[0] //= group_size
  1130. return input.new_empty(out_size)
  1131. def _all_reduce_coalesced_meta(self, *args):
  1132. return [torch.empty_like(t) for t in self]
  1133. def _all_reduce__meta(inp, *args):
  1134. return inp
  1135. def _broadcast__meta(inp, *args):
  1136. return inp
  1137. def _all_reduce_coalesced__meta(inputs, *args):
  1138. return inputs
  1139. def _reduce_scatter_tensor_coalesced_meta(inputs, reduceOp, tag, rankset, group_size):
  1140. def mk_out_tensor(input):
  1141. out_size = list(input.size())
  1142. out_size[0] //= group_size
  1143. out_tensor = input.new_empty(out_size)
  1144. return out_tensor
  1145. return [mk_out_tensor(t) for t in inputs]
  1146. # NB: We often say all_to_all has dynamic output size, but this is not
  1147. # technically true: instead, what typically happens is you manually
  1148. # communicate the output_split_sizes ahead of time (which is dynamic),
  1149. # but then you pass those sizes explicitly, and the all to all itself
  1150. # isn't dynamic, it just follows the specified output splits
  1151. def _all_to_all_single_meta(
  1152. input, output_split_sizes, input_split_sizes, *args, **kwargs
  1153. ):
  1154. if output_split_sizes is None:
  1155. return input.new_empty(input.size())
  1156. else:
  1157. for s in output_split_sizes:
  1158. torch._check(s >= 0)
  1159. out_size = list(input.size())
  1160. out_size[0] = sum(output_split_sizes)
  1161. return input.new_empty(out_size)
  1162. def _all_gather_into_tensor_out_native_meta(input, group_size, group_name, *, out):
  1163. return _make_all_gather_out_tensor(input, group_size)
  1164. def _all_gather_into_tensor_native_meta(input, group_size, group_name):
  1165. return _make_all_gather_out_tensor(input, group_size)
  1166. def _all_gather_into_tensor_coalesced_native_meta(inputs, group_size, group_name):
  1167. return [
  1168. _all_gather_into_tensor_native_meta(input, group_size, group_name)
  1169. for input in inputs
  1170. ]
  1171. def _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name):
  1172. shape = list(inp.size())
  1173. shape[0] //= group_size
  1174. return inp.new_empty(shape)
  1175. def _reduce_scatter_tensor_out_native_meta(
  1176. inp, reduce_op, group_size, group_name, *, out
  1177. ):
  1178. shape = list(inp.size())
  1179. shape[0] //= group_size
  1180. return inp.new_empty(shape)
  1181. def _reduce_scatter_tensor_coalesced_native_meta(
  1182. inputs, reduce_op, group_size, group_name
  1183. ):
  1184. return [
  1185. _reduce_scatter_tensor_native_meta(inp, reduce_op, group_size, group_name)
  1186. for inp in inputs
  1187. ]
  1188. # Library MUST be defined at module scope or it doesn't work
  1189. lib_impl = torch.library.Library("_c10d_functional", "IMPL")
  1190. lib_impl.impl("all_reduce", _all_reduce_meta, "Meta")
  1191. lib_impl.impl("all_reduce_", _all_reduce__meta, "Meta")
  1192. lib_impl.impl("all_reduce_coalesced", _all_reduce_coalesced_meta, "Meta")
  1193. lib_impl.impl("all_reduce_coalesced_", _all_reduce_coalesced__meta, "Meta")
  1194. lib_impl.impl("wait_tensor", _wait_tensor_meta, "Meta")
  1195. lib_impl.impl(
  1196. "all_gather_into_tensor_out", _all_gather_into_tensor_out_native_meta, "Meta"
  1197. )
  1198. lib_impl.impl("all_gather_into_tensor", _all_gather_into_tensor_native_meta, "Meta")
  1199. lib_impl.impl(
  1200. "all_gather_into_tensor_coalesced",
  1201. _all_gather_into_tensor_coalesced_native_meta,
  1202. "Meta",
  1203. )
  1204. lib_impl.impl("reduce_scatter_tensor", _reduce_scatter_tensor_native_meta, "Meta")
  1205. lib_impl.impl(
  1206. "reduce_scatter_tensor_out", _reduce_scatter_tensor_out_native_meta, "Meta"
  1207. )
  1208. lib_impl.impl(
  1209. "reduce_scatter_tensor_coalesced",
  1210. _reduce_scatter_tensor_coalesced_native_meta,
  1211. "Meta",
  1212. )
  1213. lib_impl.impl("all_to_all_single", _all_to_all_single_meta, "Meta")
  1214. lib_impl.impl("broadcast", _broadcast_meta, "Meta")
  1215. lib_impl.impl("broadcast_", _broadcast__meta, "Meta")
  1216. # mark these ops has side effect so that they won't be removed by DCE
  1217. torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor.default) # type: ignore[has-type]
  1218. torch.fx.node.has_side_effect(torch.ops._c10d_functional.wait_tensor) # type: ignore[has-type]
  1219. # Register legacy ops for backward compatibility
  1220. # TODO(yifu): remove these in functional collective beta release
  1221. legacy_lib = torch.library.Library("c10d_functional", "DEF")
  1222. legacy_lib_impl = torch.library.Library("c10d_functional", "IMPL")
  1223. ops_defs = [
  1224. "broadcast(Tensor self, int src, str tag, int[] ranks, int group_size) -> Tensor",
  1225. "all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor",
  1226. "all_reduce_coalesced(Tensor[] self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]",
  1227. "wait_tensor(Tensor self) -> Tensor",
  1228. "all_gather_into_tensor(Tensor shard, str tag, int[] ranks, int group_size) -> Tensor",
  1229. "all_gather_into_tensor_coalesced(Tensor[] input, str tag, int[] ranks, int group_size) -> Tensor[]",
  1230. "reduce_scatter_tensor(Tensor input, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor",
  1231. "reduce_scatter_tensor_coalesced(Tensor[] inputs, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor[]",
  1232. "all_to_all_single(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor", # noqa: B950
  1233. ]
  1234. my_module = sys.modules[__name__]
  1235. for op_def in ops_defs:
  1236. op_name = op_def[0 : op_def.index("(")]
  1237. backend_impl = getattr(fun_col_impl, f"_{op_name}")
  1238. legacy_lib.define(op_def, tags=torch.Tag.pt2_compliant_tag)
  1239. legacy_lib_impl.impl(op_name, backend_impl, "CompositeImplicitAutograd")
  1240. """
  1241. Dynamo Remappings allow seamless translation from non-functional collectives of supportable form into
  1242. functional collective calls followed by inplace copy ops, allowing them to be traced into a functional graph.
  1243. We implement this by writing a decomposition and teaching dynamo how to associate it to a corresponding op via
  1244. the mapping dict below.
  1245. These schemas intentionally match torch.distributed.distributed_c10d.* ops that we are trying to remap from
  1246. """
  1247. def all_gather_tensor_inplace(
  1248. output_tensor: torch.Tensor,
  1249. input_tensor: torch.Tensor,
  1250. group=None, # TODO add a type,
  1251. async_op: bool = False,
  1252. tag: str = "",
  1253. gather_dim: int = 0,
  1254. ):
  1255. if async_op:
  1256. raise AssertionError(
  1257. "Can't remap async version of inplace op to functional collective"
  1258. )
  1259. group = group or dist.group.WORLD
  1260. if group is None:
  1261. raise AssertionError("group cannot be None")
  1262. return output_tensor.copy_(all_gather_tensor(input_tensor, gather_dim, group, tag))
  1263. def reduce_scatter_tensor_inplace(
  1264. output: torch.Tensor,
  1265. input: torch.Tensor,
  1266. op: str = "sum", # TODO type is actually c10d ReduceOp. is this ok?
  1267. group=None, # TODO add a type
  1268. async_op: bool = False,
  1269. scatter_dim: int = 0,
  1270. tag: str = "",
  1271. ):
  1272. if async_op:
  1273. raise AssertionError(
  1274. "Can't remap async version of inplace op to functional collective"
  1275. )
  1276. group = group or dist.group.WORLD
  1277. if group is None:
  1278. raise AssertionError("group cannot be None")
  1279. return output.copy_(reduce_scatter_tensor(input, op, scatter_dim, group, tag))
  1280. REDUCE_OP_TO_STR = {
  1281. dist.ReduceOp.SUM: "sum",
  1282. dist.ReduceOp.AVG: "avg",
  1283. dist.ReduceOp.PRODUCT: "product",
  1284. dist.ReduceOp.MIN: "min",
  1285. dist.ReduceOp.MAX: "max",
  1286. dist.ReduceOp.BAND: "band",
  1287. dist.ReduceOp.BOR: "bor",
  1288. dist.ReduceOp.BXOR: "bxor",
  1289. }
  1290. def all_reduce_inplace(
  1291. tensor: torch.Tensor,
  1292. op: str = "sum",
  1293. group=None,
  1294. async_op: bool = False,
  1295. tag: str = "",
  1296. ):
  1297. if async_op:
  1298. raise AssertionError(
  1299. "Can't remap async version of inplace op to functional collective"
  1300. )
  1301. group = group or dist.group.WORLD
  1302. if group is None:
  1303. raise AssertionError("group cannot be None")
  1304. return tensor.copy_(all_reduce(tensor, op, group, tag))
  1305. def all_to_all_inplace(
  1306. output: torch.Tensor,
  1307. input: torch.Tensor,
  1308. output_split_sizes=None,
  1309. input_split_sizes=None,
  1310. group=None,
  1311. async_op=False,
  1312. tag: str = "",
  1313. ):
  1314. if async_op:
  1315. raise AssertionError(
  1316. "Can't remap async version of inplace op to functional collective"
  1317. )
  1318. group = group or dist.group.WORLD
  1319. if group is None:
  1320. raise AssertionError("group cannot be None")
  1321. return output.copy_(
  1322. all_to_all_single(
  1323. input,
  1324. output_split_sizes,
  1325. input_split_sizes,
  1326. group,
  1327. tag,
  1328. )
  1329. )
  1330. def all_gather_inplace(
  1331. tensor_list: list[torch.Tensor],
  1332. tensor: torch.Tensor,
  1333. group=None,
  1334. async_op=False,
  1335. tag: str = "",
  1336. ):
  1337. if async_op:
  1338. raise AssertionError(
  1339. "Can't remap async version of inplace op to functional collective"
  1340. )
  1341. if tensor.dim() != 0 and not all(t.size(0) == tensor.size(0) for t in tensor_list):
  1342. raise AssertionError("Remapping variable size all_gather is not yet supported")
  1343. group = group or dist.group.WORLD
  1344. if group is None:
  1345. raise AssertionError("group cannot be None")
  1346. output = all_gather_tensor(tensor, 0, group, tag)
  1347. # Use aten.slice instead of aten.split because the latter causes
  1348. # tensor.shape(0) to be unnecessarily baked in when it's a SymInt.
  1349. output_splits = []
  1350. offset = 0
  1351. for t in tensor_list:
  1352. is_scalar = t.dim() == 0
  1353. t_offset = 1 if is_scalar else t.size(0)
  1354. out = output[offset] if is_scalar else output[offset : offset + t_offset]
  1355. output_splits.append(out)
  1356. offset += t_offset
  1357. for dst, src in zip(tensor_list, output_splits):
  1358. dst.copy_(src)
  1359. return tensor_list
  1360. from torch.distributed.distributed_c10d import (
  1361. _all_gather_base as legacy_all_gather_base,
  1362. _reduce_scatter_base as legacy_reduce_scatter_base,
  1363. all_gather as legacy_all_gather,
  1364. all_gather_into_tensor as legacy_allgather,
  1365. all_reduce as legacy_allreduce,
  1366. all_to_all_single as legacy_all_to_all_single,
  1367. reduce_scatter_tensor as legacy_reducescatter,
  1368. )
  1369. # This dict should contain sets of functions that dynamo is allowed to remap.
  1370. # Functions in this set should accept the same args/kwargs 1:1 as their mapping.
  1371. traceable_collective_remaps = {
  1372. legacy_allgather: all_gather_tensor_inplace, # type: ignore[has-type]
  1373. legacy_reducescatter: reduce_scatter_tensor_inplace, # type: ignore[has-type]
  1374. legacy_allreduce: all_reduce_inplace, # type: ignore[has-type]
  1375. legacy_all_to_all_single: all_to_all_inplace, # type: ignore[has-type]
  1376. legacy_all_gather: all_gather_inplace, # type: ignore[has-type]
  1377. legacy_reduce_scatter_base: reduce_scatter_tensor_inplace, # type: ignore[has-type]
  1378. legacy_all_gather_base: all_gather_tensor_inplace, # type: ignore[has-type]
  1379. }