comms.py 103 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722
  1. # mypy: allow-untyped-defs
  2. # pyre-strict
  3. from __future__ import annotations
  4. import heapq
  5. import importlib
  6. import itertools
  7. import logging
  8. import operator
  9. import sys
  10. from collections import defaultdict
  11. from dataclasses import dataclass
  12. from typing import Any, Optional, TYPE_CHECKING, Union
  13. import torch
  14. from torch._logging import trace_structured
  15. from torch.multiprocessing.reductions import StorageWeakRef
  16. from torch.utils._ordered_set import OrderedSet
  17. from . import config, config_comms, ir
  18. from .dependencies import WeakDep
  19. if TYPE_CHECKING:
  20. from .ir import IRNode, Operation
  21. from .memory import (
  22. estimate_peak_memory_allocfree,
  23. FreeableInputBuffer,
  24. get_freeable_input_buf,
  25. SNodeMemory,
  26. )
  27. from .utils import (
  28. contains_collective,
  29. contains_wait,
  30. find_recursive_deps_of_node,
  31. find_recursive_users_of_node,
  32. is_collective,
  33. is_fallback_op,
  34. is_wait,
  35. )
  36. from .virtualized import V
  37. log = logging.getLogger(__name__)
  38. overlap_log = torch._logging.getArtifactLogger(__name__, "overlap")
  39. if TYPE_CHECKING:
  40. from torch._inductor.scheduler import BaseSchedulerNode
  41. def align_runtime_estimations_across_all_distributed_ranks(
  42. snodes: list[BaseSchedulerNode],
  43. ):
  44. from torch._inductor.scheduler import _get_mm_like_fn
  45. runtime_estimations = {}
  46. runtime_estimations_for_mms = {}
  47. for snode in snodes:
  48. runtime_estimations[snode] = snode.get_estimated_runtime()
  49. if _get_mm_like_fn(snode) is not None:
  50. runtime_estimations_for_mms[snode] = runtime_estimations[snode]
  51. import torch.distributed as dist
  52. from torch.distributed.distributed_c10d import _get_default_group
  53. world_size = dist.get_world_size()
  54. pg = _get_default_group()
  55. gathered_runtime_estimations_for_mms: list[list[float]] = [
  56. [] for _ in range(world_size)
  57. ]
  58. dist.all_gather_object(
  59. gathered_runtime_estimations_for_mms,
  60. list(runtime_estimations_for_mms.values()),
  61. pg,
  62. )
  63. median_runtime_estimations_for_mms = torch.median(
  64. torch.tensor(gathered_runtime_estimations_for_mms), dim=0
  65. ).values.tolist()
  66. for idx, snode in enumerate(runtime_estimations_for_mms.keys()):
  67. runtime_estimations_for_mms[snode] = median_runtime_estimations_for_mms[idx]
  68. for snode in snodes:
  69. if snode in runtime_estimations_for_mms:
  70. runtime_estimations[snode] = runtime_estimations_for_mms[snode]
  71. snode.override_estimated_runtime = runtime_estimations[snode]
  72. def sink_waits(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
  73. """
  74. Greedily schedules waits as late as possible.
  75. """
  76. return _schedule_for_comm(
  77. snodes, raise_comms=False, sink_waits=True, reorder_for_overlap=False
  78. )
  79. def raise_comms(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
  80. """
  81. Greedily schedules comms as early as possible.
  82. """
  83. return _schedule_for_comm(
  84. snodes, raise_comms=True, sink_waits=False, reorder_for_overlap=False
  85. )
  86. def reorder_compute_for_overlap(
  87. snodes: list[BaseSchedulerNode],
  88. ) -> list[BaseSchedulerNode]:
  89. """
  90. This achieves the following overall scheduling procedure:
  91. Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes
  92. that are required for comm N + 1 but do not depend on comm N, to run at the same time with comm N.
  93. Step 2: If all those compute nodes are sufficient to overlap comm N, we're done.
  94. Otherwise, we now need to look elsewhere to find compute that overlaps with comm N.
  95. We prioritize compute nodes that are needed sooner.
  96. Step 3: We schedule the compute nodes dependent on comm N and required for comm N + 1.
  97. Step 4: We schedule comm N + 1.
  98. Repeat this for subsequent comm nodes.
  99. """
  100. return _schedule_for_comm(
  101. snodes, raise_comms=True, sink_waits=True, reorder_for_overlap=True
  102. )
  103. def reorder_communication_preserving_peak_memory(
  104. snodes: list[BaseSchedulerNode],
  105. ) -> list[BaseSchedulerNode]:
  106. """
  107. Reorders communication ops relative to computation ops to improve communication-compute overlapping and hide comm
  108. latency. Stops moving a particular op if it reaches a point that would have increased the peak memory footprint.
  109. Currently, follows these heuristics (subject to change or tune):
  110. - never reorders collectives relative to one another, for SPMD safety
  111. - has an option for per-collective prefetch limit, but does not enable it by default
  112. - limits the total number of reorder steps to some factor of the graph size to prevent worst-case quadratic
  113. performance
  114. Prerequisite: sink_comms_and_waits - ensure comm and wait nodes are scheduled as late as possible, respecting data
  115. dependencies. That allows reorder_communication_preserving_peak_memory to take a best case peak-memory snapshot,
  116. and then monotonically improve latency by moving collectives backward in time.
  117. Peak memory impact is computed in an iterative fashion. First, memory use at each timestep is computed, and global
  118. peak memory is computed as a max over timesteps. Then, when swapping any two adjacent nodes, only the curr-memory
  119. for the earlier of the nodes after the swap is affected. This enables checking step by step whether a swap is
  120. peak-memory-safe, and bailing out if not. Example:
  121. 0 n0 C0
  122. 1 n1 C0 + Allocs(n1) - Frees(n1)
  123. 2 n2 C0 + Allocs(n1) - Frees(n1) + Allocs(n2) - Frees(n2)
  124. 0 n0 C0
  125. 1 n2 C0 + Allocs(n2) - Frees(n2) <-- After moving n2 to Time 1, only time1 memory changes
  126. 2 n1 C0 + Allocs(n2) - Frees(n2) + Allocs(n1) - Frees(n1)
  127. """
  128. reordered_snodes, node_stats = (
  129. _reorder_communication_preserving_peak_memory_internal(snodes)
  130. )
  131. return reordered_snodes
  132. @dataclass
  133. class ReorderInfo:
  134. """
  135. Debug info describing how an individual snode was reordered
  136. """
  137. limiting_factor: str = "None"
  138. moves: int = 0
  139. grouped: int = 0
  140. grouped_info: str = ""
  141. comm_time: float = -1.0
  142. comp_time: float = -1.0
  143. initial_exposed: float = -1.0
  144. final_exposed: float = -1.0
  145. overlap_info: str = "None"
  146. @property
  147. def improvement(self):
  148. return self.initial_exposed - self.final_exposed
  149. def is_gemm_like(node: Optional[Union[IRNode, Operation]]) -> bool:
  150. if node is None:
  151. return False
  152. if is_fallback_op(
  153. node, # type: ignore[arg-type]
  154. torch.ops.aten._scaled_dot_product_flash_attention.default,
  155. ):
  156. return True
  157. if (
  158. python_kernel_name := getattr(node, "python_kernel_name", None)
  159. ) and "extern_kernels" in python_kernel_name:
  160. return True
  161. return False
  162. def contains_gemm_like(snode: BaseSchedulerNode) -> bool:
  163. from torch._inductor.scheduler import GroupedSchedulerNode
  164. if isinstance(snode, GroupedSchedulerNode):
  165. return any(contains_gemm_like(x) for x in snode.snodes)
  166. else:
  167. return is_gemm_like(snode.node)
  168. def _temp_group_visit_leaves(snode: BaseSchedulerNode, fn):
  169. from torch._inductor.scheduler import GroupedSchedulerNode
  170. if isinstance(snode, GroupedSchedulerNode) and snode.temp_grouping:
  171. for _snode in snode.snodes:
  172. fn(_snode)
  173. else:
  174. fn(snode)
  175. def _group_name(snode, with_bufs=False) -> str:
  176. ret = ""
  177. for n in snode.snodes:
  178. if ret:
  179. ret += "_"
  180. ret += n.get_name()
  181. if with_bufs:
  182. ret += f"{list(snode.get_buffer_names())}"
  183. return ret
  184. def _is_fake_dep(d):
  185. return isinstance(d, WeakDep) and d.is_fake
  186. def _group_names(gns: list[BaseSchedulerNode]) -> str:
  187. return "~".join([gn.get_name() for gn in gns])
  188. def _initialize_memory_tracking(snodes, graph_inputs, graph_outputs):
  189. """Initialize memory tracking data structures"""
  190. name_to_freeable_input_buf = get_freeable_input_buf(snodes, graph_inputs)
  191. peak_memory, snodes_curr_memory, snodes_allocfree, buf_to_snode_last_use = (
  192. estimate_peak_memory_allocfree(
  193. snodes, name_to_freeable_input_buf, graph_outputs
  194. )
  195. )
  196. _curr_memory = dict(zip(snodes, snodes_curr_memory))
  197. # pyrefly: ignore [unsupported-operation]
  198. _curr_memory[None] = (0, 0)
  199. # Build candidate buffer map for optimization
  200. candidate_buffer_map = _build_candidate_buffer_map(buf_to_snode_last_use)
  201. return (
  202. peak_memory,
  203. _curr_memory,
  204. snodes_allocfree,
  205. buf_to_snode_last_use,
  206. name_to_freeable_input_buf,
  207. candidate_buffer_map,
  208. )
  209. def _initialize_double_linked_list(
  210. snodes: list[BaseSchedulerNode],
  211. ) -> tuple[
  212. dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
  213. dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
  214. BaseSchedulerNode,
  215. ]:
  216. """Create double-linked list structure from snodes"""
  217. _prev = {}
  218. _next = {}
  219. for i, snode in enumerate(snodes):
  220. _prev[snode] = snodes[i - 1] if i > 0 else None
  221. _next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
  222. _head = snodes[0]
  223. return _prev, _next, _head
  224. def _build_candidate_buffer_map(
  225. buf_to_snode_last_use: dict,
  226. ) -> dict[BaseSchedulerNode, OrderedSet]:
  227. """
  228. Build inverted index: node -> set of buffers where node appears in successors.
  229. This optimization reduces buffer iteration from O(total_buffers) to O(buffers_per_node).
  230. Since buffer successors are immutable during reordering, this map doesn't need updates.
  231. Returns:
  232. dict mapping each node to the set of buffers that have this node in their successors
  233. """
  234. node_to_candidate_bufs: dict[BaseSchedulerNode, OrderedSet] = defaultdict(
  235. OrderedSet
  236. )
  237. for buf in buf_to_snode_last_use:
  238. # Add to every successor node's buffer set
  239. for succ_node in buf.mpi_buffer.succ_nodes:
  240. node_to_candidate_bufs[succ_node].add(buf)
  241. return dict(node_to_candidate_bufs)
  242. def _precompute_node_output_sets(
  243. snodes: list[BaseSchedulerNode],
  244. ) -> dict[BaseSchedulerNode, OrderedSet[str]]:
  245. """
  246. Pre-compute output name sets for all nodes.
  247. This optimization avoids creating OrderedSet objects repeatedly during
  248. exposed time calculations.
  249. Returns:
  250. dict mapping each node to a set of its output names
  251. """
  252. return {
  253. snode: OrderedSet(o.get_name() for o in snode.get_outputs()) for snode in snodes
  254. }
  255. def _op_runtime_estimate_mult(snode):
  256. # Apply multipliers for faster experimentation.
  257. # TODO(ivankobzarev): Remove after confirmation that runtime estimations are correct.
  258. if contains_collective(snode):
  259. return config_comms.reorder_sink_runtime_estimations_comm_mult
  260. return config_comms.reorder_sink_runtime_estimations_non_comm_mult
  261. def is_async_collective(snode):
  262. """
  263. Filtering out ops that contain Collective and Wait inside and considered as Collectives.
  264. See contains_collective function.
  265. If the op contains Wait inside - consider as Synchronous compute.
  266. """
  267. if python_kernel_name := getattr(snode.node, "python_kernel_name", None):
  268. if "torch.ops._dtensor.shard_dim_alltoall.default" in python_kernel_name:
  269. return False
  270. return True
  271. def contains_async_collective(snode):
  272. return contains_collective(snode, is_async_collective)
  273. def _group_nodes_from_linked_list(
  274. head: Optional[BaseSchedulerNode],
  275. tail: Optional[BaseSchedulerNode],
  276. next_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
  277. ) -> list[BaseSchedulerNode]:
  278. """
  279. Traverse doubly-linked list from head to tail and return nodes as a list.
  280. Args:
  281. head: Starting node of the segment
  282. tail: Ending node of the segment (inclusive)
  283. next_dict: Dictionary mapping each node to its next node
  284. Returns:
  285. List of nodes from head to tail (inclusive)
  286. """
  287. ret = []
  288. n = head
  289. while True:
  290. if n is not None:
  291. ret.append(n)
  292. if n == tail:
  293. break
  294. n = next_dict[n] # type: ignore[index]
  295. return ret
  296. def _is_corresponding_collective_wait(
  297. collective_snode: BaseSchedulerNode,
  298. wait_snode: BaseSchedulerNode,
  299. node_output_sets: dict[BaseSchedulerNode, frozenset[str]],
  300. node_dep_sets: dict[BaseSchedulerNode, frozenset[str]],
  301. ) -> bool:
  302. """
  303. Check if a wait node corresponds to a given collective node.
  304. Uses pre-computed sets for O(1) lookup.
  305. """
  306. collective_outs = node_output_sets[collective_snode]
  307. unmet_deps = node_dep_sets[wait_snode]
  308. return bool(unmet_deps & collective_outs)
  309. def _coll_exposed_communication_time(
  310. collective_snode: BaseSchedulerNode,
  311. next_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
  312. runtimes: dict[BaseSchedulerNode, float],
  313. node_output_sets: dict[BaseSchedulerNode, frozenset[str]],
  314. node_dep_sets: dict[BaseSchedulerNode, frozenset[str]],
  315. ) -> tuple[float, float, str]:
  316. """
  317. Calculate exposed communication time by iterating directly over linked list.
  318. Avoids O(N) list construction for each call.
  319. The collective_snode is the starting point, iteration continues via next_dict.
  320. """
  321. comm_time = runtimes[collective_snode]
  322. comp_time = 0.0
  323. collective_outs = node_output_sets[collective_snode]
  324. overlap_info = ""
  325. collectives_found: list[BaseSchedulerNode] = []
  326. snode = next_dict[collective_snode]
  327. while snode is not None:
  328. unmet_deps = node_dep_sets[snode]
  329. if unmet_deps & collective_outs:
  330. overlap_info += f"->W[{snode.get_name()}]"
  331. break
  332. if contains_collective(snode):
  333. if not contains_async_collective(snode):
  334. break
  335. else:
  336. collectives_found.append(snode)
  337. snode = next_dict[snode]
  338. continue
  339. if contains_wait(snode):
  340. has_wait_for_collectives_found = False
  341. for _coll in collectives_found:
  342. if _is_corresponding_collective_wait(
  343. collective_snode, snode, node_output_sets, node_dep_sets
  344. ):
  345. has_wait_for_collectives_found = True
  346. break
  347. if has_wait_for_collectives_found:
  348. break
  349. comp_time_before = comp_time
  350. def accumulate_time(_snode: BaseSchedulerNode) -> None:
  351. nonlocal comp_time
  352. comp_time += runtimes[_snode]
  353. _temp_group_visit_leaves(snode, accumulate_time)
  354. comp_time_after = comp_time
  355. overlap_info += f"+{snode.get_name()}[{comp_time_after - comp_time_before}]"
  356. snode = next_dict[snode]
  357. return comm_time, comp_time, overlap_info
  358. def _wait_exposed_communication_time(
  359. wait_snode: BaseSchedulerNode,
  360. head: BaseSchedulerNode,
  361. prev_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
  362. runtimes: dict[BaseSchedulerNode, float],
  363. node_output_sets: dict[BaseSchedulerNode, frozenset[str]],
  364. node_dep_sets: dict[BaseSchedulerNode, frozenset[str]],
  365. ) -> tuple[float, float, str]:
  366. """
  367. Calculate exposed communication time for a wait operation by iterating
  368. directly over linked list backwards. Avoids O(N) list construction.
  369. Iterates from wait_snode backwards using prev_dict to find corresponding collective.
  370. """
  371. comm_time = 0.0
  372. comp_time = 0.0
  373. overlap_info = ""
  374. waits_found: list[BaseSchedulerNode] = []
  375. snode = prev_dict[wait_snode]
  376. while snode is not None:
  377. if contains_wait(snode):
  378. waits_found.append(snode)
  379. if contains_collective(snode):
  380. if _is_corresponding_collective_wait(
  381. snode, wait_snode, node_output_sets, node_dep_sets
  382. ):
  383. comm_time = runtimes[snode]
  384. overlap_info += f"->C[{snode.get_name()}]"
  385. break
  386. if not contains_async_collective(snode):
  387. comp_time = 0.0
  388. snode = prev_dict[snode]
  389. continue
  390. else:
  391. for w in waits_found:
  392. if _is_corresponding_collective_wait(
  393. snode, w, node_output_sets, node_dep_sets
  394. ):
  395. comp_time = 0.0
  396. break # inner loop break
  397. snode = prev_dict[snode]
  398. continue
  399. comp_time_before = comp_time
  400. def accumulate_time(_snode: BaseSchedulerNode) -> None:
  401. nonlocal comp_time
  402. comp_time += runtimes[_snode]
  403. _temp_group_visit_leaves(snode, accumulate_time)
  404. comp_time_after = comp_time
  405. overlap_info += f"+{snode.get_name()}[{comp_time_after - comp_time_before}]"
  406. snode = prev_dict[snode]
  407. return comm_time, comp_time, overlap_info
  408. def _perform_double_linked_list_swap(
  409. candidate: BaseSchedulerNode,
  410. group_head: BaseSchedulerNode,
  411. group_tail: BaseSchedulerNode,
  412. prev_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
  413. next_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
  414. head: BaseSchedulerNode,
  415. ) -> BaseSchedulerNode:
  416. """
  417. Swap positions of candidate and group in doubly-linked list.
  418. Transforms:
  419. candidate_prev -> candidate -> group_head...group_tail -> group_tail_next
  420. Into:
  421. candidate_prev -> group_head...group_tail -> candidate -> group_tail_next
  422. Args:
  423. candidate: Node to swap with group
  424. group_head: First node of group
  425. group_tail: Last node of group
  426. prev_dict: Dictionary mapping nodes to their previous nodes
  427. next_dict: Dictionary mapping nodes to their next nodes
  428. head: Current head of the linked list
  429. Returns:
  430. New head of the linked list (may change if candidate was the head)
  431. """
  432. # 0: Update candidate's previous node
  433. candidate_prev = prev_dict[candidate]
  434. if candidate_prev:
  435. next_dict[candidate_prev] = group_head
  436. prev_dict[group_head] = candidate_prev
  437. # 2: Update group_tail's next node
  438. group_tail_next = next_dict[group_tail]
  439. if group_tail_next:
  440. prev_dict[group_tail_next] = candidate
  441. next_dict[candidate] = group_tail_next
  442. # 1: Link group_tail to candidate
  443. prev_dict[candidate] = group_tail
  444. next_dict[group_tail] = candidate
  445. # Update head if candidate was the head
  446. if head == candidate:
  447. return group_head
  448. return head
  449. def _calculate_potential_peak_memory_reorder(
  450. candidate: BaseSchedulerNode,
  451. gns: list[BaseSchedulerNode],
  452. group_tail: BaseSchedulerNode,
  453. group_peak_memory: int,
  454. candidate_delta_mem: int,
  455. candidate_allocfree: SNodeMemory,
  456. group_n_to_bufs_after_swap_dealloc_by_candidate: dict,
  457. curr_memory: dict,
  458. ) -> tuple[int, dict[BaseSchedulerNode, int]]:
  459. """
  460. Calculate potential peak memory after swapping candidate with group (reorder version).
  461. Computes new memory levels for all affected nodes and returns the potential
  462. peak memory along with cached post-allocation memory values for each node.
  463. Args:
  464. candidate: Node being moved
  465. gns: Group nodes
  466. group_tail: Last node of group
  467. group_peak_memory: Current peak memory within the group
  468. candidate_delta_mem: Net memory change from candidate (alloc - free)
  469. candidate_allocfree: Candidate's allocation/free info
  470. group_n_to_bufs_after_swap_dealloc_by_candidate: Buffers whose deallocation moves to candidate
  471. curr_memory: Current memory state dict
  472. Returns:
  473. Tuple of (potential_peak_memory, post_alloc_update_dict)
  474. """
  475. # Caching calculations of memory for group nodes and candidate,
  476. # to apply without recalculation after swap.
  477. _post_alloc_update: dict[BaseSchedulerNode, int] = {}
  478. potential_peak: int = 0
  479. if not group_n_to_bufs_after_swap_dealloc_by_candidate:
  480. # Not accounting for buffers last use change
  481. potential_peak = max(
  482. group_peak_memory - candidate_delta_mem,
  483. curr_memory[group_tail][1]
  484. - candidate_delta_mem
  485. + candidate_allocfree.size_alloc,
  486. )
  487. return potential_peak, _post_alloc_update
  488. # If candidate will be after group, the starting memory level of group nodes
  489. # changes to the -(candidate.size_alloc - candidate.size_free)
  490. mem_after_reorder_delta: int = -candidate_delta_mem
  491. for gn in gns:
  492. gn_post_alloc_mem = curr_memory[gn][0] + mem_after_reorder_delta
  493. _post_alloc_update[gn] = gn_post_alloc_mem
  494. potential_peak = max(potential_peak, gn_post_alloc_mem)
  495. bufs = group_n_to_bufs_after_swap_dealloc_by_candidate.get(gn)
  496. if bufs is not None:
  497. for buf in bufs:
  498. # Candidate will deallocate those buffers
  499. mem_after_reorder_delta += buf.mpi_buffer.size_free
  500. candidate_mem_post_alloc = (
  501. curr_memory[group_tail][1]
  502. + mem_after_reorder_delta
  503. + candidate_allocfree.size_alloc
  504. )
  505. _post_alloc_update[candidate] = candidate_mem_post_alloc
  506. potential_peak = max(potential_peak, candidate_mem_post_alloc)
  507. return potential_peak, _post_alloc_update
  508. def _update_memory_tracking_after_swap_reorder(
  509. candidate: BaseSchedulerNode,
  510. gns: list[BaseSchedulerNode],
  511. group_tail: BaseSchedulerNode,
  512. candidate_delta_mem: int,
  513. candidate_allocfree: SNodeMemory,
  514. group_n_to_bufs_after_swap_dealloc_by_candidate: dict,
  515. post_alloc_update: dict[BaseSchedulerNode, int],
  516. curr_memory: dict,
  517. buf_to_snode_last_use: dict,
  518. snodes_allocfree: dict,
  519. ) -> None:
  520. """
  521. Update memory tracking structures after swap (reorder version).
  522. Updates curr_memory, buf_to_snode_last_use, and snodes_allocfree dictionaries
  523. to reflect the new memory state after swapping candidate with group.
  524. Args:
  525. candidate: Node that was moved
  526. gns: Group nodes
  527. group_tail: Last node of group
  528. candidate_delta_mem: Net memory change from candidate (alloc - free)
  529. candidate_allocfree: Candidate's allocation/free info
  530. group_n_to_bufs_after_swap_dealloc_by_candidate: Buffers whose deallocation moves to candidate
  531. post_alloc_update: Cached post-allocation memory values
  532. curr_memory: Current memory state dict (mutated)
  533. buf_to_snode_last_use: Buffer to last-use node mapping (mutated)
  534. snodes_allocfree: Node allocation/free info dict (mutated)
  535. """
  536. if not group_n_to_bufs_after_swap_dealloc_by_candidate:
  537. for gn in gns:
  538. cm = curr_memory[gn]
  539. curr_memory[gn] = (
  540. cm[0] - candidate_delta_mem,
  541. cm[1] - candidate_delta_mem,
  542. )
  543. _candidate_post_alloc_mem = (
  544. curr_memory[group_tail][1] + candidate_allocfree.size_alloc
  545. )
  546. _candidate_post_free_mem = (
  547. _candidate_post_alloc_mem - candidate_allocfree.size_free
  548. )
  549. curr_memory[candidate] = (
  550. _candidate_post_alloc_mem,
  551. _candidate_post_free_mem,
  552. )
  553. return
  554. # Candidate becomes last use of some bufs
  555. for bufs in group_n_to_bufs_after_swap_dealloc_by_candidate.values():
  556. for buf in bufs:
  557. buf_to_snode_last_use[buf] = candidate
  558. size_free_to_move_to_candidate_sum: int = 0
  559. for n in gns:
  560. _gn_post_alloc_mem: int = post_alloc_update[n]
  561. size_free_to_move_to_candidate: int = sum(
  562. buf.mpi_buffer.size_free
  563. for buf in group_n_to_bufs_after_swap_dealloc_by_candidate[n]
  564. )
  565. size_free_to_move_to_candidate_sum += size_free_to_move_to_candidate
  566. # group node does not deallocate this after swap
  567. snodes_allocfree[n].size_free -= size_free_to_move_to_candidate
  568. gn_post_free_mem: int = _gn_post_alloc_mem - snodes_allocfree[n].size_free
  569. curr_memory[n] = (_gn_post_alloc_mem, gn_post_free_mem)
  570. _candidate_post_alloc_mem = post_alloc_update[candidate]
  571. snodes_allocfree[candidate].size_free += size_free_to_move_to_candidate_sum
  572. candidate_post_free_mem = (
  573. _candidate_post_alloc_mem - snodes_allocfree[candidate].size_free
  574. )
  575. curr_memory[candidate] = (
  576. _candidate_post_alloc_mem,
  577. candidate_post_free_mem,
  578. )
  579. def _find_buffers_with_changed_last_use(
  580. candidate: BaseSchedulerNode,
  581. gns: list[BaseSchedulerNode],
  582. buf_to_snode_last_use: dict,
  583. candidate_buffer_map: dict[BaseSchedulerNode, OrderedSet],
  584. ) -> dict[BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]]]:
  585. """
  586. Find buffers whose last use will change after swapping candidate with group.
  587. When we swap [candidate [group]] to [[group] candidate], some buffers that
  588. were last used by a group node will now be last used by candidate instead.
  589. This affects memory deallocation timing.
  590. Args:
  591. candidate: The node being moved
  592. gns: Group nodes being swapped with candidate
  593. buf_to_snode_last_use: Mapping of buffers to their current last-use nodes
  594. candidate_buffer_map: Pre-computed map of node -> buffers using that node
  595. Returns:
  596. Dict mapping group nodes to buffers that will change their last-use node
  597. """
  598. group_n_to_bufs_after_swap_dealloc_by_candidate: dict[
  599. BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]]
  600. ] = defaultdict(list)
  601. # Optimization: only check buffers where candidate is a successor
  602. # Reduces from O(all_buffers) to O(buffers_per_candidate)
  603. candidate_bufs = candidate_buffer_map.get(candidate, OrderedSet())
  604. gns_set = OrderedSet(gns) # O(1) membership testing
  605. for buf in candidate_bufs:
  606. snode_last_use = buf_to_snode_last_use[buf]
  607. if snode_last_use in gns_set:
  608. group_n_to_bufs_after_swap_dealloc_by_candidate[snode_last_use].append(buf)
  609. return group_n_to_bufs_after_swap_dealloc_by_candidate
  610. def _is_node_groupable_for_reorder(
  611. candidate: BaseSchedulerNode,
  612. ) -> tuple[bool, Optional[str]]:
  613. """
  614. Check if a candidate node can be grouped with collective during reordering.
  615. This pass processes collectives left to right, so we avoid grouping with
  616. already-processed collectives based on configuration.
  617. Args:
  618. candidate: Node to check for groupability
  619. Returns:
  620. Tuple of (is_groupable, reason_if_not_groupable)
  621. """
  622. # This pass processes collectives left to right,
  623. # Do not group with processed collectives.
  624. # Leaving config for experimentation in 2D
  625. if not config_comms.reorder_iterative_group_with_collectives:
  626. if contains_async_collective(candidate):
  627. return (
  628. False,
  629. f"candidate contains_collective {candidate.get_name()}",
  630. )
  631. if not config_comms.reorder_iterative_use_runtime_estimations:
  632. if contains_gemm_like(candidate):
  633. return False, "contains_gemm_like"
  634. return True, None
  635. def _format_and_log_reordering_stats(
  636. stats: dict[BaseSchedulerNode, ReorderInfo],
  637. head: BaseSchedulerNode,
  638. next_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
  639. original_snodes_num: int,
  640. peak_memory: int,
  641. name_to_freeable_input_buf: dict,
  642. graph_outputs: OrderedSet[str],
  643. ) -> list[BaseSchedulerNode]:
  644. """
  645. Format reordering statistics, log them, and return final node list.
  646. Computes improvement metrics, creates a formatted table (using tabulate if
  647. available), validates the reordered node count, recalculates peak memory,
  648. and logs all information.
  649. Args:
  650. stats: Per-node reordering statistics
  651. head: Head of the reordered linked list
  652. next_dict: Linked list next pointers
  653. original_snodes_num: Original number of nodes (for validation)
  654. peak_memory: Initial peak memory before reordering
  655. name_to_freeable_input_buf: Buffer memory tracking info
  656. graph_outputs: Graph output names
  657. Returns:
  658. Final reordered list of scheduler nodes
  659. """
  660. node_stats = stats
  661. improvement = {snode: node_stats[snode].improvement for snode in node_stats}
  662. total_improvement = sum([improvement[snode] for snode in improvement])
  663. total_moves = sum([node_stats[snode].moves for snode in node_stats])
  664. reorder_log_str = (
  665. f"reorder_communication_preserving_peak_memory improved overlap by {total_improvement} ns"
  666. f" after {total_moves} reorders.\n"
  667. )
  668. headers = [
  669. "Collective node",
  670. "comm_time(us)",
  671. "comp_time(us)",
  672. "initial exposed(us)",
  673. "final exposed(us)",
  674. "improvement(us)",
  675. "limiting factor",
  676. "moves",
  677. "grouped",
  678. "grouped_info",
  679. "overlap_info",
  680. ]
  681. rows = [
  682. [
  683. node_summary(snode),
  684. node_info.comm_time / 1e3,
  685. node_info.comp_time / 1e3,
  686. node_info.initial_exposed / 1e3,
  687. node_info.final_exposed / 1e3,
  688. node_info.improvement / 1e3,
  689. node_info.limiting_factor,
  690. node_info.moves,
  691. node_info.grouped,
  692. node_info.grouped_info,
  693. node_info.overlap_info,
  694. ]
  695. for snode, node_info in node_stats.items()
  696. ]
  697. if importlib.util.find_spec("tabulate"):
  698. from tabulate import tabulate
  699. reorder_log_str += tabulate(
  700. rows,
  701. headers=headers,
  702. )
  703. else:
  704. reorder_log_str += (
  705. "Please `pip install tabulate` to nicely render overlap stats.\n"
  706. )
  707. reorder_log_str += str(headers) + "\n"
  708. reorder_log_str += "\n".join(map(str, rows))
  709. new_snodes = _group_nodes_from_linked_list(head, None, next_dict)
  710. assert len(new_snodes) == original_snodes_num
  711. new_peak_memory, _, _, _ = estimate_peak_memory_allocfree(
  712. new_snodes, name_to_freeable_input_buf, graph_outputs
  713. )
  714. reorder_log_str += f"\n peak_memory_before:{peak_memory}"
  715. reorder_log_str += f"\n peak_memory_after:{new_peak_memory}"
  716. overlap_log.info(reorder_log_str)
  717. trace_structured(
  718. "artifact",
  719. metadata_fn=lambda: {
  720. "name": "reorder_communication_preserving_peak_memory",
  721. "encoding": "string",
  722. },
  723. payload_fn=lambda: reorder_log_str,
  724. )
  725. return new_snodes
  726. def _reorder_communication_preserving_peak_memory_internal(
  727. snodes: list[BaseSchedulerNode],
  728. ) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]:
  729. """
  730. Internal testing helper that also returns debug info.
  731. Returns:
  732. - reordered snodes list
  733. - dict {snode: ReorderInfo}
  734. """
  735. has_collectives = False
  736. for snode in snodes:
  737. if contains_collective(snode):
  738. has_collectives = True
  739. break
  740. if not has_collectives:
  741. return snodes, {}
  742. original_snodes_num = len(snodes)
  743. # heuristic to avoid degenerating to quadratic time
  744. graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
  745. graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
  746. (
  747. peak_memory,
  748. _curr_memory,
  749. snodes_allocfree,
  750. buf_to_snode_last_use,
  751. name_to_freeable_input_buf,
  752. candidate_buffer_map,
  753. ) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs)
  754. runtimes: dict[BaseSchedulerNode, float] = {
  755. snode: estimate_op_runtime(snode) * _op_runtime_estimate_mult(snode)
  756. for snode in snodes
  757. }
  758. # Pre-compute output and dependency sets for O(1) lookup instead of O(N) creation per iteration
  759. node_output_sets: dict[BaseSchedulerNode, frozenset[str]] = {
  760. snode: frozenset(o.get_name() for o in snode.get_outputs()) for snode in snodes
  761. }
  762. node_dep_sets: dict[BaseSchedulerNode, frozenset[str]] = {
  763. snode: frozenset(
  764. d.name for d in snode.unmet_dependencies if not _is_fake_dep(d)
  765. )
  766. for snode in snodes
  767. }
  768. # debug stats
  769. stats: dict[BaseSchedulerNode, ReorderInfo] = {}
  770. total_moves = 0
  771. _prev, _next, _head = _initialize_double_linked_list(snodes)
  772. debug_num_collectives_to_reorder: Optional[int] = (
  773. config_comms.reorder_iterative_debug_limit_to_reorder
  774. )
  775. num_processed_collectives: int = 0
  776. curr: Optional[BaseSchedulerNode] = _head
  777. debug_iterative_memory_recompute = (
  778. config_comms.reorder_iterative_debug_memory_recompute
  779. )
  780. iterative_recompute_error = False
  781. while curr is not None and _next[curr] is not None:
  782. _next_curr = _next[curr]
  783. if iterative_recompute_error:
  784. break
  785. if not contains_async_collective(curr):
  786. curr = _next_curr
  787. continue
  788. if debug_num_collectives_to_reorder is not None and (
  789. num_processed_collectives >= debug_num_collectives_to_reorder
  790. ):
  791. break
  792. num_processed_collectives += 1
  793. info = stats[curr] = ReorderInfo()
  794. comm_time, comp_time, overlap_info = _coll_exposed_communication_time(
  795. curr, _next, runtimes, node_output_sets, node_dep_sets
  796. )
  797. info.comm_time = comm_time
  798. info.comp_time = comp_time
  799. info.initial_exposed = info.final_exposed = comm_time - comp_time
  800. info.overlap_info = overlap_info
  801. candidate = _prev[curr]
  802. group_head = curr
  803. group_tail = curr
  804. group_waits = {}
  805. group_runtime = 0.0
  806. group_peak_memory = _curr_memory[curr][0] # post_alloc memory
  807. # Track group dependencies incrementally - initialize from pre-computed sets
  808. group_unmet_deps_names = OrderedSet(node_dep_sets[curr])
  809. group_output_names = OrderedSet(node_output_sets[curr])
  810. while candidate is not None:
  811. if config_comms.reorder_iterative_use_runtime_estimations and (
  812. info.final_exposed
  813. < -config_comms.reorder_iterative_extra_comm_comp_overlap
  814. * info.comm_time
  815. ):
  816. info.limiting_factor = "unexposed by runtime estimations"
  817. break
  818. if (
  819. not config_comms.reorder_iterative_unsafe_collectives_reorder
  820. and contains_collective(candidate)
  821. ):
  822. info.limiting_factor = "collective ordering"
  823. break
  824. # Early exit: if group has no unmet dependencies, candidate can't have data dependency
  825. data_deps_names = group_unmet_deps_names - group_output_names
  826. if not data_deps_names:
  827. data_dep = False
  828. else:
  829. # Calculate effective dependencies (not satisfied within group)
  830. # Use pre-computed set for O(1) lookup
  831. candidate_out_names = node_output_sets[candidate]
  832. data_dep = bool(candidate_out_names & data_deps_names)
  833. if data_dep:
  834. is_groupable_result, grouping_reason = _is_node_groupable_for_reorder(
  835. candidate
  836. )
  837. if is_groupable_result:
  838. group_head = candidate
  839. # Update incremental dependency tracking using pre-computed sets
  840. group_unmet_deps_names.update(node_dep_sets[candidate])
  841. group_output_names.update(node_output_sets[candidate])
  842. if config_comms.reorder_iterative_use_runtime_estimations:
  843. if contains_wait(candidate):
  844. comm_time, comp_time, _ = _wait_exposed_communication_time(
  845. candidate,
  846. _head,
  847. _prev,
  848. runtimes,
  849. node_output_sets,
  850. node_dep_sets,
  851. )
  852. group_waits[candidate] = comm_time, comp_time
  853. if not contains_async_collective(candidate):
  854. group_runtime += runtimes[candidate]
  855. group_peak_memory = max(
  856. group_peak_memory, _curr_memory[candidate][0]
  857. )
  858. info.grouped += 1
  859. candidate = _prev[candidate]
  860. continue
  861. else:
  862. msg = (
  863. f"data dependency detected"
  864. f"\n candidate:{candidate.get_name()}(outs:{[o.get_name() for o in candidate.get_outputs()]})"
  865. f"\n non_group_reason:{grouping_reason}"
  866. )
  867. info.limiting_factor = msg
  868. break
  869. if config_comms.reorder_iterative_use_runtime_estimations:
  870. # Check if candidate has sync runtime
  871. if not contains_async_collective(candidate):
  872. c_runtime = runtimes[candidate]
  873. if c_runtime > 0 and len(group_waits) > 0:
  874. # pyrefly: ignore[no-matching-overload]
  875. exposed_before = max(0, info.comm_time - info.comp_time)
  876. # pyrefly: ignore[no-matching-overload]
  877. exposed_after = max(
  878. 0, info.comm_time - info.comp_time - c_runtime
  879. )
  880. exposed_delta = exposed_after - exposed_before
  881. for gw_comm_time, gw_comp_time in group_waits.values():
  882. # pyrefly: ignore [no-matching-overload]
  883. gw_exposed_before = max(0, gw_comm_time - gw_comp_time)
  884. # pyrefly: ignore [no-matching-overload]
  885. gw_exposed_after = max(
  886. 0, gw_comm_time - gw_comp_time + c_runtime
  887. )
  888. exposed_delta += gw_exposed_after - gw_exposed_before
  889. if exposed_delta > 0:
  890. info.limiting_factor = (
  891. f"candidate has compute {c_runtime},"
  892. f" group contains waits, total_exposed_delta {exposed_delta}"
  893. )
  894. break
  895. else:
  896. # Update all group_colls comm_time, comp_time
  897. for gw, (
  898. gw_comm_time,
  899. gw_comp_time,
  900. ) in group_waits.items():
  901. group_waits[gw] = (
  902. gw_comm_time,
  903. gw_comp_time - c_runtime,
  904. )
  905. else:
  906. # Candidate is async_collective
  907. # Unsafe collectives reordering
  908. # Cj -> [...group_runtime..., Ci] -> Wj
  909. # Checking that we are not increasing exposed time of Cj
  910. if group_runtime > 0:
  911. comm_time, comp_time, _ = _coll_exposed_communication_time(
  912. candidate, _next, runtimes, node_output_sets, node_dep_sets
  913. )
  914. # pyrefly: ignore[no-matching-overload]
  915. exposed_before = max(0, comm_time - comp_time)
  916. # pyrefly: ignore[no-matching-overload]
  917. exposed_after = max(0, comm_time - comp_time + group_runtime)
  918. exposed_delta = exposed_after - exposed_before
  919. if exposed_delta > 0:
  920. info.limiting_factor = (
  921. f"candidate {candidate.get_name()} is collective,"
  922. f" group_runtime:{group_runtime},"
  923. f" exposed_delta:{exposed_delta} c_comm_time:{comm_time} c_comp_time:{comp_time}"
  924. )
  925. break
  926. # Create group nodes list once for swap operations
  927. gns: list[BaseSchedulerNode] = _group_nodes_from_linked_list(
  928. group_head, group_tail, _next
  929. )
  930. candidate_allocfree: SNodeMemory = snodes_allocfree[candidate]
  931. candidate_delta_mem: int = (
  932. candidate_allocfree.size_alloc - candidate_allocfree.size_free
  933. )
  934. # candidate and one of group nodes are successors of the same buffer
  935. # and last use of the buffer happen in group nodes.
  936. # This last use deallocates it.
  937. # If we swap [candidate [group]] to [[group] candidate],
  938. # candidate becomes the last use
  939. # and deallocated this buffer instead of group node.
  940. # we need to update size_free accordingly to group_node and candidate,
  941. # and recalculate post_alloc, post_free for them.
  942. #
  943. # Buf that changes its last use snode,
  944. # after swap will be deallocated only by candidate,
  945. # while before it was deallocated by group node.
  946. group_n_to_bufs_after_swap_dealloc_by_candidate = (
  947. _find_buffers_with_changed_last_use(
  948. candidate, gns, buf_to_snode_last_use, candidate_buffer_map
  949. )
  950. )
  951. potential_peak, _post_alloc_update = (
  952. _calculate_potential_peak_memory_reorder(
  953. candidate,
  954. gns,
  955. group_tail,
  956. group_peak_memory,
  957. candidate_delta_mem,
  958. candidate_allocfree,
  959. group_n_to_bufs_after_swap_dealloc_by_candidate,
  960. _curr_memory,
  961. )
  962. )
  963. if (
  964. potential_peak - peak_memory
  965. > peak_memory * config_comms.reorder_iterative_peak_memory_budget
  966. ):
  967. info.limiting_factor = (
  968. f"peak memory new:{potential_peak} vs base:{peak_memory}"
  969. )
  970. break
  971. info.moves += 1
  972. total_moves += 1
  973. _head = _perform_double_linked_list_swap(
  974. candidate, group_head, group_tail, _prev, _next, _head
  975. )
  976. comm_time, comp_time, overlap_info = _coll_exposed_communication_time(
  977. curr, _next, runtimes, node_output_sets, node_dep_sets
  978. )
  979. info.comm_time = comm_time
  980. info.comp_time = comp_time
  981. info.overlap_info = overlap_info
  982. info.final_exposed = comm_time - comp_time
  983. _update_memory_tracking_after_swap_reorder(
  984. candidate,
  985. gns,
  986. group_tail,
  987. candidate_delta_mem,
  988. candidate_allocfree,
  989. group_n_to_bufs_after_swap_dealloc_by_candidate,
  990. _post_alloc_update,
  991. _curr_memory,
  992. buf_to_snode_last_use,
  993. snodes_allocfree,
  994. )
  995. if debug_iterative_memory_recompute:
  996. # Compare iteratively recomputed memory data
  997. # with full run of estimate_peak_memory
  998. from .comms_debug import _debug_iterative_memory_recompute
  999. iterative_recompute_error = _debug_iterative_memory_recompute(
  1000. candidate,
  1001. gns,
  1002. _group_names(gns),
  1003. _group_nodes_from_linked_list(_head, None, _next),
  1004. name_to_freeable_input_buf,
  1005. graph_outputs,
  1006. peak_memory,
  1007. _curr_memory,
  1008. snodes_allocfree,
  1009. "reorder_communication_preserving_peak_memory",
  1010. group_n_to_bufs_after_swap_dealloc_by_candidate,
  1011. )
  1012. if iterative_recompute_error:
  1013. break
  1014. candidate = _prev[group_head]
  1015. curr = _next_curr
  1016. if not config_comms.reorder_sink_verbose_logging:
  1017. new_snodes = _group_nodes_from_linked_list(_head, None, _next)
  1018. return new_snodes, stats
  1019. new_snodes = _format_and_log_reordering_stats(
  1020. stats,
  1021. _head,
  1022. _next,
  1023. original_snodes_num,
  1024. peak_memory,
  1025. name_to_freeable_input_buf,
  1026. graph_outputs,
  1027. )
  1028. return new_snodes, stats
  1029. def _schedule_for_comm(
  1030. snodes: list[BaseSchedulerNode],
  1031. raise_comms: bool,
  1032. sink_waits: bool,
  1033. reorder_for_overlap: bool,
  1034. ) -> list[BaseSchedulerNode]:
  1035. """
  1036. Schedule `snodes` for various comm optimization objectives.
  1037. Args:
  1038. snodes: the nodes to be scheduled.
  1039. raise_comms: whether to greedily schedule collectives as early as possible
  1040. sink_wait: whether to greedily schedule waits as late as possible
  1041. reorder_compute_for_overlap: whether to reorder compute nodes to
  1042. optimize for compute/communication overlapping.
  1043. Returns:
  1044. The new schedule order.
  1045. Some notes on the synergy between different options:
  1046. - `raise_comms` provides more overlapping oppurtunies for `reorder_compute_for_overlap`.
  1047. - When both `raise_comms` and `sink_waits` is `True`, `raise_comms` is prioritized.
  1048. """
  1049. # We assign each node a tuple of scores (score_0, score_1, score_2),
  1050. # decreasing in importance, with a lower value indicating a higher ranking:
  1051. #
  1052. # - score_0: the lowest comm_idx among the comm nodes that the node blocks.
  1053. # If a node doesn't block any comm nodes, its score_0 is set to
  1054. # sys.maxsize. This score ensures that comm nodes get scheduled as early as
  1055. # possible.
  1056. # - score_1: 1 if the node is a wait node, 0 otherwise. This score ensures
  1057. # that wait nodes are deferred as late as possible.
  1058. # - score_2: the index of the node in the original topological order. This
  1059. # score provides stability in case of ties.
  1060. #
  1061. # When only raise_comms is True, only score_0 and score_2 are considered.
  1062. # When only sink_waits is True, only score_1 and score_2 are considered.
  1063. # When neither is True, the original order is yielded.
  1064. buf_name_to_snode = {}
  1065. name_to_fused_node = {}
  1066. scores_0, scores_1, scores_2 = {}, {}, {}
  1067. for idx, snode in enumerate(snodes):
  1068. for buf_name in snode.get_buffer_names():
  1069. buf_name_to_snode[buf_name] = snode
  1070. for op_name in snode.get_operation_names():
  1071. name_to_fused_node[op_name] = snode
  1072. name_to_fused_node[snode.get_name()] = snode
  1073. node_name = snode.get_name()
  1074. scores_0[node_name] = sys.maxsize
  1075. scores_1[node_name] = 0
  1076. scores_2[node_name] = idx
  1077. comm_idx = 0
  1078. for snode in snodes:
  1079. if raise_comms and contains_collective(snode):
  1080. scores_0[snode.get_name()] = comm_idx
  1081. for ancestor in snode.ancestors:
  1082. anc_fused_name = name_to_fused_node[ancestor].get_name()
  1083. scores_0[anc_fused_name] = min(scores_0[anc_fused_name], comm_idx)
  1084. comm_idx += 1
  1085. elif sink_waits and contains_wait(snode):
  1086. scores_1[snode.get_name()] = 1
  1087. class Runnable:
  1088. def __init__(self, snode) -> None:
  1089. self.snode = snode
  1090. name = next(iter(snode.get_operation_names()))
  1091. fused_name = name_to_fused_node[name].get_name()
  1092. self.score = (
  1093. scores_0[fused_name],
  1094. scores_1[fused_name],
  1095. scores_2[fused_name],
  1096. )
  1097. def __lt__(self, other):
  1098. return self.score < other.score
  1099. unmet_deps: dict[BaseSchedulerNode, OrderedSet[str]] = {
  1100. snode: OrderedSet(dep.name for dep in snode.unmet_dependencies)
  1101. for snode in snodes
  1102. }
  1103. ready: list[Runnable] = []
  1104. buffer_users: dict[str, OrderedSet[BaseSchedulerNode]] = defaultdict(OrderedSet)
  1105. snode_to_cost = {snode: estimate_op_runtime(snode) for snode in snodes}
  1106. for snode, deps in unmet_deps.items():
  1107. if len(deps) == 0:
  1108. heapq.heappush(ready, Runnable(snode))
  1109. for dep in deps:
  1110. buffer_users[dep].add(snode)
  1111. scheduled = []
  1112. def schedule(snode):
  1113. """
  1114. Schedules `snode` and put all unblocked nodes onto the ready queue.
  1115. """
  1116. scheduled.append(snode)
  1117. for buf_name in snode.get_buffer_names():
  1118. for snode in buffer_users[buf_name]:
  1119. unmet_deps[snode].remove(buf_name)
  1120. if len(unmet_deps[snode]) == 0:
  1121. heapq.heappush(ready, Runnable(snode))
  1122. def get_overlapping_candidate():
  1123. """
  1124. Return the next node in the ready queue that's neither a collective or
  1125. a wait.
  1126. """
  1127. candidates = [
  1128. x
  1129. for x in ready
  1130. if not contains_collective(x.snode) and not contains_wait(x.snode)
  1131. ]
  1132. if len(candidates) == 0:
  1133. return None
  1134. return min(candidates, key=lambda x: x.score)
  1135. def schedule_collective_for_overlap(snode):
  1136. """
  1137. Schedules collective node `snode`, along with one or more compute nodes
  1138. to overlap with it. The strategy is described in the comment of
  1139. `reorder_compute_for_overlap`.
  1140. """
  1141. assert contains_collective(snode)
  1142. schedule(snode)
  1143. collective_cost = snode_to_cost[snode]
  1144. while (
  1145. collective_cost > 0
  1146. and (candidate := get_overlapping_candidate()) is not None
  1147. ):
  1148. ready.remove(candidate)
  1149. schedule(candidate.snode)
  1150. collective_cost -= snode_to_cost[candidate.snode]
  1151. heapq.heapify(ready)
  1152. while ready:
  1153. snode = heapq.heappop(ready).snode
  1154. if reorder_for_overlap and contains_collective(snode):
  1155. schedule_collective_for_overlap(snode)
  1156. else:
  1157. schedule(snode)
  1158. for deps in unmet_deps.values():
  1159. assert len(deps) == 0, (
  1160. f"Detected unscheduled nodes. Nodes with unmet dependencies: {unmet_deps}"
  1161. )
  1162. return scheduled
  1163. def decide_global_ordering_of_comms(
  1164. nodes: list[BaseSchedulerNode], name_to_buf, name_to_fused_node
  1165. ) -> list[BaseSchedulerNode]:
  1166. """
  1167. Decide global ordering of comms, by just enforcing the ordering that's in the input graph
  1168. (might not be the same ordering as the eager mode program).
  1169. TODO: Come up with a better approach
  1170. """
  1171. if not torch.distributed.is_available():
  1172. return nodes
  1173. comm_nodes = [n for n in nodes if contains_collective(n)]
  1174. for i in range(1, len(comm_nodes)):
  1175. # Enforce ordering by making previous comm a `WeakDep` dependency of the next comm
  1176. mutating_buf = next(iter(comm_nodes[i].get_buffer_names()))
  1177. for buf in comm_nodes[i - 1].get_buffer_names():
  1178. comm_nodes[i].add_fake_dep(
  1179. WeakDep(buf, mutating_buf=mutating_buf, is_fake=True)
  1180. )
  1181. return nodes
  1182. @dataclass
  1183. class SinkWaitInfo:
  1184. grouped: int = 0
  1185. grouped_info: str = ""
  1186. moves: int = 0
  1187. moves_info: str = ""
  1188. limiting_factor: str = "None"
  1189. comm_time: float = -1.0
  1190. comp_time: float = -1.0
  1191. initial_exposed: float = -1.0
  1192. final_exposed: float = -1.0
  1193. overlap_info: str = "None"
  1194. @property
  1195. def improvement(self):
  1196. return self.initial_exposed - self.final_exposed
  1197. def _is_node_groupable_for_sink_waits(
  1198. candidate: BaseSchedulerNode,
  1199. ) -> tuple[bool, Optional[str]]:
  1200. """
  1201. Check if a candidate node can be grouped during sink_waits pass.
  1202. Sink Waits traverses waits right to left, so we don't group with
  1203. processed waits on the right or with async collectives.
  1204. Args:
  1205. candidate: Node to check for groupability
  1206. Returns:
  1207. Tuple of (is_groupable, reason_if_not_groupable)
  1208. """
  1209. # Sink Waits traverse Waits right to left,
  1210. # => we do not group with processed Waits on the right.
  1211. if contains_wait(candidate):
  1212. return False, f"candidate contains wait {candidate.get_name()}"
  1213. if contains_async_collective(candidate):
  1214. return (
  1215. False,
  1216. f"candidate contains_async_collective {candidate.get_name()}",
  1217. )
  1218. if not config_comms.sink_iterative_use_runtime_estimations:
  1219. # Heuristics pre-use_runtime_estimations:
  1220. # TODO(ivankobzarev): Remove them after confirming,
  1221. # that using runtime estimations always give better results.
  1222. # We do not want to group with collectives to not reorder them forward.
  1223. if contains_collective(candidate):
  1224. return (
  1225. False,
  1226. f"candidate contains collective {candidate.get_name()}",
  1227. )
  1228. if contains_gemm_like(candidate):
  1229. return (
  1230. False,
  1231. f"candidate contains gemm_like {candidate.get_name()}",
  1232. )
  1233. return True, None
  1234. def _update_memory_tracking_after_swap_sink_waits(
  1235. candidate: BaseSchedulerNode,
  1236. gns: list[BaseSchedulerNode],
  1237. candidate_delta_mem: int,
  1238. candidate_allocfree: SNodeMemory,
  1239. group_n_to_bufs_after_swap_dealloc_instead_of_candidate: dict,
  1240. post_alloc_update: dict[BaseSchedulerNode, int],
  1241. size_free_delta_update: dict[BaseSchedulerNode, int],
  1242. curr_memory: dict,
  1243. snodes_allocfree: dict,
  1244. ) -> None:
  1245. """
  1246. Update memory tracking structures after swap (sink_waits version).
  1247. Updates curr_memory and snodes_allocfree dictionaries to reflect the new
  1248. memory state after swapping candidate with group.
  1249. Args:
  1250. candidate: Node that was moved
  1251. gns: Group nodes
  1252. candidate_delta_mem: Net memory change from candidate (alloc - free)
  1253. candidate_allocfree: Candidate's allocation/free info
  1254. group_n_to_bufs_after_swap_dealloc_instead_of_candidate: Buffers whose deallocation moves from candidate to group
  1255. post_alloc_update: Cached post-allocation memory values
  1256. size_free_delta_update: Cached size-free delta values
  1257. curr_memory: Current memory state dict (mutated)
  1258. snodes_allocfree: Node allocation/free info dict (mutated)
  1259. """
  1260. group_head = gns[0]
  1261. pre_group_mem = curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc
  1262. if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
  1263. candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc
  1264. curr_memory[candidate] = (
  1265. candidate_post_alloc,
  1266. candidate_post_alloc - candidate_allocfree.size_free,
  1267. )
  1268. for gn in gns:
  1269. cm = curr_memory[gn]
  1270. curr_memory[gn] = (
  1271. cm[0] + candidate_delta_mem,
  1272. cm[1] + candidate_delta_mem,
  1273. )
  1274. return
  1275. for n in [candidate, *gns]:
  1276. post_alloc = post_alloc_update[n]
  1277. snodes_allocfree[n].size_free += size_free_delta_update.get(n, 0)
  1278. curr_memory[n] = (
  1279. post_alloc,
  1280. post_alloc - snodes_allocfree[n].size_free,
  1281. )
  1282. def _calculate_potential_peak_memory_sink_waits(
  1283. candidate: BaseSchedulerNode,
  1284. gns: list[BaseSchedulerNode],
  1285. group_head: BaseSchedulerNode,
  1286. group_peak_memory: int,
  1287. candidate_delta_mem: int,
  1288. candidate_allocfree: SNodeMemory,
  1289. group_n_to_bufs_after_swap_dealloc_instead_of_candidate: dict,
  1290. curr_memory: dict,
  1291. snodes_allocfree: dict,
  1292. ) -> tuple[int, dict[BaseSchedulerNode, int], dict[BaseSchedulerNode, int]]:
  1293. """
  1294. Calculate potential peak memory after swapping candidate with group (sink_waits version).
  1295. Computes new memory levels for all affected nodes and returns the potential
  1296. peak memory along with cached post-allocation and size-free delta values.
  1297. Args:
  1298. candidate: Node being moved
  1299. gns: Group nodes
  1300. group_head: First node of group
  1301. group_peak_memory: Current peak memory within the group
  1302. candidate_delta_mem: Net memory change from candidate (alloc - free)
  1303. candidate_allocfree: Candidate's allocation/free info
  1304. group_n_to_bufs_after_swap_dealloc_instead_of_candidate: Buffers whose deallocation moves from candidate to group
  1305. curr_memory: Current memory state dict
  1306. snodes_allocfree: Allocation/free info for all nodes
  1307. Returns:
  1308. Tuple of (potential_peak_memory, post_alloc_update_dict, size_free_delta_update_dict)
  1309. """
  1310. pre_group_mem = curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc
  1311. # Stash memory tracing updates to not recompute them after swap
  1312. _post_alloc_update: dict[BaseSchedulerNode, int] = {}
  1313. _size_free_delta_update: dict[BaseSchedulerNode, int] = {}
  1314. potential_peak = 0
  1315. if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
  1316. # Not accounting for buffers liveliness change
  1317. potential_peak = max(
  1318. group_peak_memory + candidate_delta_mem,
  1319. pre_group_mem + candidate_allocfree.size_alloc,
  1320. )
  1321. return potential_peak, _post_alloc_update, _size_free_delta_update
  1322. candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc
  1323. _post_alloc_update[candidate] = candidate_post_alloc
  1324. potential_peak = candidate_post_alloc
  1325. candidate_size_free_to_move = sum(
  1326. buf.mpi_buffer.size_free # type: ignore[attr-defined]
  1327. for buf in itertools.chain.from_iterable(
  1328. group_n_to_bufs_after_swap_dealloc_instead_of_candidate.values()
  1329. )
  1330. )
  1331. _size_free_delta_update[candidate] = -candidate_size_free_to_move
  1332. delta_mem = candidate_delta_mem + candidate_size_free_to_move
  1333. for gn in gns:
  1334. gn_post_alloc = curr_memory[gn][0] + delta_mem
  1335. _post_alloc_update[gn] = gn_post_alloc
  1336. potential_peak = max(potential_peak, gn_post_alloc)
  1337. gn_size_free_to_add = 0
  1338. if gn in group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
  1339. bufs = group_n_to_bufs_after_swap_dealloc_instead_of_candidate[gn]
  1340. for buf in bufs:
  1341. gn_size_free_to_add += buf.mpi_buffer.size_free
  1342. _size_free_delta_update[gn] = gn_size_free_to_add
  1343. delta_mem -= gn_size_free_to_add
  1344. return potential_peak, _post_alloc_update, _size_free_delta_update
  1345. def _perform_double_linked_list_swap_sink_waits(
  1346. candidate: BaseSchedulerNode,
  1347. group_head: BaseSchedulerNode,
  1348. group_tail: BaseSchedulerNode,
  1349. prev_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
  1350. next_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
  1351. head: BaseSchedulerNode,
  1352. ) -> BaseSchedulerNode:
  1353. """
  1354. Swap positions of candidate and group in doubly-linked list (sink_waits version).
  1355. Transforms (moves candidate to the left):
  1356. group_head_prev -> group_head...group_tail -> candidate -> candidate_next
  1357. Into:
  1358. group_head_prev -> candidate -> group_head...group_tail -> candidate_next
  1359. Args:
  1360. candidate: Node to swap with group
  1361. group_head: First node of group
  1362. group_tail: Last node of group
  1363. prev_dict: Dictionary mapping nodes to their previous nodes
  1364. next_dict: Dictionary mapping nodes to their next nodes
  1365. head: Current head of the linked list
  1366. Returns:
  1367. New head of the linked list (may change if group_head was the head)
  1368. """
  1369. # 0: Update group_head's previous node
  1370. group_head_prev = prev_dict[group_head]
  1371. if group_head_prev:
  1372. next_dict[group_head_prev] = candidate
  1373. prev_dict[candidate] = group_head_prev
  1374. # 2: Update candidate's next node
  1375. candidate_next = next_dict[candidate]
  1376. if candidate_next:
  1377. prev_dict[candidate_next] = group_tail
  1378. next_dict[group_tail] = candidate_next
  1379. # 1: Link candidate to group_head
  1380. prev_dict[group_head] = candidate
  1381. next_dict[candidate] = group_head
  1382. # Update head if group_head was the head
  1383. if group_head == head:
  1384. return candidate
  1385. return head
  1386. def _format_and_log_sink_waits_stats(
  1387. stats: dict[BaseSchedulerNode, SinkWaitInfo],
  1388. head: BaseSchedulerNode,
  1389. next_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
  1390. original_snodes_num: int,
  1391. peak_memory: int,
  1392. name_to_freeable_input_buf: dict,
  1393. graph_outputs: OrderedSet[str],
  1394. ) -> list[BaseSchedulerNode]:
  1395. """
  1396. Format sink_waits statistics, log them, and return final node list.
  1397. Computes improvement metrics, creates a formatted table (using tabulate if
  1398. available), validates the reordered node count, recalculates peak memory,
  1399. and logs all information.
  1400. Args:
  1401. stats: Per-node sink_waits statistics
  1402. head: Head of the reordered linked list
  1403. next_dict: Linked list next pointers
  1404. original_snodes_num: Original number of nodes (for validation)
  1405. peak_memory: Initial peak memory before reordering
  1406. name_to_freeable_input_buf: Buffer memory tracking info
  1407. graph_outputs: Graph output names
  1408. Returns:
  1409. Final reordered list of scheduler nodes
  1410. """
  1411. headers = [
  1412. "Wait node",
  1413. "comm_time(us)",
  1414. "comp_time(us)",
  1415. "initial exposed(us)",
  1416. "final exposed(us)",
  1417. "improvement(us)",
  1418. "limiting factor",
  1419. "grouped",
  1420. "grouped_info",
  1421. "moves",
  1422. "moves_info",
  1423. "overlap_info",
  1424. ]
  1425. rows = [
  1426. [
  1427. node_summary(snode),
  1428. info.comm_time / 1e3,
  1429. info.comp_time / 1e3,
  1430. info.initial_exposed / 1e3,
  1431. info.final_exposed / 1e3,
  1432. info.improvement / 1e3,
  1433. info.limiting_factor,
  1434. info.grouped,
  1435. info.grouped_info,
  1436. info.moves,
  1437. info.moves_info,
  1438. info.overlap_info,
  1439. ]
  1440. for snode, info in stats.items()
  1441. ]
  1442. log_str = ""
  1443. if importlib.util.find_spec("tabulate"):
  1444. from tabulate import tabulate
  1445. log_str += tabulate(
  1446. rows,
  1447. headers=headers,
  1448. )
  1449. else:
  1450. log_str += "Please `pip install tabulate` to nicely render overlap stats.\n"
  1451. log_str += str(headers) + "\n"
  1452. log_str += "\n".join(map(str, rows))
  1453. overlap_log.info(log_str)
  1454. new_snodes = _group_nodes_from_linked_list(head, None, next_dict)
  1455. assert len(new_snodes) == original_snodes_num
  1456. new_peak_memory, _, _, _ = estimate_peak_memory_allocfree(
  1457. new_snodes, name_to_freeable_input_buf, graph_outputs
  1458. )
  1459. log_str += f"\n sink_waits_iterative peak_memory_before:{peak_memory}"
  1460. log_str += f"\n sink_waits_iterative peak_memory_after:{new_peak_memory}"
  1461. trace_structured(
  1462. "artifact",
  1463. metadata_fn=lambda: {
  1464. "name": "sink_waits_iterative_info",
  1465. "encoding": "string",
  1466. },
  1467. payload_fn=lambda: log_str,
  1468. )
  1469. return new_snodes
  1470. def _find_buffers_with_changed_last_use_sink_waits(
  1471. candidate: BaseSchedulerNode,
  1472. gns: list[BaseSchedulerNode],
  1473. buf_to_snode_last_use: dict,
  1474. candidate_buffer_map: dict[BaseSchedulerNode, OrderedSet],
  1475. ) -> dict[BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]]]:
  1476. """
  1477. Find buffers whose last use will change after swapping in sink_waits pass.
  1478. When we swap [group] candidate to candidate [group], some buffers that
  1479. were last used by candidate will now be last used by a group node instead.
  1480. This is the opposite direction from the reorder version.
  1481. Args:
  1482. candidate: The node being moved (currently last use)
  1483. gns: Group nodes being swapped with candidate
  1484. buf_to_snode_last_use: Mapping of buffers to their current last-use nodes
  1485. candidate_buffer_map: Pre-computed map of node -> buffers using that node
  1486. Returns:
  1487. Dict mapping group nodes to buffers that will change their last-use node
  1488. """
  1489. group_n_to_bufs_after_swap_dealloc_instead_of_candidate: dict[
  1490. BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]]
  1491. ] = defaultdict(list)
  1492. # Optimization: only check buffers where candidate is a successor
  1493. # Reduces from O(all_buffers) to O(buffers_per_candidate)
  1494. candidate_bufs = candidate_buffer_map.get(candidate, OrderedSet())
  1495. for buf in candidate_bufs:
  1496. snode_last_use = buf_to_snode_last_use[buf]
  1497. if snode_last_use != candidate: # noqa: E711
  1498. continue
  1499. # candidate is last use of buf
  1500. # Find last group node in successors (maintains order)
  1501. succ_nodes = buf.mpi_buffer.succ_nodes
  1502. last_succ_gn = None
  1503. for gn in gns:
  1504. if gn in succ_nodes:
  1505. last_succ_gn = gn
  1506. if last_succ_gn is None:
  1507. continue
  1508. # gn has successors of buf that after potential swap will become
  1509. # last use of buf and start deallocating buf instead of candidate
  1510. group_n_to_bufs_after_swap_dealloc_instead_of_candidate[last_succ_gn].append(
  1511. buf
  1512. )
  1513. return group_n_to_bufs_after_swap_dealloc_instead_of_candidate
  1514. def _sink_waits_iterative_internal(
  1515. snodes: list[BaseSchedulerNode],
  1516. ) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, SinkWaitInfo]]:
  1517. original_snodes_num = len(snodes)
  1518. if original_snodes_num == 0:
  1519. return snodes, {}
  1520. graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
  1521. graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
  1522. (
  1523. peak_memory,
  1524. _curr_memory,
  1525. snodes_allocfree,
  1526. buf_to_snode_last_use,
  1527. name_to_freeable_input_buf,
  1528. candidate_buffer_map,
  1529. ) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs)
  1530. _prev, _next, _head = _initialize_double_linked_list(snodes)
  1531. stats: dict[BaseSchedulerNode, SinkWaitInfo] = {}
  1532. runtimes: dict[BaseSchedulerNode, float] = {
  1533. snode: estimate_op_runtime(snode) * _op_runtime_estimate_mult(snode)
  1534. for snode in snodes
  1535. }
  1536. # Pre-compute output and dependency sets for O(1) lookup instead of O(N) creation per iteration
  1537. node_output_sets: dict[BaseSchedulerNode, frozenset[str]] = {
  1538. snode: frozenset(o.get_name() for o in snode.get_outputs()) for snode in snodes
  1539. }
  1540. node_dep_sets: dict[BaseSchedulerNode, frozenset[str]] = {
  1541. snode: frozenset(
  1542. d.name for d in snode.unmet_dependencies if not _is_fake_dep(d)
  1543. )
  1544. for snode in snodes
  1545. }
  1546. curr: Optional[BaseSchedulerNode] = snodes[-1]
  1547. processed_waits = OrderedSet() # type: ignore[var-annotated]
  1548. debug_iterative_memory_recompute = (
  1549. config_comms.reorder_iterative_debug_memory_recompute
  1550. )
  1551. debug_num_sink_waits_to_reorder: Optional[int] = (
  1552. config_comms.sink_waits_iterative_debug_limit_to_sink
  1553. )
  1554. iterative_recompute_error = False
  1555. while curr is not None and _prev[curr] is not None:
  1556. _prev_curr = _prev[curr]
  1557. if iterative_recompute_error:
  1558. break
  1559. if (
  1560. debug_num_sink_waits_to_reorder is not None
  1561. and len(processed_waits) >= debug_num_sink_waits_to_reorder
  1562. ):
  1563. break
  1564. if not (contains_wait(curr) and curr not in processed_waits):
  1565. curr = _prev_curr
  1566. continue
  1567. processed_waits.add(curr)
  1568. info = stats[curr] = SinkWaitInfo()
  1569. comm_time, comp_time, overlap_info = _wait_exposed_communication_time(
  1570. curr, _head, _prev, runtimes, node_output_sets, node_dep_sets
  1571. )
  1572. info.initial_exposed = info.final_exposed = comm_time - comp_time
  1573. info.comm_time = comm_time
  1574. info.comp_time = comp_time
  1575. info.overlap_info = overlap_info
  1576. candidate = _next[curr]
  1577. group_head = curr
  1578. group_tail = curr
  1579. group_colls = {}
  1580. group_runtime = 0.0
  1581. group_peak_memory = _curr_memory[curr][0]
  1582. # Track group outputs and check collective status incrementally - initialize from pre-computed set
  1583. group_output_names = OrderedSet(node_output_sets[curr])
  1584. group_contains_collective = contains_collective(curr)
  1585. while candidate is not None:
  1586. if config_comms.sink_iterative_use_runtime_estimations and (
  1587. info.final_exposed
  1588. < -config_comms.sink_iterative_extra_comm_comp_overlap * info.comm_time
  1589. ):
  1590. info.limiting_factor = "unexposed by runtime estimations"
  1591. break
  1592. # Early exit: if group has no outputs, candidate can't depend on it
  1593. if not group_output_names:
  1594. data_dep = False
  1595. else:
  1596. # Calculate candidate dependencies using pre-computed set
  1597. candidate_dep_names = node_dep_sets[candidate]
  1598. data_dep = bool(candidate_dep_names & group_output_names)
  1599. # Conservative sink wait, limiting by space before next collective.
  1600. # The global strategy is that bucketing should create space.
  1601. # For 2D we can experiment with allowing to sink Wait beyond non current group collective.
  1602. if not config_comms.sink_waits_iterative_swap_with_collectives:
  1603. if contains_async_collective(candidate):
  1604. info.limiting_factor = (
  1605. f"candidate contains_async_collective {candidate.get_name()}"
  1606. )
  1607. break
  1608. # 1. If we have data_dep - we can not swap => trying to group
  1609. # 2. If swap candidate and current node both contain collectives => trying to group
  1610. both_contain_comms = group_contains_collective and contains_collective(
  1611. candidate
  1612. )
  1613. if data_dep or both_contain_comms:
  1614. _is_groupable, groupable_reason = _is_node_groupable_for_sink_waits(
  1615. candidate
  1616. )
  1617. if _is_groupable:
  1618. group_tail = candidate
  1619. # Update incremental tracking using pre-computed set
  1620. group_output_names.update(node_output_sets[candidate])
  1621. group_contains_collective = (
  1622. group_contains_collective or contains_collective(candidate)
  1623. )
  1624. if (
  1625. config_comms.sink_iterative_use_runtime_estimations
  1626. and contains_collective(candidate)
  1627. ):
  1628. comm_time, comp_time, _ = _coll_exposed_communication_time(
  1629. candidate, _next, runtimes, node_output_sets, node_dep_sets
  1630. )
  1631. group_colls[candidate] = (comm_time, comp_time)
  1632. if not contains_async_collective(candidate):
  1633. group_runtime += runtimes[candidate]
  1634. group_peak_memory = max(
  1635. group_peak_memory, _curr_memory[candidate][0]
  1636. )
  1637. info.grouped += 1
  1638. candidate = _next[candidate]
  1639. continue
  1640. elif not data_dep:
  1641. if (
  1642. not config_comms.sink_waits_iterative_unsafe_collectives_reorder
  1643. and both_contain_comms
  1644. ):
  1645. info.limiting_factor = (
  1646. f"collective ordering"
  1647. f"\n with candidate:{candidate.get_name()}"
  1648. )
  1649. break
  1650. else:
  1651. info.limiting_factor = (
  1652. f"data dependency detected"
  1653. f"\n candidate:{candidate.get_name()}"
  1654. f"\n non_group_reason:{groupable_reason}"
  1655. )
  1656. break
  1657. if config_comms.sink_iterative_use_runtime_estimations:
  1658. if is_wait(candidate.node):
  1659. # Corresponding collective is before the group,
  1660. # Swap can increase exposed time of corresponding collective
  1661. comm_time, comp_time, _ = _wait_exposed_communication_time(
  1662. candidate,
  1663. _head,
  1664. _prev,
  1665. runtimes,
  1666. node_output_sets,
  1667. node_dep_sets,
  1668. )
  1669. # pyrefly: ignore[no-matching-overload]
  1670. exposed_before = max(0, comm_time - comp_time)
  1671. # pyrefly: ignore[no-matching-overload]
  1672. exposed_after = max(0, comm_time - comp_time + group_runtime)
  1673. # We do not know how much we can sink more after this swap,
  1674. # Just comparing advantage at the moment for now.
  1675. if exposed_after > exposed_before:
  1676. info.limiting_factor = (
  1677. "candidate is wait,"
  1678. f" exposed_before:{exposed_before} vs exposed_after:{exposed_after}"
  1679. )
  1680. break
  1681. # Check if candidate has sync runtime
  1682. if not contains_async_collective(candidate):
  1683. # If candidate has sync runtime,
  1684. # Waits of gorup_colls are on the right from group.
  1685. # Swap can increase their exposed time.
  1686. c_runtime = runtimes[candidate]
  1687. if c_runtime > 0 and len(group_colls) > 0:
  1688. # Advantage for current Wait to do the Swap
  1689. # pyrefly: ignore[no-matching-overload]
  1690. exposed_delta = max(
  1691. 0,
  1692. info.comm_time - info.comp_time,
  1693. )
  1694. # pyrefly: ignore[no-matching-overload]
  1695. -max(0, info.comm_time - info.comp_time - c_runtime)
  1696. for gc_comm_time, gc_comp_time in group_colls.values():
  1697. # pyrefly: ignore [no-matching-overload]
  1698. exposed_delta += max(0, gc_comm_time - gc_comp_time) - max(
  1699. 0, gc_comm_time - gc_comp_time + c_runtime
  1700. )
  1701. if exposed_delta > 0:
  1702. info.limiting_factor = (
  1703. f"candidate has compute {c_runtime}, group contains collectives,"
  1704. f" total_exposed_delta {exposed_delta}"
  1705. )
  1706. break
  1707. else:
  1708. # Update all group_colls comm_time, comp_time
  1709. for gc, (
  1710. gc_comm_time,
  1711. gc_comp_time,
  1712. ) in group_colls.items():
  1713. group_colls[gc] = (
  1714. gc_comm_time,
  1715. gc_comp_time - c_runtime,
  1716. )
  1717. # Create group nodes list once for swap operations
  1718. gns: list[BaseSchedulerNode] = _group_nodes_from_linked_list(
  1719. group_head, group_tail, _next
  1720. )
  1721. candidate_allocfree: SNodeMemory = snodes_allocfree[candidate]
  1722. candidate_delta_mem = (
  1723. candidate_allocfree.size_alloc - candidate_allocfree.size_free
  1724. )
  1725. # [group] candidate -> candidate [group]
  1726. # Check for buffers with successors in group and candidate last successor
  1727. #
  1728. # Buf that changes its last use snode,
  1729. # It was deallocated by candidate,
  1730. # but after swap it will be deallocated by group node.
  1731. group_n_to_bufs_after_swap_dealloc_instead_of_candidate = (
  1732. _find_buffers_with_changed_last_use_sink_waits(
  1733. candidate, gns, buf_to_snode_last_use, candidate_buffer_map
  1734. )
  1735. )
  1736. potential_peak, _post_alloc_update, _size_free_delta_update = (
  1737. _calculate_potential_peak_memory_sink_waits(
  1738. candidate,
  1739. gns,
  1740. group_head,
  1741. group_peak_memory,
  1742. candidate_delta_mem,
  1743. candidate_allocfree,
  1744. group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
  1745. _curr_memory,
  1746. snodes_allocfree,
  1747. )
  1748. )
  1749. if (
  1750. potential_peak - peak_memory
  1751. > peak_memory * config_comms.sink_iterative_peak_memory_budget
  1752. ):
  1753. info.limiting_factor = (
  1754. f"peak memory new:{potential_peak} vs base:{peak_memory}"
  1755. )
  1756. break
  1757. info.moves += 1
  1758. info.moves_info += f"+{candidate.get_name()}"
  1759. _head = _perform_double_linked_list_swap_sink_waits(
  1760. candidate, group_head, group_tail, _prev, _next, _head
  1761. )
  1762. comm_time, comp_time, overlap_info = _wait_exposed_communication_time(
  1763. curr, _head, _prev, runtimes, node_output_sets, node_dep_sets
  1764. )
  1765. info.comm_time = comm_time
  1766. info.comp_time = comp_time
  1767. info.final_exposed = comm_time - comp_time
  1768. info.overlap_info = overlap_info
  1769. _update_memory_tracking_after_swap_sink_waits(
  1770. candidate,
  1771. gns,
  1772. candidate_delta_mem,
  1773. candidate_allocfree,
  1774. group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
  1775. _post_alloc_update,
  1776. _size_free_delta_update,
  1777. _curr_memory,
  1778. snodes_allocfree,
  1779. )
  1780. if debug_iterative_memory_recompute:
  1781. from .comms_debug import _debug_iterative_memory_recompute
  1782. iterative_recompute_error = _debug_iterative_memory_recompute(
  1783. candidate,
  1784. gns,
  1785. _group_names(gns),
  1786. _group_nodes_from_linked_list(_head, None, _next),
  1787. name_to_freeable_input_buf,
  1788. graph_outputs,
  1789. peak_memory,
  1790. _curr_memory,
  1791. snodes_allocfree,
  1792. "sink_waits_iterative",
  1793. group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
  1794. )
  1795. if iterative_recompute_error:
  1796. break
  1797. candidate = _next[group_tail]
  1798. curr = _prev_curr
  1799. if not config_comms.reorder_sink_verbose_logging:
  1800. new_snodes = _group_nodes_from_linked_list(_head, None, _next)
  1801. return new_snodes, stats
  1802. new_snodes = _format_and_log_sink_waits_stats(
  1803. stats,
  1804. _head,
  1805. _next,
  1806. original_snodes_num,
  1807. peak_memory,
  1808. name_to_freeable_input_buf,
  1809. graph_outputs,
  1810. )
  1811. return new_snodes, stats
  1812. def sink_waits_iterative(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
  1813. """
  1814. Similarly to reorder_communication_preserving_peak_memory this pass will try to iteratively
  1815. push Wait nodes later, recomputing estimated peak memory before each swap,
  1816. and preventing peak memory regressions.
  1817. Pass will be applied to every Wait node. If there are immediate dependencies with next node,
  1818. pass will try to group them together and on the next step to swap the group with next candidate.
  1819. If _inductor.config_comms.sink_iterative_use_runtime_estimations is set True,
  1820. pass will stop reordering of Wait once corresponding Collective is unexposed,
  1821. based on runtime estimations.
  1822. inductor.config_comms.sink_iterative_peak_memory_budget allows to tune how much pass
  1823. can regress initial peak memory.
  1824. E.g.:
  1825. sink_iterative_peak_memory_budget == 0.0 - No regression of initial peak memory is allowed
  1826. sink_iterative_peak_memory_budget == 0.2 - Pass can improve comm-compute overlap, sacrificing
  1827. 20% of initial peak memory value.
  1828. inductor.config_comms.sink_iterative_extra_comm_comp_overlap config allows to more aggressively
  1829. sink waits, stopping only when overlap_compute >= (1 + extra_comm_comp_overlap) * comm_time
  1830. """
  1831. return _sink_waits_iterative_internal(snodes)[0]
  1832. def estimate_op_runtime(snode: BaseSchedulerNode) -> float:
  1833. """
  1834. Returns estimated op runtime in milliseconds (ms)
  1835. """
  1836. if config.estimate_op_runtime == "default":
  1837. runtime = snode.get_estimated_runtime()
  1838. else:
  1839. assert callable(config.estimate_op_runtime)
  1840. runtime = config.estimate_op_runtime(snode)
  1841. return runtime
  1842. def node_summary(snode):
  1843. snodes = snode.get_nodes()
  1844. if len(snodes) == 1:
  1845. detail = ""
  1846. if isinstance(snode.node, (ir.ExternKernelOut, ir._CollectiveKernel)):
  1847. outs_str = f"outs:{[o.get_name() for o in snode.get_outputs()]}"
  1848. ins_str = f"ins:{[d.name for d in snode.unmet_dependencies]}"
  1849. detail = f" {snode.get_name()} ({snode.node.python_kernel_name})\n {outs_str}({ins_str})"
  1850. layouts = [child.node.get_output_spec() for child in snode.get_nodes()]
  1851. out_tensor_info = ",".join(
  1852. [
  1853. f" (size={layout.size}, stride={layout.stride})"
  1854. if isinstance(layout, ir.Layout)
  1855. else ""
  1856. for layout in layouts
  1857. ]
  1858. )
  1859. try:
  1860. node_name = snode.node.maybe_get_name()
  1861. except AttributeError:
  1862. # TODO: node_summary was written without FusedSchedulerNode in mind, generally needs to be hardened
  1863. node_name = ""
  1864. return f"{snode.node.__class__.__name__}{detail}{out_tensor_info} ({node_name} ({snode.get_estimated_runtime():.0f} ns)"
  1865. # Flatten the summaries for Fused/Foreach/Grouped nodes
  1866. summaries = []
  1867. for child_snode in snodes:
  1868. summaries.append(node_summary(child_snode))
  1869. return f"{snode.__class__.__name__}: {', '.join(summaries)}"
  1870. def visualize_overlap(order):
  1871. # TODO - this function probably doesn't do a very good job estimating the runtime because it doesn't carefully model
  1872. # streams and overlap. For now its mostly useful as a debug visualization.
  1873. total_est_runtime: float = 0.0
  1874. cur_comm_node = None
  1875. def step_log(step, msg):
  1876. overlap_log.debug(f"{step:>6}: {msg}") # noqa: G004
  1877. for step, snode in enumerate(order):
  1878. if cur_comm_node is None:
  1879. if contains_collective(snode):
  1880. total_est_runtime += estimate_op_runtime(snode)
  1881. cur_comm_node = snode.node
  1882. elif is_wait(snode.node):
  1883. # raise AssertionError(
  1884. # "Wait is not expected when there is no collective running"
  1885. # )
  1886. pass
  1887. else: # exposed compute op
  1888. total_est_runtime += estimate_op_runtime(snode)
  1889. step_log(step, f"{node_summary(snode)}")
  1890. else: # cur_comm_node is not None
  1891. if contains_collective(snode):
  1892. total_est_runtime += estimate_op_runtime(snode)
  1893. cur_comm_node = snode.node
  1894. step_log(step, f"{node_summary(snode)}") # noqa: G004
  1895. elif is_wait(snode.node): # end of this comm op
  1896. step_log(step, f"{node_summary(snode)}")
  1897. cur_comm_node = None
  1898. else: # overlapped compute op
  1899. step_log(step, f"| {node_summary(snode)}")
  1900. overlap_log.debug(
  1901. f"Est. runtime (ms): {total_est_runtime / 1000 / 1000}" # noqa: G004
  1902. )
  1903. def reorder_compute_and_comm_for_overlap(
  1904. snodes: list[BaseSchedulerNode],
  1905. ) -> list[BaseSchedulerNode]:
  1906. order = snodes
  1907. # pyrefly: ignore [bad-assignment]
  1908. for p in config.reorder_for_compute_comm_overlap_passes:
  1909. if isinstance(p, str) and p in globals():
  1910. p = globals()[p] # it is a builtin pass
  1911. assert callable(p), (
  1912. f"Invalid reorder_compute_and_comm_for_overlap pass: {p} is not callable"
  1913. )
  1914. order = p(order) # type: ignore[operator]
  1915. # pyrefly: ignore [bad-return]
  1916. return order
  1917. def remove_fsdp2_unsharded_param_graph_input_usage(graph: torch.fx.Graph):
  1918. """
  1919. This FX graph pass replaces uses of FSDP2 unsharded params with their corresponding
  1920. graph intermediates that were fsdp.copy_ into the unsharded params in the original graph.
  1921. NOTE: Can only apply this pass to any of the FSDP2 unsharded params that have this pattern
  1922. (or repetition of): `resize_(full) -> copy_ -> resize_(0)`. Because of this, for partial-graph case
  1923. where `resize_(full) -> copy_` is in one graph and `resize_(0)` is in another graph, we can't
  1924. remove these resize and copy ops and thus we will have worse performance there.
  1925. In other words, "do we try to remove all the resize_(full) -> copy_ -> resize_(0) nodes for this unsharded param"
  1926. is actually a per-unsharded-param decision, since for each unsharded param, we look at its resize sequence pattern
  1927. (in `check_resize_pattern()`) to determine if its set of resize and copy nodes can be removed.
  1928. """
  1929. node_list = list(graph.nodes)
  1930. # Find all graph inputs and their resize counts
  1931. graph_input_to_resized_to_full_node_idxes = defaultdict(list)
  1932. graph_input_to_resized_to_0_node_idxes = defaultdict(list)
  1933. for idx, node in enumerate(node_list):
  1934. if (
  1935. node.op == "call_function"
  1936. and node.target is torch.ops.inductor.resize_storage_bytes_.default
  1937. ):
  1938. assert node.args[0].op == "placeholder", f"""\
  1939. Resize can only operate on graph inputs, but got {node} which is resizing non-graph-input {node.args[0]}
  1940. """
  1941. graph_input = node.args[0]
  1942. new_size = node.args[1]
  1943. if new_size > 0:
  1944. graph_input_to_resized_to_full_node_idxes[graph_input].append(idx)
  1945. else:
  1946. graph_input_to_resized_to_0_node_idxes[graph_input].append(idx)
  1947. def check_resize_pattern(graph_input):
  1948. # Check the number of resize-to-full and resize-to-0 nodes are equal,
  1949. # and that for each (resize-to-full, resize-to-0) pair, the resize-to-full node
  1950. # always happens before the resize-to-0 node.
  1951. # This is the precondition for being able to remove all the resize and copy nodes
  1952. # for this specific unsharded param.
  1953. resized_to_full_idxes = graph_input_to_resized_to_full_node_idxes.get(
  1954. graph_input, []
  1955. )
  1956. resized_to_0_idxes = graph_input_to_resized_to_0_node_idxes.get(graph_input, [])
  1957. if len(resized_to_full_idxes) != len(resized_to_0_idxes):
  1958. log.warning(
  1959. f"""
  1960. Unequal number of resize-to-full and resize-to-0 nodes for graph input {graph_input}:
  1961. {len(resized_to_full_idxes)} vs. {len(resized_to_0_idxes)}.
  1962. Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass.
  1963. """ # noqa: G004
  1964. )
  1965. return False
  1966. # Check the sequence: (resize_to_full -> resize_to_0)+
  1967. for resize_to_full_idx, resize_to_0_idx in zip(
  1968. resized_to_full_idxes, resized_to_0_idxes
  1969. ):
  1970. if resize_to_full_idx >= resize_to_0_idx:
  1971. log.warning(
  1972. f"""
  1973. For graph input {graph_input}: resize-to-full node {node_list[resize_to_full_idx]} at index {resize_to_full_idx}
  1974. happens after resize-to-0 node {node_list[resize_to_0_idx]} at index {resize_to_0_idx}.
  1975. Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass for that unsharded param.
  1976. """ # noqa: G004
  1977. )
  1978. return False
  1979. return True
  1980. # Find all eligible unsharded params and their corresponding graph intermediates.
  1981. unsharded_param_to_fsdp_copy_node_idxes = defaultdict(list)
  1982. for idx, node in enumerate(node_list):
  1983. if node.op == "call_function" and node.target is torch.ops.fsdp.copy_.default:
  1984. fsdp_copy_node = node
  1985. unsharded_param = node.args[0]
  1986. assert unsharded_param.op == "placeholder", f"""
  1987. Assumed all FSDP2 `unsharded_param`s to be graph input, but it's not true!
  1988. Offending node: {unsharded_param}. Graph: {graph}
  1989. """
  1990. if check_resize_pattern(unsharded_param):
  1991. unsharded_param_to_fsdp_copy_node_idxes[unsharded_param].append(idx)
  1992. def is_allowed_mutation(node):
  1993. return (
  1994. node.target is torch.ops.fsdp.copy_.default
  1995. or node.target is torch.ops.inductor.resize_storage_bytes_.default
  1996. )
  1997. def is_node_mutating_unsharded_param_or_its_alias(node, unsharded_params):
  1998. # Check whether the node is mutating any of the unsharded params or their aliases.
  1999. mutated_arg_idxes = (
  2000. [
  2001. i
  2002. for i, x in enumerate(node.target._schema.arguments)
  2003. if x.alias_info is not None and x.alias_info.is_write
  2004. ]
  2005. if isinstance(node.target, torch._ops.OpOverload)
  2006. else []
  2007. )
  2008. mutated_node_arg_storages = OrderedSet(
  2009. [
  2010. StorageWeakRef(node.args[i].meta["val"].untyped_storage())
  2011. for i in mutated_arg_idxes
  2012. ]
  2013. )
  2014. storages_of_unsharded_params = OrderedSet(
  2015. [
  2016. StorageWeakRef(unsharded_param.meta["val"].untyped_storage())
  2017. for unsharded_param in unsharded_params
  2018. ]
  2019. )
  2020. return len(mutated_node_arg_storages & storages_of_unsharded_params) > 0
  2021. # Check no user mutation on any unsharded_param
  2022. for node in node_list:
  2023. if (
  2024. node.op == "call_function"
  2025. and isinstance(node.target, torch._ops.OpOverload)
  2026. and node.target._schema.is_mutable
  2027. and not is_allowed_mutation(node)
  2028. ):
  2029. assert not is_node_mutating_unsharded_param_or_its_alias(
  2030. node, unsharded_param_to_fsdp_copy_node_idxes.keys()
  2031. ), f"""\
  2032. User mutation on FSDP2 unsharded param is not allowed when Traceable FSDP2 is used. Violating node: {node}
  2033. """
  2034. # For each `fsdp.copy_(unsharded_param, Y)`, replace downstream usage of `unsharded_param` with `Y`.
  2035. #
  2036. # NOTE: Because of "layer reuse" use case, there could be multiple `fsdp.copy_` to the same `unsharded_param` graph input.
  2037. # e.g.
  2038. # ```
  2039. # fsdp_copy_1 = fsdp.copy_(unsharded_param_1, Y1)
  2040. # ... (use of unsharded_param_1) -> Subgraph 1
  2041. # fsdp_copy_2 = fsdp.copy_(unsharded_param_1, Y2)
  2042. # ... (use of unsharded_param_1) -> Subgraph 2
  2043. # fsdp_copy_3 = fsdp.copy_(unsharded_param_1, Y3)
  2044. # ... (use of unsharded_param_1) -> Subgraph 3
  2045. # ```
  2046. # We must do the replacement only within each subgraph.
  2047. for (
  2048. unsharded_param,
  2049. fsdp_copy_node_idxes,
  2050. ) in unsharded_param_to_fsdp_copy_node_idxes.items():
  2051. for i, fsdp_copy_node_idx in enumerate(fsdp_copy_node_idxes):
  2052. fsdp_copy_node = node_list[fsdp_copy_node_idx]
  2053. assert fsdp_copy_node.args[0] is unsharded_param
  2054. _, replacement = fsdp_copy_node.args
  2055. # subgraph_start_idx is exclusive
  2056. subgraph_start_idx = fsdp_copy_node_idx + 1
  2057. # subgraph_end_idx is exclusive (also intentionally don't replace args in return op)
  2058. subgraph_end_idx = (
  2059. fsdp_copy_node_idxes[i + 1]
  2060. if i < len(fsdp_copy_node_idxes) - 1
  2061. else len(node_list) - 1
  2062. )
  2063. subgraph_nodes = node_list[subgraph_start_idx:subgraph_end_idx]
  2064. assert not any(
  2065. is_node_mutating_unsharded_param_or_its_alias(node, [unsharded_param])
  2066. for node in subgraph_nodes
  2067. ), f"""\
  2068. Assumed no ops mutating unsharded param {unsharded_param} in subgraph {subgraph_nodes}, but it's not true!
  2069. Graph: {graph}
  2070. """
  2071. for node in subgraph_nodes:
  2072. if (
  2073. node.op == "call_function"
  2074. and unsharded_param in node.args
  2075. and node.target != torch.ops.inductor.resize_storage_bytes_.default
  2076. ): # TODO(yf225): implement replacement in kwargs
  2077. new_args = tuple(
  2078. replacement if arg is unsharded_param else arg
  2079. for arg in node.args
  2080. )
  2081. node.args = new_args
  2082. # Delete `fsdp.copy_(unsharded_param, Y)` nodes
  2083. for fsdp_copy_node_idxes in unsharded_param_to_fsdp_copy_node_idxes.values():
  2084. for fsdp_copy_node_idx in fsdp_copy_node_idxes:
  2085. fsdp_copy_node = node_list[fsdp_copy_node_idx]
  2086. graph.erase_node(fsdp_copy_node)
  2087. # Delete `resize_(unsharded_param, ...)` nodes
  2088. for node in node_list:
  2089. if (
  2090. node.op == "call_function"
  2091. and node.target is torch.ops.inductor.resize_storage_bytes_.default
  2092. and node.args[0] in unsharded_param_to_fsdp_copy_node_idxes
  2093. ):
  2094. graph.erase_node(node)
  2095. def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None:
  2096. try:
  2097. import torch.distributed.fsdp._fully_shard._fsdp_collectives
  2098. assert torch.distributed.is_available()
  2099. # Assert existence of these ops
  2100. assert (
  2101. torch.ops._c10d_functional.all_gather_into_tensor
  2102. and torch.ops._c10d_functional.all_gather_into_tensor_out
  2103. )
  2104. except (ImportError, AttributeError, AssertionError):
  2105. return
  2106. from .pattern_matcher import (
  2107. CallFunction,
  2108. KeywordArg,
  2109. Match,
  2110. PatternMatcherPass,
  2111. register_graph_pattern,
  2112. )
  2113. """
  2114. all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...);
  2115. getitem = all_gather_copy_in[0];
  2116. (getitem_1 = all_gather_copy_in[1];) # optional
  2117. all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem, ...);
  2118. ->
  2119. all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...);
  2120. getitem = all_gather_copy_in[0];
  2121. getitem_1 = all_gather_copy_in[1];
  2122. all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor_out.default(getitem, ..., out=getitem_1);
  2123. """
  2124. def remove_unused_getitem(g):
  2125. # Remove `getitem_X = all_gather_copy_in[1]` which is never used.
  2126. node_list = list(g.nodes)
  2127. for n in node_list:
  2128. if (
  2129. n.target is operator.getitem
  2130. and n.args[0].target is torch.ops.fsdp.all_gather_copy_in.default
  2131. and n.args[1] == 1
  2132. ):
  2133. g.erase_node(n)
  2134. graph_pass = PatternMatcherPass()
  2135. @register_graph_pattern(
  2136. CallFunction(
  2137. torch.ops._c10d_functional.all_gather_into_tensor.default,
  2138. CallFunction(
  2139. operator.getitem,
  2140. CallFunction(
  2141. torch.ops.fsdp.all_gather_copy_in.default,
  2142. KeywordArg("all_gather_inputs"),
  2143. KeywordArg("all_gather_output"),
  2144. KeywordArg("inp_split_sizes"),
  2145. KeywordArg("all_gather_input_numel"),
  2146. KeywordArg("rank"),
  2147. ),
  2148. KeywordArg("item_idx"),
  2149. ),
  2150. KeywordArg("group_size"),
  2151. KeywordArg("group_name"),
  2152. ),
  2153. # pyrefly: ignore [bad-argument-type]
  2154. pass_dict=graph_pass,
  2155. extra_check=lambda match: match.kwargs["item_idx"] == 0,
  2156. )
  2157. def reinplace_all_gather(match: Match, *args, **kwargs):
  2158. def repl(
  2159. *args,
  2160. ):
  2161. copy_in_args = args[:-2]
  2162. group_size = args[-2]
  2163. group_name = args[-1]
  2164. all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(
  2165. *copy_in_args
  2166. )
  2167. getitem = all_gather_copy_in[0]
  2168. getitem_1 = all_gather_copy_in[1]
  2169. all_gather_into_tensor = (
  2170. torch.ops._c10d_functional.all_gather_into_tensor_out.default(
  2171. getitem, group_size, group_name, out=getitem_1
  2172. )
  2173. )
  2174. return all_gather_into_tensor
  2175. match.replace_by_example(
  2176. # pyrefly: ignore [bad-argument-type]
  2177. repl,
  2178. [
  2179. kwargs["all_gather_inputs"],
  2180. kwargs["all_gather_output"],
  2181. kwargs["inp_split_sizes"],
  2182. kwargs["all_gather_input_numel"],
  2183. kwargs["rank"],
  2184. kwargs["group_size"],
  2185. kwargs["group_name"],
  2186. ],
  2187. )
  2188. remove_unused_getitem(graph)
  2189. graph_pass.apply(graph) # type: ignore[arg-type]
  2190. def get_op_idx(snode):
  2191. assert not isinstance(
  2192. snode,
  2193. (
  2194. torch._inductor.scheduler.FusedSchedulerNode,
  2195. torch._inductor.scheduler.GroupedSchedulerNode,
  2196. ),
  2197. )
  2198. return int(snode.get_name()[2:])
  2199. def enforce_comm_ordering_for_fsdp(
  2200. snodes: list[torch._inductor.scheduler.BaseSchedulerNode],
  2201. name_to_buf: dict[str, torch._inductor.scheduler.SchedulerBuffer],
  2202. name_to_fused_node: dict[str, BaseSchedulerNode],
  2203. ) -> list[torch._inductor.scheduler.BaseSchedulerNode]:
  2204. from . import scheduler
  2205. new_order: list[BaseSchedulerNode] = []
  2206. scheduled = OrderedSet[Any]()
  2207. ag_exists = False
  2208. rs_exists = False
  2209. ag_grouped_node_to_wait_grouped_node = {}
  2210. rs_grouped_node_to_wait_grouped_node = {}
  2211. snode_name_to_final_snode = {}
  2212. def _create_group_node(snodes_to_group):
  2213. group_node = scheduler.GroupedSchedulerNode.create(snodes_to_group)
  2214. for snode in snodes_to_group:
  2215. snode_name_to_final_snode[snode.get_name()] = group_node
  2216. snode_name_to_final_snode[group_node.get_name()] = group_node
  2217. return group_node
  2218. # Create grouped nodes for specific sets of ops
  2219. for snode in snodes:
  2220. # Case 1: Handle AllGather
  2221. if is_collective(
  2222. snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor_out.default
  2223. ) and any(
  2224. is_fallback_op(
  2225. name_to_fused_node[x].node, torch.ops.fsdp.all_gather_copy_in.default
  2226. )
  2227. for x in snode.ancestors
  2228. ):
  2229. ag_exists = True
  2230. ag_snode = snode
  2231. ag_related_snode_set: OrderedSet[scheduler.BaseSchedulerNode] = OrderedSet()
  2232. # Find the "cast + copy_in + getitem + all_gather" code block
  2233. find_recursive_deps_of_node(
  2234. ag_snode,
  2235. ag_related_snode_set,
  2236. name_to_buf,
  2237. name_to_fused_node,
  2238. )
  2239. # Find the "all_gather + all_gather_wait_tensor + copy_out" code block
  2240. allowed_ops = OrderedSet(
  2241. [
  2242. torch.ops._c10d_functional.all_gather_into_tensor_out.default,
  2243. torch.ops._c10d_functional.wait_tensor.default,
  2244. torch.ops.fsdp.split_with_sizes_copy.default,
  2245. ]
  2246. )
  2247. find_recursive_users_of_node(
  2248. ag_snode,
  2249. ag_related_snode_set,
  2250. name_to_buf,
  2251. name_to_fused_node,
  2252. criteria_cb=lambda x: not (
  2253. isinstance(x, scheduler.NopKernelSchedulerNode)
  2254. or (
  2255. isinstance(x, scheduler.ExternKernelSchedulerNode)
  2256. and x.node.op_overload in allowed_ops # type: ignore[union-attr]
  2257. )
  2258. ),
  2259. )
  2260. # sort nodes by original operation order
  2261. ag_related_snodes = sorted(
  2262. ag_related_snode_set, key=lambda x: get_op_idx(x)
  2263. )
  2264. # In the "reuse layer" case, some ops in the 2nd all-gather code block could also
  2265. # depend on ops in the 1st all-gather code block, and we don't want to group them together.
  2266. end_idx_of_current_ag_block = len(ag_related_snodes)
  2267. copy_out_count = 0
  2268. for i in range(len(ag_related_snodes)):
  2269. cur_snode = ag_related_snodes[i]
  2270. if is_fallback_op(
  2271. cur_snode.node, torch.ops.fsdp.split_with_sizes_copy.default
  2272. ):
  2273. copy_out_count += 1
  2274. if copy_out_count > 1:
  2275. end_idx_of_current_ag_block = i
  2276. break
  2277. ag_related_snodes = ag_related_snodes[:end_idx_of_current_ag_block]
  2278. # Group "cast + copy_in + getitem + all_gather" into one GroupedSchedulerNode
  2279. wait_node_idx = None
  2280. for i in range(len(ag_related_snodes) - 1):
  2281. if isinstance(ag_related_snodes[i + 1].node, ir._WaitKernel):
  2282. wait_node_idx = i + 1
  2283. break
  2284. assert wait_node_idx is not None
  2285. ag_group_node = _create_group_node(ag_related_snodes[:wait_node_idx])
  2286. # Group "all_gather_wait_tensor + copy_out" into one GroupedSchedulerNode
  2287. ag_wait_group_node = _create_group_node(ag_related_snodes[wait_node_idx:])
  2288. ag_grouped_node_to_wait_grouped_node[ag_group_node] = ag_wait_group_node
  2289. # Case 2: Handle ReduceScatter
  2290. elif is_fallback_op(snode.node, torch.ops.fsdp.chunk_cat.default):
  2291. rs_exists = True
  2292. rs_snode = snode
  2293. # Find the "reduce_scatter copy-in + reduce_scatter comm + reduce_scatter wait" code block
  2294. rs_related_snode_set: OrderedSet[scheduler.BaseSchedulerNode] = OrderedSet()
  2295. find_recursive_users_of_node(
  2296. rs_snode,
  2297. rs_related_snode_set,
  2298. name_to_buf,
  2299. name_to_fused_node,
  2300. )
  2301. # sort nodes by original operation order
  2302. rs_related_snodes = sorted(
  2303. rs_related_snode_set, key=lambda x: get_op_idx(x)
  2304. )
  2305. # Group "reduce_scatter copy-in + reduce_scatter comm" into one GroupedSchedulerNode
  2306. wait_node_idx = None
  2307. for i in range(len(rs_related_snodes) - 1):
  2308. if isinstance(rs_related_snodes[i + 1].node, ir._WaitKernel):
  2309. wait_node_idx = i + 1
  2310. break
  2311. assert wait_node_idx is not None
  2312. rs_group_node = _create_group_node(rs_related_snodes[:wait_node_idx])
  2313. # Group "reduce_scatter wait + related output nodes" into one GroupedSchedulerNode
  2314. rs_wait_group_node = _create_group_node(rs_related_snodes[wait_node_idx:])
  2315. rs_grouped_node_to_wait_grouped_node[rs_group_node] = rs_wait_group_node
  2316. assert len(snode_name_to_final_snode) > 0
  2317. if ag_exists:
  2318. assert len(ag_grouped_node_to_wait_grouped_node) > 0
  2319. if rs_exists:
  2320. assert len(rs_grouped_node_to_wait_grouped_node) > 0
  2321. # Build the new node schedule, taking GroupedSchedulerNode into account
  2322. for snode in snodes:
  2323. if snode.get_name() in snode_name_to_final_snode:
  2324. snode = snode_name_to_final_snode[snode.get_name()]
  2325. if snode in scheduled:
  2326. continue
  2327. new_order.append(snode)
  2328. scheduled.add(snode)
  2329. # Enforce AllGather ordering: previous AllGather's "wait then copy_out" group node must run
  2330. # before next AllGather's "copy_in then AG" group node
  2331. prev_ag_wait = None
  2332. for ag_group_node, wait_group_node in ag_grouped_node_to_wait_grouped_node.items():
  2333. if prev_ag_wait is not None:
  2334. mutating_buf = next(iter(ag_group_node.get_buffer_names()))
  2335. for o in prev_ag_wait.get_outputs():
  2336. ag_group_node.add_fake_dep(
  2337. WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True)
  2338. )
  2339. prev_ag_wait = wait_group_node
  2340. # Enforce ReduceScatter ordering: previous ReduceScatter's "wait" group node must run
  2341. # before next ReduceScatter's "copy_in then RS" group node
  2342. prev_rs_wait = None
  2343. for rs_group_node, wait_group_node in rs_grouped_node_to_wait_grouped_node.items():
  2344. if prev_rs_wait is not None:
  2345. mutating_buf = next(iter(rs_group_node.get_buffer_names()))
  2346. for o in prev_rs_wait.get_outputs():
  2347. rs_group_node.add_fake_dep(
  2348. WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True)
  2349. )
  2350. prev_rs_wait = wait_group_node
  2351. return new_order # type: ignore[return-value]