partitioners.py 141 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651
  1. from __future__ import annotations
  2. import copy
  3. import functools
  4. import hashlib
  5. import heapq
  6. import itertools
  7. import logging
  8. import math
  9. import operator
  10. import os
  11. import os.path
  12. import re
  13. import warnings
  14. from collections import defaultdict, deque
  15. from collections.abc import Callable
  16. from dataclasses import dataclass, replace
  17. from typing import Any, TYPE_CHECKING
  18. import torch
  19. import torch._inductor.inductor_prims
  20. import torch.distributed
  21. import torch.fx as fx
  22. import torch.utils._pytree as pytree
  23. from torch._dynamo.utils import counters, is_node_meta_valid
  24. from torch._functorch._activation_checkpointing.ac_logging_utils import (
  25. create_structured_trace_for_min_cut_info,
  26. )
  27. from torch._functorch._aot_autograd.utils import is_with_effects
  28. from torch._inductor import config as inductor_config
  29. from torch._inductor.custom_graph_pass import (
  30. CustomKnapsackSolver,
  31. CustomRuntimeEstimator,
  32. )
  33. from torch._library.fake_class_registry import FakeScriptObject
  34. from torch._library.utils import is_builtin
  35. from torch._logging import LazyString, trace_structured
  36. from torch._logging._internal import trace_log
  37. from torch._subclasses.fake_tensor import extract_tensor_metadata
  38. from torch.fx.experimental._backward_state import BackwardState
  39. from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
  40. from torch.fx.experimental.sym_node import magic_methods, method_to_operator
  41. from torch.fx.experimental.symbolic_shapes import (
  42. find_symbol_binding_fx_nodes,
  43. free_symbols,
  44. is_symbol_binding_fx_node,
  45. size_hint,
  46. statically_known_false,
  47. statically_known_true,
  48. )
  49. from torch.fx.passes import graph_drawer
  50. from torch.utils._ordered_set import OrderedSet
  51. from torch.utils.checkpoint import CheckpointPolicy
  52. from . import config
  53. from ._activation_checkpointing.graph_info_provider import GraphInfoProvider
  54. from ._activation_checkpointing.knapsack import (
  55. dp_knapsack,
  56. dp_knapsack_sliding_hirschberg,
  57. greedy_knapsack,
  58. ilp_knapsack,
  59. )
  60. from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator
  61. from ._aot_autograd.descriptors import (
  62. AOTOutput,
  63. SavedForBackwardsAOTOutput,
  64. SavedForBackwardsNoVcCheckAOTOutput,
  65. )
  66. from ._aot_autograd.functional_utils import _is_functional_graph
  67. from ._aot_autograd.logging_utils import get_aot_graph_name
  68. from ._aot_autograd.utils import (
  69. _is_bwd_seed_offset,
  70. _is_fwd_seed_offset,
  71. _is_primal,
  72. _is_tangent,
  73. get_cuda_generator_meta_val,
  74. )
  75. from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems
  76. if TYPE_CHECKING:
  77. import networkx as nx
  78. import sympy
  79. AOT_PARTITIONER_DEBUG: bool = config.debug_partitioner
  80. log: logging.Logger = logging.getLogger(__name__)
  81. aten = torch.ops.aten
  82. prims = torch.ops.prims
  83. @dataclass
  84. class OpTypes:
  85. """Class for keeping track of different operator categories"""
  86. fusible_ops: OrderedSet[Callable[..., Any]]
  87. compute_intensive_ops: OrderedSet[Callable[..., Any]]
  88. random_ops: OrderedSet[Callable[..., Any]]
  89. view_ops: OrderedSet[Callable[..., Any]]
  90. recomputable_ops: OrderedSet[Callable[..., Any]]
  91. def is_fusible(self, node: fx.Node) -> bool:
  92. return get_aten_target(node) in self.fusible_ops
  93. def is_compute_intensive(self, node: fx.Node) -> bool:
  94. return get_aten_target(node) in self.compute_intensive_ops
  95. def is_random(self, node: fx.Node) -> bool:
  96. return get_aten_target(node) in self.random_ops
  97. def is_view(self, node: fx.Node) -> bool:
  98. return get_aten_target(node) in self.view_ops
  99. def is_recomputable(self, node: fx.Node) -> bool:
  100. return get_aten_target(node) in self.recomputable_ops
  101. @dataclass
  102. class NodeInfo:
  103. # Be careful about iterating over these explicitly, as their order may not
  104. # be deterministic
  105. inputs: list[fx.Node]
  106. _required_fw_nodes: OrderedSet[fx.Node]
  107. required_bw_nodes: OrderedSet[fx.Node]
  108. tangents_closure: OrderedSet[fx.Node]
  109. unclaimed_nodes: OrderedSet[fx.Node]
  110. fw_order: dict[fx.Node, int]
  111. # Effectively maps to which of our primals are parameters
  112. static_lifetime_input_nodes: OrderedSet[fx.Node]
  113. @functools.cached_property
  114. def required_fw_nodes(self) -> list[fx.Node]:
  115. return sorted(
  116. (n for n in self._required_fw_nodes), key=lambda n: self.fw_order[n]
  117. )
  118. def is_required_fw(self, n: fx.Node) -> bool:
  119. return n in self._required_fw_nodes
  120. def is_required_bw(self, n: fx.Node) -> bool:
  121. return n in self.required_bw_nodes
  122. def is_unclaimed(self, n: fx.Node) -> bool:
  123. return n in self.unclaimed_nodes
  124. def get_fw_order(self, n: fx.Node) -> int:
  125. if n not in self._required_fw_nodes:
  126. raise AssertionError(f"Node {n} not in fw nodes!")
  127. return self.fw_order[n]
  128. @dataclass
  129. class MinCutOptions:
  130. ban_if_used_far_apart: bool
  131. ban_if_long_fusible_chains: bool
  132. ban_if_materialized_backward: bool
  133. ban_if_not_in_allowlist: bool
  134. ban_if_reduction: bool
  135. def must_recompute(node: fx.Node) -> bool:
  136. return node.meta.get("recompute", None) in [
  137. CheckpointPolicy.MUST_RECOMPUTE,
  138. CheckpointPolicy.PREFER_RECOMPUTE,
  139. ]
  140. def has_recomputable_ops(fx_g: fx.GraphModule) -> bool:
  141. for node in fx_g.graph.nodes:
  142. if must_recompute(node):
  143. return True
  144. return False
  145. def has_recomputable_rng_ops(fx_g: fx.GraphModule) -> bool:
  146. for node in fx_g.graph.nodes:
  147. if (
  148. must_recompute(node)
  149. and hasattr(node.target, "tags")
  150. and torch.Tag.nondeterministic_seeded in node.target.tags
  151. ):
  152. return True
  153. return False
  154. def sym_node_size(node: fx.Node) -> int:
  155. if isinstance(node.meta["val"], (torch.SymInt, torch.SymBool)):
  156. return 1
  157. if not isinstance(node.meta["val"], torch.SymFloat):
  158. raise AssertionError(
  159. f"expected node.meta['val'] to be SymFloat, got {type(node.meta['val'])}"
  160. )
  161. return 4
  162. class InvalidNodeBase:
  163. def __repr__(self) -> str:
  164. return "Invalid Node"
  165. # Run DCE while overriding the definition of is_impure_node
  166. def is_not_collective(node: fx.Node) -> bool:
  167. return getattr(node.target, "namespace", None) != "_c10d_functional"
  168. InvalidNode = InvalidNodeBase()
  169. def _get_ho_op_original_input(getitem_node: fx.Node) -> fx.Node | None:
  170. """Given a getitem node, check if it extracts from a higher-order op
  171. that has kwargs mapping the key back to an original input.
  172. Returns the original input node if found, None otherwise.
  173. """
  174. if getitem_node.target != operator.getitem:
  175. return None
  176. ho_result = getitem_node.args[0]
  177. key = getitem_node.args[1]
  178. if not isinstance(ho_result, fx.Node) or ho_result.op != "call_function":
  179. return None
  180. if "kwargs" not in ho_result.kwargs:
  181. return None
  182. kwargs = ho_result.kwargs["kwargs"]
  183. # pyrefly: ignore [not-iterable, unsupported-operation]
  184. if key not in kwargs:
  185. return None
  186. # pyrefly: ignore [bad-index, unsupported-operation]
  187. original_input = kwargs[key]
  188. if isinstance(original_input, fx.Node):
  189. return original_input
  190. return None
  191. def _is_copy_node_bw_only(node: fx.Node) -> fx.Node | None:
  192. """Check if node is a view/reshape of a higher-order op output that aliases an input.
  193. Returns the original input node from the higher-order op's kwargs if the pattern
  194. matches, None otherwise.
  195. """
  196. if node.target not in (torch.ops.aten.view.default, torch.ops.aten.reshape.default):
  197. return None
  198. source = node.args[0]
  199. if not isinstance(source, fx.Node):
  200. return None
  201. return _get_ho_op_original_input(source)
  202. def _find_input_for_invalid_output(
  203. node: fx.Node,
  204. env: dict[fx.Node, Any],
  205. ) -> fx.Node | None:
  206. """Try to find a valid input replacement for an invalid forward output.
  207. This handles cases where a forward output depends on backward nodes but
  208. semantically aliases an input. For example, a view of a getitem from a
  209. triton kernel that mutates a buffer in backward, or a direct getitem from
  210. such a higher-order op. The original input may be a primal or a valid
  211. intermediate node already present in the forward graph.
  212. """
  213. # Pattern 1: view/reshape(getitem(ho_op, key)) -> ho_op.kwargs["kwargs"][key]
  214. original_input = _is_copy_node_bw_only(node)
  215. if (
  216. original_input is not None
  217. and original_input in env
  218. and not isinstance(env[original_input], InvalidNodeBase)
  219. ):
  220. return env[original_input]
  221. # Pattern 2: getitem(ho_op, key) -> ho_op.kwargs["kwargs"][key]
  222. original_input = _get_ho_op_original_input(node)
  223. if (
  224. original_input is not None
  225. and original_input in env
  226. and not isinstance(env[original_input], InvalidNodeBase)
  227. ):
  228. return env[original_input]
  229. return None
  230. def _extract_graph_with_inputs_outputs(
  231. joint_graph: fx.Graph,
  232. inputs: list[fx.Node],
  233. outputs: list[fx.Node],
  234. outputs_descs: list[AOTOutput],
  235. subgraph: str | None = None,
  236. ignore_must_be_in_fw_bw: bool = False,
  237. ) -> fx.Graph:
  238. """
  239. Given a graph, extracts out a subgraph that takes the specified nodes as
  240. inputs and returns the specified outputs.
  241. This includes specifying non-placeholder nodes as inputs.
  242. The general strategy is to initialize all inputs with proxies as we
  243. encounter them, and trace through the graph, only keeping values which take
  244. in valid proxies. Then, all dead code is eliminated.
  245. """
  246. new_graph = fx.Graph()
  247. env: dict[fx.Node, fx.Node] = {}
  248. # Add new placeholder nodes in the order specified by the inputs
  249. for node in inputs:
  250. new_node = new_graph.placeholder(node.name)
  251. # Can't use node_copy here as we may be turning previous call_function into placeholders
  252. new_node.meta = node.meta
  253. # pyrefly: ignore [unsupported-operation]
  254. env[node] = new_node
  255. for node in joint_graph.nodes:
  256. if not ignore_must_be_in_fw_bw:
  257. if (
  258. _must_be_in_backward(node)
  259. and subgraph != "backward"
  260. and node not in inputs
  261. ):
  262. env[node] = InvalidNode # type: ignore[assignment]
  263. continue
  264. if (
  265. _must_be_in_forward(node)
  266. and subgraph != "forward"
  267. and node not in inputs
  268. ):
  269. env[node] = InvalidNode # type: ignore[assignment]
  270. continue
  271. if node in env:
  272. # Node must be one of our inputs. (Any member of env which wasn't an
  273. # input to start must have been created by this loop and won't be in
  274. # joint_graph.nodes).
  275. continue
  276. elif node.op == "placeholder":
  277. env[node] = InvalidNode # type: ignore[assignment]
  278. elif node.op == "call_function":
  279. all_args = pytree.arg_tree_leaves(*node.args, **node.kwargs)
  280. all_args = [
  281. isinstance(env[x], InvalidNodeBase)
  282. for x in all_args
  283. if isinstance(x, fx.Node)
  284. ]
  285. if any(all_args):
  286. env[node] = InvalidNode # type: ignore[assignment]
  287. continue
  288. # pyrefly: ignore [unsupported-operation, bad-argument-type]
  289. env[node] = new_graph.node_copy(node, lambda x: env[x])
  290. elif node.op == "get_attr":
  291. # pyrefly: ignore [unsupported-operation, bad-argument-type]
  292. env[node] = new_graph.node_copy(node, lambda x: env[x])
  293. elif node.op == "output":
  294. pass
  295. output_values = []
  296. for x, x_desc in zip(outputs, outputs_descs):
  297. if isinstance(x, fx.Node):
  298. if x not in env:
  299. raise RuntimeError(f"Node {x} couldn't be found in env")
  300. if isinstance(env[x], InvalidNodeBase):
  301. # For forward outputs that are invalid (depend on backward), try
  302. # to find a valid replacement.
  303. replacement = None
  304. # For copy_ nodes that are backward-only, use the destination
  305. # (first arg) which is the original input.
  306. if (
  307. x.target is torch.ops.aten.copy_.default
  308. and _must_be_in_backward(x)
  309. and len(x.args) >= 1
  310. and isinstance(x.args[0], fx.Node)
  311. and x.args[0] in env
  312. and not isinstance(env[x.args[0]], InvalidNodeBase)
  313. ):
  314. replacement = env[x.args[0]]
  315. # For view/reshape outputs that trace back to a getitem of a
  316. # higher-order op that mutates an input, find that input.
  317. # This handles custom_function_view outputs from triton kernels.
  318. if replacement is None:
  319. replacement = _find_input_for_invalid_output(x, env)
  320. if replacement is not None:
  321. output_values.append(replacement)
  322. continue
  323. raise AssertionError(f"Node {x} was invalid, but is output")
  324. output_values.append(env[x])
  325. else:
  326. output_values.append(x)
  327. out = new_graph.output(tuple(output_values))
  328. out.meta["desc"] = outputs_descs
  329. new_graph.eliminate_dead_code()
  330. new_graph.lint()
  331. return new_graph
  332. def is_non_builtin_to_include(node: fx.Node) -> bool:
  333. return config.is_non_builtin_to_include and (
  334. (isinstance(node.target, torch._ops.OpOverload) and not is_builtin(node.target))
  335. or node.target == torch.ops.higher_order.triton_kernel_wrapper_functional
  336. )
  337. def _is_backward_state(node: fx.Node) -> bool:
  338. return node.op == "placeholder" and isinstance(node.meta.get("val"), BackwardState)
  339. def _has_tag_is_backward(node: fx.Node) -> bool:
  340. return node.meta.get("partitioner_tag", None) == "is_backward"
  341. def _has_tag_is_forward(node: fx.Node) -> bool:
  342. return node.meta.get("partitioner_tag", None) == "is_forward"
  343. def _has_tag_must_be_in_forward(node: fx.Node) -> bool:
  344. return node.meta.get("partitioner_tag", None) == "must_be_in_forward"
  345. def _has_tag_must_be_in_backward(node: fx.Node) -> bool:
  346. return node.meta.get("partitioner_tag", None) == "must_be_in_backward"
  347. def _must_be_in_forward(node: fx.Node) -> bool:
  348. if _has_tag_must_be_in_forward(node):
  349. return True
  350. is_mutable = (
  351. isinstance(node.target, torch._ops.OpOverload)
  352. and node.target._schema.is_mutable
  353. )
  354. return (
  355. not _has_tag_is_backward(node)
  356. and not _has_tag_must_be_in_backward(node)
  357. and is_mutable
  358. )
  359. def _must_be_in_backward(node: fx.Node) -> bool:
  360. if _has_tag_must_be_in_backward(node):
  361. return True
  362. is_mutable = (
  363. isinstance(node.target, torch._ops.OpOverload)
  364. and node.target._schema.is_mutable
  365. )
  366. return _has_tag_is_backward(node) and is_mutable
  367. def _extract_fwd_bwd_outputs(
  368. joint_module: fx.GraphModule, *, num_fwd_outputs: int
  369. ) -> tuple[list[fx.Node], list[fx.Node], list[AOTOutput], list[AOTOutput]]:
  370. outputs = pytree.arg_tree_leaves(
  371. *(node.args for node in joint_module.graph.find_nodes(op="output"))
  372. )
  373. outputs_descs = pytree.arg_tree_leaves(
  374. next(iter(joint_module.graph.find_nodes(op="output"))).meta.get(
  375. "desc", [None] * len(outputs)
  376. )
  377. )
  378. fwd_outputs = outputs[:num_fwd_outputs]
  379. bwd_outputs = outputs[num_fwd_outputs:]
  380. fwd_outputs_descs = outputs_descs[:num_fwd_outputs]
  381. bwd_outputs_descs = outputs_descs[num_fwd_outputs:]
  382. return fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs
  383. def _remove_by_name(saved_values: list[fx.Node], name: str) -> None:
  384. for saved_value in saved_values:
  385. if saved_value.name == name:
  386. saved_values.remove(saved_value)
  387. break
  388. def find_first_sym_node(
  389. fwd_module_outputs: list[fx.Node] | tuple[fx.Node, ...],
  390. ) -> int:
  391. idx = len(fwd_module_outputs)
  392. for i in range(len(fwd_module_outputs) - 1, -1, -1):
  393. if not is_sym_node(fwd_module_outputs[i]):
  394. idx = i + 1
  395. break
  396. return idx
  397. def calculate_quantization_scaling(
  398. graph: torch.fx.Graph,
  399. node: torch.fx.Node,
  400. max: float = 57344.0,
  401. min: float = 1e-12,
  402. position: int = 0,
  403. ) -> torch.fx.Node:
  404. with graph.inserting_after(node):
  405. abs_node = graph.call_function(
  406. torch.ops.aten.abs.default,
  407. args=(node,),
  408. )
  409. abs_node.meta["val"] = torch.ops.aten.abs.default(node.meta["val"])
  410. abs_node.meta["tensor_meta"] = extract_tensor_metadata(abs_node.meta["val"])
  411. with graph.inserting_after(abs_node):
  412. amax_node = graph.call_function(
  413. torch.ops.aten.amax.default,
  414. args=(abs_node, [-1], True),
  415. )
  416. amax_node.meta["val"] = torch.ops.aten.amax.default(
  417. abs_node.meta["val"], [-1], True
  418. )
  419. amax_node.meta["tensor_meta"] = extract_tensor_metadata(amax_node.meta["val"])
  420. with graph.inserting_after(amax_node):
  421. amax_64_node = graph.call_function(
  422. torch.ops.prims.convert_element_type.default,
  423. args=(amax_node, torch.float64),
  424. )
  425. amax_64_node.meta["val"] = torch.ops.prims.convert_element_type.default(
  426. amax_node.meta["val"], torch.float64
  427. )
  428. amax_64_node.meta["tensor_meta"] = extract_tensor_metadata(
  429. amax_64_node.meta["val"]
  430. )
  431. with graph.inserting_after(amax_64_node):
  432. clamp_min_node = graph.call_function(
  433. torch.ops.aten.clamp_min.default,
  434. args=(amax_64_node, min),
  435. )
  436. clamp_min_node.meta["val"] = torch.ops.aten.clamp_min.default(
  437. amax_64_node.meta["val"], min
  438. )
  439. clamp_min_node.meta["tensor_meta"] = extract_tensor_metadata(
  440. clamp_min_node.meta["val"]
  441. )
  442. with graph.inserting_after(clamp_min_node):
  443. reciprocal_node = graph.call_function(
  444. torch.ops.aten.reciprocal.default,
  445. args=(clamp_min_node,),
  446. )
  447. reciprocal_node.meta["val"] = torch.ops.aten.reciprocal.default(
  448. clamp_min_node.meta["val"]
  449. )
  450. reciprocal_node.meta["tensor_meta"] = extract_tensor_metadata(
  451. reciprocal_node.meta["val"]
  452. )
  453. with graph.inserting_after(reciprocal_node):
  454. mul_node = graph.call_function(
  455. torch.ops.aten.mul.Tensor,
  456. args=(reciprocal_node, max),
  457. )
  458. mul_node.meta["val"] = torch.ops.aten.mul.Tensor(
  459. reciprocal_node.meta["val"], max
  460. )
  461. mul_node.meta["tensor_meta"] = extract_tensor_metadata(mul_node.meta["val"])
  462. with graph.inserting_after(mul_node):
  463. scale_node = graph.call_function(
  464. torch.ops.prims.convert_element_type.default,
  465. args=(mul_node, torch.float32),
  466. name=f"fp8_scale_pos_{position}_{node.name}",
  467. )
  468. scale_node.meta["val"] = torch.ops.prims.convert_element_type.default(
  469. mul_node.meta["val"], torch.float32
  470. )
  471. scale_node.meta["tensor_meta"] = extract_tensor_metadata(scale_node.meta["val"])
  472. return scale_node
  473. def perform_quantization(
  474. graph: torch.fx.Graph,
  475. node: torch.fx.Node,
  476. scale_node: torch.fx.Node,
  477. quant_type: torch.dtype,
  478. clamp_min: float,
  479. clamp_max: float,
  480. position: int,
  481. ) -> torch.fx.Node:
  482. with graph.inserting_after(scale_node):
  483. target_node_32 = graph.call_function(
  484. torch.ops.prims.convert_element_type.default,
  485. args=(node, torch.float32),
  486. )
  487. target_node_32.meta["val"] = torch.ops.prims.convert_element_type.default(
  488. node.meta["val"], torch.float32
  489. )
  490. target_node_32.meta["tensor_meta"] = extract_tensor_metadata(
  491. target_node_32.meta["val"]
  492. )
  493. with graph.inserting_after(target_node_32):
  494. scaled_target_node = graph.call_function(
  495. torch.ops.aten.mul.Tensor,
  496. args=(target_node_32, scale_node),
  497. )
  498. scaled_target_node.meta["val"] = torch.ops.aten.mul.Tensor(
  499. target_node_32.meta["val"], scale_node.meta["val"]
  500. )
  501. scaled_target_node.meta["tensor_meta"] = extract_tensor_metadata(
  502. scaled_target_node.meta["val"]
  503. )
  504. with graph.inserting_after(scaled_target_node):
  505. clamp_min_scaled_node = graph.call_function(
  506. torch.ops.aten.clamp_min.default,
  507. args=(scaled_target_node, clamp_min),
  508. )
  509. clamp_min_scaled_node.meta["val"] = torch.ops.aten.clamp_min.default(
  510. scaled_target_node.meta["val"], clamp_min
  511. )
  512. clamp_min_scaled_node.meta["tensor_meta"] = extract_tensor_metadata(
  513. clamp_min_scaled_node.meta["val"]
  514. )
  515. with graph.inserting_after(clamp_min_scaled_node):
  516. clamp_max_scaled_node = graph.call_function(
  517. torch.ops.aten.clamp_max.default,
  518. args=(clamp_min_scaled_node, clamp_max),
  519. )
  520. clamp_max_scaled_node.meta["val"] = torch.ops.aten.clamp_max.default(
  521. clamp_min_scaled_node.meta["val"], clamp_max
  522. )
  523. clamp_max_scaled_node.meta["tensor_meta"] = extract_tensor_metadata(
  524. clamp_max_scaled_node.meta["val"]
  525. )
  526. with graph.inserting_after(clamp_max_scaled_node):
  527. quant_activation_node = graph.call_function(
  528. torch.ops.prims.convert_element_type.default,
  529. args=(clamp_max_scaled_node, quant_type),
  530. name=f"fp8_quant_pos_{position}_{node.name}",
  531. )
  532. quant_activation_node.meta["val"] = (
  533. torch.ops.prims.convert_element_type.default(
  534. clamp_max_scaled_node.meta["val"], quant_type
  535. )
  536. )
  537. quant_activation_node.meta["tensor_meta"] = extract_tensor_metadata(
  538. quant_activation_node.meta["val"]
  539. )
  540. return quant_activation_node
  541. def calculate_tensor_size(tensor: torch.Tensor) -> float:
  542. """
  543. Calculate the size of a PyTorch tensor in megabytes (MB).
  544. Args:
  545. tensor (torch.Tensor): Input tensor
  546. Returns:
  547. float: Memory size in MB
  548. """
  549. # Get number of elements and size per element
  550. num_elements = tensor.numel()
  551. element_size = tensor.element_size()
  552. return (num_elements * element_size) / (1024 * 1024)
  553. def get_allowed_dtypes() -> list[torch.dtype]:
  554. allowed_dtypes = torch._inductor.config.post_grad_fusion_options[
  555. "activation_quantization_aten_pass"
  556. ].get("allowed_dtypes", "torch.bfloat16")
  557. allowed_dtypes = [
  558. getattr(torch, dtype.split(".")[-1]) for dtype in allowed_dtypes.split(";")
  559. ]
  560. return allowed_dtypes
  561. def should_quantize(node: torch.fx.Node) -> bool:
  562. allowed_dtypes = get_allowed_dtypes()
  563. if not is_node_meta_valid(node) or node.meta["val"].dtype not in allowed_dtypes:
  564. return False
  565. size_threshold = torch._inductor.config.post_grad_fusion_options[
  566. "activation_quantization_aten_pass"
  567. ].get("size_in_mb", 100)
  568. # calculate the size of the node
  569. size_in_mb = calculate_tensor_size(node.meta["val"])
  570. if not torch._inductor.config.post_grad_fusion_options[
  571. "activation_quantization_aten_pass"
  572. ].get("skip_dynamo_guards", False):
  573. return size_in_mb >= size_threshold
  574. else:
  575. # case 1: we always quantize tensors with dynamic shapes
  576. if torch._inductor.config.post_grad_fusion_options[
  577. "activation_quantization_aten_pass"
  578. ].get("quantize_dynamic_shape", False):
  579. return statically_known_true(
  580. size_in_mb >= size_threshold
  581. ) or not statically_known_false(size_in_mb >= size_threshold)
  582. else:
  583. # case 2: we always not quantize tensors with dynamic shapes
  584. return statically_known_true(size_in_mb >= size_threshold)
  585. def get_quant_type() -> torch.dtype:
  586. quant_type = torch._inductor.config.post_grad_fusion_options[
  587. "activation_quantization_aten_pass"
  588. ].get("quant_type", "torch.float8_e5m2")
  589. return getattr(torch, quant_type.split(".")[-1])
  590. def calculate_range(dtype: torch.dtype) -> tuple[float, float]:
  591. """
  592. Calculate the range of values for a given torch.dtype.
  593. Args:
  594. dtype (torch.dtype): The input dtype.
  595. Returns:
  596. tuple: A tuple containing the minimum and maximum values.
  597. """
  598. info = torch.finfo(dtype)
  599. return info.min, info.max
  600. def quantize_activation_fw(graph: torch.fx.Graph) -> None:
  601. output = graph.find_nodes(op="output")[0]
  602. fwd_outputs = output.args[0]
  603. quant_type = get_quant_type()
  604. clamp_min, clamp_max = calculate_range(quant_type)
  605. position_to_quant = dict()
  606. tensor_scale_nodes: list[fx.Node] = []
  607. sym_scale_nodes: list[fx.Node] = []
  608. for position, node in enumerate(fwd_outputs):
  609. # check if the activation node is the node saved for quantization
  610. if node.meta.get("saved_for_quantization", False):
  611. # case: use scaling
  612. if torch._inductor.config.post_grad_fusion_options[
  613. "activation_quantization_aten_pass"
  614. ].get("use_scaling", True):
  615. # calculating the scale
  616. scale_node = calculate_quantization_scaling(
  617. graph, node, clamp_max, 1e-12, position
  618. )
  619. # converting to fp8
  620. quant_node = perform_quantization(
  621. graph, node, scale_node, quant_type, clamp_min, clamp_max, position
  622. )
  623. if not is_sym_node(scale_node):
  624. tensor_scale_nodes.append(scale_node)
  625. else:
  626. sym_scale_nodes.append(scale_node)
  627. else:
  628. # case: do not use scaling
  629. with graph.inserting_after(node):
  630. quant_node = graph.call_function(
  631. torch.ops.prims.convert_element_type.default,
  632. args=(node, quant_type),
  633. name=f"fp8_quant_pos_{position}_{node.name}",
  634. )
  635. quant_node.meta["val"] = (
  636. torch.ops.prims.convert_element_type.default(
  637. node.meta["val"], quant_type
  638. )
  639. )
  640. quant_node.meta["tensor_meta"] = extract_tensor_metadata(
  641. quant_node.meta["val"]
  642. )
  643. position_to_quant[position] = quant_node
  644. # Use position-based lookup for building output
  645. # only update the return node args, and remain all other users unchanged
  646. output_updated_args = [
  647. position_to_quant.get(i, node) for i, node in enumerate(fwd_outputs)
  648. ]
  649. # add the scale nodes to the output find the first sym_node in the output
  650. # pyrefly: ignore [bad-argument-type]
  651. idx = find_first_sym_node(output_updated_args)
  652. scale_nodes = tensor_scale_nodes + sym_scale_nodes
  653. if scale_nodes:
  654. output_updated_args = (
  655. output_updated_args[:idx] + scale_nodes + output_updated_args[idx:]
  656. )
  657. output.update_arg(0, tuple(output_updated_args))
  658. counters["inductor"]["activation_quantization_fwd_aten_pass"] += 1
  659. def quantize_activation_bw(graph: torch.fx.Graph) -> None:
  660. bw_inputs = [node for node in graph.nodes if node.op == "placeholder"]
  661. activation_node = None
  662. for node in bw_inputs:
  663. if node.meta.get("saved_for_quantization", False):
  664. node.meta.pop("saved_for_quantization")
  665. dequant_type = node.meta.pop("dequant_type")
  666. # dequantize the node
  667. if torch._inductor.config.post_grad_fusion_options[
  668. "activation_quantization_aten_pass"
  669. ].get("use_scaling", False):
  670. # case: use scaling
  671. with graph.inserting_after(node):
  672. # find corresponding scale node
  673. scale_name = "fp8_scale_" + node.name.replace("fp8_quant_", "")
  674. scale_node = next(
  675. bwd_input
  676. for bwd_input in bw_inputs
  677. if bwd_input.name == scale_name
  678. )
  679. with graph.inserting_after(scale_node):
  680. activation_node = graph.call_function(
  681. torch.ops.prims.convert_element_type.default,
  682. args=(node, dequant_type),
  683. )
  684. activation_node.meta["val"] = (
  685. torch.ops.prims.convert_element_type.default(
  686. node.meta["val"], dequant_type
  687. )
  688. )
  689. activation_node.meta["tensor_meta"] = extract_tensor_metadata(
  690. activation_node.meta["val"]
  691. )
  692. with graph.inserting_after(activation_node):
  693. divided_target_node_32 = graph.call_function(
  694. torch.ops.aten.div.Tensor,
  695. args=(activation_node, scale_node),
  696. )
  697. divided_target_node_32.meta["val"] = torch.ops.aten.div.Tensor(
  698. activation_node.meta["val"], scale_node.meta["val"]
  699. )
  700. divided_target_node_32.meta["tensor_meta"] = (
  701. extract_tensor_metadata(divided_target_node_32.meta["val"])
  702. )
  703. with graph.inserting_after(divided_target_node_32):
  704. dequant_node = graph.call_function(
  705. torch.ops.prims.convert_element_type.default,
  706. args=(divided_target_node_32, dequant_type),
  707. )
  708. dequant_node.meta["val"] = (
  709. torch.ops.prims.convert_element_type.default(
  710. divided_target_node_32.meta["val"], dequant_type
  711. )
  712. )
  713. dequant_node.meta["tensor_meta"] = extract_tensor_metadata(
  714. dequant_node.meta["val"]
  715. )
  716. else:
  717. with graph.inserting_after(node):
  718. dequant_node = graph.call_function(
  719. torch.ops.prims.convert_element_type.default,
  720. args=(node, dequant_type),
  721. name="dequant_" + str(node.name),
  722. )
  723. dequant_node.meta["val"] = (
  724. torch.ops.prims.convert_element_type.default(
  725. node.meta["val"], dequant_type
  726. )
  727. )
  728. dequant_node.meta["tensor_meta"] = extract_tensor_metadata(
  729. dequant_node.meta["val"]
  730. )
  731. # find the users of the node and replace them with the new node except the dequant_node
  732. for user in list(node.users.keys()):
  733. if user != dequant_node and user != activation_node:
  734. user.replace_input_with(node, dequant_node)
  735. counters["inductor"]["activation_quantization_bwd_aten_pass"] += 1
  736. def perform_fp8_activation_quantization(
  737. fwd_module: fx.GraphModule,
  738. bwd_module: fx.GraphModule,
  739. bwd_module_inputs: dict[str, fx.Node],
  740. ) -> None:
  741. trace_structured(
  742. "artifact",
  743. metadata_fn=lambda: {
  744. "name": "before_activation_quantization_fwd_aten_pass",
  745. "encoding": "string",
  746. },
  747. payload_fn=lambda: fwd_module.print_readable(
  748. print_output=False, include_stride=True, include_device=True
  749. ),
  750. )
  751. quantize_activation_fw(fwd_module.graph)
  752. trace_structured(
  753. "artifact",
  754. metadata_fn=lambda: {
  755. "name": "after_activation_quantization_fwd_aten_pass",
  756. "encoding": "string",
  757. },
  758. payload_fn=lambda: fwd_module.print_readable(
  759. print_output=False, include_stride=True, include_device=True
  760. ),
  761. )
  762. trace_structured(
  763. "artifact",
  764. metadata_fn=lambda: {
  765. "name": "before_activation_quantization_bwd_aten_pass",
  766. "encoding": "string",
  767. },
  768. payload_fn=lambda: bwd_module.print_readable(
  769. print_output=False, include_stride=True, include_device=True
  770. ),
  771. )
  772. quant_fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0]
  773. # update the corresponding bwd_inputs due to the fwd_outputs quantization
  774. for fwd_node in quant_fwd_module_outputs:
  775. if "fp8_quant_" in fwd_node.name:
  776. bwd_input = bwd_module_inputs[
  777. re.sub(r"^fp8_quant_pos_\d+_", "", fwd_node.name)
  778. ]
  779. with bwd_module.graph.inserting_after(bwd_input):
  780. quant_bwd_input = bwd_module.graph.placeholder(name=fwd_node.name)
  781. dequant_type = bwd_input.meta["dequant_type"]
  782. quant_bwd_input.meta.update(fwd_node.meta)
  783. quant_bwd_input.meta["saved_for_quantization"] = True
  784. quant_bwd_input.meta["dequant_type"] = dequant_type
  785. bwd_input.replace_all_uses_with(quant_bwd_input)
  786. bwd_module.graph.erase_node(bwd_input)
  787. # update the bwd_inputs if quantization with scaling is used
  788. if torch._inductor.config.post_grad_fusion_options[
  789. "activation_quantization_aten_pass"
  790. ].get("use_scaling", True):
  791. quant_bwd_module_inputs = list(bwd_module.graph.find_nodes(op="placeholder"))
  792. # update the corresponding bwd input nodes find the last non-tangent node
  793. bwd_input_loc = quant_bwd_module_inputs[-1]
  794. for bw_input in reversed(quant_bwd_module_inputs):
  795. if not _is_tangent(bw_input):
  796. bwd_input_loc = bw_input
  797. break
  798. scaled_fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0]
  799. for fwd_node in scaled_fwd_module_outputs:
  800. if "fp8_scale_" in fwd_node.name:
  801. # fwd node is a scale node
  802. with bwd_module.graph.inserting_after(bwd_input_loc):
  803. scale_bwd_input = bwd_module.graph.placeholder(name=fwd_node.name)
  804. scale_bwd_input.meta.update(fwd_node.meta)
  805. bwd_input_loc = scale_bwd_input
  806. quantize_activation_bw(bwd_module.graph)
  807. trace_structured(
  808. "artifact",
  809. metadata_fn=lambda: {
  810. "name": "after_activation_quantization_bwd_aten_pass",
  811. "encoding": "string",
  812. },
  813. payload_fn=lambda: bwd_module.print_readable(
  814. print_output=False, include_stride=True, include_device=True
  815. ),
  816. )
  817. def enable_activation_quantization(
  818. saved_values: list[fx.Node],
  819. fwd_module: fx.GraphModule,
  820. bwd_module: fx.GraphModule,
  821. static_lifetime_input_nodes: OrderedSet[fx.Node] | None = None,
  822. ) -> None:
  823. if (
  824. inductor_config.post_grad_fusion_options.get(
  825. "activation_quantization_aten_pass", None
  826. )
  827. is None
  828. ):
  829. return
  830. static_input_names: list[str] = (
  831. [node.name for node in static_lifetime_input_nodes]
  832. if static_lifetime_input_nodes
  833. else []
  834. )
  835. saved_values_names = {node.name: node for node in saved_values}
  836. if torch._inductor.config.post_grad_fusion_options[
  837. "activation_quantization_aten_pass"
  838. ].get("exclude_primals", False):
  839. saved_values_names = {
  840. node.name: node for node in saved_values if "primals" not in node.name
  841. }
  842. fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0]
  843. bwd_module_inputs = {
  844. node.name: node for node in bwd_module.graph.find_nodes(op="placeholder")
  845. }
  846. should_perform_fp8_quant = False
  847. for node in fwd_module_outputs:
  848. if node.name in saved_values_names and should_quantize(node):
  849. if node.name in static_input_names:
  850. log.debug("Skipping quantization of static input %s: ", node.name)
  851. continue
  852. node.meta["saved_for_quantization"] = True
  853. node.meta["dequant_type"] = node.meta["val"].dtype
  854. # some of the fwd outputs and bwd inputs are not share the same object
  855. bwd_module_inputs[node.name].meta["saved_for_quantization"] = True
  856. bwd_module_inputs[node.name].meta["dequant_type"] = node.meta["val"].dtype
  857. should_perform_fp8_quant = True
  858. if should_perform_fp8_quant:
  859. perform_fp8_activation_quantization(fwd_module, bwd_module, bwd_module_inputs)
  860. def _extract_fwd_bwd_modules(
  861. joint_module: fx.GraphModule,
  862. saved_values: list[fx.Node],
  863. saved_sym_nodes: list[fx.Node],
  864. *,
  865. num_fwd_outputs: int,
  866. static_lifetime_input_nodes: OrderedSet[fx.Node] | None = None,
  867. ) -> tuple[fx.GraphModule, fx.GraphModule]:
  868. fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
  869. _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
  870. )
  871. placeholders = joint_module.graph.find_nodes(op="placeholder")
  872. primal_inputs = [*filter(_is_primal, placeholders)]
  873. tangent_inputs = [*filter(_is_tangent, placeholders)]
  874. fwd_seed_offset_inputs = [*filter(_is_fwd_seed_offset, placeholders)]
  875. bwd_seed_offset_inputs = [*filter(_is_bwd_seed_offset, placeholders)]
  876. backward_state_inputs = [*filter(_is_backward_state, placeholders)]
  877. bwd_graph = _extract_graph_with_inputs_outputs(
  878. joint_module.graph,
  879. saved_sym_nodes + saved_values + tangent_inputs + bwd_seed_offset_inputs,
  880. bwd_outputs,
  881. bwd_outputs_descs,
  882. "backward",
  883. )
  884. distributed_enabled = torch.distributed.is_available()
  885. for node in bwd_graph.find_nodes(op="placeholder"):
  886. # This is to filter out saved values that don't actually end up being used by the backwards pass
  887. if not node.users:
  888. _remove_by_name(saved_values, node.name)
  889. _remove_by_name(saved_sym_nodes, node.name)
  890. # wait_tensor is a bit special: if we have a "dead activation" that is not used in the bw,
  891. # but this dead activation is actually a collective,
  892. # then the collective will generally by followed by a wait_tensor() call.
  893. # we need to peak one node further to see if this wait_tensor is dead as well.
  894. elif distributed_enabled and all(
  895. n.target is torch.ops._c10d_functional.wait_tensor.default
  896. and len(n.users) == 0
  897. for n in node.users
  898. ):
  899. _remove_by_name(saved_values, node.name)
  900. _remove_by_name(saved_sym_nodes, node.name)
  901. elif _is_backward_state(node):
  902. # BackwardState is saved directly
  903. _remove_by_name(saved_values, node.name)
  904. if not backward_state_inputs:
  905. raise AssertionError("backward_state_inputs must not be empty")
  906. # Now that we have the finalized list of saved values, we need to ensure
  907. # we propagate all symbols which are referenced by backwards inputs.
  908. # These are not directly used in the graph but are required for downstream
  909. # sizevar assignment
  910. saved_symbols: OrderedSet[sympy.Symbol] = OrderedSet()
  911. saved_sym_nodes_binding = []
  912. saved_sym_nodes_derived = []
  913. # Some symbols may already be bound in the directly saved_sym_nodes,
  914. # keep track of them so we don't re-bind them
  915. for node in saved_sym_nodes:
  916. symbol = is_symbol_binding_fx_node(node)
  917. if symbol:
  918. saved_symbols.add(symbol)
  919. saved_sym_nodes_binding.append(node)
  920. else:
  921. saved_sym_nodes_derived.append(node)
  922. # Now go through all of the prospective backward inputs and track any
  923. # other symbols we need to bind
  924. symbol_bindings = find_symbol_binding_fx_nodes(joint_module.graph)
  925. for node in itertools.chain(saved_sym_nodes_derived, saved_values, tangent_inputs):
  926. if "val" not in node.meta:
  927. continue
  928. new_symbols = free_symbols(node.meta["val"]) - saved_symbols
  929. # NB: Deterministic order please!
  930. for s in sorted(new_symbols, key=lambda s: s.name):
  931. # NB: For well formed graphs, the symbol should always be present,
  932. # but we also have ways to produce ill-formed graphs, e.g., direct
  933. # make_fx usages, so don't choke in this case
  934. if s not in symbol_bindings:
  935. continue
  936. saved_sym_nodes_binding.append(symbol_bindings[s])
  937. saved_symbols |= new_symbols
  938. # Update saved_sym_nodes that are now reordered to have all bindings at
  939. # front. This can also be used later on to figure out the position of saved
  940. # sym nodes in the output of fwd graph.
  941. saved_sym_nodes.clear()
  942. saved_sym_nodes.extend(saved_sym_nodes_binding + saved_sym_nodes_derived)
  943. # See Note [Activations with no version counter checks in eager]
  944. # Sort saved_values so that tensors with saved_tensor_with_no_vc_check=True
  945. # are at the end. This allows us to have two consecutive slices:
  946. # 1. tensors_saved_with_vc_check_slice - tensors saved via save_for_backward
  947. # 2. tensors_saved_with_no_vc_check_slice - tensors stashed on ctx without save_for_backward
  948. # The sort is stable, so the relative order within each group is preserved.
  949. #
  950. # Additionally, separate out opaque objects (FakeScriptObject) from tensors.
  951. # Opaque objects should be placed after tensors in the forward outputs.
  952. saved_values_with_vc_check = []
  953. saved_values_no_vc_check = []
  954. saved_opaque_objects = []
  955. for node in saved_values:
  956. # Check if this is an opaque object
  957. if isinstance(node.meta.get("val"), FakeScriptObject):
  958. saved_opaque_objects.append(node)
  959. elif node.meta.get("saved_tensor_with_no_vc_check", False):
  960. saved_values_no_vc_check.append(node)
  961. else:
  962. saved_values_with_vc_check.append(node)
  963. saved_values.clear()
  964. saved_values.extend(saved_values_with_vc_check + saved_values_no_vc_check)
  965. no_vc_check_start_idx = len(saved_values_with_vc_check)
  966. # debug assert: given saved_values where the last k of them are expected to not
  967. # require VC checks, they should all have node metadata indicating so.
  968. for i, node in enumerate(saved_values):
  969. if i >= no_vc_check_start_idx:
  970. if not node.meta.get("saved_tensor_with_no_vc_check", False):
  971. raise AssertionError(
  972. f"i={i}, no_vc_check_start_idx={no_vc_check_start_idx}, len(saved_values)={len(saved_values)}"
  973. )
  974. # Now, we re-generate the fwd/bwd graphs.
  975. # NB: This might increase compilation time, but I doubt it matters
  976. # Convention for saved acts is (tensors_with_vc_check, tensors_no_vc_check, opaque_objects, symints)
  977. fwd_graph = _extract_graph_with_inputs_outputs(
  978. joint_module.graph,
  979. primal_inputs + fwd_seed_offset_inputs,
  980. fwd_outputs + saved_values + saved_opaque_objects + saved_sym_nodes,
  981. fwd_outputs_descs
  982. + [
  983. SavedForBackwardsNoVcCheckAOTOutput(i)
  984. if i >= no_vc_check_start_idx and i < len(saved_values)
  985. else SavedForBackwardsAOTOutput(i)
  986. for i in range(
  987. len(saved_values) + len(saved_opaque_objects) + len(saved_sym_nodes)
  988. )
  989. ],
  990. "forward",
  991. )
  992. bwd_graph = _extract_graph_with_inputs_outputs(
  993. joint_module.graph,
  994. saved_sym_nodes
  995. + saved_values
  996. + saved_opaque_objects
  997. + tangent_inputs
  998. + bwd_seed_offset_inputs
  999. + backward_state_inputs,
  1000. bwd_outputs,
  1001. bwd_outputs_descs,
  1002. "backward",
  1003. )
  1004. fwd_module = fx._lazy_graph_module._make_graph_module(joint_module, fwd_graph)
  1005. bwd_module = fx._lazy_graph_module._make_graph_module(joint_module, bwd_graph)
  1006. enable_activation_quantization(
  1007. saved_values, fwd_module, bwd_module, static_lifetime_input_nodes
  1008. )
  1009. return fwd_module, bwd_module
  1010. def default_partition(
  1011. joint_module: fx.GraphModule,
  1012. _joint_inputs: Any,
  1013. *,
  1014. num_fwd_outputs: int,
  1015. static_lifetime_input_indices: list[int] | None = None,
  1016. static_lifetime_input_nodes: OrderedSet[fx.Node] | None = None,
  1017. ) -> tuple[fx.GraphModule, fx.GraphModule]:
  1018. """
  1019. Partitions the :attr:`joint_module` in a manner that closely resembles the
  1020. behavior observed in the original ``.forward()`` and ``.backward()`` of the
  1021. callable, i.e., the resulting forward graph contains those operators that
  1022. are executed in the original ``.forward()`` callable passed to
  1023. :func:`aot_function`.
  1024. The default partitioner collects the operators that are between the forward
  1025. inputs and the forward outputs. This helps in finding the tensors which have
  1026. to be stashed for the backward pass. These stashed tensors become the output
  1027. of the generated forward graph. The remaining operators are then placed in
  1028. the backward graph.
  1029. .. warning::
  1030. This API is experimental and likely to change.
  1031. Args:
  1032. joint_module(fx.GraphModule): The joint forward and backward graph. This
  1033. is the result of AOT Autograd tracing.
  1034. Returns:
  1035. Returns the generated forward and backward Fx graph modules.
  1036. """
  1037. # Respect the original placement of ops rather than rely on dataflow.
  1038. forward_nodes = []
  1039. last_node = None
  1040. for node in joint_module.graph.nodes:
  1041. if _has_tag_is_forward(node) or _is_primal(node) or _is_fwd_seed_offset(node):
  1042. last_node = node
  1043. if last_node is None:
  1044. raise AssertionError("last_node must not be None")
  1045. for node in joint_module.graph.nodes:
  1046. if not _is_tangent(node):
  1047. forward_nodes.append(node)
  1048. if node is last_node:
  1049. break
  1050. forward_node_names = OrderedSet(
  1051. node.name for node in forward_nodes if node.op != "output"
  1052. )
  1053. graph_has_recomputable_ops = has_recomputable_ops(joint_module)
  1054. graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
  1055. if graph_has_recomputable_ops:
  1056. if _is_functional_graph(joint_module.graph)[0] is not None:
  1057. # Fall-back to previous behavior to avoid bc-breaking, although can
  1058. # eventually flip the switch to make this a hard error.
  1059. warnings.warn(
  1060. "Trying to unsafely apply AC to a non-functional graph with the "
  1061. "default partitioner. Falling back to min-cut partitioner."
  1062. )
  1063. return min_cut_rematerialization_partition(
  1064. joint_module,
  1065. _joint_inputs,
  1066. num_fwd_outputs=num_fwd_outputs,
  1067. static_lifetime_input_indices=static_lifetime_input_indices,
  1068. )
  1069. joint_module = cleanup_recompute_tags(joint_module, is_default_partition=True)
  1070. if not config.unsafe_allow_optimization_of_collectives:
  1071. force_save_collectives(joint_module)
  1072. force_save_effectful_ops(joint_module)
  1073. force_save_bw_mutation_src(joint_module)
  1074. if static_lifetime_input_indices is None:
  1075. static_lifetime_input_indices = []
  1076. node_info = classify_nodes(
  1077. joint_module, static_lifetime_input_indices, num_fwd_outputs
  1078. )
  1079. saved_values = []
  1080. saved_sym_nodes = []
  1081. distributed_enabled = torch.distributed.is_available()
  1082. def is_tensor(node: fx.Node) -> bool:
  1083. return "tensor_meta" in node.meta or isinstance(
  1084. node.meta.get("val"), torch._subclasses.FakeTensor
  1085. )
  1086. def is_multi_output(node: fx.Node) -> bool:
  1087. return (
  1088. all(user.target == operator.getitem for user in node.users)
  1089. and len(node.users) > 0
  1090. )
  1091. def is_impure(node: fx.Node) -> bool:
  1092. # wait tensor is an "impure" op according to DCE's definition of impure
  1093. # (see is_impure in torch/fx/node.py), but it survives past
  1094. # functionalization and can be safely dup'd and reordered under the
  1095. # assumption SPMD.
  1096. return (
  1097. node.is_impure(impure_random=False)
  1098. and node.op
  1099. not in (
  1100. "placeholder",
  1101. "output",
  1102. )
  1103. and (
  1104. not distributed_enabled
  1105. or node.target is not torch.ops._c10d_functional.wait_tensor.default
  1106. )
  1107. )
  1108. for node in joint_module.graph.nodes:
  1109. if node.name not in forward_node_names:
  1110. continue
  1111. if node.op == "get_attr" and node.name in (
  1112. k for k, v in joint_module.named_modules()
  1113. ):
  1114. continue
  1115. if node.target in (
  1116. torch.ops.aten._assert_scalar.default,
  1117. # Profiler record_function ops are technically impure (they set up
  1118. # profiling spans), but they're safe to duplicate during AC recompute.
  1119. # We skip both enter and exit to keep profiling spans balanced.
  1120. torch.ops.profiler._record_function_enter_new.default,
  1121. torch.ops.profiler._record_function_enter.default,
  1122. torch.ops.profiler._record_function_exit.default,
  1123. torch.ops.profiler._record_function_exit._RecordFunction,
  1124. ):
  1125. continue
  1126. if is_sym_node(node):
  1127. # Symints must be kept separate from tensors so that PythonFunction only calls
  1128. # save_for_backward on tensors and stashes symints in autograd .ctx
  1129. saved_sym_nodes.append(node)
  1130. continue
  1131. if is_multi_output(node):
  1132. # Must be ordered before MUST_SAVE tags to avoid saving tuples marked MUST_SAVE.
  1133. continue
  1134. if node.meta.get("recompute") == CheckpointPolicy.MUST_SAVE:
  1135. saved_values.append(node)
  1136. continue
  1137. if is_impure(node):
  1138. if graph_has_recomputable_ops:
  1139. raise AssertionError(
  1140. f"Trying to apply AC on a graph with impure op: {node}, {node.target}"
  1141. )
  1142. saved_values.append(node)
  1143. continue
  1144. if not is_tensor(node) and node.op == "call_function":
  1145. raise AssertionError(f"Expected {node} to be a tensor")
  1146. backward_usages = [n for n in node.users if n.name not in forward_node_names]
  1147. if all(is_sym_node(n) for n in backward_usages):
  1148. # If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
  1149. # and not the actual tensor data,
  1150. # then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
  1151. #
  1152. # Note that saving the tensor could also cause compilation problems:
  1153. # If the user mutated an input in the forward and uses its sizes/strides in the backward,
  1154. # then we would be obligated to clone the input before saving it to appease autograd.
  1155. # (This is how we originally found this bug).
  1156. saved_sym_nodes.extend(backward_usages)
  1157. continue
  1158. if not must_recompute(node):
  1159. saved_values.append(node)
  1160. saved_values = list(dict.fromkeys(saved_values).keys())
  1161. saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
  1162. if config._sync_decision_cross_ranks:
  1163. saved_values = _sync_decision_cross_ranks(joint_module.graph, saved_values)
  1164. if static_lifetime_input_nodes is None:
  1165. static_lifetime_input_nodes = node_info.static_lifetime_input_nodes
  1166. fw_module, bw_module = _extract_fwd_bwd_modules(
  1167. joint_module,
  1168. saved_values,
  1169. saved_sym_nodes=saved_sym_nodes,
  1170. num_fwd_outputs=num_fwd_outputs,
  1171. static_lifetime_input_nodes=static_lifetime_input_nodes,
  1172. )
  1173. # Run DCE while overriding the definition of is_impure_node
  1174. fw_module.graph.eliminate_dead_code(is_impure_node=is_not_collective)
  1175. bw_module.graph.eliminate_dead_code(is_impure_node=is_not_collective)
  1176. if graph_has_recomputable_ops:
  1177. if graph_has_recomputable_rng_ops:
  1178. fw_module, bw_module = functionalize_rng_ops(
  1179. joint_module, fw_module, bw_module, len(saved_sym_nodes)
  1180. )
  1181. bw_module = reordering_to_mimic_autograd_engine(bw_module)
  1182. # pyrefly: ignore [unbound-name]
  1183. if config.enable_activation_offloading:
  1184. from ._activation_offloading.activation_offloading import (
  1185. enable_activation_offloading,
  1186. )
  1187. enable_activation_offloading(
  1188. fw_module,
  1189. bw_module,
  1190. num_fwd_outputs,
  1191. static_lifetime_input_nodes,
  1192. )
  1193. # raise all getitem ops to as early as possible
  1194. # this is helpful for memory, especially in the case of aot_eager backend
  1195. fw_module = raise_getitems(fw_module)
  1196. bw_module = raise_getitems(bw_module)
  1197. fw_module = thread_graphsafe_rng_from_hops(fw_module, is_backward=False)
  1198. if len(node_info.required_bw_nodes) > 0:
  1199. bw_module = thread_graphsafe_rng_from_hops(bw_module, is_backward=True)
  1200. return fw_module, bw_module
  1201. INT_INF = int(1e6)
  1202. def _tensor_nbytes(numel: int, dtype: torch.dtype) -> int:
  1203. return numel * dtype.itemsize
  1204. def _size_of(node: fx.Node) -> int:
  1205. def object_nbytes(x: object) -> int:
  1206. if not isinstance(x, torch.Tensor):
  1207. return 0
  1208. return _tensor_nbytes(size_hint(x.numel(), fallback=4096), x.dtype)
  1209. if "val" in node.meta:
  1210. val = node.meta["val"]
  1211. if isinstance(val, py_sym_types):
  1212. return 1
  1213. # NB: The fallback values here are meaningless, maybe we should respect
  1214. # torch._inductor.config.unbacked_symint_fallback (but this is a
  1215. # layering violation)
  1216. elif isinstance(val, (list, tuple)):
  1217. return sum(object_nbytes(n) for n in val)
  1218. elif isinstance(val, dict):
  1219. return sum(object_nbytes(n) for _, n in val.items())
  1220. elif isinstance(val, torch.Tensor):
  1221. return object_nbytes(val)
  1222. raise RuntimeError(f"Unknown metadata type {type(val)} on node {node}")
  1223. if node.op == "get_attr" or node.target is torch.ops.aten._assert_scalar.default:
  1224. return 0
  1225. raise RuntimeError(
  1226. f"Node {node} didn't have `val` metadata; we should always have `val` metadata on the nodes."
  1227. )
  1228. # Used for some investigative purposes
  1229. def _count_ops(graph: fx.Graph) -> None:
  1230. from collections import defaultdict
  1231. cnt: dict[str, int] = defaultdict(int)
  1232. for node in graph.nodes:
  1233. if node.op == "call_function":
  1234. cnt[node.target.__name__] += 1
  1235. log.info("%s", sorted(cnt.items(), key=operator.itemgetter(1), reverse=True))
  1236. @functools.cache
  1237. def pointwise_ops() -> list[torch._ops.OpOverloadPacket]:
  1238. ops: list[torch._ops.OpOverloadPacket] = []
  1239. for attr_name in dir(torch.ops.aten):
  1240. opoverloadpacket = getattr(torch.ops.aten, attr_name)
  1241. if not isinstance(opoverloadpacket, torch._ops.OpOverloadPacket):
  1242. continue
  1243. for overload in opoverloadpacket.overloads():
  1244. op_overload = getattr(opoverloadpacket, overload)
  1245. if torch.Tag.pointwise in op_overload.tags:
  1246. # currently aot autograd uses packet not overload
  1247. ops.append(opoverloadpacket)
  1248. break
  1249. return ops
  1250. def sort_depths(
  1251. args: tuple[Any, ...], depth_map: dict[fx.Node, int]
  1252. ) -> list[tuple[fx.Node, int]]:
  1253. arg_depths = {
  1254. arg: depth_map[arg] for arg in args if isinstance(arg, torch.fx.node.Node)
  1255. }
  1256. return sorted(arg_depths.items(), key=operator.itemgetter(1), reverse=True)
  1257. def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule:
  1258. """
  1259. This pass finds the first bwd node in the graph (by looking at users of
  1260. tangents) and then reorders the graph by walking from this node to all the
  1261. way to the end of the graph. At each op in this traversal, we insert this op
  1262. in a new graph and try to bring only the relevant subgraph from the other
  1263. non-bwd edges relevant for this op. This closely mimics the behavior of
  1264. autograd engine.
  1265. Why is this pass required in the first place?
  1266. This is an artifact of how partitioners work today. The starting point of
  1267. partitioner is a joint graph, which is fwd and then bwd graph. In the case
  1268. of checkpointing, we keep portions of fwd graph in their original place in
  1269. the joint graph, while obtaining a bwd graph. As a result, the resulting bwd
  1270. graph has copies of recomputed fwd subgraphs followed by the original bwd
  1271. graph. If we run this naively, this leads to bad memory footprint, because
  1272. the fwd subgraphs are live for way longer duration than necessary. This pass
  1273. reorders the operations such that we prioritize the ops for the original bwd
  1274. graph while only realizing those ops from the fwd graph that are necessary
  1275. at any given point in the graph.
  1276. """
  1277. new_graph = fx.Graph()
  1278. env: dict[fx.Node, fx.Node] = {}
  1279. # Add new placeholder nodes in the order specified by the inputs
  1280. for node in gm.graph.find_nodes(op="placeholder"):
  1281. env[node] = new_graph.node_copy(node, lambda x: env[x])
  1282. order = {node: idx for idx, node in enumerate(gm.graph.nodes)}
  1283. def insert_node_in_graph(node: fx.Node) -> None:
  1284. cur_nodes = [node]
  1285. insertable_nodes: OrderedSet[fx.Node] = OrderedSet()
  1286. while len(cur_nodes) > 0:
  1287. node = cur_nodes.pop()
  1288. if node in insertable_nodes or node in env:
  1289. continue
  1290. insertable_nodes.add(node)
  1291. # Bias traversal towards the nodes that have higher depth - prioritizes
  1292. # critical path first.
  1293. cur_nodes += node.all_input_nodes
  1294. # pyrefly: ignore [bad-assignment]
  1295. insertable_nodes = sorted(insertable_nodes, key=lambda n: order[n])
  1296. for node in insertable_nodes:
  1297. env[node] = new_graph.node_copy(node, lambda x: env[x])
  1298. # Find first bwd node in the graph
  1299. tangent_inputs = list(filter(_is_tangent, gm.graph.nodes))
  1300. first_node_in_bwd = None
  1301. minimum_order = math.inf
  1302. for tangent in tangent_inputs:
  1303. for user in tangent.users:
  1304. if order[user] < minimum_order:
  1305. minimum_order = order[user]
  1306. first_node_in_bwd = user
  1307. # If gradInp does not depend upon gradOut, we may not find any nodes in the "backwards pass"
  1308. if first_node_in_bwd is None:
  1309. return gm
  1310. # Build the graph op-by-op by starting from the node all the way to the end
  1311. # copy_ can be not using tangents at all, we must copy it.
  1312. for node in list(gm.graph.nodes)[: order[first_node_in_bwd]]:
  1313. if node.op == "call_function" and node.target is torch.ops.aten.copy_.default:
  1314. insert_node_in_graph(node)
  1315. for node in list(gm.graph.nodes)[order[first_node_in_bwd] :]:
  1316. insert_node_in_graph(node)
  1317. # The output node is already built by the traversal.
  1318. new_gm = torch.fx.GraphModule(gm, new_graph)
  1319. return new_gm
  1320. def apply_graphsafe_rng_functionalization(
  1321. fw_module: torch.fx.GraphModule,
  1322. bw_module: torch.fx.GraphModule,
  1323. fw_node: torch.fx.Node,
  1324. bw_node: torch.fx.Node,
  1325. device: torch.device,
  1326. rng_count: int,
  1327. last_fwd_input: torch.fx.Node,
  1328. last_bwd_input: torch.fx.Node,
  1329. ) -> tuple[torch.fx.Node, torch.fx.Node]:
  1330. """
  1331. Note [CUDA Graph Safe RNG Functionalization]
  1332. CUDA Graph capture doesn't work with get_rng_state and set_rng_state because these functions operate on CPU values,
  1333. while CUDA Graph RNG capture uses on-device CUDA tensors. To solve this, we use graphsafe_set_state with a
  1334. CUDA Generator registered to the CUDA Graph before capture begins. graphsafe_set_state updates the generator's pointer
  1335. to reference a different GeneratorImpl, ensuring subsequent calls are correctly forwarded to the desired generator
  1336. (and its cuda-tensor RNG state during graph capture).
  1337. For each RNG operation's forward/backward pair:
  1338. - We create two generators initialized with identical values
  1339. - Each forward and backward call advances its respective generator equally
  1340. - This keeps generators synchronized so forward and backward operations use matching RNG values
  1341. When forward is called multiple times before backward (causing desynchronization):
  1342. - We save the forward RNG state
  1343. - We update the backward Generator's state before executing backward
  1344. Before each CUDA Graph replay, replay_prologue updates captured RNG pointers with current states, ensuring backward Generator
  1345. changes are reflected during replay.
  1346. This function modifies both forward and backward computation graphs by:
  1347. Creating RNG state placeholders for both passes
  1348. Updating the forward node to use graph-safe RNG state
  1349. Updating the backward node to use graph-safe RNG state
  1350. For more details: https://github.com/pytorch/pytorch/issues/113541
  1351. """
  1352. device_idx = device.index
  1353. if device_idx is None:
  1354. raise AssertionError("device_idx must not be None")
  1355. fw_graph = fw_module.graph
  1356. bw_graph = bw_module.graph
  1357. graphsafe_run_with_rng_state = torch._prims.rng_prims.graphsafe_run_with_rng_state
  1358. # Handle forward pass
  1359. # Note: [Generator arguments in AOTDispatcher]
  1360. # Generator arguments in AOTDispatcher are added to support graphsafe rng
  1361. # functionalization. See note above [CUDA Graph Safe RNG Functionalization]
  1362. with fw_module.graph.inserting_after(last_fwd_input):
  1363. fwd_rng_state = fw_module.graph.placeholder(f"fwd_rng_state_{rng_count}")
  1364. fwd_rng_state.meta["val"] = get_cuda_generator_meta_val(device_idx)
  1365. last_fwd_input = fwd_rng_state
  1366. # Handle backward pass
  1367. with bw_module.graph.inserting_after(last_bwd_input):
  1368. bwd_rng_state = bw_module.graph.placeholder(f"bwd_rng_state_{rng_count}")
  1369. # as above, clone so that meta val generator will not contain tensors
  1370. bwd_rng_state.meta["val"] = get_cuda_generator_meta_val(device_idx)
  1371. last_bwd_input = bwd_rng_state
  1372. # Update forward node
  1373. fw_kwargs = dict(fw_node.kwargs)
  1374. fw_kwargs["rng_state"] = fwd_rng_state
  1375. with fw_module.graph.inserting_after(fw_node):
  1376. functional_fw_node = fw_graph.create_node(
  1377. "call_function",
  1378. graphsafe_run_with_rng_state,
  1379. args=(fw_node.target, *fw_node.args), # type: ignore[arg-type]
  1380. kwargs=fw_kwargs,
  1381. )
  1382. fw_node.replace_all_uses_with(functional_fw_node)
  1383. fw_graph.erase_node(fw_node)
  1384. # Update backward node
  1385. bwd_kwargs = dict(bw_node.kwargs)
  1386. bwd_kwargs["rng_state"] = bwd_rng_state
  1387. with bw_graph.inserting_before(bw_node):
  1388. rng_output = bw_graph.create_node(
  1389. "call_function",
  1390. graphsafe_run_with_rng_state,
  1391. args=(bw_node.target, *bw_node.args), # type: ignore[arg-type]
  1392. kwargs=bwd_kwargs,
  1393. )
  1394. bw_node.replace_all_uses_with(rng_output)
  1395. bw_graph.erase_node(bw_node)
  1396. return last_fwd_input, last_bwd_input
  1397. def functionalize_rng_ops(
  1398. joint_module: fx.GraphModule,
  1399. fw_module: fx.GraphModule,
  1400. bw_module: fx.GraphModule,
  1401. num_sym_nodes: int,
  1402. ) -> tuple[fx.GraphModule, fx.GraphModule]:
  1403. # During user-driven activation checkpointing, we have to ensure that a rng
  1404. # op in fwd yields the same output as the recomputed rng op in the bwd. To
  1405. # do this, we use functionalize wrappers to wrap the random ops and share
  1406. # rng state between the fwd and bwd graphs.
  1407. # There are 3 main steps to do this
  1408. # Step 1 - Construct a mapping of rng node between the fwd and its counterpart in bwd.
  1409. # Step 2 - Modify the fwd pass such that
  1410. # 1) Replace rand with run_and_save_rng_state wrapper
  1411. # 2) Replace the users of the original op with the output[1] of this op.
  1412. # 3) Collect all the rng_state - output[0] of each op, and make them
  1413. # output nodes. Special care needs to be taken here because fwd outputs
  1414. # has symints at the very end.
  1415. # Step 3 - Modify the bwd pass such that
  1416. # 1) Add the input nodes just before the tangents for the stashed rng states
  1417. # 2) Replace rand with run_with_save_rng_state wrappers
  1418. # 3) Use the stashed states as inputs to these ops
  1419. # Unique id to generate name
  1420. uid = itertools.count()
  1421. def get_rng_ops(gmod: fx.GraphModule) -> dict[str, fx.Node]:
  1422. random_nodes: dict[str, fx.Node] = {}
  1423. for node in gmod.graph.nodes:
  1424. if (
  1425. node.op == "call_function"
  1426. and hasattr(node.target, "tags")
  1427. and torch.Tag.nondeterministic_seeded in node.target.tags
  1428. ):
  1429. random_nodes[node.name] = node
  1430. return random_nodes
  1431. def get_device(node: fx.Node) -> torch.device | None:
  1432. """
  1433. Check the example value of the node outputs to find the device type.
  1434. """
  1435. if "val" not in node.meta:
  1436. return None
  1437. candidates = node.meta["val"]
  1438. if not isinstance(candidates, tuple):
  1439. candidates = (candidates,)
  1440. for candidate in candidates:
  1441. if isinstance(candidate, torch.Tensor):
  1442. if candidate.device.type == "cuda":
  1443. return candidate.device
  1444. return torch.device("cpu")
  1445. def get_sample_rng_state(device: torch.device | None) -> torch.Tensor:
  1446. from torch._guards import detect_fake_mode # noqa: F401
  1447. fake_mode = detect_fake_mode()
  1448. if fake_mode is None:
  1449. raise AssertionError("fake_mode must not be None")
  1450. with fake_mode:
  1451. if device is not None and device.type == "cuda":
  1452. return fake_mode.from_tensor(torch.cuda.get_rng_state())
  1453. return fake_mode.from_tensor(torch.get_rng_state())
  1454. # Step 1 - Construct a mapping of rng node between the fwd and its counterpart in bwd.
  1455. joint_graph_rng_ops = get_rng_ops(joint_module)
  1456. fw_graph_rng_ops = get_rng_ops(fw_module)
  1457. bw_graph_rng_ops = get_rng_ops(bw_module)
  1458. recomputable_rng_ops_map = {}
  1459. for node in joint_module.graph.nodes:
  1460. if (
  1461. must_recompute(node)
  1462. and hasattr(node.target, "tags")
  1463. and torch.Tag.nondeterministic_seeded in node.target.tags
  1464. ):
  1465. # Skip if the node doesn't exist in both forward and backward graphs.
  1466. # This can happen when the RNG op's output is not needed for gradient
  1467. # computation and gets eliminated by dead code elimination.
  1468. if node.name not in fw_graph_rng_ops or node.name not in bw_graph_rng_ops:
  1469. continue
  1470. base_node = joint_graph_rng_ops[node.name]
  1471. fw_node = fw_graph_rng_ops[node.name]
  1472. bw_node = bw_graph_rng_ops[node.name]
  1473. recomputable_rng_ops_map[base_node] = {"fwd": fw_node, "bwd": bw_node}
  1474. run_and_save_rng = torch._prims.rng_prims.run_and_save_rng_state
  1475. run_with_rng_state = torch._prims.rng_prims.run_with_rng_state
  1476. bw_tangent_start_node = None
  1477. for node in bw_module.graph.find_nodes(op="placeholder"):
  1478. if "tangent" in node.name:
  1479. bw_tangent_start_node = node
  1480. break
  1481. if bw_tangent_start_node is None:
  1482. raise RuntimeError(
  1483. "Couldn't find tangent node in graph inputs. This is unexpected, please file a bug if you see this"
  1484. )
  1485. fw_rng_state_outputs: list[fx.Node] = []
  1486. last_fwd_input = next(reversed(fw_module.graph.find_nodes(op="placeholder")))
  1487. last_bwd_input = next(reversed(bw_module.graph.find_nodes(op="placeholder")))
  1488. devices = OrderedSet(
  1489. get_device(node_pair["fwd"]) for node_pair in recomputable_rng_ops_map.values()
  1490. )
  1491. # pyrefly: ignore [unbound-name]
  1492. devices.discard(torch.device("cpu"))
  1493. # multiple cuda devices won't work with cudagraphs anyway,
  1494. # fallback to non graphsafe rng checkpointing
  1495. multi_cuda_devices = len(devices) > 1
  1496. # this changes numerics, so if fallback_random is set we will not use it
  1497. # pyrefly: ignore [unbound-name]
  1498. ind_config = torch._inductor.config
  1499. use_rng_graphsafe_rng_functionalization = (
  1500. config.graphsafe_rng_functionalization
  1501. and not multi_cuda_devices
  1502. and (
  1503. not ind_config.fallback_random
  1504. or ind_config.test_configs.graphsafe_rng_func_ignores_fallback_random
  1505. )
  1506. )
  1507. for rng_count, node_pair in enumerate(recomputable_rng_ops_map.values()):
  1508. # Step 2 - Modify the fwd pass such that
  1509. fw_node = node_pair["fwd"]
  1510. bw_node = node_pair["bwd"]
  1511. device = get_device(fw_node)
  1512. fw_graph = fw_module.graph
  1513. bw_graph = bw_module.graph
  1514. if (
  1515. use_rng_graphsafe_rng_functionalization
  1516. and device is not None
  1517. and device.type == "cuda"
  1518. ):
  1519. last_fwd_input, last_bwd_input = apply_graphsafe_rng_functionalization(
  1520. fw_module,
  1521. bw_module,
  1522. fw_node,
  1523. bw_node,
  1524. device,
  1525. rng_count,
  1526. last_fwd_input,
  1527. last_bwd_input,
  1528. )
  1529. else:
  1530. with fw_graph.inserting_before(fw_node):
  1531. functional_fw_node = fw_graph.create_node(
  1532. "call_function",
  1533. run_and_save_rng,
  1534. # pyrefly: ignore [bad-argument-type]
  1535. args=(
  1536. fw_node.target,
  1537. *fw_node.args,
  1538. ), # pyrefly: ignore[bad-argument-type]
  1539. kwargs=fw_node.kwargs,
  1540. )
  1541. state = fw_graph.create_node(
  1542. "call_function",
  1543. operator.getitem,
  1544. args=(functional_fw_node, 0),
  1545. kwargs={},
  1546. )
  1547. state.meta["val"] = get_sample_rng_state(device)
  1548. rng_output = fw_graph.create_node(
  1549. "call_function",
  1550. operator.getitem,
  1551. args=(
  1552. functional_fw_node,
  1553. 1,
  1554. ),
  1555. kwargs={},
  1556. )
  1557. # Copy the meta data from the original node
  1558. rng_output.meta = copy.copy(fw_node.meta)
  1559. fw_node.replace_all_uses_with(rng_output)
  1560. fw_graph.erase_node(fw_node)
  1561. fw_rng_state_outputs.append(state)
  1562. # Step 3 - Modify the bwd pass such that
  1563. with bw_graph.inserting_before(bw_tangent_start_node):
  1564. state_name = f"rng_state_output_{next(uid)}"
  1565. bw_rng_state_node = bw_graph.placeholder(state_name)
  1566. bw_rng_state_node.meta["val"] = get_sample_rng_state(device)
  1567. with bw_graph.inserting_before(bw_node):
  1568. rng_output = bw_graph.create_node(
  1569. "call_function",
  1570. run_with_rng_state,
  1571. # pyrefly: ignore [bad-argument-type]
  1572. args=(
  1573. bw_rng_state_node,
  1574. bw_node.target,
  1575. *bw_node.args,
  1576. ), # pyrefly: ignore[bad-argument-type]
  1577. kwargs=bw_node.kwargs,
  1578. )
  1579. bw_node.replace_all_uses_with(rng_output)
  1580. bw_graph.erase_node(bw_node)
  1581. # Add the rng states in the output of the fwd graph. AOT Autograd assumes
  1582. # that symints are at the end of forward graph outputs. So, insert the new
  1583. # rng states accordingly.
  1584. if fw_rng_state_outputs:
  1585. fw_output_node = next(iter(fw_module.graph.find_nodes(op="output")))
  1586. fw_outputs = fw_output_node.args[0]
  1587. sym_node_start_idx = len(fw_outputs) - num_sym_nodes
  1588. outputs = (
  1589. fw_outputs[:sym_node_start_idx]
  1590. + tuple(fw_rng_state_outputs)
  1591. + fw_outputs[sym_node_start_idx:]
  1592. )
  1593. fw_module.graph.output(outputs)
  1594. fw_module.graph.erase_node(fw_output_node)
  1595. fw_module.recompile()
  1596. bw_module.recompile()
  1597. return fw_module, bw_module
  1598. def force_save_collectives(joint_module: fx.GraphModule) -> None:
  1599. """
  1600. By default, the partitioner is not allowed to recompute collectives
  1601. unless they come from a user-annotated AC region.
  1602. See Note [Recomputing collectives in the partitioner]
  1603. """
  1604. for node in joint_module.graph.nodes:
  1605. if (
  1606. isinstance(node.target, torch._ops.OpOverload)
  1607. and node.target.namespace == "_c10d_functional"
  1608. and not must_recompute(node)
  1609. ):
  1610. node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
  1611. def force_save_effectful_ops(joint_module: fx.GraphModule) -> None:
  1612. """
  1613. Force save outputs from with_effects nodes wrapping effectful ops.
  1614. Effectful ops (registered via _register_effectful_op) should not be recomputed
  1615. because they may have arbitrary global side effects (I/O, RNG state, collectives,
  1616. etc.). We mark the tensor outputs of with_effects as MUST_SAVE to prevent
  1617. recomputation of the effectful op.
  1618. The with_effects node returns a tuple (token, result). We recursively find all
  1619. leaf outputs extracted via getitem and mark them as MUST_SAVE. Since these are
  1620. saved, the with_effects op doesn't need to be recomputed in backward.
  1621. """
  1622. def mark_getitem_outputs(node: fx.Node) -> None:
  1623. for user in node.users:
  1624. if user.target is operator.getitem:
  1625. mark_getitem_outputs(user)
  1626. if not isinstance(user.meta.get("val"), (tuple, list)):
  1627. user.meta["recompute"] = CheckpointPolicy.MUST_SAVE
  1628. for node in joint_module.graph.nodes:
  1629. if (
  1630. is_with_effects(node)
  1631. and not must_recompute(node)
  1632. and not _has_tag_is_backward(node)
  1633. ):
  1634. mark_getitem_outputs(node)
  1635. def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None:
  1636. # If we have mutations of the same primal in forward and backward,
  1637. # We must not recompute the source of mutation to not apply twice.
  1638. has_mutation_in_bw: OrderedSet[torch.fx.Node] = OrderedSet()
  1639. for node in reversed(joint_module.graph.nodes):
  1640. if node.op == "output":
  1641. continue
  1642. is_copy_ = node.target is torch.ops.aten.copy_.default
  1643. if is_copy_:
  1644. if _has_tag_must_be_in_backward(node):
  1645. has_mutation_in_bw.add(node.args[0])
  1646. if _has_tag_must_be_in_forward(node) and node.args[0] in has_mutation_in_bw:
  1647. node.args[1].meta["recompute"] = CheckpointPolicy.MUST_SAVE
  1648. else:
  1649. # We use invariant of aotdispatch joint graph,
  1650. # That we emit copy_ only in the end of it.
  1651. # We do not want to iterate through all the joint graph,
  1652. # so break at the first non-output, non-copy_ node.
  1653. break
  1654. def is_getitem_of_multi_output(node: fx.Node) -> bool:
  1655. if node.target != operator.getitem:
  1656. return False
  1657. parent = node.args[0]
  1658. if type(parent) is not fx.Node:
  1659. raise AssertionError(f"expected parent to be fx.Node, got {type(parent)}")
  1660. return "tensor_meta" not in parent.meta and node.op == "call_function"
  1661. def cleanup_recompute_tags(
  1662. joint_module: fx.GraphModule, *, is_default_partition: bool
  1663. ) -> fx.GraphModule:
  1664. """
  1665. If there are two consecutive checkpointed blocks with no operator in
  1666. between, we would still want to stash the tensor at the boundary of
  1667. checkpointed blocks. The following pass makes the last output node
  1668. non-recomputable to allow for that.
  1669. """
  1670. for node in joint_module.graph.nodes:
  1671. if must_recompute(node):
  1672. for user in node.users:
  1673. if (
  1674. must_recompute(user)
  1675. and "ac_graph_id" in user.meta
  1676. and "ac_graph_id" in node.meta
  1677. and user.meta["ac_graph_id"] > node.meta["ac_graph_id"]
  1678. ):
  1679. node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
  1680. if node.meta.get("has_backward_hook", False) and not any(
  1681. must_recompute(user) for user in node.users
  1682. ):
  1683. # If node is AC region output and has a backward hook on it, we intentionally choose to save it.
  1684. # This is to work around circular dependencies in Traceable FSDP2+AC.
  1685. # Example:
  1686. # ```
  1687. # out = fully_shard(utils.checkpoint(module))(x)
  1688. # norm_out = layer_norm(out)
  1689. # ```
  1690. # Here there is a circular dependency:
  1691. # 1. In backward, grad_input of layer_norm aka. `out_grad` is actually dependent on `out`.
  1692. # 2. `out` depends on `out`'s backward hook created by FSDP2 (which does all-gather for `module` weights)
  1693. # in order to be recomputed.
  1694. # 3. `out`'s backward hook, as is the case for all eager backward hooks, depends on `out_grad`
  1695. # -> circular dependency with (1)!
  1696. #
  1697. # Solution: check whether `out` has a backward hook, and if so, intentionally save `out`
  1698. # in forward graph outputs. With this, we can break the above circular dependency.
  1699. node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
  1700. elif (
  1701. "ac_graph_id" not in node.meta
  1702. and any(must_recompute(user) for user in node.users)
  1703. and not (
  1704. # Avoid saving getitem nodes which are not labeled with "ac_graph_id"
  1705. is_getitem_of_multi_output(node) and "ac_graph_id" in node.args[0].meta
  1706. )
  1707. and is_default_partition
  1708. ):
  1709. # This node is not part of the AC region and a user is marked as recompute.
  1710. # This means it's an input to the AC region and we should save it.
  1711. # For ease of landing, gate this to default partitioner only, but we should think
  1712. # about flipping the switch in general as well.
  1713. node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
  1714. return joint_module
  1715. def solve_min_cut(
  1716. joint_graph: fx.Graph,
  1717. node_info: NodeInfo,
  1718. min_cut_options: MinCutOptions,
  1719. dont_ban: OrderedSet[fx.Node] | None = None,
  1720. ) -> tuple[list[fx.Node], OrderedSet[fx.Node]]:
  1721. if dont_ban is None:
  1722. dont_ban = OrderedSet()
  1723. op_types = get_default_op_list()
  1724. if AOT_PARTITIONER_DEBUG:
  1725. joint_module_ops = OrderedSet(
  1726. str(node.target._overloadpacket)
  1727. for node in joint_graph.nodes
  1728. if node.op == "call_function" and hasattr(node.target, "_overloadpacket")
  1729. )
  1730. ops_ignored = joint_module_ops - OrderedSet(
  1731. str(i) for i in op_types.recomputable_ops
  1732. )
  1733. log.info("Ops banned from re-materialization: %s", ops_ignored)
  1734. def can_fuse_into_auto_functionalized(a: fx.Node, b: fx.Node) -> bool:
  1735. if b.target != torch.ops.higher_order.auto_functionalized:
  1736. return False
  1737. mutable_op = b.args[0]
  1738. (
  1739. mutable_arg_names,
  1740. _,
  1741. ) = torch._higher_order_ops.auto_functionalize.get_mutable_args(
  1742. # pyrefly: ignore[bad-argument-type]
  1743. mutable_op
  1744. )
  1745. for name in mutable_arg_names: # pyrefly: ignore [not-iterable]
  1746. arg = b.kwargs[name]
  1747. if a is arg:
  1748. return True
  1749. if isinstance(arg, list):
  1750. if a in arg:
  1751. return True
  1752. return False
  1753. def can_fuse_into_triton_kernel_wrapper_functional(a: fx.Node, b: fx.Node) -> bool:
  1754. if b.target != torch.ops.higher_order.triton_kernel_wrapper_functional:
  1755. return False
  1756. mutable_arg_names = b.kwargs["tensors_to_clone"]
  1757. for name in mutable_arg_names: # pyrefly: ignore [not-iterable]
  1758. kwargs: Any = b.kwargs["kwargs"]
  1759. if kwargs is None:
  1760. raise AssertionError("kwargs must not be None")
  1761. arg = kwargs[name]
  1762. if a is arg:
  1763. return True
  1764. return False
  1765. def is_fusible(a: fx.Node, b: fx.Node) -> bool:
  1766. # We can perform "memory fusion" into a cat, but cat cannot be a
  1767. # producer to a fusion
  1768. if get_aten_target(b) == aten.cat:
  1769. return True
  1770. if can_fuse_into_auto_functionalized(a, b):
  1771. return True
  1772. if can_fuse_into_triton_kernel_wrapper_functional(a, b):
  1773. return True
  1774. if (
  1775. a.target is operator.getitem
  1776. and a.args[0].target # pyrefly: ignore [missing-attribute]
  1777. is torch.ops.higher_order.triton_kernel_wrapper_functional
  1778. ):
  1779. # if a is the output of a user triton kernel,
  1780. # then (by default) we will not be able to fuse b into it
  1781. return False
  1782. return op_types.is_fusible(a) and op_types.is_fusible(b)
  1783. try:
  1784. import networkx as nx
  1785. except ImportError as e:
  1786. raise RuntimeError(
  1787. "Need networkx installed to perform smart recomputation heuristics"
  1788. ) from e
  1789. def is_materialized_backwards(node: fx.Node) -> bool:
  1790. if op_types.is_view(node):
  1791. return False
  1792. cur_nodes = OrderedSet([node])
  1793. while len(cur_nodes) > 0:
  1794. cur = cur_nodes.pop()
  1795. for user in cur.users:
  1796. if not node_info.is_required_fw(user) and not is_fusible(cur, user):
  1797. return True
  1798. if op_types.is_view(user):
  1799. cur_nodes.add(user)
  1800. return False
  1801. def should_ban_recomputation(node: fx.Node) -> str | None:
  1802. """Returns reason string if node should be banned from recomputation, None otherwise."""
  1803. if node.op != "call_function":
  1804. return None
  1805. if node.target is operator.getitem:
  1806. return None
  1807. if node.meta.get("recompute", None) == CheckpointPolicy.MUST_SAVE:
  1808. return "marked MUST_SAVE"
  1809. if config.recompute_views and op_types.is_view(node):
  1810. return None
  1811. if node.target in [aten.lift_fresh_copy.default, aten.lift_fresh.default]:
  1812. return None
  1813. if min_cut_options.ban_if_not_in_allowlist:
  1814. if not op_types.is_recomputable(node):
  1815. return "not in recomputable allowlist"
  1816. else:
  1817. if op_types.is_random(node):
  1818. return "random op"
  1819. if op_types.is_compute_intensive(node):
  1820. return "compute intensive op"
  1821. if is_non_builtin_to_include(node):
  1822. return "non-builtin op"
  1823. # If a node *must* be materialized in the backwards pass, then we
  1824. # should never recompute it. This is a pretty subtle point. In
  1825. # general, the assumption we make is that recomputing a node in the
  1826. # backwards pass is "free". However, if a node must be materialized
  1827. # in the backwards pass, then recomputing it is never free.
  1828. if min_cut_options.ban_if_materialized_backward and is_materialized_backwards(
  1829. node
  1830. ):
  1831. log.debug("materialized backwards: %s %s", node, tuple(node.users))
  1832. return "materialized in backward"
  1833. # Arbitrary hack that sometimes seems to help things. The above
  1834. # modification appears to have made this heuristic a lot less critical
  1835. # for performance.
  1836. # NB: As of PR #121692, this hack no longer seems necessary.
  1837. if (
  1838. # pyrefly: ignore [missing-attribute]
  1839. node.dist_from_bw < 1000 and node.dist_from_bw > config.max_dist_from_bw
  1840. ):
  1841. return "too far from backward"
  1842. # If the output of an op is 4x smaller (arbitrary choice),
  1843. # then we don't allow recomputation. The idea here is that for
  1844. # things like reductions, saving the output of the reduction is very
  1845. # cheap/small, and it makes sure we don't do things like recompute
  1846. # normalizations in the backwards.
  1847. if min_cut_options.ban_if_reduction:
  1848. input_tensors_size = sum(
  1849. _size_of(i) for i in node.args if isinstance(i, fx.Node)
  1850. )
  1851. output_size = _size_of(node)
  1852. if output_size * 4 < input_tensors_size:
  1853. return "reduction op"
  1854. return None
  1855. def is_materialized(node: fx.Node) -> bool:
  1856. if node.op == "placeholder":
  1857. return True
  1858. return not all(is_fusible(node, user) for user in node.users)
  1859. def get_node_weight(
  1860. node: fx.Node, static_lifetime_input_nodes: OrderedSet[fx.Node]
  1861. ) -> tuple[float, str | None]:
  1862. """Returns (weight, cannot_save_reason).
  1863. cannot_save_reason is None for finite weights, or a string explaining
  1864. why the node cannot be saved for infinite weights.
  1865. """
  1866. if (
  1867. config.treat_parameters_as_free_to_save
  1868. and node in static_lifetime_input_nodes
  1869. ):
  1870. return 0, None
  1871. mem_sz = _size_of(node)
  1872. if config.recompute_views and op_types.is_view(node):
  1873. # If `config.recompute_views=True`, we don't save views. This is generally
  1874. # a good idea since views are free to recompute, and it makes it a bit simpler
  1875. # to analyze.
  1876. # NB: If they're not free to recompute (e.g. nested tensors)... I
  1877. # think we should modify checks for view_ops to `is_view` and check
  1878. # that. Basically, with nested tensors, `aten.view` is not a "view
  1879. # op".
  1880. return math.inf, "view op (recompute_views=True)"
  1881. if isinstance(node.meta["val"], py_sym_types):
  1882. # We never want to save symfloats
  1883. if not isinstance(node.meta["val"], torch.SymInt):
  1884. return INT_INF, "SymFloat (non-SymInt symbolic value)"
  1885. # Heuristic to bias towards nodes closer to the backwards pass
  1886. # Complete guess about current value
  1887. mem_sz = int(
  1888. # pyrefly: ignore [missing-attribute]
  1889. mem_sz * (1.1 ** max(min(node.dist_from_bw, 100), 1))
  1890. )
  1891. if is_materialized(node):
  1892. return mem_sz, None
  1893. else:
  1894. return mem_sz * 2, None
  1895. nx_graph = nx.DiGraph()
  1896. banned_nodes: OrderedSet[fx.Node] = OrderedSet()
  1897. def ban_recomputation_if_allowed(node: fx.Node, reason: str = "") -> bool:
  1898. if op_types.is_view(node):
  1899. return False
  1900. if node in dont_ban:
  1901. # collectives are *always* banned from recompute, overriding `dont_ban`
  1902. # (in particular, the activation memory budget logic is not allowed to recompute collectives)
  1903. is_collective = (
  1904. isinstance(node.target, torch._ops.OpOverload)
  1905. and node.target.namespace == "_c10d_functional"
  1906. )
  1907. if config.unsafe_allow_optimization_of_collectives or not is_collective:
  1908. return False
  1909. # This bans recomputation of the node unless we've been forced not to by
  1910. # user annotation
  1911. if must_recompute(node):
  1912. return False
  1913. if "val" in node.meta and isinstance(node.meta["val"], torch.SymFloat):
  1914. return False
  1915. banned_nodes.add(node)
  1916. # A node will only ever be recomputed if there is a path from an
  1917. # ancestor of this node to the backwards path through this node that
  1918. # doesn't go through any saved value. If this node is saved, then that
  1919. # condition is not possible.
  1920. nx_graph.add_edge(
  1921. "source",
  1922. node.name + "_in",
  1923. capacity=math.inf,
  1924. reason=f"cannot recompute: {reason}" if reason else "cannot recompute",
  1925. )
  1926. return True
  1927. for node in joint_graph.nodes:
  1928. if node.op == "output":
  1929. continue
  1930. if node in node_info.required_bw_nodes:
  1931. # See Note: [tangents_closure vs required_bw_nodes]
  1932. if node not in node_info.tangents_closure:
  1933. nx_graph.add_edge(
  1934. node.name + "_out",
  1935. "sink",
  1936. capacity=math.inf,
  1937. reason="must be available for backward: input required for gradient",
  1938. )
  1939. else:
  1940. nx_graph.add_edge(
  1941. node.name + "_in",
  1942. "sink",
  1943. capacity=math.inf,
  1944. reason="must be computed in backward: required for gradient",
  1945. )
  1946. continue
  1947. if must_recompute(node):
  1948. # If user explicitly says they want to recompute a node, we honor it
  1949. # by adding an inf-capacity edge from X_in to the sink.
  1950. # This way, X_in node is guaranteed to be part of the subgraph that contains "sink"
  1951. # after the cut, thus guaranteeing that X op will be recomputed.
  1952. nx_graph.add_edge(
  1953. node.name + "_in",
  1954. "sink",
  1955. capacity=math.inf,
  1956. reason="must recompute: marked by checkpoint policy",
  1957. )
  1958. continue
  1959. if _is_primal(node):
  1960. ban_recomputation_if_allowed(node, "primal input")
  1961. elif _is_fwd_seed_offset(node):
  1962. ban_recomputation_if_allowed(node, "forward RNG seed")
  1963. # If a node can't be recomputed (too expensive or involves randomness),
  1964. # we prevent it from being recomputed by adding an inf edge to the source
  1965. # We only need to ban nodes in the fw pass, as those are the only ones that would be recomputed.
  1966. ban_reason = should_ban_recomputation(node)
  1967. if node_info.is_required_fw(node) and ban_reason:
  1968. ban_recomputation_if_allowed(node, ban_reason)
  1969. # Checks if a node is actually a tuple. Can be simplified to just an isinstance check if we always use faketensors.
  1970. is_non_tensor_node = (
  1971. "val" not in node.meta and "tensor_meta" not in node.meta
  1972. ) or ("val" in node.meta and not isinstance(node.meta["val"], torch.Tensor))
  1973. if is_sym_node(node):
  1974. weight = float(sym_node_size(node))
  1975. cannot_save_reason = None
  1976. elif is_non_tensor_node:
  1977. # FakeScriptObjects (opaque objects) should have weight 0.0 so they can be
  1978. # properly partitioned between forward and backward, like BackwardState.
  1979. if isinstance(node.meta.get("val"), (BackwardState, FakeScriptObject)):
  1980. weight = 0.0
  1981. cannot_save_reason = None
  1982. else:
  1983. weight = math.inf
  1984. cannot_save_reason = "non-tensor output"
  1985. else:
  1986. weight, cannot_save_reason = get_node_weight(
  1987. node, node_info.static_lifetime_input_nodes
  1988. )
  1989. # Creates the weights on the "node" edge
  1990. if cannot_save_reason and (weight == math.inf or weight == INT_INF):
  1991. nx_graph.add_edge(
  1992. node.name + "_in",
  1993. node.name + "_out",
  1994. capacity=weight,
  1995. reason=f"cannot save: {cannot_save_reason}",
  1996. )
  1997. else:
  1998. nx_graph.add_edge(node.name + "_in", node.name + "_out", capacity=weight)
  1999. for user in node.users:
  2000. nx_graph.add_edge(
  2001. node.name + "_out",
  2002. user.name + "_in",
  2003. capacity=math.inf,
  2004. reason="data dependency",
  2005. )
  2006. # todo(chilli): This is the most questionable of the 3 heuristics for banning recompute.
  2007. # Some example models to look at where this helps perf: poolformer_m36,
  2008. # mixer_b16_224, cait_m36_384
  2009. # The "rough" idea here is that if you have some node that is used by both a
  2010. # node nearby downstream as well as a node far downstream, if we recompute
  2011. # both of the downstream nodes, we're unlikely to be able to fuse both
  2012. # downstream nodes together.
  2013. # Thus, we shouldn't aim to recompute far downstream nodes that depend on
  2014. # this node. That intuition of "far downstream" is captured by whether
  2015. # there's an unfusible op along the chain somewhere
  2016. # It could probably be improved by properly analyzing what's going on in the
  2017. # backwards pass instead of only relying on whether it's unfusible in the
  2018. # forwards.
  2019. def find_first_unfusible(start_nodes: list[fx.Node], max_range: int) -> int:
  2020. """
  2021. Finds the first unfusible node in the chain of nodes starting from
  2022. `start_nodes` and returns its position.
  2023. """
  2024. sorted_nodes: list[tuple[int, fx.Node, bool]] = []
  2025. for n in start_nodes:
  2026. heapq.heappush(sorted_nodes, (node_info.get_fw_order(n), n, True))
  2027. while len(sorted_nodes) > 0:
  2028. _, node, node_is_fusible = heapq.heappop(sorted_nodes)
  2029. if not node_is_fusible:
  2030. return node_info.get_fw_order(node)
  2031. for user in node.users:
  2032. if node_info.is_required_fw(user):
  2033. if node_info.get_fw_order(user) > max_range:
  2034. continue
  2035. val: tuple[int, fx.Node, bool] = (
  2036. node_info.get_fw_order(user),
  2037. user,
  2038. is_fusible(node, user),
  2039. )
  2040. if val not in sorted_nodes:
  2041. heapq.heappush(sorted_nodes, val)
  2042. return max_range
  2043. if min_cut_options.ban_if_used_far_apart:
  2044. for used_node in node_info.required_fw_nodes:
  2045. orders = [
  2046. node_info.get_fw_order(user)
  2047. for user in used_node.users
  2048. if node_info.is_required_fw(user)
  2049. ]
  2050. fw_users = [
  2051. user for user in used_node.users if node_info.is_required_fw(user)
  2052. ]
  2053. if len(orders) > 0:
  2054. first_unfusible_use = find_first_unfusible(fw_users, max(orders))
  2055. for user in tuple(used_node.users):
  2056. if (
  2057. node_info.is_required_fw(user)
  2058. and node_info.get_fw_order(user) > first_unfusible_use
  2059. and is_fusible(used_node, user)
  2060. ):
  2061. if user in banned_nodes:
  2062. continue
  2063. log.info(
  2064. "used above/below fusible %s:(%s) -> %s -> %s:(%s)",
  2065. used_node,
  2066. node_info.get_fw_order(used_node),
  2067. first_unfusible_use,
  2068. user,
  2069. node_info.get_fw_order(user),
  2070. )
  2071. ban_recomputation_if_allowed(user)
  2072. # This heuristic is fairly straightforward. The idea is that although it is
  2073. # cheap to recompute bandwidth-bound ops, we don't want to end up in a situation
  2074. # where we have a long chain of pointwise ops from the beginning to the end
  2075. # of the model (like say, residual connections)
  2076. # todo: I'm not totally sure why this heuristic matters. It's possible that this is
  2077. # working around Inductor fusion decisions, or that it's a patch over
  2078. # suboptimal partitioning decisions
  2079. # Some models it improves perf on are cait_m36_384, mixer_b16_224, poolformer_m36
  2080. if min_cut_options.ban_if_long_fusible_chains:
  2081. visited: OrderedSet[fx.Node] = OrderedSet()
  2082. for start_node in joint_graph.nodes:
  2083. if not node_info.is_required_fw(start_node):
  2084. continue
  2085. fusible: list[tuple[int, fx.Node]] = [
  2086. (node_info.get_fw_order(start_node), start_node)
  2087. ]
  2088. start_order = node_info.get_fw_order(start_node)
  2089. while len(fusible) > 0:
  2090. _, cur = heapq.heappop(fusible)
  2091. if cur in visited:
  2092. continue
  2093. visited.add(cur)
  2094. # 100 is arbitrary choice to try and prevent degenerate cases
  2095. if (
  2096. node_info.get_fw_order(cur) > start_order + 100
  2097. and len(fusible) == 0
  2098. ):
  2099. log.info(
  2100. "too long %s %s %s %s",
  2101. cur,
  2102. start_node,
  2103. node_info.get_fw_order(cur),
  2104. node_info.get_fw_order(start_node),
  2105. )
  2106. ban_recomputation_if_allowed(cur)
  2107. break
  2108. for user in cur.users:
  2109. if (
  2110. node_info.is_required_fw(user)
  2111. and is_fusible(cur, user)
  2112. and user not in banned_nodes
  2113. ):
  2114. heapq.heappush(fusible, (node_info.get_fw_order(user), user))
  2115. try:
  2116. cut_value, partition = nx.minimum_cut(nx_graph, "source", "sink")
  2117. except nx.NetworkXUnbounded as unbounded_exc:
  2118. # Check if structured tracing is enabled (for production job debugging via tlparse)
  2119. structured_tracing_enabled = bool(trace_log.handlers)
  2120. # Dump the FX graph for debugging
  2121. fx_graph_file: str | None = None
  2122. fx_graph_str: str | None = None
  2123. joint_module = joint_graph.owning_module
  2124. try:
  2125. fx_graph_str = (
  2126. joint_module.print_readable(
  2127. print_output=False, include_stride=True, include_device=True
  2128. )
  2129. if joint_module
  2130. else str(joint_graph)
  2131. )
  2132. # Always log to structured trace for production debugging
  2133. trace_structured(
  2134. "artifact",
  2135. metadata_fn=lambda: {
  2136. "name": "min_cut_failed_fx_graph",
  2137. "encoding": "string",
  2138. },
  2139. payload_fn=lambda: fx_graph_str,
  2140. )
  2141. # Also write to local file for local debugging
  2142. fx_graph_file = _get_unique_path("min_cut_failed_graph", ".txt")
  2143. with open(fx_graph_file, "w") as f:
  2144. f.write(fx_graph_str)
  2145. except Exception as e:
  2146. fx_graph_file = f"(failed to write: {e})"
  2147. # Dump the min-cut edge list to structured trace
  2148. edge_list_str = "\n".join(nx.readwrite.edgelist.generate_edgelist(nx_graph))
  2149. trace_structured(
  2150. "artifact",
  2151. metadata_fn=lambda: {
  2152. "name": "min_cut_failed_edge_list",
  2153. "encoding": "string",
  2154. },
  2155. payload_fn=lambda: edge_list_str,
  2156. )
  2157. # Find and report the infinite-capacity path
  2158. inf_path = _find_infinite_capacity_path(nx_graph)
  2159. if inf_path:
  2160. # Group edges by FX node and format for user understanding
  2161. # inf_path is a list of (from_node, to_node, reason) tuples
  2162. #
  2163. # Edge types and what they mean:
  2164. # - source -> X_in: X cannot be recomputed
  2165. # - X_in -> X_out (inf): X's output cannot be saved
  2166. # - X_out -> Y_in: Y depends on X (data flow)
  2167. # - X_in -> sink: X must be computed in backward
  2168. # - X_out -> sink: X's output must be available for backward
  2169. # Build a user-friendly explanation grouped by FX node
  2170. node_constraints: dict[str, list[str]] = {}
  2171. raw_path_nodes = ["source"]
  2172. def get_base_name(node_name: str) -> str:
  2173. for suffix in ("_in", "_out"):
  2174. if node_name.endswith(suffix):
  2175. return node_name[: -len(suffix)]
  2176. return node_name
  2177. for from_node, to_node, reason in inf_path:
  2178. raw_path_nodes.append(to_node)
  2179. # Skip source/sink, focus on FX nodes
  2180. if from_node == "source":
  2181. base = get_base_name(to_node)
  2182. node_constraints.setdefault(base, []).append(reason)
  2183. elif to_node == "sink":
  2184. base = get_base_name(from_node)
  2185. node_constraints.setdefault(base, []).append(reason)
  2186. elif get_base_name(from_node) == get_base_name(to_node):
  2187. # Internal edge (X_in -> X_out)
  2188. base = get_base_name(from_node)
  2189. node_constraints.setdefault(base, []).append(reason)
  2190. else:
  2191. # Data dependency edge (X_out -> Y_in)
  2192. from_base = get_base_name(from_node)
  2193. to_base = get_base_name(to_node)
  2194. node_constraints.setdefault(to_base, []).append(
  2195. f"depends on {from_base}"
  2196. )
  2197. # Format the constraints nicely
  2198. constraint_lines: list[str] = []
  2199. for node_name, constraints in node_constraints.items():
  2200. constraint_lines.append(f" {node_name}:")
  2201. for c in constraints:
  2202. constraint_lines.append(f" - {c}")
  2203. constraints_str = "\n".join(constraint_lines)
  2204. raw_path_str = " -> ".join(raw_path_nodes)
  2205. # Try to visualize (logs to structured trace and writes local file)
  2206. svg_path, svg_content = visualize_min_cut_graph(nx_graph)
  2207. if svg_content:
  2208. trace_structured(
  2209. "artifact",
  2210. metadata_fn=lambda: {
  2211. "name": "min_cut_failed_svg",
  2212. "encoding": "string",
  2213. },
  2214. payload_fn=lambda: svg_content,
  2215. )
  2216. # Build file location messages
  2217. local_files_msg = (
  2218. f"FX graph dump: {fx_graph_file}\n" if fx_graph_file else ""
  2219. )
  2220. if svg_path:
  2221. local_files_msg += f"Min-cut graph visualization: {svg_path}\n"
  2222. # Suggest tlparse if structured tracing is enabled
  2223. tlparse_msg = ""
  2224. if structured_tracing_enabled:
  2225. tlparse_msg = (
  2226. "[Production debugging: Use tlparse to extract debug artifacts "
  2227. "(min_cut_failed_fx_graph, min_cut_failed_edge_list, min_cut_failed_svg)]\n"
  2228. )
  2229. raise RuntimeError(
  2230. f"AOT Autograd failed to partition the joint forward-backward graph.\n\n"
  2231. f"The partitioner determines which intermediate values to save from the "
  2232. f"forward pass vs recompute in the backward pass. This error means a value "
  2233. f"is required for backward, but cannot be saved AND cannot be recomputed.\n\n"
  2234. f"This is a bug in PyTorch. Please file an issue at "
  2235. f"https://github.com/pytorch/pytorch/issues\n\n"
  2236. f"Nodes involved in the conflict:\n"
  2237. f"{constraints_str}\n\n"
  2238. f"[For PyTorch developers: one of the above constraints is wrong. "
  2239. f"Either the node should be recomputable, saveable, or not required for backward.]\n\n"
  2240. f"[Debug: min-cut path] {raw_path_str}\n"
  2241. f"{local_files_msg}"
  2242. f"{tlparse_msg}"
  2243. ) from unbounded_exc
  2244. # Fallback if we couldn't find the path
  2245. log.info("Failed to compute min-cut on following graph:")
  2246. log.info(
  2247. "%s",
  2248. LazyString(
  2249. lambda: "\n".join(nx.readwrite.edgelist.generate_edgelist(nx_graph))
  2250. ),
  2251. )
  2252. visualize_min_cut_graph(nx_graph)
  2253. raise
  2254. except Exception:
  2255. log.info("Failed to compute min-cut on following graph:")
  2256. log.info(
  2257. "%s",
  2258. LazyString(
  2259. lambda: "\n".join(nx.readwrite.edgelist.generate_edgelist(nx_graph))
  2260. ),
  2261. )
  2262. visualize_min_cut_graph(nx_graph)
  2263. raise
  2264. reachable, non_reachable = partition
  2265. cutset: OrderedSet[tuple[str, str]] = OrderedSet()
  2266. for u, nbrs in ((n, nx_graph[n]) for n in reachable):
  2267. cutset.update((u, v) for v in nbrs if v in non_reachable)
  2268. cut_nodes: OrderedSet[str] = OrderedSet()
  2269. for node_in, node_out in cutset:
  2270. if node_in[:-3] != node_out[:-4]:
  2271. raise AssertionError(
  2272. f"node_in[:-3]={node_in[:-3]} != node_out[:-4]={node_out[:-4]}"
  2273. )
  2274. node_name = node_in[:-3]
  2275. cut_nodes.add(node_name)
  2276. name_to_node = get_name_to_node(joint_graph)
  2277. # To make this stuff deterministic
  2278. node_idx = {node: idx for idx, node in enumerate(joint_graph.nodes)}
  2279. saved_values = sorted(
  2280. (name_to_node[node] for node in cut_nodes), key=lambda x: node_idx[x]
  2281. )
  2282. return saved_values, banned_nodes
  2283. def _find_infinite_capacity_path(
  2284. nx_graph: nx.DiGraph[str, dict[str, Any]],
  2285. ) -> list[tuple[str, str, str]] | None:
  2286. """BFS from source to sink following only infinite-capacity edges.
  2287. Returns a list of (from_node, to_node, reason) tuples representing the path,
  2288. or None if no such path exists.
  2289. """
  2290. visited = OrderedSet(["source"])
  2291. # Each queue item: (current_node, path_of_edges)
  2292. # where path_of_edges is a list of (from_node, to_node, reason) tuples
  2293. queue: deque[tuple[str, list[tuple[str, str, str]]]] = deque([("source", [])])
  2294. while queue:
  2295. node, edge_path = queue.popleft()
  2296. for neighbor in nx_graph.successors(node):
  2297. if neighbor in visited:
  2298. continue
  2299. edge_data = nx_graph[node][neighbor]
  2300. capacity = edge_data.get("capacity", 0)
  2301. # Check for infinite capacity (either math.inf or INT_INF)
  2302. if capacity == math.inf or capacity == INT_INF:
  2303. reason = edge_data.get("reason", "unknown")
  2304. new_edge = (node, neighbor, reason)
  2305. new_path = edge_path + [new_edge]
  2306. if neighbor == "sink":
  2307. return new_path
  2308. visited.add(neighbor)
  2309. queue.append((neighbor, new_path))
  2310. return None
  2311. def _get_unique_path(base_name: str, extension: str) -> str:
  2312. """Get a unique file path, appending a counter if the file already exists.
  2313. For example, if "min_cut_failed.svg" exists, returns "min_cut_failed_1.svg".
  2314. """
  2315. path = f"{base_name}{extension}"
  2316. if not os.path.exists(path):
  2317. return path
  2318. counter = 1
  2319. while os.path.exists(f"{base_name}_{counter}{extension}"):
  2320. counter += 1
  2321. return f"{base_name}_{counter}{extension}"
  2322. def visualize_min_cut_graph(
  2323. nx_graph: nx.DiGraph[str, dict[str, Any]],
  2324. ) -> tuple[str | None, str | None]:
  2325. """Visualize the min-cut graph to an SVG file.
  2326. Returns (path_to_svg, svg_content) tuple. Both are None if pydot is unavailable.
  2327. """
  2328. import networkx as nx
  2329. try:
  2330. import pydot
  2331. except ImportError:
  2332. log.info(
  2333. "Install pydot to visualize the min-cut graph for debugging: pip install pydot",
  2334. exc_info=True,
  2335. )
  2336. return None, None
  2337. dot_format = nx.nx_pydot.to_pydot(nx_graph).to_string()
  2338. dot_graph = pydot.graph_from_dot_data(dot_format)[0] # type: ignore[index]
  2339. for edge in dot_graph.get_edges():
  2340. weight = nx_graph[edge.get_source()][edge.get_destination()]["capacity"]
  2341. # Set edge label to weight
  2342. edge.set_label(str(weight)) # type: ignore[union-attr]
  2343. # Color edges with weight 'inf' as red
  2344. if weight == float("inf"):
  2345. edge.set_color("red") # type: ignore[union-attr]
  2346. # Generate SVG content
  2347. svg_content = dot_graph.create_svg().decode("utf-8") # type: ignore[union-attr]
  2348. # Write to local file
  2349. svg_path = _get_unique_path("min_cut_failed", ".svg")
  2350. with open(svg_path, "w") as f:
  2351. f.write(svg_content)
  2352. return svg_path, svg_content
  2353. def get_default_op_list() -> OpTypes:
  2354. default_recomputable_ops: list[Callable[..., Any]] = [
  2355. aten.add,
  2356. aten.sub,
  2357. aten.div,
  2358. aten.atan2,
  2359. aten.mul,
  2360. aten.max,
  2361. aten.min,
  2362. aten.pow,
  2363. aten.remainder,
  2364. aten.fmod,
  2365. aten.__and__,
  2366. aten.__or__,
  2367. aten.__xor__,
  2368. aten.__lshift__,
  2369. aten.__rshift__,
  2370. aten.eq,
  2371. aten.ne,
  2372. aten.ge,
  2373. aten.gt,
  2374. aten.le,
  2375. aten.lt,
  2376. aten.abs,
  2377. aten.bitwise_not,
  2378. aten.ceil,
  2379. aten.floor,
  2380. aten.frac,
  2381. aten.neg,
  2382. aten.relu,
  2383. aten.round,
  2384. aten.silu,
  2385. aten.trunc,
  2386. aten.log,
  2387. aten.log10,
  2388. aten.log1p,
  2389. aten.log2,
  2390. aten.lgamma,
  2391. aten.exp,
  2392. aten.expm1,
  2393. aten.erf,
  2394. aten.erfc,
  2395. aten.cos,
  2396. aten.acos,
  2397. aten.cosh,
  2398. aten.sin,
  2399. aten.asin,
  2400. aten.sinh,
  2401. aten.tan,
  2402. aten.atan,
  2403. aten.tanh,
  2404. aten.atanh,
  2405. aten.sqrt,
  2406. aten.rsqrt,
  2407. aten.reciprocal,
  2408. aten.sigmoid,
  2409. aten.softplus,
  2410. aten.threshold,
  2411. aten.threshold_backward,
  2412. aten.clamp,
  2413. aten.where,
  2414. aten.lerp,
  2415. aten.addcmul,
  2416. aten.gelu,
  2417. aten.gelu_backward,
  2418. aten.sum,
  2419. aten.mean,
  2420. aten._grad_sum_to_size,
  2421. aten.sum_to_size,
  2422. aten.amax,
  2423. aten.to,
  2424. aten.type_as,
  2425. operator.getitem,
  2426. aten.squeeze,
  2427. aten.unsqueeze,
  2428. aten.rsub,
  2429. aten._to_copy,
  2430. ] # noqa: E501,B950
  2431. recomputable_view_ops = [aten.squeeze, aten.unsqueeze, aten.alias]
  2432. recomputable_view_ops += [
  2433. aten.view,
  2434. aten.slice,
  2435. aten.t,
  2436. prims.broadcast_in_dim,
  2437. aten.expand,
  2438. aten.as_strided,
  2439. aten.permute,
  2440. aten.select,
  2441. aten.split,
  2442. ]
  2443. view_ops = recomputable_view_ops
  2444. default_recomputable_ops += [
  2445. prims.div,
  2446. prims.convert_element_type,
  2447. aten.clone,
  2448. aten._to_copy,
  2449. aten.full_like,
  2450. prims.var,
  2451. prims.sum,
  2452. aten.var,
  2453. aten.std,
  2454. prims.broadcast_in_dim,
  2455. aten.select,
  2456. aten._unsafe_view,
  2457. aten.view,
  2458. aten.expand,
  2459. aten.slice,
  2460. aten.reshape,
  2461. aten.broadcast_tensors,
  2462. aten.scalar_tensor,
  2463. aten.ones,
  2464. aten.new_zeros,
  2465. aten.lift_fresh_copy,
  2466. aten.arange,
  2467. aten.triu,
  2468. aten.var_mean,
  2469. aten.isinf,
  2470. aten.any,
  2471. aten.full,
  2472. aten.as_strided,
  2473. aten.zeros,
  2474. aten.empty,
  2475. aten.empty_like,
  2476. aten.argmax,
  2477. aten.maximum,
  2478. prims.iota,
  2479. prims._low_memory_max_pool_offsets_to_indices,
  2480. ] # noqa: E501,B950
  2481. # Natalia said that we should allow recomputing indexing :)
  2482. default_recomputable_ops += [aten.index, aten.gather]
  2483. default_recomputable_ops += view_ops
  2484. default_recomputable_ops += pointwise_ops()
  2485. default_recomputable_ops += [
  2486. aten.zeros_like,
  2487. ]
  2488. default_recomputable_ops += [method_to_operator(m) for m in magic_methods]
  2489. recomputable_ops = OrderedSet(default_recomputable_ops)
  2490. random_ops = OrderedSet[Callable[..., Any]](
  2491. [aten.native_dropout, aten.rand_like, aten.randn_like]
  2492. )
  2493. compute_intensive_ops = [
  2494. aten.mm,
  2495. aten.convolution,
  2496. aten.convolution_backward,
  2497. aten.bmm,
  2498. aten.addmm,
  2499. aten._scaled_dot_product_flash_attention,
  2500. aten._scaled_dot_product_efficient_attention,
  2501. aten._flash_attention_forward,
  2502. aten._efficient_attention_forward,
  2503. aten.upsample_bilinear2d,
  2504. aten._scaled_mm,
  2505. ] # noqa: E501,B950
  2506. fusible_ops = recomputable_ops | random_ops
  2507. return OpTypes(
  2508. fusible_ops,
  2509. OrderedSet(compute_intensive_ops),
  2510. random_ops,
  2511. OrderedSet(view_ops),
  2512. recomputable_ops,
  2513. )
  2514. def get_name_to_node(graph: fx.Graph) -> dict[str, fx.Node]:
  2515. name_to_node: dict[str, fx.Node] = {}
  2516. for node in graph.nodes:
  2517. name_to_node[node.name] = node
  2518. return name_to_node
  2519. def _optimize_runtime_with_given_memory(
  2520. joint_graph: fx.Graph,
  2521. memory: list[float],
  2522. runtimes: list[float],
  2523. max_memory: float,
  2524. node_info: NodeInfo,
  2525. all_recomputable_banned_nodes: list[fx.Node],
  2526. ) -> tuple[float, list[int], list[int]]:
  2527. SOLVER = config.activation_memory_budget_solver
  2528. if SOLVER == "greedy":
  2529. return greedy_knapsack(memory, runtimes, max_memory)
  2530. elif SOLVER == "ilp":
  2531. return ilp_knapsack(memory, runtimes, max_memory)
  2532. elif SOLVER == "dp":
  2533. return dp_knapsack(memory, runtimes, max_memory)
  2534. elif SOLVER == "dp_knapsack_sliding_hirschberg":
  2535. return dp_knapsack_sliding_hirschberg(memory, runtimes, max_memory)
  2536. elif SOLVER == "dynamic_memory_budget_dp":
  2537. log.warning(
  2538. "dynamic_memory_budget_dp is an experimental solver. "
  2539. "It does not guarantee performance improvements. "
  2540. "Additionally, it is not guaranteed to be stable."
  2541. )
  2542. graph_info_provider = GraphInfoProvider.inialize_from_graph(
  2543. joint_graph=joint_graph,
  2544. all_recomputable_banned_nodes=all_recomputable_banned_nodes,
  2545. recorded_knapsack_input_memories=memory,
  2546. recorded_knapsack_input_runtimes=runtimes,
  2547. )
  2548. return dp_knapsack(
  2549. memory,
  2550. runtimes,
  2551. KnapsackEvaluator(
  2552. graph_info_provider=graph_info_provider,
  2553. ).get_knee_point_memory_budget(
  2554. knapsack_algo=dp_knapsack,
  2555. max_mem_budget=max_memory,
  2556. ),
  2557. )
  2558. elif isinstance(SOLVER, CustomKnapsackSolver):
  2559. saved_node_idx, recomp_node_idx = SOLVER(
  2560. memory, joint_graph, max_memory, node_info, all_recomputable_banned_nodes
  2561. )
  2562. return (0.0, saved_node_idx, recomp_node_idx)
  2563. else:
  2564. raise RuntimeError(f"Not aware of memory budget knapsack solver: {SOLVER}")
  2565. from torch.utils._mode_utils import no_dispatch
  2566. # replace symbols in size and strides with their hints without guarding.
  2567. def _remove_symbols_without_guarding(x: torch.Tensor, fallback: int) -> torch.Tensor:
  2568. shape = list(x.shape)
  2569. def realize_symbol(d: torch.SymInt | int) -> int:
  2570. return size_hint(d, fallback=fallback)
  2571. shape = [realize_symbol(s) for s in shape]
  2572. stride = [realize_symbol(s) for s in x.stride()]
  2573. return x.new_empty_strided(shape, stride=stride)
  2574. def estimate_runtime(node: fx.Node) -> float:
  2575. RUNTIME_MODE = config.activation_memory_budget_runtime_estimator
  2576. def materialize_arg(x: Any) -> Any:
  2577. if isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.Tensor):
  2578. return _remove_symbols_without_guarding(x.meta["val"], fallback=4096)
  2579. elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymInt):
  2580. return size_hint(x.meta["val"], fallback=4096)
  2581. elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymFloat):
  2582. return 1.0
  2583. elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymBool):
  2584. return True
  2585. else:
  2586. return x
  2587. if RUNTIME_MODE == "testing":
  2588. return 1
  2589. elif RUNTIME_MODE == "profile":
  2590. with no_dispatch():
  2591. from torch._inductor.runtime.benchmarking import benchmarker
  2592. args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs))
  2593. # pyrefly: ignore[not-callable]
  2594. ms = benchmarker.benchmark_gpu(lambda: node.target(*args, **kwargs))
  2595. return ms
  2596. elif RUNTIME_MODE == "flops":
  2597. # todo(chilli): Normalize this to also return ms
  2598. from torch.utils.flop_counter import FlopCounterMode
  2599. args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs))
  2600. with FlopCounterMode(display=False) as mode:
  2601. # pyrefly: ignore[not-callable]
  2602. node.target(*args, **kwargs)
  2603. counted_flops = mode.get_total_flops()
  2604. return max(counted_flops, 1)
  2605. elif isinstance(RUNTIME_MODE, CustomRuntimeEstimator):
  2606. return RUNTIME_MODE(node)
  2607. else:
  2608. raise RuntimeError(f"Not aware of runtime estimator: {RUNTIME_MODE}")
  2609. def choose_saved_values_set(
  2610. joint_graph: fx.Graph,
  2611. node_info: NodeInfo,
  2612. memory_budget: float = 1,
  2613. ) -> list[fx.Node]:
  2614. if memory_budget > 1 or memory_budget < 0:
  2615. raise RuntimeError(
  2616. f"The valid ranges for memory budget are 0 <= m <= 1. The provided value is {memory_budget}"
  2617. )
  2618. min_cut_options = MinCutOptions(
  2619. ban_if_used_far_apart=config.ban_recompute_used_far_apart,
  2620. ban_if_long_fusible_chains=config.ban_recompute_long_fusible_chains,
  2621. ban_if_materialized_backward=config.ban_recompute_materialized_backward,
  2622. ban_if_not_in_allowlist=config.ban_recompute_not_in_allowlist,
  2623. ban_if_reduction=config.ban_recompute_reductions,
  2624. )
  2625. if config.aggressive_recomputation:
  2626. min_cut_options = replace(
  2627. min_cut_options,
  2628. ban_if_used_far_apart=False,
  2629. ban_if_long_fusible_chains=False,
  2630. ban_if_materialized_backward=False,
  2631. ban_if_not_in_allowlist=False,
  2632. )
  2633. if memory_budget == 0:
  2634. return node_info.inputs
  2635. runtime_optimized_saved_values, _ = solve_min_cut(
  2636. joint_graph,
  2637. node_info,
  2638. min_cut_options,
  2639. )
  2640. # return runtime_optimized_saved_values
  2641. if memory_budget == 1:
  2642. return runtime_optimized_saved_values
  2643. def estimate_activations_size(saved_values: list[fx.Node]) -> float:
  2644. return sum(map(_size_of, saved_values)) / 1e9
  2645. min_act_size = estimate_activations_size(node_info.inputs)
  2646. max_act_size = estimate_activations_size(runtime_optimized_saved_values)
  2647. # The optimized choice is smaller than the inputs anyways
  2648. if max_act_size <= min_act_size:
  2649. return runtime_optimized_saved_values
  2650. def get_normalized_size(sz: float) -> float:
  2651. return (sz / 1e9) / (max_act_size - min_act_size)
  2652. def get_mem_ratio(activations: list[fx.Node]) -> float:
  2653. return (estimate_activations_size(activations) - min_act_size) / (
  2654. max_act_size - min_act_size
  2655. )
  2656. more_aggressive_options = replace(
  2657. min_cut_options,
  2658. ban_if_used_far_apart=False,
  2659. ban_if_long_fusible_chains=False,
  2660. ban_if_materialized_backward=False,
  2661. )
  2662. more_aggressive_saved_values, _ = solve_min_cut(
  2663. joint_graph, node_info, more_aggressive_options
  2664. )
  2665. if get_mem_ratio(more_aggressive_saved_values) < memory_budget:
  2666. return more_aggressive_saved_values
  2667. aggressive_options = replace(
  2668. more_aggressive_options,
  2669. ban_if_not_in_allowlist=False,
  2670. )
  2671. aggressive_recomputation_saved_values, banned_nodes = solve_min_cut(
  2672. joint_graph, node_info, aggressive_options
  2673. )
  2674. if get_mem_ratio(aggressive_recomputation_saved_values) < memory_budget:
  2675. return aggressive_recomputation_saved_values
  2676. from torch._inductor.fx_utils import get_node_storage
  2677. input_storages = OrderedSet(get_node_storage(node) for node in node_info.inputs)
  2678. def get_recomputable_banned_nodes(
  2679. banned_nodes: OrderedSet[fx.Node],
  2680. ) -> list[fx.Node]:
  2681. return [
  2682. i
  2683. for i in banned_nodes
  2684. if (
  2685. # Only allow recomputing nodes that are actually required for BW
  2686. i.dist_from_bw < int(1e9) # type: ignore[attr-defined]
  2687. and (
  2688. get_node_storage(i) not in input_storages
  2689. or is_non_builtin_to_include(i)
  2690. )
  2691. )
  2692. ]
  2693. recomputable_banned_nodes = get_recomputable_banned_nodes(banned_nodes)
  2694. must_save_nodes = [
  2695. i
  2696. for i in recomputable_banned_nodes
  2697. if i.meta.get("recompute", False) == CheckpointPolicy.MUST_SAVE
  2698. ]
  2699. recomputable_banned_nodes = [
  2700. i for i in recomputable_banned_nodes if i not in must_save_nodes
  2701. ]
  2702. # default: runtime_optimized_saved_values
  2703. # more aggressive: more_aggressive_saved_values
  2704. # full aggressive: aggressive_recomputation_saved_values
  2705. all_recomputable_banned_nodes = sorted(
  2706. recomputable_banned_nodes, key=_size_of, reverse=True
  2707. )
  2708. if len(all_recomputable_banned_nodes) == 0:
  2709. return node_info.inputs + must_save_nodes
  2710. memories_banned_nodes = [
  2711. get_normalized_size(_size_of(i)) for i in all_recomputable_banned_nodes
  2712. ]
  2713. runtimes_banned_nodes = [
  2714. estimate_runtime(node) for node in all_recomputable_banned_nodes
  2715. ]
  2716. from torch.utils._mode_utils import no_dispatch
  2717. def get_saved_values_knapsack(
  2718. memory_budget: float, node_info: NodeInfo, joint_graph: fx.Graph
  2719. ) -> tuple[list[fx.Node], float]:
  2720. with no_dispatch():
  2721. (
  2722. expected_runtime,
  2723. saved_node_idxs,
  2724. recomputable_node_idxs,
  2725. ) = _optimize_runtime_with_given_memory(
  2726. joint_graph,
  2727. memories_banned_nodes,
  2728. runtimes_banned_nodes,
  2729. max(memory_budget, 0),
  2730. node_info,
  2731. all_recomputable_banned_nodes,
  2732. )
  2733. dont_ban: OrderedSet[fx.Node] = OrderedSet()
  2734. for idx in recomputable_node_idxs:
  2735. # if idx in all_recomputable_banned_nodes:
  2736. try:
  2737. dont_ban.add(all_recomputable_banned_nodes[idx])
  2738. except BaseException: # noqa: B036
  2739. pass
  2740. if not dont_ban.issubset(all_recomputable_banned_nodes):
  2741. raise AssertionError(
  2742. "dont_ban must be a subset of all_recomputable_banned_nodes"
  2743. )
  2744. saved_values, _ = solve_min_cut(
  2745. joint_graph,
  2746. node_info,
  2747. aggressive_options,
  2748. dont_ban,
  2749. )
  2750. if AOT_PARTITIONER_DEBUG:
  2751. create_structured_trace_for_min_cut_info(
  2752. joint_graph=joint_graph,
  2753. all_recomputable_banned_nodes=all_recomputable_banned_nodes,
  2754. saved_node_idxs=saved_node_idxs,
  2755. recomputable_node_idxs=recomputable_node_idxs,
  2756. expected_runtime=expected_runtime,
  2757. memories_banned_nodes=[
  2758. _size_of(i) for i in all_recomputable_banned_nodes
  2759. ],
  2760. normalized_memories_banned_nodes=memories_banned_nodes,
  2761. runtimes_banned_nodes=runtimes_banned_nodes,
  2762. min_cut_saved_values=saved_values,
  2763. )
  2764. return saved_values, expected_runtime
  2765. if config.visualize_memory_budget_pareto:
  2766. def estimate_for_budget(b: float) -> tuple[float, float, float]:
  2767. saved_values, expected_runtime = get_saved_values_knapsack(
  2768. b, node_info=node_info, joint_graph=joint_graph
  2769. )
  2770. return (
  2771. b,
  2772. sum(runtimes_banned_nodes) - expected_runtime,
  2773. get_mem_ratio(saved_values),
  2774. )
  2775. options = [estimate_for_budget(0.0), estimate_for_budget(1.0)]
  2776. if options[0][1:] != options[1][1:]:
  2777. bisects = [(options[0], options[1])]
  2778. while bisects:
  2779. lhs, rhs = bisects.pop()
  2780. if rhs[0] - lhs[0] < 1e-3:
  2781. options.append(lhs)
  2782. options.append(rhs)
  2783. continue
  2784. mid = estimate_for_budget((lhs[0] + rhs[0]) / 2)
  2785. if mid[1:] != lhs[1:]:
  2786. bisects.append((lhs, mid))
  2787. if mid[1:] != rhs[1:]:
  2788. bisects.append((mid, rhs))
  2789. options.sort()
  2790. import matplotlib.pyplot as plt
  2791. x_values = [item[2] for item in options]
  2792. y_values = [item[1] for item in options]
  2793. # Plotting the values with updated axis labels and chart title
  2794. plt.figure(figsize=(10, 6))
  2795. plt.plot(x_values, y_values, marker="o")
  2796. # Adding labels for each point
  2797. for i, txt in enumerate(x_values):
  2798. plt.annotate(
  2799. f"{txt:.4f}",
  2800. (txt, y_values[i]),
  2801. textcoords="offset points",
  2802. xytext=(0, 10),
  2803. ha="center",
  2804. )
  2805. plt.xlabel("Memory Budget")
  2806. plt.ylabel("Runtime of Recomputed Components")
  2807. plt.title("Pareto Frontier of Memory Budget vs. Recomputation Runtime")
  2808. plt.grid(True)
  2809. fig = plt.gcf()
  2810. plt.show()
  2811. fig_dir = os.getcwd()
  2812. if config.memory_budget_pareto_dir is not None:
  2813. fig_dir = config.memory_budget_pareto_dir
  2814. os.makedirs(fig_dir, exist_ok=True)
  2815. rank_suffix = ""
  2816. if torch.distributed.is_available() and torch.distributed.is_initialized():
  2817. rank_suffix = f"_rank_{torch.distributed.get_rank()}"
  2818. fig_name = os.path.join(
  2819. fig_dir, f"memory_budget_pareto{rank_suffix}_{get_aot_graph_name()}.svg"
  2820. )
  2821. fig.savefig(fig_name)
  2822. log.warning("Generated Pareto frontier curve at %s", fig_name)
  2823. # todo(chilli): Estimated doesn't align exactly with actual - actual is
  2824. # usually less memory than estimated. i'm guessing (actually quite
  2825. # unsure about this) that's because estimated is just only including
  2826. # tensors we actually banned from recompute, but there may be other
  2827. # tensors that we choose to save.
  2828. return get_saved_values_knapsack(
  2829. memory_budget=memory_budget, node_info=node_info, joint_graph=joint_graph
  2830. )[0]
  2831. def _sync_decision_cross_ranks(
  2832. joint_graph: torch.fx.Graph, saved_values: list[torch.fx.Node]
  2833. ) -> list[torch.fx.Node]:
  2834. # use the same policy across different GPUs
  2835. from torch._subclasses.fake_tensor import unset_fake_temporarily
  2836. def has_collectives(joint_graph: torch.fx.Graph) -> bool:
  2837. for node in joint_graph.nodes:
  2838. if isinstance(
  2839. node.target, torch._ops.OpOverload
  2840. ) and node.target.namespace in {"_c10d_functional", "c10d_functional"}:
  2841. return True
  2842. return False
  2843. def has_same_nodes(joint_graph: torch.fx.Graph) -> bool:
  2844. # proxy to check if the graph is the same across different GPUs.
  2845. # We only consider the name and order of nodes. A more robust way
  2846. # would be to check the hash of the whole graph (disregarding input shapes),
  2847. # this is a reasonable first-order approximation.
  2848. node_str = "/".join(x.name for x in joint_graph.nodes)
  2849. inputs = hashlib.sha256(node_str.encode("utf-8")).hexdigest()
  2850. all_inputs = [None for _ in range(torch.distributed.get_world_size())]
  2851. with no_dispatch(), unset_fake_temporarily():
  2852. # TODO: maybe use a different process group?
  2853. torch.distributed.all_gather_object(all_inputs, inputs)
  2854. return all(all_inputs[0] == x for x in all_inputs)
  2855. if (
  2856. torch.distributed.is_available()
  2857. and torch.distributed.is_initialized()
  2858. and torch.distributed.get_world_size() > 1
  2859. and has_collectives(joint_graph)
  2860. and has_same_nodes(joint_graph)
  2861. ):
  2862. with no_dispatch(), unset_fake_temporarily():
  2863. objects = [[x.name for x in saved_values]]
  2864. saved_ops_names_all_ranks: list[list[str]] = [
  2865. [] for _ in range(torch.distributed.get_world_size())
  2866. ]
  2867. torch.distributed.all_gather_object(saved_ops_names_all_ranks, objects[0])
  2868. name_to_node = get_name_to_node(joint_graph)
  2869. saved_sizes: list[int] = []
  2870. saved_ops_with_sizes: dict[str, int] = {}
  2871. for idx, saved_ops_names in enumerate(saved_ops_names_all_ranks):
  2872. saved_nodes = [name_to_node[op_name] for op_name in saved_ops_names]
  2873. saved_size = 0
  2874. for node in saved_nodes:
  2875. size_of_node = _size_of(node)
  2876. saved_size += size_of_node
  2877. if idx == torch.distributed.get_rank():
  2878. saved_ops_with_sizes[node.name] = size_of_node
  2879. saved_ops_with_sizes["total size"] = saved_size
  2880. saved_sizes.append(saved_size)
  2881. saved_sizes_tensor = torch.tensor(
  2882. saved_sizes,
  2883. device=torch.distributed.distributed_c10d._get_object_coll_device(),
  2884. )
  2885. torch.distributed.all_reduce(
  2886. saved_sizes_tensor, op=torch.distributed.distributed_c10d.ReduceOp.MAX
  2887. )
  2888. picked_rank_idx = int(torch.argmin(saved_sizes_tensor).item())
  2889. sync_decision_cross_ranks_str = f"picked_rank_idx={picked_rank_idx}, saved_nodes of current rank={saved_ops_with_sizes}"
  2890. trace_structured(
  2891. "artifact",
  2892. metadata_fn=lambda: {
  2893. "name": "aot_joint_graph_sync_decision_cross_ranks",
  2894. "encoding": "string",
  2895. },
  2896. payload_fn=lambda: sync_decision_cross_ranks_str,
  2897. )
  2898. saved_values = [
  2899. name_to_node[n] for n in saved_ops_names_all_ranks[picked_rank_idx]
  2900. ]
  2901. return saved_values
  2902. def thread_graphsafe_rng_from_hops(
  2903. module: fx.GraphModule, is_backward: bool
  2904. ) -> fx.GraphModule:
  2905. """
  2906. Graph-safe RNG lets torch.compile use CUDA Graphs for graphs with RNG ops.
  2907. For graphs without HOPs, the partitioner adds placeholder nodes
  2908. fwd_rng_state_* and bw_rng_state_* to the forward and backward graphs. At
  2909. runtime, the AOTDispatcher retrieves these RNG states and passes them to the
  2910. compiled graphs.
  2911. This works well for no-HOP graphs. With HOPs, the partitioner runs
  2912. recursively: it first partitions the HOP (producing forward/backward HOP
  2913. subgraphs) and then stitches them back into the outer joint graph. For HOPs
  2914. that contain RNG ops, the outer joint graph now includes HOP subgraph
  2915. modules with extra RNG placeholders. We must thread these placeholders
  2916. through the outer module partitioned forward and backward graphs—this
  2917. function does exactly that. It collects the RNG placeholder nodes from the
  2918. HOPs and creates corresponding placeholders in the outer forward and
  2919. backward graphs.
  2920. There is a catch: for a short period, the joint graph is in a “bad” state.
  2921. The HOP subgraphs expect additional inputs (because of the new
  2922. placeholders), but the outer graph call sites don't yet provide them. We
  2923. can't fix this in the joint graph because the joint graph's input signature
  2924. is fixed (primals, tangents). As a compromise, we keep the joint graph in
  2925. somewhat of a bad state for some time and, once the outer forward and
  2926. backward graphs are partitioned, insert the corresponding RNG placeholders
  2927. and wire up the calls.
  2928. """
  2929. rng_count = 0
  2930. rng_string = "bwd_rng_state" if is_backward else "fwd_rng_state"
  2931. last_input = next(reversed(module.graph.find_nodes(op="placeholder")))
  2932. for hop_node in module.graph.find_nodes(
  2933. op="call_function", target=torch.ops.higher_order.invoke_subgraph
  2934. ):
  2935. subgraph = getattr(module, hop_node.args[0].target)
  2936. if isinstance(subgraph, fx.GraphModule):
  2937. new_rng_inputs: list[fx.Node] = []
  2938. for placeholder_node in subgraph.graph.find_nodes(op="placeholder"):
  2939. if rng_string in placeholder_node.name:
  2940. # Found a rng state placeholder in the hop graph, lets add
  2941. # the corresponding node in the outer graph
  2942. with module.graph.inserting_after(last_input):
  2943. rng_state = module.graph.placeholder(
  2944. f"{rng_string}_{rng_count}"
  2945. )
  2946. rng_count += 1
  2947. rng_state.meta["val"] = placeholder_node.meta["val"]
  2948. last_input = rng_state
  2949. new_rng_inputs.append(rng_state)
  2950. if new_rng_inputs:
  2951. # Pass on the new args that include the new_rng_inputs
  2952. with module.graph.inserting_after(hop_node):
  2953. new_hop_node_with_fixed_args = module.graph.create_node(
  2954. "call_function",
  2955. torch.ops.higher_order.invoke_subgraph,
  2956. (*hop_node.args, *new_rng_inputs), # type: ignore[arg-type]
  2957. {},
  2958. )
  2959. hop_node.replace_all_uses_with(
  2960. new_hop_node_with_fixed_args, propagate_meta=True
  2961. )
  2962. # Setup the eager_input_vals
  2963. eager_vals = hop_node.meta.get("eager_input_vals")
  2964. if eager_vals:
  2965. eager_args, eager_kwargs = eager_vals
  2966. new_eager_args = (
  2967. *eager_args,
  2968. *[inp.meta["val"] for inp in new_rng_inputs],
  2969. )
  2970. new_hop_node_with_fixed_args.meta["eager_input_vals"] = (
  2971. new_eager_args,
  2972. eager_kwargs,
  2973. )
  2974. module.graph.erase_node(hop_node)
  2975. return module
  2976. def classify_nodes(
  2977. joint_module: fx.GraphModule,
  2978. static_lifetime_input_indices: list[int],
  2979. num_fwd_outputs: int,
  2980. ) -> NodeInfo:
  2981. name_to_node = get_name_to_node(joint_module.graph)
  2982. required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
  2983. for node in joint_module.graph.nodes:
  2984. if node.op == "placeholder" and "tangents" in node.target:
  2985. required_bw_nodes.add(node)
  2986. elif _must_be_in_backward(node):
  2987. required_bw_nodes.add(node)
  2988. if node in required_bw_nodes:
  2989. required_bw_nodes.update(node.users)
  2990. primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
  2991. fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
  2992. inputs = primal_inputs + fwd_seed_offset_inputs
  2993. fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
  2994. _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
  2995. )
  2996. # Note: [tangents_closure vs required_bw_nodes]
  2997. #
  2998. # required_bw_nodes is used to determine which nodes need edges to
  2999. # the sink. It is important to also track tangents closure because
  3000. # that determines whether you can save that tensor, i.e., whether you
  3001. # want to connect x_in or x_out to the sink.
  3002. tangents_closure = required_bw_nodes.copy()
  3003. required_bw_nodes.update(
  3004. o for o in bwd_outputs if o is not None and o.op != "output"
  3005. )
  3006. forward_only_graph = _extract_graph_with_inputs_outputs(
  3007. joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
  3008. )
  3009. required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
  3010. name_to_node[node.name]
  3011. for node in forward_only_graph.nodes
  3012. if node.op != "output"
  3013. )
  3014. unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
  3015. node
  3016. for node in joint_module.graph.nodes
  3017. if node not in required_fw_nodes and node not in required_bw_nodes
  3018. )
  3019. static_lifetime_input_nodes = OrderedSet(
  3020. p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
  3021. )
  3022. fw_cnt = 0
  3023. fw_order = {}
  3024. for node in joint_module.graph.nodes:
  3025. if node in required_fw_nodes:
  3026. fw_order[node] = fw_cnt
  3027. fw_cnt += 1
  3028. return NodeInfo(
  3029. inputs,
  3030. required_fw_nodes,
  3031. required_bw_nodes,
  3032. tangents_closure,
  3033. unclaimed_nodes,
  3034. fw_order,
  3035. static_lifetime_input_nodes,
  3036. )
  3037. def min_cut_rematerialization_partition(
  3038. joint_module: fx.GraphModule,
  3039. _joint_inputs: Any,
  3040. compiler: str = "inductor",
  3041. *,
  3042. num_fwd_outputs: int,
  3043. static_lifetime_input_indices: list[int] | None = None,
  3044. ) -> tuple[fx.GraphModule, fx.GraphModule]:
  3045. """
  3046. Partitions the joint graph such that the backward recomputes the forward.
  3047. Recomputing helps in trading off memory bandwidth with computation.
  3048. To create the fwd and bwd graph, we copy the joint graph, manually set the
  3049. outputs to just original forward or backward outputs. And then we run the
  3050. resulting graphs through dead code elimination.
  3051. .. warning::
  3052. This API is experimental and likely to change.
  3053. Args:
  3054. joint_module(fx.GraphModule): The joint forward and backward graph. This
  3055. is the result of AOT Autograd tracing.
  3056. _joint_inputs: The inputs to the joint graph. This is unused.
  3057. compiler: This option determines the default set of recomputable ops.
  3058. Currently, there are two options: ``nvfuser`` and ``inductor``.
  3059. recomputable_ops: This is an optional set of recomputable ops. If this
  3060. is not None, then this set of ops will be used instead of the
  3061. default set of ops.
  3062. num_fwd_outputs: The number of outputs from the forward graph.
  3063. Returns:
  3064. Returns the generated forward and backward Fx graph modules.
  3065. """
  3066. joint_module.graph.eliminate_dead_code()
  3067. joint_module.recompile()
  3068. fx_g = joint_module.graph
  3069. # add the CSE pass
  3070. if config.cse:
  3071. cse_graph = fx_graph_cse(fx_g)
  3072. joint_module.graph = cse_graph
  3073. joint_graph = joint_module.graph
  3074. graph_has_recomputable_ops = has_recomputable_ops(joint_module)
  3075. graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
  3076. if graph_has_recomputable_ops:
  3077. joint_module = cleanup_recompute_tags(joint_module, is_default_partition=False)
  3078. if not config.unsafe_allow_optimization_of_collectives:
  3079. force_save_collectives(joint_module)
  3080. force_save_effectful_ops(joint_module)
  3081. force_save_bw_mutation_src(joint_module)
  3082. if static_lifetime_input_indices is None:
  3083. static_lifetime_input_indices = []
  3084. node_info = classify_nodes(
  3085. joint_module, static_lifetime_input_indices, num_fwd_outputs
  3086. )
  3087. # networkx blows up on graphs with no required backward nodes
  3088. # Since there's nothing to partition anyway, and the default partitioner can "handle"
  3089. # this case, send our graph over to the default partitioner.
  3090. if len(node_info.required_bw_nodes) == 0:
  3091. return default_partition(
  3092. joint_module,
  3093. _joint_inputs,
  3094. num_fwd_outputs=num_fwd_outputs,
  3095. static_lifetime_input_indices=static_lifetime_input_indices,
  3096. static_lifetime_input_nodes=node_info.static_lifetime_input_nodes,
  3097. )
  3098. for node in reversed(joint_module.graph.nodes):
  3099. if node.op == "output":
  3100. node.dist_from_bw = int(1e9)
  3101. elif not node_info.is_required_fw(node):
  3102. node.dist_from_bw = 0
  3103. else:
  3104. node.dist_from_bw = int(1e9)
  3105. for user in node.users:
  3106. node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1)
  3107. memory_budget = config.activation_memory_budget
  3108. for node in joint_graph.nodes:
  3109. if isinstance(node.meta.get("memory_budget", None), float):
  3110. memory_budget = node.meta["memory_budget"]
  3111. break
  3112. saved_values = choose_saved_values_set(
  3113. joint_graph,
  3114. node_info,
  3115. memory_budget=memory_budget,
  3116. )
  3117. # pyrefly: ignore [unbound-name]
  3118. if config._sync_decision_cross_ranks:
  3119. saved_values = _sync_decision_cross_ranks(joint_graph, saved_values)
  3120. # save_for_backward on tensors and stashes symints in autograd .ctx
  3121. saved_sym_nodes = list(filter(is_sym_node, saved_values))
  3122. saved_values = list(filter(lambda n: not is_sym_node(n), saved_values))
  3123. # NB: saved_sym_nodes will be mutated to reflect the actual saved symbols
  3124. fw_module, bw_module = _extract_fwd_bwd_modules(
  3125. joint_module,
  3126. saved_values,
  3127. # pyrefly: ignore [bad-argument-type]
  3128. saved_sym_nodes=saved_sym_nodes,
  3129. num_fwd_outputs=num_fwd_outputs,
  3130. static_lifetime_input_nodes=node_info.static_lifetime_input_nodes,
  3131. )
  3132. if graph_has_recomputable_ops:
  3133. if graph_has_recomputable_rng_ops:
  3134. fw_module, bw_module = functionalize_rng_ops(
  3135. joint_module, fw_module, bw_module, len(saved_sym_nodes)
  3136. )
  3137. bw_module = reordering_to_mimic_autograd_engine(bw_module)
  3138. # pyrefly: ignore [unbound-name]
  3139. if config.enable_activation_offloading:
  3140. from ._activation_offloading.activation_offloading import (
  3141. enable_activation_offloading,
  3142. )
  3143. enable_activation_offloading(
  3144. fw_module,
  3145. bw_module,
  3146. num_fwd_outputs,
  3147. node_info.static_lifetime_input_nodes,
  3148. )
  3149. # raise all getitem ops to as early as possible
  3150. # this is helpful for memory, especially in the case of aot_eager backend
  3151. fw_module = raise_getitems(fw_module)
  3152. bw_module = raise_getitems(bw_module)
  3153. fw_module = thread_graphsafe_rng_from_hops(fw_module, is_backward=False)
  3154. bw_module = thread_graphsafe_rng_from_hops(bw_module, is_backward=True)
  3155. if AOT_PARTITIONER_DEBUG:
  3156. # Calculate sorted sizes of saved values
  3157. sorted_sizes = sorted([(_size_of(i), str(i)) for i in saved_values])
  3158. # Log total theoretical activations stored
  3159. total_activations_size_gb = sum(_size_of(i) for i in saved_values) / 1e9
  3160. log.info("Theoretical Activations Stored: %.2f GB", total_activations_size_gb)
  3161. # Log theoretical per activation storage sizes
  3162. log.info("Theoretical Per Activation Storage Sizes: %s", sorted_sizes)
  3163. fw_module_nodes = OrderedSet(
  3164. node.name for node in fw_module.graph.nodes if node.op == "call_function"
  3165. )
  3166. bw_module_nodes = OrderedSet(
  3167. node.name for node in bw_module.graph.nodes if node.op == "call_function"
  3168. )
  3169. remat_nodes = fw_module_nodes & bw_module_nodes
  3170. counts: dict[str, int] = defaultdict(int)
  3171. for node in fw_module.graph.nodes:
  3172. if node.name in remat_nodes and hasattr(node.target, "_overloadpacket"):
  3173. counts[str(node.target._overloadpacket)] += 1
  3174. log.info(
  3175. "# remat/fw/bw: %d/%d/%d",
  3176. len(remat_nodes),
  3177. len(fw_module_nodes),
  3178. len(bw_module_nodes),
  3179. )
  3180. rematerialized_ops = sorted(
  3181. counts.items(), key=operator.itemgetter(1), reverse=True
  3182. )
  3183. log.info("Count of Ops Rematerialized: %s", rematerialized_ops)
  3184. return fw_module, bw_module
  3185. def draw_graph(
  3186. traced: torch.fx.GraphModule,
  3187. fname: str,
  3188. figname: str = "fx_graph",
  3189. clear_meta: bool = True,
  3190. prog: str | list[str] | None = None,
  3191. parse_stack_trace: bool = False,
  3192. dot_graph_shape: str | None = None,
  3193. ) -> None:
  3194. if clear_meta:
  3195. new_graph = copy.deepcopy(traced.graph)
  3196. traced = fx.GraphModule(traced, new_graph)
  3197. for node in traced.graph.nodes:
  3198. node.meta = {} # pyrefly: ignore[implicit-any]
  3199. base, ext = os.path.splitext(fname)
  3200. if not ext:
  3201. ext = "." + config.torch_compile_graph_format
  3202. log.info("Writing FX graph to file: %s%s", base, ext)
  3203. g = graph_drawer.FxGraphDrawer(
  3204. traced,
  3205. figname,
  3206. parse_stack_trace=parse_stack_trace,
  3207. dot_graph_shape=dot_graph_shape,
  3208. )
  3209. x = g.get_main_dot_graph()
  3210. write_method = getattr(x, "write_" + ext.lstrip("."))
  3211. fname = f"{base}{ext}"
  3212. if prog is None:
  3213. write_method(fname)
  3214. else:
  3215. write_method(fname, prog=prog)