guards.py 193 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270427142724273427442754276427742784279428042814282428342844285428642874288428942904291429242934294429542964297429842994300430143024303430443054306430743084309431043114312431343144315431643174318431943204321432243234324432543264327432843294330433143324333433443354336433743384339434043414342434343444345434643474348434943504351435243534354435543564357435843594360436143624363436443654366436743684369437043714372437343744375437643774378437943804381438243834384438543864387438843894390439143924393439443954396439743984399440044014402440344044405440644074408440944104411441244134414441544164417441844194420442144224423442444254426442744284429443044314432443344344435443644374438443944404441444244434444444544464447444844494450445144524453445444554456445744584459446044614462446344644465446644674468446944704471447244734474447544764477447844794480448144824483448444854486448744884489449044914492449344944495449644974498449945004501450245034504450545064507450845094510451145124513451445154516451745184519452045214522452345244525452645274528452945304531453245334534453545364537453845394540454145424543454445454546454745484549455045514552455345544555455645574558455945604561456245634564456545664567456845694570457145724573457445754576457745784579458045814582458345844585458645874588458945904591459245934594459545964597459845994600460146024603460446054606460746084609461046114612461346144615461646174618461946204621462246234624462546264627462846294630463146324633463446354636463746384639464046414642464346444645464646474648464946504651465246534654465546564657465846594660466146624663466446654666466746684669467046714672467346744675467646774678467946804681468246834684468546864687468846894690469146924693469446954696469746984699470047014702470347044705470647074708470947104711471247134714471547164717471847194720472147224723472447254726472747284729473047314732473347344735473647374738473947404741474247434744474547464747474847494750475147524753475447554756475747584759476047614762476347644765476647674768476947704771477247734774477547764777477847794780478147824783478447854786478747884789479047914792479347944795479647974798479948004801480248034804480548064807480848094810481148124813481448154816481748184819482048214822482348244825482648274828482948304831483248334834483548364837483848394840484148424843484448454846484748484849485048514852485348544855485648574858485948604861486248634864486548664867486848694870487148724873487448754876487748784879488048814882488348844885488648874888488948904891
  1. """
  2. Core guard system for Dynamo that detects when compiled code needs to be recompiled due to
  3. changes in program state. Guards are conditions that must remain true for previously-compiled
  4. code to be valid for reuse.
  5. This module provides the infrastructure for creating, managing and checking guards, including:
  6. - Guard creation and composition
  7. - Guard state management and invalidation
  8. - Guard checking and failure handling
  9. - Utilities for guard optimization and debugging
  10. - Integration with Dynamo's compilation caching
  11. The guard system is critical for Dynamo's ability to efficiently reuse compiled code while
  12. maintaining correctness by detecting when recompilation is necessary due to changes in
  13. program state, tensor properties, or control flow.
  14. """
  15. from __future__ import annotations
  16. import ast
  17. import builtins
  18. import collections
  19. import dataclasses
  20. import enum
  21. import functools
  22. import importlib
  23. import inspect
  24. import io
  25. import logging
  26. import math
  27. import pickle
  28. import sys
  29. import textwrap
  30. import traceback
  31. import types
  32. import warnings
  33. import weakref
  34. from contextlib import contextmanager
  35. from copy import deepcopy
  36. from inspect import currentframe
  37. from typing import Any, NamedTuple, NoReturn, Optional, TYPE_CHECKING, Union
  38. from typing_extensions import LiteralString, TypeAliasType, TypeVar
  39. from weakref import ReferenceType
  40. import torch
  41. import torch.overrides
  42. import torch.utils._device
  43. from torch._C._dynamo.eval_frame import code_framelocals_names
  44. from torch._C._dynamo.guards import (
  45. check_obj_id,
  46. check_type_id,
  47. ClosureGuardAccessor,
  48. CodeGuardAccessor,
  49. dict_version,
  50. DictGetItemGuardAccessor,
  51. DictGuardManager,
  52. FuncDefaultsGuardAccessor,
  53. FuncKwDefaultsGuardAccessor,
  54. GetAttrGuardAccessor,
  55. GetGenericDictGuardAccessor,
  56. GuardAccessor,
  57. GuardDebugInfo,
  58. GuardManager,
  59. install_no_tensor_aliasing_guard,
  60. install_object_aliasing_guard,
  61. install_storage_overlapping_guard,
  62. install_symbolic_shape_guard,
  63. LeafGuard,
  64. profile_guard_manager,
  65. RelationalGuard,
  66. RootGuardManager,
  67. TupleGetItemGuardAccessor,
  68. TypeDictGuardAccessor,
  69. TypeGuardAccessor,
  70. TypeMROGuardAccessor,
  71. )
  72. from torch._dynamo.source import (
  73. get_global_source_name,
  74. get_local_source_name,
  75. IndexedSource,
  76. is_from_flatten_script_object_source,
  77. is_from_local_source,
  78. is_from_optimizer_source,
  79. is_from_skip_guard_source,
  80. is_from_unspecialized_builtin_nn_module_source,
  81. TensorProperty,
  82. TensorPropertySource,
  83. )
  84. from torch._dynamo.utils import CompileEventLogger, get_metrics_context
  85. from torch._guards import (
  86. CompileContext,
  87. CompileId,
  88. DuplicateInputs,
  89. Guard,
  90. GuardBuilderBase,
  91. GuardEnvExpr,
  92. GuardSource,
  93. Source,
  94. StorageOverlap,
  95. )
  96. from torch._inductor.utils import IndentedBuffer
  97. from torch._library.opaque_object import get_opaque_obj_info, is_opaque_value_type
  98. from torch._logging import structured
  99. from torch._utils_internal import justknobs_check
  100. from torch.fx.experimental.symbolic_shapes import (
  101. _CppShapeGuardsHelper,
  102. _ShapeGuardsHelper,
  103. EqualityConstraint,
  104. is_symbolic,
  105. SYMPY_INTERP,
  106. )
  107. from torch.utils import _pytree as pytree
  108. from torch.utils._ordered_set import OrderedSet
  109. from torch.utils._traceback import format_frame, report_compile_source_on_error
  110. from torch.utils.weak import TensorWeakRef
  111. from . import config, convert_frame, exc
  112. from .eval_frame import set_guard_error_hook
  113. from .source import (
  114. AttrProxySource,
  115. AttrSource,
  116. CallFunctionNoArgsSource,
  117. CallMethodItemSource,
  118. CellContentsSource,
  119. ChainedSource,
  120. ClosureSource,
  121. CodeSource,
  122. ConstantSource,
  123. ConstDictKeySource,
  124. CurrentStreamSource,
  125. DataclassFieldsSource,
  126. DefaultsSource,
  127. DictGetItemSource,
  128. DictSubclassGetItemSource,
  129. DynamicScalarSource,
  130. FlattenScriptObjectSource,
  131. FloatTensorSource,
  132. FSDPNNModuleSource,
  133. GenericAttrSource,
  134. GetItemSource,
  135. GlobalSource,
  136. GlobalStateSource,
  137. GlobalWeakRefSource,
  138. GradSource,
  139. ImportSource,
  140. ListGetItemSource,
  141. LocalSource,
  142. NamedTupleFieldsSource,
  143. NNModuleSource,
  144. NonSerializableSetGetItemSource,
  145. NumpyTensorSource,
  146. OptimizerSource,
  147. ScriptObjectQualifiedNameSource,
  148. ShapeEnvSource,
  149. SubclassAttrListSource,
  150. TorchFunctionModeStackSource,
  151. TupleIteratorGetItemSource,
  152. TypeDictSource,
  153. TypeMROSource,
  154. TypeSource,
  155. UnspecializedBuiltinNNModuleSource,
  156. UnspecializedNNModuleSource,
  157. UnspecializedParamBufferSource,
  158. WeakRefCallSource,
  159. )
  160. from .types import ( # noqa: F401
  161. CacheEntry,
  162. DynamoFrameType,
  163. ExtraState,
  164. GuardedCode,
  165. GuardFail,
  166. GuardFilterEntry,
  167. GuardFn,
  168. )
  169. from .utils import (
  170. builtin_dict_keys,
  171. common_constant_types,
  172. dataclass_fields,
  173. dict_keys,
  174. get_current_stream,
  175. get_custom_getattr,
  176. get_torch_function_mode_stack,
  177. get_torch_function_mode_stack_at,
  178. guard_failures,
  179. istype,
  180. key_is_id,
  181. key_to_id,
  182. normalize_range_iter,
  183. orig_code_map,
  184. tensor_always_has_static_shape,
  185. tuple_iterator_getitem,
  186. tuple_iterator_len,
  187. unpatched_nn_module_getattr,
  188. verify_guard_fn_signature,
  189. )
  190. if TYPE_CHECKING:
  191. from collections.abc import Callable
  192. guard_manager_testing_hook_fn: Optional[Callable[[Any, Any, Any], Any]] = None
  193. try:
  194. import numpy as np
  195. except ModuleNotFoundError:
  196. np = None # type: ignore[assignment]
  197. if TYPE_CHECKING:
  198. from collections.abc import Generator, KeysView, Sequence
  199. from sympy import Symbol
  200. from torch._C import DispatchKeySet
  201. from torch._dynamo.output_graph import OutputGraphCommon, OutputGraphGuardsState
  202. from torch._dynamo.package import SerializedCode
  203. T = TypeVar("T")
  204. log = logging.getLogger(__name__)
  205. guards_log = torch._logging.getArtifactLogger(__name__, "guards")
  206. recompiles_log = torch._logging.getArtifactLogger(__name__, "recompiles")
  207. recompiles_verbose_log = torch._logging.getArtifactLogger(
  208. __name__, "recompiles_verbose"
  209. )
  210. verbose_guards_log = torch._logging.getArtifactLogger(__name__, "verbose_guards")
  211. dunder_attrs_assumed_constants = (
  212. "__defaults__",
  213. "__kwdefaults__",
  214. "__code__",
  215. "__closure__",
  216. "__annotations__",
  217. "__func__",
  218. "__mro__",
  219. )
  220. def get_framelocals_idx(code: types.CodeType, var_name: str) -> int:
  221. # Refer to index in the frame's localsplus directly.
  222. # NOTE: name order for a code object doesn't change.
  223. # NOTE: we need to find the LAST matching index because <= 3.10 contains
  224. # duplicate names in the case of cells: a name can be both local and cell
  225. # and will take up 2 slots of the frame's localsplus. The correct behavior
  226. # is to refer to the cell, which has a higher index.
  227. framelocals_names_reversed = code_framelocals_names_reversed_cached(code)
  228. framelocals_idx = (
  229. len(framelocals_names_reversed) - framelocals_names_reversed.index(var_name) - 1
  230. )
  231. return framelocals_idx
  232. class IndentedBufferWithPrefix(IndentedBuffer):
  233. def prefix(self) -> str:
  234. return "| " * (self._indent * self.tabwidth)
  235. def writeline(self, line: str, skip_prefix: bool = False) -> None: # type: ignore[override]
  236. if skip_prefix:
  237. super().writeline(line)
  238. else:
  239. super().writeline("+- " + line)
  240. class GuardManagerWrapper:
  241. """
  242. A helper class that contains the root guard manager. An instance of this
  243. class is stored in the Dynamo cache entry, so that the cache entry can
  244. access the RootGuardManager stored in the "root" attribute and directly call
  245. the check_nopybind from C++.
  246. """
  247. def __init__(self, root: Optional[RootGuardManager] = None) -> None:
  248. if root is None:
  249. self.root = RootGuardManager()
  250. else:
  251. self.root = root
  252. self.diff_guard_root: Optional[RootGuardManager] = None
  253. self.closure_vars: Optional[dict[str, Any]] = None
  254. self.args: Optional[list[str]] = None
  255. self.code_parts: list[str] = []
  256. self.verbose_code_parts: Optional[list[str]] = None
  257. self.global_scope: Optional[dict[str, Any]] = None
  258. self.guard_fail_fn: Optional[Callable[[GuardFail], None]] = None
  259. self.cache_entry: Optional[CacheEntry] = None
  260. self.extra_state: Optional[ExtraState] = None
  261. self.id_matched_objs: dict[str, ReferenceType[object]] = {}
  262. self.no_tensor_aliasing_sources: list[str] = []
  263. self.printed_relational_guards: set[RelationalGuard] = set()
  264. self.diff_guard_sources: OrderedSet[str] = OrderedSet()
  265. @contextmanager
  266. def _preserve_printed_relational_guards(self) -> Generator[None, None, None]:
  267. self.printed_relational_guards = set()
  268. try:
  269. yield
  270. finally:
  271. self.printed_relational_guards = set()
  272. # TODO: clarify what fn and attributes guard manager has to get the right things here
  273. def collect_diff_guard_sources(self) -> OrderedSet[str]:
  274. # At the time of finalize, we have only marked guard managers with
  275. # TENSOR_MATCH guards as diff guard managers. So, we do a tree traversal
  276. # and collect all the nodes in the tree (branches) that lead to tensor
  277. # guards.
  278. # After a recompilation, some of guard managers will have a fail_count >
  279. # 0, so we collect them as well. Later on, we accumulate the diff guard
  280. # sources for all the guard managers.
  281. def visit_dict_manager(node: DictGuardManager) -> bool:
  282. is_diff_guard_node = (
  283. node.get_source() in self.diff_guard_sources or node.fail_count() > 0
  284. )
  285. for _idx, (key_mgr, val_mgr) in sorted(
  286. node.get_key_value_managers().items()
  287. ):
  288. is_diff_guard_node |= visit(key_mgr) | visit(val_mgr)
  289. if is_diff_guard_node:
  290. self.diff_guard_sources.add(node.get_source())
  291. return is_diff_guard_node
  292. def visit_manager(node: GuardManager) -> bool:
  293. assert not isinstance(node, DictGuardManager)
  294. is_diff_guard_node = (
  295. node.get_source() in self.diff_guard_sources or node.fail_count() > 0
  296. )
  297. for child_mgr in node.get_child_managers():
  298. is_diff_guard_node |= visit(child_mgr)
  299. if is_diff_guard_node:
  300. self.diff_guard_sources.add(node.get_source())
  301. return is_diff_guard_node
  302. def visit(node: GuardManager) -> bool:
  303. if node is None:
  304. return False
  305. if isinstance(node, DictGuardManager):
  306. return visit_dict_manager(node)
  307. return visit_manager(node)
  308. visit(self.root)
  309. return self.diff_guard_sources
  310. def finalize(self) -> None:
  311. if config.use_recursive_dict_tags_for_guards and justknobs_check(
  312. "pytorch/compiler:use_recursive_dict_tags_for_guards"
  313. ):
  314. self.find_tag_safe_roots()
  315. self.prepare_diff_guard_manager()
  316. def prepare_diff_guard_manager(self) -> None:
  317. self.collect_diff_guard_sources()
  318. self.populate_diff_guard_manager()
  319. def find_tag_safe_roots(self) -> None:
  320. """
  321. Identify ``tag safe nodes`` and ``tag safe roots`` within a guard tree.
  322. -----------------------------------------------------------------------
  323. tag safe node
  324. -----------------------------------------------------------------------
  325. A *tag safe node* is a ``GuardManager`` whose guarded value satisfies one
  326. of the following conditions:
  327. 1. Immutable value - The value is intrinsically immutable according to
  328. ``is_immutable_object``. Tensors are considered immutable. To ensure
  329. that symbolic guards run, we also check that the GuardManager has no
  330. accessors.
  331. 2. Nested tag safe dictionary - The value is a ``dict`` whose keys and
  332. values are all tag safe nodes (checked recursively). Such dictionaries
  333. allow entire nested structures to be skipped once their identity tag
  334. matches.
  335. 3. Pure ``nn.Module`` - The value is an ``nn.Module`` whose sole
  336. accessor is ``GetGenericDictGuardAccessor``—i.e., it only exposes its
  337. ``__dict__`` and nothing else that could mutate between runs.
  338. For every tag safe node, verifying the identity/tag of just the top-level
  339. dictionary is enough to guarantee the entire subtree is unchanged, enabling
  340. a *fast-path* guard check.
  341. -----------------------------------------------------------------------
  342. tag safe root
  343. -----------------------------------------------------------------------
  344. A ``tag safe root`` is a tag safe node whose parent is not tag safe.
  345. These boundary nodes mark the points where guard evaluation can safely
  346. prune traversal: if a tag-safe root's dictionary tag matches, the entire
  347. subtree beneath it is skipped.
  348. One strong requirement for tag safe root is for the guarded object to
  349. support weakref. Refer to more details in the Recursive dict tag
  350. matching note. In short, we need to save the weakref of the object on
  351. first invocation, and check if it is still valid in later iterations, to
  352. apply recursive dict tag optimizations. `dict` objects do NOT support
  353. weakref. Therefore, as of now, we only mark nn module related guard
  354. managers as tag safe roots.
  355. Algorithm
  356. ---------
  357. The search runs in post-order traversal
  358. 1. Visit leaves and classify them as tag safe or not.
  359. 2. Propagate tag-safety upward: a parent dictionary becomes tag safe only if
  360. all of its children are already tag-safe.
  361. 3. Propagate tag-safe-rootness upward: if the whole subtree is tag safe,
  362. the current node becomes the new tag safe root, otherwise propagate the
  363. subtree tag safe roots.
  364. 4. Collect every tag safe node and, by inspecting parent tags, label the
  365. subset that are tag safe roots.
  366. """
  367. def check_tag_safety(
  368. node: GuardManager, accepted_accessors: tuple[type[GuardAccessor], ...]
  369. ) -> bool:
  370. accessors = node.get_accessors()
  371. child_mgrs = node.get_child_managers()
  372. return all(
  373. isinstance(accessor, accepted_accessors) and mgr.is_tag_safe()
  374. for accessor, mgr in zip(accessors, child_mgrs)
  375. )
  376. def visit_dict_manager(node: DictGuardManager) -> list[GuardManager]:
  377. # Just recurse through the key and value dict managers and check if
  378. # all of them are tag safe nodes.
  379. assert issubclass(node.get_type_of_guarded_value(), dict)
  380. tag_safe_roots = []
  381. is_subtree_tag_safe = True
  382. # Recurse to get the tag safe roots from subtree.
  383. for _idx, (key_mgr, val_mgr) in sorted(
  384. node.get_key_value_managers().items()
  385. ):
  386. if key_mgr is not None:
  387. visit(key_mgr)
  388. if val_mgr is not None:
  389. tag_safe_roots.extend(visit(val_mgr))
  390. for key_mgr, val_mgr in node.get_key_value_managers().values():
  391. if key_mgr:
  392. is_subtree_tag_safe &= key_mgr.is_tag_safe()
  393. if val_mgr:
  394. is_subtree_tag_safe &= val_mgr.is_tag_safe()
  395. if is_subtree_tag_safe:
  396. node.mark_tag_safe()
  397. return tag_safe_roots
  398. def visit_manager(node: GuardManager) -> list[GuardManager]:
  399. assert not isinstance(node, DictGuardManager)
  400. # Collect the subtree tag safe roots
  401. tag_safe_roots = []
  402. for child_mgr in node.get_child_managers():
  403. tag_safe_roots.extend(visit(child_mgr))
  404. if node.is_guarded_value_immutable():
  405. # If the node guards a tensor, mark it tag safe only if there
  406. # are no accessors. Presence of accessors means presence of
  407. # symbolic shape guards.
  408. if issubclass(node.get_type_of_guarded_value(), torch.Tensor):
  409. if node.has_no_accessors() and not node.has_object_aliasing_guard():
  410. node.mark_tag_safe()
  411. else:
  412. node.mark_tag_safe()
  413. elif issubclass(node.get_type_of_guarded_value(), dict):
  414. accessors = node.get_accessors()
  415. child_mgrs = node.get_child_managers()
  416. is_subtree_tag_safe = all(
  417. isinstance(accessor, DictGetItemGuardAccessor) and mgr.is_tag_safe()
  418. for accessor, mgr in zip(accessors, child_mgrs)
  419. )
  420. if is_subtree_tag_safe:
  421. node.mark_tag_safe()
  422. elif issubclass(node.get_type_of_guarded_value(), torch.nn.Module):
  423. is_subtree_tag_safe = check_tag_safety(
  424. node, (GetGenericDictGuardAccessor, TypeGuardAccessor)
  425. )
  426. if is_subtree_tag_safe:
  427. node.mark_tag_safe()
  428. # Return the current node as tag safe root, discarding the
  429. # subtree tag safe roots.
  430. return [
  431. node,
  432. ]
  433. elif (
  434. node.get_type_of_guarded_value()
  435. in (
  436. types.FunctionType,
  437. types.MethodType,
  438. staticmethod,
  439. classmethod,
  440. )
  441. and config.assume_dunder_attributes_remain_unchanged
  442. ):
  443. # Assumption: callers will not reassignthe attributes
  444. # func.__code__, func.__closure__, func.__defaults__, or func.__kwdefaults__.
  445. # Mutating the objects those attributes point to is fine;
  446. # rebinding the attribute itself is not.
  447. # Example ─ allowed: foo.__defaults__[0].bar = 99
  448. # forbidden: foo.__defaults__ = (3, 4)
  449. is_subtree_tag_safe = check_tag_safety(
  450. node,
  451. (
  452. CodeGuardAccessor,
  453. ClosureGuardAccessor,
  454. FuncDefaultsGuardAccessor,
  455. FuncKwDefaultsGuardAccessor,
  456. GetAttrGuardAccessor,
  457. ),
  458. )
  459. for accessor in node.get_accessors():
  460. if isinstance(accessor, GetAttrGuardAccessor):
  461. is_subtree_tag_safe &= (
  462. accessor.get_attr_name() in dunder_attrs_assumed_constants
  463. )
  464. if is_subtree_tag_safe:
  465. node.mark_tag_safe()
  466. elif issubclass(node.get_type_of_guarded_value(), types.CellType):
  467. is_subtree_tag_safe = check_tag_safety(node, (GetAttrGuardAccessor,))
  468. is_subtree_tag_safe &= all(
  469. isinstance(accessor, GetAttrGuardAccessor)
  470. and accessor.get_attr_name() == "cell_contents"
  471. for accessor in node.get_accessors()
  472. )
  473. if is_subtree_tag_safe:
  474. node.mark_tag_safe()
  475. elif (
  476. issubclass(node.get_type_of_guarded_value(), tuple)
  477. and node.get_source().endswith(dunder_attrs_assumed_constants)
  478. and config.assume_dunder_attributes_remain_unchanged
  479. ):
  480. # We trust tuples obtained from a function's __closure__ or
  481. # __defaults__. Any *other* tuple-valued attribute can be
  482. # silently replaced—for example:
  483. #
  484. # foo.bar = (1, 2) # original
  485. # foo.bar = (3, 4) # rebinding that our dict-tag optimisation won't see
  486. #
  487. # Therefore only tuples from __closure__ / __defaults__ participate in the
  488. # recursive-dict-tag optimization; all others are ignored.
  489. is_subtree_tag_safe = check_tag_safety(
  490. node, (TupleGetItemGuardAccessor,)
  491. )
  492. if is_subtree_tag_safe:
  493. node.mark_tag_safe()
  494. elif issubclass(node.get_type_of_guarded_value(), type):
  495. is_subtree_tag_safe = check_tag_safety(
  496. node, (TypeDictGuardAccessor, TypeMROGuardAccessor)
  497. )
  498. if is_subtree_tag_safe:
  499. node.mark_tag_safe()
  500. return tag_safe_roots
  501. def visit(node: GuardManager) -> list[GuardManager]:
  502. if node is None:
  503. return []
  504. if isinstance(node, DictGuardManager):
  505. return visit_dict_manager(node)
  506. return visit_manager(node)
  507. tag_safe_roots = visit(self.root)
  508. for node in tag_safe_roots:
  509. if issubclass(node.get_type_of_guarded_value(), torch.nn.Module):
  510. node.mark_tag_safe_root()
  511. def populate_diff_guard_manager(self) -> None:
  512. self.diff_guard_root = self.clone_with_chosen_sources(self.diff_guard_sources)
  513. # Ensure that that C++ side points to the updated diff guard manager.
  514. # When a new GuardManagerWrapper is created, it does not have a
  515. # cache_entry attribute, so it relies on the CacheEntry constructor to
  516. # set the diff_guard_root in C++. But once it is saved in the Dynamo
  517. # cache, C++ side adds a cache_entry attribute. On recompiles, this
  518. # cache_entry is visible, so we update the C++ side to point to the
  519. # update guard manager.
  520. if self.cache_entry:
  521. self.cache_entry.update_diff_guard_root_manager()
  522. def clone_with_chosen_sources(
  523. self, chosen_sources: OrderedSet[str]
  524. ) -> RootGuardManager:
  525. def filter_fn(node_mgr: GuardManager) -> bool:
  526. return node_mgr.get_source() in chosen_sources
  527. return self.root.clone_manager(filter_fn)
  528. def get_guard_lines(self, guard: LeafGuard) -> list[str]:
  529. guard_name = guard.__class__.__name__
  530. parts = guard.verbose_code_parts()
  531. parts = [guard_name + ": " + part for part in parts]
  532. return parts
  533. def get_manager_line(
  534. self, guard_manager: GuardManager, accessor_str: Optional[str] = None
  535. ) -> str:
  536. source = guard_manager.get_source()
  537. t = guard_manager.__class__.__name__
  538. s = t + ": source=" + source
  539. if accessor_str:
  540. s += ", " + accessor_str
  541. s += f", type={guard_manager.get_type_of_guarded_value()}"
  542. s += f", tag_safe=({guard_manager.is_tag_safe()}, {guard_manager.is_tag_safe_root()})"
  543. return s
  544. def construct_dict_manager_string(
  545. self, mgr: DictGuardManager, body: IndentedBufferWithPrefix
  546. ) -> None:
  547. for idx, (key_mgr, val_mgr) in sorted(mgr.get_key_value_managers().items()):
  548. body.writeline(f"KeyValueManager pair at index={idx}")
  549. with body.indent():
  550. if key_mgr:
  551. body.writeline(f"KeyManager: {self.get_manager_line(key_mgr)}")
  552. self.construct_manager_string(key_mgr, body)
  553. if val_mgr:
  554. body.writeline(f"ValueManager: {self.get_manager_line(val_mgr)}")
  555. self.construct_manager_string(val_mgr, body)
  556. def construct_manager_string(
  557. self, mgr: GuardManager, body: IndentedBufferWithPrefix
  558. ) -> None:
  559. with body.indent():
  560. for guard in mgr.get_leaf_guards():
  561. if isinstance(guard, RelationalGuard):
  562. if guard not in self.printed_relational_guards:
  563. self.printed_relational_guards.add(guard)
  564. body.writelines(self.get_guard_lines(guard))
  565. else:
  566. body.writelines(
  567. [
  568. guard.__class__.__name__,
  569. ]
  570. )
  571. else:
  572. body.writelines(self.get_guard_lines(guard))
  573. # This works for both DictGuardManager and SubclassedDictGuardManager
  574. if isinstance(mgr, DictGuardManager):
  575. self.construct_dict_manager_string(mgr, body)
  576. # General case of GuardManager/RootGuardManager
  577. for accessor, child_mgr in zip(
  578. mgr.get_accessors(), mgr.get_child_managers()
  579. ):
  580. body.writeline(
  581. self.get_manager_line(child_mgr, f"accessed_by={accessor.repr()}")
  582. )
  583. self.construct_manager_string(child_mgr, body)
  584. def __str__(self) -> str:
  585. with self._preserve_printed_relational_guards():
  586. body = IndentedBufferWithPrefix()
  587. body.tabwidth = 1
  588. body.writeline("", skip_prefix=True)
  589. body.writeline("TREE_GUARD_MANAGER:", skip_prefix=True)
  590. body.writeline("RootGuardManager")
  591. self.construct_manager_string(self.root, body)
  592. if hasattr(self.root, "get_epilogue_lambda_guards"):
  593. for guard in self.root.get_epilogue_lambda_guards():
  594. body.writelines(self.get_guard_lines(guard))
  595. return body.getvalue()
  596. def check(self, x: Any) -> bool:
  597. # Only needed for debugging purposes.
  598. return self.root.check(x)
  599. def check_verbose(self, x: Any) -> GuardDebugInfo:
  600. # Only needed for debugging purposes.
  601. return self.root.check_verbose(x)
  602. def populate_code_parts_for_debugging(self) -> None:
  603. # This should be called when the guard manager is fully populated
  604. relational_guards_seen = set()
  605. def get_code_parts(leaf_guard: LeafGuard) -> list[str]:
  606. code_parts = []
  607. for verbose_code_part in leaf_guard.verbose_code_parts():
  608. code_part = verbose_code_part.split("#")[0].rstrip()
  609. code_parts.append(code_part)
  610. return code_parts
  611. def visit(mgr: GuardManager) -> None:
  612. nonlocal relational_guards_seen
  613. for guard in mgr.get_leaf_guards():
  614. if isinstance(guard, RelationalGuard):
  615. if guard not in relational_guards_seen:
  616. self.code_parts.extend(get_code_parts(guard))
  617. relational_guards_seen.add(guard)
  618. else:
  619. self.code_parts.extend(get_code_parts(guard))
  620. for child_mgr in mgr.get_child_managers():
  621. visit(child_mgr)
  622. visit(self.root)
  623. def from_numpy(a: Any) -> torch.Tensor:
  624. # If not numpy array, piggy back on e.g. tensor guards to check type
  625. # Re-enable torch function since we disable it on leaf guards
  626. # we need it to properly construct the tensor if a default device is set
  627. with torch.overrides._enable_torch_function():
  628. # pyrefly: ignore [missing-attribute]
  629. return torch.as_tensor(a) if isinstance(a, (np.generic, np.ndarray)) else a
  630. # For user stack printing
  631. @functools.cache
  632. def uninteresting_files() -> set[str]:
  633. import torch._dynamo.external_utils
  634. import torch._dynamo.polyfills
  635. mods = [torch._dynamo.external_utils, torch._dynamo.polyfills]
  636. from torch._dynamo.polyfills.loader import POLYFILLED_MODULES
  637. # pyrefly: ignore [bad-argument-type]
  638. mods.extend(POLYFILLED_MODULES)
  639. return {inspect.getfile(m) for m in mods}
  640. _CLOSURE_VARS: Optional[dict[str, object]] = None
  641. def _get_closure_vars() -> dict[str, object]:
  642. global _CLOSURE_VARS
  643. if _CLOSURE_VARS is None:
  644. _CLOSURE_VARS = {
  645. "___check_type_id": check_type_id,
  646. "___check_obj_id": check_obj_id,
  647. "___odict_getitem": collections.OrderedDict.__getitem__,
  648. "___key_to_id": key_to_id,
  649. "___dict_version": dict_version,
  650. "___dict_contains": lambda a, b: dict.__contains__(b, a),
  651. "___tuple_iterator_len": tuple_iterator_len,
  652. "___normalize_range_iter": normalize_range_iter,
  653. "___tuple_iterator_getitem": tuple_iterator_getitem,
  654. "___dataclass_fields": dataclass_fields,
  655. "___namedtuple_fields": lambda x: x._fields,
  656. "___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at,
  657. "___get_current_stream": get_current_stream,
  658. "__math_isnan": math.isnan,
  659. "__numpy_isnan": None if np is None else np.isnan,
  660. "inf": float("inf"),
  661. "__load_module": importlib.import_module,
  662. "utils_device": torch.utils._device,
  663. "device": torch.device,
  664. "___from_numpy": from_numpy,
  665. "___as_tensor": torch._as_tensor_fullprec,
  666. "torch": torch,
  667. "inspect": inspect,
  668. }
  669. return _CLOSURE_VARS
  670. def _ast_unparse(node: ast.AST) -> str:
  671. return ast.unparse(node).replace("\n", "")
  672. strip_function_call = torch._C._dynamo.strip_function_call
  673. def get_verbose_code_part(code_part: str, guard: Optional[Guard]) -> str:
  674. extra = ""
  675. if guard is not None:
  676. if guard.user_stack:
  677. for fs in reversed(guard.user_stack):
  678. if fs.filename not in uninteresting_files():
  679. extra = f" # {format_frame(fs, line=True)}"
  680. if len(extra) > 1024:
  681. # For fx graphs, the line can be very long in case of
  682. # torch.stack ops, where many inputs are set to None
  683. # after the operation. This increases the size of the
  684. # guards log file. In such cases, do not print the line
  685. # contents.
  686. extra = f" # {format_frame(fs)}"
  687. break
  688. elif guard.stack:
  689. summary = guard.stack.summary()
  690. if len(summary) > 0:
  691. extra = f" # {format_frame(summary[-1])}"
  692. else:
  693. extra = " # <unknown>"
  694. return f"{code_part:<60}{extra}"
  695. def get_verbose_code_parts(
  696. code_parts: Union[str, list[str]],
  697. guard: Optional[Guard],
  698. recompile_hint: Optional[str] = None,
  699. ) -> list[str]:
  700. if not isinstance(code_parts, list):
  701. code_parts = [code_parts]
  702. verbose_code_parts = [
  703. get_verbose_code_part(code_part, guard) for code_part in code_parts
  704. ]
  705. # For CellContentsSource (or any source with a CellContentsSource ancestor),
  706. # add a hint explaining which closure variable is being checked.
  707. # This helps users understand which closure variable caused the guard failure.
  708. if guard is not None:
  709. closure_hint = _get_closure_var_hint(guard.originating_source)
  710. if closure_hint:
  711. recompile_hint = (
  712. f"{closure_hint}, {recompile_hint}" if recompile_hint else closure_hint
  713. )
  714. if recompile_hint:
  715. verbose_code_parts = [
  716. f"{part} (HINT: {recompile_hint})" for part in verbose_code_parts
  717. ]
  718. return verbose_code_parts
  719. def _get_closure_var_hint(source: Optional[Source]) -> Optional[str]:
  720. """
  721. Walk up the source chain to find a CellContentsSource ancestor.
  722. Returns a hint like 'guard on "varname".attr' or None if not found.
  723. """
  724. if source is None:
  725. return None
  726. full_name = source.name
  727. current: Optional[Source] = source
  728. while current is not None:
  729. if isinstance(current, CellContentsSource) and current.freevar_name:
  730. # Compute the path suffix by comparing names
  731. # e.g., full_name="x.__closure__[0].cell_contents.scale"
  732. # current.name="x.__closure__[0].cell_contents"
  733. # suffix=".scale"
  734. path_suffix = full_name[len(current.name) :]
  735. return f'guard on "{current.freevar_name}"{path_suffix}'
  736. current = current.base if isinstance(current, ChainedSource) else None
  737. return None
  738. def convert_int_to_concrete_values(dim: Any) -> Optional[int]:
  739. if dim is None:
  740. return None
  741. if not is_symbolic(dim):
  742. return dim
  743. else:
  744. assert isinstance(dim, torch.SymInt)
  745. return dim.node.maybe_as_int()
  746. def convert_to_concrete_values(size_or_stride: list[Any]) -> list[Optional[int]]:
  747. return [convert_int_to_concrete_values(dim) for dim in size_or_stride]
  748. def get_tensor_guard_code_part(
  749. value: torch.Tensor,
  750. name: str,
  751. sizes: list[Optional[int]],
  752. strides: list[Optional[int]],
  753. pytype: type,
  754. dispatch_keys: DispatchKeySet,
  755. ) -> str:
  756. dispatch_key = (
  757. dispatch_keys | torch._C._dispatch_tls_local_include_set()
  758. ) - torch._C._dispatch_tls_local_exclude_set()
  759. dtype = value.dtype
  760. device_index = value.device.index
  761. requires_grad = value.requires_grad
  762. guard_str = (
  763. f"check_tensor({name}, {pytype.__qualname__}, {dispatch_key}, {dtype}, "
  764. f"device={device_index}, requires_grad={requires_grad}, size={sizes}, stride={strides})"
  765. )
  766. return guard_str
  767. def get_key_index(dct: dict[Any, Any], key: Any) -> int:
  768. # Ensure that we call dict.keys and not value.keys (which can call
  769. # overridden keys method). In the C++ guards, we relied on PyDict_Next
  770. # to traverse the dictionary, which uses the internal data structure and
  771. # does not call the overridden keys method.
  772. return list(builtin_dict_keys(dct)).index(key)
  773. def get_key_index_source(source: Any, index: Any) -> str:
  774. return f"list(dict.keys({source}))[{index}]"
  775. def raise_local_type_error(obj: Any) -> NoReturn:
  776. raise TypeError(
  777. f"Type {type(obj)} for object {obj} cannot be saved "
  778. + "into torch.compile() package since it's defined in local scope. "
  779. + "Please define the class at global scope (top level of a module)."
  780. )
  781. def should_optimize_getattr_on_nn_module(value: Any) -> bool:
  782. # If inline_inbuilt_nn_modules flag is True, Dynamo has already traced
  783. # through the __getattr__, and therefore it is always safe to optimize
  784. # getattr on nn modules.
  785. return isinstance(value, torch.nn.Module) and (
  786. config.inline_inbuilt_nn_modules
  787. or get_custom_getattr(value) is unpatched_nn_module_getattr
  788. )
  789. @dataclasses.dataclass(frozen=True)
  790. class NNModuleAttrAccessorInfo:
  791. # Represents where is the attr name is present in the nn module attribute
  792. # access
  793. # Tells that the attribute can be accessed via __dict__
  794. present_in_generic_dict: bool = False
  795. # Either the actual name or _parameters/_buffers/_modules
  796. l1_key: Optional[str] = None
  797. # Actual parameter/buffer/submodule name
  798. l2_key: Optional[str] = None
  799. def getitem_on_dict_manager(
  800. source: Union[DictGetItemSource, DictSubclassGetItemSource],
  801. base_guard_manager: DictGuardManager,
  802. base_example_value: Any,
  803. example_value: Any,
  804. guard_manager_enum: GuardManagerType,
  805. ) -> GuardManager:
  806. base_source_name = source.base.name
  807. if isinstance(source.index, ConstDictKeySource):
  808. index = source.index.index
  809. else:
  810. assert isinstance(base_example_value, dict)
  811. index = get_key_index(base_example_value, source.index)
  812. key_source = get_key_index_source(base_source_name, index)
  813. # Ensure that we call dict.keys and not value.keys (which can call
  814. # overridden keys method). In the C++ guards, we relied on PyDict_Next
  815. # to traverse the dictionary, which uses the internal data structure and
  816. # does not call the overridden keys method.
  817. key_example_value = list(builtin_dict_keys(base_example_value))[index]
  818. if isinstance(key_example_value, (int, str)):
  819. value_source = f"{base_source_name}[{key_example_value!r}]"
  820. else:
  821. value_source = f"{base_source_name}[{key_source}]"
  822. if not isinstance(source.index, ConstDictKeySource):
  823. # We have to insert a key manager guard here
  824. # TODO - source debug string is probably wrong here.
  825. base_guard_manager.get_key_manager(
  826. index=index,
  827. source=key_source,
  828. example_value=source.index,
  829. guard_manager_enum=GuardManagerType.GUARD_MANAGER,
  830. ).add_equals_match_guard(
  831. source.index, [f"{key_source} == {key_example_value!r}"], None
  832. )
  833. return base_guard_manager.get_value_manager(
  834. index=index,
  835. source=value_source,
  836. example_value=example_value,
  837. guard_manager_enum=guard_manager_enum,
  838. )
  839. def match_on_id_for_tensor(guard: Guard) -> bool:
  840. source = guard.originating_source
  841. # For numpy tensors, always use TENSOR_MATCH because __from_numpy leads
  842. # to a new tensor every time and therefore id differs.
  843. if isinstance(source, NumpyTensorSource):
  844. return False
  845. if guard.is_specialized_nn_module():
  846. return True
  847. return source.is_dict_key() and not isinstance(source, GradSource)
  848. # The ready to eval generated code (possibly multiple parts) for a guard, plus
  849. # the original guard object that created it for provenance
  850. @dataclasses.dataclass
  851. class GuardCodeList:
  852. code_list: list[str]
  853. guard: Guard
  854. class GuardManagerType(enum.Enum):
  855. GUARD_MANAGER = 1
  856. DICT_GUARD_MANAGER = 2
  857. @functools.cache
  858. def code_framelocals_names_reversed_cached(code: types.CodeType) -> list[str]:
  859. return list(reversed(code_framelocals_names(code)))
  860. class GuardBuilder(GuardBuilderBase):
  861. def __init__(
  862. self,
  863. f_code: types.CodeType,
  864. id_ref: Callable[[object, str], int],
  865. source_ref: Callable[[Source], str],
  866. lookup_weakrefs: Callable[[object], Optional[weakref.ref[object]]],
  867. local_scope: dict[str, object],
  868. global_scope: dict[str, object],
  869. guard_manager: GuardManagerWrapper,
  870. check_fn_manager: CheckFunctionManager,
  871. save_guards: bool = False,
  872. runtime_global_scope: Optional[dict[str, object]] = None,
  873. guard_filter_fn: Callable[[Sequence[GuardFilterEntry]], Sequence[bool]]
  874. | None = None,
  875. ) -> None:
  876. self.f_code = f_code
  877. self.id_ref = id_ref
  878. self.source_ref = source_ref
  879. self.lookup_weakrefs = lookup_weakrefs
  880. self.scope: dict[str, dict[str, object]] = {"L": local_scope, "G": global_scope}
  881. self.src_get_value_cache: weakref.WeakKeyDictionary[Source, object] = (
  882. weakref.WeakKeyDictionary()
  883. )
  884. self.runtime_global_scope = runtime_global_scope or global_scope
  885. self.scope["__builtins__"] = builtins.__dict__.copy()
  886. for (
  887. name,
  888. package_module,
  889. ) in torch.package.package_importer._package_imported_modules.items():
  890. name = name.replace(">", "_").replace("<", "_").replace(".", "_dot_")
  891. # Write the package module into the scope so that we can import it
  892. self.scope["__builtins__"][name] = package_module
  893. # Write the demangled name to the scope so that we can use it
  894. self.scope[name] = package_module
  895. self.guard_manager = guard_manager
  896. self.argnames: list[str] = []
  897. # Code is python expression strings generated for each guard
  898. self.code: list[GuardCodeList] = []
  899. # shape_env_code is only used by builder and is used for
  900. # shape env code. This exists only because we need to make sure
  901. # shape env guards get run after tensor match guards (since the
  902. # tensor match guards make sure we actually have tensors)
  903. self.shape_env_code: list[GuardCodeList] = []
  904. # Collect the guard managers and debug info to insert no tensor aliasing
  905. # guards.
  906. self.no_tensor_aliasing_names: list[str] = []
  907. self.no_tensor_aliasing_guard_managers: list[GuardManager] = []
  908. self.check_fn_manager: CheckFunctionManager = check_fn_manager
  909. self.guard_tree_values: dict[int, Any] = {}
  910. self.save_guards = save_guards
  911. self.guard_filter_fn = guard_filter_fn
  912. # Collect the ids of dicts which need key order guarding. source_name is
  913. # not sufficient because for nn modules, we can have different sources
  914. # to access the same object - self._module["param"] is same as
  915. # self.param.
  916. self.key_order_guarded_dict_ids = set()
  917. assert self.check_fn_manager.output_graph is not None
  918. for source in self.check_fn_manager.output_graph.guard_on_key_order:
  919. dict_obj = self.get(source)
  920. self.key_order_guarded_dict_ids.add(id(dict_obj))
  921. # Keep track of weak references of objects with ID_MATCH guard. This
  922. # info is stored alongside optimized_code and guard_manager and is used to
  923. # limit the number of cache entries with same ID_MATCH'd object.
  924. self.id_matched_objs: dict[str, ReferenceType[object]] = {}
  925. # Save the guard managers to avoid repeatedly traversing sources.
  926. self._cached_guard_managers: dict[str, GuardManager] = {}
  927. self._cached_duplicate_input_guards: set[tuple[str, str]] = set()
  928. self.object_aliasing_guard_codes: list[tuple[str, str]] = []
  929. self.guard_nn_modules = config.guard_nn_modules and justknobs_check(
  930. "pytorch/compiler:guard_nn_modules"
  931. )
  932. self.already_added_code_parts: OrderedSet[str] = OrderedSet()
  933. def guard_on_dict_keys_and_ignore_order(
  934. self, example_value: dict[Any, Any], guard: Guard
  935. ) -> None:
  936. dict_mgr = self.get_guard_manager(guard)
  937. if isinstance(dict_mgr, DictGuardManager):
  938. raise NotImplementedError(
  939. "Not expecting a DictGuardManager. Seems like Dynamo incorrectly "
  940. f"added the dict to tx.output.guard_on_key_order for {guard.name}"
  941. )
  942. # Iterate over the dicts and install a dict_getitem_manager.
  943. dict_source = guard.originating_source.name
  944. # Ensure that we call dict.keys and not value.keys (which can call
  945. # overridden keys method). In the C++ guards, we relied on PyDict_Next
  946. # to traverse the dictionary, which uses the internal data structure and
  947. # does not call the overridden keys method.
  948. for key in builtin_dict_keys(example_value):
  949. value = example_value[key]
  950. value_source = DictGetItemSource(guard.originating_source, index=key)
  951. guard_manager_enum = self.get_guard_manager_type(
  952. value_source, example_value
  953. )
  954. dict_mgr.dict_getitem_manager(
  955. key=key,
  956. source=f"{dict_source}[{key!r}]",
  957. example_value=value,
  958. guard_manager_enum=guard_manager_enum,
  959. )
  960. def guard_on_dict_keys_and_order(self, value: dict[Any, Any], guard: Guard) -> None:
  961. # Add key managers for the DictGuardManager. Then add either an
  962. # ID_MATCH or EQUALS_MATCH guard on the key.
  963. dict_mgr = self.get_guard_manager(guard)
  964. if not isinstance(dict_mgr, DictGuardManager):
  965. raise NotImplementedError(
  966. "Expecting a DictGuardManager. Seems like Dynamo forgot "
  967. f"to set the right guard manager enum for {guard.name}"
  968. )
  969. assert isinstance(dict_mgr, DictGuardManager)
  970. # Ensure that we call dict.keys and not value.keys (which can call
  971. # overridden keys method). In the C++ guards, we relied on PyDict_Next
  972. # to traverse the dictionary, which uses the internal data structure and
  973. # does not call the overridden keys method.
  974. for idx, key in enumerate(builtin_dict_keys(value)):
  975. key_source = get_key_index_source(guard.name, idx)
  976. key_manager = dict_mgr.get_key_manager(
  977. index=idx,
  978. source=key_source,
  979. example_value=key,
  980. guard_manager_enum=GuardManagerType.GUARD_MANAGER,
  981. )
  982. if key_is_id(key):
  983. # Install ID_MATCH guard
  984. id_val = self.id_ref(key, key_source)
  985. key_manager.add_id_match_guard(
  986. id_val,
  987. get_verbose_code_parts(
  988. f"__check_obj_id({key_source}, {id_val})", guard
  989. ),
  990. guard.user_stack,
  991. )
  992. else:
  993. # Install EQUALS_MATCH guard
  994. key_manager.add_equals_match_guard(
  995. key,
  996. get_verbose_code_parts(f"{key_source} == {key!r}", guard),
  997. guard.user_stack,
  998. )
  999. @staticmethod
  1000. def _get_generic_dict_manager_example_value(example_value: Any) -> Optional[Any]:
  1001. # due to a bug in 3.13.0 (introduced by https://github.com/python/cpython/pull/116115,
  1002. # reported in https://github.com/python/cpython/issues/125608,
  1003. # fixed by https://github.com/python/cpython/pull/125611), we cannot take
  1004. # advantage of __dict__ versions to speed up guard checks.
  1005. if (
  1006. config.issue_3_13_0_warning
  1007. and sys.version_info >= (3, 13)
  1008. and sys.version_info < (3, 13, 1)
  1009. ):
  1010. warnings.warn(
  1011. "Guards may run slower on Python 3.13.0. Consider upgrading to Python 3.13.1+.",
  1012. RuntimeWarning,
  1013. )
  1014. return None
  1015. return example_value
  1016. def getattr_on_nn_module(
  1017. self,
  1018. source: AttrSource,
  1019. base_guard_manager: GuardManager,
  1020. base_example_value: Any,
  1021. example_value: Any,
  1022. base_source_name: str,
  1023. source_name: str,
  1024. guard_manager_enum: GuardManagerType,
  1025. ) -> GuardManager:
  1026. """
  1027. This tries to avoid calling the expensive nn module custom getattr method by
  1028. checking if the attribute is accessible via __dict__. For attributes that
  1029. are not accessible via __dict__ (like descriptors), we fallback to
  1030. PyObject_GetAttr.
  1031. There are two cases that we optimize for
  1032. 1) attributes present directly in __dict__, e.g training.
  1033. 2) parameters/buffers/modules - they can be accessed via _parameters,
  1034. _buffers, _modules keys in __dict__. For example, mod.linear can be
  1035. accessed as mod.__dict__["_parameters"]["linear"]
  1036. The most common and expensive case for nn module guards is of type
  1037. mod.submod1.submod2.submod3.training. We avoid the python getattr of nn
  1038. modules by going through the __dict__.
  1039. """
  1040. def getitem_on_dict_mgr(
  1041. mgr: GuardManager,
  1042. key: Any,
  1043. source_name: str,
  1044. base_example_value: Any,
  1045. example_value: Any,
  1046. guard_manager_enum: GuardManagerType,
  1047. ) -> GuardManager:
  1048. if isinstance(mgr, DictGuardManager):
  1049. # Case where the user code relies on key order, e.g.,
  1050. # named_parameters
  1051. index = get_key_index(base_example_value, key)
  1052. # Install the key manager and add equals match guard
  1053. key_source = f"list(dict.keys({source_name}))[{index!r}]"
  1054. mgr.get_key_manager(
  1055. index=index,
  1056. source=key_source,
  1057. example_value=key,
  1058. guard_manager_enum=GuardManagerType.GUARD_MANAGER,
  1059. ).add_equals_match_guard(key, [f"{key_source} == {key!r}"], None)
  1060. # Install the value manager
  1061. return mgr.get_value_manager(
  1062. index=index,
  1063. source=source_name,
  1064. example_value=example_value,
  1065. guard_manager_enum=guard_manager_enum,
  1066. )
  1067. else:
  1068. return mgr.dict_getitem_manager(
  1069. key=key,
  1070. source=source_name,
  1071. example_value=example_value,
  1072. guard_manager_enum=guard_manager_enum,
  1073. )
  1074. attr_name = source.member
  1075. mod_dict = base_example_value.__dict__
  1076. all_class_attribute_names: set[str] = set()
  1077. for x in inspect.getmro(base_example_value.__class__):
  1078. all_class_attribute_names.update(x.__dict__.keys())
  1079. accessor_info = NNModuleAttrAccessorInfo(False, None, None)
  1080. if attr_name in mod_dict:
  1081. accessor_info = NNModuleAttrAccessorInfo(True, attr_name, None)
  1082. elif "_parameters" in mod_dict and attr_name in mod_dict["_parameters"]:
  1083. accessor_info = NNModuleAttrAccessorInfo(True, "_parameters", attr_name)
  1084. elif "_buffers" in mod_dict and attr_name in mod_dict["_buffers"]:
  1085. accessor_info = NNModuleAttrAccessorInfo(True, "_buffers", attr_name)
  1086. elif (
  1087. attr_name not in all_class_attribute_names
  1088. and "_modules" in mod_dict
  1089. and attr_name in mod_dict["_modules"]
  1090. ):
  1091. # Check test_attr_precedence test - instance attributes always take precedence unless its an nn.Module.
  1092. accessor_info = NNModuleAttrAccessorInfo(True, "_modules", attr_name)
  1093. if not accessor_info.present_in_generic_dict:
  1094. # The attribute can be accessed by __getattribute__ call, so rely on
  1095. # PyObject_GetAttr
  1096. return base_guard_manager.getattr_manager(
  1097. attr=source.member,
  1098. source=source_name,
  1099. example_value=example_value,
  1100. guard_manager_enum=guard_manager_enum,
  1101. )
  1102. else:
  1103. assert accessor_info.l1_key
  1104. l1_key = accessor_info.l1_key
  1105. l2_key = accessor_info.l2_key
  1106. # Set source strings for debug info
  1107. mod_dict_source = f"{base_source_name}.__dict__"
  1108. l1_source_name = l2_source_name = None
  1109. l1_value = l2_value = None
  1110. l1_guard_manager_enum = l2_guard_manager_enum = None
  1111. if l2_key:
  1112. l1_source = AttrSource(source.base, l1_key)
  1113. l1_source_name = l1_source.name
  1114. l1_value = mod_dict[l1_key]
  1115. # do not guard on key order for _parameters etc unless the user code
  1116. # actually needs the key order (e.g. calling named_parameters)
  1117. l1_guard_manager_enum = self.get_guard_manager_type(l1_source, l1_value)
  1118. l2_source_name = source_name
  1119. l2_value = example_value
  1120. l2_guard_manager_enum = self.get_guard_manager_type(
  1121. source, example_value
  1122. )
  1123. else:
  1124. l1_source_name = source_name
  1125. l1_value = example_value
  1126. l1_guard_manager_enum = self.get_guard_manager_type(
  1127. source, example_value
  1128. )
  1129. # Get __dict__ accessor. No need to guard on dict key order, so use base
  1130. # Guard Manager
  1131. mod_generic_dict_manager = base_guard_manager.get_generic_dict_manager(
  1132. source=mod_dict_source,
  1133. example_value=self._get_generic_dict_manager_example_value(mod_dict),
  1134. guard_manager_enum=GuardManagerType.GUARD_MANAGER,
  1135. )
  1136. l1_mgr = getitem_on_dict_mgr(
  1137. mgr=mod_generic_dict_manager,
  1138. key=l1_key,
  1139. source_name=l1_source_name,
  1140. base_example_value=mod_dict,
  1141. example_value=l1_value,
  1142. guard_manager_enum=l1_guard_manager_enum,
  1143. )
  1144. if l2_key:
  1145. assert l2_source_name is not None and l2_guard_manager_enum is not None
  1146. return getitem_on_dict_mgr(
  1147. mgr=l1_mgr,
  1148. key=l2_key,
  1149. source_name=l2_source_name,
  1150. base_example_value=l1_value,
  1151. example_value=l2_value,
  1152. guard_manager_enum=l2_guard_manager_enum,
  1153. )
  1154. return l1_mgr
  1155. def requires_key_order_guarding(self, source: Source) -> bool:
  1156. source_name = source.name
  1157. if source_name == "":
  1158. return False
  1159. obj_id = id(self.get(source))
  1160. return obj_id in self.key_order_guarded_dict_ids
  1161. def get_guard_manager_type(
  1162. self,
  1163. source: Source,
  1164. example_value: Optional[
  1165. Union[KeysView[Any], set[Any], frozenset[Any], dict[Any, Any]]
  1166. ],
  1167. ) -> GuardManagerType:
  1168. guard_manager_enum = GuardManagerType.GUARD_MANAGER
  1169. if self.requires_key_order_guarding(source):
  1170. # Fix this if condition
  1171. if isinstance(example_value, dict_keys):
  1172. guard_manager_enum = GuardManagerType.DICT_GUARD_MANAGER
  1173. elif isinstance(example_value, (set, frozenset)):
  1174. # we don't need to guard on key order for set/frozenset
  1175. # but the if above will be true for these types as set is
  1176. # implemented using a dict in Dynamo
  1177. guard_manager_enum = GuardManagerType.GUARD_MANAGER
  1178. else:
  1179. assert isinstance(example_value, dict)
  1180. guard_manager_enum = GuardManagerType.DICT_GUARD_MANAGER
  1181. return guard_manager_enum
  1182. def manager_guards_on_keys(self, mgr_enum: GuardManagerType) -> bool:
  1183. return mgr_enum == GuardManagerType.DICT_GUARD_MANAGER
  1184. def get_global_guard_manager(self) -> GuardManager:
  1185. return self.guard_manager.root.globals_dict_manager(
  1186. f_globals=self.runtime_global_scope,
  1187. source="G",
  1188. example_value=self.scope["G"],
  1189. guard_manager_enum=GuardManagerType.GUARD_MANAGER,
  1190. )
  1191. def get_guard_manager_from_source(self, source: Source) -> GuardManager:
  1192. root_guard_manager = self.guard_manager.root
  1193. example_value = None
  1194. source_name = source.name
  1195. if source_name != "" and source_name in self._cached_guard_managers:
  1196. return self._cached_guard_managers[source_name]
  1197. if source_name != "":
  1198. example_value = self.get(source)
  1199. self.guard_tree_values[id(example_value)] = example_value
  1200. guard_manager_enum = self.get_guard_manager_type(source, example_value)
  1201. # Get base manager related information
  1202. base_source_name = None
  1203. base_example_value = None
  1204. base_guard_manager = None
  1205. base_guard_manager_enum = GuardManagerType.GUARD_MANAGER
  1206. if isinstance(source, ChainedSource):
  1207. base_source_name = source.base.name
  1208. base_example_value = self.get(source.base)
  1209. base_guard_manager = self.get_guard_manager_from_source(source.base)
  1210. base_guard_manager_enum = self.get_guard_manager_type(
  1211. source.base, base_example_value
  1212. )
  1213. # Use istype instead of isinstance to check for exact type of source.
  1214. if istype(source, LocalSource):
  1215. framelocals_idx = get_framelocals_idx(self.f_code, source.local_name)
  1216. out = root_guard_manager.framelocals_manager(
  1217. key=(source.local_name, framelocals_idx),
  1218. source=source_name,
  1219. example_value=example_value,
  1220. guard_manager_enum=guard_manager_enum,
  1221. )
  1222. elif istype(source, GlobalSource):
  1223. # Global manager accepts a dict but it is not a DictGuardManager
  1224. # because globals dict is big and we typically guard on a very
  1225. # selected items on globals.
  1226. out = self.get_global_guard_manager().dict_getitem_manager(
  1227. key=source.global_name,
  1228. source=source_name,
  1229. example_value=example_value,
  1230. guard_manager_enum=guard_manager_enum,
  1231. )
  1232. elif istype(source, GlobalWeakRefSource):
  1233. out = self.get_global_guard_manager().global_weakref_manager(
  1234. global_name=source.global_name,
  1235. source=source_name,
  1236. example_value=example_value,
  1237. guard_manager_enum=guard_manager_enum,
  1238. )
  1239. elif istype(source, GlobalStateSource):
  1240. # Don't do anything here. We guard on global state completely in
  1241. # C++. So just return the root mgr.
  1242. return root_guard_manager
  1243. elif istype(source, ShapeEnvSource):
  1244. return root_guard_manager
  1245. elif istype(source, TypeSource):
  1246. assert base_guard_manager # to make mypy happy
  1247. out = base_guard_manager.type_manager(
  1248. source=source_name,
  1249. example_value=example_value,
  1250. guard_manager_enum=guard_manager_enum,
  1251. )
  1252. elif istype(source, TypeDictSource):
  1253. assert base_guard_manager # to make mypy happy
  1254. out = base_guard_manager.type_dict_manager(
  1255. source=source_name,
  1256. example_value=example_value,
  1257. guard_manager_enum=guard_manager_enum,
  1258. )
  1259. elif istype(source, TypeMROSource):
  1260. assert base_guard_manager # to make mypy happy
  1261. out = base_guard_manager.type_mro_manager(
  1262. source=source_name,
  1263. example_value=example_value,
  1264. guard_manager_enum=guard_manager_enum,
  1265. )
  1266. elif istype(
  1267. source,
  1268. (
  1269. OptimizerSource,
  1270. NNModuleSource,
  1271. UnspecializedNNModuleSource,
  1272. UnspecializedBuiltinNNModuleSource,
  1273. FSDPNNModuleSource,
  1274. ),
  1275. ):
  1276. assert base_guard_manager # to make mypy happy
  1277. out = base_guard_manager
  1278. elif istype(source, ImportSource):
  1279. module = importlib.import_module(source.module_name)
  1280. out = root_guard_manager.lambda_manager(
  1281. python_lambda=lambda _, m=module: m,
  1282. source=source_name,
  1283. example_value=example_value,
  1284. guard_manager_enum=guard_manager_enum,
  1285. )
  1286. elif istype(source, TorchFunctionModeStackSource):
  1287. out = root_guard_manager.lambda_manager(
  1288. python_lambda=lambda _: get_torch_function_mode_stack_at(
  1289. source._get_index()
  1290. ),
  1291. source=source_name,
  1292. example_value=example_value,
  1293. guard_manager_enum=guard_manager_enum,
  1294. )
  1295. elif istype(source, CurrentStreamSource):
  1296. out = root_guard_manager.lambda_manager(
  1297. python_lambda=lambda _: get_current_stream(source.device),
  1298. source=source_name,
  1299. example_value=example_value,
  1300. guard_manager_enum=guard_manager_enum,
  1301. )
  1302. elif istype(source, GradSource):
  1303. assert base_guard_manager # to make mypy happy
  1304. out = base_guard_manager.grad_manager(
  1305. source=source_name,
  1306. example_value=example_value,
  1307. guard_manager_enum=guard_manager_enum,
  1308. )
  1309. elif istype(source, GenericAttrSource):
  1310. assert base_guard_manager # to make mypy happy
  1311. out = base_guard_manager.generic_getattr_manager(
  1312. attr=source.member,
  1313. source=source_name,
  1314. example_value=example_value,
  1315. guard_manager_enum=guard_manager_enum,
  1316. )
  1317. elif istype(
  1318. source, (AttrSource, CellContentsSource, UnspecializedParamBufferSource)
  1319. ):
  1320. assert base_guard_manager # to make mypy happy
  1321. assert isinstance(source, AttrSource)
  1322. if should_optimize_getattr_on_nn_module(base_example_value):
  1323. assert base_source_name
  1324. out = self.getattr_on_nn_module(
  1325. source,
  1326. base_guard_manager,
  1327. base_example_value,
  1328. example_value,
  1329. base_source_name,
  1330. source_name,
  1331. guard_manager_enum,
  1332. )
  1333. else:
  1334. out = base_guard_manager.getattr_manager(
  1335. attr=source.member,
  1336. source=source_name,
  1337. example_value=example_value,
  1338. guard_manager_enum=guard_manager_enum,
  1339. )
  1340. elif istype(source, (DictGetItemSource, DictSubclassGetItemSource)):
  1341. assert base_guard_manager # to make mypy happy
  1342. assert isinstance(base_example_value, (dict, collections.OrderedDict))
  1343. assert isinstance(source, (DictGetItemSource, DictSubclassGetItemSource))
  1344. if isinstance(base_guard_manager, DictGuardManager):
  1345. assert self.manager_guards_on_keys(base_guard_manager_enum)
  1346. out = getitem_on_dict_manager(
  1347. source,
  1348. base_guard_manager,
  1349. base_example_value,
  1350. example_value,
  1351. guard_manager_enum,
  1352. )
  1353. else:
  1354. if isinstance(source.index, ConstDictKeySource):
  1355. raise RuntimeError(
  1356. "Expecting clean index here. Likely Dynamo forgot to mark"
  1357. " a dict as guard_on_key_order"
  1358. )
  1359. out = base_guard_manager.dict_getitem_manager(
  1360. key=source.index,
  1361. source=source_name,
  1362. example_value=example_value,
  1363. guard_manager_enum=guard_manager_enum,
  1364. )
  1365. elif istype(source, TensorPropertySource):
  1366. out = getattr(
  1367. base_guard_manager,
  1368. f"tensor_property_{source.prop.name.lower()}_manager",
  1369. )(
  1370. idx=source.idx,
  1371. source=source_name,
  1372. example_value=example_value,
  1373. guard_manager_enum=guard_manager_enum,
  1374. )
  1375. elif istype(source, IndexedSource):
  1376. assert base_guard_manager # to make mypy happy
  1377. out = base_guard_manager.indexed_manager(
  1378. idx=source.idx,
  1379. source=source_name,
  1380. example_value=example_value,
  1381. guard_manager_enum=guard_manager_enum,
  1382. )
  1383. elif istype(source, ListGetItemSource):
  1384. assert base_guard_manager # to make mypy happy
  1385. out = base_guard_manager.list_getitem_manager(
  1386. key=source.index,
  1387. source=source_name,
  1388. example_value=example_value,
  1389. guard_manager_enum=guard_manager_enum,
  1390. )
  1391. elif istype(source, GetItemSource):
  1392. assert base_guard_manager # to make mypy happy
  1393. assert not isinstance(
  1394. base_example_value, (dict, collections.OrderedDict)
  1395. ), "Use DictGetItemSource"
  1396. if isinstance(base_example_value, list) and not source.index_is_slice:
  1397. out = base_guard_manager.list_getitem_manager(
  1398. key=source.index,
  1399. source=source_name,
  1400. example_value=example_value,
  1401. guard_manager_enum=guard_manager_enum,
  1402. )
  1403. elif isinstance(base_example_value, tuple) and not source.index_is_slice:
  1404. out = base_guard_manager.tuple_getitem_manager(
  1405. key=source.index,
  1406. source=source_name,
  1407. example_value=example_value,
  1408. guard_manager_enum=guard_manager_enum,
  1409. )
  1410. else:
  1411. index = source.index
  1412. if source.index_is_slice:
  1413. index = source.unpack_slice()
  1414. out = base_guard_manager.getitem_manager(
  1415. key=index,
  1416. source=source_name,
  1417. example_value=example_value,
  1418. guard_manager_enum=guard_manager_enum,
  1419. )
  1420. elif istype(source, DefaultsSource):
  1421. assert base_guard_manager # to make mypy happy
  1422. assert base_source_name
  1423. assert callable(base_example_value)
  1424. if not source.is_kw:
  1425. out = base_guard_manager.func_defaults_manager(
  1426. source=base_source_name,
  1427. example_value=base_example_value.__defaults__,
  1428. guard_manager_enum=GuardManagerType.GUARD_MANAGER,
  1429. ).getitem_manager(
  1430. key=source.idx_key,
  1431. source=source_name,
  1432. example_value=example_value,
  1433. guard_manager_enum=guard_manager_enum,
  1434. )
  1435. else:
  1436. # kwdefauts is a dict, so use a DictGuardManager
  1437. kwdefaults = base_example_value.__kwdefaults__
  1438. assert base_source_name is not None
  1439. kw_source = base_source_name + ".__kwdefaults__"
  1440. # kwdefaults is a dict. No need to guard on dict order.
  1441. dict_mgr = base_guard_manager.func_kwdefaults_manager(
  1442. source=kw_source,
  1443. example_value=kwdefaults,
  1444. guard_manager_enum=GuardManagerType.GUARD_MANAGER,
  1445. )
  1446. assert not isinstance(dict_mgr, DictGuardManager)
  1447. out = dict_mgr.dict_getitem_manager(
  1448. key=source.idx_key,
  1449. source=source_name,
  1450. example_value=example_value,
  1451. guard_manager_enum=guard_manager_enum,
  1452. )
  1453. elif istype(source, NumpyTensorSource):
  1454. assert base_guard_manager # to make mypy happy
  1455. out = base_guard_manager.lambda_manager(
  1456. python_lambda=from_numpy,
  1457. source=source_name,
  1458. example_value=example_value,
  1459. guard_manager_enum=guard_manager_enum,
  1460. )
  1461. elif istype(source, SubclassAttrListSource):
  1462. assert base_guard_manager # to make mypy happy
  1463. out = base_guard_manager.lambda_manager(
  1464. python_lambda=lambda x: x.__tensor_flatten__()[0],
  1465. source=source_name,
  1466. example_value=example_value,
  1467. guard_manager_enum=guard_manager_enum,
  1468. )
  1469. elif istype(source, FlattenScriptObjectSource):
  1470. assert base_guard_manager # to make mypy happy
  1471. out = base_guard_manager.lambda_manager(
  1472. python_lambda=lambda x: x.__obj_flatten__(),
  1473. source=source_name,
  1474. example_value=example_value,
  1475. guard_manager_enum=guard_manager_enum,
  1476. )
  1477. elif istype(source, ScriptObjectQualifiedNameSource):
  1478. assert base_guard_manager # to make mypy happy
  1479. out = base_guard_manager.lambda_manager(
  1480. python_lambda=lambda x: x._type().qualified_name(),
  1481. source=source_name,
  1482. example_value=example_value,
  1483. guard_manager_enum=guard_manager_enum,
  1484. )
  1485. elif istype(source, AttrProxySource):
  1486. assert base_guard_manager # to make mypy happy
  1487. out = base_guard_manager.lambda_manager(
  1488. python_lambda=lambda x: x.get_base(),
  1489. source=source_name,
  1490. example_value=example_value,
  1491. guard_manager_enum=guard_manager_enum,
  1492. )
  1493. elif istype(source, CallMethodItemSource):
  1494. assert base_guard_manager # to make mypy happy
  1495. out = base_guard_manager.lambda_manager(
  1496. python_lambda=lambda x: x.item(),
  1497. source=source_name,
  1498. example_value=example_value,
  1499. guard_manager_enum=guard_manager_enum,
  1500. )
  1501. elif istype(source, FloatTensorSource):
  1502. assert base_guard_manager # to make mypy happy
  1503. out = base_guard_manager.lambda_manager(
  1504. python_lambda=lambda x: torch._as_tensor_fullprec(x),
  1505. source=source_name,
  1506. example_value=example_value,
  1507. guard_manager_enum=guard_manager_enum,
  1508. )
  1509. elif istype(source, TupleIteratorGetItemSource):
  1510. assert base_guard_manager # to make mypy happy
  1511. out = base_guard_manager.tuple_iterator_getitem_manager(
  1512. index=source.index,
  1513. source=source_name,
  1514. example_value=example_value,
  1515. guard_manager_enum=guard_manager_enum,
  1516. )
  1517. elif isinstance(source, ConstDictKeySource):
  1518. if not isinstance(base_guard_manager, DictGuardManager):
  1519. raise AssertionError(
  1520. "ConstDictKeySource can only work on DictGuardManager"
  1521. )
  1522. out = base_guard_manager.get_key_manager(
  1523. index=source.index,
  1524. source=source_name,
  1525. example_value=example_value,
  1526. guard_manager_enum=guard_manager_enum,
  1527. )
  1528. elif istype(source, NonSerializableSetGetItemSource):
  1529. assert base_guard_manager
  1530. out = base_guard_manager.set_getitem_manager(
  1531. index=source.index,
  1532. source=source_name,
  1533. example_value=example_value,
  1534. guard_manager_enum=guard_manager_enum,
  1535. )
  1536. elif istype(source, WeakRefCallSource):
  1537. assert base_guard_manager # to make mypy happy
  1538. out = base_guard_manager.weakref_call_manager(
  1539. source=source_name,
  1540. example_value=example_value,
  1541. guard_manager_enum=guard_manager_enum,
  1542. )
  1543. elif istype(source, CallFunctionNoArgsSource):
  1544. assert base_guard_manager # to make mypy happy
  1545. out = base_guard_manager.call_function_no_args_manager(
  1546. source=source_name,
  1547. example_value=example_value,
  1548. guard_manager_enum=guard_manager_enum,
  1549. )
  1550. elif istype(source, DataclassFieldsSource):
  1551. assert base_guard_manager
  1552. out = base_guard_manager.lambda_manager(
  1553. python_lambda=lambda x: dataclass_fields(x),
  1554. source=source_name,
  1555. example_value=example_value,
  1556. guard_manager_enum=guard_manager_enum,
  1557. )
  1558. elif istype(source, NamedTupleFieldsSource):
  1559. assert base_guard_manager
  1560. out = base_guard_manager.lambda_manager(
  1561. python_lambda=lambda x: x._fields,
  1562. source=source_name,
  1563. example_value=example_value,
  1564. guard_manager_enum=guard_manager_enum,
  1565. )
  1566. elif istype(source, CodeSource):
  1567. assert base_guard_manager # to make mypy happy
  1568. out = base_guard_manager.code_manager(
  1569. source=source_name,
  1570. example_value=example_value,
  1571. guard_manager_enum=guard_manager_enum,
  1572. )
  1573. elif istype(source, ClosureSource):
  1574. assert base_guard_manager # to make mypy happy
  1575. out = base_guard_manager.closure_manager(
  1576. source=source_name,
  1577. example_value=example_value,
  1578. guard_manager_enum=guard_manager_enum,
  1579. )
  1580. elif istype(source, DynamicScalarSource):
  1581. assert base_guard_manager
  1582. out = base_guard_manager.lambda_manager(
  1583. python_lambda=lambda x: int(x),
  1584. source=source_name,
  1585. example_value=example_value,
  1586. guard_manager_enum=guard_manager_enum,
  1587. )
  1588. else:
  1589. raise AssertionError(
  1590. f"missing guard manager builder {source} - {source.name}"
  1591. )
  1592. self._cached_guard_managers[source.name] = out
  1593. return out
  1594. def get_guard_manager(self, guard: Guard) -> GuardManager:
  1595. return self.get_guard_manager_from_source(guard.originating_source)
  1596. def add_python_lambda_leaf_guard_to_root(
  1597. self,
  1598. code_parts: list[str],
  1599. verbose_code_parts: list[str],
  1600. closure_vars: Optional[dict[str, object]] = None,
  1601. is_epilogue: bool = True,
  1602. ) -> None:
  1603. if closure_vars is None:
  1604. closure_vars = _get_closure_vars()
  1605. # Adds a lambda leaf guard to the root guard manager. It wraps the
  1606. # code_parts in a function object which is then passed on to the leaf
  1607. # guard.
  1608. make_guard_fn_args = ", ".join(closure_vars.keys())
  1609. _guard_body, pycode = build_guard_function(code_parts, make_guard_fn_args)
  1610. out: dict[str, Any] = {}
  1611. globals_for_guard_fn = {"G": self.scope["G"]}
  1612. guards_log.debug("Python shape guard function:\n%s", pycode)
  1613. exec(pycode, globals_for_guard_fn, out)
  1614. guard_fn = out["___make_guard_fn"](*closure_vars.values())
  1615. if is_epilogue:
  1616. # Epilogue guards are run after all the other guards have finished.
  1617. # If epilogue guards contain a getattr or getitem access, one of the
  1618. # other guards would fail preventing the epilogue guards to run.
  1619. self.guard_manager.root.add_epilogue_lambda_guard(
  1620. guard_fn,
  1621. verbose_code_parts,
  1622. None,
  1623. )
  1624. else:
  1625. self.guard_manager.root.add_lambda_guard(guard_fn, verbose_code_parts, None)
  1626. # Warning: use this with care! This lets you access what the current
  1627. # value of the value you are guarding on is. You probably don't want
  1628. # to actually durably save this value though (because it's specific
  1629. # to this frame!) Instead, you should be reading out some property
  1630. # (like its type) which is what you permanently install into the
  1631. # guard code.
  1632. def get(
  1633. self,
  1634. guard_or_source: Guard | Source,
  1635. closure_vars: Optional[dict[str, Any]] = None,
  1636. ) -> Any:
  1637. if isinstance(guard_or_source, Source):
  1638. src = guard_or_source
  1639. else:
  1640. src = guard_or_source.originating_source
  1641. if closure_vars is None:
  1642. closure_vars = _get_closure_vars()
  1643. ret = src.get_value(self.scope, closure_vars, self.src_get_value_cache)
  1644. return ret
  1645. # Registers the usage of the source name referenced by the
  1646. # string (or stored in the Guard) as being guarded upon. It's important
  1647. # to call this before generating some code that makes use of 'guard',
  1648. # because without this call, we won't actually bind the variable
  1649. # you reference in the actual guard closure (oops!)
  1650. def arg_ref(self, guard: Union[str, Guard]) -> str:
  1651. name: str
  1652. if isinstance(guard, str):
  1653. name = guard
  1654. else:
  1655. name = guard.name
  1656. base = strip_function_call(name)
  1657. if base not in self.argnames:
  1658. is_valid = torch._C._dynamo.is_valid_var_name(base)
  1659. if is_valid:
  1660. if is_valid == 2:
  1661. log.warning("invalid var name: %s", guard)
  1662. self.argnames.append(base)
  1663. return name
  1664. def _guard_on_attribute(
  1665. self,
  1666. guard: Guard,
  1667. attr_name: str,
  1668. guard_fn: Callable[[GuardBuilderBase, Guard], Any],
  1669. ) -> None:
  1670. if attr_name == "__code__":
  1671. attr_source = CodeSource(guard.originating_source)
  1672. else:
  1673. attr_source = AttrSource(guard.originating_source, attr_name) # type: ignore[assignment]
  1674. # Copy the stack info
  1675. new_guard = Guard(
  1676. attr_source, guard_fn, stack=guard.stack, user_stack=guard.user_stack
  1677. )
  1678. new_guard.create(self)
  1679. # Note: the order of the guards in this file matters since we sort guards on the same object by lineno
  1680. def HASATTR(self, guard: Guard) -> None:
  1681. source = guard.originating_source
  1682. if isinstance(source, NNModuleSource):
  1683. source = source.base
  1684. if isinstance(source, CodeSource):
  1685. # No need to guard that a function has a __code__ attribute
  1686. return
  1687. assert isinstance(source, AttrSource), f"invalid source {guard.name}"
  1688. base_source = source.base
  1689. base = base_source.name
  1690. attr = source.member
  1691. ref = self.arg_ref(base)
  1692. val = hasattr(self.get(base_source), attr)
  1693. code = None
  1694. if val:
  1695. code = f"hasattr({ref}, {attr!r})"
  1696. else:
  1697. code = f"not hasattr({ref}, {attr!r})"
  1698. if code in self.already_added_code_parts:
  1699. return
  1700. self._set_guard_export_info(
  1701. guard, [code], provided_guarded_object=self.get(base_source)
  1702. )
  1703. base_manager = self.get_guard_manager_from_source(base_source)
  1704. if val:
  1705. # Just install a getattr manager. GetAttrGuardAccessor itself
  1706. # acts as hasattr guard.
  1707. example_value = self.get(source)
  1708. base_example_value = self.get(base_source)
  1709. guard_manager_enum = self.get_guard_manager_type(source, example_value)
  1710. # if the base value is nn.Module, check if we can speedup the
  1711. # guard by going through __dict__ attrs.
  1712. if should_optimize_getattr_on_nn_module(base_example_value):
  1713. self.getattr_on_nn_module(
  1714. source,
  1715. base_manager,
  1716. base_example_value,
  1717. example_value,
  1718. base,
  1719. source.name,
  1720. guard_manager_enum,
  1721. )
  1722. else:
  1723. base_manager.getattr_manager(
  1724. attr=attr,
  1725. source=guard.name,
  1726. example_value=example_value,
  1727. guard_manager_enum=guard_manager_enum,
  1728. )
  1729. else:
  1730. base_manager.add_no_hasattr_guard(
  1731. attr, get_verbose_code_parts(code, guard), guard.user_stack
  1732. )
  1733. self.already_added_code_parts.add(code)
  1734. def NOT_PRESENT_IN_GENERIC_DICT(
  1735. self, guard: Guard, attr: Optional[Any] = None
  1736. ) -> None:
  1737. assert attr is not None
  1738. ref = self.arg_ref(guard)
  1739. val = self.get(guard)
  1740. base_manager = self.get_guard_manager(guard)
  1741. code = f"not ___dict_contains({attr!r}, {ref}.__dict__)"
  1742. if code in self.already_added_code_parts:
  1743. return
  1744. mod_dict_source = f"{guard.name}.__dict__"
  1745. mod_generic_dict_manager = base_manager.get_generic_dict_manager(
  1746. source=mod_dict_source,
  1747. example_value=self._get_generic_dict_manager_example_value(val.__dict__),
  1748. guard_manager_enum=GuardManagerType.GUARD_MANAGER,
  1749. )
  1750. mod_generic_dict_manager.add_dict_contains_guard(
  1751. False,
  1752. attr,
  1753. get_verbose_code_parts(code, guard),
  1754. guard.user_stack,
  1755. )
  1756. self.already_added_code_parts.add(code)
  1757. def TYPE_MATCH(self, guard: Guard) -> None:
  1758. # ___check_type_id is same as `id(type(x)) == y`
  1759. value = self.get(guard)
  1760. if isinstance(value, torch._subclasses.FakeTensor) and value.pytype:
  1761. t = value.pytype
  1762. else:
  1763. t = type(value)
  1764. if t.__qualname__ != t.__name__:
  1765. # Type match guards must be local scope, this is
  1766. # raised in self.serialize_guards
  1767. guard._unserializable = True
  1768. obj_id = self.id_ref(t, f"type({guard.name})")
  1769. type_repr = repr(t)
  1770. code = f"___check_type_id({self.arg_ref(guard)}, {obj_id}), type={type_repr}"
  1771. self._set_guard_export_info(guard, [code])
  1772. self.get_guard_manager(guard).add_type_match_guard(
  1773. obj_id,
  1774. get_verbose_code_parts(
  1775. code, guard, recompile_hint=f"type {t.__qualname__}"
  1776. ),
  1777. guard.user_stack,
  1778. )
  1779. def DICT_VERSION(self, guard: Guard) -> None:
  1780. # ___check_dict_version is same as `dict_version(x) == y`
  1781. ref = self.arg_ref(guard)
  1782. val = self.get(guard)
  1783. version = dict_version(self.get(guard))
  1784. code = f"___dict_version({ref}) == {version}"
  1785. self._set_guard_export_info(guard, [code])
  1786. # TODO(anijain2305) - Delete this when DictGuardManager uses tags
  1787. # for dicts.
  1788. self.get_guard_manager(guard).add_dict_version_guard(
  1789. val, get_verbose_code_parts(code, guard), guard.user_stack
  1790. )
  1791. def DICT_CONTAINS(self, guard: Guard, key: str, invert: bool) -> None:
  1792. dict_ref = self.arg_ref(guard)
  1793. maybe_not = "not " if invert else ""
  1794. code = f"{maybe_not}___dict_contains({key!r}, {dict_ref})"
  1795. if code in self.already_added_code_parts:
  1796. return
  1797. self._set_guard_export_info(guard, [code])
  1798. self.get_guard_manager(guard).add_dict_contains_guard(
  1799. not invert,
  1800. key,
  1801. get_verbose_code_parts(code, guard),
  1802. guard.user_stack,
  1803. )
  1804. self.already_added_code_parts.add(code)
  1805. def SET_CONTAINS(self, guard: Guard, key: Any, invert: bool) -> None:
  1806. set_ref = self.arg_ref(guard)
  1807. item = key
  1808. contains = not invert # install_dict_contains_guard inverts "contains"
  1809. code = f"set.__contains__({set_ref}, {item!r})"
  1810. if code in self.already_added_code_parts:
  1811. return
  1812. self._set_guard_export_info(guard, [code])
  1813. self.get_guard_manager(guard).add_set_contains_guard(
  1814. contains,
  1815. item,
  1816. get_verbose_code_parts(code, guard),
  1817. guard.user_stack,
  1818. )
  1819. self.already_added_code_parts.add(code)
  1820. def BOOL_MATCH(self, guard: Guard) -> None:
  1821. # checks val == True or val == False
  1822. ref = self.arg_ref(guard)
  1823. val = self.get(guard)
  1824. assert istype(val, bool)
  1825. code = [f"{ref} == {val!r}"]
  1826. self._set_guard_export_info(guard, code)
  1827. if val:
  1828. self.get_guard_manager(guard).add_true_match_guard(
  1829. get_verbose_code_parts(code, guard), guard.user_stack
  1830. )
  1831. else:
  1832. self.get_guard_manager(guard).add_false_match_guard(
  1833. get_verbose_code_parts(code, guard), guard.user_stack
  1834. )
  1835. def NONE_MATCH(self, guard: Guard) -> None:
  1836. # checks `val is None`
  1837. ref = self.arg_ref(guard)
  1838. val = self.get(guard)
  1839. assert val is None
  1840. code = [f"{ref} is None"]
  1841. self._set_guard_export_info(guard, code)
  1842. self.get_guard_manager(guard).add_none_match_guard(
  1843. get_verbose_code_parts(code, guard), guard.user_stack
  1844. )
  1845. def ID_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> None:
  1846. # TODO - Run a CI with the following uncommented to find the remaining places
  1847. # val = self.get(guard)
  1848. # if inspect.isclass(val):
  1849. # raise AssertionError(f"{guard.name} is a class, use CLASS_MATCH guard")
  1850. # if inspect.ismodule(val):
  1851. # raise AssertionError(f"{guard.name} is a module, use MODULE_MATCH guard")
  1852. return self.id_match_unchecked(guard, recompile_hint)
  1853. def id_match_unchecked(
  1854. self, guard: Guard, recompile_hint: Optional[str] = None
  1855. ) -> None:
  1856. # ___check_obj_id is same as `id(x) == y`
  1857. if isinstance(guard.originating_source, TypeSource):
  1858. # optional optimization to produce cleaner/faster guard code
  1859. return self.TYPE_MATCH(
  1860. Guard(guard.originating_source.base, GuardBuilder.TYPE_MATCH) # type: ignore[arg-type]
  1861. )
  1862. ref = self.arg_ref(guard)
  1863. val = self.get(guard)
  1864. id_val = self.id_ref(val, guard.name)
  1865. try:
  1866. type_repr = repr(val)
  1867. except Exception:
  1868. # During deepcopy reconstruction or other state transitions,
  1869. # objects may be in an incomplete state where repr() fails
  1870. type_repr = f"<{type(val).__name__}>"
  1871. code = f"___check_obj_id({ref}, {id_val}), type={type_repr}"
  1872. self._set_guard_export_info(guard, [code], provided_func_name="ID_MATCH")
  1873. self.get_guard_manager(guard).add_id_match_guard(
  1874. id_val,
  1875. get_verbose_code_parts(code, guard, recompile_hint),
  1876. guard.user_stack,
  1877. )
  1878. # Keep track of ID_MATCH'd objects. This will be used to modify the
  1879. # cache size logic
  1880. if isinstance(guard.originating_source, LocalSource):
  1881. # TODO(anijain2305) - This is currently restricted to nn.Module objects
  1882. # because many other ID_MATCH'd objects fail - like DeviceMesh.
  1883. # Increase the scope of ID_MATCH'd objects.
  1884. if isinstance(val, torch.nn.Module):
  1885. local_name = guard.originating_source.local_name
  1886. weak_id = self.lookup_weakrefs(val)
  1887. if weak_id is not None:
  1888. self.id_matched_objs[local_name] = weak_id
  1889. def NOT_NONE_MATCH(self, guard: Guard, value: Optional[Any] = None) -> None:
  1890. ref = self.arg_ref(guard)
  1891. val = self.get(guard)
  1892. assert isinstance(val, torch.Tensor)
  1893. code = f"{ref} is not None"
  1894. self._set_guard_export_info(guard, [code])
  1895. self.get_guard_manager(guard).add_not_none_guard(
  1896. get_verbose_code_parts(code, guard), guard.user_stack
  1897. )
  1898. def DISPATCH_KEY_SET_MATCH(self, guard: Guard) -> None:
  1899. ref = self.arg_ref(guard)
  1900. val = self.get(guard)
  1901. assert isinstance(val, torch._C.DispatchKeySet)
  1902. code_parts = f"{ref}.raw_repr() == {val!r}.raw_repr()"
  1903. self.get_guard_manager(guard).add_dispatch_key_set_guard(
  1904. val,
  1905. get_verbose_code_parts(code_parts, guard),
  1906. guard.user_stack,
  1907. )
  1908. def DUAL_LEVEL(self, guard: Guard) -> None:
  1909. # Invalidate dual level if current dual level is different than the one
  1910. # in the fx graph
  1911. assert self.check_fn_manager.output_graph is not None
  1912. dual_level = self.check_fn_manager.output_graph.dual_level
  1913. code = [f"torch.autograd.forward_ad._current_level == {dual_level}"]
  1914. self._set_guard_export_info(guard, code)
  1915. self.guard_manager.root.add_dual_level_match_guard(
  1916. dual_level,
  1917. get_verbose_code_parts(code, guard),
  1918. guard.user_stack,
  1919. )
  1920. def FUNCTORCH_STACK_MATCH(self, guard: Guard) -> None:
  1921. # Invalidate functorch code if current level is different than
  1922. # the one when FX graph was generated
  1923. assert self.check_fn_manager.output_graph is not None
  1924. cis = self.check_fn_manager.output_graph.functorch_layers
  1925. states = [ci.get_state() for ci in cis]
  1926. code = [f"torch._functorch.pyfunctorch.compare_functorch_state({states})"]
  1927. self._set_guard_export_info(guard, code)
  1928. # TODO(anijain2305) - Consider this moving this guard to C++
  1929. compare_fn = torch._functorch.pyfunctorch.compare_functorch_state
  1930. def fn(x: Any) -> bool:
  1931. return compare_fn(states)
  1932. self.guard_manager.root.add_lambda_guard(
  1933. fn, get_verbose_code_parts(code, guard), guard.user_stack
  1934. )
  1935. def AUTOGRAD_SAVED_TENSORS_HOOKS(self, guard: Guard) -> None:
  1936. get_hooks = torch._functorch._aot_autograd.utils.top_saved_tensors_hooks
  1937. are_inline_hooks = (
  1938. torch._functorch._aot_autograd.utils.saved_tensors_hooks_are_inlineable
  1939. )
  1940. def hooks_ids_fn(
  1941. hooks: tuple[Callable[[torch.Tensor], Any], Callable[[Any], torch.Tensor]],
  1942. ) -> Optional[tuple[int, ...]]:
  1943. if not are_inline_hooks(hooks):
  1944. return None
  1945. return tuple(map(id, hooks))
  1946. guard_hooks_ids = hooks_ids_fn(get_hooks())
  1947. code = [
  1948. f"torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == {guard_hooks_ids}"
  1949. ]
  1950. self._set_guard_export_info(guard, code)
  1951. def fn(x: Any) -> bool:
  1952. return guard_hooks_ids == hooks_ids_fn(get_hooks())
  1953. self.guard_manager.root.add_lambda_guard(
  1954. fn, get_verbose_code_parts(code, guard), guard.user_stack
  1955. )
  1956. def TENSOR_SUBCLASS_METADATA_MATCH(self, guard: Guard) -> None:
  1957. value = self.get(guard)
  1958. original_metadata = deepcopy(self.get(guard).__tensor_flatten__()[1])
  1959. if hasattr(value, "__metadata_guard__"):
  1960. verify_guard_fn_signature(value)
  1961. cls = type(value)
  1962. def metadata_checker(x: Any) -> bool:
  1963. return cls.__metadata_guard__(
  1964. original_metadata, x.__tensor_flatten__()[1]
  1965. )
  1966. else:
  1967. def metadata_checker(x: Any) -> bool:
  1968. return x.__tensor_flatten__()[1] == original_metadata
  1969. global_name = f"___check_metadata_{id(metadata_checker)}_c{CompileContext.current_compile_id()}"
  1970. self.get_guard_manager(guard).add_lambda_guard(
  1971. metadata_checker,
  1972. get_verbose_code_parts(global_name, guard),
  1973. guard.user_stack,
  1974. )
  1975. def DTENSOR_SPEC_MATCH(self, guard: Guard) -> None:
  1976. # Copied from DTensor __metadata_guard__
  1977. # TODO - Consider moving this to C++ if stable
  1978. value = deepcopy(self.get(guard))
  1979. def guard_fn(x: Any) -> bool:
  1980. return x._check_equals(value, skip_shapes=True)
  1981. code = f"__dtensor_spec_{id(guard_fn)}"
  1982. self.get_guard_manager(guard).add_lambda_guard(
  1983. guard_fn, get_verbose_code_parts(code, guard), guard.user_stack
  1984. )
  1985. def OPAQUE_OBJ_GUARD_FN_MATCH(self, guard: Guard) -> None:
  1986. """Guard on the values returned by the opaque object's guard_fn."""
  1987. value = self.get(guard)
  1988. opaque_info = get_opaque_obj_info(type(value))
  1989. if not opaque_info or not opaque_info.guard_fn:
  1990. return
  1991. original_values = deepcopy(opaque_info.guard_fn(value))
  1992. def opaque_guard_checker(x: Any) -> bool:
  1993. current_values = opaque_info.guard_fn( # pyrefly: ignore[missing-attribute]
  1994. x
  1995. )
  1996. return current_values == original_values
  1997. global_name = f"___check_opaque_guard_fn_{id(opaque_guard_checker)}_c{CompileContext.current_compile_id()}"
  1998. self.get_guard_manager(guard).add_lambda_guard(
  1999. opaque_guard_checker,
  2000. get_verbose_code_parts(global_name, guard),
  2001. guard.user_stack,
  2002. )
  2003. def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> None:
  2004. ref = self.arg_ref(guard)
  2005. val = self.get(guard)
  2006. if np:
  2007. np_types: tuple[type[Any], ...] = (
  2008. np.int8,
  2009. np.int16,
  2010. np.int32,
  2011. np.int64,
  2012. np.uint8,
  2013. np.uint16,
  2014. np.uint32,
  2015. np.uint64,
  2016. np.float16,
  2017. np.float32,
  2018. np.float64,
  2019. )
  2020. else:
  2021. np_types = ()
  2022. ok_mutable_types = (list, set)
  2023. ok_types = tuple(
  2024. common_constant_types
  2025. | {
  2026. type,
  2027. tuple,
  2028. frozenset,
  2029. slice,
  2030. range,
  2031. dict_keys,
  2032. torch.Size,
  2033. torch.Stream,
  2034. torch.cuda.streams.Stream,
  2035. *np_types,
  2036. *ok_mutable_types,
  2037. }
  2038. )
  2039. if torch.distributed.is_available():
  2040. from torch.distributed.device_mesh import DeviceMesh
  2041. from torch.distributed.tensor.placement_types import (
  2042. _StridedShard,
  2043. Partial,
  2044. Replicate,
  2045. Shard,
  2046. )
  2047. ok_types = ok_types + (
  2048. Shard,
  2049. Replicate,
  2050. Partial,
  2051. DeviceMesh,
  2052. _StridedShard,
  2053. )
  2054. from torch.export.dynamic_shapes import _IntWrapper
  2055. ok_types = ok_types + (_IntWrapper,)
  2056. import torch.utils._pytree as pytree
  2057. assert (
  2058. isinstance(val, ok_types)
  2059. or pytree.is_constant_class(type(val))
  2060. or is_opaque_value_type(type(val))
  2061. ), f"Unexpected type {type(val)}"
  2062. # Special case for nan because float("nan") == float("nan") evaluates to False
  2063. if istype(val, float) and math.isnan(val):
  2064. code = [f"(type({ref}) is float and __math_isnan({ref}))"]
  2065. self._set_guard_export_info(guard, code)
  2066. self.get_guard_manager(guard).add_float_is_nan_guard(
  2067. get_verbose_code_parts(code, guard),
  2068. guard.user_stack,
  2069. )
  2070. return
  2071. # Python math library doesn't support complex nan, so we need to use numpy
  2072. # pyrefly: ignore [missing-attribute]
  2073. if istype(val, complex) and np.isnan(val):
  2074. code = [f"(type({ref}) is complex and __numpy_isnan({ref}))"]
  2075. self._set_guard_export_info(guard, code)
  2076. self.get_guard_manager(guard).add_complex_is_nan_guard(
  2077. get_verbose_code_parts(code, guard),
  2078. guard.user_stack,
  2079. )
  2080. return
  2081. # Construct a debug string to put into the c++ equals match guard.
  2082. code = [f"{ref} == {val!r}"]
  2083. if istype(val, ok_mutable_types):
  2084. # C++ guards perform a pointer equality check to speedup guards, but the assumption is that the object
  2085. # is immutable. For a few corner cases like sets and lists, we make a deepcopy to purposefully fail the
  2086. # pointer equality check.
  2087. val = deepcopy(val)
  2088. verbose_code_parts = get_verbose_code_parts(code, guard)
  2089. if recompile_hint:
  2090. verbose_code_parts = [
  2091. f"{part} (HINT: {recompile_hint})" for part in verbose_code_parts
  2092. ]
  2093. self.get_guard_manager(guard).add_equals_match_guard(
  2094. val, verbose_code_parts, guard.user_stack
  2095. )
  2096. self._set_guard_export_info(guard, code)
  2097. return
  2098. def CONSTANT_MATCH(self, guard: Guard) -> None:
  2099. val = self.get(guard)
  2100. if istype(val, bool):
  2101. self.BOOL_MATCH(guard)
  2102. elif val is None:
  2103. self.NONE_MATCH(guard)
  2104. elif istype(val, types.CodeType):
  2105. self.ID_MATCH(guard)
  2106. else:
  2107. self.EQUALS_MATCH(guard)
  2108. def NN_MODULE(self, guard: Guard) -> None:
  2109. # don't support this in serialization because it uses unsupported ID_MATCH
  2110. self.ID_MATCH(guard, "[inline-inbuilt-nn-modules-candidate]")
  2111. val = self.get(guard)
  2112. if hasattr(val, "training"):
  2113. assert istype(val.training, bool)
  2114. if not self.guard_nn_modules:
  2115. # If guard_nn_modules is true, we will guard on the right set of guards
  2116. self._guard_on_attribute(guard, "training", GuardBuilder.CONSTANT_MATCH) # type: ignore[arg-type]
  2117. else:
  2118. exc.unimplemented(
  2119. gb_type="Attempted to guard on uninitialized nn.Module",
  2120. context="",
  2121. explanation="Attempted to setup an NN_MODULE guard on uninitialized "
  2122. f"nn.Module subclass `{type(val)}`.",
  2123. hints=[
  2124. "Ensure the `nn.Module` subclass instance has called `super().__init__()`.",
  2125. ],
  2126. )
  2127. def FUNCTION_MATCH(self, guard: Guard) -> None:
  2128. """things like torch.add and user defined functions"""
  2129. # don't support this in serialization because it uses unsupported ID_MATCH
  2130. return self.ID_MATCH(guard)
  2131. def CLASS_MATCH(self, guard: Guard) -> None:
  2132. """Equals ID_MATCH on classes - better readability than directly calling ID_MATCH"""
  2133. val = self.get(guard)
  2134. if not inspect.isclass(val):
  2135. raise AssertionError(
  2136. f"{guard.name} is not a class, but CLASS_MATCH is used"
  2137. )
  2138. self.id_match_unchecked(guard)
  2139. def MODULE_MATCH(self, guard: Guard) -> None:
  2140. """Equals ID_MATCH on modules - better readability than directly calling ID_MATCH"""
  2141. val = self.get(guard)
  2142. if not inspect.ismodule(val):
  2143. raise AssertionError(
  2144. f"{guard.name} is not a module, but MODULE_MATCH is used"
  2145. )
  2146. self.id_match_unchecked(guard)
  2147. def CLOSURE_MATCH(self, guard: Guard) -> None:
  2148. """matches a closure by __code__ id."""
  2149. # don't support this in serialization because it uses unsupported FUNCTION_MATCH
  2150. val = self.get(guard)
  2151. # Strictly only want user-defined functions
  2152. if type(val) is types.FunctionType and hasattr(val, "__code__"):
  2153. self._guard_on_attribute(guard, "__code__", GuardBuilder.HASATTR) # type: ignore[arg-type]
  2154. self._guard_on_attribute(guard, "__code__", GuardBuilder.CONSTANT_MATCH) # type: ignore[arg-type]
  2155. else:
  2156. self.FUNCTION_MATCH(guard)
  2157. def BUILTIN_MATCH(self, guard: Guard) -> None:
  2158. if self.save_guards:
  2159. # Record which builtin variables are used for pruning later.
  2160. if isinstance(guard.originating_source, DictGetItemSource):
  2161. self.check_fn_manager.used_builtin_vars.add(
  2162. guard.originating_source.index
  2163. )
  2164. return self.id_match_unchecked(guard)
  2165. def SEQUENCE_LENGTH(self, guard: Guard) -> None:
  2166. # This guard is used to check length of PySequence objects like list,
  2167. # tuple, collections.deque etc
  2168. ref = self.arg_ref(guard)
  2169. value = self.get(guard)
  2170. if not isinstance(value, dict):
  2171. # C++ DICT_LENGTH checks for type
  2172. self.TYPE_MATCH(guard)
  2173. code = []
  2174. if len(value) == 0:
  2175. code.append(f"not {ref}")
  2176. else:
  2177. code.append(f"len({ref}) == {len(value)}")
  2178. self._set_guard_export_info(guard, code)
  2179. if isinstance(value, dict):
  2180. self.get_guard_manager(guard).add_dict_length_check_guard(
  2181. len(value),
  2182. get_verbose_code_parts(code, guard),
  2183. guard.user_stack,
  2184. )
  2185. else:
  2186. self.get_guard_manager(guard).add_length_check_guard(
  2187. len(value),
  2188. get_verbose_code_parts(code, guard),
  2189. guard.user_stack,
  2190. )
  2191. def TUPLE_ITERATOR_LEN(self, guard: Guard) -> None:
  2192. ref = self.arg_ref(guard)
  2193. value = self.get(guard)
  2194. t = type(value)
  2195. code = []
  2196. code.append(f"___tuple_iterator_len({ref}) == {tuple_iterator_len(value)}")
  2197. self._set_guard_export_info(guard, code)
  2198. t = type(value)
  2199. obj_id = self.id_ref(t, f"type({guard.name})")
  2200. self.get_guard_manager(guard).add_tuple_iterator_length_guard(
  2201. tuple_iterator_len(value),
  2202. obj_id,
  2203. get_verbose_code_parts(code, guard),
  2204. guard.user_stack,
  2205. )
  2206. def RANGE_ITERATOR_MATCH(self, guard: Guard) -> None:
  2207. ref = self.arg_ref(guard)
  2208. value = self.get(guard)
  2209. t = type(value)
  2210. code = []
  2211. normalized_range_iter = normalize_range_iter(value)
  2212. code.append(f"___normalize_range_iter({ref}) == {normalized_range_iter}")
  2213. self._set_guard_export_info(guard, code)
  2214. t = type(value)
  2215. obj_id = self.id_ref(t, f"type({guard.name})")
  2216. start, stop, step = normalized_range_iter
  2217. self.get_guard_manager(guard).add_range_iterator_match_guard(
  2218. start,
  2219. stop,
  2220. step,
  2221. obj_id,
  2222. get_verbose_code_parts(code, guard),
  2223. guard.user_stack,
  2224. )
  2225. # TODO(voz): Deduplicate w/ AOTAutograd dupe input guards
  2226. def DUPLICATE_INPUT(self, guard: Guard, source_b: Source) -> None:
  2227. if is_from_skip_guard_source(
  2228. guard.originating_source
  2229. ) or is_from_skip_guard_source(source_b):
  2230. return
  2231. if self.save_guards:
  2232. if name := get_local_source_name(source_b):
  2233. self.check_fn_manager.additional_used_local_vars.add(name)
  2234. if name := get_global_source_name(source_b):
  2235. self.check_fn_manager.additional_used_global_vars.add(name)
  2236. ref_a = self.arg_ref(guard)
  2237. ref_b = self.arg_ref(source_b.name)
  2238. if is_from_optimizer_source(
  2239. guard.originating_source
  2240. ) or is_from_optimizer_source(source_b):
  2241. return
  2242. # Check that the guard has not been inserted already
  2243. key = (ref_a, ref_b)
  2244. if key in self._cached_duplicate_input_guards:
  2245. return
  2246. self._cached_duplicate_input_guards.add((ref_a, ref_b))
  2247. self._cached_duplicate_input_guards.add((ref_b, ref_a))
  2248. code = [f"{ref_b} is {ref_a}"]
  2249. self._set_guard_export_info(guard, code)
  2250. if config.use_lamba_guard_for_object_aliasing:
  2251. # Save the code part so that we can install a lambda guard at the
  2252. # end. Read the Note - On Lambda guarding of object aliasing - to
  2253. # get more information.
  2254. code_part = code[0]
  2255. verbose_code_part = get_verbose_code_parts(code_part, guard)[0]
  2256. self.object_aliasing_guard_codes.append((code_part, verbose_code_part))
  2257. else:
  2258. install_object_aliasing_guard(
  2259. self.get_guard_manager(guard),
  2260. self.get_guard_manager_from_source(source_b),
  2261. get_verbose_code_parts(code, guard),
  2262. guard.user_stack,
  2263. )
  2264. def WEAKREF_ALIVE(self, guard: Guard) -> None:
  2265. code = [f"{self.arg_ref(guard)} is not None"]
  2266. self._set_guard_export_info(guard, code)
  2267. self.get_guard_manager(guard).add_not_none_guard(
  2268. get_verbose_code_parts(code, guard), guard.user_stack
  2269. )
  2270. def MAPPING_KEYS_CHECK(self, guard: Guard) -> None:
  2271. """Guard on the key order of types.MappingProxyType object"""
  2272. ref = self.arg_ref(guard)
  2273. value = self.get(guard)
  2274. code = []
  2275. code.append(f"list({ref}.keys()) == {list(value.keys())}")
  2276. self._set_guard_export_info(guard, code)
  2277. self.get_guard_manager(guard).add_mapping_keys_guard(
  2278. value, code, guard.user_stack
  2279. )
  2280. def DICT_KEYS_MATCH(self, guard: Guard) -> None:
  2281. """Insert guard to check that the keys of a dict are same"""
  2282. ref = self.arg_ref(guard)
  2283. value = self.get(guard)
  2284. if value is torch.utils._pytree.SUPPORTED_NODES:
  2285. # For SUPPORTED_NODES, we can guard on the dictionary version (PEP509).
  2286. self.DICT_VERSION(guard)
  2287. return
  2288. self.SEQUENCE_LENGTH(guard)
  2289. code = []
  2290. # Ensure that we call dict.keys and not value.keys (which can call
  2291. # overridden keys method). In the C++ guards, we relied on PyDict_Next
  2292. # to traverse the dictionary, which uses the internal data structure and
  2293. # does not call the overridden keys method.
  2294. code.append(f"list(dict.keys({ref})) == {list(builtin_dict_keys(value))!r}")
  2295. self._set_guard_export_info(guard, code)
  2296. if self.requires_key_order_guarding(guard.originating_source):
  2297. self.guard_on_dict_keys_and_order(value, guard)
  2298. else:
  2299. self.guard_on_dict_keys_and_ignore_order(value, guard)
  2300. def EMPTY_NN_MODULE_HOOKS_DICT(self, guard: Guard) -> None:
  2301. """Special guard to skip guards on empty hooks. This is controlled by skip_nnmodule_hook_guards"""
  2302. if config.skip_nnmodule_hook_guards:
  2303. # This is unsafe if you add/remove a hook on nn module variable
  2304. return
  2305. self.SEQUENCE_LENGTH(guard)
  2306. def GRAD_MODE(self, guard: Guard) -> None:
  2307. pass # we always guard on this via GlobalStateGuard()
  2308. def DETERMINISTIC_ALGORITHMS(self, guard: Guard) -> None:
  2309. pass # we always guard on this via GlobalStateGuard()
  2310. def FSDP_TRAINING_STATE(self, guard: Guard) -> None:
  2311. pass # we always guard on this via GlobalStateGuard()
  2312. def GLOBAL_STATE(self, guard: Guard) -> None:
  2313. output_graph = self.check_fn_manager.output_graph
  2314. assert output_graph is not None
  2315. global_state = output_graph.global_state_guard
  2316. self.check_fn_manager.global_state = global_state
  2317. code = [
  2318. f"___check_global_state() against {self.check_fn_manager.global_state.__getstate__()}"
  2319. ]
  2320. self.guard_manager.root.add_global_state_guard(
  2321. global_state, code, guard.user_stack
  2322. )
  2323. def TORCH_FUNCTION_STATE(self, guard: Guard) -> None:
  2324. assert self.check_fn_manager.torch_function_mode_stack is not None
  2325. self.check_fn_manager.torch_function_mode_stack_check_fn = (
  2326. make_torch_function_mode_stack_guard(
  2327. self.check_fn_manager.torch_function_mode_stack
  2328. )
  2329. )
  2330. self.guard_manager.root.add_torch_function_mode_stack_guard(
  2331. self.check_fn_manager.torch_function_mode_stack,
  2332. ["___check_torch_function_mode_stack()"],
  2333. guard.user_stack,
  2334. )
  2335. def DEFAULT_DEVICE(self, guard: Guard) -> None:
  2336. """Guard on CURRENT_DEVICE per torch.utils._device"""
  2337. assert guard.source is GuardSource.GLOBAL
  2338. assert self.check_fn_manager.output_graph is not None
  2339. code = [
  2340. f"utils_device.CURRENT_DEVICE == {self.check_fn_manager.output_graph.current_device!r}"
  2341. ]
  2342. self._set_guard_export_info(guard, code)
  2343. self.get_guard_manager(guard).add_default_device_guard(
  2344. get_verbose_code_parts(code, guard), guard.user_stack
  2345. )
  2346. def SHAPE_ENV(self, guard: Guard) -> None:
  2347. from torch._dynamo.output_graph import OutputGraphCommon
  2348. assert guard.name == ""
  2349. output_graph = self.check_fn_manager.output_graph
  2350. assert output_graph is not None
  2351. if self.check_fn_manager.shape_code_parts is not None:
  2352. shape_code_parts = self.check_fn_manager.shape_code_parts
  2353. python_code_parts = shape_code_parts.python_code_parts
  2354. verbose_code_parts = shape_code_parts.verbose_code_parts
  2355. if shape_code_parts.cpp_code_parts is not None:
  2356. cpp_code_parts = shape_code_parts.cpp_code_parts
  2357. python_fallback = shape_code_parts.python_fallback
  2358. else:
  2359. # Let's handle ShapeEnv guards. To do this, we will resolve
  2360. # shape variables to sources from tracked_fakes. This must happen after
  2361. # tensor checks.
  2362. # NB: self.output_graph can be None in the debug_nops tests
  2363. assert isinstance(output_graph, OutputGraphCommon)
  2364. assert output_graph.shape_env is not None
  2365. fs = output_graph.shape_env.tracked_fakes or []
  2366. input_contexts = [a.symbolic_context for a in fs]
  2367. def get_sources(t_id: int, dim: int) -> list[Source]:
  2368. # Looks up base sources mapped to a tensor id and uses them to create
  2369. # sources for the corresponding tensor dimension.
  2370. return [
  2371. TensorPropertySource(source, TensorProperty.SIZE, dim)
  2372. # pyrefly: ignore [missing-attribute]
  2373. for source in output_graph.tracked_fakes_id_to_source[t_id]
  2374. ]
  2375. if output_graph.export_constraints:
  2376. names: dict[str, tuple[int, int]] = {}
  2377. source_pairs: list[tuple[Source, Source]] = []
  2378. derived_equalities: list[ # type: ignore[type-arg]
  2379. # pyrefly: ignore [implicit-any]
  2380. tuple[Source, Union[Source, Symbol], Callable]
  2381. ] = []
  2382. phantom_symbols: dict[str, Symbol] = {}
  2383. relaxed_sources: set[Source] = set()
  2384. for constraint in output_graph.export_constraints: # type: ignore[attr-defined]
  2385. if constraint.t_id in output_graph.tracked_fakes_id_to_source:
  2386. torch.export.dynamic_shapes._process_equalities(
  2387. constraint,
  2388. get_sources,
  2389. output_graph.shape_env,
  2390. names,
  2391. source_pairs,
  2392. derived_equalities,
  2393. phantom_symbols,
  2394. relaxed_sources,
  2395. )
  2396. else:
  2397. log.warning("Untracked tensor used in export constraints")
  2398. equalities_inputs = EqualityConstraint(
  2399. source_pairs=source_pairs,
  2400. derived_equalities=derived_equalities,
  2401. phantom_symbols=list(phantom_symbols.values()),
  2402. relaxed_sources=relaxed_sources,
  2403. warn_only=False,
  2404. )
  2405. else:
  2406. equalities_inputs = None
  2407. def _get_code_parts(langs: tuple[str, ...]) -> list[_ShapeGuardsHelper]:
  2408. # pyrefly: ignore [missing-attribute]
  2409. return output_graph.shape_env.produce_guards_verbose(
  2410. [a.fake for a in fs], # type: ignore[misc]
  2411. [a.source for a in fs],
  2412. input_contexts=input_contexts, # type: ignore[arg-type]
  2413. equalities_inputs=equalities_inputs,
  2414. source_ref=self.source_ref,
  2415. # Export keeps static.
  2416. # pyrefly: ignore [missing-attribute]
  2417. ignore_static=(not output_graph.export),
  2418. langs=langs,
  2419. )
  2420. if config.enable_cpp_symbolic_shape_guards:
  2421. try:
  2422. # For exporting we need the python code parts
  2423. python_code_parts, verbose_code_parts, cpp_code_parts = (
  2424. _get_code_parts(("python", "verbose_python", "cpp")) # type: ignore[assignment]
  2425. )
  2426. python_fallback = False
  2427. except OverflowError:
  2428. # Cannot use int64_t
  2429. python_fallback = True
  2430. python_code_parts, verbose_code_parts = _get_code_parts(
  2431. ("python", "verbose_python")
  2432. )
  2433. else:
  2434. python_fallback = True
  2435. python_code_parts, verbose_code_parts = _get_code_parts(
  2436. ("python", "verbose_python")
  2437. )
  2438. # When exporting, we may work with the shape constraints some more in
  2439. # postprocessing, so don't freeze yet
  2440. if not output_graph.export:
  2441. output_graph.shape_env.freeze()
  2442. if self.save_guards:
  2443. # For SHAPE_ENV we want to skip serializing the entire ShapeEnv so instead
  2444. # we directly serialize the generated code here.
  2445. maybe_cpp_code_parts = locals().get("cpp_code_parts")
  2446. assert maybe_cpp_code_parts is None or isinstance(
  2447. maybe_cpp_code_parts, _CppShapeGuardsHelper
  2448. )
  2449. maybe_shape_env_sources = (
  2450. []
  2451. if maybe_cpp_code_parts is None
  2452. else list(maybe_cpp_code_parts.source_to_symbol.keys())
  2453. )
  2454. self.check_fn_manager.shape_code_parts = ShapeCodeParts(
  2455. python_code_parts=python_code_parts,
  2456. verbose_code_parts=verbose_code_parts,
  2457. cpp_code_parts=maybe_cpp_code_parts,
  2458. python_fallback=python_fallback,
  2459. shape_env_sources=maybe_shape_env_sources,
  2460. )
  2461. for code in python_code_parts.exprs:
  2462. self._set_guard_export_info(guard, [code])
  2463. # Make ShapeEnv guards available for testing.
  2464. if compile_context := CompileContext.try_get():
  2465. compile_context.shape_env_guards.extend(verbose_code_parts.exprs)
  2466. int_source_to_symbol = []
  2467. float_source_to_symbol = []
  2468. if not python_fallback:
  2469. assert cpp_code_parts # type: ignore[possibly-undefined]
  2470. code_parts, source_to_symbol = (
  2471. # pyrefly: ignore [unbound-name]
  2472. cpp_code_parts.exprs,
  2473. # pyrefly: ignore [unbound-name, missing-attribute]
  2474. cpp_code_parts.source_to_symbol,
  2475. )
  2476. if not code_parts:
  2477. return
  2478. for source, symbol in source_to_symbol.items():
  2479. if isinstance(source, ConstantSource):
  2480. python_fallback = True
  2481. else:
  2482. example_value = self.get(
  2483. source,
  2484. closure_vars={**SYMPY_INTERP, **_get_closure_vars()},
  2485. )
  2486. if isinstance(example_value, int):
  2487. int_source_to_symbol.append((source, symbol))
  2488. elif isinstance(example_value, float):
  2489. float_source_to_symbol.append((source, symbol))
  2490. else:
  2491. # SymInts/SymFloats go through python guard as we only support
  2492. # int64_t/double in C++ guards for now.
  2493. python_fallback = True
  2494. if not python_fallback:
  2495. import ctypes
  2496. from torch._inductor.codecache import CppCodeCache
  2497. assert cpp_code_parts # type: ignore[possibly-undefined]
  2498. code_parts, source_to_symbol = (
  2499. # pyrefly: ignore [unbound-name]
  2500. cpp_code_parts.exprs,
  2501. # pyrefly: ignore [unbound-name, missing-attribute]
  2502. cpp_code_parts.source_to_symbol,
  2503. )
  2504. source_to_symbol = dict(int_source_to_symbol + float_source_to_symbol)
  2505. try:
  2506. guard_managers = [
  2507. self.get_guard_manager_from_source(IndexedSource(source, i))
  2508. for i, source in enumerate(source_to_symbol)
  2509. ]
  2510. int_symbols_str = ", ".join(
  2511. f"{symbol} = int_values[{i}]"
  2512. for i, (_, symbol) in enumerate(int_source_to_symbol)
  2513. )
  2514. float_symbols_str = ", ".join(
  2515. f"{symbol} = float_values[{i}]"
  2516. for i, (_, symbol) in enumerate(float_source_to_symbol)
  2517. )
  2518. if int_symbols_str:
  2519. int_symbols_str = f"int64_t {int_symbols_str};"
  2520. if float_symbols_str:
  2521. float_symbols_str = f"double {float_symbols_str};"
  2522. func_str = textwrap.dedent(
  2523. f"""
  2524. #include <algorithm>
  2525. #include <cstdint>
  2526. #include <cmath>
  2527. #include <c10/util/generic_math.h>
  2528. #if defined(_MSC_VER)
  2529. # define EXTERN_DLL_EXPORT extern "C" __declspec(dllexport)
  2530. #else
  2531. # define EXTERN_DLL_EXPORT extern "C"
  2532. #endif
  2533. EXTERN_DLL_EXPORT int8_t guard(int64_t *int_values, double *float_values) {{
  2534. {int_symbols_str}
  2535. {float_symbols_str}
  2536. return ({") && (".join(code_parts)});
  2537. }}
  2538. """
  2539. )
  2540. guards_log.debug(
  2541. "C++ shape guard function: %s %s",
  2542. func_str,
  2543. verbose_code_parts.exprs,
  2544. )
  2545. clib = CppCodeCache.load(func_str)
  2546. cguard = ctypes.cast(clib.guard, ctypes.c_void_p).value
  2547. assert cguard
  2548. except torch._inductor.exc.InvalidCxxCompiler:
  2549. # No valid C++ compiler to compile the shape guard
  2550. pass
  2551. else:
  2552. install_symbolic_shape_guard(
  2553. guard_managers,
  2554. len(int_source_to_symbol),
  2555. len(float_source_to_symbol),
  2556. cguard,
  2557. clib,
  2558. verbose_code_parts.exprs,
  2559. guard.user_stack,
  2560. )
  2561. return
  2562. # Install all the symbolic guards in one python lambda guard. These are run
  2563. # at the very end of the RootGuardManager via epilogue guards.
  2564. # TODO(anijain2305,williamwen42) - Consider moving this to C++.
  2565. if python_code_parts.exprs:
  2566. self.add_python_lambda_leaf_guard_to_root(
  2567. python_code_parts.exprs,
  2568. verbose_code_parts.exprs,
  2569. closure_vars={**SYMPY_INTERP, **_get_closure_vars()},
  2570. )
  2571. def TENSOR_MATCH(self, guard: Guard, value: Optional[Any] = None) -> None:
  2572. if config._unsafe_skip_fsdp_module_guards and guard.is_fsdp_module():
  2573. return
  2574. # For tensors that are part of the Dynamo extracted Fx graph module, an
  2575. # ID_MATCH suffices. Once we turn on inline_inbuilt_nn_modules, these
  2576. # will be lifted as inputs and have a TENSOR_MATCH guard.
  2577. if match_on_id_for_tensor(guard):
  2578. self.ID_MATCH(guard)
  2579. else:
  2580. if isinstance(value, TensorWeakRef):
  2581. value = value()
  2582. value = value if value is not None else self.get(guard)
  2583. pytype = type(value)
  2584. dispatch_keys = torch._C._dispatch_keys(value)
  2585. if isinstance(value, torch._subclasses.FakeTensor):
  2586. if value.pytype is not None:
  2587. pytype = value.pytype
  2588. if value.dispatch_keys is not None:
  2589. dispatch_keys = value.dispatch_keys
  2590. assert isinstance(value, torch.Tensor)
  2591. if config.log_compilation_metrics and isinstance(value, torch.nn.Parameter):
  2592. metrics_context = get_metrics_context()
  2593. if metrics_context.in_progress():
  2594. metrics_context.increment("param_numel", value.numel())
  2595. metrics_context.increment("param_bytes", value.nbytes)
  2596. metrics_context.increment("param_count", 1)
  2597. tensor_name = self.arg_ref(guard)
  2598. # [Note - On Export Tensor Guards]
  2599. #
  2600. # In eager mode, tensor guards are evaluated through C++, in guards.cpp
  2601. # see [Note - On Eager Tensor Guards] for more info.
  2602. #
  2603. # In export mode, we instead maintain parallel logic between C++ and python
  2604. # here, with an exception of checking the dispatch key - with the idea that a dispatch key
  2605. # is an entirely runtime notion that would make no sense to keep in an exported graph.
  2606. #
  2607. # Now, this idea is okay, but to paraphrase @ezyang, this mental model is sufficient for now, although
  2608. # not entirely true.
  2609. # For example, suppose one of the input tensors had the negative dispatch key.
  2610. # You should end up with a graph that is specialized for tensors that have a negative dispatch key.
  2611. # If you allow a Tensor that does NOT have this bit set, you will accidentally run it "as if" it were negated.
  2612. # Now, negative key only shows up for complex numbers, and most likely, the exported to target doesn't
  2613. # support this feature at all, but the point stands that :some: tensor state only shows up on dispatch key.
  2614. # TODO(voz): Either populate a dispatch_key check into the guards, or error on users passing in an unsupported
  2615. # subset of keys during export.
  2616. #
  2617. # The list of tensor fields and calls we care about can be found in `terms` below.
  2618. # TODO(voz): We are missing storage offset in all our tensor guards?
  2619. code: list[str] = []
  2620. assert self.check_fn_manager.output_graph is not None
  2621. if self.check_fn_manager.output_graph.export:
  2622. self.TYPE_MATCH(guard)
  2623. terms = [
  2624. "dtype",
  2625. "device",
  2626. "requires_grad",
  2627. "ndimension",
  2628. ]
  2629. for term in terms:
  2630. term_src = AttrSource(guard.originating_source, term)
  2631. if term == "ndimension":
  2632. term = "ndimension()"
  2633. term_src = CallFunctionNoArgsSource(term_src)
  2634. real_value = self.get(term_src)
  2635. if istype(real_value, (torch.device, torch.dtype)):
  2636. # copy pasted from EQUALS_MATCH
  2637. code.append(f"str({tensor_name}.{term}) == {str(real_value)!r}")
  2638. else:
  2639. code.append(f"{tensor_name}.{term} == {real_value}")
  2640. else:
  2641. guard_manager = self.get_guard_manager(guard)
  2642. # skip_no_tensor_aliasing_guards_on_parameters bring
  2643. # unsoundness. If you compile a function with two different
  2644. # parameters, but later on you pass on same tensor as two
  2645. # different outputs (aliasing), Dynamo will not detect this.
  2646. # But we deliberately take this soundness hit because this
  2647. # usecase is quite rare and there is substantial reduction in
  2648. # guard overhead.
  2649. # For numpy tensors, since those are ephemeral, we don't have to
  2650. # insert aliasing guards on them
  2651. if not (
  2652. config.skip_no_tensor_aliasing_guards_on_parameters
  2653. and (
  2654. istype(value, torch.nn.Parameter)
  2655. or is_from_unspecialized_builtin_nn_module_source(
  2656. guard.originating_source
  2657. )
  2658. )
  2659. ) and not isinstance(guard.originating_source, NumpyTensorSource):
  2660. # Keep track of all the tensor guard managers to insert
  2661. # NoAliasing check at the end.
  2662. self.no_tensor_aliasing_names.append(tensor_name)
  2663. self.no_tensor_aliasing_guard_managers.append(guard_manager)
  2664. output_graph = self.check_fn_manager.output_graph
  2665. metadata = output_graph.input_source_to_sizes_strides[
  2666. guard.originating_source
  2667. ]
  2668. size = convert_to_concrete_values(metadata["size"])
  2669. stride = convert_to_concrete_values(metadata["stride"])
  2670. verbose_code_parts = get_verbose_code_parts(
  2671. get_tensor_guard_code_part(
  2672. value,
  2673. tensor_name,
  2674. size,
  2675. stride,
  2676. pytype,
  2677. dispatch_keys,
  2678. ),
  2679. guard,
  2680. )
  2681. user_stack = guard.user_stack
  2682. guard_manager.add_tensor_match_guard(
  2683. value,
  2684. size, # type: ignore[arg-type]
  2685. stride, # type: ignore[arg-type]
  2686. tensor_name,
  2687. verbose_code_parts,
  2688. user_stack,
  2689. pytype,
  2690. dispatch_keys,
  2691. )
  2692. # We consider TENSOR_MATCH guard to be important enough to be
  2693. # included in diff guard manager by default.
  2694. if not isinstance(value, torch.nn.Parameter):
  2695. self.guard_manager.diff_guard_sources.add(guard.name)
  2696. # A frame is valid for reuse with dynamic dimensions if the new
  2697. # (user-requested) dynamic dimensions are a subset of the old
  2698. # (already compiled) dynamic dimensions.
  2699. #
  2700. # It's a little non-obvious why you'd want this: in particular,
  2701. # if an already compiled frame matches all of the guards, why
  2702. # not just use it, why force a recompile?
  2703. #
  2704. # We force it for two reasons:
  2705. #
  2706. # - The user *required* us to compile with a new dynamic dimension,
  2707. # we should not ignore that and serve up the old, specialized
  2708. # frame. Listen to the user!
  2709. #
  2710. # - In fact, we are obligated to *raise an error* if we fail to
  2711. # make the requested dimension dynamic. If we don't
  2712. # recompile, we can't tell if that dimension can actually be
  2713. # made dynamic.
  2714. #
  2715. # If the new dynamic dims are a subset of the old, we already know
  2716. # we can make them dynamic (since we made them dynamic in old).
  2717. # This is slightly unsound, because maybe your input size is
  2718. # [s0, s0, s1] and so you can do it dynamic if you say dynamic
  2719. # dims {0, 1, 2} but you can't if you only do {0, 2} (because now
  2720. # the second s0 is specialized). But we're not entirely sure if
  2721. # this is a good idea anyway lol... (if you want to try removing
  2722. # this logic, be my guest! -- ezyang 2024)
  2723. #
  2724. assert guard.source is not None
  2725. static, _reason = tensor_always_has_static_shape(
  2726. value, is_tensor=True, tensor_source=guard.originating_source
  2727. )
  2728. if not static:
  2729. if hasattr(value, "_dynamo_dynamic_indices"):
  2730. dynamic_indices = value._dynamo_dynamic_indices
  2731. code_part = f"(({tensor_name}._dynamo_dynamic_indices.issubset({dynamic_indices})) if hasattr({tensor_name}, '_dynamo_dynamic_indices') else True)" # noqa: B950
  2732. code.append(code_part)
  2733. self.get_guard_manager(guard).add_dynamic_indices_guard(
  2734. dynamic_indices,
  2735. get_verbose_code_parts(code_part, guard),
  2736. guard.user_stack,
  2737. )
  2738. # In the case of us not having any dynamic dimension indices, we compiled the frame with no chance of
  2739. # raising for this specific tensor - and any inputs with more dynamic user directives specified must be recompiled.
  2740. else:
  2741. code_part = (
  2742. f"hasattr({tensor_name}, '_dynamo_dynamic_indices') == False"
  2743. )
  2744. code.append(code_part)
  2745. self.get_guard_manager(guard).add_no_hasattr_guard(
  2746. "_dynamo_dynamic_indices",
  2747. get_verbose_code_parts(code_part, guard),
  2748. guard.user_stack,
  2749. )
  2750. # Guard on shape_ids when tensor has unbacked indices.
  2751. # shape_id is only set via mark_unbacked, which sets _dynamo_unbacked_indices.
  2752. # Empty dict is treated the same as not having the attribute.
  2753. if shape_ids := getattr(value, "_dynamo_shape_ids", None):
  2754. code_part = f"((getattr({tensor_name}, '_dynamo_shape_ids', None) == {shape_ids!r}) if hasattr({tensor_name}, '_dynamo_unbacked_indices') else True)" # noqa: B950
  2755. code.append(code_part)
  2756. self.get_guard_manager(guard).add_lambda_guard(
  2757. lambda x, expected=shape_ids: (
  2758. getattr(x, "_dynamo_shape_ids", None) == expected
  2759. if hasattr(x, "_dynamo_unbacked_indices")
  2760. else True
  2761. ),
  2762. get_verbose_code_parts(code_part, guard),
  2763. guard.user_stack,
  2764. )
  2765. # TODO we dont have guards on _dynamo_unbacked_indices like those of _dynamo_dynamic_indices this seems wrong!!
  2766. if len(code) > 0:
  2767. self._set_guard_export_info(guard, code)
  2768. # A util that in the case of export, adds data onto guards
  2769. def _set_guard_export_info(
  2770. self,
  2771. guard: Guard,
  2772. code_list: list[str],
  2773. provided_guarded_object: Optional[Any] = None,
  2774. provided_func_name: Optional[str] = None,
  2775. ) -> None:
  2776. # WARNING: It is important that cur_frame/caller do NOT stay in
  2777. # the current frame, because they will keep things live longer
  2778. # than they should. See TestMisc.test_release_module_memory
  2779. cur_frame = currentframe()
  2780. assert cur_frame is not None
  2781. caller = cur_frame.f_back
  2782. del cur_frame
  2783. assert caller is not None
  2784. func_name = provided_func_name or caller.f_code.co_name
  2785. del caller
  2786. # We use func_name for export, so might as well get a nice defensive check out of it
  2787. assert func_name in self.__class__.__dict__, (
  2788. f"_produce_guard_code must be called from inside GuardedCode. Called from {func_name}"
  2789. )
  2790. # Not all guards have names, some can be installed globally (see asserts on HAS_GRAD)
  2791. if provided_guarded_object is None:
  2792. name = guard.name
  2793. guarded_object = None if not name else self.get(guard)
  2794. else:
  2795. guarded_object = provided_guarded_object
  2796. guarded_object_type = (
  2797. weakref.ref(type(guarded_object)) if guarded_object is not None else None
  2798. )
  2799. obj_ref = None
  2800. # Not necessary to have weakref for Enum type, but there is a bug that
  2801. # makes hasattr(guarded_object.__class__, "__weakref__") return True.
  2802. supports_weakref = (
  2803. getattr(guarded_object.__class__, "__weakrefoffset__", 0) != 0
  2804. )
  2805. # See D64140537 for why we are checking for tuple.
  2806. if supports_weakref and not isinstance(
  2807. guarded_object, (enum.Enum, tuple, weakref.ProxyTypes)
  2808. ):
  2809. obj_ref = weakref.ref(guarded_object)
  2810. guard.set_export_info(
  2811. func_name,
  2812. guarded_object_type,
  2813. code_list,
  2814. obj_ref,
  2815. )
  2816. # Common Sub-Expression Elimination for Python expressions.
  2817. #
  2818. # There are 2 steps to this pass:
  2819. # 1. Count the frequency of each sub-expression (i.e. inner
  2820. # node in the AST tree)
  2821. #
  2822. # 2. Replace those that occur more than once by a fresh variable 'v'.
  2823. # 'v' will be defined in the 'preface' list (output argument to
  2824. # 'NodeTransformer')
  2825. #
  2826. # NB: the use of 'ast.unparse' while visiting the nodes makes this pass
  2827. # quadratic on the depth of the tree.
  2828. #
  2829. # NB: this pass creates a new variable for each AST node that is repeated
  2830. # more than 'USE_THRESHOLD'. e.g. if 'a.b.c.d' is used 10 times, 'a.b.c'
  2831. # and 'a.b' are also used 10 times. So, there will be a new variable for
  2832. # each of them.
  2833. class PyExprCSEPass:
  2834. # Maximum number of times a given expression can be used without being
  2835. # replaced by a fresh variable.
  2836. USE_THRESHOLD = 1
  2837. # Ad-Hoc: AST nodes this pass focuses on.
  2838. ALLOWED_NODE_TYPES = (ast.Attribute, ast.Call, ast.Subscript)
  2839. @dataclasses.dataclass
  2840. class Config:
  2841. expr_count: dict[str, int]
  2842. expr_to_name: dict[str, str]
  2843. class ExprCounter(ast.NodeVisitor):
  2844. def __init__(self, config: PyExprCSEPass.Config) -> None:
  2845. self._config = config
  2846. def visit(self, node: ast.AST) -> None:
  2847. if isinstance(node, PyExprCSEPass.ALLOWED_NODE_TYPES):
  2848. self._config.expr_count[_ast_unparse(node)] += 1
  2849. super().visit(node)
  2850. class Replacer(ast.NodeTransformer):
  2851. def __init__(
  2852. self,
  2853. config: PyExprCSEPass.Config,
  2854. gen_name: Callable[[], str],
  2855. ) -> None:
  2856. super().__init__()
  2857. self._config = config
  2858. self._gen_name = gen_name
  2859. self.preface: list[str] = []
  2860. def visit(self, node: ast.AST) -> Any:
  2861. if isinstance(node, PyExprCSEPass.ALLOWED_NODE_TYPES):
  2862. expr = _ast_unparse(node)
  2863. # Replacement only occurs if a given expression is used more
  2864. # than once.
  2865. if self._config.expr_count[expr] > PyExprCSEPass.USE_THRESHOLD:
  2866. if expr not in self._config.expr_to_name:
  2867. # Parent 'visit' is called so that we CSE the inner expressions first.
  2868. #
  2869. # The resulting expression is used as right-hand-side of the variable
  2870. # assignment. i.e. we are CSE-ing the children before the parents.
  2871. #
  2872. # Indexing still uses the old 'node', since that's what was counted
  2873. # by the 'NodeVisitor'.
  2874. node_ = super().visit(node)
  2875. expr_ = _ast_unparse(node_)
  2876. var_name = self._gen_name()
  2877. self.preface.append(f"{var_name} = {expr_}")
  2878. self._config.expr_to_name[expr] = var_name
  2879. else:
  2880. var_name = self._config.expr_to_name[expr]
  2881. return ast.Name(var_name, ast.Load())
  2882. return super().visit(node)
  2883. def __init__(self) -> None:
  2884. self._counter = 0
  2885. self._config = self.Config(
  2886. expr_count=collections.defaultdict(lambda: 0), expr_to_name={}
  2887. )
  2888. def _new_var(self, prefix: str = "_var") -> str:
  2889. name = f"{prefix}{self._counter}"
  2890. self._counter += 1
  2891. return name
  2892. def count(self, exprs: list[str]) -> None:
  2893. counter = self.ExprCounter(self._config)
  2894. for e in exprs:
  2895. try:
  2896. counter.visit(ast.parse(e))
  2897. except SyntaxError as ex:
  2898. log.exception("Failed to visit expr at line %s.\n%s", ex.lineno, e)
  2899. raise
  2900. def replace(self, expr: str) -> tuple[list[str], str]:
  2901. replacer = self.Replacer(self._config, self._new_var)
  2902. new_node = replacer.visit(ast.parse(expr))
  2903. return replacer.preface, _ast_unparse(new_node)
  2904. def must_add_nn_module_guards(guard: Guard) -> bool:
  2905. # For config.guard_nn_modules=False, we can skip all the guards that
  2906. # originate from inside of nn module except for a few categories.
  2907. return (
  2908. # Guard for defaults
  2909. isinstance(guard.originating_source, DefaultsSource)
  2910. # Guard using dict tags if the config flag is set
  2911. or (
  2912. config.guard_nn_modules_using_dict_tags
  2913. and guard.create_fn is GuardBuilder.NN_MODULE
  2914. )
  2915. )
  2916. class DeletedGuardManagerWrapper(GuardManagerWrapper):
  2917. def __init__(self, reason: str) -> None:
  2918. super().__init__()
  2919. self.invalidation_reason = reason
  2920. def populate_diff_guard_manager(self) -> None:
  2921. self.diff_guard_root = None
  2922. @dataclasses.dataclass
  2923. class ShapeCodeParts:
  2924. python_code_parts: _ShapeGuardsHelper
  2925. verbose_code_parts: _ShapeGuardsHelper
  2926. cpp_code_parts: Optional[_CppShapeGuardsHelper]
  2927. python_fallback: bool
  2928. shape_env_sources: list[Source]
  2929. @dataclasses.dataclass
  2930. class GuardsState:
  2931. output_graph: OutputGraphGuardsState
  2932. shape_code_parts: Optional[ShapeCodeParts]
  2933. class _Missing:
  2934. def __init__(self, reason: Optional[str] = None) -> None:
  2935. self._reason = reason
  2936. def __repr__(self) -> str:
  2937. return f"_Missing({self._reason})"
  2938. def __str__(self) -> str:
  2939. return f"_Missing({self._reason})"
  2940. # Sometimes _Missing object is used as the callable with functools.partial,
  2941. # so we add a dummy __call__ here to bypass TypeError from partial().
  2942. def __call__(self, *args: Any, **kwargs: Any) -> Any:
  2943. return _Missing()
  2944. @functools.cache
  2945. def _get_unsupported_types() -> tuple[type, ...]:
  2946. # We only do ID_MATCH on C objects which is already banned from guards serialization.
  2947. ret: tuple[type, ...] = (
  2948. torch._C.Stream,
  2949. weakref.ReferenceType,
  2950. )
  2951. try:
  2952. ret += (torch._C._distributed_c10d.ProcessGroup,)
  2953. except AttributeError:
  2954. pass
  2955. return ret
  2956. class GuardsStatePickler(pickle.Pickler):
  2957. def __init__(
  2958. self,
  2959. guard_tree_values: dict[int, Any],
  2960. empty_values: dict[int, Any],
  2961. missing_values: dict[int, Any],
  2962. *args: Any,
  2963. **kwargs: Any,
  2964. ) -> None:
  2965. super().__init__(*args, **kwargs)
  2966. self.fake_mode = torch._subclasses.FakeTensorMode()
  2967. self.tensor_converter = torch._subclasses.fake_tensor.FakeTensorConverter()
  2968. self.guard_tree_values = guard_tree_values
  2969. self.empty_values = empty_values
  2970. self.missing_values = missing_values
  2971. @classmethod
  2972. def _unpickle_module(cls, state: Any) -> torch.nn.Module:
  2973. mod = torch.nn.Module()
  2974. mod.__setstate__(state)
  2975. return mod
  2976. @classmethod
  2977. def _unpickle_tensor(
  2978. cls,
  2979. meta_tensor: torch.Tensor,
  2980. device: torch.device,
  2981. pytype: type,
  2982. dispatch_keys_raw: int,
  2983. grad: torch.Tensor,
  2984. ) -> torch.Tensor:
  2985. fake_mode = torch._subclasses.FakeTensorMode()
  2986. tensor_converter = torch._subclasses.fake_tensor.FakeTensorConverter()
  2987. ret = tensor_converter.from_meta_and_device(
  2988. fake_mode,
  2989. meta_tensor,
  2990. device,
  2991. pytype,
  2992. torch._C.DispatchKeySet.from_raw_repr(dispatch_keys_raw),
  2993. )
  2994. ret.grad = grad
  2995. return ret
  2996. @classmethod
  2997. def _unpickle_traceable_wrapper_subclass(
  2998. cls,
  2999. meta_tensor: torch.Tensor,
  3000. device: torch.device,
  3001. pytype: type,
  3002. dispatch_keys_raw: int,
  3003. ctx: Any,
  3004. inner_data: list[tuple[str, Callable[..., Any], tuple[Any, ...]]],
  3005. ) -> torch.Tensor:
  3006. # Unpickle the inner tensor components. These could also be subclass instances.
  3007. inner_tensors = {}
  3008. for attr, unpickle_func, unpickle_func_args in inner_data:
  3009. inner_tensors[attr] = unpickle_func(*unpickle_func_args)
  3010. outer_size, outer_stride = meta_tensor.shape, meta_tensor.stride()
  3011. out = type(meta_tensor).__tensor_unflatten__( # type: ignore[attr-defined]
  3012. inner_tensors, ctx, outer_size, outer_stride
  3013. )
  3014. out.pytype = pytype
  3015. out.dispatch_keys = torch._C.DispatchKeySet.from_raw_repr(dispatch_keys_raw)
  3016. return out
  3017. @classmethod
  3018. def _unpickle_python_module(cls, alias: str) -> types.ModuleType:
  3019. return importlib.import_module(alias)
  3020. @classmethod
  3021. def _unpickle_dispatch_key_set(cls, raw_repr: int) -> torch._C.DispatchKeySet:
  3022. return torch._C.DispatchKeySet.from_raw_repr(raw_repr)
  3023. @classmethod
  3024. def _unpickle_functorch_interpreter(
  3025. cls, json: bytes
  3026. ) -> torch._C._functorch.CInterpreter:
  3027. return torch._C._functorch.CInterpreter.deserialize(json)
  3028. @classmethod
  3029. def _unpickle_mapping_proxy(
  3030. cls, d: dict[Any, Any]
  3031. ) -> types.MappingProxyType[Any, Any]:
  3032. return types.MappingProxyType(d)
  3033. @classmethod
  3034. def _unpickle_dict_keys(cls, elems: list[Any]) -> Any:
  3035. return dict.fromkeys(elems).keys()
  3036. @classmethod
  3037. def _unpickle_fsdp_module_type(
  3038. cls, original_type: type[torch.nn.Module]
  3039. ) -> type[torch.nn.Module]:
  3040. return torch.distributed.fsdp._fully_shard._fully_shard.get_cls_to_fsdp_cls()[
  3041. original_type
  3042. ]
  3043. @classmethod
  3044. def _unpickle_ddp_module(
  3045. cls, state: dict[str, Any]
  3046. ) -> torch.nn.parallel.DistributedDataParallel:
  3047. ty = torch.nn.parallel.DistributedDataParallel
  3048. ddp = ty.__new__(ty)
  3049. torch.nn.Module.__setstate__(ddp, state)
  3050. return ddp
  3051. @classmethod
  3052. def _unpickle_c_op(cls, name: str) -> Any:
  3053. return getattr(torch.ops._C, name)
  3054. @classmethod
  3055. def _unpickle_op(cls, namespace: str, opname: str, overloadname: str) -> Any:
  3056. return getattr(getattr(getattr(torch.ops, namespace), opname), overloadname)
  3057. @classmethod
  3058. def _unpickle_bound_method(cls, func: Any, base: Any) -> Any:
  3059. return types.MethodType(func, base)
  3060. @staticmethod
  3061. def _unpickle_sdp_backend(name: str) -> torch.nn.attention.SDPBackend:
  3062. # Reconstruct from the Python-facing enum namespace
  3063. return getattr(torch.nn.attention.SDPBackend, name)
  3064. @classmethod
  3065. def _unpickle_cell(cls, val: Any) -> Any:
  3066. def _() -> Any:
  3067. return val
  3068. assert _.__closure__ is not None
  3069. return _.__closure__[0]
  3070. @classmethod
  3071. def _unpickle_named_tuple_type(
  3072. cls, name: str, fields: tuple[str, ...]
  3073. ) -> type[NamedTuple]:
  3074. # pyrefly: ignore [bad-return]
  3075. return collections.namedtuple(name, fields)
  3076. @classmethod
  3077. def _unpickle_code(cls, serialized_code: SerializedCode) -> types.CodeType:
  3078. from torch._dynamo.package import SerializedCode
  3079. return SerializedCode.to_code_object(serialized_code)
  3080. @classmethod
  3081. def _unpickle_nested_function(
  3082. cls,
  3083. code: types.CodeType,
  3084. module: str,
  3085. qualname: str,
  3086. argdefs: tuple[object, ...] | None,
  3087. closure: tuple[types.CellType, ...] | None,
  3088. ) -> types.FunctionType:
  3089. f_globals = importlib.import_module(module).__dict__
  3090. return types.FunctionType(code, f_globals, qualname, argdefs, closure)
  3091. # pyrefly: ignore [bad-override]
  3092. def reducer_override(
  3093. self, obj: Any
  3094. ) -> Union[tuple[Callable[..., Any], tuple[Any, ...]], Any]:
  3095. import sympy
  3096. if id(obj) in self.empty_values:
  3097. return type(obj).__new__, (type(obj),)
  3098. if inspect.iscode(obj):
  3099. from torch._dynamo.package import SerializedCode
  3100. return type(self)._unpickle_code, (SerializedCode.from_code_object(obj),)
  3101. if id(obj) in self.missing_values:
  3102. return _Missing, ("missing values",)
  3103. if isinstance(obj, torch.Tensor) and obj.device.type != "meta":
  3104. from torch.utils._python_dispatch import is_traceable_wrapper_subclass
  3105. if id(obj) not in self.guard_tree_values:
  3106. return _Missing, ("tensor guard tree",)
  3107. if is_traceable_wrapper_subclass(obj):
  3108. # inner_data is a list of tuples of:
  3109. # (inner attr name, unpickle func, tuple of func inputs)
  3110. # This supports traceable wrapper subclass inner tensors.
  3111. inner_data = []
  3112. attrs, ctx = obj.__tensor_flatten__()
  3113. # recursively call for inner tensor components
  3114. for attr in attrs:
  3115. inner = getattr(obj, attr)
  3116. if isinstance(inner, torch.Tensor):
  3117. self.guard_tree_values[id(inner)] = inner
  3118. func, args_tuple = self.reducer_override(inner)
  3119. inner_data.append((attr, func, args_tuple))
  3120. return type(self)._unpickle_traceable_wrapper_subclass, (
  3121. torch.empty_like(obj, device="meta"),
  3122. obj.device,
  3123. type(obj),
  3124. torch._C._dispatch_keys(obj).raw_repr(),
  3125. ctx,
  3126. inner_data,
  3127. )
  3128. # For FakeTensors, use pytype if set, otherwise default to
  3129. # torch.Tensor. This is important for cross-compilation where
  3130. # we compile with fake tensors but run with real tensors.
  3131. pytype = type(obj)
  3132. if isinstance(obj, torch._subclasses.FakeTensor):
  3133. pytype = obj.pytype if obj.pytype is not None else torch.Tensor
  3134. return type(self)._unpickle_tensor, (
  3135. torch.empty_like(obj, device="meta", requires_grad=obj.requires_grad),
  3136. obj.device,
  3137. pytype,
  3138. torch._C._dispatch_keys(obj).raw_repr(),
  3139. obj.grad,
  3140. )
  3141. elif isinstance(obj, torch.nn.Module):
  3142. if id(obj) not in self.guard_tree_values:
  3143. return _Missing, ("module guard tree",)
  3144. for attr in obj.__dict__.values():
  3145. if isinstance(attr, (torch.Tensor, torch.nn.Module)):
  3146. continue
  3147. if id(attr) in self.guard_tree_values:
  3148. continue
  3149. if callable(attr):
  3150. continue
  3151. self.missing_values[id(attr)] = attr
  3152. # DDP module is a special case because it tries to restore unneeded
  3153. # data in custom __setstate__. We cannot skip ddp module because it
  3154. # is often a toplevel module.
  3155. if isinstance(obj, torch.nn.parallel.DistributedDataParallel):
  3156. return type(self)._unpickle_ddp_module, (obj.__getstate__(),)
  3157. if type(obj).__qualname__ == type(obj).__name__:
  3158. return NotImplemented
  3159. if obj.__class__.__getstate__ == torch.nn.Module.__getstate__:
  3160. return type(self)._unpickle_module, (obj.__getstate__(),)
  3161. elif inspect.ismodule(obj):
  3162. return type(self)._unpickle_python_module, (obj.__name__,)
  3163. elif isinstance(obj, torch._C.DispatchKeySet):
  3164. return type(self)._unpickle_dispatch_key_set, (obj.raw_repr(),)
  3165. elif isinstance(obj, torch._C._functorch.CInterpreter):
  3166. return type(self)._unpickle_functorch_interpreter, (obj.serialize(),)
  3167. elif (
  3168. inspect.isclass(obj)
  3169. and issubclass(obj, sympy.Function)
  3170. and hasattr(obj, "_torch_handler_name")
  3171. ):
  3172. assert hasattr(obj, "_torch_unpickler")
  3173. return obj._torch_unpickler, (obj._torch_handler_name,)
  3174. elif (
  3175. inspect.isclass(obj)
  3176. and issubclass(obj, tuple)
  3177. and hasattr(obj, "_fields")
  3178. and obj.__qualname__ != obj.__name__
  3179. ):
  3180. return type(self)._unpickle_named_tuple_type, (obj.__name__, obj._fields)
  3181. elif isinstance(obj, torch.SymInt):
  3182. raise RuntimeError(f"Cannot serialize SymInt {obj} (node: {obj.node})")
  3183. elif isinstance(obj, types.MappingProxyType):
  3184. return type(self)._unpickle_mapping_proxy, (obj.copy(),)
  3185. elif isinstance(obj, torch._dynamo.utils.dict_keys):
  3186. return type(self)._unpickle_dict_keys, (list(obj),)
  3187. elif isinstance(
  3188. obj, torch._ops.OpOverloadPacket
  3189. ) and obj._qualified_op_name.startswith("_C::"):
  3190. return type(self)._unpickle_c_op, (obj.__name__,)
  3191. elif isinstance(obj, torch._ops.OpOverload):
  3192. return type(self)._unpickle_op, (
  3193. obj.namespace,
  3194. obj._opname,
  3195. obj._overloadname,
  3196. )
  3197. elif (
  3198. obj.__class__.__module__ == "builtins"
  3199. and obj.__class__.__name__ == "PyCapsule"
  3200. ):
  3201. # Skipping PyCapsule since there isn't much to be guarded about them.
  3202. return _Missing, ("capsule",)
  3203. elif isinstance(obj, _get_unsupported_types()):
  3204. return _Missing, ("unsupported",)
  3205. elif inspect.isfunction(obj):
  3206. if "<locals>" in obj.__qualname__:
  3207. return type(self)._unpickle_nested_function, (
  3208. obj.__code__,
  3209. obj.__module__,
  3210. obj.__qualname__,
  3211. obj.__defaults__,
  3212. obj.__closure__,
  3213. )
  3214. if obj.__module__ in sys.modules:
  3215. f = sys.modules[obj.__module__]
  3216. for name in obj.__qualname__.split("."):
  3217. f = getattr(f, name, None) # type: ignore[assignment]
  3218. if f is not obj:
  3219. return _Missing, ("fqn mismatch",)
  3220. elif inspect.ismethod(obj):
  3221. func = obj.__func__
  3222. method_self = obj.__self__
  3223. inner_func = getattr(method_self, func.__name__)
  3224. if inspect.ismethod(inner_func):
  3225. inner_func = inner_func.__func__
  3226. if func is not inner_func:
  3227. return type(self)._unpickle_bound_method, (func, method_self)
  3228. elif isinstance(obj, type((lambda x: lambda: x)(0).__closure__[0])): # type: ignore[index] # noqa: PLC3002
  3229. return type(self)._unpickle_cell, (obj.cell_contents,)
  3230. if hasattr(torch.distributed, "distributed_c10d") and isinstance(
  3231. obj, torch.distributed.distributed_c10d.Work
  3232. ):
  3233. if id(obj) not in self.guard_tree_values:
  3234. return _Missing, ("distributed_c10d.Work",)
  3235. if isinstance(obj, torch.nn.attention.SDPBackend):
  3236. return type(self)._unpickle_sdp_backend, (obj.name,)
  3237. if type(obj).__qualname__ != type(obj).__name__ and not isinstance(obj, tuple):
  3238. raise torch._dynamo.exc.PackageError(
  3239. f"Type {type(obj)} for object {obj} cannot be saved "
  3240. + "into torch.compile() package since it's defined in local scope. "
  3241. + "Please define the class at global scope (top level of a module)."
  3242. )
  3243. if (
  3244. inspect.isclass(obj)
  3245. and hasattr(torch.distributed, "fsdp")
  3246. and issubclass(obj, torch.distributed.fsdp._fully_shard.FSDPModule)
  3247. ):
  3248. if obj is not torch.distributed.fsdp._fully_shard.FSDPModule:
  3249. original_type = obj.__mro__[2]
  3250. assert issubclass(original_type, torch.nn.Module)
  3251. assert (
  3252. original_type
  3253. in torch.distributed.fsdp._fully_shard._fully_shard.get_cls_to_fsdp_cls()
  3254. )
  3255. return type(self)._unpickle_fsdp_module_type, (original_type,)
  3256. return NotImplemented
  3257. def make_guard_filter_entry(guard: Guard, builder: GuardBuilder) -> GuardFilterEntry:
  3258. MISSING = object()
  3259. name = strip_local_scope(guard.name)
  3260. if name == "":
  3261. has_value = False
  3262. value = MISSING
  3263. else:
  3264. try:
  3265. # Guard evaluation is expected to fail when we guard on
  3266. # things like "not hasattr(x, 'foo')". In cases like this,
  3267. # we don't have a well defined value because such thing
  3268. # doesn't exist.
  3269. value = builder.get(guard)
  3270. has_value = True
  3271. except: # noqa: B001,E722
  3272. value = MISSING
  3273. has_value = False
  3274. is_global = get_global_source_name(guard.originating_source) is not None
  3275. return GuardFilterEntry(
  3276. name=name,
  3277. has_value=has_value,
  3278. value=value,
  3279. guard_type=guard.create_fn_name(),
  3280. derived_guard_types=(tuple(guard.guard_types) if guard.guard_types else ()),
  3281. is_global=is_global,
  3282. orig_guard=guard,
  3283. )
  3284. def pickle_guards_state(
  3285. state: GuardsState,
  3286. builder: GuardBuilder,
  3287. ) -> bytes:
  3288. buf = io.BytesIO()
  3289. empty_values = {}
  3290. missing_values = {}
  3291. guard_tree_values = builder.guard_tree_values
  3292. leaves = pytree.tree_leaves(state.output_graph.local_scope)
  3293. for leaf in leaves:
  3294. if inspect.ismethod(leaf) and hasattr(leaf, "__self__"):
  3295. base = leaf.__self__
  3296. if id(base) not in guard_tree_values:
  3297. try:
  3298. type(base).__new__(type(base))
  3299. empty_values[id(base)] = base
  3300. except: # noqa: E722, B001
  3301. pass
  3302. elif id(leaf) not in guard_tree_values:
  3303. # TODO See if we have lift this branch as the first one.
  3304. # Prune more objects in pytree hierarchy.
  3305. missing_values[id(leaf)] = leaf
  3306. pickler = GuardsStatePickler(guard_tree_values, empty_values, missing_values, buf)
  3307. if all(
  3308. torch.compiler.keep_portable_guards_unsafe(
  3309. [
  3310. make_guard_filter_entry(guard, builder)
  3311. for guard in state.output_graph.guards
  3312. ]
  3313. )
  3314. ):
  3315. # Prune more values in AOT precompile when complex pickling structure is not needed.
  3316. state.output_graph.guard_on_key_order = set()
  3317. state.output_graph.global_scope = {}
  3318. try:
  3319. pickler.dump(state)
  3320. except AttributeError as e:
  3321. raise torch._dynamo.exc.PackageError(str(e)) from e
  3322. return buf.getvalue()
  3323. # NB: Naively, you'd expect this to only be a function that produces
  3324. # the callable that constitutes the guard. However, there is some
  3325. # delicate handling for invalidating this check function when the
  3326. # locals/globals get invalidated, so there's some extra state
  3327. # we have to hold in this manager class.
  3328. class CheckFunctionManager:
  3329. def __init__(
  3330. self,
  3331. f_code: types.CodeType,
  3332. output_graph: OutputGraphCommon,
  3333. cache_entry: Optional[CacheEntry] = None,
  3334. guard_fail_fn: Optional[Callable[[GuardFail], None]] = None,
  3335. guard_filter_fn: Callable[[Sequence[GuardFilterEntry]], Sequence[bool]]
  3336. | None = None,
  3337. shape_code_parts: Optional[ShapeCodeParts] = None,
  3338. runtime_global_scope: Optional[dict[str, Any]] = None,
  3339. save_guards: bool = False,
  3340. strict_error: bool = False,
  3341. ) -> None:
  3342. guards = output_graph.guards if output_graph else None
  3343. self._weakrefs: dict[int, ReferenceType[object]] = {}
  3344. existing_diff_guard_sources = (
  3345. update_diff_guard_managers_for_existing_cache_entries(cache_entry)
  3346. )
  3347. self.output_graph: Optional[OutputGraphCommon] = output_graph
  3348. assert self.output_graph is not None
  3349. # Only used for serialization.
  3350. self.shape_code_parts = shape_code_parts
  3351. # NB: Until we trace device contexts, we need to use the stack recorded at the beginning of tracing
  3352. # in case a set default device call was made in the graph.
  3353. self.torch_function_mode_stack = (
  3354. output_graph.torch_function_mode_stack if output_graph else None
  3355. )
  3356. self.used_builtin_vars: OrderedSet[str] = OrderedSet()
  3357. self.additional_used_local_vars: OrderedSet[str] = OrderedSet()
  3358. self.additional_used_global_vars: OrderedSet[str] = OrderedSet()
  3359. self.runtime_global_scope = runtime_global_scope
  3360. self.global_state: Optional[torch._C._dynamo.guards.GlobalStateGuard] = None
  3361. self.torch_function_mode_stack_check_fn: Optional[Callable[[], bool]] = None
  3362. if not justknobs_check("pytorch/compiler:guard_nn_modules"):
  3363. log.warning("guard_nn_modules is turned off using justknobs killswitch")
  3364. # TODO Be more explicit about the behavior for the users.
  3365. if torch._dynamo.config.caching_precompile:
  3366. _guard_filter_fn = guard_filter_fn or (lambda gs: [True for g in gs])
  3367. def guard_filter_fn(guards: Sequence[GuardFilterEntry]) -> Sequence[bool]:
  3368. ret = []
  3369. for keep, g in zip(_guard_filter_fn(guards), guards):
  3370. if not keep:
  3371. ret.append(False)
  3372. elif (
  3373. g.guard_type
  3374. in (
  3375. "ID_MATCH",
  3376. "CLOSURE_MATCH",
  3377. "WEAKREF_ALIVE",
  3378. "DICT_VERSION",
  3379. )
  3380. or "ID_MATCH" in g.derived_guard_types
  3381. or "DICT_VERSION" in g.derived_guard_types
  3382. ):
  3383. log.warning(
  3384. "%s guard on %s is dropped with caching_precompile=True.",
  3385. g.guard_type,
  3386. g.orig_guard.name,
  3387. )
  3388. ret.append(False)
  3389. else:
  3390. ret.append(True)
  3391. return ret
  3392. sorted_guards = sorted(guards or (), key=Guard.sort_key)
  3393. if guard_filter_fn:
  3394. # If we're filtering guards, we need to build it an extra time first
  3395. # because filtering depends on the builder/guard_manager results
  3396. builder, guard_manager = self.build_guards(
  3397. sorted_guards,
  3398. existing_diff_guard_sources,
  3399. f_code,
  3400. output_graph,
  3401. False,
  3402. )
  3403. filter_results = guard_filter_fn(
  3404. [make_guard_filter_entry(guard, builder) for guard in sorted_guards]
  3405. )
  3406. assert len(filter_results) == len(sorted_guards)
  3407. assert all(type(x) is bool for x in filter_results)
  3408. sorted_guards = [
  3409. guard for i, guard in enumerate(sorted_guards) if filter_results[i]
  3410. ]
  3411. # Redo the guards because filtering relies on the results from the last guard builder.
  3412. builder, guard_manager = self.build_guards(
  3413. sorted_guards,
  3414. existing_diff_guard_sources,
  3415. f_code,
  3416. output_graph,
  3417. save_guards,
  3418. guard_filter_fn=guard_filter_fn,
  3419. )
  3420. self.guard_manager = guard_manager
  3421. self.compile_check_fn(builder, sorted_guards, guard_fail_fn)
  3422. # Keep track of weak references of objects with ID_MATCH guard. This
  3423. # info is stored alongside optimized_code and guard_manager and is used to
  3424. # limit the number of cache entries with same ID_MATCH'd object.
  3425. # TODO(anijain2305) - Currently this information is stored as an attr on
  3426. # the guard_manager itself to avoid changing CacheEntry data structure in
  3427. # eval_frame.c. In future, we should probably replace guard_manager with a
  3428. # queryable data structure such that this information is already present
  3429. # in some form.
  3430. self.guard_manager.id_matched_objs = builder.id_matched_objs
  3431. guards_log.debug("%s", self.guard_manager)
  3432. self.guard_manager.id_matched_objs = builder.id_matched_objs
  3433. # Check that the guard returns True. False means that we will always
  3434. # recompile.
  3435. # TODO(anijain2305, ydwu4) - Skipping export because of following test
  3436. # python -s test/dynamo/test_export.py -k test_export_with_symbool_inputs
  3437. latency = 0.0
  3438. if not output_graph.skip_guards_check and not output_graph.export:
  3439. if not self.guard_manager.check(output_graph.local_scope):
  3440. reasons = get_guard_fail_reason_helper(
  3441. self.guard_manager,
  3442. output_graph.local_scope,
  3443. CompileContext.current_compile_id(),
  3444. backend=None, # no need to set this because we are trying to find the offending guard entry
  3445. )
  3446. raise AssertionError(
  3447. "Guard failed on the same frame it was created. This is a bug - please create an issue."
  3448. f"Guard fail reason: {reasons}"
  3449. )
  3450. if guard_manager_testing_hook_fn is not None:
  3451. guard_manager_testing_hook_fn(
  3452. self.guard_manager, output_graph.local_scope, builder
  3453. )
  3454. # NB for developers: n_iters is chosen to be 1 to prevent excessive
  3455. # increase in compile time. We first do a cache flush to measure the
  3456. # guard latency more accurately. This cache flush is expensive.
  3457. # Note - If you are working on a guard optimization, it might be a
  3458. # good idea to increase this number for more stability during
  3459. # development.
  3460. latency = profile_guard_manager(
  3461. self.guard_manager.root, output_graph.local_scope, 1
  3462. )
  3463. guards_log.debug("Guard eval latency = %s us", f"{latency:.2f}")
  3464. # Note: We use `increment_toplevel` instead of `compilation_metric`
  3465. # here. This is because, in scenarios where `torch._dynamo.reset`
  3466. # is invoked, the same frame ID and compile ID may be reused during
  3467. # a new compilation cycle. This behavior causes issues with
  3468. # `compilation_metric`, as it expects the metric field to be empty.
  3469. # Ideally, we would overwrite the existing entry in such cases, but
  3470. # we currently lack an API to support overwriting metrics. However,
  3471. # since these situations are rare and typically impractical to
  3472. # account for, we simply increment at the toplevel instead.
  3473. CompileEventLogger.increment_toplevel("guard_latency_us", int(latency))
  3474. self.guards_state: Optional[bytes] = None
  3475. if save_guards:
  3476. from torch._dynamo.output_graph import OutputGraphCommon
  3477. assert isinstance(self.output_graph, OutputGraphCommon)
  3478. try:
  3479. self.guards_state = self.serialize_guards(
  3480. builder, sorted_guards, self.output_graph
  3481. )
  3482. except exc.PackageError as e:
  3483. if torch._dynamo.config.strict_precompile or strict_error:
  3484. raise e
  3485. self.output_graph.bypass_package(
  3486. f"Guard evaluation failed: {str(e)}",
  3487. traceback=traceback.format_exc().split("\n"),
  3488. )
  3489. # TODO: don't do the string rep, do something more structured here
  3490. torch._logging.trace_structured(
  3491. "dynamo_cpp_guards_str",
  3492. payload_fn=lambda: f"{self.guard_manager}\nGuard latency = {latency:.2f} us",
  3493. )
  3494. # NB - We have to very careful of cleaning up here. Because of the
  3495. # invalidate function, we can create a weakref finalizer that keeps
  3496. # `self` alive for very long. Sometimes by mistake, we can run
  3497. # invalidate for a type/object (check id_ref method) that Python can
  3498. # leak by design, preventing us from calling the finalizer. In that
  3499. # case, the `self` will be alive even though the cache entry will be
  3500. # deleted (check invalidate method), which can cause a memory leak,
  3501. # e.g., not setting output_graph = None can keep hold of nn_modules.
  3502. self._weakrefs.clear()
  3503. self.output_graph = None
  3504. UNSUPPORTED_SERIALIZATION_GUARD_TYPES: tuple[LiteralString, ...] = (
  3505. "DICT_VERSION",
  3506. "NN_MODULE",
  3507. "ID_MATCH",
  3508. "FUNCTION_MATCH",
  3509. "CLASS_MATCH",
  3510. "MODULE_MATCH",
  3511. "CLOSURE_MATCH",
  3512. "WEAKREF_ALIVE",
  3513. )
  3514. def serialize_guards(
  3515. self,
  3516. builder: GuardBuilder,
  3517. sorted_guards: list[Guard],
  3518. output_graph: OutputGraphCommon,
  3519. ) -> bytes:
  3520. # We check whether our list of guards are serializable here
  3521. for guard in sorted_guards:
  3522. guard_type = guard.create_fn_name()
  3523. derived_guard_types = tuple(guard.guard_types) if guard.guard_types else ()
  3524. # BUILTIN_MATCH calls TYPE_MATCH sometimes, so we need to check both for
  3525. # a chance that the guard is unserializable
  3526. if guard_type in ("TYPE_MATCH", "BUILTIN_MATCH"):
  3527. if guard._unserializable:
  3528. # Only call builder.get again if we know we're going to throw
  3529. obj = builder.get(guard)
  3530. raise_local_type_error(obj)
  3531. elif (
  3532. guard_type in CheckFunctionManager.UNSUPPORTED_SERIALIZATION_GUARD_TYPES
  3533. ):
  3534. raise torch._dynamo.exc.PackageError(
  3535. f"{guard_type} guard cannot be serialized."
  3536. )
  3537. elif failed := next(
  3538. (
  3539. i
  3540. for i in derived_guard_types
  3541. if i in CheckFunctionManager.UNSUPPORTED_SERIALIZATION_GUARD_TYPES
  3542. ),
  3543. None,
  3544. ):
  3545. # Just raise the first failed guard name
  3546. raise torch._dynamo.exc.PackageError(
  3547. f"{failed} guard cannot be serialized."
  3548. )
  3549. builtins_dict_name = output_graph.name_of_builtins_dict_key_in_fglobals or ""
  3550. used_global_vars = set()
  3551. used_local_vars = set()
  3552. def prune_variable(source: Source) -> None:
  3553. if name := get_global_source_name(source):
  3554. assert isinstance(name, str)
  3555. # Leave out the builtins dict key, as we will special handle
  3556. # it later because the guarded code rarely use the entire
  3557. # builtin dict in the common case.
  3558. if name != builtins_dict_name:
  3559. used_global_vars.add(name)
  3560. elif name := get_local_source_name(source):
  3561. assert isinstance(name, str)
  3562. used_local_vars.add(name)
  3563. output_graph_guards_state = output_graph.dump_guards_state()
  3564. # Only serialize the global variables that are actually used in guards.
  3565. for guard in sorted_guards:
  3566. if isinstance(guard.originating_source, ShapeEnvSource):
  3567. assert self.shape_code_parts
  3568. for source in self.shape_code_parts.shape_env_sources:
  3569. prune_variable(source)
  3570. else:
  3571. prune_variable(guard.originating_source)
  3572. for source in output_graph.guard_on_key_order:
  3573. prune_variable(source)
  3574. def normalize_create_fn(x: Callable[..., None]) -> Callable[..., None]:
  3575. if isinstance(x, functools.partial):
  3576. def _ref(x: Any) -> Any:
  3577. if isinstance(x, (TensorWeakRef, weakref.ref)):
  3578. return x()
  3579. return x
  3580. new_args = tuple(_ref(a) for a in x.args)
  3581. new_keywords = {k: _ref(v) for k, v in x.keywords.items()}
  3582. return functools.partial(x.func, *new_args, **new_keywords)
  3583. return x
  3584. global_scope_state = {
  3585. k: v
  3586. for k, v in output_graph_guards_state.global_scope.items()
  3587. if k in used_global_vars or k in self.additional_used_global_vars
  3588. }
  3589. global_scope_state[builtins_dict_name] = {
  3590. k: v
  3591. # pyrefly: ignore [missing-attribute]
  3592. for k, v in output_graph_guards_state.global_scope[
  3593. builtins_dict_name
  3594. ].items() # type: ignore[attr-defined]
  3595. if k in self.used_builtin_vars
  3596. }
  3597. output_graph_guards_state = dataclasses.replace(
  3598. output_graph_guards_state,
  3599. local_scope={
  3600. k: v
  3601. for k, v in output_graph_guards_state.local_scope.items()
  3602. if k in used_local_vars or k in self.additional_used_local_vars
  3603. },
  3604. global_scope=global_scope_state,
  3605. _guards=torch._guards.GuardsSet(
  3606. OrderedSet(
  3607. dataclasses.replace(
  3608. guard,
  3609. obj_weakref=None,
  3610. guarded_class_weakref=None,
  3611. create_fn=normalize_create_fn(guard.create_fn),
  3612. )
  3613. for guard in sorted_guards
  3614. )
  3615. ),
  3616. input_source_to_sizes_strides=pytree.tree_map(
  3617. convert_int_to_concrete_values,
  3618. output_graph_guards_state.input_source_to_sizes_strides,
  3619. ),
  3620. skip_guards_check=True,
  3621. )
  3622. guards_state = GuardsState(
  3623. output_graph=output_graph_guards_state,
  3624. shape_code_parts=self.shape_code_parts,
  3625. )
  3626. return pickle_guards_state(guards_state, builder)
  3627. def build_guards(
  3628. self,
  3629. sorted_guards: list[Guard],
  3630. existing_diff_guard_sources: OrderedSet[str],
  3631. f_code: types.CodeType,
  3632. output_graph: OutputGraphGuardsState,
  3633. save_guards: bool,
  3634. guard_filter_fn: Callable[[Sequence[GuardFilterEntry]], Sequence[bool]]
  3635. | None = None,
  3636. ) -> tuple[GuardBuilder, GuardManagerWrapper]:
  3637. guard_manager = GuardManagerWrapper()
  3638. guard_manager.diff_guard_sources = existing_diff_guard_sources
  3639. w_builder = None
  3640. def source_ref(source: Source) -> str:
  3641. guard_source = source.guard_source
  3642. if guard_source is GuardSource.CONSTANT:
  3643. # No need to track constants
  3644. return source.name
  3645. assert w_builder
  3646. r_builder = w_builder()
  3647. assert r_builder is not None
  3648. return r_builder.arg_ref(source.name)
  3649. builder = GuardBuilder(
  3650. f_code,
  3651. self.id_ref,
  3652. source_ref,
  3653. self.lookup_weakrefs,
  3654. output_graph.local_scope,
  3655. output_graph.global_scope,
  3656. guard_manager,
  3657. self,
  3658. save_guards,
  3659. runtime_global_scope=self.runtime_global_scope,
  3660. guard_filter_fn=guard_filter_fn,
  3661. )
  3662. # Break retain cycle. See test_release_scope_memory
  3663. def cleanup_builder(weak_b: weakref.ref[GuardBuilder]) -> None:
  3664. b = weak_b()
  3665. if b:
  3666. b.scope = None # type: ignore[assignment]
  3667. # Break retain cycle. See test_release_input_memory
  3668. w_builder = weakref.ref(builder, cleanup_builder)
  3669. guard_on_nn_modules = config.guard_nn_modules and justknobs_check(
  3670. "pytorch/compiler:guard_nn_modules"
  3671. )
  3672. for guard in sorted_guards:
  3673. if (
  3674. not guard_on_nn_modules
  3675. and guard.is_specialized_nn_module()
  3676. # Default func args must be guarded on.
  3677. # TODO: we could make use of 'DefaultsSource' and offer a .guard.is_defaults() API
  3678. and "__defaults__" not in guard.name
  3679. and "__kwdefaults__" not in guard.name
  3680. and (config.skip_nnmodule_hook_guards or "hooks" not in guard.name)
  3681. ):
  3682. continue
  3683. guard.create(builder)
  3684. return builder, guard_manager
  3685. def compile_check_fn(
  3686. self,
  3687. builder: GuardBuilder,
  3688. guards_out: list[Guard],
  3689. guard_fail_fn: Optional[Callable[[GuardFail], None]],
  3690. ) -> None:
  3691. # see parallel handling of ".0" / "___implicit0" in _eval_frame.c
  3692. largs = builder.argnames
  3693. largs += ["**___kwargs_ignored"]
  3694. guards_log.debug("GUARDS:")
  3695. # pyrefly: ignore [implicit-any]
  3696. code_parts = []
  3697. verbose_code_parts = []
  3698. structured_guard_fns: list[Callable[[], dict[str, Any]]] = []
  3699. # Add compile id info in the guard manager for debugging purpose
  3700. self.guard_manager.root.attach_compile_id(
  3701. str(CompileContext.current_compile_id())
  3702. )
  3703. # Clear references to torch_function modes held in the list
  3704. self.torch_function_mode_stack = None
  3705. def add_code_part(
  3706. code_part: str, guard: Optional[Guard], log_only: bool = False
  3707. ) -> None:
  3708. verbose_code_part = get_verbose_code_part(code_part, guard)
  3709. guards_log.debug("%s", verbose_code_part)
  3710. structured_guard_fns.append(
  3711. lambda: {
  3712. "code": code_part,
  3713. "stack": (
  3714. structured.from_traceback(guard.stack.summary())
  3715. if guard and guard.stack
  3716. else None
  3717. ),
  3718. "user_stack": (
  3719. structured.from_traceback(guard.user_stack)
  3720. if guard and guard.user_stack
  3721. else None
  3722. ),
  3723. }
  3724. )
  3725. if verbose_guards_log.isEnabledFor(logging.DEBUG):
  3726. maybe_stack = ""
  3727. maybe_user_stack = ""
  3728. if guard is not None:
  3729. if guard.stack:
  3730. maybe_stack = f"\nStack:\n{''.join(guard.stack.format())}"
  3731. if guard.user_stack:
  3732. maybe_user_stack = (
  3733. f"\nUser stack:\n{''.join(guard.user_stack.format())}"
  3734. )
  3735. verbose_guards_log.debug(
  3736. "Guard: %s%s%s",
  3737. code_part,
  3738. maybe_stack,
  3739. maybe_user_stack,
  3740. )
  3741. if not log_only:
  3742. code_parts.append(code_part)
  3743. verbose_code_parts.append(verbose_code_part)
  3744. seen = set()
  3745. for gcl in builder.code:
  3746. for code in gcl.code_list:
  3747. if code not in seen:
  3748. # If Cpp guard manager is enabled, we don't need to add to
  3749. # code_parts.
  3750. add_code_part(code, gcl.guard, True)
  3751. seen.add(code)
  3752. no_tensor_aliasing_names = builder.no_tensor_aliasing_names
  3753. check_tensors_fn = None
  3754. check_tensors_verbose_fn = None
  3755. if len(no_tensor_aliasing_names) > 1:
  3756. # Install tensor aliasing guard. TENSOR_MATCH guards are already
  3757. # installed for cpp guard manager.
  3758. install_no_tensor_aliasing_guard(
  3759. builder.no_tensor_aliasing_guard_managers,
  3760. no_tensor_aliasing_names,
  3761. ["check_no_aliasing(" + ", ".join(no_tensor_aliasing_names) + ")"],
  3762. None,
  3763. )
  3764. # Note - On Lambda guarding of object aliasing
  3765. # We previously installed object-aliasing guards as relational guards,
  3766. # but that undermined the recursive-dict guard optimization: placing the
  3767. # aliasing guard at a leaf prevented the parent dict node from
  3768. # qualifying as a recursive-dict guard root. Because aliasing guards are
  3769. # rare, we now emit them as epilogue guards via a small Python lambda.
  3770. # This repeats the access in Python—adding a bit of work—but the
  3771. # overhead is outweighed by the gains from enabling recursive-dict guard
  3772. # optimization.
  3773. if (
  3774. config.use_lamba_guard_for_object_aliasing
  3775. and builder.object_aliasing_guard_codes
  3776. ):
  3777. aliasing_code_parts, aliasing_verbose_code_parts = map(
  3778. list, zip(*builder.object_aliasing_guard_codes)
  3779. )
  3780. builder.add_python_lambda_leaf_guard_to_root(
  3781. aliasing_code_parts, aliasing_verbose_code_parts
  3782. )
  3783. aotautograd_guards: list[GuardEnvExpr] = (
  3784. self.output_graph.aotautograd_guards if self.output_graph else []
  3785. )
  3786. # TODO(anijain2305) - There is a duplicate logic in Dynamo to find
  3787. # aliased input tensors. So most probably we don't need this here.
  3788. # Revisit.
  3789. for guard in aotautograd_guards:
  3790. if isinstance(guard, DuplicateInputs):
  3791. source_a = guard.input_source_a
  3792. source_b = guard.input_source_b
  3793. code_part = f"{source_a.name} is {source_b.name}"
  3794. install_object_aliasing_guard(
  3795. builder.get_guard_manager_from_source(source_a),
  3796. builder.get_guard_manager_from_source(source_b),
  3797. [code_part],
  3798. None,
  3799. )
  3800. add_code_part(code_part, None, True)
  3801. elif isinstance(guard, StorageOverlap):
  3802. overlapping_guard_managers = [
  3803. builder.get_guard_manager_from_source(s)
  3804. for s in guard.overlapping_sources
  3805. ]
  3806. non_overlapping_guard_managers = [
  3807. builder.get_guard_manager_from_source(s)
  3808. for s in guard.non_overlapping_sources
  3809. ]
  3810. code_part = (
  3811. """check_overlapping("""
  3812. f"""overlapping=[{", ".join(s.name for s in guard.overlapping_sources)}], """
  3813. f"""non_overlapping=[{", ".join(s.name for s in guard.non_overlapping_sources)}])"""
  3814. )
  3815. install_storage_overlapping_guard(
  3816. overlapping_guard_managers,
  3817. non_overlapping_guard_managers,
  3818. [code_part],
  3819. None,
  3820. )
  3821. add_code_part(code_part, None, True)
  3822. else:
  3823. raise RuntimeError(f"Unknown GuardEnvExpr: {guard}")
  3824. # TODO: the "guard" here is actually just the top level SHAPE_ENV
  3825. # which is useless. Get ShapeEnv to pass in more provenance.
  3826. for gcl in builder.shape_env_code:
  3827. for code in gcl.code_list:
  3828. # Shape env guards are already added for CPP guard manager in
  3829. # SHAPE_ENV implementation.
  3830. add_code_part(code, gcl.guard, True)
  3831. # OK, all done generating guards
  3832. if structured_guard_fns:
  3833. torch._logging.trace_structured(
  3834. "dynamo_guards", payload_fn=lambda: [f() for f in structured_guard_fns]
  3835. )
  3836. if convert_frame.initial_global_state is None:
  3837. # we should only hit this case in NopTests()
  3838. check_global_state = convert_frame.GlobalStateGuard().check
  3839. else:
  3840. check_global_state = getattr(self.global_state, "check", None)
  3841. closure_vars = {
  3842. "___check_tensors": check_tensors_fn,
  3843. "___check_tensors_verbose": check_tensors_verbose_fn,
  3844. "___check_global_state": check_global_state,
  3845. "___check_torch_function_mode_stack": self.torch_function_mode_stack_check_fn,
  3846. **SYMPY_INTERP,
  3847. **_get_closure_vars(),
  3848. }
  3849. self.guard_manager.finalize()
  3850. globals_for_guard_fn = {"G": builder.scope["G"]}
  3851. # Guard manager construction is complete. Ensure we did not miss to
  3852. # insert a guard in cpp guard manager.
  3853. assert len(code_parts) == 0
  3854. self.guard_manager.closure_vars = closure_vars
  3855. self.guard_manager.args = largs
  3856. self.guard_manager.populate_code_parts_for_debugging()
  3857. self.guard_manager.verbose_code_parts = verbose_code_parts
  3858. # Grab only G, but preserve "G" because guards access it as "G"
  3859. self.guard_manager.global_scope = globals_for_guard_fn
  3860. self.guard_manager.guard_fail_fn = guard_fail_fn
  3861. # will be populated by a non-owning reference to CacheEntry/ExtraState
  3862. # when the CacheEntry is constructed
  3863. self.guard_manager.cache_entry = None
  3864. self.guard_manager.extra_state = None
  3865. self.guard_manager.no_tensor_aliasing_sources = no_tensor_aliasing_names
  3866. def invalidate(self, obj_str: str) -> None:
  3867. # Some tests reveal that CheckFunctionManager has no attribute
  3868. # guard_manager, but this case should not be of any concern.
  3869. # This case doesn't seem easy to repro.
  3870. if (
  3871. hasattr(self, "guard_manager")
  3872. and not isinstance(self.guard_manager, DeletedGuardManagerWrapper)
  3873. and (cache_entry := self.guard_manager.cache_entry) is not None
  3874. and (extra_state := self.guard_manager.extra_state) is not None
  3875. ):
  3876. assert isinstance(cache_entry, CacheEntry)
  3877. assert isinstance(extra_state, ExtraState)
  3878. reason = f"Cache line invalidated because {obj_str} got deallocated"
  3879. deleted_guard_manager = DeletedGuardManagerWrapper(reason)
  3880. extra_state.invalidate(cache_entry, deleted_guard_manager)
  3881. self.guard_manager = deleted_guard_manager
  3882. def id_ref(self, obj: object, obj_str: str) -> int:
  3883. """add a weakref, return the id"""
  3884. try:
  3885. if id(obj) not in self._weakrefs:
  3886. # We will clear the _weakrefs dict at the end of __init__
  3887. # function, which will delete the callbacks as well. Therefore,
  3888. # we are using a finalizer which is kept alive.
  3889. self._weakrefs[id(obj)] = weakref.ref(obj)
  3890. weakref.finalize(
  3891. obj, functools.partial(self.invalidate, obj_str=obj_str)
  3892. )
  3893. except TypeError:
  3894. pass # cannot weakref bool object
  3895. return id(obj)
  3896. def lookup_weakrefs(self, obj: object) -> Optional[weakref.ref[object]]:
  3897. """Lookup the _weakrefs created in id_ref function for ID_MATCH'd objects"""
  3898. if id(obj) in self._weakrefs:
  3899. return self._weakrefs[id(obj)]
  3900. return None
  3901. def build_guard_function(code_parts: list[str], closure_args: str) -> tuple[str, str]:
  3902. from torch._inductor.utils import IndentedBuffer
  3903. csepass = PyExprCSEPass()
  3904. try:
  3905. csepass.count(code_parts)
  3906. def replace(expr: str) -> tuple[list[str], str]:
  3907. return csepass.replace(expr)
  3908. except RecursionError:
  3909. # If we hit recursion limits during CSE analysis, fall back to a no-op replace function
  3910. # This can happen with extremely complex guard expressions
  3911. def replace(expr: str) -> tuple[list[str], str]:
  3912. return [], expr
  3913. # Generate the inner body of the guard function.
  3914. # i.e. if-chain of the guard expressions.
  3915. guard_body = IndentedBuffer()
  3916. for expr in code_parts:
  3917. preface, expr = replace(expr)
  3918. guard_body.writelines(preface)
  3919. guard_body.writeline(f"if not ({expr}):")
  3920. with guard_body.indent():
  3921. guard_body.writeline("return False")
  3922. # Wrap the inner body into the actual guard function.
  3923. guard = IndentedBuffer()
  3924. guard.writeline("def guard(L):")
  3925. with guard.indent():
  3926. guard.splice(guard_body)
  3927. guard.writeline("return True")
  3928. # Wrap the whole guard function into another function
  3929. # with the closure variables.
  3930. make_guard_fn = IndentedBuffer()
  3931. make_guard_fn.writeline(f"def ___make_guard_fn({closure_args}):")
  3932. with make_guard_fn.indent():
  3933. make_guard_fn.splice(guard)
  3934. make_guard_fn.writeline("return guard")
  3935. return guard_body.getvalue(), make_guard_fn.getvalue()
  3936. def is_recompiles_enabled() -> bool:
  3937. return torch._logging._internal.log_state.is_artifact_enabled("recompiles")
  3938. def is_recompiles_verbose_enabled() -> bool:
  3939. return torch._logging._internal.log_state.is_artifact_enabled("recompiles_verbose")
  3940. # this will only be used if cpp guards are disabled
  3941. def make_torch_function_mode_stack_guard(
  3942. initial_stack: list[torch.overrides.TorchFunctionMode],
  3943. ) -> Callable[[], bool]:
  3944. types = [type(x) for x in initial_stack]
  3945. def check_torch_function_mode_stack() -> bool:
  3946. cur_stack = get_torch_function_mode_stack()
  3947. if len(cur_stack) != len(types):
  3948. return False
  3949. for ty, mode in zip(types, cur_stack):
  3950. if ty is not type(mode):
  3951. return False
  3952. return True
  3953. return check_torch_function_mode_stack
  3954. Scope = TypeAliasType("Scope", dict[str, object])
  3955. def recompilation_reason_for_no_tensor_aliasing_guard(
  3956. guard_manager: GuardManagerWrapper, scope: Scope
  3957. ) -> list[str]:
  3958. assert guard_manager.global_scope is not None
  3959. global_scope = dict(guard_manager.global_scope)
  3960. ids_to_source = collections.defaultdict(list)
  3961. for tensor_source in guard_manager.no_tensor_aliasing_sources:
  3962. global_scope["__compile_source__"] = tensor_source
  3963. tensor_id = id(eval(tensor_source, global_scope, scope))
  3964. ids_to_source[tensor_id].append(tensor_source)
  3965. duplicate_tensors = [
  3966. f"{ids_to_source[key]}" for key in ids_to_source if len(ids_to_source[key]) > 1
  3967. ]
  3968. reason = ", ".join(duplicate_tensors)
  3969. return [f"Duplicate tensors found: {reason}"]
  3970. def strip_local_scope(s: str) -> str:
  3971. """
  3972. Replace occurrences of L[...] with just the inner content.
  3973. Handles both single and double quotes.
  3974. This is to generate user friendly recompilation messages.
  3975. """
  3976. import re
  3977. pattern = r"L\[\s*['\"](.*?)['\"]\s*\]"
  3978. return re.sub(pattern, r"\1", s)
  3979. def format_user_stack_trace(
  3980. user_stack: traceback.StackSummary | None,
  3981. ) -> str:
  3982. """
  3983. Format the user stack trace for display in guard failure messages.
  3984. Returns a formatted string representation of the stack trace,
  3985. or an empty string if no user stack is available.
  3986. """
  3987. if user_stack is None or len(user_stack) == 0:
  3988. return ""
  3989. lines: list[str] = []
  3990. for frame in user_stack:
  3991. filename = frame.filename
  3992. lineno = frame.lineno
  3993. name = frame.name
  3994. source_line = frame.line.strip() if frame.line else ""
  3995. lines.append(f' File "{filename}", line {lineno}, in {name}')
  3996. if source_line:
  3997. lines.append(f" {source_line}")
  3998. return "\n".join(lines)
  3999. def get_guard_fail_reason_helper(
  4000. guard_manager: GuardManagerWrapper,
  4001. f_locals: dict[str, object],
  4002. compile_id: Optional[CompileId],
  4003. # pyrefly: ignore [implicit-any]
  4004. backend: Optional[Callable],
  4005. ) -> str:
  4006. """
  4007. Return the reason why `guard_manager` failed.
  4008. Updates `guard_failures` with the generated reason.
  4009. Only the first failed check of guard_manager is reported.
  4010. """
  4011. assert guard_manager.global_scope is not None
  4012. assert guard_manager.closure_vars is not None
  4013. scope = {"L": f_locals, "G": guard_manager.global_scope["G"]}
  4014. scope.update(guard_manager.closure_vars)
  4015. reasons: list[str] = []
  4016. cache_entry_backend = None
  4017. if guard_manager.cache_entry:
  4018. cache_entry_backend = guard_manager.cache_entry.backend
  4019. no_tensor_aliasing_check_failed = False
  4020. verbose_code_parts: list[str] = []
  4021. guard_debug_info = guard_manager.check_verbose(f_locals)
  4022. user_stack_str = ""
  4023. # For test_export_with_map_cond, the check_verbose fail even without the
  4024. # C++ guard manager. We need to fix the issue to remove the comment.
  4025. # assert not guard_debug_info.result
  4026. if not guard_debug_info.result:
  4027. verbose_code_parts = guard_debug_info.verbose_code_parts
  4028. # verbose_code_parts is either the actual reason (e.g. in case of
  4029. # TENSOR_MATCH) or it could be a list of verbose_code_part that we
  4030. # passed to the leaf guard at construction time. If its a list, we
  4031. # walk through this list and find the guard that failed. This is
  4032. # very important for symbolic shape guards which are currently
  4033. # installed as a lambda guard and can encompass a long list of code_parts.
  4034. if len(verbose_code_parts) == 1:
  4035. if "Duplicate tensor found" in verbose_code_parts[0]:
  4036. no_tensor_aliasing_check_failed = True
  4037. else:
  4038. reasons = verbose_code_parts
  4039. verbose_code_parts = []
  4040. # Format user stack trace if available and recompile logging is enabled
  4041. if guard_debug_info.user_stack:
  4042. user_stack_str = format_user_stack_trace(guard_debug_info.user_stack)
  4043. elif cache_entry_backend != backend:
  4044. # None of the guard entries failed - a backend match issue
  4045. reason = (
  4046. "BACKEND_MATCH failure: torch.compile detected different backend callables."
  4047. " If this is unexpected, wrap your backend in functools.partial (or reuse the"
  4048. " same cached backend) to avoid creating a new backend function each time."
  4049. " More details: https://github.com/pytorch/pytorch/issues/168373"
  4050. )
  4051. reasons.append(reason)
  4052. else:
  4053. # Unexpected recompilation - points to a bug
  4054. reason = (
  4055. "Unexpected recompilation: runtime guards failed even though they passed"
  4056. " during recompilation-reason analysis."
  4057. " Please open an issue with a minimal repro:"
  4058. " https://github.com/pytorch/pytorch"
  4059. )
  4060. reasons.append(reason)
  4061. if no_tensor_aliasing_check_failed:
  4062. reasons = recompilation_reason_for_no_tensor_aliasing_guard(
  4063. guard_manager, scope
  4064. )
  4065. else:
  4066. for part in verbose_code_parts:
  4067. global_scope = dict(guard_manager.global_scope)
  4068. global_scope["__compile_source__"] = part
  4069. with report_compile_source_on_error():
  4070. try:
  4071. fail_reason = eval(part, global_scope, scope)
  4072. except Exception:
  4073. if is_recompiles_verbose_enabled():
  4074. continue
  4075. else:
  4076. raise
  4077. # Only ___check_tensors knows how to return a fancy fail reason;
  4078. # for everything else we just report the code that failed
  4079. if isinstance(fail_reason, bool) and not fail_reason:
  4080. fail_reason = part
  4081. if isinstance(fail_reason, str):
  4082. reasons.append(fail_reason)
  4083. if not is_recompiles_verbose_enabled():
  4084. break
  4085. # Build reason string - simple format for normal logging
  4086. # Use singular "reason" when there's only one, plural "reasons" for multiple
  4087. if len(reasons) == 1:
  4088. reason_str = f"{compile_id}: {reasons[0]}"
  4089. else:
  4090. reason_str = f"{compile_id}: " + "; ".join(reasons)
  4091. if user_stack_str:
  4092. reason_str += f"\nUser stack trace:\n{user_stack_str}"
  4093. return strip_local_scope(reason_str)
  4094. def get_guard_fail_reason(
  4095. guard_manager: GuardManagerWrapper,
  4096. code: types.CodeType,
  4097. f_locals: dict[str, object],
  4098. compile_id: CompileId,
  4099. # pyrefly: ignore [implicit-any]
  4100. backend: Callable,
  4101. skip_logging: bool = False,
  4102. ) -> str:
  4103. if isinstance(guard_manager, DeletedGuardManagerWrapper):
  4104. return f"{compile_id}: {guard_manager.invalidation_reason}"
  4105. reason_str = get_guard_fail_reason_helper(
  4106. guard_manager, f_locals, compile_id, backend
  4107. )
  4108. if skip_logging:
  4109. return reason_str
  4110. guard_failures[orig_code_map[code]].append(reason_str)
  4111. try:
  4112. if guard_manager.guard_fail_fn is not None:
  4113. guard_manager.guard_fail_fn(
  4114. GuardFail(reason_str or "unknown reason", orig_code_map[code])
  4115. )
  4116. except Exception:
  4117. log.exception(
  4118. "Failure in guard_fail_fn callback - raising here will cause a NULL Error on guard eval",
  4119. )
  4120. return reason_str
  4121. def get_and_maybe_log_recompilation_reasons(
  4122. cache_entry: Optional[CacheEntry],
  4123. frame: DynamoFrameType,
  4124. # pyrefly: ignore [implicit-any]
  4125. backend: Callable,
  4126. skip_logging: bool = False,
  4127. ) -> list[str]:
  4128. """
  4129. Return the list of guard failure reasons using cache_entry.
  4130. Logs the recompilation reason if `recompiles` logging is enabled.
  4131. Raises a RecompileError if `config.error_on_recompile` is enabled.
  4132. """
  4133. # pyrefly: ignore [implicit-any]
  4134. reasons = []
  4135. while cache_entry is not None:
  4136. reason = get_guard_fail_reason(
  4137. cache_entry.guard_manager,
  4138. cache_entry.code,
  4139. frame.f_locals,
  4140. cache_entry.compile_id,
  4141. backend,
  4142. skip_logging,
  4143. )
  4144. if reason:
  4145. reasons.append(reason)
  4146. cache_entry = cache_entry.next
  4147. code = frame.f_code
  4148. if skip_logging:
  4149. return reasons
  4150. # at least one of "recompiles" or "recompiles_verbose" is enabled
  4151. do_recompiles_log = is_recompiles_enabled() or is_recompiles_verbose_enabled()
  4152. if do_recompiles_log or config.error_on_recompile:
  4153. if is_recompiles_verbose_enabled():
  4154. failures = "\n\n".join(
  4155. f"guard {i} failures:\n" + textwrap.indent(reason, "- ")
  4156. for i, reason in enumerate(reasons)
  4157. )
  4158. else:
  4159. failures = textwrap.indent("\n".join(reasons), "- ")
  4160. guard_failure_details = (
  4161. f"triggered by the following guard failure(s):\n{failures}"
  4162. )
  4163. message = (
  4164. f"Recompiling function {code.co_name} in {code.co_filename}:{code.co_firstlineno}\n"
  4165. f"{textwrap.indent(guard_failure_details, ' ')}"
  4166. )
  4167. if do_recompiles_log:
  4168. if is_recompiles_verbose_enabled():
  4169. recompiles_verbose_log.debug(message)
  4170. else:
  4171. recompiles_log.debug(message)
  4172. if config.error_on_recompile:
  4173. raise exc.RecompileError(message)
  4174. torch._logging.trace_structured(
  4175. "artifact",
  4176. metadata_fn=lambda: {
  4177. "name": "recompile_reasons",
  4178. "encoding": "json",
  4179. },
  4180. payload_fn=lambda: reasons[0] if len(reasons) == 1 else reasons,
  4181. )
  4182. return reasons
  4183. def update_diff_guard_managers_for_existing_cache_entries(
  4184. cache_entry: Optional[CacheEntry],
  4185. ) -> OrderedSet[str]:
  4186. first_cache_entry = cache_entry
  4187. # On the first pass, go through the cache entries and accumulate the diff
  4188. # guard sources. Different guard managers can fail with different sources.
  4189. # So, we collect all of them first.
  4190. acc_diff_guard_sources: OrderedSet[str] = OrderedSet()
  4191. while cache_entry is not None:
  4192. acc_diff_guard_sources.update(
  4193. cache_entry.guard_manager.collect_diff_guard_sources()
  4194. )
  4195. cache_entry = cache_entry.next # type: ignore[assignment]
  4196. # On the second pass, set the diff_guard_sources for each cache line to the
  4197. # accumulated value. And the re-populate the diff guard manager.
  4198. cache_entry = first_cache_entry
  4199. while cache_entry is not None:
  4200. cache_entry.guard_manager.diff_guard_sources = acc_diff_guard_sources
  4201. cache_entry.guard_manager.populate_diff_guard_manager()
  4202. cache_entry = cache_entry.next # type: ignore[assignment]
  4203. # return the accumulated sources to set up the new cache line.
  4204. return acc_diff_guard_sources
  4205. def guard_error_hook(
  4206. guard_manager: GuardFn,
  4207. code: types.CodeType,
  4208. f_locals: dict[str, object],
  4209. index: int,
  4210. last: bool,
  4211. ) -> None:
  4212. print(
  4213. f"ERROR RUNNING GUARDS {code.co_name} {code.co_filename}:{code.co_firstlineno}"
  4214. )
  4215. print("lambda " + ", ".join(guard_manager.args) + ":")
  4216. print(" ", " and\n ".join(guard_manager.code_parts))
  4217. print(guard_manager)
  4218. local_scope = {"L": f_locals, **guard_manager.closure_vars}
  4219. for guard in guard_manager.code_parts:
  4220. try:
  4221. eval(guard, guard_manager.global_scope, local_scope)
  4222. except: # noqa: B001,E722
  4223. print(f"Malformed guard:\n{guard}")
  4224. set_guard_error_hook(guard_error_hook)
  4225. def unique(seq: Sequence[T]) -> Generator[T, None, None]:
  4226. seen = set()
  4227. for x in seq:
  4228. if x not in seen:
  4229. yield x
  4230. seen.add(x)
  4231. def make_dupe_guard(
  4232. obj_source: Source, dupe_source: Source | None
  4233. ) -> Optional[functools.partial[Any]]:
  4234. # Note - we may end up in a situation where we invoke something like
  4235. # def fn(x, y)
  4236. # with fn(x, x)
  4237. # Prior to the addition of tracking to all relevant objects, we would handle this just fine by
  4238. # eagerly re-entering VB and rewrapping inputs, correctly creating graphargs and placeholders. However,
  4239. # with tracking on inputs, duplicate inputs or aliased relationships may end up getting erased here -
  4240. # In the fn(x, x) example call above look like a graph with a single input.
  4241. # In order to ensure that we do not reuse fn(x, x) for fn(x, y), we create a duplicate input guard.
  4242. # Note - we may not have a source, that is fine, it just means we had an object that is safe to have
  4243. # leave unsourced - like a local list created and discharged entirely within a local scope.
  4244. if dupe_source and dupe_source != obj_source:
  4245. ser_source_is_local = is_from_local_source(dupe_source)
  4246. source_is_local = is_from_local_source(obj_source)
  4247. if is_from_flatten_script_object_source(
  4248. dupe_source
  4249. ) or is_from_flatten_script_object_source(obj_source):
  4250. raise exc.UnsafeScriptObjectError(
  4251. f"{obj_source.name} is aliasing {dupe_source.name}. This is not supported."
  4252. f" Please do a clone for corresponding input."
  4253. )
  4254. # Note - both must be local, or global, or we will run afoul of a lack of merging in how we currently
  4255. # reconcile guards builder scopes in compile_check_fn. This technically means we miss a guard here,
  4256. # so maybe we should do this refactor before we land this...
  4257. # TODO(voz): Combine local and global guard builders.
  4258. if ser_source_is_local == source_is_local:
  4259. # Note - this is a little aggressive - these being duplicate input does not always matter.
  4260. # However, this should always be a sound guard to add here.
  4261. return functools.partial(GuardBuilder.DUPLICATE_INPUT, source_b=dupe_source)
  4262. return None
  4263. def install_guard(*guards: Guard, skip: int = 0) -> None:
  4264. """
  4265. Add dynamo guards to the current tracing context.
  4266. Args:
  4267. guards: guard(s) to add
  4268. skip: number of stack frames to ignore for debug stack trace
  4269. """
  4270. from torch._guards import TracingContext
  4271. collect_debug_stack = guards_log.isEnabledFor(
  4272. logging.DEBUG
  4273. ) or verbose_guards_log.isEnabledFor(logging.DEBUG)
  4274. add = TracingContext.get().guards_context.dynamo_guards.add
  4275. for guard in guards:
  4276. assert isinstance(guard, Guard)
  4277. if is_from_skip_guard_source(guard.originating_source):
  4278. continue
  4279. add(guard, collect_debug_stack=collect_debug_stack, skip=skip + 1)