output_graph.py 176 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175
  1. """
  2. Core graph building functionality for PyTorch's Dynamo system. This module contains
  3. the essential components for constructing and managing FX graphs during compilation:
  4. - OutputGraph: Manages the overall graph construction and compilation process. It owns
  5. a SubgraphTracer and handles graph compilation, execution, and state management.
  6. OutputGraph also manages features like graph deduplication, symbolic shape handling,
  7. and tracking of side effects.
  8. - SubgraphTracer: Handles the actual FX graph construction by tracing Python code.
  9. It supports advanced features like higher-order operators through nested tracers,
  10. lifting of free variables, and handling of symbolic shapes.
  11. The module supports key Dynamo features including:
  12. - Higher-order operators through nested SubgraphTracers
  13. - Graph deduplication for optimization
  14. - Symbolic shape handling and propagation
  15. - Side effect tracking and management
  16. - Guard insertion and management
  17. """
  18. import collections
  19. import contextlib
  20. import copy
  21. import dataclasses
  22. import functools
  23. import inspect
  24. import itertools
  25. import logging
  26. import operator
  27. import re
  28. import sys
  29. import time
  30. import traceback
  31. import warnings
  32. import weakref
  33. from collections.abc import Callable, Generator, Sequence
  34. from dataclasses import dataclass, field as dc_field
  35. from types import CodeType
  36. from typing import Any, cast, Optional, TYPE_CHECKING, Union
  37. from typing_extensions import ParamSpec, TypeVar
  38. import sympy
  39. import torch._guards
  40. import torch._logging
  41. import torch.distributed as dist
  42. import torch.nn
  43. import torch.utils._pytree as pytree
  44. from torch import fx, Tensor
  45. from torch._C._dynamo import guards
  46. from torch._dynamo.exc import ShortenTraceback, TensorifyScalarRestartAnalysis
  47. from torch._guards import (
  48. CompileContext,
  49. CompileId,
  50. GlobalContextCheckpointState,
  51. Source,
  52. tracing,
  53. TracingContext,
  54. )
  55. from torch._library.fake_class_registry import FakeScriptObject
  56. from torch._library.opaque_object import is_opaque_type
  57. from torch._subclasses.fake_tensor import FakeTensor
  58. from torch._utils_internal import signpost_event
  59. from torch.export.dynamic_shapes import _ConstraintTarget
  60. from torch.fx._lazy_graph_module import _make_graph_module # type: ignore[attr-defined]
  61. from torch.fx.experimental._backward_state import BackwardState
  62. from torch.fx.experimental.symbolic_shapes import (
  63. free_symbols,
  64. guard_scalar,
  65. is_symbolic,
  66. ShapeEnv,
  67. Specialization,
  68. uninteresting_files,
  69. )
  70. from torch.fx.node import Target
  71. from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
  72. from torch.utils._ordered_set import OrderedSet
  73. from torch.utils._python_dispatch import is_traceable_wrapper_subclass
  74. from . import config, exc, logging as torchdynamo_logging, variables
  75. from .backends.registry import CompiledFn, CompilerFn
  76. from .bytecode_transformation import (
  77. create_binary_slice,
  78. create_binary_subscr,
  79. create_build_tuple,
  80. create_call_function,
  81. create_dup_top,
  82. create_instruction,
  83. create_load_const,
  84. create_rot_n,
  85. create_swap,
  86. Instruction,
  87. unique_id,
  88. )
  89. from .code_context import code_context
  90. from .codegen import PyCodegen
  91. from .current_scope_id import enter_new_scope
  92. from .device_interface import get_interface_for_device
  93. from .exc import (
  94. BackendCompilerFailed,
  95. exceptions_allowed_to_be_fallback,
  96. SkipFrame,
  97. unimplemented,
  98. unimplemented_with_warning,
  99. )
  100. from .graph_bytecode_inputs import has_user_objects, index_to_bytecode_constructor
  101. from .graph_deduplication import apply_graph_deduplication
  102. from .graph_id_filter import (
  103. get_backend_override_for_compile_id,
  104. get_inductor_config_override_for_compile_id,
  105. )
  106. from .graph_region_tracker import GraphRegionTracker
  107. from .guards import GuardBuilder, install_guard
  108. from .mutation_guard import is_dynamic_nn_module
  109. from .side_effects import AttributeMutationExisting, SideEffects, ValueMutationExisting
  110. from .source import (
  111. _get_source_debug_name,
  112. AttrSource,
  113. BackwardStateSource,
  114. ConstantSource,
  115. GetItemSource,
  116. GlobalStateSource,
  117. is_constant_source,
  118. is_from_local_source,
  119. LocalSource,
  120. NumpyTensorSource,
  121. ParamBufferSource,
  122. ShapeEnvSource,
  123. SyntheticLocalSource,
  124. TensorProperty,
  125. TensorPropertySource,
  126. )
  127. from .utils import (
  128. _extract_tensor_dict,
  129. checkpoint_params,
  130. CleanupHook,
  131. clone_inputs,
  132. compilation_time_metrics,
  133. count_calls,
  134. counters,
  135. dynamo_timed,
  136. get_chromium_event_logger,
  137. get_instruction_source_311,
  138. get_locals_to_steal,
  139. get_static_address_type,
  140. get_unique_name_wrt,
  141. graph_break_reasons,
  142. increment_op_count,
  143. istype,
  144. lazy_format_graph_code,
  145. LazyString,
  146. nn_module_proxy,
  147. same,
  148. set_example_value,
  149. )
  150. from .variables.builder import (
  151. BackwardStateGraphArg,
  152. GraphArg,
  153. TrackedFake,
  154. wrap_fx_proxy,
  155. )
  156. from .variables.ctx_manager import ContextWrappingVariable
  157. from .variables.functions import ClosureConversionError, VariableTracker
  158. from .variables.lists import BaseListVariable
  159. from .variables.misc import NullVariable
  160. from .variables.nn_module import NNModuleVariable
  161. from .variables.tensor import (
  162. NumpyNdarrayVariable,
  163. SymNodeVariable,
  164. UnspecializedPythonVariable,
  165. )
  166. from .variables.torch_function import TensorWithTFOverrideVariable
  167. from .variables.user_defined import UserDefinedDictVariable
  168. if TYPE_CHECKING:
  169. from torch._dynamo.dynamo_profiler import DynamoProfilerState
  170. from torch._dynamo.package import CompilePackage
  171. from torch._dynamo.symbolic_convert import InstructionTranslatorBase
  172. from torch.multiprocessing.reductions import StorageWeakRef
  173. log = logging.getLogger(__name__)
  174. graph_tabular_log = torch._logging.getArtifactLogger(__name__, "graph")
  175. graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")
  176. graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes")
  177. trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
  178. RootGuardManager = guards.RootGuardManager
  179. # Capture fn pointer at import time
  180. # This is to guard against trying to mark the iterated tensors
  181. # as static in case user overrides fn ptr
  182. og_module_named_buffers_fn_ptr = torch.nn.Module.named_buffers
  183. og_module_named_parameters_fn_ptr = torch.nn.Module.named_parameters
  184. def _wrap_with_inductor_config(
  185. compiler_fn: Any, config_patches: dict[str, Any]
  186. ) -> Callable[..., Any]:
  187. """
  188. Wrap a compiler function to apply inductor config patches during compilation.
  189. """
  190. from torch._inductor import config as inductor_config
  191. def wrapped(gm: Any, example_inputs: Any) -> Any:
  192. with inductor_config.patch(config_patches):
  193. return compiler_fn(gm, example_inputs)
  194. # Preserve function metadata for logging
  195. wrapped.__name__ = getattr(compiler_fn, "__name__", "<wrapped>")
  196. wrapped.__wrapped__ = compiler_fn # type: ignore[attr-defined]
  197. return wrapped
  198. @dataclass(frozen=True)
  199. class AliasingInfo:
  200. has_aliasing: bool
  201. msg: str
  202. @dataclass(frozen=True)
  203. class MutationInfo:
  204. has_mutation: bool
  205. msg: str
  206. def collect_reachable_grad_fns(
  207. tensors_with_sources: list[tuple[torch.Tensor, str | None]],
  208. stop_at: set[torch.autograd.graph.Node] | None = None,
  209. ) -> set[torch.autograd.graph.Node]:
  210. """Collect all grad_fns reachable from tensors' autograd graphs.
  211. Performs a DFS traversal and collects all visited grad_fns.
  212. Optionally stops traversal nodes in stop_at set. This signals the
  213. autograd.grad boundary.
  214. Args:
  215. tensors_with_sources: List of (tensor, source_name) tuples to start search from.
  216. stop_at: Optional set of grad_fns where traversal should stop (excluded from result).
  217. Returns:
  218. Set of all reachable grad_fns.
  219. """
  220. if stop_at is None:
  221. stop_at = set()
  222. visited: set[torch.autograd.graph.Node] = set()
  223. stack: list[torch.autograd.graph.Node] = []
  224. for tensor, _ in tensors_with_sources:
  225. if isinstance(tensor, torch.Tensor):
  226. grad_fn = tensor.grad_fn
  227. if grad_fn is not None:
  228. stack.append(grad_fn)
  229. while stack:
  230. node = stack.pop()
  231. if node in visited:
  232. continue
  233. # Stop traversal at stop_at nodes and don't include them
  234. # in consumed grad_fn list.
  235. if node in stop_at:
  236. continue
  237. visited.add(node)
  238. for next_fn, _ in node.next_functions:
  239. if next_fn is not None:
  240. stack.append(next_fn)
  241. return visited
  242. @functools.cache
  243. def _step_logger() -> Any:
  244. return torchdynamo_logging.get_step_logger(log)
  245. @dataclass
  246. class GraphCompileReason:
  247. """Stores why a given output graph was compiled; i.e. what caused the graph break."""
  248. reason: str
  249. user_stack: list[traceback.FrameSummary]
  250. # Indicates if this was a graph break reason due to graph break.
  251. graph_break: bool = True
  252. def __post_init__(self) -> None:
  253. if self.graph_break:
  254. graph_break_reasons.append(self)
  255. def _get_gen_rand_values_fn(random_calls: Any) -> Callable[[], list[Any]]:
  256. def _gen_rand_values() -> list[Any]:
  257. return [fn(*args, **kwargs) for fn, args, kwargs in random_calls]
  258. return _gen_rand_values
  259. class FakeRootModule(torch.nn.Module):
  260. """Trick the constructor of fx.GraphModule"""
  261. def __init__(self, nn_modules: dict[str, torch.nn.Module]) -> None:
  262. super().__init__()
  263. for k, v in nn_modules.items():
  264. setattr(self, k, v)
  265. def __repr__(self) -> str:
  266. return "FakeRootModule(...)"
  267. def add_nn_modules(self, nn_modules: dict[str, torch.nn.Module]) -> None:
  268. for k, v in nn_modules.items():
  269. setattr(self, k, v)
  270. class WrapperBackend:
  271. def __init__(self, backend: CompilerFn) -> None:
  272. self.backend: CompilerFn = backend
  273. def __call__(
  274. self, gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
  275. ) -> CompiledFn:
  276. self.restore = checkpoint_params(gm)
  277. self.gm = gm
  278. copy_gm = copy.deepcopy(self.gm)
  279. self.candidate = self.backend(copy_gm, example_inputs)
  280. if self.candidate is None or self.candidate is self.gm.forward:
  281. return self.gm.forward
  282. if not config.verify_correctness:
  283. return self.candidate
  284. # if verify_correctness=True
  285. try:
  286. correct = self.gm.forward(*clone_inputs(example_inputs))
  287. result = self.candidate(*clone_inputs(example_inputs))
  288. # TODO: replace `same` function with the one in testing
  289. if same(correct, result):
  290. return self.candidate
  291. raise RuntimeError(f"incorrect results of backend {self}")
  292. except Exception:
  293. log.exception("error in verify_correctness")
  294. raise
  295. finally:
  296. self.restore()
  297. Scope = dict[str, object]
  298. @dataclass
  299. class BytecodeTracingTimings:
  300. """Accumulated wall-clock time (ns) for major components during Dynamo
  301. bytecode tracing that are not related to variable trackers. Each field
  302. is an int accumulator that gets bumped via ``time.time_ns()`` in the
  303. corresponding hot-path wrapper. To add a new timer, add a field here
  304. and wire up the wrapper in the relevant function."""
  305. get_fake_value_ns: int = 0
  306. create_proxy_ns: int = 0
  307. wrap_to_fake_tensor_and_record_ns: int = 0
  308. variable_builder_call_ns: int = 0
  309. def report_and_reset(self) -> None:
  310. """Flush accumulated timings to the bytecode_tracing chromium event
  311. and to compilation_time_metrics, then reset all counters."""
  312. chromium_log = get_chromium_event_logger()
  313. # pyrefly: ignore [implicit-any]
  314. event_data = {}
  315. for f in dataclasses.fields(self):
  316. ns_val = getattr(self, f.name)
  317. if ns_val > 0:
  318. key = f.name.removesuffix("_ns")
  319. event_data[f"{key}_time_s"] = ns_val / 1e9
  320. compilation_time_metrics.setdefault(key, []).append(ns_val / 1e9)
  321. setattr(self, f.name, 0)
  322. if event_data:
  323. chromium_log.try_add_event_data("bytecode_tracing", **event_data)
  324. @dataclass
  325. class OutputGraphGuardsState:
  326. """
  327. A base class containing fields that are considered "persistent" when we
  328. want to save all the important state for reconstrucing guards in a different
  329. process. Normally we don't need to add states here, but we may have to when
  330. the information is needed to serialize the guards, so the fields here are
  331. supposed to be serializable as a requirement.
  332. """
  333. local_scope: Scope
  334. global_scope: Scope
  335. # This records the initial torch function mode stack for guarding
  336. torch_function_mode_stack: list[torch.overrides.TorchFunctionMode]
  337. guard_on_key_order: set[Source]
  338. # Map from graph input's `Source` to sizes / strides metadata
  339. input_source_to_sizes_strides: dict[Source, dict[str, Any]]
  340. dual_level: int
  341. functorch_layers: list[torch._functorch.pyfunctorch.FuncTorchInterpreter]
  342. current_device: Optional[torch.device]
  343. global_state_guard: torch._C._dynamo.guards.GlobalStateGuard
  344. _guards: torch._guards.GuardsSet
  345. _aotautograd_guards: list[torch._guards.GuardEnvExpr]
  346. # Whether or not the guards should be checked for correctness
  347. export: bool = False
  348. skip_guards_check: bool = False
  349. export_constraints: bool = False
  350. name_of_builtins_dict_key_in_fglobals: Optional[str] = None
  351. @property
  352. def shape_env(self) -> ShapeEnv:
  353. raise AssertionError(f"shape_env shouldn't be accessed from {type(self)}")
  354. @property
  355. def guards(self) -> torch._guards.GuardsSet:
  356. return self._guards
  357. @property
  358. def aotautograd_guards(self) -> list[torch._guards.GuardEnvExpr]:
  359. return self._aotautograd_guards
  360. def dump_guards_state(self) -> "OutputGraphGuardsState":
  361. # Dump a serializable version of self without extras
  362. return OutputGraphGuardsState(
  363. local_scope=self.local_scope,
  364. global_scope=self.global_scope,
  365. torch_function_mode_stack=self.torch_function_mode_stack,
  366. guard_on_key_order=self.guard_on_key_order,
  367. input_source_to_sizes_strides=self.input_source_to_sizes_strides,
  368. dual_level=self.dual_level,
  369. functorch_layers=self.functorch_layers,
  370. current_device=self.current_device,
  371. global_state_guard=self.global_state_guard,
  372. name_of_builtins_dict_key_in_fglobals=self.name_of_builtins_dict_key_in_fglobals,
  373. export=self.export,
  374. export_constraints=self.export_constraints,
  375. _guards=self.guards,
  376. _aotautograd_guards=self.aotautograd_guards,
  377. skip_guards_check=self.skip_guards_check,
  378. )
  379. @dataclass
  380. class StackLocalsMetadata:
  381. """
  382. Stores metadata for a frame's stack and locals for the purposes of building resume functions
  383. """
  384. num_stack: int = 0 # number of stack elements, minus removed NULLs
  385. locals_names: dict[str, int] = dc_field(
  386. default_factory=dict
  387. ) # order of locals codegen'd to the stack
  388. stack_null_idxes: list[int] = dc_field(default_factory=list)
  389. locals_null_keys: list[str] = dc_field(default_factory=list)
  390. stack_ctx_args: list[tuple[int, tuple[Any, ...]]] = dc_field(default_factory=list)
  391. stack_ctx_idxes_orig: list[int] = dc_field(default_factory=list)
  392. locals_ctx_args: list[tuple[str, tuple[Any, ...]]] = dc_field(default_factory=list)
  393. # TODO we should expand this to make it work for atribtrary in/out
  394. @dataclass
  395. class ExportMetaData:
  396. # maps graph input index to its' source which is later
  397. # used in export to map to correct user input. In its' flat form,
  398. # just looks like GetItem(base=LocalSource("foo", idx=0))
  399. graph_input_idx_to_local_source: dict[int, Source | None] = dc_field(
  400. default_factory=dict
  401. )
  402. # maps user output idx to what type of output it is. There are 3 options:
  403. # 1) graph out
  404. # 2) user input
  405. # 3) constants
  406. output_return_type: dict[int, tuple[str, Any]] = dc_field(default_factory=dict)
  407. # output spec of the traced function
  408. out_spec: Union[torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec] = (
  409. torch.utils._pytree._LEAF_SPEC
  410. )
  411. module_call_spec: dict[
  412. str,
  413. dict[str, Union[torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec]],
  414. ] = dc_field(default_factory=dict)
  415. def get_builtins_dict(global_scope: Scope) -> dict[str, Any]:
  416. # f_globals["__builtins__"] can be a dict or a module. This is an
  417. # implementation detail -
  418. # https://docs.python.org/3/library/builtins.html.
  419. # This makes guarding on any builtin messy because the guard check_fn
  420. # has to check if the __builtins__ is a module or dict, and then access
  421. # by either using getattr or getitem respectively.
  422. # To solve this problem, we insert a new entry in f_globals which points
  423. # to the builtins __dict__ and then we guard any builtin on this dict.
  424. # To avoid any collision with the pre-existing keys, we use the
  425. # install_global to give us a unique dict key.
  426. f_builtins = global_scope["__builtins__"]
  427. if not isinstance(f_builtins, dict):
  428. f_builtins = f_builtins.__dict__
  429. return f_builtins
  430. class OutputGraphCommon(OutputGraphGuardsState):
  431. """
  432. A minimal interface for full graph capture. It is intended to be
  433. the target of any tracer that feeds into backends.
  434. Currently dynamo's OutputGraph is the only known implementation
  435. of this interface, used by (aot) precompile and (strict) export.
  436. Importantly, that implementation also contains many other fields
  437. that are using during tracing but not included in this interface
  438. because they are not used once tracing is complete.
  439. It should be safe to assume that (caching) precompile also uses
  440. this interface.
  441. In the future, we want make_fx, used by (non-strict) export, to
  442. also implement this interface.
  443. The serializable part of this interface is OutputGraphGuardsState.
  444. We do not need to serialize other parts; however it will pay to
  445. be disciplined about what those other parts are, especially since
  446. we want other tracers to be able to meaningfully implement them,
  447. and we should generally try to cut them down when possible.
  448. """
  449. def __init__(
  450. self,
  451. output_graph_guards_state: OutputGraphGuardsState,
  452. import_sources: Optional[dict[str, str]] = None,
  453. shape_env: Optional[ShapeEnv] = None,
  454. export_metadata: Optional[ExportMetaData] = None,
  455. tracked_fakes_id_to_source: Optional[dict[int, list[Source]]] = None,
  456. ) -> None:
  457. super().__init__(
  458. output_graph_guards_state.local_scope,
  459. output_graph_guards_state.global_scope,
  460. output_graph_guards_state.torch_function_mode_stack,
  461. output_graph_guards_state.guard_on_key_order,
  462. output_graph_guards_state.input_source_to_sizes_strides,
  463. output_graph_guards_state.dual_level,
  464. output_graph_guards_state.functorch_layers,
  465. output_graph_guards_state.current_device,
  466. output_graph_guards_state.global_state_guard,
  467. output_graph_guards_state._guards,
  468. output_graph_guards_state._aotautograd_guards,
  469. output_graph_guards_state.export,
  470. output_graph_guards_state.skip_guards_check,
  471. output_graph_guards_state.export_constraints,
  472. output_graph_guards_state.name_of_builtins_dict_key_in_fglobals,
  473. )
  474. self.import_sources = import_sources or {}
  475. # The following fields are currently known to be used by clients.
  476. # In particular, we need:
  477. # - shape_env, for building guards
  478. # - export_metadata, for un/flattening inputs and outputs
  479. # - tracked_fakes_id_to_source, for processing tensor dim constraints
  480. self._shape_env = shape_env or ShapeEnv() # private for inheritance
  481. self.export_metadata = export_metadata or ExportMetaData()
  482. self.tracked_fakes_id_to_source: dict[int, list[Source]] = (
  483. tracked_fakes_id_to_source or {}
  484. )
  485. @property
  486. def shape_env(self) -> ShapeEnv:
  487. return self._shape_env
  488. def bypass_package(self, reason: str = "", **kwargs: Any) -> None:
  489. # NOTE: currently there are no tests for this but it is reachable
  490. # when building guards, so technically necessary to include here.
  491. # It is unclear whether we should include packaging altogether.
  492. raise NotImplementedError
  493. class OutputGraph(OutputGraphCommon):
  494. """
  495. Wrapper class to hold outputs of InstructionTranslator. Mainly the
  496. generated fx.Graph.
  497. OutputGraph is 1:1 with a frame being processed. Each frame is associated
  498. with some root InstructionTranslator. When user code calls a function,
  499. we construct a InliningInstructionTranslator that continues to write into
  500. the root InstructionTranslator's OutputGraph.
  501. """
  502. side_effects: SideEffects
  503. def __init__(
  504. self,
  505. code_options: dict[str, Any],
  506. compiler_fn: Optional[CompilerFn],
  507. root_tx: "InstructionTranslatorBase",
  508. export: bool,
  509. export_constraints: Sequence[_ConstraintTarget],
  510. frame_state: Any,
  511. local_scope: Scope,
  512. global_scope: Scope,
  513. f_code: CodeType,
  514. torch_function_mode_stack: list[torch.overrides.TorchFunctionMode],
  515. package: Optional["CompilePackage"],
  516. one_graph: bool = False,
  517. ) -> None:
  518. OutputGraphGuardsState.__init__(
  519. self,
  520. local_scope,
  521. global_scope,
  522. torch_function_mode_stack,
  523. guard_on_key_order=set(),
  524. input_source_to_sizes_strides={},
  525. dual_level=torch.autograd.forward_ad._current_level,
  526. functorch_layers=torch._functorch.pyfunctorch.retrieve_all_functorch_interpreters(),
  527. current_device=torch.utils._device.CURRENT_DEVICE,
  528. # initial_global_state is only None during NopTest.
  529. global_state_guard=torch._dynamo.convert_frame.initial_global_state
  530. or torch._C._dynamo.guards.GlobalStateGuard(),
  531. # These are set by @property instead, just initialize them as blank
  532. _guards=torch._guards.GuardsSet(),
  533. _aotautograd_guards=[],
  534. )
  535. self.tracers = [SubgraphTracer(self, is_export=export)]
  536. # Map from graph input's `Source` to its `VariableTracker` to
  537. # de-duplicate graph inputs by source and reuse the tracker
  538. self.input_source_to_var: dict[Source, VariableTracker] = {}
  539. # List of TensorVariables that are leaf tensors created in-graph
  540. # (e.g., nn.Parameter via tracable_create_parameter). These need to be
  541. # tracked separately from input_source_to_var for backward() auto-detection.
  542. self.leaf_var_creation_order: list[VariableTracker] = []
  543. self.export = export
  544. self.export_constraints = export_constraints # type: ignore[assignment]
  545. self.frame_state = frame_state
  546. self.cleanup_hooks: list[Callable[[], Any]] = []
  547. # compile_id is an id number for the current torch.compile
  548. self.compile_id: int = next(_compile_id_counter)
  549. # Set of globals installed via install_global* APIs
  550. self.installed_globals: set[str] = set()
  551. # TODO: maybe should just pass the entire f_code in here? Not
  552. # sure...
  553. self.co_fields = {
  554. "co_name": f_code.co_name,
  555. "co_filename": f_code.co_filename,
  556. "co_firstlineno": f_code.co_firstlineno,
  557. }
  558. self.region_tracker = GraphRegionTracker()
  559. # tracked_fakes says where any tensor that was wrapped to fake came
  560. # from. It is similar to GraphArg, in that all GraphArgs will get
  561. # will get added to TrackedFakes, but TrackedFakes also contains
  562. # GraphArgs that got pruned, and things like Tensor attributes which
  563. # aren't explicit graph inputs. Used by shape guard
  564. self.tracked_fakes: list[TrackedFake] = []
  565. shape_env = ShapeEnv(
  566. # Reference Cycle!
  567. # Share a reference to the list of TrackedFake.
  568. #
  569. # ShapeEnv needs this in order to be able to reproduce the call
  570. # to produce_guards at an arbitrary time point. That is because
  571. # TrackedFake instances may have its metadata changed throughout
  572. # the program execution.
  573. tracked_fakes=self.tracked_fakes,
  574. # We want to allow capture scalar outputs and allow_dynamic_output_shape_ops when fullgraph=True
  575. allow_scalar_outputs=one_graph or config.capture_scalar_outputs,
  576. allow_dynamic_output_shape_ops=one_graph
  577. or config.capture_dynamic_output_shape_ops,
  578. prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards,
  579. co_fields=self.co_fields,
  580. )
  581. # In export mode, we force the shape_env to strictly disallow any constraining
  582. # of the user marked dynamic dims
  583. import torch._functorch.config as _config
  584. with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
  585. fake_mode = torch._subclasses.FakeTensorMode(
  586. shape_env=shape_env,
  587. # TODO (tmanlaibaatar) Remove this once we always lift params and buffers
  588. allow_non_fake_inputs=bool(self.export),
  589. export=self.export,
  590. )
  591. self.tracing_context: TracingContext = TracingContext(fake_mode)
  592. self.tracing_context.traced_code.append(f_code)
  593. self.traced_code = self.tracing_context.traced_code
  594. self.dynamo_compile_id: Optional[CompileId] = (
  595. CompileContext.current_compile_id()
  596. )
  597. self.init_ambient_guards()
  598. # Map each tensor id to a list of sources. This is necessary because
  599. # tensor ids cannot be recovered from tracked fakes (in general).
  600. # We use this map to interpret (i.e., check for violations of) constraints,
  601. # specifically equality constraints, which have shared tensor ids in them.
  602. # This map should also be generally useful, e.g., for (de)serialization.
  603. self.tracked_fakes_id_to_source: dict[int, list[Source]] = (
  604. collections.defaultdict(list)
  605. )
  606. # Stores the full fqn of a param or buffer to the relevant source.
  607. self.param_name_to_source: Optional[dict[str, Source]] = {}
  608. self.side_effects = SideEffects(self)
  609. # Cached variable trackers. This makes symbolic analysis of LOAD_GLOBAL
  610. # and LOAD_ATTR for same python objects free.
  611. self.variable_tracker_cache: dict[Source, VariableTracker] = {}
  612. # Cache for inspect.signature results: function -> VariableTracker
  613. self.signature_cache: dict[Any, VariableTracker] = {}
  614. self.unique_var_id = itertools.count()
  615. self.code_options: dict[str, Any] = dict(code_options)
  616. self.output_instructions: list[Instruction] = []
  617. # used to track nodes that are added between calls of copy_graphstate
  618. # and restore_graphstate
  619. self.timestamp = 0
  620. # A list of register_finalizer_fns to apply to the output graph module
  621. self.register_finalizer_fns: list[Callable[[fx.GraphModule], None]] = []
  622. # Not checkpointed
  623. self.compiler_fn: Optional[CompilerFn] = compiler_fn
  624. self.root_tx = root_tx
  625. # Profiler state for tracking function trace timings
  626. self.profiler_state: Optional[DynamoProfilerState] = None
  627. self.package = package
  628. # Given a source, what are the user stacks of all locations that
  629. # accessed it?
  630. #
  631. # For efficiency, we only populate this:
  632. # - During export, and
  633. # - If the source could potentially lead to a spurious export input
  634. #
  635. # Feel free to populate this more frequently if other use-cases arise,
  636. # but be aware that we have to generate full stacks for each
  637. # recording!
  638. self.source_to_user_stacks: dict[Source, list[traceback.StackSummary]] = {}
  639. self._current_tx: list[InstructionTranslatorBase] = []
  640. self.cleanups: list[CleanupHook] = []
  641. self.should_exit = False
  642. self.unspec_variable_map: dict[str, UnspecializedPythonVariable] = {}
  643. # This returns false if TF Overall (both mode and subclass) is disabled OR that TF Mode stack is empty
  644. self.torch_function_mode_enabled = torch._C._is_torch_function_mode_enabled()
  645. # Tracks if the output graph has a user defined allowed function in the
  646. # graph. This is used later to determine if we should fallback to eager
  647. # for certain exceptions. THe idea is that if the user has applied
  648. # allow_in_graph, they would like to see the error instead of falling
  649. # back for backend errors.
  650. self.has_user_defined_allowed_in_graph = False
  651. # Tracks ALL grad_fn nodes that are consumed by torch.autograd.grad(outputs, inputs).
  652. # This is the set of all nodes reachable from outputs' grad_fns, excluding inputs' grad_fns
  653. # (since autograd.grad stops at inputs without consuming them).
  654. # Used to detect returning tensors connected to consumed grad_fns (would cause
  655. # "backward through graph a second time" error in aot_autograd).
  656. self.autograd_grad_consumed_grad_fns: set[torch.autograd.graph.Node] = set()
  657. # Tracks a list of called ops that were not tagged with "pt2_compliant_tag".
  658. # This information is useful for logging.
  659. self.non_compliant_ops: set[torch._ops.OpOverload] = set({})
  660. # Tracks a list of called custom ops that were tagged with "pt2_compliant_tag".
  661. # This information is useful for logging.
  662. self.compliant_custom_ops: set[torch._ops.OpOverload] = set({})
  663. # We save the global torch state here to be restored in case of graph
  664. # breaks. The relevant issue is seen here
  665. # https://github.com/pytorch/pytorch/pull/100570#issuecomment-1543427086
  666. # where inlining of a function changes the global state (because of the
  667. # presence of torch.no_grad) and there is a graph break.
  668. self.save_global_state()
  669. # Tracks the original FQNs of the constant tensors from the original graph,
  670. # i.e. buffers and parameters.
  671. self.dynamo_flat_name_to_original_fqn: dict[str, str] = {}
  672. # All calls to random() are replaced with a single call to __gen_rand_values
  673. # functions that returns a tuple of random values for each original call.
  674. # random_calls tracks calls to random() and random_values_var stores the name of
  675. # the variable that stores __gen_rand_values results.
  676. self.random_calls: list[
  677. tuple[Callable[..., object], tuple[object, ...], dict[str, object]]
  678. ] = []
  679. self.random_values_var: Any = None
  680. # Bytecode to insert right before we call the graph
  681. self.pregraph_bytecode: list[Instruction] = []
  682. # Use to pass values to backward hooks when using compiled autograd
  683. self.backward_state: dict[str, VariableTracker] = {}
  684. self.backward_state_proxy: Optional[torch.fx.Proxy] = None
  685. self.backward_state_var: Optional[str] = None
  686. # pyrefly: ignore [bad-override]
  687. self.name_of_builtins_dict_key_in_fglobals: str = (
  688. self.install_builtins_dict_in_fglobals()
  689. )
  690. self.compiler_trace_stack = contextlib.ExitStack()
  691. self.bytecode_tracing_timings = BytecodeTracingTimings()
  692. # These are the ambient, currently-global saved_tensor_hooks stashed in autograd,
  693. # that are set for the entire duration of the compiled region.
  694. # This is an invariant today because we graph break on the saved_tensor_hook
  695. # context manager inside a compiled region
  696. self.saved_tensors_hooks_subgraph_names: Optional[list[str]] = (
  697. self.maybe_install_saved_tensors_hooks_subgraphs()
  698. )
  699. # mangled alias -> module fqn name
  700. self.import_sources: dict[str, str] = {}
  701. self.export_metadata = ExportMetaData()
  702. # Set of inlined unspecialized modules names to generate the
  703. # dynamo_flat_name_to_original_fqn mapping.
  704. self.used_inlined_inbuilt_modules_names: OrderedSet[str] = OrderedSet()
  705. self.attr_source_cache: dict[tuple[Source, str], AttrSource] = {}
  706. def get_chained_attr_source(self, base: Source, path: str) -> AttrSource:
  707. parts = path.split(".")
  708. key = (base, parts[0])
  709. if key not in self.attr_source_cache:
  710. self.attr_source_cache[key] = AttrSource(base, parts[0])
  711. result = self.attr_source_cache[key]
  712. for part in parts[1:]:
  713. key = (result, part)
  714. if key not in self.attr_source_cache:
  715. self.attr_source_cache[key] = AttrSource(result, part)
  716. result = self.attr_source_cache[key]
  717. return result
  718. def get_chained_param_buffer_source(
  719. self, base: Source, path: str
  720. ) -> "ParamBufferSource":
  721. parts = path.rsplit(".", 1)
  722. if len(parts) == 1:
  723. return ParamBufferSource(base, path)
  724. intermediate_base = self.get_chained_attr_source(base, parts[0])
  725. return ParamBufferSource(intermediate_base, parts[1])
  726. def mark_bytecode_tracing_start(self) -> None:
  727. self.compiler_trace_stack.enter_context(
  728. dynamo_timed(
  729. "bytecode_tracing",
  730. log_pt2_compile_event=True,
  731. )
  732. )
  733. # Start profiler timing for the root function
  734. if config.dynamo_profiler:
  735. from torch._dynamo.dynamo_profiler import DynamoProfilerState
  736. if self.profiler_state is None:
  737. self.profiler_state = DynamoProfilerState()
  738. code = self.root_tx.f_code
  739. self.profiler_state.push(
  740. code.co_name,
  741. code.co_filename,
  742. code.co_firstlineno,
  743. time.time_ns(),
  744. )
  745. def mark_bytecode_tracing_stop(self) -> None:
  746. self.bytecode_tracing_timings.report_and_reset()
  747. self.compiler_trace_stack.close()
  748. # Record profiler timing for the root function and dump stats
  749. if config.dynamo_profiler and self.profiler_state is not None:
  750. from torch._dynamo.dynamo_profiler import FunctionTraceTiming
  751. stack_entry = self.profiler_state.pop()
  752. trace_end_ns = time.time_ns()
  753. if stack_entry is not None:
  754. cumtime_ns = trace_end_ns - stack_entry.start_time_ns
  755. tottime_ns = cumtime_ns - stack_entry.child_time_ns
  756. timing = FunctionTraceTiming(
  757. func_name=stack_entry.func_name,
  758. filename=stack_entry.filename,
  759. firstlineno=stack_entry.firstlineno,
  760. cumtime_ns=cumtime_ns,
  761. tottime_ns=tottime_ns,
  762. bytecode_count=len(self.root_tx.f_code.co_code),
  763. inline_depth=0,
  764. caller_func_name=None,
  765. caller_filename=None,
  766. caller_firstlineno=None,
  767. is_primitive_call=stack_entry.is_primitive_call,
  768. call_stack=(),
  769. )
  770. self.profiler_state.record_timing(timing)
  771. # Dump profiler stats
  772. output_file = None
  773. if isinstance(config.dynamo_profiler, str):
  774. output_file = config.dynamo_profiler
  775. self.profiler_state.dump_stats(output_file)
  776. def install_builtins_dict_in_fglobals(self) -> str:
  777. f_builtins = get_builtins_dict(self.global_scope)
  778. return self.install_global("__builtins_dict__", f_builtins)
  779. def add_backward_state_hook(
  780. self, hook: VariableTracker, prefix: str = "hook"
  781. ) -> tuple[str, torch.fx.Proxy]:
  782. name = f"{prefix}{len(self.backward_state)}"
  783. assert name not in self.backward_state
  784. self.backward_state[name] = hook
  785. return name, self.get_backward_state_proxy()
  786. def get_backward_state_proxy(self) -> torch.fx.Proxy:
  787. if self.backward_state_proxy is None:
  788. if self.export:
  789. unimplemented(
  790. gb_type="backward_state does not support export",
  791. context="",
  792. explanation="Compiled autograd doesn't work with `torch.export`.",
  793. hints=[],
  794. )
  795. example_value = BackwardState()
  796. self.backward_state_proxy = self.root_tracer.create_graph_input(
  797. "dynamo_backward_state",
  798. type(example_value),
  799. example_value,
  800. source=BackwardStateSource(),
  801. )
  802. self.backward_state_proxy.node.meta["grapharg"] = BackwardStateGraphArg()
  803. self.backward_state_var = self.new_var()
  804. return self.backward_state_proxy
  805. # This gets its own helper function so guards DEBUG logs are more informative
  806. def init_ambient_guards(self) -> None:
  807. # Register a SHAPE_ENV guard to make sure we setup shape guards
  808. # that show up in ShapeEnv
  809. self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
  810. self.guards.add(
  811. GlobalStateSource().make_guard(GuardBuilder.DETERMINISTIC_ALGORITHMS)
  812. )
  813. self.guards.add(GlobalStateSource().make_guard(GuardBuilder.GRAD_MODE))
  814. self.guards.add(GlobalStateSource().make_guard(GuardBuilder.DEFAULT_DEVICE))
  815. self.guards.add(GlobalStateSource().make_guard(GuardBuilder.GLOBAL_STATE))
  816. self.guards.add(
  817. GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE)
  818. )
  819. ci = torch._C._functorch.peek_interpreter_stack()
  820. if ci is not None:
  821. self.guards.add(
  822. GlobalStateSource().make_guard(GuardBuilder.FUNCTORCH_STACK_MATCH)
  823. )
  824. if not torch._dynamo.compiled_autograd.in_compiled_autograd_region:
  825. self.guards.add(
  826. GlobalStateSource().make_guard(
  827. GuardBuilder.AUTOGRAD_SAVED_TENSORS_HOOKS
  828. )
  829. )
  830. def maybe_install_saved_tensors_hooks_subgraphs(self) -> Optional[list[str]]:
  831. if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
  832. return None
  833. get_hooks = torch._functorch._aot_autograd.utils.top_saved_tensors_hooks
  834. are_inline_hooks = (
  835. torch._functorch._aot_autograd.utils.saved_tensors_hooks_are_inlineable
  836. )
  837. hooks = get_hooks()
  838. if not are_inline_hooks(hooks):
  839. return None
  840. # If GraphModule provided by user contains fx.wrap,
  841. # We can only rely on user provided cache hash in this case.
  842. # If user did not provide cache hash - then we always bypass cache.
  843. pack_gm, unpack_gm = hooks
  844. pack_subgraph_name = self.install_subgraph(
  845. "saved_tensors_hooks_pack",
  846. torch.fx.GraphModule(self.nn_modules, pack_gm.graph),
  847. )
  848. unpack_subgraph_name = self.install_subgraph(
  849. "saved_tensors_hooks_unpack",
  850. torch.fx.GraphModule(self.nn_modules, unpack_gm.graph),
  851. )
  852. assert pack_subgraph_name == "saved_tensors_hooks_pack_0"
  853. assert unpack_subgraph_name == "saved_tensors_hooks_unpack_0"
  854. return [pack_subgraph_name, unpack_subgraph_name]
  855. def synthetic_graph_input(
  856. self, fn: Callable[..., Any], args: tuple[Any, ...]
  857. ) -> VariableTracker:
  858. """
  859. call fn(*args) before the graph runs and turn the result into a fake input.
  860. """
  861. example_value = fn(*args)
  862. varname = self.new_var()
  863. cg = PyCodegen(self.root_tx)
  864. cg.add_push_null(
  865. lambda: cg.load_import_from(
  866. fn.__module__,
  867. fn.__name__,
  868. )
  869. )
  870. cg.foreach(map(variables.ConstantVariable.create, args))
  871. cg.call_function(len(args), False)
  872. cg.store(varname)
  873. self.pregraph_bytecode.extend(cg.get_instructions())
  874. source = SyntheticLocalSource(varname)
  875. result = VariableTracker.build(self.root_tx, example_value, source)
  876. # Realize the VT because we will delete the guards on it in the next line.
  877. result = result.realize()
  878. TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source(
  879. source
  880. )
  881. return result
  882. def add_cleanup_hook(self, fn: Callable[[], Any]) -> None:
  883. self.cleanup_hooks.append(fn)
  884. def call_cleanup_hooks(self) -> None:
  885. for hook in reversed(self.cleanup_hooks):
  886. hook()
  887. self.cleanup_hooks.clear()
  888. @property
  889. def root_tracer(self) -> "SubgraphTracer":
  890. return self.tracers[0]
  891. @property
  892. def current_tracer(self) -> "SubgraphTracer":
  893. return self.tracers[-1]
  894. def is_root_tracer(self) -> bool:
  895. # Helper to tell if we are inside the higher order operator tracing.
  896. return len(self.tracers) == 1
  897. @property
  898. def graph(self) -> torch.fx.Graph:
  899. return self.current_tracer.graph
  900. # TODO(rzou): can delete after we refactor speculate_subgraph to use nested GraphTracer.
  901. @graph.setter
  902. def graph(self, value: torch.fx.Graph) -> None:
  903. self.current_tracer.graph = value
  904. @property
  905. def input_name_to_proxy(self) -> dict[str, fx.Proxy]:
  906. return self.current_tracer.input_name_to_proxy
  907. @property
  908. def real_value_cache(self) -> dict[fx.Node, torch.Tensor]:
  909. return self.current_tracer.real_value_cache
  910. @property
  911. def bound_symbols(self) -> dict[sympy.Symbol, Union[torch.fx.Proxy, "LazyProxy"]]:
  912. return self.current_tracer.bound_symbols
  913. # If you are here, and you're looking for create_graph_input,
  914. # to avoid ambiguity, please call one of the following:
  915. # - self.current_tracer.create_graph_input
  916. # - self.root_tracer.create_graph_input
  917. # See NOTE [HigherOrderOperator tracing design] for more context.
  918. def create_proxy(self, *args: Any, **kwargs: Any) -> torch.fx.Proxy:
  919. return self.current_tracer.create_proxy(*args, **kwargs)
  920. def create_node(self, *args: Any, **kwargs: Any) -> torch.fx.Node:
  921. return self.current_tracer.create_node(*args, **kwargs)
  922. def remove_node(self, *args: Any, **kwargs: Any) -> None:
  923. return self.current_tracer.remove_node(*args, **kwargs)
  924. @contextlib.contextmanager
  925. def subtracer(
  926. self,
  927. source_target: Optional[Target],
  928. prior_tracer: Optional["SubgraphTracer"],
  929. description: Optional[str] = None,
  930. ) -> Generator["SubgraphTracer", None, None]:
  931. new_scope_ctx = enter_new_scope()
  932. try:
  933. if prior_tracer:
  934. # Lineage MUST stay preserved
  935. assert prior_tracer.parent is self.current_tracer
  936. new_scope_ctx.__enter__()
  937. tracer = (
  938. prior_tracer
  939. if prior_tracer
  940. else SubgraphTracer(
  941. self,
  942. parent=self.current_tracer,
  943. source_target=source_target,
  944. is_export=self.current_tracer.is_export,
  945. description=description,
  946. )
  947. )
  948. self.tracers.append(tracer)
  949. yield tracer
  950. finally:
  951. new_scope_ctx.__exit__(None, None, None)
  952. self.tracers.pop()
  953. @property
  954. def output(self) -> "OutputGraph":
  955. return self
  956. @property
  957. def fake_mode(self) -> torch._subclasses.FakeTensorMode:
  958. assert self.tracing_context.fake_mode is not None
  959. return self.tracing_context.fake_mode
  960. @property
  961. def shape_env(self) -> ShapeEnv:
  962. assert self.tracing_context.fake_mode is not None
  963. assert self.tracing_context.fake_mode.shape_env is not None
  964. return self.tracing_context.fake_mode.shape_env
  965. @property
  966. def guards(self) -> torch._guards.GuardsSet:
  967. return self.tracing_context.guards_context.dynamo_guards
  968. @property
  969. def nn_modules(self) -> dict[str, Any]:
  970. return self.tracing_context.module_context.nn_modules
  971. @property
  972. def aotautograd_guards(self) -> list[torch._guards.GuardEnvExpr]:
  973. return self.tracing_context.guards_context.aotautograd_guards
  974. def save_global_state(
  975. self, out: Optional[dict[str, tuple[Callable[..., Any], bool]]] = None
  976. ) -> None:
  977. """
  978. Saves to out if it is provided. Else saves to the tracing context's global_state.
  979. """
  980. global_state = cast(
  981. dict[str, tuple[Callable[..., Any], bool]],
  982. (
  983. out
  984. if out is not None
  985. else self.tracing_context.global_context.global_state
  986. ),
  987. )
  988. global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled())
  989. global_state["autocast_enabled"] = (
  990. functools.partial(torch.set_autocast_enabled, "cuda"),
  991. torch.is_autocast_enabled("cuda"),
  992. )
  993. global_state["autocast_cpu_enabled"] = (
  994. functools.partial(torch.set_autocast_enabled, "cpu"),
  995. torch.is_autocast_enabled("cpu"),
  996. )
  997. global_state["autocast_gpu_dtype"] = ( # type:ignore[assignment]
  998. functools.partial(torch.set_autocast_dtype, "cuda"),
  999. torch.get_autocast_dtype("cuda"),
  1000. )
  1001. global_state["autocast_cpu_dtype"] = ( # type:ignore[assignment]
  1002. functools.partial(torch.set_autocast_dtype, "cpu"),
  1003. torch.get_autocast_dtype("cpu"),
  1004. )
  1005. global_state["autocast_cache_enabled"] = (
  1006. torch.set_autocast_cache_enabled,
  1007. torch.is_autocast_cache_enabled(),
  1008. )
  1009. def push_tx(self, tx: "InstructionTranslatorBase") -> None:
  1010. self._current_tx.append(tx)
  1011. def pop_tx(self) -> "InstructionTranslatorBase":
  1012. return self._current_tx.pop()
  1013. @property
  1014. def current_tx(self) -> "InstructionTranslatorBase":
  1015. return self.root_tx if not self._current_tx else self._current_tx[-1]
  1016. def count_calls(self) -> int:
  1017. return count_calls(self.graph)
  1018. def is_empty_graph(self) -> bool:
  1019. return len(list(self.graph.nodes)) == 0
  1020. def has_outputs(self) -> bool:
  1021. return len([x for x in self.graph.nodes if x.op == "output"]) > 0
  1022. def get_submodule(self, keys: str) -> Union[torch.nn.Module, Any]:
  1023. assert keys
  1024. obj: Union[torch.nn.Module, dict[str, torch.nn.Module]] = self.nn_modules
  1025. for k in keys.split("."):
  1026. if isinstance(obj, dict):
  1027. obj = obj[k]
  1028. else:
  1029. obj = getattr(obj, k)
  1030. return obj
  1031. def new_var(self, name: str = "tmp") -> str:
  1032. existing = set(self.code_options["co_varnames"])
  1033. # In common case, this will be O(1)
  1034. while True:
  1035. var = f"{name}_{next(self.unique_var_id)}"
  1036. if var not in existing:
  1037. self.code_options["co_varnames"] += (var,)
  1038. return var
  1039. def update_co_names(self, name: str) -> None:
  1040. """Ensure self.code_options.co_names contains name"""
  1041. if name not in self.code_options["co_names"]:
  1042. self.code_options["co_names"] += (name,)
  1043. @staticmethod
  1044. def module_key_name(*names: Any) -> str:
  1045. # create a new unique name
  1046. name = "_".join(map(str, names))
  1047. # Strip _buffers[..]/_parameters[..]/_modules[..] names
  1048. name = re.sub(
  1049. r"\._(?:modules|parameters|buffers)\[(['\"])([^'\"\]]+)\1\]", r".\2", name
  1050. )
  1051. # Replace getattr(a, b) with a.b
  1052. name = re.sub(
  1053. r"getattr\(\s*([^,]+?)\s*,\s*(['\"])([^'\"]+)\2\s*\)", r"\1.\3", name
  1054. )
  1055. # Strip the guard lookup L/G access
  1056. name = re.sub(r"^[GL]\['?(.*?)'?\]$", r"\1", name)
  1057. # e.g. replace abc.xyz[123].qkv with abc.xyz_123.qkv
  1058. name = re.sub(r"\[(\d+)\]", r"_\g<1>", name)
  1059. # e.g. replace abc.xyz_123.qkv with abc_xyz_123_qkv
  1060. name = re.sub(r"[^a-zA-Z0-9]", "_", name)
  1061. if not name or not name[0].isalpha():
  1062. name = "sub" + name
  1063. return name
  1064. def register_static_attr_and_return_proxy(
  1065. self, attr_prefix: str, attr_value: Any
  1066. ) -> fx.Proxy:
  1067. # Check if the module already exists, if it does, return the already
  1068. # added proxy. This is important for executorch tests.
  1069. if isinstance(attr_value, torch.nn.Module):
  1070. for name, mod in self.nn_modules.items():
  1071. if mod is attr_value:
  1072. proxy = self.create_proxy("get_attr", name, (), {})
  1073. return proxy
  1074. attr_name = get_unique_name_wrt(attr_prefix, self.nn_modules)
  1075. # TODO `nn_modules` has been historically overloaded to store a lot more
  1076. # than just nn module objects, fix that.
  1077. self.nn_modules[attr_name] = attr_value
  1078. proxy = self.create_proxy("get_attr", attr_name, (), {})
  1079. set_example_value(proxy.node, attr_value)
  1080. return proxy
  1081. def register_attr_or_module(
  1082. self,
  1083. target: Union[torch.nn.Module, torch.Tensor, Any],
  1084. *names: Any,
  1085. **options: Any,
  1086. ) -> VariableTracker:
  1087. if is_dynamic_nn_module(target, self.export):
  1088. # Instead of returning UnspecializedNNModuleVariable, call
  1089. # VariableTracker.build so that it is tracked for mutation.
  1090. return VariableTracker.build(self.current_tx, target, **options)
  1091. options = dict(options)
  1092. assert "source" in options
  1093. source = options["source"]
  1094. assert not isinstance(source, ParamBufferSource)
  1095. if isinstance(target, torch.Tensor):
  1096. tracer = self.current_tracer
  1097. if not self.is_root_tracer():
  1098. # For higher order ops, we don't want to insert the get_attr in
  1099. # innermost graph. Instead, we want to raise the params/buffers
  1100. # as inputs to the higher-order graph, and register them as
  1101. # get_attrs in the root tracer.
  1102. # Note that Dynamo will still call lift_tracked_freevar_to_input
  1103. # when these inputs are encountered for the inner graph. The
  1104. # only difference is what happens at the root tracer for
  1105. # nn.Parameters vs free inputs. The free inputs are registered
  1106. # as placeholders in the root graph, whereas the nn.Parameters
  1107. # are registered as get_attr nodes in the root graph.
  1108. tracer = self.root_tracer
  1109. def wrap_name(module_key: str) -> VariableTracker:
  1110. assert self.param_name_to_source is not None
  1111. self.param_name_to_source[module_key] = source
  1112. # Check if the attr has already been registered. This can happen
  1113. # when two different sources point to the same tensor.
  1114. assert self.root_tx is not None
  1115. if target in self.root_tx.output.side_effects:
  1116. return self.root_tx.output.side_effects[target]
  1117. if get_static_address_type(target) == "guarded" and not isinstance(
  1118. source, NumpyTensorSource
  1119. ):
  1120. install_guard(source.make_guard(GuardBuilder.ID_MATCH))
  1121. elif not is_constant_source(source):
  1122. install_guard(source.make_guard(GuardBuilder.TENSOR_MATCH))
  1123. vt = wrap_fx_proxy(
  1124. self.root_tx,
  1125. tracer.create_proxy("get_attr", module_key, (), {}),
  1126. example_value=target,
  1127. **options,
  1128. )
  1129. # Track the object so to avoid duplicate registration in case of
  1130. # different sources pointing to the same tensor object.
  1131. vt = self.root_tx.output.side_effects.track_object_existing(target, vt)
  1132. assert "tensor_dict" not in vt.as_proxy().node.meta
  1133. # pyrefly: ignore [bad-argument-type]
  1134. vt.as_proxy().node.meta["tensor_dict"] = _extract_tensor_dict(target)
  1135. return vt
  1136. elif isinstance(target, torch.nn.Module):
  1137. assert isinstance(target, torch.nn.Module)
  1138. if source:
  1139. install_guard(source.make_guard(GuardBuilder.NN_MODULE))
  1140. def wrap_name(module_key: str) -> VariableTracker:
  1141. # pyrefly: ignore [bad-argument-type]
  1142. return NNModuleVariable(type(target), module_key, target, **options)
  1143. else:
  1144. # This is Dynamo created graph module, e.g., graph module coming
  1145. # from higher order ops. NNModuleVariable tracker can't be
  1146. # sourceless, so let's return a unspecializedNNModule variable
  1147. # tracker.
  1148. def wrap_name(module_key: str) -> VariableTracker:
  1149. # pyrefly: ignore[bad-argument-type]
  1150. return variables.UnspecializedNNModuleVariable(target, **options)
  1151. elif isinstance(target, (torch.SymInt, torch.SymFloat)):
  1152. # HACKY CODE REGION BEGIN
  1153. # WE ARE PIGGYBACKING ON EXISTING INFRA TO REGISTER ATTRS
  1154. # This ultimately gets written to self.nn_modules, which is unfortunate
  1155. # Attrs that are tenors and symints and such need to be migrated to have their
  1156. # own storage
  1157. # alas, this is like this for now
  1158. def wrap_name(module_key: str) -> VariableTracker:
  1159. return SymNodeVariable.create(
  1160. self.root_tx,
  1161. self.create_proxy("get_attr", module_key, (), {}),
  1162. sym_num=target,
  1163. **options,
  1164. )
  1165. # HACKY CODE REGION END
  1166. else:
  1167. def wrap_name(module_key: str) -> VariableTracker:
  1168. self.output.update_co_names(module_key)
  1169. self.global_scope[module_key] = target
  1170. return VariableTracker.build(
  1171. self, # type: ignore[arg-type]
  1172. target,
  1173. ConstantSource(source_name=module_key),
  1174. )
  1175. for k, v in self.nn_modules.items():
  1176. if v is target:
  1177. # it already exists
  1178. return wrap_name(k)
  1179. name = OutputGraph.module_key_name(*names)
  1180. name = get_unique_name_wrt(name, self.nn_modules, self.global_scope)
  1181. self.nn_modules[name] = target
  1182. if isinstance(target, torch.nn.Module):
  1183. def register_leaf_name(leaf_name: str) -> None:
  1184. assert self.param_name_to_source is not None
  1185. new_source = self.get_chained_param_buffer_source(source, leaf_name)
  1186. new_name = f"{name}.{leaf_name}"
  1187. self.param_name_to_source[new_name] = new_source
  1188. if isinstance(source, LocalSource):
  1189. self.dynamo_flat_name_to_original_fqn[
  1190. OutputGraph.module_key_name(new_source.name)
  1191. ] = leaf_name
  1192. # annoying, but there are cases when we do not have parameters
  1193. # see test_nn_moduledict_contains
  1194. if hasattr(target, "_parameters"):
  1195. for leaf_name, _ in target.named_parameters():
  1196. register_leaf_name(leaf_name)
  1197. if hasattr(target, "_buffers"):
  1198. for leaf_name, _ in target.named_buffers():
  1199. register_leaf_name(leaf_name)
  1200. return wrap_name(name)
  1201. def handle_aliases_for_stolen_lists(
  1202. self, tx: "InstructionTranslatorBase"
  1203. ) -> tuple[list[Instruction], dict[Source, Source]]:
  1204. # If list inputs are stolen, but still needed after the function call, create aliases to keep them alive
  1205. maybe_gm = self.local_scope.get("self")
  1206. stolen_list_names = get_locals_to_steal(maybe_gm)
  1207. if not stolen_list_names:
  1208. return [], {}
  1209. alias_insts = []
  1210. needs_alias: dict[str, list[VariableTracker]] = {}
  1211. queue = [
  1212. *tx.stack,
  1213. *tx.symbolic_locals.values(),
  1214. *self.side_effects.store_attr_mutations.keys(),
  1215. ]
  1216. while queue:
  1217. x = queue.pop()
  1218. if isinstance(x, BaseListVariable):
  1219. assert isinstance(x.items, list)
  1220. queue += x.items
  1221. continue
  1222. if not (
  1223. (
  1224. x not in self.side_effects.store_attr_mutations
  1225. or isinstance(x.mutation_type, AttributeMutationExisting)
  1226. )
  1227. and isinstance(x.source, GetItemSource)
  1228. and isinstance(x.source.base, LocalSource)
  1229. and x.source.base.local_name in stolen_list_names
  1230. ):
  1231. continue
  1232. stolen_name = x.source.base.local_name
  1233. if stolen_name not in needs_alias:
  1234. needs_alias[stolen_name] = []
  1235. needs_alias[stolen_name].append(x)
  1236. # pyrefly: ignore [implicit-any]
  1237. visited = {}
  1238. overridden_sources: dict[Source, Source] = {}
  1239. for arg in self.graphargs:
  1240. if not (
  1241. isinstance(arg._example, list)
  1242. and isinstance(arg.source, LocalSource)
  1243. and arg.source.local_name in needs_alias
  1244. ):
  1245. continue
  1246. # arg is a list that will be cleared by the compiled function
  1247. list_name = arg.source.local_name
  1248. assert list_name in self.code_options["co_varnames"]
  1249. for x in needs_alias[list_name]:
  1250. # Skip if already handled.
  1251. if x.source in overridden_sources:
  1252. continue
  1253. # A small codegen optimization because we might have different
  1254. # VariableTrackers that share the same source.
  1255. assert x.source is not None
  1256. list_idx = x.source.index # type: ignore[attr-defined]
  1257. if list_idx not in visited:
  1258. alias_name = self.new_var(
  1259. f"{list_name}_ref"
  1260. ) # self.new_var already adds unique id suffix
  1261. visited[list_idx] = alias_name
  1262. # bytecode of `alias_name = list_name[list_idx]`
  1263. alias_insts.extend(
  1264. [
  1265. create_instruction("LOAD_FAST", argval=list_name),
  1266. create_load_const(list_idx),
  1267. create_binary_subscr(),
  1268. create_instruction("STORE_FAST", argval=alias_name),
  1269. ]
  1270. )
  1271. # operate on alias, handled by suffix codegen
  1272. assert x.source is not None
  1273. old_source = x.source
  1274. overridden_sources[old_source] = LocalSource(visited[list_idx])
  1275. # NOTE: we need `overridden_sources` because (1) we want to codegen for
  1276. # these list items to use the new local source, but (2) we want to avoid
  1277. # updating `source` in place because that might break invariants in
  1278. # other parts of Dynamo like guards.
  1279. return alias_insts, overridden_sources
  1280. def _get_stack_values_to_restore(
  1281. self, tx: "InstructionTranslatorBase", stack_pops: int
  1282. ) -> tuple[list[VariableTracker], StackLocalsMetadata]:
  1283. """
  1284. Gets the stack + locals values belonging to tx that need to be restored.
  1285. Also prunes dead tx locals and realizes all VTs in the tx's stack.
  1286. NullVariables in stack/locals will NOT be restored, unless they are the top `stack_pops`
  1287. elements of the stack - it is expected that the next instruction to run will pop the top
  1288. `stack_pops` elements of the stack, so we should codegen NULLs.
  1289. Returns:
  1290. - stack_values: stack and locals values that need to be restored
  1291. - meta: locations of NULLs and ContextWrappingVariables in the stack/locals
  1292. (ignores the top `stack_pops` values on the stack)
  1293. """
  1294. tx.prune_dead_locals()
  1295. stack_values = []
  1296. meta = StackLocalsMetadata()
  1297. def ctx_exit_check(var: VariableTracker) -> None:
  1298. if type.__instancecheck__(variables.WithExitFunctionVariable, var):
  1299. raise AssertionError(
  1300. "Attempted to reconstruct WithExitFunctionVariable outside the stack"
  1301. )
  1302. # realize any unrealized tensor VTs in case they
  1303. # need to be added to self.nn_modules as attributes
  1304. for i, value in enumerate(tx.stack):
  1305. # Allow lazy constants through for values being returned (top of stack)
  1306. allow_lazy_constant = len(tx.stack) - i <= stack_pops
  1307. variables.LazyVariableTracker.realize_all(
  1308. value, allow_lazy_constant=allow_lazy_constant
  1309. )
  1310. # Do not allow non-stack WithExitFunctionVariable reconstructions
  1311. if not isinstance(value, variables.WithExitFunctionVariable):
  1312. VariableTracker.visit(ctx_exit_check, value)
  1313. # ignore top `stack_pops` values on the stack
  1314. if allow_lazy_constant:
  1315. stack_values.append(value)
  1316. continue
  1317. if isinstance(value, NullVariable):
  1318. meta.stack_null_idxes.append(i)
  1319. else:
  1320. stack_values.append(value)
  1321. if isinstance(value, ContextWrappingVariable):
  1322. target_values = (
  1323. () if value.target_values is None else tuple(value.target_values)
  1324. )
  1325. # NOTE: track index in stack after NULLs have been removed
  1326. meta.stack_ctx_args.append((len(stack_values) - 1, target_values))
  1327. meta.stack_ctx_idxes_orig.append(i)
  1328. meta.num_stack = len(stack_values)
  1329. cell_and_freevars = set(tx.cellvars() + tx.freevars())
  1330. # NB: Typically (i.e., for graph compile from RETURN_VALUE),
  1331. # symbolic_locals will be empty at this point, as prune_dead_locals
  1332. # will clear out all of symbolic_locals because RETURN_VALUE is the
  1333. # last instruction and no more locals are used. The fanciness here
  1334. # is only needed for partial graphs.
  1335. # NOTE: All cell and free variables are represented as CellVariable,
  1336. # so checks for NULLs and context managers in the case of codegen'ing resume
  1337. # functions will not be performed on them. This is expected behavior.
  1338. for k, v in tx.symbolic_locals.items():
  1339. # Do not reconstruct WithExitFunctionVariable!
  1340. VariableTracker.visit(ctx_exit_check, v)
  1341. # Note! this explicitly uses .local_name for matching
  1342. # Failure to do so will cause spurious registrations in val_to_names.
  1343. # This will in turn result in spurious variables showing up in the graph.
  1344. # This was very tricky to debug. For an example, dump the graph at call_user_compiler
  1345. # while running test_subgraphs.py
  1346. # Do not include top-frame unmodified locals here - otherwise, the compiled graph may
  1347. # erroneously include them as part of the return. We manually codegen them afterward.
  1348. if (
  1349. isinstance(v.source, LocalSource)
  1350. and v.source.local_name == k
  1351. and tx is self.root_tx
  1352. ):
  1353. continue
  1354. # Do not load cell/free vars
  1355. if k in cell_and_freevars:
  1356. continue
  1357. # Do not load variable if it is NULL.
  1358. if sys.version_info >= (3, 12):
  1359. # NOTE: do not use isinstance, since it realizes lazy VT's
  1360. # Continuation function will load the NULL for v.
  1361. if type.__instancecheck__(NullVariable, v):
  1362. meta.locals_null_keys.append(k)
  1363. continue
  1364. else:
  1365. # A variable should never be NULL in < 3.12
  1366. assert not type.__instancecheck__(NullVariable, v)
  1367. meta.locals_names[k] = len(meta.locals_names)
  1368. if isinstance(v, ContextWrappingVariable):
  1369. target_values = (
  1370. () if v.target_values is None else tuple(v.target_values)
  1371. )
  1372. meta.locals_ctx_args.append((k, target_values))
  1373. stack_values.append(v)
  1374. return stack_values, meta
  1375. def compile_subgraph(
  1376. self,
  1377. tx: "InstructionTranslatorBase",
  1378. reason: GraphCompileReason,
  1379. stack_pops: int = 0,
  1380. ) -> list[StackLocalsMetadata]:
  1381. """
  1382. Compiles the current subgraph, with inputs w.r.t. self.root_tx, and codegens:
  1383. - Call the compiled subgraph
  1384. - Apply side effects
  1385. - Codegen stack and locals
  1386. - Store the locals
  1387. Python does not allow NULL to be an arg to a function, so we do not codegen NULLs on the stack,
  1388. unless the value is one of the top `stack_pops` values on the stack (these values are expected to be
  1389. popped immediately after this generated code. The prologue of the resume function is expected to restore
  1390. any dropped NULLs.
  1391. Returns stack indices and locals keys where we dropped NULLs, and where we found inactive context manager objects.
  1392. """
  1393. assert self.root_tx is not None
  1394. if not config.nested_graph_breaks:
  1395. # expect to only compile 1 frame
  1396. assert self.root_tx is tx
  1397. # bytecode tracing has finished. Pop the context manager for dynamo_timed
  1398. self.mark_bytecode_tracing_stop()
  1399. self.compile_subgraph_reason = reason
  1400. self.should_exit = True
  1401. log.debug("COMPILING GRAPH due to %s", reason)
  1402. # prefix instructions (Python 3.11+)
  1403. prefix_insts: list[Instruction] = []
  1404. if sys.version_info >= (3, 11):
  1405. for inst in self.root_tx.prefix_insts:
  1406. if inst.opname == "COPY_FREE_VARS":
  1407. prefix_insts.append(
  1408. create_instruction(
  1409. "COPY_FREE_VARS",
  1410. arg=len(self.root_tx.code_options["co_freevars"]),
  1411. )
  1412. )
  1413. else:
  1414. prefix_insts.append(copy.copy(inst))
  1415. # stack values and restore vars for each frame are pushed in reverse order
  1416. # i.e. last element corresponds to root frame (1),
  1417. # first element corresponds to current frame (N)
  1418. all_stack_values = []
  1419. all_stack_locals_metas = []
  1420. cur_tx: Optional[InstructionTranslatorBase] = tx
  1421. while cur_tx is not None:
  1422. # this should have been checked by the caller
  1423. assert all(block.can_restore() for block in cur_tx.block_stack)
  1424. stack_values, meta = self._get_stack_values_to_restore(
  1425. cur_tx, stack_pops if cur_tx is tx else 0
  1426. )
  1427. all_stack_values.append(stack_values)
  1428. all_stack_locals_metas.append(meta)
  1429. # Exit from all context manager variables to make sure global state is restored
  1430. for block in reversed(cur_tx.block_stack):
  1431. block.exit(cur_tx, is_graph_break=reason.graph_break)
  1432. cur_tx = cur_tx.parent
  1433. # "Garbage collect the heap".
  1434. self.side_effects.prune_dead_object_new(tx)
  1435. self.add_output_instructions(prefix_insts)
  1436. assert not (self.pregraph_bytecode and self.export), (
  1437. "export does not support pregraph_bytecode"
  1438. )
  1439. self.add_output_instructions(self.pregraph_bytecode)
  1440. alias_insts, overridden_sources = self.handle_aliases_for_stolen_lists(
  1441. self.root_tx
  1442. )
  1443. self.add_output_instructions(alias_insts)
  1444. self.cleanup_graph()
  1445. # Use nn.Module "proxies" in the constructed GraphModule so that
  1446. # the resulting GM does not hold additional strong references to the original modules.
  1447. # This prevents a strong ref cycle where Dynamo created code holds on to references
  1448. # to modules that also have Dynamo code cache invalidation checks.
  1449. # When cache invalidation runs, the generated GM will be invalidated, which also deletes
  1450. # the proxies.
  1451. nn_modules_proxies = {
  1452. name: nn_module_proxy(mod) for name, mod in self.nn_modules.items()
  1453. }
  1454. root = FakeRootModule(nn_modules_proxies)
  1455. from .decorators import disable
  1456. # to handle random calls
  1457. if len(self.random_calls) > 0:
  1458. random_calls_instructions = []
  1459. self.random_values_var = self.new_var("random_values")
  1460. rand_fn = disable(
  1461. _get_gen_rand_values_fn(self.random_calls),
  1462. reason="do not trace into Dynamo rng recovery function",
  1463. )
  1464. rand_fn_name = self.install_global("__gen_rand_values", rand_fn)
  1465. codegen = PyCodegen(
  1466. self.root_tx, root, overridden_sources=overridden_sources
  1467. )
  1468. random_calls_instructions.extend(
  1469. codegen.load_function_name(rand_fn_name, True)
  1470. )
  1471. random_calls_instructions.extend(create_call_function(0, False))
  1472. random_calls_instructions.append(
  1473. codegen.create_store(self.random_values_var),
  1474. )
  1475. self.add_output_instructions(random_calls_instructions)
  1476. # Codegen stack convention before the unsupported instruction
  1477. # NOTE: in these comment blocks, "locals" EXCLUDE free and cell vars.
  1478. # NOTE: stack/locals/cells must be codegen'd BEFORE the unsupported instruction, since the latter
  1479. # can arbitrarily mutate the former.
  1480. # [frame N cells, .., frame 1 cells],
  1481. # [
  1482. # frame N locals,
  1483. # frame N-1 stack + locals,
  1484. # ...,
  1485. # frame 1 stack + locals,
  1486. # ], frame N stack
  1487. # see symbolic_convert.py for
  1488. # codegen stack convention after the unsupported instruction
  1489. # NOTE: cells will be loaded into continuation functions directly by symbolic_convert
  1490. # this determines the order that values are codegen'd to the stack
  1491. stack_values_flat = [val for vals in all_stack_values for val in vals]
  1492. stored_graph_output_var = False
  1493. graph_output_var = None
  1494. # call compiled fx graph and codegen all values - stack and locals
  1495. if (
  1496. self.root_tx is tx # single frame
  1497. and stack_values_flat
  1498. and all(
  1499. not isinstance(
  1500. v,
  1501. (
  1502. UnspecializedPythonVariable,
  1503. NumpyNdarrayVariable,
  1504. TensorWithTFOverrideVariable,
  1505. ),
  1506. )
  1507. and not (isinstance(v, SymNodeVariable) and v.python_type() is float)
  1508. for v in stack_values_flat
  1509. )
  1510. and all(x.is_tensor() for x in stack_values_flat)
  1511. and len(set(stack_values_flat)) == len(stack_values_flat)
  1512. and self.side_effects.is_empty()
  1513. and not tx.debug_locals
  1514. and not self.backward_state
  1515. and not all_stack_locals_metas[-1].stack_null_idxes
  1516. and not all_stack_locals_metas[-1].locals_null_keys
  1517. ):
  1518. # optimization to generate better code in a common case
  1519. # codegen cells
  1520. # no side effects, so no new cells created - no need to call side_effects.codegen_save_tempvars
  1521. cell_cg = PyCodegen(self.root_tx)
  1522. self.codegen_cells(tx, cell_cg)
  1523. self.add_output_instructions(
  1524. [
  1525. # load in reverse since UNPACK_SEQUENCE will reverse
  1526. *self.compile_and_call_fx_graph(
  1527. tx, list(reversed(stack_values_flat)), root
  1528. ),
  1529. *cell_cg.get_instructions(),
  1530. *create_swap(2),
  1531. create_instruction("UNPACK_SEQUENCE", arg=len(stack_values_flat)),
  1532. ]
  1533. )
  1534. # function output will be moved to the correct places below
  1535. else:
  1536. graph_output_var = self.new_var("graph_out")
  1537. # load stack values in a flat manner - we will codegen bytecode to place them correctly
  1538. # according to our convention above
  1539. pass1 = PyCodegen(
  1540. self.root_tx,
  1541. root,
  1542. graph_output_var,
  1543. overridden_sources=overridden_sources,
  1544. )
  1545. self.codegen_suffix(tx, stack_values_flat, pass1, False)
  1546. # Use `pass1.uses` to selectively cache multi-user variables into a
  1547. # temporary local source. This (a). speeds up loading VTs with long
  1548. # chained source, and (b). avoids redundantly saving single-user VT
  1549. # into a temporary local.
  1550. tempvars = {} # type: ignore[var-annotated]
  1551. for val, count in pass1.uses.items():
  1552. # If it's already a local source, no need to cache it
  1553. if count > 1 and not istype(val, (SyntheticLocalSource, LocalSource)):
  1554. # pyrefly: ignore [unsupported-operation]
  1555. tempvars[val] = None
  1556. pass2 = PyCodegen(
  1557. self.root_tx,
  1558. root,
  1559. graph_output_var,
  1560. # pyrefly: ignore [bad-argument-type]
  1561. tempvars=tempvars,
  1562. overridden_sources=overridden_sources,
  1563. )
  1564. self.codegen_suffix(tx, stack_values_flat, pass2, True)
  1565. if (
  1566. torch._dynamo.config.log_graph_in_out_metadata
  1567. and stack_values_flat
  1568. and len(stack_values_flat) == 1
  1569. ):
  1570. vt = stack_values_flat[0]
  1571. if (
  1572. isinstance(vt, torch._dynamo.variables.NamedTupleVariable)
  1573. and vt.tuple_cls
  1574. is torch._dynamo.functional_export.ExportTracerOutput
  1575. ):
  1576. flat_returns = vt.items[0]
  1577. out_spec = vt.items[1]
  1578. assert isinstance(
  1579. flat_returns, torch._dynamo.variables.ListVariable
  1580. )
  1581. vt_to_graph_out_idx: dict[VariableTracker, int] = {}
  1582. for value in pass2.graph_outputs.values():
  1583. assert isinstance(value, torch._dynamo.codegen.GraphOutputEntry)
  1584. variable: VariableTracker = value.variable
  1585. vt_to_graph_out_idx[variable] = value.index
  1586. for idx, vt in enumerate(flat_returns.items):
  1587. if vt in vt_to_graph_out_idx:
  1588. self.export_metadata.output_return_type[idx] = (
  1589. "graph_out",
  1590. vt_to_graph_out_idx[vt],
  1591. )
  1592. elif (
  1593. vt.source is not None
  1594. and (source := getattr(vt.source, "base", None)) # type: ignore[assignment]
  1595. and source.is_input
  1596. ):
  1597. self.export_metadata.output_return_type[idx] = (
  1598. "input",
  1599. vt.source,
  1600. )
  1601. elif vt.is_python_constant():
  1602. self.export_metadata.output_return_type[idx] = (
  1603. "constant",
  1604. vt.as_python_constant(),
  1605. )
  1606. else:
  1607. raise AssertionError(
  1608. f"Encountered unrecognized type {vt} at output {idx}"
  1609. )
  1610. try:
  1611. self.export_metadata.out_spec = out_spec.as_python_constant()
  1612. except ClosureConversionError as e:
  1613. unimplemented(
  1614. gb_type="nested function with non-constructible closure in output",
  1615. context=f"as_python_constant for out_spec {out_spec}",
  1616. explanation=(
  1617. "Cannot return a nested function with closure from a compiled function. "
  1618. "Dynamo failed to construct the function defined in the compiled region with closure objects."
  1619. ),
  1620. hints=[
  1621. "Define the function at module scope instead of inside another function ",
  1622. "Ensure that all closure variables are constants.",
  1623. ],
  1624. from_exc=e,
  1625. )
  1626. output = []
  1627. if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
  1628. output.extend(
  1629. self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  1630. )
  1631. if len(pass2.graph_outputs) != 0:
  1632. output.append(pass2.create_store(graph_output_var))
  1633. stored_graph_output_var = True
  1634. else:
  1635. output.append(create_instruction("POP_TOP"))
  1636. else:
  1637. # NB: Important to run compiler collective even when there is
  1638. # a graph break
  1639. self.run_compiler_collective()
  1640. self.add_output_instructions(output + pass2.get_instructions())
  1641. # store all stack and locals for each frame
  1642. # current state of the stack:
  1643. # all cells,
  1644. # *(frame N stack), *(frame N locals),
  1645. # ...,
  1646. # *(frame 1 stack), *(frame 1 locals)
  1647. self.add_output_instructions(
  1648. [
  1649. create_instruction(
  1650. "BUILD_LIST",
  1651. arg=len(stack_values_flat) - all_stack_locals_metas[0].num_stack,
  1652. ),
  1653. ]
  1654. )
  1655. # current state of the stack:
  1656. # all cells,
  1657. # *(frame N stack), [
  1658. # *(frame N locals),
  1659. # *(frame N-1 stack), *(frame N-1 locals),
  1660. # ...
  1661. # *(frame 1 stack), *(frame 1 locals),
  1662. # ]
  1663. # iterate current frame (N) to root frame (1)
  1664. # sliding window over frame stack/locals
  1665. start_idx = 0
  1666. end_idx = 0
  1667. for i, meta in enumerate(all_stack_locals_metas):
  1668. # do not pack frame N's stack into the value list
  1669. n_vals = len(meta.locals_names)
  1670. if i != 0:
  1671. n_vals += meta.num_stack
  1672. if n_vals == 0:
  1673. self.add_output_instructions(
  1674. [
  1675. create_instruction("BUILD_LIST", arg=0),
  1676. *create_swap(2),
  1677. ]
  1678. )
  1679. # [], stack_values_flat
  1680. else:
  1681. end_idx += n_vals
  1682. self.add_output_instructions(
  1683. [
  1684. create_dup_top(),
  1685. *create_binary_slice(start_idx, end_idx),
  1686. *create_swap(2),
  1687. ]
  1688. )
  1689. start_idx += n_vals
  1690. # stack_values_flat[x:y], stack_values_flat
  1691. # add root frame's unmodified locals here
  1692. if i == len(all_stack_locals_metas) - 1:
  1693. root_cg = PyCodegen(self.root_tx)
  1694. unmodified_locals_names: dict[str, int] = {}
  1695. for k, v in self.root_tx.symbolic_locals.items():
  1696. if isinstance(v.source, LocalSource) and v.source.local_name == k:
  1697. root_cg.append_output(root_cg.create_load(k))
  1698. unmodified_locals_names[k] = len(meta.locals_names) + len(
  1699. unmodified_locals_names
  1700. )
  1701. self.add_output_instructions(
  1702. root_cg.get_instructions()
  1703. + [
  1704. create_instruction(
  1705. "BUILD_LIST", arg=len(unmodified_locals_names)
  1706. ),
  1707. # arg=2 because we already swapped the locals list back
  1708. create_instruction("LIST_EXTEND", arg=2),
  1709. ]
  1710. )
  1711. meta.locals_names.update(unmodified_locals_names)
  1712. # *(frame N stack), metas[0] stack + locals, ..., metas[i] stack + locals, stack_values_flat
  1713. # current state of the stack:
  1714. # all cells,
  1715. # *(frame N stack),
  1716. # frame N locals,
  1717. # frame N-1 stack, frame N-1 locals,
  1718. # ...
  1719. # frame 1 stack, frame 1 locals,
  1720. # stack_values_flat
  1721. #
  1722. self.add_output_instructions(
  1723. [
  1724. create_instruction("POP_TOP"),
  1725. create_instruction("BUILD_LIST", arg=len(all_stack_locals_metas)),
  1726. *create_rot_n(all_stack_locals_metas[0].num_stack + 1),
  1727. ]
  1728. )
  1729. # final state of the stack before running the unsupported bytecode:
  1730. # all cells,
  1731. # [
  1732. # [frame N locals],
  1733. # [frame N-1 stack + locals],
  1734. # ...,
  1735. # [frame 1 stack + locals],
  1736. # ], *(frame N stack)
  1737. if graph_output_var and stored_graph_output_var:
  1738. self.add_output_instructions(
  1739. [create_instruction("DELETE_FAST", argval=graph_output_var)]
  1740. )
  1741. if torch._dynamo.config.side_effect_replay_policy in ["warn", "error"]:
  1742. from torch.export._trace import _ExportModuleSpecTrackerDict
  1743. potential_side_effects = []
  1744. for var in self.side_effects._get_modified_vars():
  1745. if hasattr(var, "mutation_type"):
  1746. mut_type = var.mutation_type
  1747. # Make sure to skip codegen specific mutations
  1748. if isinstance(
  1749. mut_type, (AttributeMutationExisting, ValueMutationExisting)
  1750. ):
  1751. if isinstance(var, UserDefinedDictVariable) and isinstance(
  1752. var.value, _ExportModuleSpecTrackerDict
  1753. ):
  1754. for k, v in var.items.items():
  1755. # pyrefly: ignore [implicit-any]
  1756. specs = {}
  1757. # pyrefly: ignore[missing-attribute]
  1758. for k_spec, val in v.items.items():
  1759. specs[k_spec.vt.as_python_constant()] = (
  1760. val.as_python_constant()
  1761. )
  1762. assert ["in_spec", "out_spec"] == list(specs.keys())
  1763. self.export_metadata.module_call_spec[
  1764. # pyrefly: ignore[missing-attribute]
  1765. k.vt.as_python_constant()
  1766. ] = specs
  1767. # export uses tracepoint pass to dump submodule inp/out spec
  1768. # into global state, so we filter it here
  1769. if not (
  1770. isinstance(var, UserDefinedDictVariable)
  1771. and isinstance(var.value, _ExportModuleSpecTrackerDict)
  1772. ):
  1773. potential_side_effects.append(var)
  1774. side_effect_refs = [
  1775. _get_source_debug_name(var.source) for var in potential_side_effects
  1776. ]
  1777. if side_effect_refs:
  1778. if torch._dynamo.config.side_effect_replay_policy == "warn":
  1779. warnings.warn(
  1780. f"While compiling, we found certain side effects happened in the model.forward. "
  1781. f"Here are the list of potential sources you can double check: {side_effect_refs}"
  1782. )
  1783. else:
  1784. raise RuntimeError(
  1785. f"While compiling, we found certain side effects happened in the model.forward. "
  1786. f"Here are the list of potential sources you can double check: {side_effect_refs}"
  1787. )
  1788. return all_stack_locals_metas
  1789. def codegen_cells(self, tx: "InstructionTranslatorBase", cg: PyCodegen) -> None:
  1790. # no need to codegen if reason.graph_break is False (since we won't resume)
  1791. if self.compile_subgraph_reason.graph_break:
  1792. tx_cnt = 0
  1793. cur_tx: Optional[InstructionTranslatorBase] = tx
  1794. while cur_tx is not None:
  1795. # NOTE: we generate cells in the same order as resume_execution.py: sorted freevars + cellvars
  1796. # Emitting `LOAD_FAST/LOAD_CLOSURE` with names in `co_freevars`
  1797. # requires that in the generated bytecode, these cells would keep
  1798. # their original local names, which we ensure via
  1799. # `CellVariable.local_name`.
  1800. freevars = tuple(sorted(cur_tx.cell_and_freevars()))
  1801. for cell in freevars:
  1802. if cur_tx is self.root_tx: # root frame
  1803. cg.append_output(cg.create_load_closure(cell))
  1804. else: # nested frame
  1805. assert cur_tx.post_prune_cell_and_freevars
  1806. cg(cur_tx.post_prune_cell_and_freevars[cell])
  1807. cg.append_output(create_build_tuple(len(freevars)))
  1808. cur_tx = cur_tx.parent
  1809. tx_cnt += 1
  1810. cg.append_output(create_instruction("BUILD_LIST", arg=tx_cnt))
  1811. else:
  1812. cg.append_output(create_instruction("BUILD_LIST", arg=0))
  1813. def codegen_suffix(
  1814. self,
  1815. tx: "InstructionTranslatorBase",
  1816. stack_values: list[VariableTracker],
  1817. cg: PyCodegen,
  1818. log_side_effects: bool,
  1819. ) -> None:
  1820. # NOTE: `codegen_save_tempvars` must run first to update `source` fields
  1821. # for variables with `AttributeMutationNew`, as they don't implement
  1822. # `reconstruct` themselves.
  1823. self.side_effects.codegen_save_tempvars(cg)
  1824. if self.backward_state:
  1825. assert not self.export
  1826. for name, val in self.backward_state.items():
  1827. cg(val)
  1828. assert self.backward_state_var is not None
  1829. cg.append_output(cg.create_load(self.backward_state_var))
  1830. cg.store_attr(name)
  1831. if config.replay_side_effects:
  1832. self.side_effects.codegen_hooks(cg)
  1833. # TODO get debug_locals working for nested graph breaks
  1834. # Return variables used for logging at the end
  1835. for debug_var, args in tx.debug_locals:
  1836. cg.add_push_null(lambda: cg(debug_var))
  1837. for arg in args:
  1838. cg(arg)
  1839. cg.extend_output(create_call_function(len(args), False))
  1840. cg.extend_output([create_instruction("POP_TOP")])
  1841. # codegen cells before we apply side effects
  1842. self.codegen_cells(tx, cg)
  1843. cg.restore_stack(stack_values, value_from_source=not tx.export)
  1844. self.side_effects.codegen_update_mutated(cg, log_side_effects)
  1845. def cleanup_graph(self) -> None:
  1846. """
  1847. Remove "creation_timestamp" from node meta
  1848. Remove this pattern from the graph:
  1849. torch._C._set_grad_enabled(False)
  1850. torch._C._set_grad_enabled(True)
  1851. """
  1852. assert self.should_exit
  1853. nodes = list(self.graph.nodes)
  1854. for node in nodes:
  1855. node.meta.pop("creation_timestamp", None)
  1856. grad_enabled = torch.is_grad_enabled()
  1857. for node1, node2 in itertools.pairwise(nodes):
  1858. if (
  1859. node1.target is torch._C._set_grad_enabled
  1860. and tuple(node1.args) == (not grad_enabled,)
  1861. and not node1._erased
  1862. ):
  1863. grad_enabled = node1.args[0]
  1864. if (
  1865. node2.target is torch._C._set_grad_enabled
  1866. and tuple(node2.args) == (not grad_enabled,)
  1867. and not node2._erased
  1868. ):
  1869. grad_enabled = node2.args[0]
  1870. self.graph.erase_node(node1)
  1871. self.graph.erase_node(node2)
  1872. def bypass_package(self, reason: str = "", **kwargs: Any) -> None:
  1873. """
  1874. Do not save this output graph to the CompilePackage
  1875. """
  1876. if not self.package:
  1877. return
  1878. if torch._dynamo.config.strict_precompile:
  1879. raise torch._dynamo.exc.PackageError(
  1880. "Detected a package bypass: %s", reason
  1881. )
  1882. log.warning("Detected a package bypass: %s", reason)
  1883. torch._logging.trace_structured(
  1884. "artifact",
  1885. metadata_fn=lambda: {
  1886. "name": "precompile_cache_bypass",
  1887. "encoding": "json",
  1888. },
  1889. payload_fn=lambda: {
  1890. # precede with underscore so it always appear first in JSON in tlparse
  1891. "_reason": reason,
  1892. **kwargs,
  1893. },
  1894. )
  1895. self.package.bypass_current_entry()
  1896. self.package = None
  1897. def get_graph_sizes_structured(self) -> dict[str, list[Union[int, str]]]:
  1898. ret: dict[str, list[Union[int, str]]] = {}
  1899. for node in self.graph.nodes:
  1900. example_value = node.meta.get("example_value", None)
  1901. if isinstance(example_value, torch._subclasses.FakeTensor):
  1902. size = example_value.size()
  1903. ret[node.name] = [s if isinstance(s, int) else repr(s) for s in size]
  1904. return ret
  1905. def get_graph_sizes(self, name: str) -> str:
  1906. graph_sizes_str = "TRACED GRAPH TENSOR SIZES\n"
  1907. graph_sizes_str += f"===== {name} =====\n"
  1908. for node in self.graph.nodes:
  1909. example_value = node.meta.get("example_value", None)
  1910. if isinstance(example_value, torch._subclasses.FakeTensor):
  1911. size = example_value.size()
  1912. graph_sizes_str += f"{node.name}: {tuple(size)}\n"
  1913. concrete_size = []
  1914. has_symint = False
  1915. for sz in size:
  1916. if isinstance(sz, int):
  1917. concrete_size.append(sz)
  1918. elif isinstance(sz, torch.SymInt):
  1919. has_symint = True
  1920. concrete_size.append(sz.node.hint)
  1921. else:
  1922. break
  1923. else:
  1924. if has_symint:
  1925. graph_sizes_str += (
  1926. f"{node.name} (concrete): {tuple(concrete_size)}\n"
  1927. )
  1928. return graph_sizes_str
  1929. @contextlib.contextmanager
  1930. def restore_global_state(self) -> Any:
  1931. """
  1932. Momentarily restores the global state to what it was prior to tracing the current output
  1933. """
  1934. prior_global_state = self.tracing_context.global_context.copy_graphstate()
  1935. current_global_state: dict[str, tuple[Any, bool]] = {}
  1936. self.save_global_state(out=current_global_state)
  1937. try:
  1938. # Set to state prior to tracing the graph
  1939. self.tracing_context.global_context.restore_graphstate(prior_global_state)
  1940. yield
  1941. finally:
  1942. # Reset to state at the current time (e.g. before calling the user compiler)
  1943. self.tracing_context.global_context.restore_graphstate(
  1944. GlobalContextCheckpointState(current_global_state)
  1945. )
  1946. def run_compiler_collective(self) -> None:
  1947. tx = self.root_tx
  1948. assert tx is not None
  1949. if (ds := tx.distributed_state) is not None and ds.all_states is None:
  1950. compile_pg = ds.compile_pg
  1951. log.info("compiler_collective %s", ds.local_state)
  1952. torch._logging.trace_structured(
  1953. "artifact",
  1954. metadata_fn=lambda: {
  1955. "name": "compiler_collective",
  1956. "encoding": "string",
  1957. },
  1958. payload_fn=lambda: ds.local_state.render(),
  1959. )
  1960. device_types = compile_pg._device_types
  1961. assert len(device_types) == 1, (
  1962. "Expect only one device type but got {}".format("+".join(device_types))
  1963. )
  1964. with (
  1965. get_interface_for_device(device_types.pop()).device( # type: ignore[attr-defined]
  1966. compile_pg.rank() % torch.accelerator.device_count()
  1967. ),
  1968. dynamo_timed("compiler_collective", log_pt2_compile_event=True),
  1969. ):
  1970. all_states: list[Any] = [None] * compile_pg.size()
  1971. dist.all_gather_object(all_states, ds.local_state, group=compile_pg)
  1972. ds.all_states = all_states
  1973. # Clear speculation log, because are tracing may diverge due to
  1974. # this information from the compiler collective
  1975. tx.speculation_log.clear()
  1976. raise exc.CompileCollectiveRestartAnalysis
  1977. def _validate_outputs_safe_for_autograd_nodes(
  1978. self, rv: list["VariableTracker"], tx: "InstructionTranslatorBase"
  1979. ) -> None:
  1980. """
  1981. Validate that if torch.autograd.grad is used in the graph and outputs
  1982. require grad, we trigger AutogradGradRestartAnalysis only if the output is connected
  1983. to the autograd.grad computation.
  1984. rv here refers to list of variables that are being returned from dynamo graph.
  1985. See Note [Tracing autograd.grad in dynamo]
  1986. """
  1987. if not self.autograd_grad_consumed_grad_fns:
  1988. return
  1989. from .variables.tensor import TensorVariable
  1990. for var in rv:
  1991. if not isinstance(var, TensorVariable) or not var.requires_grad:
  1992. continue
  1993. fake_tensor = var.as_proxy().node.meta.get("example_value")
  1994. assert isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor)
  1995. if fake_tensor.grad_fn is None:
  1996. continue
  1997. # Traverse the entire autograd graph of the returned tensor to check
  1998. # if any node was consumed by autograd.grad
  1999. reachable_grad_fns = collect_reachable_grad_fns([(fake_tensor, None)])
  2000. if reachable_grad_fns & self.autograd_grad_consumed_grad_fns:
  2001. # Set the flag to graph break at autograd.grad on retry
  2002. tx.speculation_log.graph_break_on_autograd_grad = True
  2003. raise exc.AutogradGradRestartAnalysis(
  2004. restart_reason="autograd.grad consumed grad_fns of returned tensors"
  2005. )
  2006. def compile_and_call_fx_graph(
  2007. self,
  2008. tx: "InstructionTranslatorBase",
  2009. rv: list[VariableTracker],
  2010. root: FakeRootModule,
  2011. ) -> list[Instruction]:
  2012. """
  2013. Generate code from self.graph and return the Instruction()s to
  2014. call that generated code.
  2015. Code is generated w.r.t. self.root_tx.
  2016. tx is only used for preserving GraphModule metadata
  2017. """
  2018. with torch._guards.TracingContext.clear_frame():
  2019. from .decorators import disable
  2020. assert self.should_exit
  2021. self.run_compiler_collective()
  2022. if count_calls(self.graph) == 0 and len(rv) == 0:
  2023. return []
  2024. name = unique_id("__compiled_fn", with_uuid=True)
  2025. assert isinstance(rv, list)
  2026. assert isinstance(root, FakeRootModule)
  2027. # Check if autograd.grad is used with outputs that require grad
  2028. # This would cause double backward issues in aot_autograd
  2029. self._validate_outputs_safe_for_autograd_nodes(rv, tx)
  2030. output_node = self.create_node(
  2031. "output",
  2032. "output",
  2033. (self.current_tracer.create_arg(tuple(x.as_proxy() for x in rv)),),
  2034. {},
  2035. )
  2036. sub_gms = self.dedup_pass()
  2037. root.add_nn_modules(sub_gms) # type: ignore[arg-type]
  2038. self.current_tracer._maybe_preserve_original_meta(tx, output_node)
  2039. if not config.do_not_emit_runtime_asserts:
  2040. # There is a rare scenario where codegen_suffix adds a new entry
  2041. # to self.nn_modules while `root` knows only about the
  2042. # nn_modules at the time of its creation. This causes failures
  2043. # while creating the graph module because self.graph and root
  2044. # are out of sync. This only happens for `get_attr` nodes, so
  2045. # here we clean up the get_attr nodes that are unused.
  2046. with dynamo_timed("insert_deferred_runtime_asserts"):
  2047. for attr in dir(root):
  2048. subgraph = getattr(root, attr)
  2049. if isinstance(subgraph, fx.GraphModule):
  2050. insert_deferred_runtime_asserts(
  2051. subgraph,
  2052. self.shape_env,
  2053. name,
  2054. export=self.export,
  2055. )
  2056. self.remove_unused_get_attr_nodes()
  2057. insert_deferred_runtime_asserts(
  2058. fx.GraphModule(root, self.graph),
  2059. self.shape_env,
  2060. name,
  2061. export=self.export,
  2062. )
  2063. # NB: deferred runtime asserts can keep graphargs live, so make sure
  2064. # those are inserted before pruning
  2065. self.remove_unused_graphargs()
  2066. ncalls = count_calls(self.graph)
  2067. counters["stats"]["calls_captured"] += ncalls
  2068. self.remove_tensorify_specialized_graphargs()
  2069. # free a bit of memory
  2070. self.real_value_cache.clear()
  2071. gm = _make_graph_module(root, self.graph)
  2072. from .dce_extra_outputs import dce_hop_extra_outputs
  2073. dce_hop_extra_outputs(gm)
  2074. # Saved tensors hooks are not used by the graph.
  2075. # GraphModule by default only copies used in the graph submodules.
  2076. # Copying them into the result graph manually.
  2077. if self.saved_tensors_hooks_subgraph_names:
  2078. for subgraph_name in self.saved_tensors_hooks_subgraph_names:
  2079. setattr(gm, subgraph_name, getattr(root, subgraph_name))
  2080. for register_finalizer in self.register_finalizer_fns:
  2081. register_finalizer(gm)
  2082. if next(gm.parameters(), None) is not None:
  2083. # If dynamo produces a graph with parameters, skip package stuff
  2084. # Bypass output graph
  2085. self.bypass_package(
  2086. "Graph contains named parameters: either inline_inbuilt_nn_modules=False or there are static addresses.",
  2087. inline_builtin_nn_modules=torch._dynamo.config.inline_inbuilt_nn_modules,
  2088. gm=gm.print_readable(
  2089. print_output=False, include_stride=True, include_device=True
  2090. ),
  2091. )
  2092. if self.package is not None:
  2093. gm._backend_id = name
  2094. gm.compile_subgraph_reason = self.compile_subgraph_reason
  2095. gm.meta["dynamo_flat_name_to_original_fqn"] = (
  2096. self.dynamo_flat_name_to_original_fqn.copy()
  2097. )
  2098. gm.meta["dynamo_compile_id"] = self.dynamo_compile_id
  2099. gm.meta["backend_id"] = name
  2100. graph_code_log.debug(
  2101. "%s",
  2102. lazy_format_graph_code(
  2103. name, gm, include_stride=True, include_device=True, colored=True
  2104. ),
  2105. )
  2106. torch._logging.trace_structured(
  2107. "dynamo_output_graph",
  2108. lambda: {"sizes": self.get_graph_sizes_structured()},
  2109. payload_fn=lambda: gm.print_readable(
  2110. print_output=False, include_stride=True, include_device=True
  2111. ),
  2112. )
  2113. self.call_cleanup_hooks()
  2114. old_fake_mode = self.tracing_context.fake_mode
  2115. assert old_fake_mode is not None
  2116. # Store old_fake_mode so it can be cleared at end of compile
  2117. self._old_fake_mode = old_fake_mode
  2118. if not self.export:
  2119. import torch._functorch.config as _config
  2120. with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False):
  2121. # TODO(voz): The way export uses gm, and fake tensors, is not supported with us resetting
  2122. # Why create a new FakeTensorMode?
  2123. #
  2124. # The reason this needs to be done is because when we do Dynamo tracing, fake
  2125. # tensors can have their metadata mutated. Thus, the fake tensor we allocated
  2126. # for any given tensor may no longer be valid for the beginning trace of the
  2127. # graph. Nor is it convenient to "clone" the input tensors before mutating them,
  2128. # since you have to preserve aliasing. So we just reconstruct the FakeTensorMode
  2129. # from scratch when we go to AOTAutograd. But the ShapeEnv must be preserved as
  2130. # Dynamo made decisions about what is dynamic or not / guards from the user code
  2131. # that is not in graph.
  2132. backend_fake_mode = torch._subclasses.FakeTensorMode(
  2133. shape_env=old_fake_mode.shape_env,
  2134. )
  2135. # TODO(voz): Ostensibly, this should be scoped and
  2136. # restore back to old_fake_mode, but doing so currently violates
  2137. # a lot of fake_tensor ownership assumptions and runs afoul of detect_fake_mode
  2138. self.tracing_context.fake_mode = backend_fake_mode
  2139. with self.restore_global_state():
  2140. compiled_fn = self.call_user_compiler(gm, self.example_inputs())
  2141. from torch.fx._lazy_graph_module import _LazyGraphModule
  2142. if isinstance(compiled_fn, _LazyGraphModule) or (
  2143. isinstance(getattr(compiled_fn, "__self__", None), _LazyGraphModule)
  2144. and compiled_fn.__name__ == "_lazy_forward" # type: ignore[attr-defined]
  2145. ):
  2146. # Since dynamo will run the forward method for the GraphModule shortly
  2147. # anyways, it does not hurt to do the real recompilation here if
  2148. # this is a _LazyGraphModule. This makes it easier for dynamo to
  2149. # optimize a _LazyGraphModule.
  2150. lazy_gm = (
  2151. compiled_fn
  2152. if isinstance(compiled_fn, _LazyGraphModule)
  2153. else compiled_fn.__self__ # type: ignore[attr-defined]
  2154. )
  2155. _LazyGraphModule.force_recompile(lazy_gm)
  2156. if not isinstance(compiled_fn, _LazyGraphModule):
  2157. # replace compiled_fn with the real forward method
  2158. compiled_fn = lazy_gm.forward
  2159. if self.package is not None:
  2160. self.package.add_backend_id(name, compiled_fn)
  2161. compiled_fn = disable(
  2162. compiled_fn, reason="do not trace Dynamo-compiled graph"
  2163. )
  2164. counters["stats"]["unique_graphs"] += 1
  2165. assert old_fake_mode.shape_env is not None
  2166. if specializations := old_fake_mode.shape_env.specializations:
  2167. specialization_guards = []
  2168. specialization_cache: dict[Specialization, Callable[[Any], Any]] = {}
  2169. sources = [a.source for a in self.graphargs]
  2170. for specialization in specializations:
  2171. source_index = sources.index(specialization.source)
  2172. check_fn_source = inspect.getsource(specialization.check_fn).strip()
  2173. # Required because the LABDA_GUARD API requires a root guard manager
  2174. unused_root_guard_manager = RootGuardManager()
  2175. check_fn = guards.LAMBDA_GUARD( # type: ignore[attr-defined]
  2176. unused_root_guard_manager,
  2177. specialization.check_fn,
  2178. [check_fn_source],
  2179. None, # user_stack
  2180. )
  2181. log.debug(
  2182. "Compiling backend specialized graph with specialization=%s",
  2183. check_fn_source,
  2184. )
  2185. specialization_guards.append(
  2186. (
  2187. functools.partial(
  2188. lambda idx, args, check_fn=check_fn: check_fn(
  2189. args[idx]
  2190. ),
  2191. source_index,
  2192. ),
  2193. specialization,
  2194. )
  2195. )
  2196. @torch._dynamo.disable(reason="do not trace Dynamo-compiled graph") # type: ignore[misc]
  2197. def specialized_dispatch(*args: Any, **kwargs: Any) -> Any:
  2198. for check_fn, specialization in specialization_guards:
  2199. if check_fn(args):
  2200. if specialization in specialization_cache:
  2201. return specialization_cache[specialization](
  2202. *args, **kwargs
  2203. )
  2204. with self.shape_env.patch_source_specialization(
  2205. specialization.source, specialization.check_fn
  2206. ):
  2207. # Modify gm so AOTAutogradCache key changes per specialization
  2208. gm.meta["specialization"] = specialization
  2209. example_inputs: list[Tensor] = list(args)
  2210. with tracing(self.tracing_context):
  2211. specialization_cache[specialization] = (
  2212. self.call_user_compiler(gm, example_inputs)
  2213. )
  2214. return specialization_cache[specialization](*args, **kwargs)
  2215. return compiled_fn(*args, **kwargs)
  2216. # This is safe because we pre-process name to be unique
  2217. self.install_global_unsafe(name, specialized_dispatch)
  2218. else:
  2219. # This is safe because we pre-process name to be unique
  2220. self.install_global_unsafe(name, compiled_fn)
  2221. assert self.root_tx is not None
  2222. cg = PyCodegen(self.root_tx)
  2223. if has_user_objects():
  2224. # NB: This is where we store possible user objects before running the graph
  2225. # index_to_user_object_weakref is the function used in the graph to translate
  2226. # the dynamo-generated index into the actual object passed to the compiled function.
  2227. # We generate bytecode to store all user objects at the proper index in the below
  2228. # call.
  2229. cg.add_push_null(
  2230. lambda: cg.load_import_from(
  2231. torch._dynamo.graph_bytecode_inputs.__name__,
  2232. "store_user_object_weakrefs",
  2233. )
  2234. )
  2235. tmp_vars = []
  2236. for constructor in index_to_bytecode_constructor.values():
  2237. constructor(cg)
  2238. var_name = (
  2239. self.new_var()
  2240. ) # keep alive any user objects for the rest of the frame
  2241. # TODO: we could omit this for objects we create but shouldn't be too much overhead for now
  2242. cg.store(var_name)
  2243. tmp_vars.append(var_name)
  2244. for var_name in tmp_vars:
  2245. cg.append_output(cg.create_load(var_name))
  2246. cg.call_function(len(index_to_bytecode_constructor), False)
  2247. cg.pop_top()
  2248. for idx, arg in enumerate(self.graphargs):
  2249. self.export_metadata.graph_input_idx_to_local_source[idx] = arg.source
  2250. cg.make_call_generated_code(name)
  2251. return cg.get_instructions()
  2252. @property
  2253. def placeholders(self) -> list[fx.Node]:
  2254. return self.graph.find_nodes(op="placeholder")
  2255. @property
  2256. def graphargs(self) -> list[GraphArg]:
  2257. return [node.meta["grapharg"] for node in self.placeholders]
  2258. def call_user_compiler(
  2259. self, gm: fx.GraphModule, example_inputs: list[Tensor]
  2260. ) -> CompiledFn:
  2261. with dynamo_timed(
  2262. "OutputGraph.call_user_compiler",
  2263. phase_name="backend_compile",
  2264. log_pt2_compile_event=True,
  2265. log_waitcounter=True,
  2266. waitcounter_name_override="compile_aot_autograd",
  2267. dynamo_compile_column_us="aot_autograd_cumulative_compile_time_us",
  2268. ):
  2269. return self._call_user_compiler(gm, example_inputs)
  2270. def _call_user_compiler(
  2271. self, gm: fx.GraphModule, example_inputs: list[Tensor]
  2272. ) -> CompiledFn:
  2273. assert self.compiler_fn is not None
  2274. tot = 0
  2275. placeholders = []
  2276. for node in gm.graph.nodes:
  2277. if node.op in ("call_function", "call_method", "call_module"):
  2278. tot += 1
  2279. if node.op == "placeholder":
  2280. placeholders.append(node)
  2281. increment_op_count(tot)
  2282. for pl in placeholders:
  2283. if not hasattr(pl, "_dynamo_source"):
  2284. arg = pl.meta["grapharg"]
  2285. # TODO: Why isn't this stored in meta :think:
  2286. # NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640
  2287. pl._dynamo_source = arg.source
  2288. # NOTE: can't move these into meta: https://github.com/pytorch/pytorch/issues/141640
  2289. gm._param_name_to_source = self.param_name_to_source # type: ignore[assignment]
  2290. gm._source_to_user_stacks = self.source_to_user_stacks # type: ignore[assignment]
  2291. # Check for per-graph backend override (for debugging/bisecting)
  2292. compiler_fn = (
  2293. get_backend_override_for_compile_id(
  2294. self.dynamo_compile_id, config.debug_backend_override
  2295. )
  2296. or self.compiler_fn
  2297. )
  2298. # Check for per-graph inductor config override (for debugging/bisecting)
  2299. inductor_config_override = get_inductor_config_override_for_compile_id(
  2300. self.dynamo_compile_id, config.debug_inductor_config_override
  2301. )
  2302. if inductor_config_override:
  2303. compiler_fn = _wrap_with_inductor_config(
  2304. compiler_fn, inductor_config_override
  2305. )
  2306. name = (
  2307. compiler_fn.__name__
  2308. if hasattr(compiler_fn, "__name__")
  2309. else "<unknown compiler_fn>"
  2310. )
  2311. try:
  2312. _step_logger()(logging.INFO, f"calling compiler function {name}")
  2313. if config.verify_correctness:
  2314. compiler_fn = WrapperBackend(compiler_fn)
  2315. compiled_fn = compiler_fn(gm, example_inputs)
  2316. _step_logger()(logging.INFO, f"done compiler function {name}")
  2317. assert callable(compiled_fn), "compiler_fn did not return callable"
  2318. except (TensorifyScalarRestartAnalysis, ShortenTraceback):
  2319. raise
  2320. except exceptions_allowed_to_be_fallback as e:
  2321. if self.has_user_defined_allowed_in_graph:
  2322. raise BackendCompilerFailed(
  2323. self.compiler_fn, e, inspect.currentframe()
  2324. ).with_traceback(e.__traceback__) from None
  2325. unimplemented_with_warning(
  2326. e,
  2327. self.root_tx.f_code,
  2328. gb_type="Backend compiler exception",
  2329. context=f"Backend: {name}\nException:{str(e)}\nTraceback:\n{self.root_tx.format_frame_summary()}",
  2330. explanation=f"Backend compiler `{name}` failed with {str(e)}. Adding a graph break.",
  2331. hints=[
  2332. "Report an issue to the backend compiler repo.",
  2333. ],
  2334. )
  2335. except SkipFrame:
  2336. # The backend compiler has requested that we skip the frame, instead of
  2337. # aborting execution.
  2338. raise
  2339. except Exception as e:
  2340. raise BackendCompilerFailed(
  2341. self.compiler_fn, e, inspect.currentframe()
  2342. ).with_traceback(e.__traceback__) from None
  2343. signpost_event(
  2344. "dynamo",
  2345. "OutputGraph.call_user_compiler",
  2346. {
  2347. **self.co_fields,
  2348. "op_count": tot,
  2349. "node_count": len(gm.graph.nodes),
  2350. "input_count": len(placeholders),
  2351. },
  2352. )
  2353. # pyrefly: ignore [unbound-name, bad-return]
  2354. return compiled_fn
  2355. def dedup_pass(self) -> dict[str, torch.fx.GraphModule]:
  2356. if torch._dynamo.config.use_graph_deduplication:
  2357. return apply_graph_deduplication(self)
  2358. else:
  2359. return {}
  2360. def install_subgraph(self, name: str, sub_gm: torch.fx.GraphModule) -> str:
  2361. next_name = get_unique_name_wrt(name, self.nn_modules, requires_suffix=True)
  2362. sub_gm.__name__ = next_name # type: ignore[assignment]
  2363. sub_gm.torchdynamo_force_dynamic = False # type: ignore[assignment]
  2364. # This graph module is not present in the user space, so it can't be
  2365. # accessed by a source. Set source=None.
  2366. self.register_attr_or_module(sub_gm, next_name, source=None)
  2367. return next_name
  2368. def example_inputs(self) -> list[torch.Tensor]:
  2369. result = [arg.example for arg in self.graphargs]
  2370. # pyrefly: ignore[bad-return]
  2371. return result
  2372. def remove_unused_get_attr_nodes(self) -> None:
  2373. for node in sorted(self.graph.find_nodes(op="get_attr"), reverse=True):
  2374. if len(list(node.users)) == 0:
  2375. self.remove_node(node)
  2376. def remove_unused_graphargs(self) -> None:
  2377. # NB: It's OK to drop GraphArg for symbols that ended up being
  2378. # specialized iff they are not used in runtime assertions. You don't
  2379. # even have to make a guard for it, because ShapeEnv produce_guards
  2380. # operates on tracked_fakes, which never gets pruned.
  2381. # That being said, you'll get marginally better generated
  2382. # guard code if you promote the guard into a Dynamo guard (since that
  2383. # allows for the guard to be done using C++ guards.) If we get
  2384. # ShapeEnv guards to go into C++ guards, this will stop being a thing
  2385. # though!
  2386. assert self.should_exit
  2387. # Miniature DCE pass, but only for obviously trivial operations
  2388. def is_static_true(b_node: fx.node.Argument) -> bool:
  2389. if b_node is True:
  2390. return True
  2391. if not isinstance(b_node, fx.Node):
  2392. return False
  2393. b = b_node.meta.get("example_value")
  2394. if b is None:
  2395. return False
  2396. if b is True:
  2397. return True
  2398. if (
  2399. isinstance(b, torch.SymBool)
  2400. and (r := b.node.maybe_as_bool()) is not None
  2401. ):
  2402. return r
  2403. # TODO: We can also technically remove all cases when the input
  2404. # doesn't have unbacked inputs, since it's all in the ShapeEnv
  2405. return False
  2406. def is_symnode_arg(a: fx.node.Argument) -> bool:
  2407. from torch.fx.experimental.sym_node import SymTypes
  2408. if isinstance(a, (int, float, bool)):
  2409. return True
  2410. if isinstance(a, fx.Node):
  2411. return isinstance(a.meta.get("example_value"), SymTypes)
  2412. return False
  2413. # NB: We assume that you cannot do mutations on int/float/bool,
  2414. # because they are immutable types, and therefore is always safe to
  2415. # DCE.
  2416. def is_symnode_compute_node(node: fx.Node) -> bool:
  2417. from torch.fx.experimental.sym_node import SymTypes
  2418. if node.op != "call_function":
  2419. return False
  2420. # TODO: I don't think it's possible to have a bare int/float here?
  2421. if not isinstance(node.meta.get("example_value"), SymTypes):
  2422. return False
  2423. # TODO: This will bail here if you ever end up with a more complicated
  2424. # computation function, like sum(list_of_ints), even though it
  2425. # should be DCE'able
  2426. if not all(is_symnode_arg(a) for a in node.args):
  2427. return False
  2428. if not all(is_symnode_arg(a) for a in node.kwargs.values()):
  2429. return False
  2430. return True
  2431. from torch.fx.experimental.symbolic_shapes import is_accessor_node
  2432. for node in reversed(list(self.graph.nodes)):
  2433. if len(list(node.users)) == 0:
  2434. if (
  2435. node.op == "get_attr"
  2436. or (node.op == "call_function" and node.target is operator.getitem)
  2437. or (
  2438. node.op == "call_function"
  2439. and node.target is torch._check
  2440. and is_static_true(node.args[0])
  2441. )
  2442. or is_symnode_compute_node(node)
  2443. or is_accessor_node(node)
  2444. ):
  2445. self.remove_node(node)
  2446. def placeholder_binds_symbol(node: fx.Node) -> Optional[sympy.Symbol]:
  2447. arg = node.meta["grapharg"]
  2448. example = arg.example
  2449. if isinstance(example, torch.SymInt) and isinstance(
  2450. example.node.expr, sympy.Symbol
  2451. ):
  2452. return example.node.expr
  2453. return None
  2454. def remove_unused(node: fx.Node) -> None:
  2455. log.debug("REMOVE UNUSED GRAPHARG %s", node.meta["grapharg"].source.name)
  2456. # I'm not really sure why you need to delete these from the
  2457. # node since the node is going to get removed
  2458. del node.meta["grapharg"]
  2459. self.remove_node(node)
  2460. self.real_value_cache.pop(node, None)
  2461. used_symbols: set[sympy.Symbol] = set()
  2462. def update_used_symbols(
  2463. used_symbols: set[sympy.Symbol], fake: Union[torch.SymInt, torch.Tensor]
  2464. ) -> None:
  2465. used_symbols |= free_symbols(fake)
  2466. recheck_placeholders = []
  2467. for node in self.placeholders:
  2468. binds_symbol = placeholder_binds_symbol(node) is not None
  2469. # Don't delete symbol bindings yet
  2470. if binds_symbol:
  2471. if not node.users:
  2472. recheck_placeholders.append(node)
  2473. else:
  2474. if not node.users and not isinstance(
  2475. node.meta["grapharg"], BackwardStateGraphArg
  2476. ):
  2477. remove_unused(node)
  2478. else:
  2479. # Register the free symbols as uses
  2480. arg = node.meta["grapharg"]
  2481. if isinstance(arg, BackwardStateGraphArg):
  2482. continue
  2483. if isinstance(node.meta["grapharg"].example, torch.ScriptObject):
  2484. real_script_obj = node.meta["grapharg"].example
  2485. fake_script_obj = node.meta["grapharg"].example_strong_ref
  2486. if not torch._library.fake_class_registry.tracing_with_real(
  2487. real_script_obj
  2488. ):
  2489. flat_dict = dict(real_script_obj.__obj_flatten__()) # type: ignore[attr-defined]
  2490. for attr in flat_dict:
  2491. fake_attr_val = getattr(
  2492. fake_script_obj.wrapped_obj, attr
  2493. )
  2494. pytree.tree_map_only(
  2495. (torch.SymInt, torch.Tensor),
  2496. lambda t: update_used_symbols(used_symbols, t),
  2497. fake_attr_val,
  2498. )
  2499. continue
  2500. if is_opaque_type(type(node.meta["grapharg"].example)):
  2501. continue
  2502. fake = (
  2503. arg.fake_tensor if arg.fake_tensor is not None else arg.example
  2504. )
  2505. update_used_symbols(used_symbols, fake)
  2506. # After removing unused graphargs, prune unused binds_symbol
  2507. for node in recheck_placeholders:
  2508. symbol = placeholder_binds_symbol(node)
  2509. if symbol is not None:
  2510. if symbol not in used_symbols:
  2511. remove_unused(node)
  2512. else:
  2513. # Make sure we delete later occurrences of the same symbol
  2514. used_symbols.remove(symbol)
  2515. def remove_tensorify_specialized_graphargs(self) -> None:
  2516. # This is a pretty interesting function. Basically we have this problem
  2517. # where our compiler tends to choke when we have unused inputs. The way
  2518. # we support dynamic float arguments is by doing a joint fx pass and
  2519. # tensorifying away as many symfloats as we can. For the remaining symfloats
  2520. # we have no choice but to specialize... HOWEVER at that point in time
  2521. # we can no longer remove graph inputs. So our sledgehammer solution is to
  2522. # save the state of what inputs we should have specialized in dynamo and
  2523. # restart analysis. This function incorporates this "view from the future"
  2524. # state and specializes inputs that we know we won't be able to tensorify
  2525. # away in the joint pass. In principle we shouldn't choke on unused inputs
  2526. # and so this shouldn't be necessary. In practice CUDA graphs choke on
  2527. # unused inputs so we need this for now.
  2528. # Import here to prevent circular import
  2529. from torch._dynamo.symbolic_convert import TensorifyState
  2530. for node in self.graph.nodes:
  2531. example_value = node.meta.get("example_value")
  2532. if (
  2533. isinstance(example_value, FakeTensor)
  2534. and example_value.item_memo is not None
  2535. and hasattr(example_value.item_memo.node._expr, "name")
  2536. and all(u.target == "item" for u in node.users)
  2537. and TensorifyState.should_specialize(
  2538. # We use _expr instead of expr b/c we want the symbol not the replacement
  2539. example_value.item_memo.node._expr.name
  2540. )
  2541. ):
  2542. for u in list(node.users):
  2543. u.replace_all_uses_with(guard_scalar(example_value.item_memo))
  2544. self.remove_node(u)
  2545. self.remove_node(node)
  2546. def add_output_instructions(self, prefix: list[Instruction]) -> None:
  2547. """
  2548. We call this on the creation of a new compiled subgraph that is inserted
  2549. before user code.
  2550. """
  2551. self.output_instructions.extend(prefix)
  2552. self.should_exit = True
  2553. def install_global_unsafe(self, name: str, value: Any) -> None:
  2554. """
  2555. WARNING: prefer the safer `install_global_by_id/install_global`.
  2556. torch.compile instances should be independent of each other;
  2557. one footgun is to have one instance depend on the existence of
  2558. a global installed by another instance. This can happen if we mangle
  2559. a global the same way across both instances.
  2560. """
  2561. assert name not in self.installed_globals
  2562. self.installed_globals.add(name)
  2563. self.cleanups.append(CleanupHook.create(self.global_scope, name, value))
  2564. def install_global_by_id(self, prefix: str, value: Any) -> str:
  2565. """
  2566. Installs a global if it hasn't been installed already.
  2567. This is determined by (prefix, id(value)) pair.
  2568. Returns the name of the newly installed global.
  2569. """
  2570. # NB: need self.compile_id to distinguish this global
  2571. # from another global created in a different torch.compile instance
  2572. name = f"{prefix}_{id(value)}_c{self.compile_id}"
  2573. if name in self.installed_globals:
  2574. return name
  2575. self.install_global_unsafe(name, value)
  2576. return name
  2577. def install_global(self, prefix: str, value: Any) -> str:
  2578. """
  2579. Installs a global, generating a unique name for it.
  2580. Returns the name of the newly installed global.
  2581. """
  2582. # NB: unique_id is unique, even across torch.compile instances
  2583. name = unique_id(prefix)
  2584. self.install_global_unsafe(name, value)
  2585. return name
  2586. def cleanup(self) -> None:
  2587. # There is a reference cycle between tracer and OutputGraph, causing
  2588. # some of the tensor objects to be held alive for longer than necessary.
  2589. self.root_tx = None # type: ignore[assignment]
  2590. self.nn_modules.clear()
  2591. self.used_inlined_inbuilt_modules_names.clear()
  2592. self.param_name_to_source = None
  2593. for node in self.graph.nodes:
  2594. if "grapharg" in node.meta:
  2595. del node.meta["grapharg"]
  2596. self.real_value_cache.clear()
  2597. self.input_name_to_proxy.clear()
  2598. self.side_effects.clear()
  2599. self.variable_tracker_cache.clear()
  2600. self.signature_cache.clear()
  2601. self.register_finalizer_fns.clear()
  2602. self.dynamo_flat_name_to_original_fqn.clear()
  2603. self.tracing_context.clear()
  2604. self.input_source_to_var.clear()
  2605. self.leaf_var_creation_order.clear()
  2606. self.unspec_variable_map.clear()
  2607. self.backward_state.clear()
  2608. def add_graph_finalizer(
  2609. self, register_finalizer: Callable[[fx.GraphModule], None]
  2610. ) -> None:
  2611. self.register_finalizer_fns.append(register_finalizer)
  2612. def example_value_from_input_node(self, node: torch.fx.Node) -> Any:
  2613. """Extract the non-fake example tensor"""
  2614. if node.op == "placeholder":
  2615. return node.meta["grapharg"].example
  2616. assert node.op == "get_attr"
  2617. return self.nn_modules[node.target] # type: ignore[index]
  2618. def add_fqn_info_for_inlined_modules(
  2619. self, inlined_module: torch.nn.Module, source: Source
  2620. ) -> None:
  2621. name = OutputGraph.module_key_name(source.name)
  2622. name = get_unique_name_wrt(
  2623. name, self.used_inlined_inbuilt_modules_names, self.global_scope
  2624. )
  2625. self.used_inlined_inbuilt_modules_names.add(name)
  2626. def register_leaf_name(leaf_name: str) -> None:
  2627. assert self.param_name_to_source is not None
  2628. new_source = self.get_chained_param_buffer_source(source, leaf_name)
  2629. new_name = f"{name}.{leaf_name}"
  2630. self.param_name_to_source[new_name] = new_source
  2631. if isinstance(source, LocalSource):
  2632. self.dynamo_flat_name_to_original_fqn[
  2633. OutputGraph.module_key_name(new_source.name)
  2634. ] = leaf_name
  2635. # annoying, but there are cases when we do not have parameters
  2636. # see test_nn_moduledict_contains
  2637. if hasattr(inlined_module, "_parameters"):
  2638. if (
  2639. callable(inlined_module.named_parameters)
  2640. and inlined_module.named_parameters.__func__ # type: ignore[attr-defined]
  2641. is og_module_named_parameters_fn_ptr
  2642. ):
  2643. for leaf_name, _ in inlined_module.named_parameters():
  2644. register_leaf_name(leaf_name)
  2645. if hasattr(inlined_module, "_buffers"):
  2646. if (
  2647. callable(inlined_module.named_buffers)
  2648. and inlined_module.named_buffers.__func__ # type: ignore[attr-defined]
  2649. is og_module_named_buffers_fn_ptr
  2650. ):
  2651. for leaf_name, _ in inlined_module.named_buffers():
  2652. register_leaf_name(leaf_name)
  2653. class DynamoTracerOutput:
  2654. error_on_graph_break: bool
  2655. is_tracing_resume_prologue: bool
  2656. output_graph: Optional[OutputGraph]
  2657. # output_graph_for_cleanup is set even when there's an error, to allow
  2658. # cleanup of graph nodes to break reference cycles
  2659. output_graph_for_cleanup: Optional[OutputGraph]
  2660. closure: Optional[tuple[Any, ...]]
  2661. f_globals: dict[str, Any]
  2662. def __init__(
  2663. self, tracer: "InstructionTranslatorBase", error: Optional[Any] = None
  2664. ) -> None:
  2665. self.error_on_graph_break = tracer.error_on_graph_break
  2666. self.is_tracing_resume_prologue = tracer.is_tracing_resume_prologue
  2667. self.closure = tracer.closure
  2668. self.f_globals = tracer.f_globals
  2669. self.output_graph_for_cleanup = tracer.output
  2670. if error:
  2671. self.output_graph = None
  2672. else:
  2673. self.output_graph = tracer.output
  2674. def _cleanup_output_graph(self) -> None:
  2675. output_graph = self.output_graph_for_cleanup
  2676. if output_graph:
  2677. for tracer in output_graph.tracers:
  2678. tracer.graph._clear_nodes()
  2679. # Also clear tracked_fakes to break FakeTensorMode → ShapeEnv → TrackedFake → FakeTensor cycle
  2680. if (
  2681. output_graph.tracing_context.fake_mode
  2682. and output_graph.tracing_context.fake_mode.shape_env
  2683. ):
  2684. output_graph.tracing_context.fake_mode.shape_env.tracked_fakes = None
  2685. err_epilogue = (
  2686. "With the current config, we will graph break "
  2687. "(and fall back to eager-mode PyTorch) on all ops "
  2688. "that have do not have the 'pt2_compliant_tag'. "
  2689. "Please see the following doc for how to mark this op as PT2 compliant "
  2690. "https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html"
  2691. )
  2692. def check_pt2_compliant_op(
  2693. output_graph: OutputGraph, kind: str, target: Any, args: Any, kwargs: Any
  2694. ) -> None:
  2695. if kind != "call_function":
  2696. return
  2697. def encountered_compliant_op(target: torch._ops.OpOverload) -> None:
  2698. if target.namespace in {"prim", "prims", "aten"}:
  2699. return
  2700. output_graph.compliant_custom_ops.add(target)
  2701. def encountered_non_compliant_op(target: torch._ops.OpOverload, msg: str) -> None:
  2702. output_graph.non_compliant_ops.add(target)
  2703. if config.only_allow_pt2_compliant_ops:
  2704. unimplemented(
  2705. gb_type="Encountered non-PT2-compliant op",
  2706. context="",
  2707. explanation=msg + " " + err_epilogue,
  2708. hints=[],
  2709. )
  2710. if isinstance(target, torch._ops.OpOverload):
  2711. if torch.Tag.pt2_compliant_tag in target.tags:
  2712. encountered_compliant_op(target)
  2713. return
  2714. encountered_non_compliant_op(
  2715. target,
  2716. f"Encountered the torch.ops.OpOverload {target} that is not PT2 compliant.",
  2717. )
  2718. return
  2719. if isinstance(target, torch._ops.OpOverloadPacket):
  2720. overloads = tuple(target.overloads())
  2721. # Optimization: Overload resolution is expensive.
  2722. # If there's only one overload, we know what it will resolve to.
  2723. if len(overloads) == 1:
  2724. op = getattr(target, overloads[0])
  2725. if torch.Tag.pt2_compliant_tag in op.tags:
  2726. encountered_compliant_op(op)
  2727. return
  2728. encountered_non_compliant_op(
  2729. op,
  2730. f"Encountered the non-overloaded "
  2731. f"torch.ops.OpOverloadPacket {target} "
  2732. f"that is not PT2 compliant. ",
  2733. )
  2734. return
  2735. args, kwargs = torch._dynamo.utils.get_fake_values_from_nodes(
  2736. output_graph.current_tx, (args, kwargs), False
  2737. )
  2738. try:
  2739. overload = torch._C._jit_resolve_packet(
  2740. target._qualified_op_name, *args, **kwargs
  2741. )
  2742. except RuntimeError as e:
  2743. unimplemented(
  2744. gb_type="Error when attempting to resolve op packet",
  2745. context="",
  2746. explanation=str(e),
  2747. hints=[],
  2748. )
  2749. # pyrefly: ignore [unbound-name]
  2750. op = getattr(target, overload)
  2751. if torch.Tag.pt2_compliant_tag in op.tags:
  2752. encountered_compliant_op(op)
  2753. else:
  2754. encountered_non_compliant_op(
  2755. op,
  2756. f"Encountered the torch.ops.OpOverloadPacket {target} "
  2757. # pyrefly: ignore [unbound-name]
  2758. f"which resolves to the overload ({overload}) that is "
  2759. f"not PT2 compliant.",
  2760. )
  2761. _compile_id_counter = itertools.count()
  2762. P = ParamSpec("P")
  2763. R = TypeVar("R")
  2764. class LazyProxy:
  2765. def __init__(
  2766. self,
  2767. tracer: "SubgraphTracer",
  2768. fn: Callable[P, R],
  2769. *args: P.args,
  2770. **kwargs: P.kwargs,
  2771. ) -> None:
  2772. self.tracer = tracer
  2773. # pyrefly: ignore [invalid-type-var]
  2774. self.fn = fn
  2775. self.args = args
  2776. self.kwargs = kwargs
  2777. def __call__(self) -> Any:
  2778. return self.fn(*self.args, **self.kwargs)
  2779. class SubgraphTracer(fx.Tracer):
  2780. """
  2781. Holds an FX graph that is being traced. OutputGraph owns a SubgraphTracer
  2782. and the separation of responsibilities is that SubgraphTracer is
  2783. responsible for building the graph while OutputGraph is responsible for
  2784. compiling and executing the graph.
  2785. """
  2786. def __init__(
  2787. self,
  2788. output_graph: "OutputGraph",
  2789. parent: Optional["SubgraphTracer"] = None,
  2790. is_export: bool = False,
  2791. source_target: Optional[Target] = None,
  2792. description: Optional[str] = None,
  2793. ) -> None:
  2794. super().__init__()
  2795. self.output_graph = weakref.proxy(output_graph)
  2796. self.graph = torch.fx.Graph()
  2797. # See note [Export inputs must be explicitly passed in]
  2798. self.is_export = is_export
  2799. # Map from graph input name to its placeholder proxy object, where the
  2800. # map's keys give all current placeholder node names and can be used to
  2801. # create unique node names
  2802. self.input_name_to_proxy: dict[str, fx.Proxy] = {}
  2803. # Node => computed real value (see utils.get_real_value)
  2804. self.real_value_cache: dict[fx.Node, torch.Tensor] = {}
  2805. # SubgraphTracers can be nested. See NOTE [HigherOrderOperator tracing design]
  2806. self.parent = parent
  2807. self.source_target = source_target
  2808. self.description = description
  2809. # A dict mapping previously free variables (Proxy objects)
  2810. # to new Proxy objects that wrap inputs to this subgraph.
  2811. #
  2812. # This dict maps proxies in outer graphs to placeholders in current graph.
  2813. # It serves two purposes:
  2814. # - Proxies are associated with VariableTrackers. If we see
  2815. # the same VariableTracker twice (and it is a free variable),
  2816. # then we want to use the same Proxy in the current subgraph to
  2817. # record the tracing.
  2818. # - If we are tracing a HigherOrderOperator's body_fn, then we
  2819. # need to keep track of what free variables were lifted so we can
  2820. # rewrite the HigherOrderOperator call using the traced body_fn.
  2821. # Dicts maintain the order of args for the HigherOrderOperator call.
  2822. self.lifted_freevars: dict[fx.Proxy, fx.Proxy] = {}
  2823. # map basic symbols (unbacked and unbacked) to their bound proxies.
  2824. # There are only two cases where bound_symbols will be recorded:
  2825. # 1. when we create_graph_input for a backed SymInt that's basic symbol
  2826. # 2. when we track_produced_symints for intermediate results
  2827. # bound_symbols always map the symbol to the proxy whose
  2828. # tracer is the current tracer that's readily accessible in current tracer's graph.
  2829. self.bound_symbols: dict[sympy.Symbol, Union[torch.fx.Proxy, LazyProxy]] = {}
  2830. # Maps _DynamicScalar object ids to allocated SymInt nodes, for symbol reuse
  2831. self.dynamic_scalar_nodes: dict[int, torch.SymInt] = {}
  2832. self.prev_inst = None
  2833. # True if we want to allow externally visible side-effects (doesn't throw error on their existence)
  2834. # during this tracer's tracing. This is currently only used by experimental AC out-of-tree
  2835. # via torch._dynamo.utils._disable_side_effect_safety_checks_for_current_subtracer.
  2836. # Note: Externally visible side-effects are allowed if this flag OR the above flag is True.
  2837. self.unsafe_allow_externally_visible_side_effects = False
  2838. self.traced_with_externally_visible_side_effects = False
  2839. # True if we want to allow side effects by returning them as extra outputs from the subgraph.
  2840. # This is set when enable_side_effects_in_hop=True for HOPs like invoke_subgraph
  2841. # and checkpoint (when skip_fwd_side_effects_in_bwd_under_checkpoint config is True).
  2842. self.allow_side_effects_in_hop = False
  2843. # True if this tracer is currently tracing (reconstructing) into a Python generator
  2844. self.is_reconstructing_generator = False
  2845. self.debug_level: int = parent.debug_level + 1 if parent is not None else 0
  2846. self._cur_code = None
  2847. self._orig_gm_meta: Optional[list[Any]] = None
  2848. self._orig_gm_lineno_map: Optional[dict[int, Optional[int]]] = None
  2849. self._orig_gm_firstlineno: Optional[int] = None
  2850. # Each SubgraphTracer is associated with a source target, which indicates
  2851. # which operator this subgraph is attached to. We compute a source_fn_stack
  2852. # based on the source target. For the root tracer, it's set to [].
  2853. # This is useful for debugging and transforming the exported graph.
  2854. if self.parent is None:
  2855. self.source_fn_stack: list[Any] = []
  2856. else:
  2857. self.source_fn_stack = self.parent.source_fn_stack + [
  2858. (self.graph._target_to_str(source_target), source_target)
  2859. ]
  2860. # This is used to create a unique name for the placeholder
  2861. self._used_names: OrderedSet[str] = OrderedSet()
  2862. # Stores the versions of the input tensors at the time they are inserted
  2863. # as placeholders in the graph. This is used to track input mutation.
  2864. self._input_versions_at_beginning: list[int] = []
  2865. if torch.is_inference_mode_enabled():
  2866. raise RuntimeError(
  2867. "Inference mode is supposed to be disabled during compilation. Please open an issue."
  2868. )
  2869. self.tracked_tensor_or_symint_vt: OrderedSet[VariableTracker] = OrderedSet()
  2870. def record_tensor_or_symint_vt(self, vt: VariableTracker) -> None:
  2871. self.tracked_tensor_or_symint_vt.add(vt)
  2872. # preserve original meta if it is available
  2873. def _maybe_preserve_original_meta(
  2874. self, tx: "InstructionTranslatorBase", node: fx.Node
  2875. ) -> None:
  2876. if (
  2877. self._orig_gm_meta
  2878. and self._orig_gm_lineno_map
  2879. and self._orig_gm_firstlineno
  2880. ):
  2881. lineno = tx.current_instruction.starts_line
  2882. node_idx = None
  2883. if lineno is not None:
  2884. node_idx = self._orig_gm_lineno_map.get(
  2885. lineno - self._orig_gm_firstlineno, None
  2886. )
  2887. if node_idx is not None:
  2888. meta = self._orig_gm_meta[node_idx]
  2889. for field in fx.proxy._COPY_META_FIELDS:
  2890. if field in meta:
  2891. node.meta[field] = meta[field]
  2892. if "stack_trace" in meta:
  2893. node.meta["stack_trace"] = meta["stack_trace"]
  2894. def create_proxy(
  2895. self,
  2896. kind: str,
  2897. target: Any,
  2898. args: Any,
  2899. kwargs: Any,
  2900. name: Optional[str] = None,
  2901. type_expr: Optional[Any] = None,
  2902. proxy_factory_fn: Optional[Callable[[fx.Node], fx.Proxy]] = None,
  2903. ) -> fx.Proxy:
  2904. _t0 = time.time_ns()
  2905. try:
  2906. return self._create_proxy(
  2907. kind, target, args, kwargs, name, type_expr, proxy_factory_fn
  2908. )
  2909. finally:
  2910. self.output_graph.bytecode_tracing_timings.create_proxy_ns += (
  2911. time.time_ns() - _t0
  2912. )
  2913. def _create_proxy(
  2914. self,
  2915. kind: str,
  2916. target: Any,
  2917. args: Any,
  2918. kwargs: Any,
  2919. name: Optional[str] = None,
  2920. type_expr: Optional[Any] = None,
  2921. proxy_factory_fn: Optional[Callable[[fx.Node], fx.Proxy]] = None,
  2922. ) -> fx.Proxy:
  2923. # NOTE: [Nested SubgraphTracer and free_variable handling]
  2924. # --------------------------------------------------------
  2925. # Read NOTE [HigherOrderOperator tracing design] first.
  2926. #
  2927. # Let's say we're in the middle of introspecting the body of a possibly
  2928. # nested HigherOrderOperator, and we see a free variable.
  2929. #
  2930. # There are two cases:
  2931. # 1. We see a free variable that is already tracked by Dynamo.
  2932. # 2. We see a free variable that has not been tracked by Dynamo
  2933. #
  2934. # In case 1, we call `maybe_lift_tracked_freevar_to_input` (below)
  2935. # which will lift the freevar to be an input of this subgraph
  2936. # and also recursively lift it to be an input on the parent(s).
  2937. #
  2938. # In case 2, before the call to `create_proxy`, the InstructionTranslator
  2939. # will see the freevar when it gets loaded by Python bytecode.
  2940. # E.g. for Python 3.11 the bytecodes that may do this are LOAD_DEREF or
  2941. # LOAD_GLOBAL.
  2942. # There, the InstructionTranslator asks Dynamo to begin tracking the
  2943. # freevar by building a new Variable.
  2944. # Building a new Variable automatically lifts the freevar to be an
  2945. # input of the root SubgraphTracer.
  2946. #
  2947. # The implications for the code below are:
  2948. # - We will always be in Case 1 when we get to this code.
  2949. # - Any "free variable" we encounter here is guaranteed to already be
  2950. # bound, that is, it is either a graph input of the root graph, or
  2951. # some local variable of the root graph or a subgraph.
  2952. # - The additional work we need to do here is *only* that we need to
  2953. # lift this free variable into inputs (recursively) of each nested
  2954. # higher-order-op subgraph until we hit the subgraph where the free
  2955. # variable is bound
  2956. if self.parent is not None:
  2957. flat_args, tree_spec = pytree.tree_flatten((args, kwargs))
  2958. new_flat_args = []
  2959. for arg in flat_args:
  2960. maybe_new_arg = self.maybe_lift_tracked_freevar_to_input(arg)
  2961. new_flat_args.append(maybe_new_arg)
  2962. args, kwargs = pytree.tree_unflatten(new_flat_args, tree_spec)
  2963. rv = super().create_proxy(
  2964. kind,
  2965. target,
  2966. args,
  2967. kwargs,
  2968. name,
  2969. type_expr,
  2970. proxy_factory_fn, # type: ignore[arg-type]
  2971. )
  2972. # append stack trace to fx node
  2973. tx = self.output_graph.current_tx
  2974. # log detailed location of line of code in 3.11
  2975. if sys.version_info >= (3, 11) and kind in (
  2976. "call_function",
  2977. "call_method",
  2978. "call_module",
  2979. ):
  2980. cur_inst = tx.current_instruction
  2981. if (
  2982. cur_inst is not self.prev_inst
  2983. and cur_inst.positions is not None
  2984. and cur_inst.positions.lineno is not None
  2985. ):
  2986. tx_code = tx.f_code
  2987. header = tx.get_line_of_code_header(lineno=cur_inst.positions.lineno)
  2988. def get_trace_call_log_str() -> str:
  2989. line = get_instruction_source_311(tx_code, cur_inst).rstrip()
  2990. return f"TRACE FX call {rv.node.name} from {header}\n{line}"
  2991. trace_call_log.debug("%s", LazyString(get_trace_call_log_str))
  2992. self.prev_inst = cur_inst
  2993. # update reference to original meta if we're tracing a new code object
  2994. is_retracing = False
  2995. if tx.f_code is not self._cur_code:
  2996. orig_graphmodule_maybe = code_context.get_context(tx.f_code).get(
  2997. "orig_graphmodule", lambda: None
  2998. )()
  2999. if isinstance(orig_graphmodule_maybe, torch.fx.GraphModule):
  3000. is_retracing = True
  3001. self._orig_gm_meta = [
  3002. nd.meta for nd in orig_graphmodule_maybe.graph.nodes
  3003. ]
  3004. self._orig_gm_lineno_map = orig_graphmodule_maybe._lineno_map
  3005. self._orig_gm_firstlineno = (
  3006. orig_graphmodule_maybe.forward.__code__.co_firstlineno
  3007. )
  3008. else:
  3009. self._orig_gm_meta = None
  3010. self._orig_gm_lineno_map = None
  3011. self._orig_gm_firstlineno = None
  3012. nn_module_stack = tx.nn_module_stack
  3013. if nn_module_stack:
  3014. rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
  3015. if kind in {"call_function", "call_method"}:
  3016. stack = (rv.node.name, target)
  3017. if nn_module_stack:
  3018. # Current codebase assumes that the nn_module_stack has the
  3019. # builtin modules in the stack.
  3020. current_nn_module = list(rv.node.meta["nn_module_stack"].values())[-1][
  3021. 1
  3022. ]
  3023. if current_nn_module.__module__.startswith(
  3024. ("torch.nn.modules", "torch.ao.")
  3025. ) and not current_nn_module.__module__.startswith(
  3026. "torch.nn.modules.container"
  3027. ):
  3028. stack = (rv.node.name, current_nn_module)
  3029. rv.node.meta["source_fn_stack"] = self.source_fn_stack + [stack]
  3030. elif kind == "call_module":
  3031. if self.parent is not None:
  3032. # TODO can remove once inline_inbuilt_nn_modules is always True
  3033. unimplemented(
  3034. gb_type="Invoking an nn.Module inside a higher order operator",
  3035. context=f"Higher order op name: {self.source_target}",
  3036. explanation="This is not supported.",
  3037. hints=[],
  3038. )
  3039. # For modules we store the class
  3040. rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
  3041. (
  3042. rv.node.name,
  3043. next(
  3044. ty
  3045. for k, (_, ty) in rv.node.meta["nn_module_stack"].items()
  3046. if k.split("@")[0] == target
  3047. ),
  3048. )
  3049. ]
  3050. self._maybe_preserve_original_meta(tx, rv.node)
  3051. if not is_retracing:
  3052. if "nn_module_stack" not in rv.node.meta:
  3053. nn_module_stack = tx.nn_module_stack
  3054. if nn_module_stack:
  3055. rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
  3056. if "source_fn_stack" not in rv.node.meta:
  3057. if kind in {"call_function", "call_method"}:
  3058. rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
  3059. (rv.node.name, target)
  3060. ]
  3061. elif kind == "call_module":
  3062. if self.parent is not None:
  3063. # TODO can remove once inline_inbuilt_nn_modules is always True
  3064. unimplemented(
  3065. gb_type="Invoking an nn.Module inside a HigherOrderOperator",
  3066. context="",
  3067. explanation="This is not supported.",
  3068. hints=[],
  3069. )
  3070. # For modules we store the class
  3071. rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
  3072. (
  3073. rv.node.name,
  3074. rv.node.meta["nn_module_stack"][target][1],
  3075. )
  3076. ]
  3077. if "stack_trace" not in rv.node.meta:
  3078. frame_summaries: list[traceback.FrameSummary] = []
  3079. while tx:
  3080. # Avoid frame summaries from inside the torch/nn/modules. This ensures that we keep the stack trace of
  3081. # the user code.
  3082. if not tx.is_co_filename_from_nn_modules():
  3083. frame_summaries.append(tx.frame_summary())
  3084. tx = getattr(tx, "parent", None)
  3085. filtered_frame_summaries = [
  3086. frame
  3087. for frame in frame_summaries
  3088. if frame.filename not in uninteresting_files()
  3089. ]
  3090. # Reverse the frame_summaries, such that the innermost frame is at the last
  3091. filtered_frame_summaries.reverse()
  3092. # official from_list stub doesn't have new-style type
  3093. msgs = traceback.StackSummary.from_list(filtered_frame_summaries).format()
  3094. rv.node.stack_trace = "".join(msgs)
  3095. if (
  3096. torch._dynamo.config.use_graph_deduplication
  3097. or torch._dynamo.config.track_nodes_for_deduplication
  3098. ):
  3099. self.output_graph.region_tracker.track_node(
  3100. self.output_graph.current_tx, rv.node
  3101. )
  3102. return rv
  3103. def create_node(
  3104. self,
  3105. kind: str,
  3106. target: Target,
  3107. args: Any = None,
  3108. kwargs: Any = None,
  3109. name: Optional[str] = None,
  3110. type_expr: Optional[Any] = None,
  3111. ) -> fx.Node:
  3112. check_pt2_compliant_op(self.output_graph, kind, target, args, kwargs)
  3113. if self.parent is not None:
  3114. flat_args = pytree.arg_tree_leaves(*args, **kwargs)
  3115. for arg in flat_args:
  3116. if not isinstance(arg, torch.fx.Node):
  3117. continue
  3118. assert arg.graph == self.graph, (
  3119. "create_node using arg not from this SubgraphTracer"
  3120. )
  3121. node = super().create_node(kind, target, args, kwargs, name, type_expr)
  3122. node.meta["creation_timestamp"] = self.output_graph.timestamp
  3123. self._used_names.add(node.name)
  3124. return node
  3125. # Note: we did not override erase_node since
  3126. # we call self.graph.erase_node elsewhere
  3127. def remove_node(self, node: fx.Node) -> None:
  3128. if len(node.users) > 0:
  3129. user_graph_nodes: list[torch.fx.Node] = []
  3130. for user in node.users:
  3131. # For the case where user.graph == self.graph, that is a real bug and will raise
  3132. # properly.
  3133. if user.graph != self.graph:
  3134. # This is a nested graph, which needs to be deleted.
  3135. # If we do not do this, we will raise on attempting to remove this.
  3136. # As we only get here during restoration cleanup, this is sound.
  3137. user_graph_nodes.extend(reversed(list(user.graph.nodes)))
  3138. for other_graph_node in user_graph_nodes:
  3139. other_graph_node.graph.erase_node(other_graph_node)
  3140. self.graph.erase_node(node)
  3141. self.input_name_to_proxy.pop(node.name, None)
  3142. # when before=True, we will insert this input before the most recent
  3143. # inserted proxy. This is a hack to get around an ordering problem,
  3144. # where we first insert a tensor argument, and then insert bindings
  3145. # for SymInts that may occur in the tensor argument.
  3146. # Remove this if https://github.com/pytorch/pytorch/issues/99007 gets
  3147. # fixed.
  3148. def create_graph_input(
  3149. self,
  3150. name: str,
  3151. type_expr: Any,
  3152. example_value: Any,
  3153. before: bool = False,
  3154. source: Optional[Source] = None,
  3155. ) -> fx.Proxy:
  3156. if isinstance(example_value, torch.Tensor):
  3157. self._input_versions_at_beginning.append(example_value._version)
  3158. log.debug(
  3159. "create_graph_input %s %s %s at debug_level %s before=%s",
  3160. name,
  3161. source.name if source is not None else "(none)",
  3162. example_value,
  3163. self.debug_level,
  3164. before,
  3165. )
  3166. if source is None:
  3167. assert self.parent is not None, (
  3168. f"you are required to provide a source for inputs {name} example_val {example_value} on the root tracer"
  3169. )
  3170. # Note [Export inputs must be explicitly passed in]
  3171. # In eager, we are generally OK with adding graph inputs whenever we
  3172. # want, because we take care of writing the bytecode that knows how
  3173. # to source all the inputs.
  3174. #
  3175. # In export, this is bad, because you want a self-contained export
  3176. # object which only depends on the inputs you explicitly passed to it.
  3177. # So we are a bit more strict about what sources can become inputs
  3178. # in export
  3179. if self.is_export and self.parent is None:
  3180. assert source is not None
  3181. if not is_from_local_source(source, only_allow_input=True):
  3182. self.output_graph.source_to_user_stacks.setdefault(source, []).append(
  3183. TracingContext.extract_stack()
  3184. )
  3185. # _used_names contains the names of all the nodes in the graph,
  3186. # including intermediates. This ensures that we do not have a name
  3187. # collision.
  3188. name = get_unique_name_wrt(name, self._used_names)
  3189. if self.input_name_to_proxy:
  3190. prev_name = next(reversed(self.input_name_to_proxy))
  3191. node = self.input_name_to_proxy[prev_name].node
  3192. if before:
  3193. ctx = self.graph.inserting_before(node)
  3194. else:
  3195. ctx = self.graph.inserting_after(node)
  3196. else:
  3197. ctx = self.graph.inserting_before(None)
  3198. with ctx:
  3199. proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr)
  3200. set_example_value(proxy.node, example_value)
  3201. if self.input_name_to_proxy and before:
  3202. k, v = self.input_name_to_proxy.popitem()
  3203. self.input_name_to_proxy[name] = proxy
  3204. self.input_name_to_proxy[k] = v
  3205. else:
  3206. self.input_name_to_proxy[name] = proxy
  3207. # For placeholder nodes, `name` is passed as a str to the target,
  3208. # and then torch.fx decides the node.name. So, record the `target`
  3209. # name as well in the _used_names to prevent any collision.
  3210. self._used_names.add(name)
  3211. # NOTE: [Auto lift basic free symbols when create_graph_input]
  3212. # There are two sources of basic symbols:
  3213. #
  3214. # - They can come from inputs, e.g. when an input tensor is specified as dynamic. We handle
  3215. # this case by intercepting at create_graph_input. Whenever we call create_graph_input, we
  3216. # try to also lift the basic symbols in example values as graph input.
  3217. #
  3218. # 1. When create_graph_input for a tensor that has symbolic shapes,
  3219. # we look for basic symbols in its size and stride, we check if the symbol is bound
  3220. # in current graph (i.e. bound_symbols), it it's not bound, we'll create a placeholder
  3221. # for it then recursively check its parent, creates ph if not bound at parent until.
  3222. # reachting the top-level, where we require a source is attached to the proxy.
  3223. #
  3224. # 2. When create_graph_input for a tensor that contains compound exprs,
  3225. # for example, if an input to subgraph takes size [s1+s2//8], we'll look for the
  3226. # the free basic symbols in the sizes and lift all of them following 1.
  3227. #
  3228. # 3. When create_graph_input for a symint. The following invariants hold:
  3229. # a. if symint's expr is a basic symbol, we only lift it once.
  3230. # b. if symint's expr is compuned, we lift the expr as a single input. We won't lift The basic symbols
  3231. # in the compuned expr are NOT lifted. Because if the basic symbols are used inside the subgraph
  3232. # they will be lifted according to 3.a
  3233. #
  3234. # - They can come from intermediate results:
  3235. # For example, data-dependent operators such as t.item(), t.nonzero(), where basic symbols
  3236. # might be created. For this purpose, we track the basic symbols of intermediate results
  3237. # immediately after they're created at wrap_fx_proxy with track_produced_symints. Notice
  3238. # that for basic symbols that're already tracked by create_graph_input, we won't track it again.
  3239. #
  3240. # Also see NOTE: [Export inputs must be explicitly passed in]
  3241. is_strict_export = self.is_export
  3242. is_non_strict_export = torch.compiler.is_compiling()
  3243. if not is_strict_export and not is_non_strict_export:
  3244. if isinstance(example_value, torch.Tensor):
  3245. self._lift_basic_symbols(example_value, source)
  3246. elif isinstance(example_value, (list, tuple)):
  3247. for i, e in enumerate(example_value):
  3248. if not isinstance(e, torch.Tensor):
  3249. continue
  3250. e_source = None
  3251. if source:
  3252. e_source = GetItemSource(
  3253. base=source,
  3254. index=i,
  3255. index_is_slice=False,
  3256. )
  3257. self._lift_basic_symbols(e, e_source)
  3258. # Bound the symbol to ph if example_value is a SymInt with basic symbol.
  3259. if isinstance(example_value, torch.SymInt) and isinstance(
  3260. example_value.node.expr, sympy.Symbol
  3261. ):
  3262. self.bound_symbols[example_value.node.expr] = proxy
  3263. return proxy
  3264. # See NOTE: [Nested SubgraphTracer and free_variable handling] for more details
  3265. def lift_tracked_freevar_to_input(
  3266. self, proxy: fx.Proxy
  3267. ) -> Union[LazyProxy, fx.Proxy]:
  3268. # You're doing something wrong if we are the root SubgraphTracer because
  3269. # Dynamo adds tensors to graph inputs before creating a proxy for them.
  3270. assert self.parent is not None, (
  3271. "lift_tracked_freevar_to_input should not be called on root SubgraphTracer"
  3272. )
  3273. example_value = proxy.node.meta["example_value"]
  3274. # To avoid lifting the same symbol twice, we check whether basic symbols has been tracked.
  3275. # For example, the basic symbols may have already been lifted for current subgraph when
  3276. # we automatically lift basic symbols in the sizes/strides of a tensor t.
  3277. # Suppose parent graph calls sz = t.size()[0], it creates
  3278. # a proxy in parent and the subgraph accesses sz via closure. sz's proxy is not tracked
  3279. # in current sub-tracer so we may lift the same symbol twice.
  3280. if (
  3281. isinstance(example_value, torch.SymInt)
  3282. and example_value.node.expr in self.bound_symbols
  3283. ):
  3284. return self.bound_symbols[example_value.node.expr]
  3285. # Proxies are associated with VariableTracker.
  3286. # It is possible that we've already lifted the Proxy to be an input.
  3287. # If that is the case, just return the already lifted Proxy.
  3288. if proxy in self.lifted_freevars:
  3289. return self.lifted_freevars[proxy]
  3290. # We first lift proxy to parent's graph then lift to current graph's input
  3291. # so that when we bind symints of the sizes in current graph, those symints
  3292. # would already be lifted as inputs to parent graph.
  3293. if proxy.tracer != self.parent:
  3294. self.parent.lift_tracked_freevar_to_input(proxy)
  3295. example_value = proxy.node.meta["example_value"]
  3296. type_expr = (
  3297. type(example_value.real_obj)
  3298. if isinstance(example_value, FakeScriptObject)
  3299. else type(example_value)
  3300. )
  3301. new_proxy = self.create_graph_input(proxy.node.name, type_expr, example_value)
  3302. self.lifted_freevars[proxy] = new_proxy
  3303. return new_proxy
  3304. def maybe_lift_tracked_freevar_to_input(self, arg: Any) -> Any:
  3305. """
  3306. If arg is a free variable, then lift it to be an input.
  3307. Returns the new lifted arg (if arg was a freevar), else the
  3308. original arg.
  3309. """
  3310. if not isinstance(arg, torch.fx.Proxy):
  3311. # Note: arg can be a python built-in slice type e.g.
  3312. # x[:max_seq] is represented as get_item(t, (slice(None, max_seq, None)))
  3313. # we need to also look into the slice variable itself to lift the
  3314. # proxies there.
  3315. if isinstance(arg, slice):
  3316. return slice(
  3317. *(
  3318. self.maybe_lift_tracked_freevar_to_input(sub_arg)
  3319. for sub_arg in (arg.start, arg.stop, arg.step)
  3320. )
  3321. )
  3322. else:
  3323. return arg
  3324. elif arg.tracer == self:
  3325. return arg
  3326. return self.lift_tracked_freevar_to_input(arg)
  3327. # See NOTE: [Auto lift basic free symbols when create_graph_input] for overall design
  3328. # You MUST call this API every time when creating a proxy in wrap_fx_proxy for a call
  3329. # that produced symints or tensors with unbacked symint shapes.
  3330. # This function is used to track the symints with its proxies created during
  3331. # dynamo tracing so that subgraph knows how to bind a symbol input with parent's proxy.
  3332. # LazyProxy are created for tensor shapes that're unbacked so that we don't create proxies
  3333. # for symbols that're not going to be used, the LazyProxy will be turned into a proxy
  3334. # when it's lifted as input to subgraph.
  3335. def track_produced_symints(
  3336. self, example_value: Any, e_proxy: Union[LazyProxy, torch.fx.Proxy]
  3337. ) -> None:
  3338. # When binding the symbols in an example_value, we bind the symbols
  3339. # to the proxy's associated Tracer instead of current tracer.
  3340. # This is because:
  3341. # 1. We may be calling wrap_tensors during speculate_subgraph because
  3342. # the variables are lazily realized. The proxy are top-level phs but
  3343. # current tracer is a subtracer.
  3344. # 2. For autograd.Function, we trace the backward graph with a new tracer
  3345. # whose parent is the forward tracer, but we're using all the proxies created
  3346. # in forward tracer to trace the backward.
  3347. # For example, forward calls save_for_backward for a input tensor t.
  3348. # Backward calls t.tolist(). In this case, all the proxies that backward tracer
  3349. # sees are from parent tracer (i.e. the forward tracer). (e.g. t[0].item())
  3350. # See test_validate_outputs_unbacked for repro on 2.
  3351. tracer = e_proxy.tracer
  3352. assert isinstance(tracer, SubgraphTracer)
  3353. def need_bind(s: Any) -> bool:
  3354. from torch.fx.experimental.symbolic_shapes import is_symbolic
  3355. return (
  3356. is_symbolic(s)
  3357. and isinstance(s.node.expr, sympy.Symbol)
  3358. and s.node.expr not in self.bound_symbols
  3359. )
  3360. def _proxy_with_example_value(
  3361. example_value: Any, *args: Any, **kwargs: Any
  3362. ) -> fx.Proxy:
  3363. # We need to insert proxy for creating sym_size/sym_stride/sym_storage right after e_proxy
  3364. nonlocal e_proxy
  3365. e_proxy = e_proxy() if isinstance(e_proxy, LazyProxy) else e_proxy
  3366. assert isinstance(e_proxy, torch.fx.Proxy)
  3367. with tracer.graph.inserting_after(e_proxy.node):
  3368. proxy = tracer.create_proxy(*args, **kwargs)
  3369. set_example_value(proxy.node, example_value)
  3370. return proxy
  3371. if isinstance(example_value, torch.Tensor):
  3372. for i, s in enumerate(example_value.size()):
  3373. if need_bind(s):
  3374. log.debug(
  3375. "track_produced_symints %s for %s.size()[%s] at debug_level %s",
  3376. s,
  3377. e_proxy,
  3378. i,
  3379. tracer.debug_level,
  3380. )
  3381. lazy_proxy = LazyProxy(
  3382. tracer,
  3383. _proxy_with_example_value,
  3384. s,
  3385. "call_function",
  3386. torch.ops.aten.sym_size.int,
  3387. (e_proxy, i),
  3388. {},
  3389. type_expr=type(s),
  3390. )
  3391. self.track_produced_symints(s, lazy_proxy)
  3392. storage_offset = example_value.storage_offset()
  3393. if need_bind(storage_offset):
  3394. log.debug(
  3395. "track_produced_symints %s for %s.storage_offset() at debug_level %s",
  3396. storage_offset,
  3397. e_proxy,
  3398. tracer.debug_level,
  3399. )
  3400. lazy_proxy = LazyProxy(
  3401. tracer,
  3402. _proxy_with_example_value,
  3403. storage_offset,
  3404. "call_function",
  3405. torch.ops.aten.sym_storage_offset,
  3406. (e_proxy,),
  3407. {},
  3408. type_expr=type(storage_offset),
  3409. )
  3410. self.track_produced_symints(storage_offset, lazy_proxy)
  3411. if example_value.layout is torch.strided:
  3412. for i, s in enumerate(example_value.stride()):
  3413. if need_bind(s):
  3414. log.debug(
  3415. "track_produced_symints %s for %s.stride()[%s] at debug_level %s",
  3416. s,
  3417. e_proxy,
  3418. i,
  3419. tracer.debug_level,
  3420. )
  3421. lazy_proxy = LazyProxy(
  3422. tracer,
  3423. _proxy_with_example_value,
  3424. s,
  3425. "call_function",
  3426. torch.ops.aten.sym_stride.int,
  3427. (e_proxy, i),
  3428. {},
  3429. type_expr=type(s),
  3430. )
  3431. self.track_produced_symints(s, lazy_proxy)
  3432. elif example_value.layout is torch.sparse_coo:
  3433. self.track_produced_symints(example_value._indices(), e_proxy)
  3434. self.track_produced_symints(example_value._values(), e_proxy)
  3435. elif example_value.layout in {torch.sparse_csr, torch.sparse_bsr}:
  3436. self.track_produced_symints(example_value.crow_indices(), e_proxy)
  3437. self.track_produced_symints(example_value.col_indices(), e_proxy)
  3438. elif example_value.layout in {torch.sparse_csc, torch.sparse_bsc}:
  3439. self.track_produced_symints(example_value.ccol_indices(), e_proxy)
  3440. self.track_produced_symints(example_value.row_indices(), e_proxy)
  3441. if is_traceable_wrapper_subclass(example_value):
  3442. attrs, ctx = example_value.__tensor_flatten__()
  3443. for attr in attrs:
  3444. inner_t = getattr(example_value, attr)
  3445. self.track_produced_symints(inner_t, getattr(e_proxy, attr))
  3446. elif isinstance(example_value, torch.SymInt):
  3447. if need_bind(example_value):
  3448. expr = example_value.node.expr
  3449. tracer.bound_symbols[expr] = e_proxy
  3450. # See Note [Auto lift basic free symbols when create_graph_input]
  3451. def _lift_basic_symbols(
  3452. self, example_value: Union[torch.SymInt, torch.Tensor], src: Optional[Source]
  3453. ) -> None:
  3454. # The before arg is for inserting symints in the sizes/strides of a tensor
  3455. # before the tensor. This ordering ensures that when we look at the tensor's
  3456. # symbols, they're already lifted/tracked. E.g. this assumption is used
  3457. # in insert_deferred_runtime_asserts.
  3458. def _lift_symbols_in_symint(
  3459. s: Union[int, torch.SymInt],
  3460. source: Optional[Source],
  3461. before: bool = False,
  3462. ) -> None:
  3463. if not is_symbolic(s):
  3464. return
  3465. assert isinstance(s, torch.SymInt)
  3466. self_to_be_bound = self.lookup_unbound_symbols(s)
  3467. if len(self_to_be_bound) == 0:
  3468. return
  3469. # For subgraph
  3470. if self.parent is not None:
  3471. # Recursively lift symbols in symint until top-level.
  3472. self.parent._lift_basic_symbols(s, source)
  3473. for s0 in self_to_be_bound:
  3474. parent_proxy = self.parent.bound_symbols[s0]
  3475. example_val = parent_proxy.node.meta["example_value"] # type: ignore[union-attr]
  3476. assert isinstance(example_val, torch.SymInt)
  3477. ph = self.create_graph_input(
  3478. str(s0),
  3479. type(example_val),
  3480. example_val,
  3481. before=before,
  3482. source=source,
  3483. )
  3484. log.debug(
  3485. "_lift_symbols_in_symint %s from %s at debug_level %s",
  3486. s0,
  3487. source.name if source is not None else "subgraph inputs",
  3488. self.debug_level,
  3489. )
  3490. self.lifted_freevars[parent_proxy] = ph # type: ignore[index]
  3491. # For root_tracer:
  3492. else:
  3493. assert len(self_to_be_bound) == 1, (
  3494. f"For root tracer, we only expect to bind basic symbols (compound symbols "
  3495. f"should be cached before) but got unbound symbols {self_to_be_bound} in {s}"
  3496. )
  3497. assert source is not None, (
  3498. f"Source of '{s}' is None when lifting it to input of top-level. If it's an unbacked symbol, "
  3499. "this could be because it's not tracked with lazy_bind_unbacked_symbols. "
  3500. f"Otherwise, should provide a source when create_graph_input for `{s}` at root tracer."
  3501. )
  3502. s0 = next(iter(self_to_be_bound))
  3503. ph = self.create_graph_input(
  3504. str(s0),
  3505. type(s),
  3506. s,
  3507. before=before,
  3508. source=source,
  3509. )
  3510. log.debug(
  3511. "_lift_symbols_in_symint %s from %s at debug_level %s",
  3512. s,
  3513. source.name if source is not None else "subgraph inputs",
  3514. self.debug_level,
  3515. )
  3516. ph.node.meta["grapharg"] = GraphArg(
  3517. source,
  3518. s,
  3519. pass_arg_as_tensor=False,
  3520. fake_tensor=None,
  3521. is_tensor=False,
  3522. )
  3523. if isinstance(example_value, torch.Tensor):
  3524. for i, s in enumerate(example_value.size()):
  3525. _lift_symbols_in_symint(
  3526. s,
  3527. (
  3528. TensorPropertySource(src, TensorProperty.SIZE, i)
  3529. if src is not None
  3530. else None
  3531. ),
  3532. before=True,
  3533. )
  3534. if example_value.layout is torch.strided:
  3535. for i, s in enumerate(example_value.stride()):
  3536. _lift_symbols_in_symint(
  3537. s,
  3538. (
  3539. TensorPropertySource(src, TensorProperty.STRIDE, i)
  3540. if src is not None
  3541. else None
  3542. ),
  3543. before=True,
  3544. )
  3545. _lift_symbols_in_symint(
  3546. example_value.storage_offset(),
  3547. (
  3548. TensorPropertySource(src, TensorProperty.STORAGE_OFFSET)
  3549. if src is not None
  3550. else None
  3551. ),
  3552. before=True,
  3553. )
  3554. elif example_value.layout is torch.sparse_coo:
  3555. self._lift_basic_symbols(example_value._indices(), src)
  3556. self._lift_basic_symbols(example_value._values(), src)
  3557. elif example_value.layout in {torch.sparse_csr, torch.sparse_bsr}:
  3558. self._lift_basic_symbols(example_value.crow_indices(), src)
  3559. self._lift_basic_symbols(example_value.col_indices(), src)
  3560. elif example_value.layout in {torch.sparse_csc, torch.sparse_bsc}:
  3561. self._lift_basic_symbols(example_value.ccol_indices(), src)
  3562. self._lift_basic_symbols(example_value.row_indices(), src)
  3563. if is_traceable_wrapper_subclass(example_value):
  3564. attrs, ctx = example_value.__tensor_flatten__()
  3565. for attr in attrs:
  3566. inner_t = getattr(example_value, attr)
  3567. self._lift_basic_symbols(
  3568. inner_t, AttrSource(src, attr) if src is not None else None
  3569. )
  3570. elif isinstance(example_value, torch.SymInt):
  3571. _lift_symbols_in_symint(
  3572. example_value,
  3573. src,
  3574. )
  3575. # Lookup the proxy in current tracer for each symbol in expressions of s,
  3576. # See Note [Auto lift basic free symbols when create_graph_input]
  3577. def lookup_unbound_symbols(self, s: torch.SymInt) -> list[sympy.Symbol]:
  3578. free_symbols = s.node.expr.free_symbols
  3579. if len(free_symbols) == 0:
  3580. return []
  3581. to_be_bound = []
  3582. for s0 in free_symbols:
  3583. if s0 not in self.bound_symbols:
  3584. to_be_bound.append(s0)
  3585. continue
  3586. proxy = self.bound_symbols[s0]
  3587. if isinstance(proxy, LazyProxy):
  3588. proxy = proxy()
  3589. self.bound_symbols[s0] = proxy
  3590. assert isinstance(proxy, torch.fx.Proxy) and proxy.tracer is self, (
  3591. f"The proxy of symbol {s0} doesn't belong to current tracer."
  3592. )
  3593. # Sort the symbols so that we can have a deterministic lifting order
  3594. return sorted(to_be_bound, key=lambda s: s.name)
  3595. def has_input_mutation(self) -> MutationInfo:
  3596. input_versions_at_beginning = self._input_versions_at_beginning
  3597. input_nodes = []
  3598. input_versions_at_end = []
  3599. for node in self.graph.nodes:
  3600. if node.op == "placeholder":
  3601. example_value = node.meta["example_value"]
  3602. if isinstance(example_value, torch.Tensor):
  3603. input_versions_at_end.append(example_value._version)
  3604. input_nodes.append(node)
  3605. else:
  3606. break
  3607. mutated_inputs = [
  3608. i
  3609. for i, (v1, v2) in enumerate(
  3610. zip(input_versions_at_beginning, input_versions_at_end)
  3611. )
  3612. if v1 != v2
  3613. ]
  3614. if mutated_inputs:
  3615. mutated_nodes = [input_nodes[i] for i in mutated_inputs]
  3616. msg = f"Input mutation detected at {mutated_nodes}"
  3617. return MutationInfo(True, msg)
  3618. return MutationInfo(False, "")
  3619. def has_aliasing(self) -> AliasingInfo:
  3620. from torch._dynamo.variables.higher_order_ops import get_tensor_storages
  3621. from torch._higher_order_ops.utils import _collect_fake_inputs
  3622. input_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
  3623. for node in self.graph.nodes:
  3624. if node.op == "placeholder":
  3625. example_value = _collect_fake_inputs([node])[0]
  3626. if isinstance(example_value, torch.Tensor):
  3627. for storage in get_tensor_storages(example_value):
  3628. if storage in input_storages:
  3629. # input-input aliasing
  3630. msg = f"Input-to-input aliasing detected at nodes {input_storages[storage]} and {node}"
  3631. return AliasingInfo(True, msg)
  3632. input_storages[storage] = node
  3633. else:
  3634. break
  3635. output_storages: dict[StorageWeakRef, torch.fx.Node] = dict()
  3636. out_nodes = self.graph.find_nodes(op="output")[0]
  3637. for out_node in pytree.tree_leaves(out_nodes.args[0]):
  3638. if out_node:
  3639. example_value = _collect_fake_inputs([out_node])[0]
  3640. assert not isinstance(example_value, list)
  3641. if isinstance(example_value, torch.Tensor):
  3642. for storage in get_tensor_storages(example_value):
  3643. if storage in output_storages:
  3644. # output-output aliasing
  3645. msg = f"Output-to-output aliasing detected at nodes {output_storages[storage]} and {out_node}"
  3646. return AliasingInfo(True, msg)
  3647. output_storages[storage] = out_node
  3648. intersected_storages = input_storages.keys() & output_storages.keys()
  3649. if len(intersected_storages) > 0:
  3650. # input-output aliasing
  3651. aliased = [
  3652. (input_storages[s], output_storages[s]) for s in intersected_storages
  3653. ]
  3654. aliased = ", ".join([f"{i} and {o}" for i, o in aliased])
  3655. msg = f"Input-to-output aliasing detected at nodes {aliased}"
  3656. return AliasingInfo(True, msg)
  3657. return AliasingInfo(False, "")
  3658. # NOTE: [HigherOrderOperator tracing design]
  3659. # Ignoring HigherOrderOperators for a moment,
  3660. # OutputGraph represents the graph being built by Dynamo that may be compiled
  3661. # and executed. It holds a root SubgraphTracer where the FX graph is built.
  3662. #
  3663. # HigherOrderOperators are operators that take functions as their arguments.
  3664. # When Dynamo encounters a HigherOrderOperator, then it attempts to introspect
  3665. # the function passed to it (call this the "body function"), capture it into a
  3666. # GraphModule, and rewrite the call to the HigherOrderOperator to use the
  3667. # GraphModule.
  3668. #
  3669. # The way we handle the capture of body functions is through having
  3670. # (possibly nested) SubgraphTracers, one per body function.
  3671. #
  3672. # Mechanically, we do the introspection by:
  3673. # - Creating a new SubgraphTracer via OutputGraph.subtracer
  3674. # - Executing the body function.
  3675. # This constructs the graph of the body function in the new SubgraphTracer
  3676. # while modifying the state of the OutputGraph. For example:
  3677. # - the OutputGraph can receive new GraphArgs (if we discover any new
  3678. # untracked Tensors)
  3679. # - side effects from the body function get accumulated into
  3680. # OutputGraph.side_effects
  3681. # - guards produced by the body function get accumulated into OutputGraph.guards
  3682. #
  3683. # The traced function has some special properties that make it easier for us
  3684. # to transform later down the line:
  3685. # - we lift all free variables to being inputs.
  3686. #
  3687. # If the introspection fails (due to the existence of graph breaks), then
  3688. # we roll back the current OutputGraph state and graph break on the
  3689. # HigherOrderOperator.