symbolic_convert.py 248 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270427142724273427442754276427742784279428042814282428342844285428642874288428942904291429242934294429542964297429842994300430143024303430443054306430743084309431043114312431343144315431643174318431943204321432243234324432543264327432843294330433143324333433443354336433743384339434043414342434343444345434643474348434943504351435243534354435543564357435843594360436143624363436443654366436743684369437043714372437343744375437643774378437943804381438243834384438543864387438843894390439143924393439443954396439743984399440044014402440344044405440644074408440944104411441244134414441544164417441844194420442144224423442444254426442744284429443044314432443344344435443644374438443944404441444244434444444544464447444844494450445144524453445444554456445744584459446044614462446344644465446644674468446944704471447244734474447544764477447844794480448144824483448444854486448744884489449044914492449344944495449644974498449945004501450245034504450545064507450845094510451145124513451445154516451745184519452045214522452345244525452645274528452945304531453245334534453545364537453845394540454145424543454445454546454745484549455045514552455345544555455645574558455945604561456245634564456545664567456845694570457145724573457445754576457745784579458045814582458345844585458645874588458945904591459245934594459545964597459845994600460146024603460446054606460746084609461046114612461346144615461646174618461946204621462246234624462546264627462846294630463146324633463446354636463746384639464046414642464346444645464646474648464946504651465246534654465546564657465846594660466146624663466446654666466746684669467046714672467346744675467646774678467946804681468246834684468546864687468846894690469146924693469446954696469746984699470047014702470347044705470647074708470947104711471247134714471547164717471847194720472147224723472447254726472747284729473047314732473347344735473647374738473947404741474247434744474547464747474847494750475147524753475447554756475747584759476047614762476347644765476647674768476947704771477247734774477547764777477847794780478147824783478447854786478747884789479047914792479347944795479647974798479948004801480248034804480548064807480848094810481148124813481448154816481748184819482048214822482348244825482648274828482948304831483248334834483548364837483848394840484148424843484448454846484748484849485048514852485348544855485648574858485948604861486248634864486548664867486848694870487148724873487448754876487748784879488048814882488348844885488648874888488948904891489248934894489548964897489848994900490149024903490449054906490749084909491049114912491349144915491649174918491949204921492249234924492549264927492849294930493149324933493449354936493749384939494049414942494349444945494649474948494949504951495249534954495549564957495849594960496149624963496449654966496749684969497049714972497349744975497649774978497949804981498249834984498549864987498849894990499149924993499449954996499749984999500050015002500350045005500650075008500950105011501250135014501550165017501850195020502150225023502450255026502750285029503050315032503350345035503650375038503950405041504250435044504550465047504850495050505150525053505450555056505750585059506050615062506350645065506650675068506950705071507250735074507550765077507850795080508150825083508450855086508750885089509050915092509350945095509650975098509951005101510251035104510551065107510851095110511151125113511451155116511751185119512051215122512351245125512651275128512951305131513251335134513551365137513851395140514151425143514451455146514751485149515051515152515351545155515651575158515951605161516251635164516551665167516851695170517151725173517451755176517751785179518051815182518351845185518651875188518951905191519251935194519551965197519851995200520152025203520452055206520752085209521052115212521352145215521652175218521952205221522252235224522552265227522852295230523152325233523452355236523752385239524052415242524352445245524652475248524952505251525252535254525552565257525852595260526152625263526452655266526752685269527052715272527352745275527652775278527952805281528252835284528552865287528852895290529152925293529452955296529752985299530053015302530353045305530653075308530953105311531253135314531553165317531853195320532153225323532453255326532753285329533053315332533353345335533653375338533953405341534253435344534553465347534853495350535153525353535453555356535753585359536053615362536353645365536653675368536953705371537253735374537553765377537853795380538153825383538453855386538753885389539053915392539353945395539653975398539954005401540254035404540554065407540854095410541154125413541454155416541754185419542054215422542354245425542654275428542954305431543254335434543554365437543854395440544154425443544454455446544754485449545054515452545354545455545654575458545954605461546254635464546554665467546854695470547154725473547454755476547754785479548054815482548354845485548654875488548954905491549254935494549554965497549854995500550155025503550455055506550755085509551055115512551355145515551655175518551955205521552255235524552555265527552855295530553155325533553455355536553755385539554055415542554355445545554655475548554955505551555255535554555555565557555855595560556155625563556455655566556755685569557055715572557355745575557655775578557955805581558255835584558555865587558855895590559155925593559455955596559755985599560056015602560356045605560656075608560956105611561256135614561556165617561856195620562156225623562456255626562756285629563056315632563356345635563656375638563956405641564256435644564556465647564856495650565156525653565456555656565756585659566056615662566356645665566656675668566956705671567256735674567556765677567856795680568156825683568456855686568756885689569056915692569356945695569656975698569957005701570257035704570557065707570857095710571157125713571457155716571757185719572057215722572357245725572657275728572957305731573257335734573557365737573857395740574157425743574457455746574757485749575057515752575357545755575657575758575957605761576257635764576557665767576857695770577157725773577457755776577757785779578057815782578357845785578657875788578957905791579257935794579557965797579857995800580158025803580458055806580758085809581058115812581358145815581658175818581958205821582258235824582558265827582858295830583158325833583458355836583758385839584058415842584358445845584658475848584958505851585258535854585558565857585858595860586158625863586458655866586758685869587058715872587358745875587658775878587958805881588258835884588558865887588858895890589158925893589458955896589758985899590059015902590359045905590659075908590959105911591259135914591559165917591859195920592159225923592459255926592759285929593059315932593359345935593659375938593959405941594259435944594559465947594859495950595159525953595459555956595759585959596059615962596359645965596659675968596959705971597259735974597559765977597859795980598159825983598459855986598759885989599059915992599359945995599659975998599960006001600260036004600560066007600860096010601160126013601460156016601760186019602060216022602360246025602660276028602960306031603260336034603560366037603860396040604160426043604460456046604760486049605060516052605360546055605660576058605960606061606260636064606560666067606860696070607160726073607460756076607760786079608060816082608360846085608660876088608960906091609260936094609560966097609860996100610161026103610461056106610761086109611061116112611361146115611661176118611961206121612261236124612561266127612861296130613161326133613461356136613761386139614061416142614361446145614661476148614961506151615261536154615561566157615861596160616161626163616461656166616761686169617061716172617361746175617661776178617961806181618261836184618561866187618861896190619161926193619461956196619761986199620062016202620362046205620662076208620962106211621262136214621562166217621862196220622162226223622462256226622762286229623062316232623362346235623662376238623962406241624262436244624562466247624862496250625162526253625462556256625762586259626062616262626362646265626662676268
  1. """
  2. Core module responsible for converting Python bytecode into TorchDynamo's symbolic execution format.
  3. This module implements the bytecode-level tracing system that allows TorchDynamo to analyze
  4. and transform Python code. It converts Python bytecode instructions into a symbolic format
  5. that tracks the flow of tensors and other values through the program.
  6. Key components:
  7. - InstructionTranslatorBase: Base class for converting bytecode to symbolic execution
  8. - InstructionTranslator: Main translator for function bytecode
  9. - InliningInstructionTranslator: Handles inlining of called functions
  10. - SpeculationLog: Manages state for speculative execution and rollback
  11. The symbolic conversion process handles:
  12. - Control flow (loops, conditionals, etc.)
  13. - Function inlining and call stack management
  14. - Tracking of program values and side effects
  15. - Graph breaks and resumption points
  16. - Exception handling and stack frame management
  17. This is a core part of TorchDynamo's tracing system that enables ahead-of-time
  18. optimization of PyTorch programs.
  19. """
  20. from __future__ import annotations
  21. import collections
  22. import collections.abc
  23. import contextlib
  24. import copy
  25. import dataclasses
  26. import dis
  27. import functools
  28. import importlib
  29. import inspect
  30. import itertools
  31. import linecache
  32. import logging
  33. import operator
  34. import re
  35. import sys
  36. import threading
  37. import time
  38. import traceback
  39. import types
  40. import weakref
  41. from collections import deque
  42. from typing import Any, cast, NoReturn, Optional, TYPE_CHECKING, TypeAlias, Union
  43. from typing_extensions import TypeIs
  44. import torch
  45. import torch._logging
  46. from torch._dynamo.dynamo_profiler import DynamoProfilerState, FunctionTraceTiming
  47. from torch._dynamo.exc import ObservedException, TensorifyScalarRestartAnalysis
  48. from torch._guards import InlinedCodeCache, tracing, TracingContext
  49. from torch._logging.structured import dump_file
  50. from torch.fx.experimental.symbolic_shapes import guard_bool
  51. from torch.utils._functools import cache_method
  52. from . import (
  53. config,
  54. exc,
  55. graph_break_hints,
  56. logging as torchdynamo_logging,
  57. trace_rules,
  58. variables,
  59. )
  60. from .bytecode_analysis import (
  61. get_indexof,
  62. JUMP_OPNAMES,
  63. livevars_analysis,
  64. propagate_line_nums,
  65. )
  66. from .bytecode_transformation import (
  67. cleaned_instructions,
  68. create_binary_slice,
  69. create_call_function,
  70. create_call_function_ex,
  71. create_copy,
  72. create_dup_top,
  73. create_instruction,
  74. create_jump_absolute,
  75. create_rot_n,
  76. create_swap,
  77. get_code_keys,
  78. Instruction,
  79. is_generator,
  80. is_jump_absolute,
  81. unique_id,
  82. )
  83. from .code_context import code_context
  84. from .codegen import PyCodegen
  85. from .exc import (
  86. augment_exc_message_with_hop_name,
  87. BackendCompilerFailed,
  88. collapse_resume_frames,
  89. format_frame_info,
  90. get_stack_above_dynamo,
  91. ResumePrologueTracingError,
  92. StepUnsupported,
  93. unimplemented,
  94. Unsupported,
  95. )
  96. from .funcname_cache import get_funcname
  97. from .guards import GuardBuilder, install_guard
  98. from .output_graph import GraphCompileReason, OutputGraph, StackLocalsMetadata
  99. from .polyfills import (
  100. impl_CONTAINS_OP_fallback,
  101. impl_IS_MAPPING,
  102. impl_MATCH_CLASS,
  103. impl_MATCH_KEYS,
  104. impl_MATCH_SEQUENCE,
  105. )
  106. from .replay_record import DummyModule, ExecutionRecorder
  107. from .resume_execution import (
  108. ContinueExecutionCache,
  109. IS_TRACING_RESUME_PROLOGUE_VARNAME,
  110. ReenterWith,
  111. TORCH_DYNAMO_RESUME_IN_PREFIX,
  112. )
  113. from .source import (
  114. AttrSource,
  115. DictGetItemSource,
  116. GlobalSource,
  117. GlobalWeakRefSource,
  118. LocalCellSource,
  119. LocalSource,
  120. SkipGuardSource,
  121. Source,
  122. )
  123. from .trace_rules import is_builtin_constant, is_forbidden
  124. from .utils import (
  125. _get_error_on_graph_break,
  126. counters,
  127. get_fake_value,
  128. get_instruction_source_311,
  129. get_metrics_context,
  130. graph_break_dup_warning_checker,
  131. istype,
  132. LazyString,
  133. proxy_args_kwargs,
  134. )
  135. from .variables.base import typestr, ValueMutationNew, VariableTracker
  136. from .variables.builder import FrameStateSizeEntry, VariableBuilder, wrap_fx_proxy
  137. from .variables.builtin import BuiltinVariable
  138. from .variables.constant import CONSTANT_VARIABLE_NONE, ConstantVariable
  139. from .variables.ctx_manager import (
  140. ContextWrappingVariable,
  141. GenericContextWrappingVariable,
  142. WithEnterFunctionVariable,
  143. WithExitFunctionVariable,
  144. )
  145. from .variables.dicts import ConstDictVariable, SetVariable
  146. from .variables.functions import (
  147. BaseUserFunctionVariable,
  148. LocalGeneratorFunctionVariable,
  149. LocalGeneratorObjectVariable,
  150. NestedUserFunctionVariable,
  151. SkipFunctionVariable,
  152. UserFunctionVariable,
  153. UserMethodVariable,
  154. )
  155. from .variables.iter import MAX_ITERATOR_LIMIT
  156. from .variables.lazy import LazyVariableTracker
  157. from .variables.lists import (
  158. BaseListVariable,
  159. IteratorVariable,
  160. ListIteratorVariable,
  161. ListVariable,
  162. SliceVariable,
  163. TupleVariable,
  164. )
  165. from .variables.misc import (
  166. CellVariable,
  167. ExceptionVariable,
  168. GetAttrVariable,
  169. NullVariable,
  170. PythonModuleVariable,
  171. TracebackVariable,
  172. UnknownVariable,
  173. )
  174. from .variables.nn_module import NNModuleVariable, UnspecializedNNModuleVariable
  175. from .variables.streams import SymbolicStreamState
  176. from .variables.tensor import supported_comparison_ops, SymNodeVariable
  177. from .variables.torch_function import (
  178. SymbolicTorchFunctionState,
  179. TorchFunctionModeVariable,
  180. )
  181. from .variables.user_defined import (
  182. RemovableHandleVariable,
  183. UserDefinedClassVariable,
  184. UserDefinedExceptionClassVariable,
  185. UserDefinedExceptionObjectVariable,
  186. UserDefinedObjectVariable,
  187. )
  188. if TYPE_CHECKING:
  189. from collections.abc import Callable, Generator, Sequence
  190. from torch._subclasses.fake_tensor import FakeTensorMode
  191. from .package import CompilePackage
  192. log = logging.getLogger(__name__)
  193. graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
  194. trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
  195. trace_source_log = torch._logging.getArtifactLogger(__name__, "trace_source")
  196. trace_bytecode_log = torch._logging.getArtifactLogger(__name__, "trace_bytecode")
  197. tls = threading.local()
  198. compare_op_handlers: dict[str, Any] = {
  199. k: BuiltinVariable(v).call_function for k, v in supported_comparison_ops.items()
  200. }
  201. handle_contains = BuiltinVariable(operator.contains).call_function
  202. handle_not = BuiltinVariable(operator.not_).call_function
  203. compare_op_handlers["in"] = lambda tx, args, _: handle_contains(
  204. tx, [*reversed(args)], {}
  205. )
  206. compare_op_handlers["not in"] = lambda tx, args, _: handle_not(
  207. tx, [handle_contains(tx, [*reversed(args)], {})], {}
  208. )
  209. ExceptionVals: TypeAlias = Union[
  210. variables.ExceptionVariable,
  211. UserDefinedExceptionClassVariable,
  212. UserDefinedExceptionObjectVariable,
  213. ]
  214. @functools.cache
  215. def _import_module(name: str) -> types.ModuleType:
  216. """
  217. Import the named module and cache the result. importlib.import_module()
  218. seems to do some filesystem checking to validate the name so not caching
  219. this can be slow.
  220. """
  221. return importlib.import_module(name)
  222. @dataclasses.dataclass
  223. class SpeculationEntry:
  224. filename: str
  225. lineno: int
  226. instruction_pointer: int
  227. inst: Instruction # for debugging only
  228. _failed: bool = False
  229. # TX error_on_graph_break setting at the time of failure
  230. error_on_graph_break: Optional[bool] = None
  231. reason: Optional[GraphCompileReason] = None
  232. def fail_and_restart_analysis(self, error_on_graph_break: bool) -> None:
  233. """
  234. Start tracing of the current frame over again, and don't take this branch.
  235. """
  236. self._failed = True
  237. self.error_on_graph_break = error_on_graph_break
  238. if self.reason is not None:
  239. restart_reason = self.reason.reason
  240. else:
  241. restart_reason = "Unknown fail_and_restart_analysis"
  242. raise exc.SpeculationRestartAnalysis(restart_reason=restart_reason)
  243. def failed(self, tx: InstructionTranslatorBase) -> bool:
  244. if self._failed:
  245. assert self.error_on_graph_break is not None
  246. tx.error_on_graph_break = self.error_on_graph_break
  247. return True
  248. return False
  249. @dataclasses.dataclass
  250. class SpeculationLog:
  251. """
  252. SpeculationLog replaces the prior copy_graphstate/restore_graphstate
  253. checkpointing. Rather than saving/restoring state, we restart the
  254. dynamo conversion process over from the beginning -- but when we
  255. hit the start of the speculation that failed, we instead generate
  256. a graph break.
  257. """
  258. entries: list[SpeculationEntry] = dataclasses.field(default_factory=list)
  259. index: int = 0
  260. # If True, graph break at autograd.grad instead of tracing it.
  261. # Set when we detect that autograd.grad consumed grad_fns that are returned.
  262. graph_break_on_autograd_grad: bool = False
  263. def restart(self) -> None:
  264. self.index = 0
  265. def clear(self) -> None:
  266. self.entries.clear()
  267. self.index = 0
  268. def next(
  269. self, filename: str, lineno: int, instruction_pointer: int, inst: Instruction
  270. ) -> SpeculationEntry:
  271. """
  272. Lookup or create a SpeculationEntry() that is shared across
  273. RestartAnalysis calls. Args are used only for debug checks.
  274. """
  275. if len(self.entries) == self.index:
  276. self.entries.append(
  277. SpeculationEntry(filename, lineno, instruction_pointer, inst)
  278. )
  279. entry = self.entries[self.index]
  280. prev_entry_msg = ""
  281. if self.index != 0:
  282. prev_entry = self.entries[self.index - 1]
  283. prev_entry_msg = (
  284. f"Previous instruction: {prev_entry.filename}:{prev_entry.lineno}"
  285. f"({prev_entry.inst.opname} @ {prev_entry.instruction_pointer})\n"
  286. )
  287. if not (
  288. entry.instruction_pointer == instruction_pointer
  289. and entry.filename == filename
  290. and entry.lineno == lineno
  291. ):
  292. raise SpeculationLogDivergence(
  293. f"""
  294. SpeculationLog diverged at index {self.index} (log had {len(self.entries)} entries):
  295. - Expected: {entry.filename}:{entry.lineno} ({entry.inst.opname} at ip={entry.instruction_pointer})
  296. - Actual: {filename}:{lineno} ({inst.opname} at ip={instruction_pointer})
  297. {prev_entry_msg}
  298. There are two usual reasons why this may have occurred:
  299. - When Dynamo analysis restarted, the second run took a different path than
  300. the first. If this occurred, the previous instruction is the critical instruction that
  301. behaved differently.
  302. - Speculation entries are only added under certain conditions (as seen in
  303. step()), e.g., there must exist operators in the graph; those conditions may
  304. have changed on restart.
  305. If this divergence was intentional, clear the speculation log before restarting (do NOT
  306. do this for graph breaks, you will infinite loop).
  307. Otherwise, please submit a bug report, ideally including the contents of TORCH_LOGS=+dynamo
  308. """
  309. )
  310. self.index += 1
  311. return entry
  312. @dataclasses.dataclass
  313. class LocalState:
  314. automatic_dynamic: dict[str, FrameStateSizeEntry] = dataclasses.field(
  315. default_factory=dict
  316. )
  317. def render(self) -> str:
  318. return "\n".join(
  319. f"{k}: {v.render()}" for k, v in self.automatic_dynamic.items()
  320. )
  321. # Mutable box that is shared across restarts
  322. @dataclasses.dataclass
  323. class DistributedState:
  324. compile_pg: Any
  325. local_state: LocalState
  326. all_states: Optional[list[LocalState]] = None
  327. class TensorifyState:
  328. # These are the set of string symfloats names (eg. "zf0") that we collect
  329. # from the tensorify_python_scalars.py joint fx pass to inform us about
  330. # which float inputs we should specialize when we restart analysis.
  331. force_specializations: set[str] = set()
  332. @classmethod
  333. def specialize(cls, index: str) -> None:
  334. cls.force_specializations.add(index)
  335. @classmethod
  336. def should_specialize(cls, index: str) -> bool:
  337. return index in cls.force_specializations
  338. @classmethod
  339. def clear(cls) -> None:
  340. cls.force_specializations.clear()
  341. @classmethod
  342. def empty(cls) -> bool:
  343. return len(cls.force_specializations) == 0
  344. @functools.cache
  345. def _step_logger() -> Callable[..., None]:
  346. return torchdynamo_logging.get_step_logger(log)
  347. @contextlib.contextmanager
  348. def save_and_restart_speculation_log(
  349. tx: InstructionTranslatorBase,
  350. ) -> Generator[None, None, None]:
  351. # When reconstructing a generator after a graph break, we advance it until
  352. # it is fully exhausted. This process adds new entries to the speculation
  353. # log that were not previously observed. Without temporarily clearing the
  354. # speculation log, this could lead to a divergence error.
  355. entries = tx.speculation_log.entries
  356. index = tx.speculation_log.index
  357. try:
  358. tx.speculation_log.entries = []
  359. tx.speculation_log.index = 0
  360. yield
  361. finally:
  362. tx.speculation_log.entries = entries
  363. tx.speculation_log.index = index
  364. @contextlib.contextmanager
  365. def temporarely_allow_writes_to_output_graph(
  366. tx: InstructionTranslatorBase,
  367. ) -> Generator[None, None, None]:
  368. try:
  369. tmp = tx.output.should_exit
  370. tx.output.should_exit = False
  371. yield
  372. finally:
  373. tx.output.should_exit = tmp
  374. @dataclasses.dataclass
  375. class BlockStackEntry:
  376. # Current instruction that pushes something to block_stack
  377. inst: Instruction
  378. target: Instruction | None
  379. stack_index: int
  380. with_context: Optional[
  381. Union[ContextWrappingVariable, GenericContextWrappingVariable]
  382. ] = None
  383. def can_restore(self) -> bool:
  384. return self.with_context is not None
  385. def resume_fn(self) -> ReenterWith:
  386. assert self.stack_index is not None
  387. if (
  388. self.with_context
  389. and hasattr(self.with_context, "target_values")
  390. and self.with_context.target_values
  391. ):
  392. return ReenterWith(
  393. self.stack_index - 1, tuple(self.with_context.target_values)
  394. )
  395. else:
  396. return ReenterWith(self.stack_index - 1)
  397. def exit(
  398. self, tx: InstructionTranslatorBase, is_graph_break: bool
  399. ) -> VariableTracker | None:
  400. assert self.with_context is not None
  401. if (
  402. is_graph_break and self.with_context.exit_on_graph_break()
  403. ) or not is_graph_break:
  404. return self.with_context.exit(tx) # type: ignore[arg-type]
  405. return None
  406. class SpeculationLogDivergence(AssertionError):
  407. pass
  408. class ReturnValueOp(Exception):
  409. pass
  410. class YieldValueOp(Exception):
  411. """
  412. Signal to the symbolic tracer to stop and return control flow to the
  413. caller
  414. """
  415. def stack_op(fn: Callable[..., object]) -> Callable[..., Any]:
  416. nargs = len(inspect.signature(fn).parameters)
  417. fn_var = BuiltinVariable(fn)
  418. @functools.wraps(fn)
  419. def impl(self: InstructionTranslator, inst: Instruction) -> None:
  420. self.push(fn_var.call_function(self, self.popn(nargs), {}))
  421. return impl
  422. def is_stdlib(mod: object) -> bool:
  423. if not isinstance(mod, types.ModuleType):
  424. return False
  425. return mod.__name__.split(".")[0] in sys.stdlib_module_names
  426. @functools.cache
  427. def get_assert_bytecode_sequence(with_msg: bool) -> list[str]:
  428. if with_msg:
  429. def fn(x: Any) -> None:
  430. assert x, "msg"
  431. else:
  432. def fn(x: Any) -> None:
  433. assert x
  434. insts = [inst.opname for inst in dis.get_instructions(fn)]
  435. # expect to find POP_JUMP_[FORWARD_]IF_TRUE
  436. begin_idx = next(i for i, inst in enumerate(insts) if inst.startswith("POP_JUMP"))
  437. end_idx = insts.index("RAISE_VARARGS")
  438. return insts[begin_idx + 1 : end_idx + 1]
  439. @functools.cache
  440. def _get_comprehension_bytecode_prefix() -> list[str]:
  441. """Get the bytecode instructions that precede BUILD_LIST in a list comprehension."""
  442. assert sys.version_info >= (3, 12)
  443. def fn() -> list[int]:
  444. return [i for i in range(1)] # noqa: C416
  445. insts = [inst.opname for inst in dis.get_instructions(fn)]
  446. start_idx = len(insts) - 1 - insts[::-1].index("LOAD_FAST_AND_CLEAR")
  447. end_idx = insts.index("BUILD_LIST")
  448. return insts[start_idx:end_idx]
  449. @functools.cache
  450. def _get_comprehension_result_patterns() -> dict[str, dict[str, Any]]:
  451. """Discover bytecode patterns for comprehension result handling.
  452. Analyzes sample functions to extract the opcode sequences that appear
  453. after END_FOR for each result disposition (stored, discarded, returned, consumed).
  454. Returns patterns with:
  455. - pre_store_ops: opcodes between END_FOR and first STORE_FAST
  456. - post_store_op: first opcode after all STORE_FASTs (for disambiguation)
  457. """
  458. assert sys.version_info >= (3, 12)
  459. def fn_stored() -> list[int]:
  460. result = [i for i in range(1)] # noqa: C416
  461. return result
  462. def fn_discarded() -> int:
  463. [i for i in range(1)] # noqa: C416
  464. return 1
  465. def fn_returned() -> list[int]:
  466. return [i for i in range(1)] # noqa: C416
  467. def fn_consumed() -> int:
  468. return sum([i for i in range(1)]) # noqa: C416
  469. def extract_pattern(fn: Callable[..., Any]) -> tuple[list[str], Optional[str]]:
  470. """Extract (pre_store_ops, post_store_op) from comprehension bytecode."""
  471. target_line = list(dis.findlinestarts(fn.__code__))[1][1]
  472. insts: list[str] = []
  473. started = False
  474. for instr in dis.get_instructions(fn):
  475. if started and instr.starts_line:
  476. break
  477. pos = instr.positions
  478. if pos and pos.lineno == target_line:
  479. started = started or bool(instr.starts_line)
  480. insts.append(instr.opname)
  481. ops = insts[insts.index("END_FOR") + 1 :]
  482. idx = 0
  483. pre_store_ops = []
  484. while idx < len(ops) and ops[idx] != "STORE_FAST":
  485. pre_store_ops.append(ops[idx])
  486. idx += 1
  487. while idx < len(ops) and ops[idx] == "STORE_FAST":
  488. idx += 1
  489. return pre_store_ops, ops[idx] if idx < len(ops) else None
  490. stored = extract_pattern(fn_stored)
  491. discarded = extract_pattern(fn_discarded)
  492. returned = extract_pattern(fn_returned)
  493. consumed = extract_pattern(fn_consumed)
  494. return {
  495. "stored": {"pre_store_ops": stored[0], "post_store_op": stored[1]},
  496. "discarded": {"pre_store_ops": discarded[0], "post_store_op": discarded[1]},
  497. "returned": {"pre_store_ops": returned[0], "post_store_op": returned[1]},
  498. "consumed": {"pre_store_ops": consumed[0], "post_store_op": []},
  499. }
  500. @dataclasses.dataclass
  501. class ComprehensionAnalysis:
  502. """Metadata about a comprehension's bytecode structure.
  503. Attributes:
  504. end_ip: Instruction pointer after all comprehension bytecode
  505. result_var: Name of result variable, or None if result stays on stack
  506. result_on_stack: True if result stays on stack (discarded, returned, or in expression)
  507. iterator_vars: Variables from LOAD_FAST_AND_CLEAR (need restoration)
  508. walrus_vars: Variables assigned via walrus operator (:=) inside comprehension
  509. captured_vars: Variables read from outer scope via LOAD_FAST inside comprehension
  510. """
  511. end_ip: int
  512. result_var: Optional[str]
  513. result_on_stack: bool
  514. iterator_vars: list[str]
  515. walrus_vars: list[str]
  516. captured_vars: list[str]
  517. def _detect_and_normalize_assert_statement(
  518. self: InstructionTranslatorBase,
  519. truth_fn: Callable[[object], bool],
  520. push: bool,
  521. ) -> bool:
  522. # Detect if this jump instruction is assert and normalize the assert
  523. # by pushing dummy error message when nothing is given.
  524. #
  525. # Python 3.9-3.13 assertion is in following format (minus small differences)
  526. # 18 POP_JUMP_IF_TRUE 28
  527. # 20 LOAD_ASSERTION_ERROR
  528. # 22 LOAD_CONST 3 ('Assert message') -> optional instruction
  529. # 24 CALL_FUNCTION 1 -> optional instruction
  530. # 26 RAISE_VARARGS
  531. if (truth_fn is not operator.truth) or push:
  532. return False
  533. assert isinstance(self.instruction_pointer, int)
  534. current_instruction_pointer = self.instruction_pointer
  535. for with_msg in (False, True):
  536. assert_insts = get_assert_bytecode_sequence(with_msg)
  537. cur_insts = self.instructions[
  538. current_instruction_pointer : current_instruction_pointer
  539. + len(assert_insts)
  540. ]
  541. cur_insts = [inst.opname for inst in cur_insts]
  542. if cur_insts == assert_insts:
  543. if with_msg:
  544. load_const_idx = assert_insts.index("LOAD_CONST")
  545. error_msg = self.instructions[
  546. current_instruction_pointer + load_const_idx
  547. ].argval
  548. else:
  549. error_msg = "assertion error"
  550. self.push(ConstantVariable.create(error_msg))
  551. return True
  552. return False
  553. explain = False
  554. # [NOTE] graph break handling in symbolic_convert
  555. # There are 4 possible graph break cases that InstructionTranslatorBase handles:
  556. # 1. Regular graph breaks from CALL, BINARY_SUBSCR, etc. (implemented by break_graph_if_unsupported)
  557. # 2. Data-dependent condition graph breaks (implemented by generic_jump)
  558. # 4. All other unhandled graph breaks - unsupported step graph breaks (implemented in InstructionTranslatorBase.step)
  559. #
  560. # Graph breaks are handled in the following manner:
  561. # 1. The Unsupported exception is caught. If we cannot compile a partial graph (should_compile_partial_graph() is False),
  562. # then propagate the exception upward. For unsupported step graph breaks, the condition to abort partial compilation is
  563. # more restrictive (see InstructionTranslatorBase.step).
  564. # 2. If the Unsupported exception escapes symbolic_convert.py, then we are done.
  565. # Otherwise, we want to attempt partial compilation.
  566. # Log the graph break via log_graph_break. If we're handling a data-dependent graph break (type 2.), then we can immediately
  567. # codegen the compiled graph and resume function and we're done. This is because the jump instruction we graph break on is
  568. # limited in how it can manipulate Python state (say, in comparison, to CALL, which can modify Python state arbitrarily).
  569. # Otherwise, we need to restart compilation. We need to restart because by processing the unsupported instruction,
  570. # we may have modified the VariableTrackers, and we need all of our VariableTrackers to be in the state BEFORE tracing the
  571. # unsupported instruction.
  572. # 3. During the first compilation, we updated a speculation log, indicating points in the code that we can resume from.
  573. # On the second compilation, we will stop tracing at the first speculation log that fails. Then we compile the partial
  574. # graph and resume function.
  575. #
  576. # Logging invariants:
  577. # 1. No logs need to be made if Unsupported escapes symbolic_convert.py. Python's default exception printing will
  578. # print out all of the necessary information and no partial compilation will be attempted.
  579. # 2. log_graph_break should be called as soon as Unsupported is caught and we determined we want to partial compile.
  580. # This always happens on the first compilation, NOT the restart handling this graph
  581. # 3. Any compile_subgraph call should be preceded immediately by a log in the form of "... triggered compile".
  582. def generic_jump(
  583. truth_fn: Callable[[object], bool], push: bool
  584. ) -> Callable[[InstructionTranslatorBase, Instruction], None]:
  585. def raise_jump_graph_break(value: VariableTracker) -> NoReturn:
  586. unimplemented(
  587. gb_type="Data-dependent branching",
  588. context=f"attempted to jump with {value}",
  589. explanation="Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). "
  590. "Dynamo does not support tracing dynamic control flow.",
  591. hints=[
  592. *graph_break_hints.FUNDAMENTAL,
  593. "Use `torch.cond` to express dynamic control flow.",
  594. ],
  595. )
  596. def jump_graph_break(
  597. self: InstructionTranslatorBase,
  598. inst: Instruction,
  599. value: VariableTracker,
  600. extra_msg: str = "",
  601. ) -> None:
  602. assert self.should_compile_partial_graph()
  603. exc = None
  604. try:
  605. raise_jump_graph_break(value)
  606. except Unsupported as e:
  607. exc = e
  608. assert exc is not None
  609. # compile a partial subgraph prefix then skip the rest of user code
  610. if self.maybe_has_backedge():
  611. self.raise_loop_graph_break(self.f_code, exc)
  612. self.log_graph_break(
  613. self.code_options,
  614. reason=str(exc),
  615. exc=exc,
  616. )
  617. self.push(value)
  618. log.debug("generic_jump triggered compile")
  619. all_stack_locals_metadata = self.output.compile_subgraph(
  620. self,
  621. reason=GraphCompileReason(
  622. f"generic_jump {typestr(value)}{extra_msg}", [self.frame_summary()]
  623. ),
  624. stack_pops=1,
  625. )
  626. self.pop()
  627. if_next = self.create_call_resume_at(
  628. self.next_instruction,
  629. all_stack_locals_metadata,
  630. )
  631. if push:
  632. self.push(value)
  633. assert inst.target is not None
  634. if_jump = self.create_call_resume_at(
  635. inst.target,
  636. all_stack_locals_metadata,
  637. )
  638. if sys.version_info >= (3, 13):
  639. # 3.13 requires stack[-1] to be bool type
  640. self.output.add_output_instructions([create_instruction("TO_BOOL")])
  641. jump_inst = create_instruction(inst.opname, target=if_jump[0])
  642. jump_inst.copy_positions(inst)
  643. self.output.add_output_instructions([jump_inst] + if_next + if_jump)
  644. def inner(self: InstructionTranslatorBase, inst: Instruction) -> None:
  645. value: VariableTracker = self.pop()
  646. if (
  647. config.rewrite_assert_with_torch_assert
  648. and _detect_and_normalize_assert_statement(self, truth_fn, push)
  649. ):
  650. error_msg: VariableTracker = self.pop()
  651. # Skip over things like `assert True`
  652. if value.is_python_constant():
  653. if bool(value.as_python_constant()):
  654. return self.jump(inst)
  655. elif self.should_compile_partial_graph():
  656. jump_graph_break(self, inst, value)
  657. else:
  658. unimplemented(
  659. gb_type="Data-dependent assertion failed (cannot compile partial graph)",
  660. context=f"value: {value}",
  661. explanation="Dynamo has determined when encountering a data-dependent assert failure "
  662. "that it should not compile the partial graph.",
  663. hints=[
  664. *graph_break_hints.FUNDAMENTAL,
  665. "Use `torch._assert()` to raise a hard AssertionError when the check fails. "
  666. "This error will propagate back the user code "
  667. "that called the compiled function (i.e. Dynamo will not trace any exception handling).",
  668. "Remove the assert statement.",
  669. "Move the assert statement outside of any context managers in order to graph break with "
  670. "partial graph compilation (if fullgraph=False).",
  671. ],
  672. )
  673. # TODO maybe should respect DtoH sync intention of users later??
  674. # Manually insert torch._assert_async instead of python assert and jump over
  675. # assert related instructions as we don't need them anymore.
  676. # if we see Tensor as assert statement, no need to call scalar_tensor
  677. if value.is_tensor():
  678. self.output.create_proxy(
  679. "call_function",
  680. torch._assert_async,
  681. *proxy_args_kwargs((value, error_msg), {}),
  682. )
  683. self.jump(inst)
  684. return
  685. if isinstance(value, SymNodeVariable):
  686. # if the assertion is normal shape expression.
  687. # just install guard and bail out.
  688. sym_expr = value.sym_num
  689. if not isinstance(sym_expr, torch.SymBool):
  690. sym_expr = sym_expr != 0
  691. result = torch.fx.experimental.symbolic_shapes.expect_true(sym_expr)
  692. if not result:
  693. unimplemented(
  694. gb_type="Assertion failed on symbolic shapes",
  695. context=str(sym_expr),
  696. explanation="",
  697. hints=[*graph_break_hints.USER_ERROR],
  698. )
  699. self.jump(inst)
  700. return
  701. scalar_to_tensor_proxy = self.output.create_proxy(
  702. "call_function", torch.scalar_tensor, *proxy_args_kwargs((value,), {})
  703. )
  704. scalar_to_tensor = wrap_fx_proxy(
  705. self,
  706. scalar_to_tensor_proxy,
  707. example_value=get_fake_value(scalar_to_tensor_proxy.node, self),
  708. )
  709. self.output.create_proxy(
  710. "call_function",
  711. torch._assert_async,
  712. *proxy_args_kwargs((scalar_to_tensor, error_msg), {}),
  713. )
  714. self.jump(inst)
  715. return
  716. if value.is_python_constant():
  717. # ConstDictVariable is optimized to be very lazy about insertion of
  718. # guards, so we have to manually insert a SEQUENCE_LENGTH guard
  719. # here.
  720. if isinstance(value, ConstDictVariable) and value.source:
  721. install_guard(value.source.make_guard(GuardBuilder.SEQUENCE_LENGTH))
  722. if truth_fn(value.as_python_constant()):
  723. if push:
  724. self.push(value)
  725. self.jump(inst)
  726. elif value.is_tensor() and self.should_compile_partial_graph():
  727. jump_graph_break(self, inst, value)
  728. elif isinstance(value, NNModuleVariable):
  729. # Equivalent of "self.nn_module is not None"
  730. mod = self.output.get_submodule(value.module_key)
  731. if truth_fn(mod):
  732. if push:
  733. self.push(value)
  734. self.jump(inst)
  735. elif isinstance(value, UserDefinedObjectVariable):
  736. try:
  737. x = value.var_getattr(self, "__bool__") # type: ignore[arg-type]
  738. except exc.ObservedAttributeError:
  739. exc.handle_observed_exception(self)
  740. # if __bool__ is missing, trying __len__ to infer a truth value.
  741. try:
  742. x = value.var_getattr(self, "__len__") # type: ignore[arg-type]
  743. except exc.ObservedAttributeError:
  744. exc.handle_observed_exception(self)
  745. x = None
  746. # __bool__ or __len__ is function
  747. if isinstance(x, UserMethodVariable):
  748. result = x.call_function(self, [], {}) # type: ignore[arg-type, assignment]
  749. method_name = getattr(getattr(x, "fn", None), "__name__", None)
  750. if result.is_python_constant():
  751. result_value = result.as_python_constant()
  752. if method_name == "__bool__" and not isinstance(result_value, bool):
  753. msg = variables.ConstantVariable.create(
  754. f"__bool__ should return bool, returned {type(result_value).__name__}"
  755. )
  756. exc.raise_observed_exception(TypeError, self, args=[msg])
  757. if isinstance(result_value, (bool, int)) and truth_fn(result_value):
  758. if push:
  759. self.push(value)
  760. self.jump(inst)
  761. elif isinstance(result, SymNodeVariable):
  762. if result.evaluate_expr():
  763. if push:
  764. self.push(value)
  765. self.jump(inst)
  766. else:
  767. unimplemented(
  768. gb_type="Data-dependent branching with non-constant __bool__",
  769. context=f"method: {x}, result: {result}",
  770. explanation="Attempted to perform data-dependent branching on a user-defined "
  771. "object with a __bool__ method that did not return a constant.",
  772. hints=[],
  773. )
  774. # __bool__ or __len__ is non-function or not existed in the user defined object
  775. else:
  776. if truth_fn(True):
  777. if push:
  778. self.push(value)
  779. self.jump(inst)
  780. elif not value.is_tensor() and value.has_unpack_var_sequence(self):
  781. if truth_fn(len(value.unpack_var_sequence(self))):
  782. if push:
  783. self.push(value)
  784. self.jump(inst)
  785. elif isinstance(value, SymNodeVariable):
  786. try:
  787. # if the user is branching on a SymBool, guard on it
  788. # if the user has code like:
  789. # if size:
  790. # ...
  791. # then they are just testing truthiness: guard that the expr != 0
  792. if isinstance(value.sym_num, torch.SymBool):
  793. eval_result = value.evaluate_expr(self.output)
  794. else:
  795. eval_result = guard_bool(value.sym_num != 0)
  796. except exc.UserError as e:
  797. if self.should_compile_partial_graph():
  798. return jump_graph_break(self, inst, value, extra_msg=f"\n{e}")
  799. raise
  800. if truth_fn(eval_result):
  801. if push:
  802. self.push(value)
  803. self.jump(inst)
  804. elif isinstance(value, variables.BackwardHookVariable):
  805. if truth_fn(True):
  806. if push:
  807. self.push(value)
  808. self.jump(inst)
  809. else:
  810. from .source import is_constant_source
  811. if value.source is not None and is_constant_source(value.source):
  812. if truth_fn(value.get_real_value()): # type: ignore[attr-defined]
  813. if push:
  814. self.push(value)
  815. self.jump(inst)
  816. else:
  817. raise_jump_graph_break(value)
  818. return inner
  819. def _reconstruct_block_stack(
  820. tx: InstructionTranslatorBase, cg: PyCodegen, cleanup: list[Instruction]
  821. ) -> None:
  822. """Generates bytecode to restore the block stack for running the unsupported instruction
  823. in the compiled bytecode."""
  824. # Reconstruct the context variable CLASS in the block stack
  825. all_txes: list[InstructionTranslatorBase] = []
  826. cur_tx: Optional[InstructionTranslatorBase] = tx
  827. while cur_tx is not None:
  828. all_txes.append(cur_tx)
  829. cur_tx = cur_tx.parent
  830. for tx in reversed(all_txes):
  831. for b in tx.block_stack:
  832. # Don't exit any modes we have entered,
  833. # output bytecode will mutate the tf mode stack accordingly
  834. if isinstance(b.with_context, TorchFunctionModeVariable):
  835. cg.extend_output(
  836. b.resume_fn().try_except_torch_function_mode(
  837. cg.code_options, cleanup
  838. )
  839. )
  840. continue
  841. assert b.with_context is not None
  842. assert isinstance(b.with_context, (ContextWrappingVariable))
  843. b.with_context.reconstruct_type(cg)
  844. cg.extend_output(b.resume_fn().try_finally(cg.code_options, cleanup))
  845. # NOTE: for the purposes of nested graph breaks, break_graph_if_unsupported only works on instructions
  846. # with 0 or 1 outputs. If you wish to support bytecodes with 2+ outputs, either rewrite the instruction
  847. # into a sequence of simpler instructions, or file an issue for consultation.
  848. # There is an additional requirement that if the instruction causes a function call, e.g. STORE_ATTR,
  849. # nothing should happen to the result of the function call.
  850. def break_graph_if_unsupported(
  851. *, push: bool, msg_prefix: str
  852. ) -> Callable[
  853. [Callable[..., None]], Callable[[InstructionTranslatorBase, Instruction], None]
  854. ]:
  855. def decorator(
  856. inner_fn: Callable[..., None],
  857. ) -> Callable[[InstructionTranslatorBase, Instruction], None]:
  858. @functools.wraps(inner_fn)
  859. def wrapper(self: InstructionTranslatorBase, inst: Instruction) -> None:
  860. prev_push = self.current_instruction_push
  861. self.current_instruction_push = push
  862. speculation = self.speculate()
  863. if speculation.failed(self):
  864. # no need to restore current_instruction_push if speculation failed
  865. assert speculation.reason is not None
  866. return handle_graph_break(self, inst, speculation.reason)
  867. try:
  868. return inner_fn(self, inst)
  869. except Unsupported as excp:
  870. if self.active_generic_context_managers:
  871. # raise original graph break if fullgraph/error_on_graph_break=True
  872. if self.one_graph or self.error_on_graph_break:
  873. raise
  874. # We don't support graph break under GenericContextWrappingVariable,
  875. # If there is, we roll back to the checkpoint and fall back.
  876. excp.remove_from_stats()
  877. unimplemented(
  878. gb_type="Graph break under GenericContextWrappingVariable",
  879. context=f"Active generic context managers: {self.active_generic_context_managers}",
  880. explanation="Attempted to graph break in an active context manager(s) that doesn't support graph breaking.",
  881. hints=[
  882. "Move the offending context manager(s) to outside the compiled region.",
  883. *graph_break_hints.CAUSED_BY_EARLIER_GRAPH_BREAK,
  884. ],
  885. from_exc=excp,
  886. )
  887. if excp.skip_frame:
  888. raise
  889. if not self.should_compile_partial_graph():
  890. raise
  891. if self.maybe_has_backedge():
  892. self.raise_loop_graph_break(self.f_code, excp)
  893. self.log_graph_break(
  894. self.code_options,
  895. reason=f"{msg_prefix}:\n\n{str(excp)}",
  896. exc=excp,
  897. )
  898. excp.remove_from_stats()
  899. excp.add_to_stats("graph_break")
  900. speculation.reason = GraphCompileReason(excp.msg, excp.real_stack)
  901. finally:
  902. self.current_instruction_push = prev_push
  903. speculation.fail_and_restart_analysis(self.error_on_graph_break)
  904. def handle_graph_break(
  905. self: InstructionTranslatorBase,
  906. inst: Instruction,
  907. reason: GraphCompileReason,
  908. ) -> None:
  909. if (
  910. sys.version_info >= (3, 11)
  911. and sys.version_info < (3, 12)
  912. and inst.opname == "CALL"
  913. ):
  914. # stack effect for PRECALL + CALL is split between the two instructions
  915. stack_effect = dis.stack_effect(
  916. dis.opmap["PRECALL"], inst.arg
  917. ) + dis.stack_effect(dis.opmap["CALL"], inst.arg)
  918. else:
  919. stack_effect = dis.stack_effect(inst.opcode, inst.arg)
  920. log.debug("%s triggered compile", inst.opname)
  921. all_stack_locals_metadata = self.output.compile_subgraph(
  922. self, reason=reason, stack_pops=int(push) - stack_effect
  923. )
  924. cg = PyCodegen(self.output.root_tx)
  925. cleanup: list[Instruction] = []
  926. _reconstruct_block_stack(self, cg, cleanup)
  927. self.output.add_output_instructions(cg.get_instructions())
  928. del cg
  929. if sys.version_info >= (3, 11) and inst.opname == "CALL":
  930. kw_names = (
  931. self.kw_names.as_python_constant()
  932. if self.kw_names is not None
  933. else ()
  934. )
  935. if len(kw_names) > 0:
  936. # KW_NAMES no longer used in 3.13
  937. assert sys.version_info < (3, 13)
  938. self.output.add_output_instructions(
  939. [create_instruction("KW_NAMES", argval=kw_names)]
  940. )
  941. assert inst.arg is not None
  942. call_insts = create_call_function(inst.arg, False)
  943. call_insts[-1].copy_positions(inst)
  944. self.output.add_output_instructions(call_insts)
  945. else:
  946. # copy instruction, but without exception table data
  947. assert inst.target is None
  948. inst_copy = copy.copy(inst)
  949. inst_copy.exn_tab_entry = None
  950. self.output.add_output_instructions([inst_copy])
  951. self.output.add_output_instructions(cleanup)
  952. self.popn(int(push) - stack_effect)
  953. if push:
  954. self.push(UnknownVariable())
  955. self.output.add_output_instructions(
  956. self.create_call_resume_at(
  957. self.next_instruction,
  958. all_stack_locals_metadata,
  959. )
  960. )
  961. return wrapper
  962. return decorator
  963. class BytecodeDispatchTableMeta(type):
  964. """Installs a `cls.dispatch_table` on every subclass to speed up calls to self.OPCODE()"""
  965. def __init__(cls: type, name: str, bases: Any, dct: Any) -> None:
  966. super().__init__(name, bases, dct) # type: ignore[misc]
  967. def _missing(opname: str, *args: Any) -> None:
  968. unimplemented(
  969. gb_type="Missing bytecode handler",
  970. context=f"{opname} with args {args}",
  971. explanation=f"Dynamo does not know how to handle the bytecode instruction `{opname}`.",
  972. hints=[
  973. f"Do not trace code that produces the `{opname}` bytecode instruction "
  974. "(see https://docs.python.org/3/library/dis.html for bytecode semantics).",
  975. *graph_break_hints.SUPPORTABLE,
  976. ],
  977. )
  978. dispatch_table = {
  979. op: getattr(cls, opname, functools.partial(_missing, opname))
  980. for opname, op in dis.opmap.items()
  981. }
  982. # pyrefly: ignore [missing-attribute]
  983. cls.dispatch_table = [dispatch_table.get(i) for i in range(2**8)]
  984. @dataclasses.dataclass
  985. class ExceptionStack:
  986. """
  987. Exception stack that it is shared among all InstructionTranslator instances
  988. """
  989. # Exception handling in CPython is a bit confusing and some of the bytecode
  990. # have a slightly different behavior than what is documented. While reading
  991. # the documentation, is important to notice that the terms "current exception"
  992. # and "stack" sometimes refers to a C variable with the same name and the
  993. # exception stack, respectively.
  994. #
  995. # The lifetime of an exception is (Python 3.11+):
  996. # + tx._raise_exception_variable(...) := sets the current_exception variable
  997. # + PUSH_EXC_INFO := pushes the current_exception to the *exception stack*
  998. # + POP_EXCEPT := pops TOS from the *exception stack*
  999. _exc_stack: list[ExceptionVals] = dataclasses.field(default_factory=list)
  1000. _current_exception: Optional[ExceptionVals] = dataclasses.field(default=None)
  1001. def clear_current_exception(self) -> None:
  1002. self._current_exception = None
  1003. def set_current_exception(self, val: ExceptionVals) -> None:
  1004. self._set_context_and_break_context_reference_cycle(val)
  1005. self._current_exception = val
  1006. def move_current_exception_to_stack(self) -> None:
  1007. assert self._current_exception is not None
  1008. self.append(self._current_exception)
  1009. self.clear_current_exception()
  1010. def get_current_exception(self) -> ExceptionVals:
  1011. assert self._current_exception is not None
  1012. return self._current_exception
  1013. def _set_context_recursive(
  1014. self, val: ExceptionVals, prev_idx: int
  1015. ) -> ExceptionVals:
  1016. if (ctx := val.__context__) and type(ctx) is not ConstantVariable: # type: ignore[union-attr]
  1017. return val
  1018. if len(self._exc_stack) + prev_idx > 0:
  1019. prev = self._exc_stack[prev_idx]
  1020. self._set_context_recursive(prev, prev_idx - 1)
  1021. val.set_context(prev) # type: ignore[union-attr, arg-type]
  1022. return val
  1023. def _break_context_reference_cycle(self, val: ExceptionVals) -> None:
  1024. # See test_exceptions::test_raise_does_not_create_context_chain_cycle
  1025. # Based on https://github.com/python/cpython/blob/e635bf2e49797ecb976ce45a67fce2201a25ca68/Python/errors.c#L207-L228
  1026. # As noted on CPython, this is O(chain length) but the context chains
  1027. # are usually very small
  1028. o = slow_o = val
  1029. slow_update_toggle = False # floyd's algorithm for detecting cycle
  1030. while True:
  1031. context = o.__context__ # type: ignore[union-attr]
  1032. if type(context) is ConstantVariable: # context not set
  1033. break
  1034. if context is val:
  1035. o.set_context(CONSTANT_VARIABLE_NONE) # type: ignore[union-attr, arg-type]
  1036. break
  1037. o = context # type: ignore[assignment]
  1038. if o is slow_o:
  1039. # pre-existing cycle - all exceptions on the path were
  1040. # visited and checked
  1041. break
  1042. if slow_update_toggle:
  1043. # visited all exceptions
  1044. slow_o = slow_o.__context__ # type: ignore[union-attr, assignment]
  1045. slow_update_toggle = not slow_update_toggle
  1046. def _set_context_and_break_context_reference_cycle(
  1047. self, val: ExceptionVals
  1048. ) -> None:
  1049. # set Exception.__context__
  1050. self._set_context_recursive(val, len(self._exc_stack) - 1)
  1051. self._break_context_reference_cycle(val)
  1052. def pop(self) -> ExceptionVals:
  1053. return self._exc_stack.pop()
  1054. def append(self, val: ExceptionVals) -> None:
  1055. self._exc_stack.append(val)
  1056. def __len__(self) -> int:
  1057. return len(self._exc_stack)
  1058. def __getitem__(self, index: int) -> ExceptionVals:
  1059. return self._exc_stack[index]
  1060. def __str__(self) -> str:
  1061. return f"{self._exc_stack=} - {self._current_exception=}"
  1062. __repr__ = __str__
  1063. class InstructionTranslatorBase(
  1064. metaclass=BytecodeDispatchTableMeta,
  1065. ):
  1066. output: OutputGraph
  1067. symbolic_locals: dict[str, VariableTracker]
  1068. symbolic_globals: dict[str, VariableTracker]
  1069. symbolic_torch_function_state: SymbolicTorchFunctionState
  1070. symbolic_stream_state: SymbolicStreamState
  1071. post_prune_cell_and_freevars: Optional[dict[str, VariableTracker]]
  1072. stack: list[VariableTracker]
  1073. instruction_pointer: Optional[int]
  1074. current_instruction: Instruction
  1075. current_instruction_push: bool
  1076. block_stack: list[BlockStackEntry]
  1077. lineno: int
  1078. kw_names: Optional[ConstantVariable]
  1079. accept_prefix_inst: bool
  1080. prefix_insts: list[Instruction]
  1081. inline_depth: int
  1082. inconsistent_side_effects: bool
  1083. current_speculation: Optional[SpeculationEntry]
  1084. dispatch_table: list[Any]
  1085. exn_vt_stack: ExceptionStack
  1086. exec_recorder: Optional[ExecutionRecorder]
  1087. strict_checks_fn: Optional[Callable[[VariableTracker], bool]]
  1088. start_point: Optional[int]
  1089. # Does this function make no inlined function calls?
  1090. has_no_inlined_calls: bool
  1091. parent: Optional[InstructionTranslatorBase]
  1092. # Does this tx currently have a child tx tracing?
  1093. # Used to correctly implement should_compile_partial_graph
  1094. is_child_tracer_active: bool
  1095. debug_locals: list[tuple[VariableTracker, list[VariableTracker]]]
  1096. package: Optional[CompilePackage]
  1097. latest_bytecode_queue: deque[str]
  1098. # Store the latest bytecode before graph_break() call by user
  1099. def mark_inconsistent_side_effects(self) -> None:
  1100. """
  1101. InstructionTranslator has encountered instructions which may cause
  1102. dynamo to see a different version of history from eager
  1103. See: https://github.com/pytorch/pytorch/issues/110765
  1104. """
  1105. self.inconsistent_side_effects = True
  1106. def maybe_has_backedge(self) -> bool:
  1107. # This function employs a heuristic. It does not reliably detect a backedge.
  1108. # The heuristic is straightforward: starting from the current instruction and
  1109. # continuing to the end, if any jump instruction targets an instruction before
  1110. # the current one, there might be a backedge.
  1111. # Python 3.12 introduced changes to bytecode that group common paths in
  1112. # blockstacks (with or try...else) and allow for early returns. Consequently,
  1113. # there can be multiple RETURN_VALUE instructions. Another heuristic is to
  1114. # halt detection upon encountering the first RETURN_VALUE or RETURN_CONST.
  1115. # These heuristics can result in both false positives and negatives, but
  1116. # in either case, the Dynamo code remains valid. For false positives
  1117. # (where an edge is incorrectly marked as a backedge), Dynamo will
  1118. # graph break with a frame skip instead of potentially applying optimizations. For
  1119. # false negatives (where an edge that should be marked as a backedge
  1120. # isn't), multiple graphs may be generated if there's a break in the
  1121. # graph during a for loop. In general, its better to have fewer false
  1122. # negatives so that Dynamo does not skip the whole frame.
  1123. # If any parent tx has a backedge, then return True
  1124. cur_tx: Optional[InstructionTranslatorBase] = self
  1125. while cur_tx is not None:
  1126. cur_offset = cur_tx.current_instruction.offset
  1127. assert cur_tx.instruction_pointer is not None
  1128. for inst in cur_tx.instructions[cur_tx.instruction_pointer :]:
  1129. if inst.opname in ("RETURN_VALUE", "RETURN_CONST"):
  1130. break
  1131. if inst.opname in JUMP_OPNAMES:
  1132. jump_offset = inst.argval
  1133. if jump_offset < cur_offset:
  1134. return True
  1135. cur_tx = cur_tx.parent
  1136. return False
  1137. def cellvars(self) -> list[str]:
  1138. return self.code_options["co_cellvars"]
  1139. def freevars(self) -> list[str]:
  1140. return self.code_options["co_freevars"]
  1141. def cell_and_freevars(self) -> list[str]:
  1142. if not hasattr(self, "_cell_and_freevars"):
  1143. self._cell_and_freevars = self.cellvars() + self.freevars()
  1144. return self._cell_and_freevars
  1145. def prune_dead_locals(self) -> None:
  1146. # keep cell and freevar references alive
  1147. self.post_prune_cell_and_freevars = {
  1148. k: v
  1149. for k, v in self.symbolic_locals.items()
  1150. if k in self.cell_and_freevars()
  1151. }
  1152. # Only keep the locals that must remain on the stack.
  1153. reads = livevars_analysis(self.instructions, self.current_instruction)
  1154. self.symbolic_locals = {
  1155. k: v for k, v in self.symbolic_locals.items() if k in reads
  1156. }
  1157. def call_function(
  1158. self,
  1159. fn: VariableTracker,
  1160. args: list[VariableTracker],
  1161. kwargs: dict[str, VariableTracker],
  1162. ) -> None:
  1163. assert isinstance(fn, VariableTracker)
  1164. assert isinstance(args, list)
  1165. assert isinstance(kwargs, dict)
  1166. assert all(
  1167. isinstance(x, VariableTracker)
  1168. for x in itertools.chain(args, kwargs.values())
  1169. )
  1170. inner_fn = None
  1171. if hasattr(fn, "value"):
  1172. inner_fn = fn.value
  1173. if hasattr(fn, "fn"):
  1174. inner_fn = fn.fn
  1175. if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
  1176. raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
  1177. self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type]
  1178. def inline_generator_function(
  1179. self,
  1180. fn: BaseUserFunctionVariable,
  1181. args: Sequence[VariableTracker],
  1182. kwargs: dict[str, VariableTracker],
  1183. ) -> VariableTracker:
  1184. """
  1185. Redirect the call to the generator "call_function"
  1186. """
  1187. if not isinstance(fn, LocalGeneratorFunctionVariable):
  1188. fn = LocalGeneratorFunctionVariable(fn)
  1189. return fn.call_function(self, args, kwargs) # type: ignore[arg-type]
  1190. def inline_user_function_return(
  1191. self,
  1192. fn: BaseUserFunctionVariable,
  1193. args: Sequence[VariableTracker],
  1194. kwargs: dict[str, VariableTracker],
  1195. ) -> Any:
  1196. """
  1197. A call to some user defined function by inlining it.
  1198. """
  1199. if config.enable_faithful_generator_behavior and is_generator(fn.get_code()):
  1200. return self.inline_generator_function(fn, args, kwargs)
  1201. else:
  1202. return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  1203. def get_line_of_code_header(self, lineno: Optional[int] = None) -> str:
  1204. if lineno is None:
  1205. lineno = self.lineno
  1206. inline_depth_str = (
  1207. f" (inline depth: {self.inline_depth})" if self.inline_depth > 0 else ""
  1208. )
  1209. funcname = get_funcname(self.f_code.co_filename, lineno)
  1210. funcname_str = "" if funcname is None else f" ({funcname})"
  1211. return f"{self.f_code.co_filename}:{lineno} in {self.f_code.co_name}{funcname_str}{inline_depth_str}"
  1212. def get_log_starts_line_log_str(self) -> str:
  1213. log_str = f"TRACE starts_line {self.get_line_of_code_header()}\n"
  1214. line = linecache.getline(self.f_code.co_filename, self.lineno).rstrip()
  1215. log_str += f" {line}"
  1216. return log_str
  1217. def starts_line(self, lineno: int) -> None:
  1218. if self.lineno == lineno:
  1219. return
  1220. self.lineno = lineno
  1221. TracingContext.set_current_loc(
  1222. self.f_code.co_filename, lineno, self.f_code.co_name
  1223. )
  1224. if self.is_trace_source_log_enabled:
  1225. trace_source_log.debug("%s", LazyString(self.get_log_starts_line_log_str))
  1226. def step(self) -> bool:
  1227. """Process exactly one instruction, return False we should exit"""
  1228. self.error_on_graph_break = _get_error_on_graph_break()
  1229. ip = self.instruction_pointer
  1230. if ip is None:
  1231. return False
  1232. self.current_instruction = inst = self.instructions[ip]
  1233. self.instruction_pointer = ip + 1
  1234. if inst.starts_line:
  1235. self.starts_line(inst.starts_line)
  1236. if (
  1237. not self.stack
  1238. and self.should_compile_partial_graph()
  1239. and self.is_non_empty_graph()
  1240. ):
  1241. self.current_speculation = self.speculate()
  1242. if self.current_speculation.failed(self):
  1243. self.step_graph_break(inst)
  1244. return False
  1245. if self.is_trace_bytecode_log_enabled:
  1246. trace_bytecode_log.debug(
  1247. "TRACE %s %s %s", inst.opname, inst.argval, repr(self.stack)
  1248. )
  1249. # Store the latest 20 bytecode execution for the process,
  1250. # Used repr for byte processing and limiting the length to 2048
  1251. if config.verbose:
  1252. try:
  1253. stack_repr = repr(self.stack)
  1254. except ValueError:
  1255. # Handle large integers that exceed sys.int_info.str_digits_check_threshold
  1256. stack_repr = "<self.stack repr truncated due to large integer>"
  1257. self.latest_bytecode_queue.append(
  1258. f"TRACE {inst.opname} {repr(inst.argval)} {stack_repr}"
  1259. )
  1260. self.update_block_stack(inst)
  1261. try:
  1262. self.dispatch_table[inst.opcode](self, inst)
  1263. return not self.output.should_exit
  1264. except TensorifyScalarRestartAnalysis:
  1265. raise
  1266. except exc.ObservedException as e:
  1267. self.exception_handler(e)
  1268. return True
  1269. except (ReturnValueOp, YieldValueOp):
  1270. return False
  1271. except (Unsupported, StepUnsupported) as e:
  1272. # More restrictive condition than should_compile_partial_graph:
  1273. # if this condition is true, then we SHOULD NOT attempt to find
  1274. # a previous checkpoint to resume from and try to resume - we should
  1275. # immediately error out.
  1276. # The condition is more restrictive because, it may be possible to resume significantly earlier
  1277. # in the code (the most recent speculation point). This happens, for example, in the case
  1278. # of a graph break in a try block.
  1279. if (
  1280. self.one_graph
  1281. or self.error_on_graph_break
  1282. or self.is_tracing_resume_prologue
  1283. or (isinstance(e, Unsupported) and e.skip_frame)
  1284. ):
  1285. if isinstance(e, StepUnsupported):
  1286. unimplemented(
  1287. gb_type="cannot resume from torch._dynamo.step_unsupported()",
  1288. context="",
  1289. explanation="traced torch._dynamo.step_unsupported(), but Dynamo is instructed "
  1290. "to error on graph break. This graph break is used for debugging only.",
  1291. hints=[
  1292. "Remove the torch._dynamo.step_unsupported() call.",
  1293. "Make sure fullgraph=False and error_on_graph_break=False.",
  1294. *graph_break_hints.DYNAMO_BUG,
  1295. ],
  1296. )
  1297. raise
  1298. if self.current_speculation is None:
  1299. log.debug("empty checkpoint - cannot resume from graph break")
  1300. if isinstance(e, StepUnsupported):
  1301. unimplemented(
  1302. gb_type="torch._dynamo.step_unsupported() with empty checkpoint",
  1303. context="",
  1304. explanation="traced torch._dynamo.step_unsupported(), but there is no checkpoint "
  1305. "to step_graph_break from. This graph break is used for debugging only.",
  1306. hints=[
  1307. "Remove the torch._dynamo.step_unsupported() call.",
  1308. "Include at least one checkpoint: (1) include at least 2 ops and (2) make sure there is some "
  1309. "line of code that is not in a try/with block, and has an empty Python stack.",
  1310. *graph_break_hints.DYNAMO_BUG,
  1311. ],
  1312. skip_frame=True,
  1313. )
  1314. assert isinstance(e, Unsupported)
  1315. e.skip_frame = True
  1316. raise
  1317. reason = (
  1318. "Encountered graph break that we cannot resume from. "
  1319. "Compiling up to the previous resumable state, "
  1320. "then skipping the rest of the function. "
  1321. f"Graph break encountered:\n\n{str(e)}"
  1322. )
  1323. self.log_graph_break(
  1324. self.code_options,
  1325. reason=reason,
  1326. exc=e,
  1327. )
  1328. self.current_speculation.fail_and_restart_analysis(self.error_on_graph_break)
  1329. return False
  1330. if sys.version_info >= (3, 11):
  1331. def update_block_stack(self, inst: Instruction) -> None:
  1332. # 3.11+ no longer uses a block stack, but we still keep track of one
  1333. # so that we know which contexts are currently active.
  1334. # For our purposes, all exception table entries with the same target
  1335. # are considered to be part of the same "block".
  1336. # NOTE: we only keep track of with blocks that are not contained in try blocks.
  1337. # This is because we will not create continuation functions on graph breaks in try blocks,
  1338. # but we may for with blocks. We do not push blocks here since
  1339. # with blocks are pushed when handling BEFORE_WITH.
  1340. entry = inst.exn_tab_entry
  1341. if entry:
  1342. # Detect when we have exited the top with block.
  1343. # The with blocks on the block stack are not enclosed in try
  1344. # blocks, so a with block's cleanup code should be in the
  1345. # previous with block (if any).
  1346. if (
  1347. len(self.block_stack) >= 2
  1348. and entry.target is not self.block_stack[-1].target
  1349. and entry.target is self.block_stack[-2].target
  1350. ):
  1351. # exit the current block
  1352. self.block_stack.pop()
  1353. else:
  1354. # no longer in any block
  1355. # It is possible for NOPs to be between two instructions
  1356. # in the same block, but the NOPs are not covered by an
  1357. # exception table entry. In this case, assume that we
  1358. # are still in the same block.
  1359. # In 3.12+, JUMP_BACKWARD might also not be covered by
  1360. # an exception table entry, so we also assume that we
  1361. # are still in the same block. It is probably safe to do
  1362. # this in 3.11, even though we haven't encountered this case before.
  1363. # In 3.14+, NOT_TAKEN might also not be covered by an exn table entry.
  1364. if self.block_stack and inst.opname not in (
  1365. "NOP",
  1366. "JUMP_BACKWARD",
  1367. "NOT_TAKEN",
  1368. ):
  1369. # If we really escape from a block and the current
  1370. # instruction is not in another block, then there
  1371. # should be no other nested blocks that we are in.
  1372. assert len(self.block_stack) == 1
  1373. self.block_stack.pop()
  1374. else:
  1375. def update_block_stack(self, inst: Instruction) -> None:
  1376. pass
  1377. @property
  1378. def next_instruction(self) -> Instruction:
  1379. assert self.instruction_pointer is not None
  1380. return self.instructions[self.instruction_pointer]
  1381. def step_graph_break(self, continue_inst: Instruction) -> None:
  1382. # generate code from checkpoint
  1383. assert not self.output.output_instructions
  1384. assert self.current_speculation is not None
  1385. # NOTE: adding an assert here since it seems like the only place
  1386. # where we call step_graph_break right now is when the stack is empty,
  1387. # so let's enforce that for now.
  1388. assert not self.stack
  1389. # NOTE: if we support non-empty self.stack in the future, the `stack_pops` argument
  1390. # below should be set to the stack length to ensure that the stack is codegen'd
  1391. # for the rest of the function.
  1392. log.debug("step triggered compile")
  1393. all_stack_locals_metadata = self.output.compile_subgraph(
  1394. self,
  1395. reason=GraphCompileReason("step_unsupported", [self.frame_summary()]),
  1396. )
  1397. # current frame state
  1398. # cells,
  1399. # [
  1400. # frame N locals,
  1401. # frame N-1 stack + locals,
  1402. # ...,
  1403. # frame 1 stack + locals,
  1404. # ],
  1405. if self.parent:
  1406. from .eval_frame import skip_code
  1407. # nested graph break
  1408. assert config.nested_graph_breaks
  1409. cg = PyCodegen(self.output.root_tx)
  1410. # codegen cells and frame values only for frame N
  1411. cg.extend_output(
  1412. [
  1413. *create_copy(2),
  1414. cg.create_load_const(0),
  1415. cg.create_binary_subscr(),
  1416. create_instruction("BUILD_LIST", arg=1),
  1417. *create_copy(2),
  1418. cg.create_load_const(0),
  1419. cg.create_binary_subscr(),
  1420. create_instruction("BUILD_LIST", arg=1),
  1421. ]
  1422. )
  1423. # No need to fix stack, since stack is assumed to be empty here.
  1424. # Do NOT handle_inactive_ctx because we will be skipping this resume code.
  1425. leaf_resume_code, leaf_resume_name = self.create_resume(
  1426. 0, continue_inst, all_stack_locals_metadata[0], [], cg, True, False
  1427. )
  1428. skip_code(leaf_resume_code)
  1429. cleanup: list[Instruction] = []
  1430. _reconstruct_block_stack(self.parent, cg, cleanup)
  1431. # current frame state
  1432. # cells,
  1433. # [
  1434. # frame N locals,
  1435. # frame N-1 stack + locals,
  1436. # ...,
  1437. # frame 1 stack + locals,
  1438. # ], [frame N cells], [frame N locals],
  1439. self.codegen_call_resume([leaf_resume_code], [leaf_resume_name], cg)
  1440. cg.extend_output(cleanup)
  1441. # current frame state
  1442. # cells,
  1443. # [
  1444. # frame N locals,
  1445. # frame N-1 stack + locals,
  1446. # ...,
  1447. # frame 1 stack + locals,
  1448. # ], leaf_resume result
  1449. # pop frame N cells and locals
  1450. cg.extend_output(
  1451. [
  1452. *create_copy(2),
  1453. cg.create_load_const(0),
  1454. create_instruction("DELETE_SUBSCR"),
  1455. *create_copy(3),
  1456. cg.create_load_const(0),
  1457. create_instruction("DELETE_SUBSCR"),
  1458. ]
  1459. )
  1460. # current frame state
  1461. # cells, frame_values, leaf_resume result
  1462. # extract frame N-1 stack
  1463. num_stack = all_stack_locals_metadata[1].num_stack
  1464. cg.extend_output(
  1465. [
  1466. *create_copy(2),
  1467. cg.create_load_const(0),
  1468. cg.create_binary_subscr(),
  1469. *create_binary_slice(0, num_stack),
  1470. ]
  1471. )
  1472. # current frame state
  1473. # cells, frame_values, leaf_resume result, frame N-1 stack
  1474. # add the leaf_resume result to frame N-1 stack
  1475. cg.extend_output(
  1476. [
  1477. *create_swap(2),
  1478. create_instruction("LIST_APPEND", arg=1),
  1479. ]
  1480. )
  1481. self.parent.push(UnknownVariable())
  1482. all_stack_locals_metadata[1].num_stack += 1
  1483. # current frame state
  1484. # cells, frame_values, frame N-1 stack + leaf_resume result
  1485. # remove frame N-1 stack from frame_values
  1486. if num_stack > 0:
  1487. cg.extend_output(
  1488. # frame_values[0] = frame_values[0][num_stack:]
  1489. [
  1490. *create_copy(2),
  1491. cg.create_load_const(0),
  1492. cg.create_binary_subscr(),
  1493. create_dup_top(),
  1494. *create_binary_slice(num_stack, None),
  1495. *create_swap(2),
  1496. cg.create_load_const(0),
  1497. create_instruction("STORE_SUBSCR"),
  1498. ]
  1499. )
  1500. # current frame state
  1501. # cells, frame_values, frame N-1 stack + leaf_resume result
  1502. # unpack the stack (need to unpack twice since UNPACK_SEQUENCE unpacks in reverse order)
  1503. cg.extend_output(
  1504. [
  1505. create_instruction("UNPACK_SEQUENCE", arg=num_stack + 1),
  1506. create_instruction("BUILD_LIST", arg=num_stack + 1),
  1507. create_instruction("UNPACK_SEQUENCE", arg=num_stack + 1),
  1508. ]
  1509. )
  1510. # call the remaining resume functions
  1511. # current frame state
  1512. # [frame N-1 cells, ..., frame 1 cells],
  1513. # [
  1514. # frame N-1 locals,
  1515. # frame N-2 stack + locals,
  1516. # ...,
  1517. # frame 1 stack + locals,
  1518. # ], *(frame N-1 stack), leaf_resume result
  1519. self.output.add_output_instructions(
  1520. cg.get_instructions()
  1521. + self.parent.create_call_resume_at(
  1522. self.parent.next_instruction, all_stack_locals_metadata[1:]
  1523. )
  1524. )
  1525. else:
  1526. # NOTE: if WithExitFunctionVariable is reconstructed here, then the generated bytecode will be wrong.
  1527. # However, we don't expect this to happen since WithExitFunctionVariable can only be present on the stack,
  1528. # which must be empty in step graph breaks.
  1529. # If we do decide to support step graph breaks with WithExitFunctionVariable in the future, we should
  1530. # either call a skipped resume function as in the nested step graph break case, or reconstruct the
  1531. # proper context manager object from the class (like what we used to do historically in variables/ctx_manager.py).
  1532. # pop cells
  1533. self.output.add_output_instructions(
  1534. [
  1535. *create_swap(2),
  1536. create_instruction("POP_TOP"),
  1537. ]
  1538. )
  1539. # load locals from frame values
  1540. cg = PyCodegen(self.output.root_tx)
  1541. self.output.add_output_instructions(
  1542. [
  1543. cg.create_load_const(-1),
  1544. cg.create_binary_subscr(),
  1545. ]
  1546. )
  1547. for local, idx in all_stack_locals_metadata[-1].locals_names.items():
  1548. self.output.add_output_instructions(
  1549. [
  1550. create_dup_top(),
  1551. cg.create_load_const(idx),
  1552. cg.create_binary_subscr(),
  1553. cg.create_store(local),
  1554. ]
  1555. )
  1556. self.output.add_output_instructions(
  1557. [
  1558. create_instruction("POP_TOP"),
  1559. create_jump_absolute(continue_inst),
  1560. *self.instructions,
  1561. ]
  1562. )
  1563. def run_ctx_mgr(self) -> Any:
  1564. # NB: Don't push the top level frame summary; set_current_loc will
  1565. # take care of it. However, DO make sure we attach real_stack to
  1566. # exceptions
  1567. return TracingContext.current_frame(None)
  1568. def run(self) -> None:
  1569. with self.run_ctx_mgr():
  1570. dump_file(self.f_code.co_filename)
  1571. try:
  1572. self.output.push_tx(self)
  1573. self.start_point = self.instruction_pointer
  1574. try:
  1575. while self.step():
  1576. pass
  1577. except Exception as e:
  1578. if self.is_tracing_resume_prologue:
  1579. raise ResumePrologueTracingError(
  1580. "Error while tracing through a Dynamo-generated resume function prologue. "
  1581. "Errors are not allowed when tracing resume function prologues.\n"
  1582. f"{type(e).__qualname__}: {str(e)}"
  1583. ).with_traceback(e.__traceback__) from None
  1584. raise
  1585. except TensorifyScalarRestartAnalysis:
  1586. raise
  1587. except BackendCompilerFailed:
  1588. raise
  1589. except RuntimeError as e:
  1590. # If the root tx fails to handle the graph break, then the caller (convert_frame)
  1591. # will skip the frame and fall back to eager.
  1592. # This code path happens e.g. for bytecodes we don't support
  1593. # or when we are unable to resume from a graph break.
  1594. if (
  1595. isinstance(e, Unsupported)
  1596. and isinstance(self, InstructionTranslator)
  1597. and not self.error_on_graph_break
  1598. and not self.one_graph
  1599. ):
  1600. # log graph break if we won't error
  1601. reason = (
  1602. "Failed to handle graph break gracefully. "
  1603. "Skipping the function and falling back to eager. "
  1604. f"Graph break encountered:\n\n{str(e)}"
  1605. )
  1606. self.log_graph_break(
  1607. self.code_options,
  1608. reason=reason,
  1609. exc=e,
  1610. )
  1611. if hasattr(e, "msg") and "Data-dependent" in e.msg:
  1612. readable_graph = torch.fx.GraphModule(
  1613. self.output.nn_modules, self.output.graph
  1614. ).print_readable(
  1615. print_output=False, include_stride=True, include_device=True
  1616. )
  1617. e.partial_fx_graph = readable_graph # type: ignore[attr-defined]
  1618. raise
  1619. raise
  1620. except Exception as e:
  1621. if self.exec_recorder:
  1622. e.exec_record = self.exec_recorder.get_record() # type: ignore[attr-defined]
  1623. raise
  1624. finally:
  1625. self.output.pop_tx()
  1626. # Cleanup the outputGraph to delete the held tensors. We perform the
  1627. # cleanup only for InstructionTranslator and not
  1628. # InliningInstructionTranslator. The InliningInstructionTranslator
  1629. # mutates the output object and is restored to original state if
  1630. # there was an exception.
  1631. if isinstance(self, InstructionTranslator):
  1632. self.output.cleanup()
  1633. # Note that this call maybe redundant if compile_subgraph is
  1634. # called. This is ok, because calling exit stack close()
  1635. # twice is not an issue (second stop is a no op).
  1636. self.output.mark_bytecode_tracing_stop()
  1637. def push(self, val: Optional[VariableTracker]) -> None:
  1638. assert val is None or isinstance(val, VariableTracker), (
  1639. f"push expects VariableTracker, got {typestr(val)}"
  1640. )
  1641. self.stack.append(val) # type: ignore[arg-type]
  1642. def push_many(self, vals: list[VariableTracker]) -> None:
  1643. for val in vals:
  1644. self.push(val)
  1645. def pop(self) -> VariableTracker:
  1646. return self.stack.pop()
  1647. def popn(self, n: int) -> list[VariableTracker]:
  1648. return [*reversed([self.pop() for _ in range(n)])]
  1649. def LOAD_FAST(self, inst: Instruction) -> None:
  1650. name = inst.argval
  1651. if self.exec_recorder and name in self.f_locals:
  1652. self.exec_recorder.add_local_var(name, self.f_locals[name])
  1653. try:
  1654. self.push(self.symbolic_locals[name].unwrap())
  1655. except KeyError:
  1656. if name.startswith("."):
  1657. try:
  1658. # This happens in dict/list comprehensions
  1659. new_name = name.replace(".", "implicit")
  1660. self.push(self.symbolic_locals[new_name])
  1661. except KeyError:
  1662. unimplemented(
  1663. gb_type="Attempted to read undefined local variable (implicit)",
  1664. context=f"LOAD_FAST {name}",
  1665. explanation=f"Could not find an implicit local variable with name `{name}`",
  1666. hints=[
  1667. "This happens in dict/list comprehensions",
  1668. *graph_break_hints.USER_ERROR,
  1669. ],
  1670. )
  1671. else:
  1672. unimplemented(
  1673. gb_type="Attempted to read undefined local variable",
  1674. context=f"LOAD_FAST {name}",
  1675. explanation=f"Could not find a local variable with name `{name}`",
  1676. hints=[*graph_break_hints.USER_ERROR],
  1677. )
  1678. # for continuation functions
  1679. if name.startswith("__stack"):
  1680. self.symbolic_locals.pop(name)
  1681. def LOAD_DEREF(self, inst: Instruction) -> None:
  1682. assert inst.argval in self.cell_and_freevars()
  1683. cell = self.symbolic_locals[inst.argval]
  1684. contents_var = self.output.side_effects.load_cell(cell)
  1685. self.push(contents_var)
  1686. if self.exec_recorder and inst.argval in self.f_locals:
  1687. self.exec_recorder.add_local_var(inst.argval, self.f_locals[inst.argval])
  1688. def STORE_FAST(self, inst: Instruction) -> None:
  1689. name = inst.argval
  1690. loaded_vt = self.pop()
  1691. loaded_vt.set_name_hint(name)
  1692. self.symbolic_locals[name] = loaded_vt
  1693. if name == IS_TRACING_RESUME_PROLOGUE_VARNAME:
  1694. val = loaded_vt.as_python_constant()
  1695. assert type(val) is bool
  1696. self.is_tracing_resume_prologue = val
  1697. def DELETE_FAST(self, inst: Instruction) -> None:
  1698. del self.symbolic_locals[inst.argval]
  1699. def STORE_DEREF(self, inst: Instruction) -> None: # type: ignore[override]
  1700. assert inst.argval in self.cell_and_freevars()
  1701. cell = self.symbolic_locals[inst.argval]
  1702. val = self.pop()
  1703. self.output.side_effects.store_cell(cell, val)
  1704. assert isinstance(cell, CellVariable) # tame mypy
  1705. if cell.local_name is not None:
  1706. val.set_name_hint(cell.local_name) # type: ignore[attr-defined]
  1707. LOAD_CLOSURE = LOAD_FAST
  1708. def _load_const(self, inst: Instruction) -> VariableTracker:
  1709. i = inst.arg
  1710. if i is None:
  1711. return ConstantVariable.create(value=inst.argval) # type: ignore[return-value]
  1712. val = self._constants_cache[i]
  1713. if not val:
  1714. self._constants_cache[i] = ConstantVariable.create(value=inst.argval) # type: ignore[call-overload]
  1715. val = self._constants_cache[i]
  1716. assert val is not None
  1717. return val
  1718. def LOAD_CONST(self, inst: Instruction) -> None:
  1719. self.push(self._load_const(inst))
  1720. def _load_global(self, inst: Instruction) -> None:
  1721. name = inst.argval
  1722. if self.exec_recorder:
  1723. if name in self.f_globals:
  1724. self.exec_recorder.add_global_var(name, self.f_globals[name])
  1725. else:
  1726. assert name in self.f_builtins
  1727. self.exec_recorder.builtins[name] = self.f_builtins[name]
  1728. if name not in self.f_globals:
  1729. return self.load_builtin(inst)
  1730. if name in self.symbolic_globals:
  1731. variable = self.output.side_effects[self.symbolic_globals[name]]
  1732. self.push(self.output.side_effects.load_global(variable, name))
  1733. return
  1734. value = self.f_globals[name]
  1735. self.push(VariableTracker.build(self, value, GlobalSource(name)))
  1736. @functools.cached_property
  1737. def nn_modules_globals_vt(self) -> VariableTracker:
  1738. module_name = "torch.nn.modules.module"
  1739. module_source = self.import_source(module_name)
  1740. fglobals_value = _import_module(module_name)
  1741. return VariableTracker.build(self, fglobals_value, module_source)
  1742. def LOAD_GLOBAL(self, inst: Instruction) -> None:
  1743. assert inst.arg is not None
  1744. if sys.version_info >= (3, 11) and sys.version_info < (3, 13) and inst.arg % 2:
  1745. self.PUSH_NULL(inst)
  1746. self._load_global(inst)
  1747. if sys.version_info >= (3, 13) and inst.arg % 2:
  1748. self.PUSH_NULL(inst)
  1749. def STORE_GLOBAL(self, inst: Instruction) -> None:
  1750. value = self.pop()
  1751. name = inst.argval
  1752. source = GlobalSource(name)
  1753. if name not in self.symbolic_globals:
  1754. self.symbolic_globals[name] = object() # type: ignore[assignment] # sentinel object
  1755. variable = self.output.side_effects.track_global_existing(
  1756. source, self.symbolic_globals[name]
  1757. )
  1758. if isinstance(value, RemovableHandleVariable):
  1759. unimplemented(
  1760. gb_type="Storing Tensor hook handle in globals",
  1761. context=name,
  1762. explanation="This is not supported.",
  1763. hints=[],
  1764. )
  1765. self.output.side_effects.store_global(variable, name, value)
  1766. # Cache note: This cache only exists for the duration of this
  1767. # InstructionTranslator - so it should be safe to do.
  1768. @cache_method
  1769. def import_source(self, module_name: str) -> GlobalSource:
  1770. """Create an alias to a module for use in guards"""
  1771. if "torch_package" in module_name:
  1772. value = torch.package.package_importer._package_imported_modules[
  1773. module_name
  1774. ]
  1775. alias = (
  1776. module_name.replace(">", "_").replace("<", "_").replace(".", "_dot_")
  1777. )
  1778. else:
  1779. value = _import_module(module_name)
  1780. alias = f"__import_{module_name.replace('.', '_dot_')}"
  1781. if self.package is not None:
  1782. self.package.add_import_source(alias, module_name)
  1783. self.output.import_sources[alias] = module_name
  1784. f_globals = self.output.global_scope
  1785. assert alias not in f_globals or f_globals[alias] is value
  1786. f_globals[alias] = value
  1787. self.output.update_co_names(alias)
  1788. return GlobalSource(alias)
  1789. def resolve_name(self, name: str, package: str, level: int) -> str:
  1790. """
  1791. Copied from the Cpython implementation of __import__
  1792. Resolve a relative module name to an absolute one.
  1793. https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L902
  1794. """
  1795. bits = package.rsplit(".", level - 1)
  1796. if len(bits) < level:
  1797. raise ImportError("attempted relative import beyond top-level package")
  1798. base = bits[0]
  1799. return f"{base}.{name}" if name else base
  1800. def calc_package(self) -> str:
  1801. """
  1802. Copied from the Cpython implementation of __import__
  1803. https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L1090
  1804. """
  1805. package = self.f_globals.get("__package__")
  1806. spec = self.f_globals.get("__spec__")
  1807. if package is not None:
  1808. if spec is not None and package != spec.parent:
  1809. log.warning(
  1810. "__package__ != __spec__.parent (%r != %r)",
  1811. package,
  1812. spec.parent,
  1813. stacklevel=3,
  1814. )
  1815. return package
  1816. elif spec is not None:
  1817. return spec.parent
  1818. else:
  1819. log.warning(
  1820. "can't resolve package from __spec__ or __package__, "
  1821. "falling back on __name__ and __path__",
  1822. stacklevel=3,
  1823. )
  1824. package = self.f_globals["__name__"]
  1825. if "__path__" not in self.f_globals:
  1826. package = package.rpartition(".")[0]
  1827. return package
  1828. def IMPORT_NAME(self, inst: Instruction) -> None:
  1829. level, fromlist = self.popn(2)
  1830. level = level.as_python_constant()
  1831. fromlist = fromlist.as_python_constant()
  1832. module_name = inst.argval
  1833. # Are we replaying? if so, load recorded module
  1834. recorded_name = (
  1835. f"{ExecutionRecorder.LOCAL_MOD_PREFIX}_{level}_{fromlist}_{module_name}"
  1836. )
  1837. if recorded_name in self.f_globals:
  1838. value = self.f_globals[recorded_name]
  1839. source = GlobalSource(recorded_name)
  1840. else:
  1841. try:
  1842. value = __import__(
  1843. module_name,
  1844. fromlist=fromlist,
  1845. level=level,
  1846. globals=self.f_globals,
  1847. )
  1848. except ImportError:
  1849. unimplemented(
  1850. gb_type="Import failure",
  1851. context=f"module_name: {module_name}, fromlist: {fromlist}, level={level}",
  1852. explanation="Failure when attempting to import.",
  1853. hints=[*graph_break_hints.USER_ERROR],
  1854. )
  1855. if level != 0:
  1856. pkg = self.calc_package()
  1857. module_name = self.resolve_name(module_name, pkg, level)
  1858. # For __import__, when the name variable is of the form package.module,
  1859. # normally, the top-level package (the name up till the first dot) is
  1860. # returned, not the module named by module_name. However, when a
  1861. # non-empty fromlist argument is given, the module named by name is
  1862. # returned. Therefore, we set the source correctly here.
  1863. if not fromlist:
  1864. top_level_module_name = module_name.partition(".")[0]
  1865. source = self.import_source(top_level_module_name)
  1866. else:
  1867. source = self.import_source(module_name)
  1868. if self.exec_recorder:
  1869. # pyrefly: ignore [unbound-name]
  1870. self.exec_recorder.add_local_mod(recorded_name, value)
  1871. # pyrefly: ignore [unbound-name]
  1872. if istype(value, (types.ModuleType, DummyModule)):
  1873. # pyrefly: ignore [unbound-name]
  1874. self.push(PythonModuleVariable(value, source=source))
  1875. else:
  1876. unimplemented(
  1877. gb_type="Bad import result",
  1878. # pyrefly: ignore [unbound-name]
  1879. context=typestr(value),
  1880. explanation="Import result is not a Python module.",
  1881. hints=[],
  1882. )
  1883. # fb internal 3.12 opcode
  1884. EAGER_IMPORT_NAME = IMPORT_NAME
  1885. def IMPORT_FROM(self, inst: Instruction) -> None:
  1886. self.DUP_TOP(inst)
  1887. self._load_attr(inst.argval)
  1888. # Cache note: This cache only exists for the duration of this
  1889. # InstructionTranslator - so it should be safe to do.
  1890. @cache_method
  1891. def load_builtin_from_argval(self, argval: Any) -> VariableTracker:
  1892. if argval not in self.f_builtins:
  1893. unimplemented(
  1894. gb_type="failed to find name in frame builtins",
  1895. context="",
  1896. explanation=f"Failed to find name `{argval}` in frame's builtins.",
  1897. hints=[
  1898. *graph_break_hints.DYNAMO_BUG,
  1899. ],
  1900. )
  1901. val = self.f_builtins[argval]
  1902. if callable(val):
  1903. builtins_source = GlobalSource(
  1904. self.output.name_of_builtins_dict_key_in_fglobals
  1905. )
  1906. var_source = DictGetItemSource(builtins_source, argval)
  1907. return VariableTracker.build(self, val, var_source)
  1908. else:
  1909. assert is_builtin_constant(val)
  1910. return ConstantVariable.create(value=val)
  1911. def load_builtin(self, inst: Instruction) -> None:
  1912. self.push(self.load_builtin_from_argval(inst.argval))
  1913. def jump(self, inst: Instruction | BlockStackEntry) -> None:
  1914. assert self.instruction_pointer is not None
  1915. assert self.start_point is not None
  1916. assert inst.target is not None
  1917. get_metrics_context().increment(
  1918. "ir_count", self.instruction_pointer - self.start_point
  1919. )
  1920. self.instruction_pointer = self.indexof[inst.target]
  1921. self.start_point = self.instruction_pointer
  1922. JUMP_FORWARD = jump
  1923. JUMP_ABSOLUTE = jump
  1924. POP_JUMP_IF_FALSE = generic_jump(operator.not_, False)
  1925. POP_JUMP_IF_TRUE = generic_jump(operator.truth, False)
  1926. JUMP_IF_FALSE_OR_POP = generic_jump(operator.not_, True)
  1927. JUMP_IF_TRUE_OR_POP = generic_jump(operator.truth, True)
  1928. def SETUP_LOOP(self, inst: Instruction) -> None:
  1929. # only exists in python<=3.7
  1930. assert inst.target is not None
  1931. self.block_stack.append(BlockStackEntry(inst, inst.target, len(self.stack)))
  1932. def SETUP_EXCEPT(self, inst: Instruction) -> None:
  1933. # only exists in python<=3.7
  1934. assert inst.target is not None
  1935. self.block_stack.append(BlockStackEntry(inst, inst.target, len(self.stack)))
  1936. def POP_BLOCK(self, inst: Instruction) -> None:
  1937. self.block_stack.pop()
  1938. def SETUP_WITH(self, inst: Instruction) -> None:
  1939. self.setup_or_before_with(inst)
  1940. def SETUP_FINALLY(self, inst: Instruction) -> None:
  1941. assert inst.target is not None
  1942. self.block_stack.append(BlockStackEntry(inst, inst.target, len(self.stack)))
  1943. def BEGIN_FINALLY(self, inst: Instruction) -> None:
  1944. self.push(None)
  1945. def WITH_CLEANUP_START(self, inst: Instruction) -> None:
  1946. exit, exc = self.popn(2)
  1947. assert exc is None
  1948. self.push(exc)
  1949. self.push(exit.call_function(self, [CONSTANT_VARIABLE_NONE] * 3, {}))
  1950. def WITH_CLEANUP_FINISH(self, inst: Instruction) -> None:
  1951. self.popn(2)
  1952. self.push(None)
  1953. def FOR_ITER(self, inst: Instruction) -> None:
  1954. it = self.pop().realize()
  1955. self.push(it)
  1956. try:
  1957. val = it.next_variable(self)
  1958. self.push(val)
  1959. except (StopIteration, exc.ObservedUserStopIteration) as e:
  1960. if isinstance(e, exc.ObservedUserStopIteration):
  1961. exc.handle_observed_exception(self)
  1962. if sys.version_info >= (3, 12):
  1963. # CPython 3.12 actually jumps to the instruction after the END_FOR
  1964. # and performs the action of END_FOR as part of FOR_ITER. We jump
  1965. # to the END_FOR and run it, so we need to make sure 2 values are
  1966. # on the stack for it to pop.
  1967. self.push(CONSTANT_VARIABLE_NONE)
  1968. else:
  1969. # pop the iterator in Python < 3.12
  1970. self.pop()
  1971. self.jump(inst)
  1972. def _create_exception_type(self, val: VariableTracker) -> VariableTracker:
  1973. if isinstance(
  1974. val, (variables.BuiltinVariable, UserDefinedExceptionClassVariable)
  1975. ):
  1976. # Create the instance of the exception type
  1977. # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L6547-L6549
  1978. val = val.call_function(self, [], {}) # type: ignore[arg-type]
  1979. return val
  1980. def _attach_traceback_to_exception(self, exc: ExceptionVals) -> None:
  1981. # based on CPython's PyTraceBack_Here impl
  1982. frame_summary = self.frame_summary()
  1983. tb = exc.var_getattr(
  1984. # pyrefly: ignore [bad-argument-type]
  1985. self,
  1986. "__traceback__",
  1987. )
  1988. assert isinstance(
  1989. tb, (ConstantVariable, TracebackVariable)
  1990. ) # make pyrefly happy
  1991. new_tb = TracebackVariable.from_frame_summary(frame_summary, tb)
  1992. exc.call_method(
  1993. # pyrefly: ignore [bad-argument-type]
  1994. self,
  1995. "__setattr__",
  1996. [ConstantVariable("__traceback__"), new_tb],
  1997. {},
  1998. )
  1999. def _raise_exception_variable(self, val: VariableTracker) -> NoReturn:
  2000. # User can raise exception in 2 ways
  2001. # 1) raise exception type - raise NotImplementedError
  2002. # 2) raise exception instance - raise NotImplementedError("foo")
  2003. # 1) when user raises exception type
  2004. val = self._create_exception_type(val)
  2005. # Handle https://peps.python.org/pep-0479/
  2006. # CPython 3.12+ has a specific bytecode instruction (CALL_INTRINSIC_1 3) for this
  2007. if (
  2008. is_generator(self.f_code)
  2009. and isinstance(val, variables.ExceptionVariable)
  2010. and val.exc_type is StopIteration
  2011. ):
  2012. val = variables.BuiltinVariable(RuntimeError).call_function(self, [], {}) # type: ignore[arg-type]
  2013. # Capture the python_stack when the exception is first raised.
  2014. # This preserves the original exception location even if the exception
  2015. # is later re-raised (e.g., in context manager cleanup).
  2016. # ExceptionVariable and UserDefinedExceptionObjectVariable both have
  2017. # a python_stack attribute.
  2018. if (
  2019. self._isinstance_exception(val)
  2020. and getattr(val, "python_stack", None) is None
  2021. ):
  2022. val.python_stack = torch._guards.TracingContext.extract_stack() # type: ignore[union-attr]
  2023. # 2) when user raises exception instance
  2024. if self._isinstance_exception(val):
  2025. # Save the exception in a global data structure
  2026. self.exn_vt_stack.set_current_exception(val) # type: ignore[arg-type]
  2027. observed_exception_type = exc.get_dynamo_observed_exception(val.exc_type) # type: ignore[attr-defined, union-attr]
  2028. # Pass the stored python_stack to preserve the original exception location
  2029. python_stack = getattr(val, "python_stack", None)
  2030. raise observed_exception_type(
  2031. f"raised exception {val}", real_stack=python_stack
  2032. )
  2033. unimplemented(
  2034. gb_type="Failed to raise exception",
  2035. context=str(exc),
  2036. explanation="Attempted to raise a non-Exception type/value.",
  2037. hints=[*graph_break_hints.USER_ERROR],
  2038. )
  2039. def RAISE_VARARGS(self, inst: Instruction) -> None:
  2040. if inst.arg == 0:
  2041. if not len(self.exn_vt_stack):
  2042. msg = ConstantVariable("No active exception to reraise")
  2043. exc.raise_observed_exception(RuntimeError, self, args=[msg])
  2044. # re-raise the previous exception. Here CPython refers to the exception
  2045. # on top of the exception stack
  2046. assert len(self.exn_vt_stack)
  2047. val = self.exn_vt_stack[-1]
  2048. assert self._isinstance_exception(val), val
  2049. self._raise_exception_variable(val)
  2050. elif inst.arg == 1:
  2051. # raise TOS
  2052. val = self.stack[-1] # type: ignore[assignment]
  2053. try:
  2054. self._raise_exception_variable(val)
  2055. finally:
  2056. # Update __traceback__ in the raised exception
  2057. curr_exc = self.exn_vt_stack.get_current_exception()
  2058. self._attach_traceback_to_exception(curr_exc)
  2059. else:
  2060. # raise .. from ...
  2061. from_vt = self.pop()
  2062. val = self.pop() # type: ignore[assignment]
  2063. try:
  2064. self._raise_exception_variable(val)
  2065. finally:
  2066. # Update __cause__/__suppress_context__ in the raised exception
  2067. curr_exc = self.exn_vt_stack.get_current_exception()
  2068. self._attach_traceback_to_exception(curr_exc)
  2069. cause = self._create_exception_type(from_vt)
  2070. curr_exc.call_setattr(self, ConstantVariable("__cause__"), cause) # type: ignore[arg-type, union-attr, assignment]
  2071. def CLEANUP_THROW(self, inst: Instruction) -> None:
  2072. # https://github.com/python/cpython/pull/96010
  2073. tos = self.stack[-1]
  2074. assert isinstance(tos, ExceptionVariable)
  2075. if tos.exc_type is StopIteration:
  2076. unimplemented(
  2077. gb_type="CLEANUP_THROW with StopIteration",
  2078. context="",
  2079. explanation="Received StopIteration when handling generator.throw/close. This is not supported.",
  2080. hints=[],
  2081. )
  2082. else:
  2083. self.RERAISE(inst)
  2084. def RERAISE(self, inst: Instruction) -> None:
  2085. # https://docs.python.org/3/library/dis.html#opcode-RERAISE
  2086. # Re-raises the exception currently on top of the stack. If oparg is
  2087. # non-zero, pops an additional value from the stack which is used to
  2088. # set f_lasti of the current frame.
  2089. if sys.version_info >= (3, 11):
  2090. # RERAISE is currently supported in a narrow case of `raise ... from None`
  2091. val = self.pop()
  2092. if inst.argval:
  2093. # RERAISE 1
  2094. _ = self.pop()
  2095. self._raise_exception_variable(val)
  2096. else:
  2097. # RERAISE 0
  2098. self.push(val)
  2099. self._raise_exception_variable(val)
  2100. else:
  2101. _exc = self.pop()
  2102. val = self.pop()
  2103. _tb = self.pop()
  2104. self._raise_exception_variable(val)
  2105. def _isinstance_exception(self, val: VariableTracker) -> TypeIs[ExceptionVals]:
  2106. return isinstance(val, ExceptionVals)
  2107. def WITH_EXCEPT_START(self, inst: Instruction) -> None:
  2108. args: list[VariableTracker] = []
  2109. if sys.version_info >= (3, 11):
  2110. fn_loc = 4 if sys.version_info < (3, 14) else 5
  2111. # At the top of the stack are 4 values:
  2112. # - TOP = exc_info()
  2113. # - SECOND = previous exception
  2114. # - THIRD: lasti of exception in exc_info()
  2115. # - FOURTH: the context.__exit__ bound method
  2116. # We call FOURTH(type(TOP), TOP, GetTraceback(TOP)).
  2117. # Then we push the __exit__ return value.
  2118. # In Python 3.14+, there is a NULL placed between the context.__exit__ bound method and the lasti,
  2119. # that is, fn is now the 5th from TOS.
  2120. assert len(self.stack) >= fn_loc
  2121. fn = self.stack[-fn_loc]
  2122. val = self.stack[-1]
  2123. assert self._isinstance_exception(val)
  2124. typ = BuiltinVariable(val.exc_type) # type: ignore[attr-defined, union-attr]
  2125. tb = val.var_getattr(
  2126. # pyrefly: ignore[bad-argument-type]
  2127. self,
  2128. "__traceback__",
  2129. )
  2130. if sys.version_info >= (3, 14):
  2131. if not isinstance(self.stack[-4], NullVariable):
  2132. args.append(self.stack[-4])
  2133. else:
  2134. assert len(self.stack) >= 7
  2135. fn = self.stack[-7]
  2136. val = self.stack[-2]
  2137. assert self._isinstance_exception(val)
  2138. typ = BuiltinVariable(val.exc_type) # type: ignore[attr-defined]
  2139. tb = val.var_getattr(self, "__traceback__")
  2140. args += [typ, val, tb]
  2141. self.call_function(fn, args, {})
  2142. def exception_handler(self, raised_exception: ObservedException) -> None:
  2143. observed_exn_gb_explanation = (
  2144. "Dynamo found no exception handler at the top-level compiled function "
  2145. "when encountering an exception. Exception will propagate outside the compiled region."
  2146. )
  2147. def bubble_exception_to_interpreter() -> None:
  2148. # Bubble the exception to the interpreter
  2149. curr_exc = self.exn_vt_stack.get_current_exception()
  2150. dynamo_exc = exc.get_dynamo_observed_exception(curr_exc.python_type())
  2151. assert isinstance(raised_exception, dynamo_exc) # sanity check
  2152. unimplemented(
  2153. gb_type="Observed exception",
  2154. context=f"raised exception {curr_exc.python_type_name()}({curr_exc.args})", # type: ignore[union-attr]
  2155. explanation=observed_exn_gb_explanation,
  2156. hints=[
  2157. *graph_break_hints.USER_ERROR,
  2158. *graph_break_hints.SUPPORTABLE,
  2159. ],
  2160. from_exc=raised_exception,
  2161. )
  2162. if sys.version_info >= (3, 11):
  2163. exn_tab_entry = self.current_instruction.exn_tab_entry
  2164. if exn_tab_entry:
  2165. # Implementation is based on https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt
  2166. # 1) pop values from the stack until it matches the stack depth
  2167. # for the handler
  2168. while len(self.stack) > exn_tab_entry.depth:
  2169. self.pop()
  2170. # 2) if 'lasti' is true, then push the offset that the exception was raised at
  2171. if exn_tab_entry.lasti:
  2172. self.push(
  2173. variables.ConstantVariable(self.current_instruction.offset)
  2174. )
  2175. # 3) push the exception to the stack
  2176. self.push(self.exn_vt_stack.get_current_exception())
  2177. # 4) jump to the handler
  2178. self.jump(exn_tab_entry) # type: ignore[arg-type]
  2179. else:
  2180. # No handler found. Bubble the exception to the parent
  2181. # instruction translator. We use special exception for this.
  2182. self.stack.clear()
  2183. # attach traceback to the exception and set it as current exception
  2184. curr_exc = self.exn_vt_stack.get_current_exception()
  2185. self._attach_traceback_to_exception(curr_exc)
  2186. if type(self) is InstructionTranslator:
  2187. bubble_exception_to_interpreter()
  2188. raise raised_exception
  2189. else:
  2190. if len(self.block_stack):
  2191. # base implementation - https://github.com/python/cpython/blob/3.10/Python/ceval.c#L4455
  2192. block_stack_entry = self.block_stack.pop()
  2193. while block_stack_entry.inst.opname == "EXCEPT_HANDLER":
  2194. # TODO(anijain2305) - This is not tested .. unable to create a testcase
  2195. # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456
  2196. self.popn(3)
  2197. self.exn_vt_stack.pop()
  2198. if len(self.block_stack) == 0:
  2199. # No handler found in this frame. Bubble the exception to the parent
  2200. # instruction translator.
  2201. self.stack.clear()
  2202. if type(self) is InstructionTranslator:
  2203. unimplemented(
  2204. gb_type="Observed exception (EXCEPT_HANDLER)",
  2205. context=str(raised_exception),
  2206. explanation=observed_exn_gb_explanation
  2207. + " This graph break is unexpected.",
  2208. hints=[*graph_break_hints.DYNAMO_BUG],
  2209. from_exc=raised_exception,
  2210. )
  2211. raise raised_exception
  2212. block_stack_entry = self.block_stack.pop()
  2213. exception_var = self.exn_vt_stack.get_current_exception()
  2214. self.exn_vt_stack.move_current_exception_to_stack()
  2215. # 1) pop values from the stack until it matches the stack depth
  2216. # for the handler
  2217. while len(self.stack) > block_stack_entry.stack_index:
  2218. self.pop()
  2219. # Push a dummy block stack entry of EXCEPT_HANDLER
  2220. # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L1456
  2221. except_handler_inst = Instruction(int(1e6), "EXCEPT_HANDLER", None, 0)
  2222. self.block_stack.append(
  2223. BlockStackEntry(except_handler_inst, None, len(self.stack))
  2224. )
  2225. # Push old exception
  2226. if len(self.exn_vt_stack) >= 2:
  2227. old_exception = self.exn_vt_stack[-2]
  2228. # Push the old exception on to stack - tb, value, type
  2229. # Traceback is currently mapped to UnknownVariable
  2230. self.push(variables.UnknownVariable())
  2231. self.push(old_exception)
  2232. self.push(variables.BuiltinVariable(old_exception.exc_type))
  2233. else:
  2234. # Push empty exception tb, value, type
  2235. self.push(variables.CONSTANT_VARIABLE_NONE)
  2236. self.push(variables.CONSTANT_VARIABLE_NONE)
  2237. self.push(variables.CONSTANT_VARIABLE_NONE)
  2238. # Push new exception - tb, val, type
  2239. # Traceback is currently mapped to UnknownVariable
  2240. self.push(variables.UnknownVariable())
  2241. self.push(exception_var)
  2242. self.push(variables.BuiltinVariable(exception_var.exc_type))
  2243. # Jump to target
  2244. self.jump(block_stack_entry)
  2245. else:
  2246. # No handler found. Bubble the exception to the parent
  2247. # instruction translator. We use special exception for this.
  2248. self.stack.clear()
  2249. if type(self) is InstructionTranslator:
  2250. bubble_exception_to_interpreter()
  2251. raise raised_exception
  2252. def PUSH_EXC_INFO(self, inst: Instruction) -> None:
  2253. # https://docs.python.org/3/library/dis.html#opcode-PUSH_EXC_INFO
  2254. # Pops a value from the stack. Pushes the current exception to the top
  2255. # of the stack. Pushes the value originally popped back to the stack.
  2256. #
  2257. # The behavior of this opcode in CPython is a bit different than what it
  2258. # is described. It pops a value from the stack, pushes the top of the
  2259. # exception stack to the interpreter stack and moves the
  2260. # "current exception" to the exception stack.
  2261. #
  2262. # As an example, suppose the stack is in the following state:
  2263. # + stack = [..., ConstantVariable(1), ConstantVariable(2)]
  2264. # + current_exception = TypeError
  2265. # + exception_stack = [ValueError]
  2266. #
  2267. # After PUSH_EXC_INFO is executed
  2268. # + stack = [..., ConstantVariable(1), ValueError, ConstantVariable(2)]
  2269. # + current_exception = None
  2270. # + exception_stack = [ValueError, TypeError]
  2271. val = self.pop()
  2272. if len(self.exn_vt_stack) == 0:
  2273. prev_exc: VariableTracker = CONSTANT_VARIABLE_NONE
  2274. else:
  2275. prev_exc = self.exn_vt_stack[-1]
  2276. self.push(prev_exc)
  2277. self.push(val)
  2278. self.exn_vt_stack.move_current_exception_to_stack()
  2279. def POP_EXCEPT(self, inst: Instruction) -> None:
  2280. if sys.version_info >= (3, 11):
  2281. _ = self.pop()
  2282. # This exception is handled and therefore we can clear the error indicator
  2283. assert len(self.exn_vt_stack)
  2284. self.exn_vt_stack.pop()
  2285. else:
  2286. assert len(self.block_stack) > 0
  2287. if self.block_stack[-1].inst.opname != "EXCEPT_HANDLER":
  2288. raise AssertionError(
  2289. "Bug in Dynamo tracing of exception handling."
  2290. "Top of the block stack is not EXCEPT_HANDLER."
  2291. )
  2292. self.block_stack.pop()
  2293. self.popn(3)
  2294. # This exception is handled and therefore we can clear the error indicator
  2295. assert len(self.exn_vt_stack)
  2296. self.exn_vt_stack.pop()
  2297. def check_if_exc_matches(self) -> bool:
  2298. assert len(self.stack) >= 2
  2299. expected_exc_types = self.pop()
  2300. if sys.version_info >= (3, 11):
  2301. # CHECK_EXC_MATCH (which is used from 3.11 onwards) does not pop.
  2302. # This is the description from the disassembly doc
  2303. #
  2304. # Performs exception matching for ``except``. Tests whether the ``STACK[-2]``
  2305. # is an exception matching ``STACK[-1]``. Pops ``STACK[-1]`` and pushes the boolean
  2306. # result of the test.
  2307. exc_instance = self.stack[-1]
  2308. else:
  2309. # This is used prior to 3.11 via opcode JUMP_IF_NOT_EXC_MATCH
  2310. # There is no documentation but here is the code pointer that does 2 pops
  2311. # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L3650-L3665
  2312. exc_instance = self.stack.pop()
  2313. # Users can check exception in 3 ways
  2314. # 1) except NotImplementedError --> BuiltinVariable
  2315. # 2) except CustomException --> UserDefinedExceptionClassVariable
  2316. # 3) except (NotImplementedError, AttributeError) -> TupleVariable
  2317. if not isinstance(
  2318. expected_exc_types,
  2319. (
  2320. BuiltinVariable,
  2321. TupleVariable,
  2322. UserDefinedExceptionClassVariable,
  2323. UserDefinedExceptionObjectVariable,
  2324. ),
  2325. ):
  2326. unimplemented(
  2327. gb_type="Exception with bad expected type",
  2328. context=str(expected_exc_types),
  2329. explanation=f"`except ...` has unsupported type {expected_exc_types}.",
  2330. hints=[*graph_break_hints.USER_ERROR],
  2331. )
  2332. if sys.version_info >= (3, 11):
  2333. if not self._isinstance_exception(exc_instance):
  2334. unimplemented(
  2335. gb_type="Caught non-Exception value",
  2336. context=str(exc_instance),
  2337. explanation=f"Except expects to receive an object of Exception type but received {exc_instance}.",
  2338. hints=[*graph_break_hints.USER_ERROR],
  2339. )
  2340. if isinstance(expected_exc_types, TupleVariable):
  2341. expected_types = expected_exc_types.items
  2342. else:
  2343. expected_types = [
  2344. expected_exc_types,
  2345. ]
  2346. for expected_type in expected_types:
  2347. if not isinstance(
  2348. expected_type,
  2349. (
  2350. BuiltinVariable,
  2351. UserDefinedExceptionObjectVariable,
  2352. UserDefinedExceptionClassVariable,
  2353. ),
  2354. ):
  2355. unimplemented(
  2356. gb_type="Exception with non-type expectation",
  2357. context=str(expected_type),
  2358. explanation=f"`except ...` expects a non-type: {expected_type}.",
  2359. hints=[*graph_break_hints.USER_ERROR],
  2360. )
  2361. if self._isinstance_exception(exc_instance) and issubclass(
  2362. exc_instance.exc_type, # type: ignore[union-attr]
  2363. expected_type.fn, # type: ignore[attr-defined]
  2364. ):
  2365. return True
  2366. elif isinstance(exc_instance, variables.BuiltinVariable) and issubclass(
  2367. exc_instance.fn,
  2368. # pyrefly: ignore [invalid-argument, missing-attribute]
  2369. expected_type.fn,
  2370. ):
  2371. return True
  2372. return False
  2373. def CHECK_EXC_MATCH(self, inst: Instruction) -> None:
  2374. self.push(variables.ConstantVariable(self.check_if_exc_matches()))
  2375. def JUMP_IF_NOT_EXC_MATCH(self, inst: Instruction) -> None:
  2376. if not self.check_if_exc_matches():
  2377. self.jump(inst)
  2378. def COMPARE_OP(self, inst: Instruction) -> None:
  2379. if inst.argval == "exception match":
  2380. self.CHECK_EXC_MATCH(inst)
  2381. else:
  2382. self.push(compare_op_handlers[inst.argval](self, self.popn(2), {}))
  2383. def GET_ITER(self, inst: Instruction) -> None:
  2384. self.call_function(BuiltinVariable(iter), [self.pop()], {})
  2385. @break_graph_if_unsupported(
  2386. push=True,
  2387. msg_prefix="Encountered graph break when attempting to trace CALL_FUNCTION: a call to a regular function, e.g. f(x, y)",
  2388. )
  2389. def CALL_FUNCTION(self, inst: Instruction) -> None:
  2390. args = self.popn(inst.argval)
  2391. fn = self.pop()
  2392. self.call_function(fn, args, {})
  2393. @break_graph_if_unsupported(
  2394. push=True,
  2395. msg_prefix="Encountered graph break when attempting to trace CALL_FUNCTION_EX: a variadic function call, e.g. f(*args)",
  2396. )
  2397. def CALL_FUNCTION_EX(self, inst: Instruction) -> None:
  2398. kwargsvars: VariableTracker
  2399. if inst.argval == 0:
  2400. kwargsvars = ConstDictVariable({})
  2401. argsvars = self.pop()
  2402. elif inst.argval == 1 or sys.version_info >= (3, 14):
  2403. # Python 3.14+ removed the argval and replaced it with a possibly NULL kwargs
  2404. kwargsvars = self.pop()
  2405. if isinstance(kwargsvars, NullVariable):
  2406. kwargsvars = ConstDictVariable({})
  2407. argsvars = self.pop()
  2408. else:
  2409. unimplemented(
  2410. gb_type="Variadic function call with bad flags",
  2411. context=f"flags: {inst.argval}",
  2412. explanation=f"Attempted to call a variadic function (CALL_FUNCTION_EX) with bad flags {inst.argval}",
  2413. hints=[*graph_break_hints.DYNAMO_BUG],
  2414. )
  2415. if sys.version_info >= (3, 13):
  2416. # 3.13 swapped null and callable
  2417. null = self.pop()
  2418. assert isinstance(null, NullVariable)
  2419. fn = self.pop()
  2420. if sys.version_info >= (3, 11) and sys.version_info < (3, 13):
  2421. null = self.pop()
  2422. assert isinstance(null, NullVariable)
  2423. if not isinstance(
  2424. # pyrefly: ignore [unbound-name]
  2425. argsvars,
  2426. BaseListVariable,
  2427. # pyrefly: ignore [unbound-name]
  2428. ) and argsvars.has_force_unpack_var_sequence(self):
  2429. # pyrefly: ignore [unbound-name]
  2430. argsvars = TupleVariable(argsvars.force_unpack_var_sequence(self))
  2431. # Unpack for cases like fn(**obj) where obj is a map
  2432. # pyrefly: ignore [unbound-name]
  2433. if isinstance(kwargsvars, UserDefinedObjectVariable):
  2434. kwargsvars = BuiltinVariable.call_custom_dict(self, dict, kwargsvars) # type: ignore[arg-type]
  2435. # pyrefly: ignore [unbound-name]
  2436. if not isinstance(argsvars, BaseListVariable) or not isinstance(
  2437. # pyrefly: ignore [unbound-name]
  2438. kwargsvars,
  2439. ConstDictVariable,
  2440. ):
  2441. unimplemented(
  2442. gb_type="Variadic function call with bad args/kwargs type",
  2443. # pyrefly: ignore [unbound-name]
  2444. context=f"args type: {typestr(argsvars)}, kwargs type: {typestr(kwargsvars)}",
  2445. explanation="Expected args to be a list and kwargs to be a dict",
  2446. hints=[*graph_break_hints.USER_ERROR],
  2447. )
  2448. # Map to a dictionary of str -> VariableTracker
  2449. # pyrefly: ignore [bad-assignment, missing-attribute, unbound-name]
  2450. kwargsvars = kwargsvars.keys_as_python_constant()
  2451. # pyrefly: ignore [bad-argument-type, missing-attribute, unbound-name]
  2452. self.call_function(fn, argsvars.items, kwargsvars)
  2453. @break_graph_if_unsupported(
  2454. push=True,
  2455. msg_prefix="Encountered graph break when attempting to trace CALL_FUNCTION_KW: "
  2456. "a function call with keyword arguments, e.g. f(x=True)",
  2457. )
  2458. def CALL_FUNCTION_KW(self, inst: Instruction) -> None:
  2459. argnames = self.pop()
  2460. args = self.popn(inst.argval)
  2461. fn = self.pop()
  2462. assert isinstance(argnames, TupleVariable) and argnames.is_python_constant()
  2463. argnames = argnames.as_python_constant()
  2464. args, kwargs_list = args[: -len(argnames)], args[-len(argnames) :]
  2465. kwargs = dict(zip(argnames, kwargs_list))
  2466. assert len(kwargs) == len(argnames)
  2467. self.call_function(fn, args, kwargs)
  2468. def LOAD_METHOD_SUPER(self, inst: Instruction) -> None:
  2469. self.CALL_FUNCTION(dataclasses.replace(inst, argval=2))
  2470. arg = inst.argval[0]
  2471. argval = self.code_options["co_names"][arg]
  2472. if sys.version_info < (3, 11):
  2473. self._load_attr(argval)
  2474. else:
  2475. self.LOAD_METHOD(dataclasses.replace(inst, argval=argval))
  2476. def LOAD_ATTR_SUPER(self, inst: Instruction) -> None:
  2477. self.CALL_FUNCTION(dataclasses.replace(inst, argval=2))
  2478. arg = inst.argval[0]
  2479. argval = self.code_options["co_names"][arg]
  2480. self._load_attr(argval)
  2481. def LOAD_METHOD(self, inst: Instruction) -> None:
  2482. self._load_attr(inst.argval)
  2483. obj = self.pop()
  2484. if sys.version_info >= (3, 13):
  2485. self.push(obj)
  2486. self.PUSH_NULL(inst)
  2487. elif sys.version_info >= (3, 11):
  2488. # always follow the NULL + fn convention, since if obj
  2489. # is actually a method, self is already bound to it, so it
  2490. # doesn't need to be passed in as an arg.
  2491. self.PUSH_NULL(inst)
  2492. self.push(obj)
  2493. else:
  2494. self.push(obj)
  2495. self.push(None)
  2496. def CALL_METHOD(self, inst: Instruction) -> None:
  2497. args = self.popn(inst.argval)
  2498. dummy = self.pop()
  2499. assert dummy is None
  2500. fn = self.pop()
  2501. self.call_function(fn, args, {})
  2502. def _load_attr(self, attr: Any) -> None:
  2503. obj = self.pop()
  2504. result = BuiltinVariable(getattr).call_function(
  2505. self, # type: ignore[arg-type]
  2506. [obj, ConstantVariable.create(attr)],
  2507. {},
  2508. )
  2509. self.push(result)
  2510. def LOAD_ATTR(self, inst: Instruction) -> None:
  2511. if sys.version_info >= (3, 12):
  2512. # pyrefly: ignore [unsupported-operation]
  2513. if inst.arg % 2:
  2514. self.LOAD_METHOD(inst)
  2515. return
  2516. self._load_attr(inst.argval)
  2517. @break_graph_if_unsupported(
  2518. push=False,
  2519. msg_prefix="Encountered graph break when attempting to trace STORE_ATTR: storing an object's attribute, e.g. x.attr = y",
  2520. )
  2521. def STORE_ATTR(self, inst: Instruction) -> None:
  2522. val, obj = self.popn(2)
  2523. BuiltinVariable(setattr).call_function(
  2524. self, # type: ignore[arg-type]
  2525. [obj, ConstantVariable.create(inst.argval), val],
  2526. {},
  2527. )
  2528. def DELETE_ATTR(self, inst: Instruction) -> None:
  2529. obj = self.pop()
  2530. BuiltinVariable(delattr).call_function(
  2531. self, # type: ignore[arg-type]
  2532. [obj, ConstantVariable.create(inst.argval)],
  2533. {},
  2534. )
  2535. @staticmethod
  2536. def codegen_return_with_pops(
  2537. inst: Instruction, num_stack: int
  2538. ) -> list[Instruction]:
  2539. """
  2540. Debug CPython expects the stack to be empty after the return.
  2541. Calling compile_subgraph will push cells and frame values to TOS.
  2542. This function will pop those 2 values from the stack before actually returning.
  2543. Expects the stack to be:
  2544. cells, frame values, current frame stack (0 or 1 values)
  2545. Pops cells and frame values, leaving the current frame stack as TOS.
  2546. A return instruction is included.
  2547. """
  2548. insts = []
  2549. # NOTE: Debug CPython expects the stack to be empty after the return.
  2550. # Expect the current stack to be in the state
  2551. # cells, frame values, current frame stack (0 or 1 values)
  2552. assert num_stack <= 1
  2553. if num_stack == 1:
  2554. insts.extend(create_swap(3))
  2555. return_inst = (
  2556. create_instruction("RETURN_VALUE")
  2557. if inst.opname == "RETURN_VALUE"
  2558. else create_instruction("RETURN_CONST", argval=inst.argval)
  2559. )
  2560. insts.extend(
  2561. [create_instruction("POP_TOP"), create_instruction("POP_TOP"), return_inst]
  2562. )
  2563. return insts
  2564. def create_resume(
  2565. self,
  2566. idx: int,
  2567. resume_inst: Instruction,
  2568. meta: StackLocalsMetadata,
  2569. resume_codes: list[types.CodeType],
  2570. cg: PyCodegen,
  2571. is_leaf: bool,
  2572. handle_inactive_ctx: bool,
  2573. ) -> tuple[types.CodeType, str]:
  2574. """
  2575. Creates the resume function for the frame corresponding to `self`.
  2576. Expects the TOS to be:
  2577. [frame N cells, ..., frame 1 cells],
  2578. [
  2579. frame N stack + locals,
  2580. ...,
  2581. frame 1 stack + locals
  2582. ]
  2583. Some additional codegen may happen to prepare the frame stack + locals values for the generated resume function:
  2584. - inactive context variables in the stack and locals will be replaced by their types
  2585. - if the frame is a leaf frame, prune dead locals
  2586. Regardless of codegen, the stack will be left in the same state as before.
  2587. Args:
  2588. - idx: depth of this frame: 0 corresponds to the leaf frame (frame N), N-1 to the root frame (frame 1).
  2589. - resume_inst: the instruction that this frame should resume at
  2590. - meta: metadata for this frame returned from OutputGraph.compile_subgraph
  2591. - resume_codes: nested resume code objects generated from previous create_resume calls.
  2592. - cg: codegen object to output to
  2593. - is_leaf: True if `self` corresponds to the leaf frame.
  2594. - handle_inactive_ctx: If True, handles inactive context variables as described above. This is necessary
  2595. iff the resume function is traced
  2596. """
  2597. # Handle inactive context variables.
  2598. # The resume function assumes that context variables are the class, NOT the object.
  2599. # e.g. torch.set_grad_enabled(True) will be reconstructed as torch.set_grad_enabled
  2600. # NOTE: if the unsupported instruction modifies the inactive context variable, it may
  2601. # result in silent incorrectness!
  2602. if handle_inactive_ctx:
  2603. for (j, _), j_orig in zip(meta.stack_ctx_args, meta.stack_ctx_idxes_orig):
  2604. # Replace the stack var with the context class
  2605. ctx = cast(ContextWrappingVariable, self.stack[j_orig])
  2606. # frames[idx][j] = reconstructed_ctx
  2607. cg.append_output(create_dup_top())
  2608. ctx.reconstruct_type(cg)
  2609. cg.extend_output(
  2610. [
  2611. *create_swap(2),
  2612. cg.create_load_const(idx),
  2613. cg.create_binary_subscr(),
  2614. cg.create_load_const(j),
  2615. create_instruction("STORE_SUBSCR"),
  2616. ]
  2617. )
  2618. for name, _ in meta.locals_ctx_args:
  2619. # Replace the local with the context class
  2620. ctx = cast(ContextWrappingVariable, self.symbolic_locals[name])
  2621. # frames[idx][meta.num_stack +meta.locals_names[name]] = reconstructed_ctx
  2622. cg.append_output(create_dup_top())
  2623. ctx.reconstruct_type(cg)
  2624. cg.extend_output(
  2625. [
  2626. *create_swap(2),
  2627. cg.create_load_const(idx),
  2628. cg.create_binary_subscr(),
  2629. cg.create_load_const(meta.num_stack + meta.locals_names[name]),
  2630. create_instruction("STORE_SUBSCR"),
  2631. ]
  2632. )
  2633. # If the resume instruction is a jump absolute, then resume
  2634. # at the target instead. This handles the case where we
  2635. # graph break again in a nested function before jump-resuming
  2636. # this frame.
  2637. if is_jump_absolute(resume_inst):
  2638. assert resume_inst.target
  2639. resume_inst = resume_inst.target
  2640. resume_name = unique_id(f"__resume_at_{resume_inst.offset}")
  2641. # More locals may have been pruned in the current/leaf frame
  2642. # after the unsupported instruction (e.g. branch).
  2643. # There should not be any pruning in the other frames since
  2644. # the current instruction there should be a CALL.
  2645. if is_leaf:
  2646. reads = livevars_analysis(self.instructions, resume_inst)
  2647. all_argnames = tuple(
  2648. k
  2649. for k in self.symbolic_locals
  2650. if k in reads and k not in self.cell_and_freevars()
  2651. )
  2652. argnames_null_set = set(meta.locals_null_keys)
  2653. argnames = tuple(k for k in all_argnames if k not in argnames_null_set)
  2654. argnames_null = tuple(k for k in all_argnames if k in argnames_null_set)
  2655. # codegen filter for current frame's locals
  2656. # current stack state: frames
  2657. cg.extend_output(
  2658. [
  2659. create_dup_top(),
  2660. cg.create_load_const(idx),
  2661. cg.create_binary_subscr(),
  2662. create_dup_top(),
  2663. ]
  2664. )
  2665. for arg in argnames:
  2666. # current stack state: frames, frames[i], *(prev locals), frames[i]
  2667. cg.extend_output(
  2668. [
  2669. create_dup_top(),
  2670. cg.create_load_const(meta.num_stack + meta.locals_names[arg]),
  2671. cg.create_binary_subscr(),
  2672. *create_swap(2),
  2673. ],
  2674. )
  2675. # current stack state: frames, frames[i], *(frame i live locals), frames[i]
  2676. cg.extend_output(
  2677. [
  2678. create_instruction("POP_TOP"),
  2679. create_instruction("BUILD_LIST", arg=len(argnames)),
  2680. *create_swap(2),
  2681. # frames, frames i live locals, frames[i]
  2682. *create_binary_slice(meta.num_stack, None, True),
  2683. # frames[i][num_stack:] = frame i live locals
  2684. ]
  2685. )
  2686. # current stack state: frames
  2687. else:
  2688. argnames = tuple(meta.locals_names.keys())
  2689. argnames_null = tuple(meta.locals_null_keys)
  2690. if sys.version_info < (3, 12):
  2691. assert len(argnames_null) == 0, "variables should not be NULL in < 3.12"
  2692. # compile_subgraph did not codegen any NULLs,
  2693. # so we should not count NullVariables
  2694. stack_len = len(self.stack) - len(meta.stack_null_idxes)
  2695. assert self.current_instruction.offset is not None
  2696. new_code: types.CodeType = ContinueExecutionCache.lookup(
  2697. self.f_code,
  2698. self.lineno,
  2699. self.current_instruction.offset,
  2700. resume_inst.offset,
  2701. # pyre: ignore[missing-attribute]
  2702. tuple(b.target.offset for b in self.block_stack),
  2703. stack_len,
  2704. argnames,
  2705. argnames_null,
  2706. tuple(b.resume_fn() for b in self.block_stack),
  2707. handle_inactive_ctx,
  2708. tuple(meta.stack_ctx_args),
  2709. tuple(meta.locals_ctx_args),
  2710. tuple(meta.stack_null_idxes),
  2711. tuple(resume_codes),
  2712. not self.current_instruction_push,
  2713. )
  2714. # Add original GraphModule context to the resume function to handle
  2715. # the case of a graph break while tracing a GraphModule
  2716. orig_graphmodule_maybe = code_context.get_context(self.f_code).get(
  2717. "orig_graphmodule", lambda: None
  2718. )()
  2719. if orig_graphmodule_maybe is not None:
  2720. code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref(
  2721. orig_graphmodule_maybe
  2722. )
  2723. # add resume function to the global scope
  2724. if new_code.co_freevars:
  2725. # expose code object for debugging purposes
  2726. self.output.install_global_unsafe(resume_name, new_code)
  2727. package_name = None
  2728. else:
  2729. # This is safe: we pre-generate a unique name
  2730. self.output.install_global_unsafe(
  2731. resume_name,
  2732. types.FunctionType(new_code, self.f_globals, resume_name),
  2733. )
  2734. package_name = resume_name
  2735. if self.package is not None:
  2736. self.package.add_resume_function(
  2737. new_code, self.f_globals["__name__"], package_name
  2738. )
  2739. counters["resumes"][new_code.co_name] += 1
  2740. return new_code, resume_name
  2741. def create_call_resume_at(
  2742. self,
  2743. inst: Instruction,
  2744. all_stack_locals_metadata: list[StackLocalsMetadata],
  2745. ) -> list[Instruction]:
  2746. """
  2747. Codegen all resume function(s) from the frame stack starting at `self`, call them,
  2748. and return the result.
  2749. Assumes that the unsupported instruction has already been run.
  2750. Expects the TOS to be:
  2751. [
  2752. frame N locals,
  2753. frame N-1 stack + locals,
  2754. ...,
  2755. frame 1 stack + locals
  2756. ], *(frame N stack (post-unsupported instruction))
  2757. Leaves the result of calling the resume functions on the stack and returns it
  2758. (empty stack after return).
  2759. Args:
  2760. - inst: the instruction of the current (deepest) frame to resume at
  2761. - all_stack_locals_metadata: metadata returned from OutputGraph.compile_subgraph - contains
  2762. metadata such as local names, NULL positions, stack length, etc.
  2763. """
  2764. self.instruction_pointer = None
  2765. cg = PyCodegen(self.output.root_tx)
  2766. # NOTE: We do not need to codegen frames whose resume instruction is RETURN_VALUE
  2767. # We could also do something similar for RETURN_CONST, but a lot more code is necessary
  2768. # since we would need to track RETURN_CONST values and inject the constant in the right places.
  2769. # Filter out tx'es that are resuming on RETURN_*.
  2770. txes: list[InstructionTranslatorBase] = []
  2771. idxes: list[int] = []
  2772. resume_insts: list[Instruction] = []
  2773. cur_tx: Optional[InstructionTranslatorBase] = self
  2774. idx = 0
  2775. while cur_tx is not None:
  2776. if cur_tx is self:
  2777. resume_inst = inst
  2778. else:
  2779. resume_inst = cur_tx.next_instruction
  2780. if resume_inst.opname != "RETURN_VALUE":
  2781. txes.append(cur_tx)
  2782. idxes.append(idx)
  2783. resume_insts.append(resume_inst)
  2784. cur_tx = cur_tx.parent
  2785. idx += 1
  2786. current_num_stack = len(self.stack) - len(
  2787. all_stack_locals_metadata[0].stack_null_idxes
  2788. )
  2789. # Every tx is returning - no need to call a resume function.
  2790. if not txes:
  2791. # Pop everything but TOS, then return the TOS.
  2792. # Frame N's stack must have length >= 1 since it's about to RETURN_VALUE.
  2793. # Frame N actually should have stack length == 1, because debug CPython expects
  2794. # empty stacks after return, but there is no guarantee written down anywhere.
  2795. assert current_num_stack >= 1
  2796. cg.extend_output(create_swap(current_num_stack + 2))
  2797. for _ in range(current_num_stack + 1):
  2798. cg.append_output(create_instruction("POP_TOP"))
  2799. cg.append_output(create_instruction("RETURN_VALUE"))
  2800. return cg.get_instructions()
  2801. # Let frame k be the deepest frame where the resume function is not RETURN_VALUE
  2802. # - If k == N, then the frame N stack is prepended to the frame N locals.
  2803. # - If k != N, then frame N's TOS is added to frame k's stack.
  2804. # Rearrange the TOS to be compatible with create_resume and codegen_call_resume:
  2805. # [
  2806. # frame N stack + locals,
  2807. # ...,
  2808. # frame 1 stack + locals
  2809. # ]
  2810. # create the stack values that should be moved
  2811. if txes[0] is self:
  2812. # Frame N is non-returning, pack all of frame N's stack to
  2813. # be moved to frame N's frame values
  2814. cg.append_output(create_instruction("BUILD_LIST", arg=current_num_stack))
  2815. # frame N stack is not yet on the frame N's frame values
  2816. stack_insert_idx = 0
  2817. all_stack_locals_metadata[0].num_stack = current_num_stack
  2818. else:
  2819. # Frame N is returning. Let frame k be the deepest non-returning frame.
  2820. # Add frame N's TOS to frame k's stack.
  2821. # pop frame N stack except TOS
  2822. cg.extend_output(create_swap(current_num_stack))
  2823. for _ in range(current_num_stack - 1):
  2824. cg.append_output(create_instruction("POP_TOP"))
  2825. cg.append_output(create_instruction("BUILD_LIST", arg=1))
  2826. # frame k stack is already on frame k's frame values
  2827. stack_insert_idx = all_stack_locals_metadata[idxes[0]].num_stack
  2828. all_stack_locals_metadata[idxes[0]].num_stack += 1
  2829. txes[0].push(UnknownVariable())
  2830. # move the predetermined stack value(s) to the deepest non-returning frame
  2831. cg.extend_output(
  2832. [
  2833. *create_copy(2),
  2834. # frame_values, return_const, frame_values
  2835. cg.create_load_const(idxes[0]),
  2836. cg.create_binary_subscr(),
  2837. *create_binary_slice(stack_insert_idx, stack_insert_idx, True),
  2838. # frame_values[idxes[0]][stack_insert_idx:stack_insert_idx] = frame N stack/[return_const/TOS]
  2839. # frame_values left on top of stack
  2840. ]
  2841. )
  2842. # filter out frame values of skipped tx'es
  2843. filter_insts = []
  2844. for idx in idxes:
  2845. filter_insts.extend(
  2846. [
  2847. create_dup_top(),
  2848. cg.create_load_const(idx),
  2849. cg.create_binary_subscr(),
  2850. *create_swap(2),
  2851. ]
  2852. )
  2853. # TOS: cells, frame_values[idxes[0]], ..., frame_values[idxes[...]], frame_values
  2854. filter_insts.extend(
  2855. [
  2856. create_instruction("POP_TOP"),
  2857. create_instruction("BUILD_LIST", arg=len(idxes)),
  2858. ]
  2859. )
  2860. # TOS: cells, filtered frame_values
  2861. cg.extend_output(filter_insts)
  2862. # filter out cells of skipped tx'es using the same instructions in filter_insts,
  2863. # but with cells as TOS instead of frame values
  2864. cg.extend_output(
  2865. [
  2866. *create_swap(2),
  2867. *copy.deepcopy(filter_insts),
  2868. *create_swap(2),
  2869. ]
  2870. )
  2871. # TOS: filtered cells, filtered frame_values
  2872. resume_codes: list[types.CodeType] = []
  2873. resume_names = []
  2874. for i, cur_tx in enumerate(txes):
  2875. resume_code, resume_name = cur_tx.create_resume(
  2876. i,
  2877. resume_insts[i],
  2878. all_stack_locals_metadata[idxes[i]],
  2879. resume_codes,
  2880. cg,
  2881. cur_tx is self,
  2882. True,
  2883. )
  2884. resume_codes.append(resume_code)
  2885. resume_names.append(resume_name)
  2886. self.codegen_call_resume(resume_codes, resume_names, cg)
  2887. cg.append_output(create_instruction("RETURN_VALUE"))
  2888. return cg.get_instructions()
  2889. @staticmethod
  2890. def codegen_call_resume(
  2891. resume_codes: list[types.CodeType], resume_names: list[str], cg: PyCodegen
  2892. ) -> None:
  2893. """
  2894. Calls the provided resume functions.
  2895. Expects the TOS to be in the state:
  2896. [frame N cells, ..., frame 1 cells],
  2897. [
  2898. frame N stack + locals,
  2899. frame N-1 stack + locals,
  2900. ...,
  2901. frame 1 stack + locals
  2902. ]
  2903. Pops the cells and frame values, leaving the result of calling the resume functions on TOS.
  2904. Args:
  2905. - resume_codes: list of resume function code objects to call
  2906. - resume_names: list of the corresponding names of the resume functions
  2907. - cg: PyCodegen object to output instructions to
  2908. """
  2909. # NOTE: We will load cells as we load resume functions
  2910. # load resume functions except the root's
  2911. cg.extend_output(create_copy(2))
  2912. for i, (name, code) in enumerate(zip(resume_names, resume_codes)):
  2913. if i == len(resume_names) - 1:
  2914. break
  2915. # stack: cells, frames, *(resume 1, ...), cells
  2916. if code.co_freevars:
  2917. cg.extend_output(
  2918. [
  2919. create_dup_top(),
  2920. cg.create_load_const(i),
  2921. cg.create_binary_subscr(),
  2922. ]
  2923. )
  2924. cg.make_function_with_closure(name, code)
  2925. else:
  2926. cg.extend_output(cg.load_function_name(name, False, 0))
  2927. cg.extend_output(create_swap(2))
  2928. cg.extend_output(
  2929. [
  2930. create_instruction("POP_TOP"),
  2931. create_instruction("BUILD_LIST", arg=len(resume_codes) - 1),
  2932. ]
  2933. )
  2934. # stack: cells, frames, [resume 1, ..., resume N - 1]
  2935. # load root resume function
  2936. cg.extend_output(create_swap(3))
  2937. if resume_codes[-1].co_freevars:
  2938. cg.extend_output(
  2939. [
  2940. cg.create_load_const(-1),
  2941. cg.create_binary_subscr(),
  2942. ]
  2943. )
  2944. cg.make_function_with_closure(resume_names[-1], resume_codes[-1])
  2945. cg.extend_output(
  2946. [
  2947. *create_rot_n(3),
  2948. ]
  2949. )
  2950. else:
  2951. cg.extend_output(
  2952. [
  2953. create_instruction("POP_TOP"),
  2954. *cg.load_function_name(resume_names[-1], False),
  2955. *create_rot_n(3),
  2956. ]
  2957. )
  2958. # resume 1, [resume N, ..., resume 2], frames
  2959. # load top level-frame; final stack state should be:
  2960. # first resume function (+ NULL),
  2961. # [
  2962. # [resume N, ..., resume 2],
  2963. # [
  2964. # frame N stack + locals,
  2965. # ...,
  2966. # frame 2 stack + locals,
  2967. # ], *(frame 1 stack + locals)
  2968. # ]
  2969. cg.extend_output(
  2970. [
  2971. create_dup_top(),
  2972. create_dup_top(),
  2973. # frames, frames, frames
  2974. cg.create_load_const(-1),
  2975. cg.create_binary_subscr(),
  2976. # frames, frames, frames[-1]
  2977. *create_swap(2),
  2978. # frames, frames[-1], frames
  2979. cg.create_load_const(-1),
  2980. create_instruction("DELETE_SUBSCR"),
  2981. ]
  2982. )
  2983. # TOS: resume 1, remaining resumes, frames (popped), frame 1 stack + locals
  2984. cg.extend_output(
  2985. [
  2986. *create_rot_n(3),
  2987. create_instruction("BUILD_LIST", arg=2),
  2988. *create_swap(2),
  2989. # [resumes, frames (popped)], frame 1 stack + locals
  2990. create_instruction("LIST_EXTEND", arg=1),
  2991. ]
  2992. )
  2993. # TOS: resume 1, [remaining resumes, frames, *(frame 1 stack + locals)]
  2994. cg.extend_output(create_call_function_ex(False, True))
  2995. def should_compile_partial_graph(self) -> bool:
  2996. if sys.version_info >= (3, 11):
  2997. # Do not compile if current instruction's block is not the top with block
  2998. entry = self.current_instruction.exn_tab_entry
  2999. if entry and (
  3000. not self.block_stack or entry.target is not self.block_stack[-1].target
  3001. ):
  3002. return False
  3003. return (
  3004. all(b.can_restore() for b in self.block_stack)
  3005. and not self.one_graph
  3006. # Only the leaf tracer's error_on_graph_break should be used
  3007. and (self.is_child_tracer_active or not self.error_on_graph_break)
  3008. and not self.is_tracing_resume_prologue
  3009. and not self.active_generic_context_managers
  3010. # Do not allow nested graph breaks in HOPs
  3011. and self.output.current_tracer.parent is None
  3012. )
  3013. @break_graph_if_unsupported(
  3014. push=False,
  3015. msg_prefix="Encountered graph break when attempting to trace STORE_SUBSCR: trying to store subscript, e.g. x[key] = y",
  3016. )
  3017. def STORE_SUBSCR(self, inst: Instruction) -> None:
  3018. val, obj, key = self.popn(3)
  3019. obj.call_method(self, "__setitem__", [key, val], {})
  3020. def DELETE_SUBSCR(self, inst: Instruction) -> None:
  3021. obj, key = self.popn(2)
  3022. obj.call_method(self, "__delitem__", [key], {})
  3023. def BUILD_TUPLE(self, inst: Instruction) -> None:
  3024. items = self.popn(inst.argval)
  3025. self.push(TupleVariable(items))
  3026. def BUILD_SLICE(self, inst: Instruction) -> None:
  3027. items = self.popn(inst.argval)
  3028. self.push(SliceVariable(items, tx=self)) # type: ignore[arg-type]
  3029. def _can_speculate_comprehension_nested(self) -> bool:
  3030. """Check if comprehension speculation is allowed in nested context.
  3031. For the base class (non-inlined), this always returns False.
  3032. """
  3033. return False
  3034. def _maybe_setup_comprehension_speculation(self, inst: Instruction) -> bool:
  3035. """
  3036. Handle comprehension start for Python 3.12+ BUILD_LIST/BUILD_MAP with argval 0.
  3037. Returns True if a graph break was triggered and the caller should return early.
  3038. """
  3039. if not (sys.version_info >= (3, 12) and inst.argval == 0):
  3040. return False
  3041. if not self._is_comprehension_start():
  3042. return False
  3043. can_speculate = (
  3044. all(b.can_restore() for b in self.block_stack)
  3045. and not self.one_graph
  3046. and not self.error_on_graph_break
  3047. and not self.is_tracing_resume_prologue
  3048. and not self.active_generic_context_managers
  3049. and self.output.current_tracer.parent is None
  3050. )
  3051. if can_speculate and self.parent is not None:
  3052. can_speculate = self._can_speculate_comprehension_nested()
  3053. # Only set up speculation at depth 0 (outermost comprehension)
  3054. if can_speculate and self._comprehension_depth == 0:
  3055. speculation = self.speculate()
  3056. if speculation.failed(self):
  3057. self._handle_comprehension_graph_break(inst)
  3058. return True
  3059. self.current_speculation = speculation
  3060. end_for_ip = self._find_comprehension_end_for_ip()
  3061. assert end_for_ip >= 0
  3062. self._comprehension_end_for_ips.add(end_for_ip)
  3063. self._comprehension_depth += 1
  3064. return False
  3065. def BUILD_LIST(self, inst: Instruction) -> None:
  3066. if self._maybe_setup_comprehension_speculation(inst):
  3067. return
  3068. items = self.popn(inst.argval)
  3069. self.push(ListVariable(items, mutation_type=ValueMutationNew()))
  3070. def BUILD_SET(self, inst: Instruction) -> None:
  3071. if config.inject_BUILD_SET_unimplemented_TESTING_ONLY:
  3072. unimplemented(
  3073. gb_type="missing BUILD_SET handler",
  3074. context="",
  3075. explanation="Missing BUILD_SET bytecode handler (for testing purposes).",
  3076. hints=[],
  3077. )
  3078. items = self.popn(inst.argval)
  3079. new_set = SetVariable(items, mutation_type=ValueMutationNew())
  3080. self.push(new_set)
  3081. def BUILD_LIST_UNPACK(self, inst: Instruction, cls: type = ListVariable) -> None:
  3082. seqs = self.popn(inst.argval)
  3083. items = []
  3084. for seq in seqs:
  3085. try:
  3086. items.extend(seq.force_unpack_var_sequence(self))
  3087. except NotImplementedError:
  3088. unimplemented(
  3089. gb_type="Failed to unpack object for BUILD_LIST_UNPACK",
  3090. context=str(seq),
  3091. explanation=f"{seq} cannot be unpacked into a list for the BUILD_LIST_UNPACK "
  3092. "bytecode (`[*x, *y, ...]`).",
  3093. hints=[*graph_break_hints.USER_ERROR],
  3094. )
  3095. self.push(cls(items, mutation_type=ValueMutationNew()))
  3096. def BUILD_TUPLE_UNPACK(self, inst: Instruction) -> None:
  3097. self.BUILD_LIST_UNPACK(inst, cls=TupleVariable)
  3098. BUILD_TUPLE_UNPACK_WITH_CALL = BUILD_TUPLE_UNPACK
  3099. def BUILD_MAP(self, inst: Instruction) -> None:
  3100. if self._maybe_setup_comprehension_speculation(inst):
  3101. return
  3102. items = self.popn(inst.argval * 2)
  3103. d = dict(zip(items[::2], items[1::2]))
  3104. self.push(ConstDictVariable(d, mutation_type=ValueMutationNew()))
  3105. def BUILD_MAP_UNPACK(self, inst: Instruction) -> None:
  3106. items = self.popn(inst.argval)
  3107. # ensure everything is a dict
  3108. items = [BuiltinVariable(dict).call_function(self, [x], {}) for x in items] # type: ignore[arg-type]
  3109. result: dict[Any, Any] = {}
  3110. for x in items:
  3111. assert isinstance(x, ConstDictVariable)
  3112. result.update(x.items)
  3113. self.push(
  3114. ConstDictVariable(
  3115. result,
  3116. mutation_type=ValueMutationNew(),
  3117. )
  3118. )
  3119. BUILD_MAP_UNPACK_WITH_CALL = BUILD_MAP_UNPACK
  3120. def BUILD_CONST_KEY_MAP(self, inst: Instruction) -> None:
  3121. keys = self.pop()
  3122. values = self.popn(inst.argval)
  3123. assert isinstance(keys, TupleVariable)
  3124. assert keys.is_python_constant()
  3125. keys = keys.force_unpack_var_sequence(self)
  3126. assert len(keys) == len(values)
  3127. self.push(
  3128. ConstDictVariable(
  3129. dict(zip(keys, values)),
  3130. mutation_type=ValueMutationNew(),
  3131. )
  3132. )
  3133. def MAP_ADD(self, inst: Instruction) -> None:
  3134. k, v = self.popn(2)
  3135. assert inst.argval > 0
  3136. assert inst.arg is not None
  3137. obj = self.stack[-inst.arg].realize()
  3138. assert isinstance(obj, ConstDictVariable)
  3139. obj.call_method(self, "__setitem__", (k, v), {}) # type: ignore[arg-type]
  3140. def SET_ADD(self, inst: Instruction) -> None:
  3141. v = self.pop()
  3142. assert inst.argval > 0
  3143. assert inst.arg is not None
  3144. obj = self.stack[-inst.arg]
  3145. assert isinstance(obj, SetVariable)
  3146. assert obj.is_mutable()
  3147. obj.call_method(self, "add", [v], {}) # type: ignore[arg-type]
  3148. def SET_UPDATE(self, inst: Instruction) -> None:
  3149. v = self.pop()
  3150. assert inst.argval > 0
  3151. assert inst.arg is not None
  3152. obj = self.stack[-inst.arg]
  3153. assert isinstance(obj, SetVariable)
  3154. assert obj.is_mutable()
  3155. obj.call_method(self, "update", [v], {}) # type: ignore[arg-type]
  3156. def LIST_APPEND(self, inst: Instruction) -> None:
  3157. v = self.pop()
  3158. assert inst.argval > 0
  3159. assert inst.arg is not None
  3160. obj = self.stack[-inst.arg].realize()
  3161. assert isinstance(obj, ListVariable)
  3162. assert obj.is_mutable()
  3163. self.output.side_effects.mutation(obj)
  3164. obj.items.append(v)
  3165. def MAKE_FUNCTION(self, inst: Instruction) -> None:
  3166. flags = inst.arg
  3167. if sys.version_info < (3, 11):
  3168. fn_name = self.pop()
  3169. code = self.pop()
  3170. if sys.version_info >= (3, 11):
  3171. # MAKE_FUNCTION behavior actually changed in 3.11, see
  3172. # https://github.com/python/cpython/pull/93189/
  3173. assert hasattr(code.value, "co_qualname") # type: ignore[attr-defined]
  3174. fn_name = ConstantVariable.create(value=code.value.co_qualname) # type: ignore[attr-defined]
  3175. defaults = None
  3176. closure = None
  3177. annotations = None
  3178. kwdefaults = None
  3179. if sys.version_info < (3, 13):
  3180. # in 3.13, this is handled in SET_FUNCTION_ATTRIBUTE
  3181. if flags is not None:
  3182. if flags & 0x08:
  3183. closure = self.pop()
  3184. if flags & 0x04:
  3185. annotations = self.pop()
  3186. if flags & 0x02:
  3187. kwdefaults = self.pop()
  3188. if flags & 0x01:
  3189. defaults = self.pop()
  3190. self.push(
  3191. NestedUserFunctionVariable(
  3192. fn_name,
  3193. code,
  3194. self.f_globals,
  3195. defaults,
  3196. kwdefaults,
  3197. annotations,
  3198. closure,
  3199. )
  3200. )
  3201. def UNPACK_SEQUENCE(self, inst: Instruction) -> None:
  3202. seq = self.pop()
  3203. if seq.is_tensor():
  3204. val = seq.unpack_var_sequence(self, idxes=range(inst.argval)) # type: ignore[arg-type]
  3205. elif isinstance(seq, GetAttrVariable) and seq.obj.is_tensor():
  3206. # x, y = a.shape
  3207. proxy = getattr(seq.obj.as_proxy(), seq.name)
  3208. val = [wrap_fx_proxy(self, proxy[i]) for i in range(inst.argval)]
  3209. elif seq.has_force_unpack_var_sequence(self):
  3210. val = seq.force_unpack_var_sequence(self)
  3211. else:
  3212. unimplemented(
  3213. gb_type="Failed to unpack object for UNPACK_SEQUENCE",
  3214. context=str(seq),
  3215. explanation=f"{seq} cannot be unpacked into a list for the UNPACK_SEQUENCE bytecode "
  3216. "(i.e. `a, b, c = d`).",
  3217. hints=[*graph_break_hints.USER_ERROR],
  3218. )
  3219. # pyrefly: ignore [unbound-name]
  3220. if len(val) != inst.argval:
  3221. unimplemented(
  3222. gb_type="Length mismatch when unpacking object for UNPACK_SEQUENCE",
  3223. # pyrefly: ignore [unbound-name]
  3224. context=f"expected length: {inst.argval}, actual: {len(val)}",
  3225. explanation=f"{seq} unpacked to a list for the UNPACK_SEQUENCE bytecode "
  3226. "(i.e. `a, b, c = d`) with unexpected length.",
  3227. hints=[*graph_break_hints.DYNAMO_BUG],
  3228. )
  3229. # pyrefly: ignore [unbound-name]
  3230. for i in reversed(val):
  3231. self.push(i)
  3232. def UNPACK_EX(self, inst: Instruction) -> None:
  3233. assert 0 <= inst.argval <= 0xFFFF
  3234. prefix = inst.argval & 0xFF # low byte
  3235. suffix = inst.argval >> 8 # high byte
  3236. seq = self.pop()
  3237. if seq.has_force_unpack_var_sequence(self):
  3238. vals = list(seq.force_unpack_var_sequence(self))
  3239. assert len(vals) >= prefix + suffix
  3240. vals_prefix = vals[:prefix]
  3241. vals_list = vals[prefix : len(vals) - suffix]
  3242. vals_suffix = vals[len(vals) - suffix :]
  3243. for item in reversed(vals_suffix):
  3244. self.push(item)
  3245. self.push(TupleVariable(vals_list))
  3246. for item in reversed(vals_prefix):
  3247. self.push(item)
  3248. else:
  3249. unimplemented(
  3250. gb_type="Failed to unpack object for UNPACK_EX",
  3251. context=str(seq),
  3252. explanation=f"{seq} cannot be unpacked into a list for the UNPACK_EX bytecode.",
  3253. hints=[*graph_break_hints.USER_ERROR],
  3254. )
  3255. @break_graph_if_unsupported(
  3256. push=False, msg_prefix="Encountered intentional debugging graph break"
  3257. )
  3258. def graph_break_on_leaf_function(self, inst: Instruction) -> None:
  3259. if self.has_no_inlined_calls:
  3260. unimplemented(
  3261. gb_type="Forced graph break on leaf function",
  3262. context="",
  3263. explanation="Forced graph break on non-inlining function for "
  3264. "nested graph break testing purposes",
  3265. hints=[
  3266. "Set torch._dynamo.config.debug_force_graph_break_on_leaf_return = False",
  3267. ],
  3268. )
  3269. def NOP(self, inst: Instruction) -> None:
  3270. # Dynamo-specific testing behavior
  3271. if inst.argval == "GRAPH_BREAK_IF_LEAF":
  3272. self.graph_break_on_leaf_function(inst)
  3273. def POP_TOP(self, inst: Instruction) -> None:
  3274. self.pop()
  3275. def ROT_TWO(self, inst: Instruction) -> None:
  3276. a = self.pop()
  3277. b = self.pop()
  3278. self.push(a)
  3279. self.push(b)
  3280. def ROT_THREE(self, inst: Instruction) -> None:
  3281. a = self.pop()
  3282. b = self.pop()
  3283. c = self.pop()
  3284. self.push(a)
  3285. self.push(c)
  3286. self.push(b)
  3287. def ROT_FOUR(self, inst: Instruction) -> None:
  3288. a = self.pop()
  3289. b = self.pop()
  3290. c = self.pop()
  3291. d = self.pop()
  3292. self.push(a)
  3293. self.push(d)
  3294. self.push(c)
  3295. self.push(b)
  3296. def DUP_TOP(self, inst: Instruction) -> None:
  3297. a = self.pop()
  3298. self.push(a)
  3299. self.push(a)
  3300. def DUP_TOP_TWO(self, inst: Instruction) -> None:
  3301. a = self.pop()
  3302. b = self.pop()
  3303. self.push(b)
  3304. self.push(a)
  3305. self.push(b)
  3306. self.push(a)
  3307. def _convert_value(self, value: VariableTracker, flag: int) -> VariableTracker:
  3308. if flag == 1:
  3309. return BuiltinVariable(str).call_function(self, [value], {}) # type: ignore[arg-type]
  3310. elif flag == 2:
  3311. return BuiltinVariable(repr).call_function(self, [value], {}) # type: ignore[arg-type]
  3312. elif flag == 3:
  3313. return BuiltinVariable(ascii).call_function(self, [value], {}) # type: ignore[arg-type]
  3314. return value
  3315. def _format_value(self, fmt_spec: VariableTracker, flags: int) -> None:
  3316. value = self.pop()
  3317. if isinstance(value, SymNodeVariable):
  3318. from torch._dynamo.variables.lazy import (
  3319. LazySymNodeFormatString,
  3320. LazyVariableTracker,
  3321. )
  3322. value = LazyVariableTracker.create(
  3323. LazySymNodeFormatString(value, fmt_spec), source=value.source
  3324. )
  3325. self.push(value)
  3326. return
  3327. value = self._convert_value(value, flags & 0x03)
  3328. fmt_var = ConstantVariable.create("{:" + fmt_spec.as_python_constant() + "}")
  3329. self.call_function(BuiltinVariable(str.format), [fmt_var, value], {})
  3330. def FORMAT_VALUE(self, inst: Instruction) -> None:
  3331. flags = inst.arg
  3332. assert flags is not None
  3333. if (flags & 0x04) == 0x04:
  3334. fmt_spec = self.pop()
  3335. else:
  3336. fmt_spec = ConstantVariable.create("")
  3337. return self._format_value(fmt_spec, flags)
  3338. def BUILD_STRING(self, inst: Instruction) -> None:
  3339. format_string_parts: list[str] = []
  3340. args: list[VariableTracker] = []
  3341. kwargs: dict[str, VariableTracker] = {}
  3342. assert inst.arg is not None
  3343. for part in self.popn(inst.arg):
  3344. if part.is_python_constant():
  3345. format_string_parts.append("{}")
  3346. args.append(part)
  3347. elif isinstance(part, variables.StringFormatVariable):
  3348. format_string_parts.append(part.format_string)
  3349. args.extend(part.sym_args)
  3350. if set(kwargs.keys()) & set(part.sym_kwargs.keys()):
  3351. unimplemented(
  3352. gb_type="BUILD_STRING key conflict",
  3353. context=f"format_string_parts: {format_string_parts}, kwargs: {kwargs}, part.sym_kwargs: {part.sym_kwargs}",
  3354. explanation="Failed to build format string due to key conflict",
  3355. hints=[*graph_break_hints.USER_ERROR],
  3356. )
  3357. kwargs.update(part.sym_kwargs)
  3358. else:
  3359. unimplemented(
  3360. gb_type="BUILD_STRING type error",
  3361. context=str(part),
  3362. explanation="Format string part type is not correct - expected constant or format string.",
  3363. hints=[*graph_break_hints.USER_ERROR],
  3364. )
  3365. self.push(
  3366. variables.StringFormatVariable.create(
  3367. "".join(format_string_parts), args, kwargs
  3368. )
  3369. )
  3370. def IS_OP(self, inst: Instruction) -> None:
  3371. assert inst.argval == 0 or inst.argval == 1
  3372. if inst.argval == 0:
  3373. new_argval = "is"
  3374. else:
  3375. new_argval = "is not"
  3376. new_inst = create_instruction("COMPARE_OP", argval=new_argval)
  3377. self.COMPARE_OP(new_inst)
  3378. def CONTAINS_OP(self, inst: Instruction) -> None:
  3379. assert inst.argval == 0 or inst.argval == 1
  3380. left, right = self.popn(2)
  3381. op = inst.argval
  3382. try:
  3383. self.push(right.call_method(self, "__contains__", [left], {}))
  3384. except (
  3385. # right.__contains__ can raise TypeError
  3386. exc.ObservedTypeError,
  3387. # Ideally we should only capture TypeError here but some VTs don't
  3388. # implement hasattr(vt, "__contains__") entirely
  3389. Unsupported,
  3390. ) as excp: # object doesn't support __contains__
  3391. # Use __iter__ as fallback
  3392. if isinstance(excp, Unsupported):
  3393. if excp.skip_frame:
  3394. # do not absorb graph break with skip_frame set
  3395. raise
  3396. excp.remove_from_stats()
  3397. self.push(
  3398. self.inline_user_function_return(
  3399. VariableTracker.build(self, impl_CONTAINS_OP_fallback),
  3400. [left, right],
  3401. {},
  3402. )
  3403. )
  3404. if op == 1:
  3405. self.UNARY_NOT(inst)
  3406. def LIST_EXTEND(self, inst: Instruction) -> None:
  3407. v = self.pop()
  3408. assert inst.argval > 0
  3409. assert inst.arg is not None
  3410. obj = self.stack[-inst.arg]
  3411. assert isinstance(obj, ListVariable)
  3412. assert obj.is_mutable()
  3413. obj.call_method(self, "extend", [v], {}) # type: ignore[arg-type]
  3414. def LIST_TO_TUPLE(self, inst: Instruction) -> None:
  3415. self.push(BuiltinVariable(tuple).call_function(self, [self.pop()], {})) # type: ignore[arg-type]
  3416. def STOPITERATION_ERROR(self, inst: Instruction) -> None:
  3417. # wrap the generator body in a try: ... except StopIteration: ... which
  3418. # converts the StopIteration into a RuntimeError
  3419. # https://peps.python.org/pep-0479/
  3420. # https://github.com/python/cpython/pull/99006
  3421. # https://github.com/python/cpython/commit/28187141cc34063ef857976ddbca87ba09a882c2
  3422. val = self.stack[-1]
  3423. assert self._isinstance_exception(val)
  3424. if val.exc_type is StopIteration: # type: ignore[union-attr]
  3425. new_val = variables.BuiltinVariable(RuntimeError).call_function(
  3426. self, # type: ignore[arg-type]
  3427. [ConstantVariable("generator raised StopIteration")],
  3428. {},
  3429. )
  3430. new_val.call_setattr(self, ConstantVariable("__context__"), val) # type: ignore[attr-defined]
  3431. new_val.call_setattr(self, ConstantVariable("__cause__"), val) # type: ignore[attr-defined]
  3432. self.stack[-1] = new_val
  3433. def DICT_MERGE(self, inst: Instruction) -> None:
  3434. v = self.pop()
  3435. assert inst.argval > 0
  3436. assert inst.arg is not None
  3437. obj = self.stack[-inst.arg].realize()
  3438. assert isinstance(obj, ConstDictVariable)
  3439. assert obj.is_mutable()
  3440. obj.call_method(self, "update", [v], {}) # type: ignore[arg-type]
  3441. DICT_UPDATE = DICT_MERGE
  3442. def GEN_START(self, inst: Instruction) -> None:
  3443. self.pop()
  3444. def GET_LEN(self, inst: Instruction) -> None:
  3445. tos = self.stack[-1]
  3446. if tos.is_python_constant():
  3447. self.push(ConstantVariable.create(len(tos.as_python_constant())))
  3448. else:
  3449. self.push(tos.call_method(self, "__len__", [], {}))
  3450. def MATCH_MAPPING(self, inst: Instruction) -> None:
  3451. """
  3452. If STACK[-1] is an instance of collections.abc.Mapping, push True.
  3453. Otherwise, push False
  3454. """
  3455. tos = self.stack[-1]
  3456. self.push(
  3457. self.inline_user_function_return(
  3458. VariableTracker.build(self, impl_IS_MAPPING),
  3459. [tos],
  3460. {},
  3461. )
  3462. )
  3463. def MATCH_SEQUENCE(self, inst: Instruction) -> None:
  3464. tos = self.stack[-1]
  3465. self.push(
  3466. self.inline_user_function_return(
  3467. VariableTracker.build(self, impl_MATCH_SEQUENCE),
  3468. [tos],
  3469. {},
  3470. )
  3471. )
  3472. def MATCH_CLASS(self, inst: Instruction) -> None:
  3473. subject, cls, names = self.popn(3)
  3474. arg = ConstantVariable.create(inst.arg)
  3475. self.push(
  3476. self.inline_user_function_return(
  3477. VariableTracker.build(self, impl_MATCH_CLASS),
  3478. [subject, cls, arg, names],
  3479. {},
  3480. )
  3481. )
  3482. if sys.version_info < (3, 11):
  3483. # for versions < 3.11, also push the boolean result
  3484. tos = self.stack[-1]
  3485. self.push(ConstantVariable.create(not istype(tos, ConstantVariable)))
  3486. def MATCH_KEYS(self, inst: Instruction) -> None:
  3487. keys = self.stack[-1]
  3488. obj = self.stack[-2]
  3489. assert isinstance(keys, TupleVariable)
  3490. self.push(
  3491. self.inline_user_function_return(
  3492. VariableTracker.build(self, impl_MATCH_KEYS), [obj, keys], {}
  3493. )
  3494. )
  3495. if sys.version_info < (3, 11):
  3496. # for versions < 3.11, also push the boolean result
  3497. tos = self.stack[-1]
  3498. self.push(ConstantVariable.create(not istype(tos, ConstantVariable)))
  3499. def LOAD_ASSERTION_ERROR(self, inst: Instruction) -> None:
  3500. self.push(self.load_builtin_from_argval("AssertionError"))
  3501. def LOAD_BUILD_CLASS(self, inst: Instruction) -> None:
  3502. self.push(self.load_builtin_from_argval("__build_class__"))
  3503. UNARY_POSITIVE = stack_op(operator.pos)
  3504. UNARY_NEGATIVE = stack_op(operator.neg)
  3505. UNARY_NOT = stack_op(operator.not_)
  3506. UNARY_INVERT = stack_op(operator.invert)
  3507. BINARY_POWER = stack_op(operator.pow)
  3508. BINARY_MULTIPLY = stack_op(operator.mul)
  3509. BINARY_MATRIX_MULTIPLY = stack_op(operator.matmul)
  3510. BINARY_FLOOR_DIVIDE = stack_op(operator.floordiv)
  3511. BINARY_TRUE_DIVIDE = stack_op(operator.truediv)
  3512. BINARY_MODULO = stack_op(operator.mod)
  3513. BINARY_REMAINDER = stack_op(operator.mod)
  3514. BINARY_ADD = stack_op(operator.add)
  3515. BINARY_SUBTRACT = stack_op(operator.sub)
  3516. BINARY_SUBSCR = break_graph_if_unsupported(
  3517. push=True,
  3518. msg_prefix="Encountered graph break when attempting to trace BINARY_SUBSCR: a binary subscript, e.g. x[attr]",
  3519. )(stack_op(operator.getitem))
  3520. BINARY_LSHIFT = stack_op(operator.lshift)
  3521. BINARY_RSHIFT = stack_op(operator.rshift)
  3522. BINARY_AND = stack_op(operator.and_)
  3523. BINARY_OR = stack_op(operator.or_)
  3524. BINARY_XOR = stack_op(operator.xor)
  3525. INPLACE_POWER = stack_op(operator.ipow)
  3526. INPLACE_MULTIPLY = stack_op(operator.imul)
  3527. INPLACE_MATRIX_MULTIPLY = stack_op(operator.imatmul)
  3528. INPLACE_FLOOR_DIVIDE = stack_op(operator.ifloordiv)
  3529. INPLACE_TRUE_DIVIDE = stack_op(operator.itruediv)
  3530. INPLACE_MODULO = stack_op(operator.imod)
  3531. INPLACE_REMAINDER = stack_op(operator.imod)
  3532. INPLACE_ADD = stack_op(operator.iadd)
  3533. INPLACE_SUBTRACT = stack_op(operator.isub)
  3534. INPLACE_LSHIFT = stack_op(operator.ilshift)
  3535. INPLACE_RSHIFT = stack_op(operator.irshift)
  3536. INPLACE_AND = stack_op(operator.iand)
  3537. INPLACE_XOR = stack_op(operator.ixor)
  3538. INPLACE_OR = stack_op(operator.ior)
  3539. # 3.11 opcodes
  3540. def RESUME(self, inst: Instruction) -> None:
  3541. if inst.arg == 0:
  3542. self.append_prefix_inst(inst)
  3543. self.accept_prefix_inst = False
  3544. else:
  3545. assert not self.accept_prefix_inst
  3546. if sys.version_info >= (3, 11):
  3547. def BINARY_OP(self, inst: Instruction) -> None:
  3548. assert inst.arg is not None
  3549. return _binary_op_lookup[inst.arg](self, inst)
  3550. def PRECALL(self, inst: Instruction) -> None:
  3551. pass
  3552. def KW_NAMES(self, inst: Instruction) -> None:
  3553. kw_names = self.code_options["co_consts"][inst.arg]
  3554. assert isinstance(kw_names, tuple)
  3555. for name in kw_names:
  3556. assert isinstance(name, str)
  3557. assert self.kw_names is None
  3558. self.kw_names = ConstantVariable.create(value=kw_names) # type: ignore[assignment]
  3559. def PUSH_NULL(self, inst: Instruction) -> None:
  3560. self.push(NullVariable())
  3561. def _call(self, inst: Instruction, call_kw: bool = False) -> None:
  3562. # see https://docs.python.org/3.11/library/dis.html#opcode-CALL
  3563. # for convention
  3564. if call_kw:
  3565. # TOS is kw_names for CALL_KW instruction
  3566. assert sys.version_info >= (3, 13)
  3567. kw_names = self.pop()
  3568. assert isinstance(kw_names, TupleVariable) and kw_names.is_python_constant()
  3569. kw_names = kw_names.as_python_constant()
  3570. else:
  3571. kw_names = self.kw_names.value if self.kw_names else ()
  3572. assert inst.arg is not None
  3573. contents = self.popn(inst.arg + 2)
  3574. if sys.version_info >= (3, 13):
  3575. # NULL and callable swapped
  3576. fn = contents[0]
  3577. args = [] if isinstance(contents[1], NullVariable) else [contents[1]]
  3578. else:
  3579. if isinstance(contents[0], NullVariable):
  3580. fn = contents[1]
  3581. # pyrefly: ignore [implicit-any]
  3582. args = []
  3583. else:
  3584. fn = contents[0]
  3585. args = [contents[1]]
  3586. if kw_names:
  3587. args = args + contents[2 : -len(kw_names)]
  3588. kwargs_list = contents[-len(kw_names) :]
  3589. kwargs = dict(zip(kw_names, kwargs_list))
  3590. assert len(kwargs) == len(kw_names)
  3591. else:
  3592. args = args + contents[2:]
  3593. # pyrefly: ignore [implicit-any]
  3594. kwargs = {}
  3595. try:
  3596. # if call_function fails, need to set kw_names to None, otherwise
  3597. # a subsequent call may have self.kw_names set to an old value
  3598. self.call_function(fn, args, kwargs)
  3599. finally:
  3600. self.kw_names = None
  3601. @break_graph_if_unsupported(
  3602. push=True,
  3603. msg_prefix="Encountered graph break when attempting to trace CALL: a function call, e.g. f(x, y)",
  3604. )
  3605. def CALL(self, inst: Instruction) -> None:
  3606. self._call(inst)
  3607. def COPY(self, inst: Instruction) -> None:
  3608. assert inst.arg is not None
  3609. self.push(self.stack[-inst.arg])
  3610. def SWAP(self, inst: Instruction) -> None:
  3611. assert inst.arg is not None
  3612. self.stack[-1], self.stack[-inst.arg] = self.stack[-inst.arg], self.stack[-1]
  3613. JUMP_BACKWARD = jump
  3614. JUMP_BACKWARD_NO_INTERRUPT = jump
  3615. POP_JUMP_FORWARD_IF_TRUE = generic_jump(operator.truth, False)
  3616. POP_JUMP_BACKWARD_IF_TRUE = generic_jump(operator.truth, False)
  3617. POP_JUMP_FORWARD_IF_FALSE = generic_jump(operator.not_, False)
  3618. POP_JUMP_BACKWARD_IF_FALSE = generic_jump(operator.not_, False)
  3619. def CACHE(self, inst: Instruction) -> None:
  3620. pass
  3621. def BEFORE_WITH(self, inst: Instruction) -> None:
  3622. self.setup_or_before_with(inst)
  3623. def enter_ctx(
  3624. self,
  3625. ctx: Union[ContextWrappingVariable, GenericContextWrappingVariable],
  3626. inst: Instruction,
  3627. ) -> VariableTracker:
  3628. if (
  3629. isinstance(ctx, GenericContextWrappingVariable)
  3630. and not ctx.supports_graph_breaks()
  3631. ):
  3632. self.active_generic_context_managers.append(ctx)
  3633. if sys.version_info >= (3, 11):
  3634. # See update_block_stack/create_resume for block stack details.
  3635. # Only push a block if the current instruction's block is a
  3636. # with block that is not nested in a try block - that is, the current
  3637. # instruction's block target is the same as the top block's target.
  3638. if inst.exn_tab_entry and (
  3639. not self.block_stack
  3640. or inst.exn_tab_entry.target is not self.block_stack[-1].target
  3641. ):
  3642. target = None
  3643. else:
  3644. assert self.next_instruction.exn_tab_entry is not None
  3645. target = self.next_instruction.exn_tab_entry.target
  3646. else:
  3647. target = inst.target
  3648. if target:
  3649. if isinstance(self, InstructionTranslator) or config.nested_graph_breaks:
  3650. self.block_stack.append(
  3651. BlockStackEntry(inst, target, len(self.stack), ctx)
  3652. )
  3653. else:
  3654. self.block_stack.append(BlockStackEntry(inst, target, len(self.stack)))
  3655. return ctx.enter(self) # type: ignore[arg-type]
  3656. @staticmethod
  3657. def unsupported_ctx_graph_break(ctx: VariableTracker) -> NoReturn:
  3658. unimplemented(
  3659. gb_type="Unsupported context manager",
  3660. context=f"Attempted SETUP_WITH/BEFORE_WITH/LOAD_SPECIAL on {ctx}",
  3661. explanation=f"Dynamo does not know how to enter a `{ctx.python_type_name()}` context manager.",
  3662. hints=[
  3663. "Avoid using the unsupported context manager.",
  3664. "If the context manager seems like it should be supported (e.g. torch.set_grad_enabled), then "
  3665. "it may be the case that it was created outside the compiled region, which Dynamo does not support. "
  3666. "Supported context managers can cross graph break boundaries only if they are local non-closure "
  3667. "variables, or are intermediate values.",
  3668. "File an issue to PyTorch. Simple context managers can potentially be supported, "
  3669. "but note that context managers can't be supported in general",
  3670. ],
  3671. )
  3672. def setup_or_before_with(self, inst: Instruction) -> None:
  3673. ctx = self.pop()
  3674. if not isinstance(
  3675. ctx, (ContextWrappingVariable, GenericContextWrappingVariable)
  3676. ):
  3677. self.unsupported_ctx_graph_break(ctx)
  3678. # Need this redundant check for mypy
  3679. assert isinstance(
  3680. ctx, (ContextWrappingVariable, GenericContextWrappingVariable)
  3681. )
  3682. self.push(WithExitFunctionVariable(ctx, inst.target))
  3683. self.push(self.enter_ctx(ctx, inst))
  3684. def append_prefix_inst(self, inst: Instruction) -> None:
  3685. assert self.accept_prefix_inst
  3686. self.prefix_insts.append(inst)
  3687. def MAKE_CELL(self, inst: Instruction) -> None:
  3688. if sys.version_info >= (3, 12) and not self.accept_prefix_inst:
  3689. # In 3.12+, MAKE_CELL is not longer necessarily a prefix instruction.
  3690. # It can be generated by inlined comprehensions.
  3691. assert isinstance(self.symbolic_locals[inst.argval], NullVariable)
  3692. self.symbolic_locals[inst.argval] = (
  3693. self.output.side_effects.track_cell_new()
  3694. )
  3695. else:
  3696. self.append_prefix_inst(inst)
  3697. def COPY_FREE_VARS(self, inst: Instruction) -> None:
  3698. self.append_prefix_inst(inst)
  3699. def RETURN_GENERATOR(self, inst: Instruction) -> None:
  3700. self.append_prefix_inst(inst)
  3701. # 3.12 opcodes
  3702. # BINARY/STORE_SLICE opcodes are broken down into
  3703. # BUILD_SLICE 2 and BINARY/STORE_SUBSCR
  3704. def END_FOR(self, inst: Instruction) -> None:
  3705. if sys.version_info >= (3, 13):
  3706. self.pop()
  3707. else:
  3708. self.popn(2)
  3709. # Decrement comprehension depth if exiting a comprehension layer
  3710. if sys.version_info >= (3, 12):
  3711. current_ip = self.indexof[inst]
  3712. if current_ip in self._comprehension_end_for_ips:
  3713. self._comprehension_end_for_ips.discard(current_ip)
  3714. self._comprehension_depth -= 1
  3715. def LOAD_FAST_CHECK(self, inst: Instruction) -> None:
  3716. if istype(self.symbolic_locals.get(inst.argval, None), NullVariable):
  3717. unimplemented(
  3718. gb_type="LOAD_FAST_CHECK on uninitialized variable",
  3719. context=inst.argval,
  3720. explanation=f"Attempted to load uninitialized local variable {inst.argval}",
  3721. hints=[*graph_break_hints.USER_ERROR],
  3722. )
  3723. self.LOAD_FAST(inst)
  3724. def LOAD_FAST_AND_CLEAR(self, inst: Instruction) -> None:
  3725. if inst.argval not in self.symbolic_locals:
  3726. self.push(NullVariable())
  3727. else:
  3728. self.LOAD_FAST(inst)
  3729. self.symbolic_locals[inst.argval] = NullVariable()
  3730. def LOAD_SUPER_ATTR(self, inst: Instruction) -> None:
  3731. self.CALL_FUNCTION(dataclasses.replace(inst, argval=2))
  3732. assert inst.arg is not None
  3733. if inst.arg & 1:
  3734. self.LOAD_METHOD(inst)
  3735. else:
  3736. self._load_attr(inst.argval)
  3737. def CALL_INTRINSIC_1(self, inst: Instruction) -> None:
  3738. if inst.argval == 3:
  3739. # INTRINSIC_STOPITERATION_ERROR
  3740. self.STOPITERATION_ERROR(inst)
  3741. elif inst.argval == 5:
  3742. # INTRINSIC_UNARY_POSITIVE
  3743. self.UNARY_POSITIVE(inst)
  3744. elif inst.argval == 6:
  3745. # INTRINSIC_LIST_TO_TUPLE
  3746. self.push(TupleVariable(self.pop().force_unpack_var_sequence(self)))
  3747. else:
  3748. unimplemented(
  3749. gb_type="Missing CALL_INTRINSIC_1 handler",
  3750. context=f"CALL_INTRINSIC_1 operand: {inst.argval}",
  3751. explanation=f"No handler implemented for CALL_INTRINSIC_1 {inst.argval} instruction.",
  3752. hints=[*graph_break_hints.SUPPORTABLE],
  3753. )
  3754. def END_SEND(self, inst: Instruction) -> None:
  3755. tos = self.pop()
  3756. self.pop()
  3757. self.push(tos)
  3758. # 3.13 opcodes
  3759. # fused instructions LOAD_FAST_LOAD_FAST, STORE_FAST_STORE_FAST, STORE_FAST_LOAD_FAST
  3760. # are broken down.
  3761. @break_graph_if_unsupported(
  3762. push=True,
  3763. msg_prefix="Encountered graph break when attempting to trace CALL_KW: "
  3764. "a function call with keyword arguments, e.g. f(x=True)",
  3765. )
  3766. def CALL_KW(self, inst: Instruction) -> None:
  3767. self._call(inst, call_kw=True)
  3768. def TO_BOOL(self, inst: Instruction) -> None:
  3769. # TO_BOOL only precedes a conditional jump or UNARY_NOT (see compile.c in CPython)
  3770. # So we can skip this instruction as long as we remember to codegen a TO_BOOL
  3771. # before conditional jumps/UNARY_NOT.
  3772. assert self.next_instruction.opname in (
  3773. "POP_JUMP_IF_TRUE",
  3774. "POP_JUMP_IF_FALSE",
  3775. "UNARY_NOT",
  3776. )
  3777. def SET_FUNCTION_ATTRIBUTE(self, inst: Instruction) -> None:
  3778. flags = inst.arg
  3779. assert flags is not None
  3780. fn = self.pop()
  3781. assert isinstance(fn, NestedUserFunctionVariable)
  3782. attr = self.pop()
  3783. if flags & 0x10:
  3784. assert sys.version_info >= (3, 14)
  3785. # maybe use Format.VALUE_WITH_FAKE_GLOBALS instead?
  3786. # https://docs.python.org/3/library/annotationlib.html#annotationlib.Format.VALUE_WITH_FAKE_GLOBALS
  3787. attr = attr.call_function(self, [ConstantVariable.create(1)], {})
  3788. fn.annotations = attr
  3789. elif flags & 0x08:
  3790. fn.closure = attr
  3791. elif flags & 0x04:
  3792. assert isinstance(attr, TupleVariable)
  3793. # Convert the attribute to a dictionary before assigning it
  3794. # https://github.com/python/cpython/blob/28fb13cb33d569720938258db68956b5f9c9eb40/Objects/funcobject.c#L574-L594
  3795. items = attr.items
  3796. ann = ConstDictVariable(
  3797. dict(zip(items[::2], items[1::2], strict=True)),
  3798. mutation_type=ValueMutationNew(),
  3799. )
  3800. fn.annotations = ann
  3801. elif flags & 0x02:
  3802. fn.kwdefaults = attr
  3803. elif flags & 0x01:
  3804. fn.defaults = attr
  3805. self.push(fn)
  3806. def CONVERT_VALUE(self, inst: Instruction) -> None:
  3807. self.push(self._convert_value(self.pop(), inst.argval))
  3808. def FORMAT_SIMPLE(self, inst: Instruction) -> None:
  3809. self._format_value(ConstantVariable.create(""), 0)
  3810. def FORMAT_WITH_SPEC(self, inst: Instruction) -> None:
  3811. self._format_value(self.pop(), 0)
  3812. # 3.14 opcodes
  3813. LOAD_FAST_BORROW = LOAD_FAST
  3814. NOT_TAKEN = NOP
  3815. POP_ITER = POP_TOP
  3816. # See
  3817. # https://github.com/python/cpython/blob/805e3368d6d07e58430654d1365283924fdf4143/Python/ceval.c#L559
  3818. # for the LOAD_SPECIAL table - make sure it matches for Python 3.14+
  3819. _load_special_names = (
  3820. "__enter__",
  3821. "__exit__",
  3822. "__aenter__",
  3823. "__aexit__",
  3824. )
  3825. def LOAD_SPECIAL(self, inst: Instruction) -> None:
  3826. assert isinstance(inst.arg, int), "expected LOAD_SPECIAL arg to be set to int"
  3827. attr = self._load_special_names[inst.arg]
  3828. if attr in ("__enter__", "__exit__"):
  3829. ctx = self.pop()
  3830. if not isinstance(
  3831. ctx, (ContextWrappingVariable, GenericContextWrappingVariable)
  3832. ):
  3833. self.unsupported_ctx_graph_break(ctx)
  3834. # Need this redundant check for mypy
  3835. assert isinstance(
  3836. ctx, (ContextWrappingVariable, GenericContextWrappingVariable)
  3837. )
  3838. if attr == "__enter__":
  3839. self.push(WithEnterFunctionVariable(ctx))
  3840. self.PUSH_NULL(inst)
  3841. else:
  3842. # WithExitFunctionVariable doesn't really do anything with target for 3.11+
  3843. self.push(WithExitFunctionVariable(ctx, None))
  3844. self.PUSH_NULL(inst)
  3845. else:
  3846. # Implementation is similar to LOAD_METHOD for 3.13+
  3847. self._load_attr(attr)
  3848. obj = self.pop()
  3849. self.push(obj)
  3850. self.PUSH_NULL(inst)
  3851. def LOAD_SMALL_INT(self, inst: Instruction) -> None:
  3852. self.push(ConstantVariable.create(inst.argval))
  3853. # See
  3854. # https://github.com/python/cpython/blob/7519ac294fc5c4fd7fb9cb8dc0edc960688cf887/Python/pylifecycle.c#L814
  3855. # for the common constants - make sure it matches for Python 3.14+.
  3856. # The common constants are all attributes of `builtins`.
  3857. _common_constants = (
  3858. "AssertionError",
  3859. "NotImplementedError",
  3860. "tuple",
  3861. "all",
  3862. "any",
  3863. )
  3864. def LOAD_COMMON_CONSTANT(self, inst: Instruction) -> None:
  3865. assert isinstance(inst.arg, int), (
  3866. "expected LOAD_COMMON_CONSTANT arg to be set to int"
  3867. )
  3868. self.push(self.load_builtin_from_argval(self._common_constants[inst.arg]))
  3869. def is_non_empty_graph(self) -> bool:
  3870. if self.output.count_calls() > 1:
  3871. # perf optimization only
  3872. self.is_non_empty_graph = lambda: True # type: ignore[method-assign]
  3873. return True
  3874. return False
  3875. def format_frame_summary(
  3876. self, additional_stack_frames: Optional[list[Any]] = None
  3877. ) -> str:
  3878. if additional_stack_frames is None:
  3879. additional_stack_frames = []
  3880. return "".join(
  3881. traceback.format_list(
  3882. [self.frame_summary()] + list(reversed(additional_stack_frames))
  3883. )
  3884. )
  3885. def frame_summary(self) -> traceback.FrameSummary:
  3886. return traceback.FrameSummary(
  3887. getattr(self.f_code, "co_filename", "<unknown>"),
  3888. self.lineno,
  3889. getattr(self.f_code, "co_name", "<unknown>"),
  3890. lookup_line=False,
  3891. )
  3892. def is_co_filename_from_nn_modules(self) -> bool:
  3893. filename = getattr(self.f_code, "co_filename", "<unknown>")
  3894. nn_modules_pattern = re.compile(r".*torch/nn/modules.*")
  3895. return nn_modules_pattern.match(filename) is not None
  3896. def store_global_weakref_by_id(self, prefix: str, value: Any) -> str:
  3897. global_name = self.output.install_global_by_id(prefix, weakref.ref(value))
  3898. install_guard(
  3899. GlobalWeakRefSource(global_name).make_guard(GuardBuilder.WEAKREF_ALIVE)
  3900. )
  3901. return global_name
  3902. @property
  3903. def fake_mode(self) -> Optional[FakeTensorMode]:
  3904. return self.output.tracing_context.fake_mode
  3905. @contextlib.contextmanager
  3906. def strict_translation_mode(
  3907. self, check_fn: Callable[[VariableTracker], bool]
  3908. ) -> Any:
  3909. """
  3910. Strict mode is enabled on a per-VariableTracker level depending on the return value of check_fn(node).
  3911. """
  3912. prior = self.strict_checks_fn
  3913. self.strict_checks_fn = check_fn
  3914. try:
  3915. yield
  3916. finally:
  3917. self.strict_checks_fn = prior
  3918. def speculate(self) -> SpeculationEntry:
  3919. assert self.instruction_pointer is not None
  3920. assert self.instruction_pointer > 0
  3921. return self.speculation_log.next(
  3922. self.f_code.co_filename,
  3923. self.lineno,
  3924. self.instruction_pointer - 1,
  3925. self.instructions[self.instruction_pointer - 1],
  3926. )
  3927. def _is_comprehension_start(self) -> bool:
  3928. """Detect if we're at the start of a list/dict comprehension in 3.12+.
  3929. In Python 3.12+, comprehensions are inlined with a bytecode pattern that
  3930. precedes BUILD_LIST/BUILD_MAP.
  3931. """
  3932. assert sys.version_info >= (3, 12)
  3933. assert self.instruction_pointer is not None
  3934. ip = self.instruction_pointer - 1
  3935. pattern = _get_comprehension_bytecode_prefix()
  3936. prefix = [inst.opname for inst in self.instructions[ip - len(pattern) : ip]]
  3937. return prefix == pattern
  3938. def _find_comprehension_end_for_ip(self) -> int:
  3939. """Find the instruction pointer of the outermost END_FOR for current comprehension."""
  3940. assert sys.version_info >= (3, 12)
  3941. assert self.instruction_pointer is not None
  3942. nesting_depth = 0
  3943. for search_ip in range(self.instruction_pointer, len(self.instructions)):
  3944. inst = self.instructions[search_ip]
  3945. if inst.opname == "FOR_ITER":
  3946. nesting_depth += 1
  3947. elif inst.opname == "END_FOR":
  3948. nesting_depth -= 1
  3949. if nesting_depth == 0:
  3950. return search_ip
  3951. return -1
  3952. def _analyze_comprehension(self) -> ComprehensionAnalysis:
  3953. """Analyze comprehension bytecode to determine result handling pattern."""
  3954. assert sys.version_info >= (3, 12)
  3955. assert self.instruction_pointer is not None
  3956. patterns = _get_comprehension_result_patterns()
  3957. start_ip = self.instruction_pointer - 1 # BUILD_LIST/BUILD_MAP
  3958. iterator_vars: list[str] = []
  3959. walrus_vars: list[str] = []
  3960. captured_vars: list[str] = []
  3961. defined_inside: set[str] = set()
  3962. # Collect iterator variables from LOAD_FAST_AND_CLEAR before BUILD_LIST/BUILD_MAP
  3963. iter_scan_ip = start_ip - 1
  3964. while iter_scan_ip >= 0:
  3965. inst = self.instructions[iter_scan_ip]
  3966. if inst.opname == "LOAD_FAST_AND_CLEAR":
  3967. iterator_vars.insert(0, inst.argval)
  3968. iter_scan_ip -= 1
  3969. elif inst.opname in ("SWAP", "GET_ITER"):
  3970. iter_scan_ip -= 1
  3971. else:
  3972. break
  3973. defined_inside.update(iterator_vars)
  3974. end_for_ip = self._find_comprehension_end_for_ip()
  3975. if end_for_ip == -1:
  3976. unimplemented(
  3977. gb_type="Comprehension analysis failed: No END_FOR",
  3978. context="",
  3979. explanation="Could not find END_FOR instruction in comprehension bytecode.",
  3980. hints=[],
  3981. )
  3982. # Find first FOR_ITER to know where loop body starts
  3983. for_iter_ip = next(
  3984. i
  3985. for i in range(start_ip, end_for_ip)
  3986. if self.instructions[i].opname == "FOR_ITER"
  3987. )
  3988. # Single pass through loop body to detect walrus vars and captured vars
  3989. for body_ip in range(for_iter_ip + 1, end_for_ip):
  3990. inst = self.instructions[body_ip]
  3991. # Detect walrus pattern: COPY 1 followed by STORE_FAST
  3992. if inst.opname == "COPY" and inst.arg == 1 and body_ip + 1 < end_for_ip:
  3993. next_inst = self.instructions[body_ip + 1]
  3994. if next_inst.opname == "STORE_FAST":
  3995. var_name = next_inst.argval
  3996. if var_name not in iterator_vars and var_name not in walrus_vars:
  3997. walrus_vars.append(var_name)
  3998. defined_inside.add(var_name)
  3999. # Track variables defined inside the loop
  4000. if inst.opname == "STORE_FAST":
  4001. defined_inside.add(inst.argval)
  4002. # Detect LOAD_FAST referencing outer variables
  4003. elif inst.opname.startswith("LOAD_FAST"):
  4004. var_names = (
  4005. inst.argval if isinstance(inst.argval, tuple) else (inst.argval,)
  4006. )
  4007. for var_name in var_names:
  4008. if var_name not in defined_inside and var_name not in captured_vars:
  4009. captured_vars.append(var_name)
  4010. # Extract pre_store_ops: all opcodes from END_FOR+1 until first STORE_FAST
  4011. pre_store_ops: list[str] = []
  4012. scan_ip = end_for_ip + 1
  4013. while (
  4014. scan_ip < len(self.instructions)
  4015. and self.instructions[scan_ip].opname != "STORE_FAST"
  4016. ):
  4017. pre_store_ops.append(self.instructions[scan_ip].opname)
  4018. scan_ip += 1
  4019. store_fast_ip = scan_ip
  4020. # Skip all STORE_FASTs to find post_store_op
  4021. while (
  4022. scan_ip < len(self.instructions)
  4023. and self.instructions[scan_ip].opname == "STORE_FAST"
  4024. ):
  4025. scan_ip += 1
  4026. post_store_op = (
  4027. self.instructions[scan_ip].opname
  4028. if scan_ip < len(self.instructions)
  4029. else None
  4030. )
  4031. def matches(name: str) -> bool:
  4032. pat = patterns[name]
  4033. return pre_store_ops == pat["pre_store_ops"] and (
  4034. post_store_op == pat["post_store_op"] or not pat["post_store_op"]
  4035. )
  4036. result_var: Optional[str] = None
  4037. if matches("stored"):
  4038. result_var = self.instructions[store_fast_ip].argval
  4039. result_on_stack = False
  4040. elif matches("discarded"):
  4041. result_var = None
  4042. result_on_stack = False
  4043. scan_ip = scan_ip + 1 if patterns["discarded"]["post_store_op"] else scan_ip
  4044. elif (
  4045. matches("returned")
  4046. or pre_store_ops == patterns["consumed"]["pre_store_ops"]
  4047. ):
  4048. result_var = None
  4049. result_on_stack = True
  4050. else:
  4051. unimplemented(
  4052. gb_type="Comprehension analysis failed: No matches",
  4053. context=f"pre_store_ops={pre_store_ops}, post_store_op={post_store_op}",
  4054. explanation="Comprehension does not match any known bytecode pattern.",
  4055. hints=[],
  4056. )
  4057. return ComprehensionAnalysis(
  4058. end_ip=scan_ip,
  4059. result_var=result_var,
  4060. # pyrefly: ignore [unbound-name]
  4061. result_on_stack=result_on_stack,
  4062. iterator_vars=iterator_vars,
  4063. walrus_vars=walrus_vars,
  4064. captured_vars=captured_vars,
  4065. )
  4066. def _handle_comprehension_graph_break(self, inst: Instruction) -> None:
  4067. """Handle list/dict comprehension graph break.
  4068. Builds a synthetic function wrapping the comprehension bytecode,
  4069. calls it via codegen_call_resume, then chains into the resume
  4070. function for the post-comprehension code.
  4071. """
  4072. assert sys.version_info >= (3, 12)
  4073. analysis = self._analyze_comprehension()
  4074. # Validate: can't handle captured vars in resume functions due to nested sources
  4075. if self.f_code.co_name.startswith(TORCH_DYNAMO_RESUME_IN_PREFIX):
  4076. if analysis.captured_vars:
  4077. unimplemented(
  4078. gb_type="Comprehension graph break in resume function with captured variables",
  4079. context=str(analysis.captured_vars),
  4080. explanation="Cannot use comprehension optimization inside a resume "
  4081. "function when there are captured variables. This can cause issues "
  4082. "with deeply nested source chains.",
  4083. hints=[],
  4084. )
  4085. assert self.instruction_pointer is not None
  4086. start_ip = self.instruction_pointer - 1 # BUILD_LIST/BUILD_MAP
  4087. stack_pops = 1 + len(analysis.iterator_vars)
  4088. reason = GraphCompileReason("comprehension_graph_break", [self.frame_summary()])
  4089. log.debug("comprehension triggered compile")
  4090. # --- Step 1: Compile the graph up to the comprehension ---
  4091. all_stack_locals_metadata = self.output.compile_subgraph(
  4092. self,
  4093. reason=reason,
  4094. stack_pops=stack_pops,
  4095. )
  4096. # Record which stack_pops items are NULL before popn loses the info.
  4097. # NULLs on the CPython stack can't be passed as function arguments.
  4098. stack_pops_null_mask = [
  4099. isinstance(self.stack[len(self.stack) - stack_pops + i], NullVariable)
  4100. for i in range(stack_pops)
  4101. ]
  4102. self.popn(stack_pops)
  4103. meta = all_stack_locals_metadata[0]
  4104. cg = PyCodegen(self.output.root_tx)
  4105. # Runtime stack after compile_subgraph:
  4106. # cells, [frame_values], *(non-popped items), *(stack_pops items w/ NULLs)
  4107. # frame_values[0] = [frame N locals] (no stack items yet)
  4108. nonnull_count = sum(1 for m in stack_pops_null_mask if not m)
  4109. # live_stack_depth: stack items above cells/frame_values excluding NULLs
  4110. # that compile_subgraph didn't codegen (tracked in stack_null_idxes).
  4111. live_stack_depth = len(self.stack) - len(meta.stack_null_idxes)
  4112. # --- Step 2: Pop stack_pops items and append non-nulls to frame_values[0] ---
  4113. # SWAP each item to TOS then LIST_APPEND or pop_null; fv_list stays at
  4114. # TOS throughout. Items append in TOS-first (reversed) order;
  4115. # _build_comprehension_fn compensates by loading in reverse.
  4116. cg.extend_output(
  4117. [
  4118. # frame_values[0] to TOS
  4119. *create_copy(live_stack_depth + stack_pops + 1),
  4120. cg.create_load_const(0),
  4121. cg.create_binary_subscr(),
  4122. ]
  4123. )
  4124. for i in reversed(range(stack_pops)):
  4125. cg.extend_output(create_swap(2))
  4126. if stack_pops_null_mask[i]:
  4127. cg.extend_output(cg.pop_null())
  4128. else:
  4129. cg.extend_output([create_instruction("LIST_APPEND", arg=1)])
  4130. cg.extend_output([create_instruction("POP_TOP")])
  4131. # Stack: cells, [frame_values], *(non-popped items)
  4132. # --- Step 3: Build comprehension function ---
  4133. new_code, fn_name = self._build_comprehension_fn(
  4134. analysis,
  4135. start_ip,
  4136. stack_pops,
  4137. stack_pops_null_mask,
  4138. nonnull_count,
  4139. meta,
  4140. )
  4141. # --- Step 4: Extract [cells[0]] and [frame_values[0]] for codegen_call_resume ---
  4142. cg.extend_output(
  4143. [
  4144. *create_copy(live_stack_depth + 2),
  4145. cg.create_load_const(0),
  4146. cg.create_binary_subscr(),
  4147. create_instruction("BUILD_LIST", arg=1),
  4148. *create_copy(live_stack_depth + 2),
  4149. cg.create_load_const(0),
  4150. cg.create_binary_subscr(),
  4151. create_instruction("BUILD_LIST", arg=1),
  4152. ]
  4153. )
  4154. # Stack: ..., *(non-popped), [cells[0]], [frame_values[0]]
  4155. # --- Step 5: Call comprehension function via codegen_call_resume ---
  4156. self.codegen_call_resume([new_code], [fn_name], cg)
  4157. # Stack: ..., *(non-popped), comp_result
  4158. # --- Step 6: Remove appended stack_pops items from frame_values[0] ---
  4159. if nonnull_count > 0:
  4160. frame_values_pos = live_stack_depth + 1 + 1 # +1 result, +1 frame_values
  4161. cg.extend_output(
  4162. [
  4163. *create_copy(frame_values_pos),
  4164. cg.create_load_const(0),
  4165. cg.create_binary_subscr(),
  4166. # frame_values[0] on TOS
  4167. create_dup_top(),
  4168. # frame_values[0], frame_values[0]
  4169. cg.create_load_const(-nonnull_count),
  4170. cg.create_load_const(None),
  4171. create_instruction("BUILD_SLICE", arg=2),
  4172. create_instruction("DELETE_SUBSCR"),
  4173. # del frame_values[0][-nonnull_count:]
  4174. create_instruction("POP_TOP"),
  4175. ]
  4176. )
  4177. # --- Step 7: Pass comprehension outputs to frame_values[0] ---
  4178. # Walrus vars first, then result_var.
  4179. vars_to_pass = analysis.walrus_vars + (
  4180. [analysis.result_var] if analysis.result_var else []
  4181. )
  4182. existing_vars: dict[str, int] = {}
  4183. for var_name in vars_to_pass:
  4184. self.symbolic_locals[var_name] = UnknownVariable()
  4185. if var_name in meta.locals_names:
  4186. existing_vars[var_name] = meta.locals_names[var_name]
  4187. else:
  4188. meta.locals_names[var_name] = len(meta.locals_names)
  4189. fv_depth = live_stack_depth + 2 # comp_result + frame_values
  4190. # --- Walrus vars: extract from comp_result tuple ---
  4191. if analysis.walrus_vars:
  4192. # comp_result is (result, *walrus_vars).
  4193. cg.extend_output(
  4194. [
  4195. *create_copy(fv_depth),
  4196. cg.create_load_const(0),
  4197. cg.create_binary_subscr(),
  4198. ]
  4199. )
  4200. # Stack: ..., comp_tuple, fv0
  4201. for j, walrus_var in enumerate(analysis.walrus_vars):
  4202. cg.extend_output(
  4203. [
  4204. *create_copy(2),
  4205. cg.create_load_const(j + 1),
  4206. cg.create_binary_subscr(),
  4207. ]
  4208. )
  4209. # Stack: ..., comp_tuple, fv0, walrus_value
  4210. if walrus_var in existing_vars:
  4211. # fv0[idx] = walrus_value
  4212. cg.extend_output(
  4213. [
  4214. *create_copy(2), # copy fv0
  4215. cg.create_load_const(existing_vars[walrus_var]),
  4216. create_instruction("STORE_SUBSCR"),
  4217. ]
  4218. )
  4219. else:
  4220. cg.extend_output([create_instruction("LIST_APPEND", arg=1)])
  4221. # Stack: ..., comp_tuple, fv0
  4222. cg.extend_output(
  4223. [
  4224. create_instruction("POP_TOP"), # pop fv0
  4225. # Extract the result from the tuple.
  4226. cg.create_load_const(0),
  4227. cg.create_binary_subscr(),
  4228. ]
  4229. )
  4230. # Stack: ..., result
  4231. # --- Result: keep on stack, overwrite/append to fv[0], or discard ---
  4232. if analysis.result_on_stack:
  4233. self.push(UnknownVariable())
  4234. elif analysis.result_var:
  4235. cg.extend_output(
  4236. [
  4237. *create_copy(fv_depth),
  4238. cg.create_load_const(0),
  4239. cg.create_binary_subscr(),
  4240. # Stack: ..., result, fv0
  4241. ]
  4242. )
  4243. if analysis.result_var in existing_vars:
  4244. cg.extend_output(
  4245. [
  4246. cg.create_load_const(existing_vars[analysis.result_var]),
  4247. create_instruction("STORE_SUBSCR"),
  4248. # fv0[idx] = result
  4249. ]
  4250. )
  4251. else:
  4252. cg.extend_output(
  4253. [
  4254. *create_swap(2),
  4255. create_instruction("LIST_APPEND", arg=1),
  4256. create_instruction("POP_TOP"),
  4257. ]
  4258. )
  4259. else:
  4260. cg.extend_output([create_instruction("POP_TOP")])
  4261. # Stack: cells, [frame_values], *(non-popped stack)
  4262. self.output.add_output_instructions(cg.get_instructions())
  4263. # --- Step 8: Create resume function chain ---
  4264. resume_inst = self.instructions[analysis.end_ip]
  4265. self.output.add_output_instructions(
  4266. self.create_call_resume_at(resume_inst, all_stack_locals_metadata)
  4267. )
  4268. self.instruction_pointer = None
  4269. def _build_comprehension_fn(
  4270. self,
  4271. analysis: ComprehensionAnalysis,
  4272. start_ip: int,
  4273. stack_pops: int,
  4274. stack_pops_null_mask: list[bool],
  4275. nonnull_count: int,
  4276. meta: StackLocalsMetadata,
  4277. ) -> tuple[types.CodeType, str]:
  4278. """Build a synthetic function wrapping comprehension bytecode.
  4279. Uses the same calling convention as resume functions created by
  4280. create_resume / ContinueExecutionCache.generate: the first two args
  4281. are __nested_resume_fns and __nested_frame_values (ignored here),
  4282. followed by stack items and live locals.
  4283. Returns (code, name) where name is the global name for the function.
  4284. """
  4285. from .bytecode_transformation import transform_code_object
  4286. from .eval_frame import skip_code
  4287. from .resume_execution import CO_VARARGS, CO_VARKEYWORDS
  4288. # Args follow frame_values layout: locals first, then stack_pops items
  4289. # (appended to end of frame_values[0] by the caller).
  4290. # codegen_call_resume unpacks frame_values[0] as positional args.
  4291. argnames = tuple(
  4292. k for k in meta.locals_names if k not in self.cell_and_freevars()
  4293. )
  4294. args = (
  4295. ["__nested_resume_fns", "__nested_frame_values"]
  4296. + list(argnames)
  4297. + [f"___stack{i}" for i in range(nonnull_count)]
  4298. )
  4299. freevars = tuple(
  4300. sorted(
  4301. list(self.f_code.co_cellvars or [])
  4302. + list(self.f_code.co_freevars or [])
  4303. )
  4304. )
  4305. lineno = self.lineno if self.lineno is not None else self.f_code.co_firstlineno
  4306. fn_name = unique_id(f"__comprehension_{self.f_code.co_name}_at_{lineno}")
  4307. comprehension_body_vars = (
  4308. analysis.iterator_vars
  4309. + analysis.walrus_vars
  4310. + ([analysis.result_var] if analysis.result_var else [])
  4311. + analysis.captured_vars
  4312. )
  4313. def update(
  4314. instructions: list[Instruction], code_options: dict[str, Any]
  4315. ) -> None:
  4316. code_options["co_name"] = fn_name
  4317. if sys.version_info >= (3, 11):
  4318. code_options["co_qualname"] = fn_name
  4319. code_options["co_firstlineno"] = lineno
  4320. code_options["co_cellvars"] = ()
  4321. code_options["co_freevars"] = freevars
  4322. code_options["co_argcount"] = len(args)
  4323. code_options["co_posonlyargcount"] = 0
  4324. code_options["co_kwonlyargcount"] = 0
  4325. code_options["co_varnames"] = tuple(
  4326. args + [v for v in comprehension_body_vars if v not in args]
  4327. )
  4328. code_options["co_flags"] = code_options["co_flags"] & ~(
  4329. CO_VARARGS | CO_VARKEYWORDS
  4330. )
  4331. prefix: list[Instruction] = []
  4332. if freevars:
  4333. prefix.append(create_instruction("COPY_FREE_VARS", arg=len(freevars)))
  4334. prefix.append(create_instruction("RESUME", arg=0))
  4335. # Push stack_pops items onto operand stack so the comprehension
  4336. # bytecode finds them where it expects (iterator + saved vars).
  4337. # NULL positions get PUSH_NULL, non-null get LOAD_FAST.
  4338. # Items were appended to frame_values[0] in TOS-first order,
  4339. # so load in reverse to reconstruct the original stack layout.
  4340. nonnull_i = nonnull_count - 1
  4341. for i in range(stack_pops):
  4342. if stack_pops_null_mask[i]:
  4343. prefix.append(create_instruction("PUSH_NULL"))
  4344. else:
  4345. prefix.append(
  4346. create_instruction("LOAD_FAST", argval=f"___stack{nonnull_i}")
  4347. )
  4348. nonnull_i -= 1
  4349. comp_insts = self._copy_comprehension_bytecode(start_ip, analysis.end_ip)
  4350. # Epilogue: ensure result is on stack, pack walrus vars, return.
  4351. epilogue: list[Instruction] = []
  4352. if not analysis.result_on_stack:
  4353. if analysis.result_var:
  4354. epilogue.append(
  4355. create_instruction("LOAD_FAST", argval=analysis.result_var)
  4356. )
  4357. else:
  4358. epilogue.append(create_instruction("LOAD_CONST", argval=None))
  4359. if analysis.walrus_vars:
  4360. for var_name in analysis.walrus_vars:
  4361. epilogue.append(create_instruction("LOAD_FAST", argval=var_name))
  4362. epilogue.append(
  4363. create_instruction(
  4364. "BUILD_TUPLE",
  4365. arg=1 + len(analysis.walrus_vars),
  4366. )
  4367. )
  4368. epilogue.append(create_instruction("RETURN_VALUE"))
  4369. instructions[:] = prefix + comp_insts + epilogue
  4370. new_code, _ = transform_code_object(self.f_code, update)
  4371. skip_code(new_code)
  4372. # Install as global
  4373. if new_code.co_freevars:
  4374. self.output.install_global_unsafe(fn_name, new_code)
  4375. else:
  4376. self.output.install_global_unsafe(
  4377. fn_name,
  4378. types.FunctionType(new_code, self.f_globals, fn_name),
  4379. )
  4380. return new_code, fn_name
  4381. def _copy_comprehension_bytecode(
  4382. self, start_ip: int, end_ip: int
  4383. ) -> list[Instruction]:
  4384. """Copy comprehension bytecode instructions, updating jump targets."""
  4385. inst_map: dict[Instruction, Instruction] = {}
  4386. copied_insts: list[Instruction] = []
  4387. for ip in range(start_ip, end_ip):
  4388. original_inst = self.instructions[ip]
  4389. copied_inst = copy.copy(original_inst)
  4390. copied_inst.exn_tab_entry = None
  4391. inst_map[original_inst] = copied_inst
  4392. copied_insts.append(copied_inst)
  4393. for copied_inst in copied_insts:
  4394. if copied_inst.target is not None and copied_inst.target in inst_map:
  4395. copied_inst.target = inst_map[copied_inst.target]
  4396. return copied_insts
  4397. def _make_frame_loc(
  4398. self, filename: str, lineno: Optional[int], fallback_lineno: int
  4399. ) -> tuple[str, int]:
  4400. if lineno is None or lineno < 0:
  4401. return (filename, fallback_lineno)
  4402. return (filename, lineno)
  4403. def _get_frame_loc_chain(
  4404. self, frame_loc: tuple[str, int]
  4405. ) -> tuple[tuple[str, int], ...]:
  4406. frame_loc_chain_list: list[tuple[str, int]] = []
  4407. if config.nested_graph_breaks:
  4408. current_tx: Optional[InstructionTranslatorBase] = self.parent
  4409. while current_tx is not None:
  4410. parent_frame_loc = self._make_frame_loc(
  4411. current_tx.f_code.co_filename,
  4412. current_tx.lineno,
  4413. current_tx.f_code.co_firstlineno,
  4414. )
  4415. frame_loc_chain_list.append(parent_frame_loc)
  4416. current_tx = current_tx.parent
  4417. frame_loc_chain_list.reverse()
  4418. frame_loc_chain_list.append(frame_loc)
  4419. return tuple(frame_loc_chain_list)
  4420. def log_graph_break(
  4421. self,
  4422. code_options: dict[str, Any],
  4423. reason: str,
  4424. exc: Unsupported | StepUnsupported,
  4425. ) -> None:
  4426. if exc.logged:
  4427. return
  4428. user_stack = getattr(exc, "real_stack", None)
  4429. if user_stack is None:
  4430. user_stack = torch._guards.TracingContext.extract_stack()
  4431. try:
  4432. if config.nested_graph_breaks and self.parent is not None:
  4433. frame_loc = self._make_frame_loc(
  4434. self.f_code.co_filename,
  4435. self.lineno,
  4436. self.f_code.co_firstlineno,
  4437. )
  4438. else:
  4439. frame_loc = self._make_frame_loc(
  4440. user_stack[-1].filename,
  4441. user_stack[-1].lineno,
  4442. 0,
  4443. )
  4444. except IndexError:
  4445. # first instruction
  4446. frame_loc = (
  4447. code_options["co_filename"],
  4448. code_options["co_firstlineno"],
  4449. )
  4450. frame_loc_chain = self._get_frame_loc_chain(frame_loc)
  4451. stack_above_dynamo_formatted = ""
  4452. if config.verbose:
  4453. stack_above_dynamo = get_stack_above_dynamo()
  4454. stack_above_dynamo_formatted = "".join(
  4455. traceback.format_list(stack_above_dynamo)
  4456. )
  4457. else:
  4458. user_stack = get_stack_above_dynamo() + user_stack # type: ignore[assignment]
  4459. user_stack = collapse_resume_frames(user_stack)
  4460. user_stack_formatted = "".join(traceback.format_list(user_stack))
  4461. # Add HOP context after the first line of reason if present
  4462. if exc is not None:
  4463. reason = augment_exc_message_with_hop_name(exc, reason)
  4464. user_stack_trace = (
  4465. f"Graph break in user code at {frame_loc[0]}:{frame_loc[1]}\n"
  4466. f"Graph Break Reason: {reason}\n"
  4467. "\nUser code traceback:\n"
  4468. )
  4469. if config.verbose:
  4470. user_stack_trace += (
  4471. f"{stack_above_dynamo_formatted}\n"
  4472. "========== most recent `torch.compile` tracing attempt started here ==========\n\n"
  4473. f"{user_stack_formatted}\n"
  4474. "NOTE: the most recent `torch.compile` tracing attempt might not be where you applied `torch.compile`! "
  4475. "This is due to how graph breaks are implemented - the optimized code object returned by Dynamo will call another "
  4476. "Dynamo-generated resume function and tracing is re-enabled by calling the resume function as a normal Python "
  4477. "function, which Dynamo intercepts as a top-level frame.\n"
  4478. )
  4479. else:
  4480. user_stack_trace += str(user_stack_formatted)
  4481. torch._logging.trace_structured(
  4482. "artifact",
  4483. metadata_fn=lambda: {
  4484. "name": "dynamo_graph_break_reason",
  4485. "encoding": "string",
  4486. },
  4487. payload_fn=lambda: f"{user_stack_trace}\n{traceback.format_exc()}",
  4488. )
  4489. # torch._dynamo.explain() formats this a little nicer, and presents a slightly
  4490. # more actionable user code pointer
  4491. gb_type = exc.gb_type if isinstance(exc, Unsupported) else type(exc)
  4492. if (
  4493. graph_break_log.isEnabledFor(logging.DEBUG)
  4494. and not explain
  4495. and graph_break_dup_warning_checker.add((gb_type, frame_loc_chain)) # type: ignore[arg-type]
  4496. ):
  4497. # This log line MUST contain the string "Graph break in user code",
  4498. # This log line is exercised from
  4499. # python test/dynamo/test_exc.py -k test_graph_break_log
  4500. if config.verbose:
  4501. user_stack_trace += (
  4502. "\nMost recent bytecode instructions traced (max 20):\n"
  4503. )
  4504. user_stack_trace += "\n".join(self.latest_bytecode_queue) + "\n"
  4505. graph_break_log.debug(
  4506. user_stack_trace,
  4507. )
  4508. else:
  4509. # This log line MUST not contain the string "Graph break in user code",
  4510. # exercised by
  4511. # python test/dynamo/test_misc.py -k test_duplicate_graph_break_log
  4512. graph_break_log.debug(
  4513. "Graph break (user stack suppressed due to duplicate graph break) in user code at %s:%s\nGraph Break Reason: %s",
  4514. frame_loc[0],
  4515. frame_loc[1],
  4516. reason,
  4517. )
  4518. exc.logged = True
  4519. @staticmethod
  4520. def raise_loop_graph_break(code: types.CodeType, exc: Unsupported) -> NoReturn:
  4521. unimplemented(
  4522. gb_type="graph break in loop",
  4523. context=f"frame skipped: {format_frame_info(code)}",
  4524. explanation="torch.compile detected a graph break in a for/while loop. "
  4525. "Skipping the frame and falling back to eager, as graph breaks in loops are not supported.",
  4526. hints=[*graph_break_hints.CAUSED_BY_EARLIER_GRAPH_BREAK],
  4527. from_exc=exc,
  4528. skip_frame=True,
  4529. )
  4530. def __init__(
  4531. self,
  4532. output: OutputGraph,
  4533. instructions: list[Instruction],
  4534. f_locals: dict[str, Any],
  4535. f_globals: dict[str, Any],
  4536. f_builtins: dict[str, Any],
  4537. code_options: dict[str, Any],
  4538. symbolic_locals: dict[str, VariableTracker],
  4539. symbolic_globals: dict[str, VariableTracker],
  4540. symbolic_torch_function_state: SymbolicTorchFunctionState,
  4541. symbolic_stream_state: SymbolicStreamState,
  4542. f_code: types.CodeType,
  4543. export: bool,
  4544. inline_depth: int,
  4545. speculation_log: SpeculationLog,
  4546. exn_vt_stack: ExceptionStack,
  4547. distributed_state: Optional[DistributedState],
  4548. # This determines whether to use the execution recorder.
  4549. closure: Optional[tuple[types.CellType]] = None,
  4550. package: Optional[CompilePackage] = None,
  4551. # Pre-computed indexof for cache reuse
  4552. indexof: Optional[dict[Instruction, int]] = None,
  4553. ) -> None:
  4554. super().__init__()
  4555. self.speculation_log = speculation_log
  4556. self.distributed_state = distributed_state
  4557. # Mutable state checkpointed by copy_graphstate()
  4558. self.output = output
  4559. self.symbolic_locals = symbolic_locals
  4560. self.symbolic_globals = symbolic_globals
  4561. self.symbolic_torch_function_state = symbolic_torch_function_state
  4562. self.symbolic_stream_state = symbolic_stream_state
  4563. # used to keep cell/freevars alive after pruning symbolic_locals (prune_dead_locals)
  4564. # in order to generate any nested closures
  4565. self.post_prune_cell_and_freevars = None
  4566. self.stack: list[VariableTracker] = []
  4567. self.instruction_pointer = 0
  4568. self.start_point = None
  4569. self.current_instruction = create_instruction("NOP")
  4570. self.current_instruction_push = True
  4571. self.block_stack = []
  4572. # states before SETUP_WITH for checkpointing and fallback
  4573. self.active_generic_context_managers: list[GenericContextWrappingVariable] = []
  4574. self.lineno = -1
  4575. self.kw_names = None
  4576. self.accept_prefix_inst = True
  4577. self.prefix_insts = []
  4578. self.exn_vt_stack = exn_vt_stack
  4579. self.latest_bytecode_queue = deque(maxlen=20)
  4580. self._comprehension_depth = 0
  4581. self._comprehension_end_for_ips: set[int] = set()
  4582. # Properties of the input/output code
  4583. self.instructions: list[Instruction] = instructions
  4584. self.indexof: dict[Instruction, int] = (
  4585. indexof if indexof is not None else get_indexof(self.instructions)
  4586. )
  4587. self.f_locals: dict[str, Any] = (
  4588. f_locals # needed for recording accessed locals for replay
  4589. )
  4590. self.f_globals: dict[str, Any] = f_globals
  4591. self.f_builtins: dict[str, Any] = f_builtins
  4592. self.code_options: dict[str, Any] = code_options
  4593. self.f_code: types.CodeType = f_code
  4594. self.closure = closure
  4595. # Execution record for replaying errors
  4596. if closure is not None and config.replay_record_enabled:
  4597. self.exec_recorder = ExecutionRecorder(
  4598. code=f_code, closure=closure, code_options=code_options
  4599. )
  4600. else:
  4601. self.exec_recorder = None
  4602. # Stack of module being parsed, current nn.module is at the end of ordered dict.
  4603. # The first field of tuple is the fully qualified name of current module
  4604. # in original hierarchy. The second field is the type of current nn.module
  4605. self.nn_module_stack: dict[str, tuple[str, type[Any]]] = {}
  4606. self.num_calls: dict[str, int] = {}
  4607. # Flag to indicate whether tracing is used for export.
  4608. self.export = export
  4609. # NOTE: one_graph is used for export/fullgraph=True to always force errors on graph breaks.
  4610. # To toggle erroring/resuming on graph breaks during fullgraph=False compile, self.error_on_graph_break
  4611. # is used instead. Every step(), its value is updated to the global tls.error_on_graph_break.
  4612. # We mirror this value since cleanup may (correctly) inadvertently change tls.error_on_graph_break.
  4613. # This assumes that we cannot both trace a change to tls.error_on_graph_break and graph break on
  4614. # the same instruction.
  4615. self.one_graph = False
  4616. self.error_on_graph_break = False
  4617. # Also do not graph break when tracing resume function prologues
  4618. self.is_tracing_resume_prologue = False
  4619. self.current_speculation = None
  4620. self.strict_checks_fn = None
  4621. self.has_no_inlined_calls = True
  4622. self.parent = None
  4623. self.is_child_tracer_active = False
  4624. self.debug_locals = []
  4625. self.package = package
  4626. from .resume_execution import (
  4627. CO_ASYNC_GENERATOR,
  4628. CO_COROUTINE,
  4629. CO_GENERATOR,
  4630. CO_ITERABLE_COROUTINE,
  4631. )
  4632. if f_code.co_flags & (
  4633. CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR
  4634. ):
  4635. self.push(BuiltinVariable(None))
  4636. self.inline_depth = inline_depth
  4637. self.inconsistent_side_effects = False
  4638. self._constants_cache: list[
  4639. Optional[Union[ConstantVariable, SliceVariable]]
  4640. ] = [None] * len(f_code.co_consts)
  4641. self.is_trace_bytecode_log_enabled: Optional[bool] = (
  4642. trace_bytecode_log.isEnabledFor(logging.DEBUG)
  4643. )
  4644. self.is_trace_source_log_enabled: Optional[bool] = (
  4645. trace_source_log.isEnabledFor(logging.DEBUG)
  4646. )
  4647. linecache.lazycache(f_code.co_filename, f_globals)
  4648. class InstructionTranslator(InstructionTranslatorBase):
  4649. @staticmethod
  4650. def current_tx() -> InstructionTranslator:
  4651. return tls.current_tx
  4652. @contextlib.contextmanager
  4653. def set_current_tx(self) -> Any:
  4654. prior = getattr(tls, "current_tx", None)
  4655. tls.current_tx = self
  4656. try:
  4657. yield
  4658. finally:
  4659. tls.current_tx = prior
  4660. def __init__(
  4661. self,
  4662. instructions: list[Instruction],
  4663. f_code: types.CodeType,
  4664. f_locals: dict[str, Any],
  4665. f_globals: dict[str, Any],
  4666. f_builtins: dict[str, Any],
  4667. closure: Optional[tuple[Any, ...]],
  4668. torch_function_mode_stack: Any,
  4669. code_options: dict[str, Any],
  4670. compiler_fn: Any,
  4671. one_graph: bool,
  4672. export: bool,
  4673. export_constraints: Any,
  4674. frame_state: Any,
  4675. speculation_log: SpeculationLog,
  4676. exn_vt_stack: ExceptionStack,
  4677. distributed_state: Optional[DistributedState],
  4678. package: Optional[CompilePackage],
  4679. ) -> None:
  4680. _step_logger()(
  4681. logging.INFO,
  4682. f"torchdynamo start tracing {f_code.co_name} {code_options['co_filename']}:{code_options['co_firstlineno']}",
  4683. )
  4684. super().__init__(
  4685. output=OutputGraph(
  4686. code_options,
  4687. compiler_fn,
  4688. self,
  4689. export,
  4690. export_constraints,
  4691. frame_state,
  4692. local_scope=f_locals,
  4693. global_scope=f_globals,
  4694. f_code=f_code,
  4695. torch_function_mode_stack=torch_function_mode_stack,
  4696. one_graph=one_graph,
  4697. package=package,
  4698. ),
  4699. instructions=instructions,
  4700. f_locals=f_locals,
  4701. f_globals=f_globals,
  4702. f_builtins=f_builtins,
  4703. closure=closure,
  4704. code_options=code_options,
  4705. symbolic_locals={}, # set below
  4706. # A global var is inserted only after a STORE_GLOBAL happens to it
  4707. symbolic_globals={},
  4708. symbolic_torch_function_state=None, # type: ignore[arg-type] # set below
  4709. symbolic_stream_state=None, # type: ignore[arg-type] # set below
  4710. f_code=f_code,
  4711. export=export,
  4712. inline_depth=0,
  4713. speculation_log=speculation_log,
  4714. exn_vt_stack=exn_vt_stack,
  4715. distributed_state=distributed_state,
  4716. package=package,
  4717. )
  4718. self._throw_if_in_functorch()
  4719. # as soon as we create the tracing context we should keep it active, so any calls
  4720. # into dynamo apis can rely on finding it
  4721. with tracing(self.output.tracing_context), self.set_current_tx():
  4722. self.one_graph: bool = one_graph
  4723. self.export = export
  4724. if self.export:
  4725. assert self.one_graph, (
  4726. "Export without one graph - something has gone wrong."
  4727. )
  4728. self.symbolic_locals = {}
  4729. # Populate `symbolic_locals` with non-cell variables.
  4730. cell_and_freevars: set[str] = set(self.cell_and_freevars())
  4731. dynamism = code_context.get_context(f_code).get("dynamism", None)
  4732. for name, value in f_locals.items():
  4733. if name not in cell_and_freevars:
  4734. local_dynamism = None
  4735. if dynamism:
  4736. local_dynamism = frozenset(dynamism.get(name, {}).items())
  4737. var = LazyVariableTracker.create(
  4738. value,
  4739. LocalSource(
  4740. name,
  4741. is_input=True,
  4742. dynamism=local_dynamism,
  4743. ),
  4744. )
  4745. self.symbolic_locals[name] = var
  4746. # Populate `symbolic_locals` with cells created by this frame,
  4747. # effectively implementing the `MAKE_CELL` instructions.
  4748. side_effects = self.output.side_effects
  4749. for name in self.cellvars():
  4750. if name in f_locals:
  4751. # This models cells that are also function inputs.
  4752. value = f_locals[name]
  4753. # NOTE: root frame inputs that are captured by a nested
  4754. # function become special cell objects -- they exist in
  4755. # `f_locals` as contents of the cells, rather than the cells
  4756. # objects themselves.
  4757. #
  4758. # In Dynamo, we choose to represent such input cell objects
  4759. # as newly created (rather than pre-existing) cell objects,
  4760. # because
  4761. #
  4762. # 1. The reason for representing a pre-existing cell object
  4763. # is to emit guard or codegen mutations. However, local
  4764. # cells should never be used for guards. Moreover, at this
  4765. # point these input cell objects should've never been
  4766. # accessed by anyone else, since Dynamo intercepts the frame
  4767. # right after its evaluation starts, i.e., right after these
  4768. # cell objects are created. So they should have no external
  4769. # reference, meaning no mutation needs to be propagated.
  4770. #
  4771. # 2. This conveniently allows codegen to prune away
  4772. # mutations to these cells, unless they escape the frame.
  4773. contents_source = LocalSource(
  4774. name,
  4775. is_input=True,
  4776. is_derefed_cell_contents=True,
  4777. )
  4778. contents_var: VariableTracker = LazyVariableTracker.create(
  4779. value, contents_source
  4780. )
  4781. cell_var = side_effects.track_cell_new()
  4782. side_effects.store_cell(cell_var, contents_var)
  4783. else:
  4784. cell_var = side_effects.track_cell_new()
  4785. cell_var.local_name = name # type: ignore[attr-defined]
  4786. self.symbolic_locals[name] = cell_var
  4787. # Populate `symbolic_locals` with cells captured by this frame,
  4788. # effectively implementing the `COPY_FREE_VARS` instruction.
  4789. assert closure is not None
  4790. for name, cell in zip(self.freevars(), closure):
  4791. cell_source = LocalCellSource(name)
  4792. contents_source = LocalSource(name, is_derefed_cell_contents=True)
  4793. try:
  4794. contents_var = LazyVariableTracker.create(
  4795. cell.cell_contents, contents_source
  4796. )
  4797. except ValueError:
  4798. # Cell has not yet been assigned
  4799. contents_var = variables.DeletedVariable()
  4800. cell_var = side_effects.track_cell_existing(
  4801. cell_source, cell, contents_var
  4802. )
  4803. cell_var.local_name = name # type: ignore[attr-defined]
  4804. self.symbolic_locals[name] = cell_var
  4805. self.symbolic_torch_function_state = SymbolicTorchFunctionState(
  4806. torch_function_mode_stack
  4807. )
  4808. self.symbolic_stream_state = SymbolicStreamState()
  4809. if export:
  4810. # export gets confused if we never realize unused inputs
  4811. # in export mode just eagerly realize everything
  4812. self.symbolic_locals = variables.LazyVariableTracker.realize_all(
  4813. self.symbolic_locals
  4814. )
  4815. def _throw_if_in_functorch(self) -> None:
  4816. # Fallback to eager in case of a graph break inside vmap
  4817. eager = torch._dynamo.lookup_backend("eager")
  4818. compiler_fn = inspect.getattr_static(
  4819. self.output.compiler_fn, "compiler_fn", self.output.compiler_fn
  4820. )
  4821. ci = torch._C._functorch.peek_interpreter_stack()
  4822. forbidden_keys = (
  4823. torch._C._functorch.TransformType.Vmap,
  4824. torch._C._functorch.TransformType.Grad,
  4825. torch._C._functorch.TransformType.Jvp,
  4826. )
  4827. if ci is not None and ci.key() in forbidden_keys and compiler_fn is not eager:
  4828. name = ci.key().name.lower()
  4829. msg = (
  4830. "If you are reaching here, it means dynamo failed for one of the following reasons:\n"
  4831. # Calling a torch.compiled function
  4832. f"- Calling torch.func.{name}(compiled_fn) function from eager mode is not supported. "
  4833. f"Ensure that torch.func.{name} is also wrapped within a torch.compile function. "
  4834. "For more information, see PyTorch issue #128711.\n"
  4835. # if it reaches here, it means Dynamo failed to inline a functorch function
  4836. f"- torch.func.{name}(fn) requires the function to be inlined by dynamo"
  4837. )
  4838. unimplemented(
  4839. gb_type="Unsupported functorch tracing attempt",
  4840. context="",
  4841. explanation=msg,
  4842. hints=[],
  4843. )
  4844. def get_example_value(self, source: Source) -> Any:
  4845. if isinstance(source, LocalSource):
  4846. return self.f_locals[source.local_name]
  4847. if isinstance(source, GlobalSource):
  4848. return self.f_globals[source.global_name]
  4849. raise KeyError
  4850. def symbolic_locals_contain_module_class(self) -> bool:
  4851. for v in self.symbolic_locals.values():
  4852. if isinstance(v, UserDefinedClassVariable) and issubclass(
  4853. v.as_python_constant(), torch.nn.Module
  4854. ):
  4855. return True
  4856. return False
  4857. def replace_tos_if_return_is_generator(self) -> None:
  4858. if (
  4859. len(self.stack)
  4860. and (tos := self.stack[-1])
  4861. and isinstance(tos, LocalGeneratorObjectVariable)
  4862. ):
  4863. self.stack[-1] = ListIteratorVariable(
  4864. tos.force_unpack_var_sequence(self),
  4865. mutation_type=ValueMutationNew(),
  4866. )
  4867. def _return(self, inst: Instruction) -> None:
  4868. self.replace_tos_if_return_is_generator()
  4869. assert self.instruction_pointer is not None
  4870. assert self.start_point is not None
  4871. get_metrics_context().increment(
  4872. "ir_count", self.instruction_pointer - self.start_point
  4873. )
  4874. if (
  4875. not config.allow_empty_graphs
  4876. and self.output.count_calls() == 0
  4877. and not self.inconsistent_side_effects
  4878. and not self.symbolic_locals_contain_module_class()
  4879. and not self.export
  4880. and not self.one_graph
  4881. and not self.error_on_graph_break
  4882. and not self.is_tracing_resume_prologue
  4883. ):
  4884. # TODO graph break if one_graph is set - this might break things
  4885. raise exc.SkipFrame(
  4886. "No ops traced for the FX graph. `torch.compile` will skip the frame and fall back to eager.\n"
  4887. f"Frame info: {format_frame_info(self.f_code)}"
  4888. )
  4889. self.instruction_pointer = None
  4890. _step_logger()(
  4891. logging.INFO,
  4892. f"torchdynamo done tracing {self.f_code.co_name} ({inst.opname})",
  4893. )
  4894. log.debug("return triggered compile")
  4895. all_stack_locals_metadata = self.output.compile_subgraph(
  4896. self,
  4897. reason=GraphCompileReason(
  4898. "return_value", [self.frame_summary()], graph_break=False
  4899. ),
  4900. # the value to be returned
  4901. stack_pops=1 if inst.opname == "RETURN_VALUE" else 0,
  4902. )
  4903. # check that our stack/locals meta are correct:
  4904. # we should only be tracing 1 frame, and there should not be any NULLs on the stack
  4905. assert len(all_stack_locals_metadata) == 1
  4906. assert not all_stack_locals_metadata[0].stack_null_idxes
  4907. self.output.add_output_instructions(
  4908. self.codegen_return_with_pops(inst, all_stack_locals_metadata[0].num_stack)
  4909. )
  4910. raise ReturnValueOp
  4911. def RETURN_VALUE(self, inst: Instruction) -> None:
  4912. self._return(inst)
  4913. def RETURN_CONST(self, inst: Instruction) -> None:
  4914. self._return(inst)
  4915. if sys.version_info >= (3, 11):
  4916. _binary_op_lookup = [
  4917. getattr(
  4918. InstructionTranslator,
  4919. opname[3:] if "INPLACE" in opname else f"BINARY_{opname[3:]}",
  4920. )
  4921. for opname, _ in dis._nb_ops # type: ignore[attr-defined]
  4922. ]
  4923. @contextlib.contextmanager
  4924. def profile_inline_call(
  4925. output: OutputGraph,
  4926. code: types.CodeType,
  4927. get_inline_depth: Callable[[], int],
  4928. ) -> Generator[None, None, None]:
  4929. """
  4930. Context manager for profiling inline calls.
  4931. Args:
  4932. output: The OutputGraph containing profiler_state
  4933. code: The code object being inlined (for timing metadata)
  4934. get_inline_depth: Callable that returns inline_depth (called after work completes)
  4935. Yields:
  4936. None (profiling happens around the with block)
  4937. """
  4938. if not config.dynamo_profiler:
  4939. yield
  4940. return
  4941. if output.profiler_state is None:
  4942. output.profiler_state = DynamoProfilerState()
  4943. caller_info = output.profiler_state.get_current_caller()
  4944. call_stack = output.profiler_state.get_call_stack()
  4945. output.profiler_state.push(
  4946. code.co_name, code.co_filename, code.co_firstlineno, time.time_ns()
  4947. )
  4948. trace_success = False
  4949. try:
  4950. yield
  4951. trace_success = True
  4952. finally:
  4953. stack_entry = output.profiler_state.pop()
  4954. trace_end_ns = time.time_ns()
  4955. if trace_success and stack_entry is not None:
  4956. inline_depth = get_inline_depth()
  4957. cumtime_ns = trace_end_ns - stack_entry.start_time_ns
  4958. tottime_ns = cumtime_ns - stack_entry.child_time_ns
  4959. timing = FunctionTraceTiming(
  4960. func_name=stack_entry.func_name,
  4961. filename=stack_entry.filename,
  4962. firstlineno=stack_entry.firstlineno,
  4963. cumtime_ns=cumtime_ns,
  4964. tottime_ns=tottime_ns,
  4965. bytecode_count=len(code.co_code),
  4966. inline_depth=inline_depth,
  4967. caller_func_name=caller_info[0] if caller_info else None,
  4968. caller_filename=caller_info[1] if caller_info else None,
  4969. caller_firstlineno=caller_info[2] if caller_info else None,
  4970. is_primitive_call=stack_entry.is_primitive_call,
  4971. call_stack=call_stack,
  4972. )
  4973. output.profiler_state.record_timing(timing)
  4974. output.profiler_state.add_child_time(cumtime_ns)
  4975. class InliningInstructionTranslator(InstructionTranslatorBase):
  4976. """Trace and inline a called method"""
  4977. symbolic_result: Optional[VariableTracker]
  4978. # pyrefly: ignore [bad-override]
  4979. parent: InstructionTranslatorBase
  4980. @classmethod
  4981. def inline_call(
  4982. cls,
  4983. parent: Any,
  4984. func: BaseUserFunctionVariable,
  4985. args: Sequence[VariableTracker],
  4986. kwargs: dict[str, VariableTracker],
  4987. ) -> VariableTracker:
  4988. tracer = None
  4989. with profile_inline_call(
  4990. parent.output, func.get_code(), lambda: parent.inline_depth + 1
  4991. ):
  4992. tracer = cls.build_inline_tracer(parent, func, args, kwargs)
  4993. return tracer.inline_call_()
  4994. @staticmethod
  4995. def check_inlineable(
  4996. func: BaseUserFunctionVariable,
  4997. ) -> trace_rules.SkipResult:
  4998. if func.has_self():
  4999. unimplemented(
  5000. gb_type="Inline attempt with __self__",
  5001. context=str(func),
  5002. explanation="Attempted to inline a function with the `__self__` attribute. "
  5003. "Dynamo is expected to decompose method calls into function calls with a `self` argument.",
  5004. hints=[],
  5005. )
  5006. if isinstance(func, UserFunctionVariable) and inspect.getattr_static(
  5007. func.get_function(), "_torchdynamo_disable", False
  5008. ):
  5009. msg = inspect.getattr_static(
  5010. func.get_function(), "_torchdynamo_disable_msg", None
  5011. )
  5012. unimplemented(
  5013. gb_type="Skip inlining `torch.compiler.disable()`d function",
  5014. context=str(func.get_function()),
  5015. explanation=f"Skip inlining function {func.get_function()} since it was wrapped "
  5016. f"with `torch.compiler.disable` (reason: {msg})",
  5017. hints=[
  5018. "Remove the `torch.compiler.disable` call",
  5019. ],
  5020. )
  5021. result = trace_rules.check_verbose(func, is_inlined_call=True)
  5022. if result.skipped:
  5023. from torch._dynamo.variables.misc import produce_trampoline_autograd_apply
  5024. # _origin marks this as coming from an internal dynamo known function that is safe to
  5025. # trace through.
  5026. if (
  5027. hasattr(func, "fn")
  5028. and hasattr(func.fn, "_origin")
  5029. and func.fn._origin is produce_trampoline_autograd_apply
  5030. ):
  5031. # Known sound
  5032. return trace_rules.SkipResult(
  5033. False, "allowlist in dynamo known function"
  5034. )
  5035. fn_qualname = func.fn.__qualname__ if hasattr(func, "fn") else ""
  5036. hints = [
  5037. f"Avoid calling the function `{fn_qualname}`.",
  5038. ]
  5039. if "_dynamo" not in func.get_filename():
  5040. hints += [
  5041. f"Apply `@torch._dynamo.dont_skip_tracing` to the function `{fn_qualname}` "
  5042. "to force tracing into the function. "
  5043. "More graph breaks may occur as a result of attempting to trace into the function.",
  5044. "Please file an issue to PyTorch.",
  5045. ]
  5046. unimplemented(
  5047. gb_type="Attempted to inline function marked as skipped",
  5048. context=f"qualname: {fn_qualname}, name: {func.get_name()}, "
  5049. f"filename: `{func.get_filename()}`, skip reason: {result.reason}",
  5050. explanation=f"Dynamo developers have intentionally marked that the function `{fn_qualname}` "
  5051. "should not be traced.",
  5052. hints=hints,
  5053. )
  5054. return result
  5055. @staticmethod
  5056. def build_inline_tracer(
  5057. parent: Any,
  5058. func: BaseUserFunctionVariable,
  5059. args: Sequence[VariableTracker],
  5060. kwargs: dict[str, VariableTracker],
  5061. ) -> InliningInstructionTranslator:
  5062. assert isinstance(
  5063. func,
  5064. (
  5065. UserFunctionVariable,
  5066. NestedUserFunctionVariable,
  5067. LocalGeneratorFunctionVariable,
  5068. ),
  5069. )
  5070. code: types.CodeType = func.get_code()
  5071. result = None
  5072. tracing_ctx = parent.output.tracing_context
  5073. # Check if we have already identified this function to be inline-able.
  5074. # The exception is dont_skip_tracing flag which affects the inline
  5075. # behavior. If the flag is True, don't rely on previous results.
  5076. if not config.dont_skip_tracing and tracing_ctx:
  5077. if previous_result := tracing_ctx.previously_inlined_functions.get(
  5078. code, None
  5079. ):
  5080. result = previous_result
  5081. if result is None:
  5082. result = InliningInstructionTranslator.check_inlineable(func)
  5083. assert result.skipped is False
  5084. if not config.dont_skip_tracing and tracing_ctx:
  5085. tracing_ctx.previously_inlined_functions[code] = result
  5086. sub_locals = None
  5087. try:
  5088. sub_locals = func.bind_args(parent, args, kwargs)
  5089. except TypeError as e:
  5090. unimplemented(
  5091. gb_type="failed to bind arguments when attempting to inline",
  5092. context=f"func='{func.get_name()}' {func.get_filename()}:{func.get_code().co_firstlineno}; "
  5093. f"args = {[arg.python_type() for arg in args]}; kwargs = {kwargs}",
  5094. explanation=f"Argument mismatch when attempting to trace function {func.get_name()}.",
  5095. hints=[
  5096. *graph_break_hints.USER_ERROR,
  5097. ],
  5098. from_exc=e,
  5099. )
  5100. assert sub_locals is not None
  5101. for v in itertools.chain(sub_locals.values()):
  5102. if not isinstance(v, VariableTracker):
  5103. unimplemented(
  5104. gb_type="Encountered unconverted argument when attempting to inline",
  5105. context=f"func: {func}, arg: {v}",
  5106. explanation="An argument to an inlined function was not successfully converted to a VariableTracker.",
  5107. hints=[*graph_break_hints.DYNAMO_BUG],
  5108. )
  5109. if code.co_name in ("__setitem__", "__setattr__") and not (
  5110. args and isinstance(args[0], variables.UserDefinedObjectVariable)
  5111. ):
  5112. unimplemented(
  5113. gb_type="Unsupported __setitem__/__setattr__ inline attempt",
  5114. context=f"code name: {code.co_name}, args: {args}",
  5115. explanation=f"Attempted to inline {code.co_name} where first argument (self) is not a user-defined object.",
  5116. hints=[],
  5117. )
  5118. suffix = ""
  5119. # TODO: mlazos, add support for enabling multiple artifact logs
  5120. # with a single alias
  5121. if torch._logging._internal.log_state.is_artifact_enabled("bytecode"):
  5122. suffix = f"\n{dis.Bytecode(code).dis()}"
  5123. if sys.version_info >= (3, 11):
  5124. cur_inst = parent.current_instruction
  5125. parent_code = parent.f_code
  5126. def get_trace_call_log_str() -> str:
  5127. header = parent.get_line_of_code_header(
  5128. lineno=cur_inst.positions.lineno
  5129. )
  5130. line = get_instruction_source_311(parent_code, cur_inst).rstrip()
  5131. return f"TRACE inlined call {code.co_name} from {header}\n{line}"
  5132. trace_call_log.debug("%s", LazyString(get_trace_call_log_str))
  5133. log.debug("INLINING %s%s, %s", code, suffix, result.reason)
  5134. # Detect inline GraphModule calls in order to propagate node metadata,
  5135. # by checking if the first argument (self) is a variable tracking a GraphModule.
  5136. if args and isinstance(args[0], NNModuleVariable):
  5137. module = parent.output.get_submodule(args[0].module_key)
  5138. if isinstance(module, torch.fx.GraphModule):
  5139. # The inline call might not actually be a call to `forward`,
  5140. # but it is enough to add a context for `forward` in case it is called.
  5141. code_context.get_context(module.forward.__code__)[
  5142. "orig_graphmodule"
  5143. ] = weakref.ref(module)
  5144. # When we have inline_nn_module turned on, modules resolve to UnspecializedNNModuleVariable
  5145. if args and isinstance(args[0], UnspecializedNNModuleVariable):
  5146. module = args[0].value
  5147. if isinstance(module, torch.fx.GraphModule):
  5148. # The inline call might not actually be a call to `forward`,
  5149. # but it is enough to add a context for `forward` in case it is called.
  5150. code_context.get_context(module.forward.__code__)[
  5151. "orig_graphmodule"
  5152. ] = weakref.ref(module)
  5153. assert not isinstance(func, SkipFunctionVariable)
  5154. tracer: InliningInstructionTranslator
  5155. if is_generator(code):
  5156. tracer = InliningGeneratorInstructionTranslator(
  5157. parent,
  5158. code,
  5159. sub_locals,
  5160. parent.symbolic_globals,
  5161. parent.symbolic_torch_function_state,
  5162. parent.symbolic_stream_state,
  5163. func,
  5164. )
  5165. else:
  5166. tracer = InliningInstructionTranslator(
  5167. parent,
  5168. code,
  5169. sub_locals,
  5170. parent.symbolic_globals,
  5171. parent.symbolic_torch_function_state,
  5172. parent.symbolic_stream_state,
  5173. func,
  5174. )
  5175. return tracer
  5176. def inline_call_(self) -> VariableTracker:
  5177. parent = self.parent
  5178. parent.has_no_inlined_calls = False
  5179. parent.is_child_tracer_active = True
  5180. code = self.f_code
  5181. strict_ctx: Any = contextlib.nullcontext()
  5182. if parent.strict_checks_fn:
  5183. strict_ctx = self.strict_translation_mode(parent.strict_checks_fn)
  5184. try:
  5185. with strict_ctx:
  5186. self.run()
  5187. except exc.ObservedException as e:
  5188. msg = f"Observed exception DURING INLING {code} : {e}"
  5189. log.debug(msg)
  5190. # bubble up the exception to the parent frame.
  5191. raise
  5192. except Unsupported as e:
  5193. # If this graph break has skip_frame set, unset it
  5194. # since it refers to the current frame and not the parent.
  5195. e.skip_frame = False
  5196. raise
  5197. except Exception:
  5198. log.debug("FAILED INLINING %s", code)
  5199. raise
  5200. finally:
  5201. # Pass inlined tx's error_on_graph_break to parent.
  5202. # Deals with the case where the parent's error_on_graph_break is True
  5203. # while the inlined tx's error_on_graph_break was set to False.
  5204. parent.error_on_graph_break = self.error_on_graph_break
  5205. parent.is_child_tracer_active = False
  5206. if self.output.should_exit:
  5207. # graph break
  5208. return CONSTANT_VARIABLE_NONE # return dummy variable
  5209. assert self.symbolic_result is not None
  5210. if self.f_globals is parent.f_globals:
  5211. # Merge symbolic_globals back if parent and child are in the same namespace
  5212. parent.symbolic_globals.update(self.symbolic_globals)
  5213. parent.inconsistent_side_effects |= self.inconsistent_side_effects
  5214. log.debug("DONE INLINING %s", code)
  5215. self.output.tracing_context.traced_code.append(code)
  5216. if config.enable_faithful_generator_behavior or (
  5217. isinstance(self, InliningGeneratorInstructionTranslator)
  5218. and self.is_generator_from_ctx_manager
  5219. ):
  5220. if (
  5221. is_generator(code)
  5222. and isinstance(self, InliningGeneratorInstructionTranslator)
  5223. and self.generator_exhausted
  5224. ):
  5225. assert isinstance(self, InliningGeneratorInstructionTranslator)
  5226. # When the generator returns None, we raise StopIteration
  5227. # pyrefly: ignore [implicit-any]
  5228. args = []
  5229. if not self.symbolic_result.is_constant_none():
  5230. args = [self.symbolic_result]
  5231. exc.raise_observed_exception(StopIteration, self, args=args)
  5232. else:
  5233. return self.symbolic_result
  5234. else:
  5235. if is_generator(code):
  5236. assert isinstance(self, InliningGeneratorInstructionTranslator)
  5237. assert self.symbolic_result.is_constant_none()
  5238. return ListIteratorVariable(
  5239. self.generated_items,
  5240. mutation_type=ValueMutationNew(),
  5241. )
  5242. else:
  5243. return self.symbolic_result
  5244. def __init__(
  5245. self,
  5246. parent: InstructionTranslatorBase,
  5247. code: types.CodeType,
  5248. symbolic_locals: dict[str, VariableTracker],
  5249. symbolic_globals: dict[str, VariableTracker],
  5250. symbolic_torch_function_state: SymbolicTorchFunctionState,
  5251. symbolic_stream_state: SymbolicStreamState,
  5252. funcvar: BaseUserFunctionVariable | LocalGeneratorObjectVariable,
  5253. ) -> None:
  5254. f_globals = funcvar.get_globals()
  5255. f_builtins = f_globals["__builtins__"]
  5256. if not isinstance(f_builtins, dict):
  5257. f_builtins = f_builtins.__dict__
  5258. # Get the cached code data. This cache combines instructions, indexof, and
  5259. # code_options to avoid recomputing them for frequently-inlined functions.
  5260. tracing_ctx = parent.output.tracing_context
  5261. cached = tracing_ctx.inlined_code_cache.get(code) if tracing_ctx else None
  5262. if cached is not None:
  5263. instructions = cached.instructions
  5264. indexof = cached.indexof
  5265. code_options = cached.code_options
  5266. else:
  5267. instructions = cleaned_instructions(code)
  5268. propagate_line_nums(instructions)
  5269. indexof = get_indexof(instructions)
  5270. code_options = {k: getattr(code, k) for k in get_code_keys()}
  5271. if tracing_ctx:
  5272. tracing_ctx.inlined_code_cache[code] = InlinedCodeCache(
  5273. instructions=instructions,
  5274. indexof=indexof,
  5275. code_options=code_options,
  5276. )
  5277. super().__init__(
  5278. output=parent.output,
  5279. f_locals={},
  5280. f_globals=f_globals,
  5281. f_builtins=f_builtins,
  5282. symbolic_locals=symbolic_locals,
  5283. symbolic_globals=symbolic_globals,
  5284. symbolic_torch_function_state=symbolic_torch_function_state,
  5285. symbolic_stream_state=symbolic_stream_state,
  5286. instructions=instructions,
  5287. code_options=code_options,
  5288. f_code=code,
  5289. export=parent.export,
  5290. inline_depth=parent.inline_depth + 1,
  5291. speculation_log=parent.speculation_log,
  5292. exn_vt_stack=parent.exn_vt_stack,
  5293. distributed_state=parent.distributed_state,
  5294. package=parent.package,
  5295. indexof=indexof,
  5296. )
  5297. self.funcvar = funcvar
  5298. self.parent = parent
  5299. self.num_calls = parent.num_calls
  5300. self.symbolic_result = None
  5301. self.nn_module_stack = parent.nn_module_stack.copy()
  5302. self.one_graph = parent.one_graph
  5303. @property
  5304. def fake_mode(self) -> Optional[FakeTensorMode]:
  5305. return self.parent.fake_mode
  5306. def run_ctx_mgr(self) -> Any:
  5307. return TracingContext.current_frame(self.parent.frame_summary())
  5308. def _can_speculate_comprehension_nested(self) -> bool:
  5309. """Check if comprehension speculation is allowed in this inlined context.
  5310. Unlike should_compile_partial_graph(), this skips the exception table entry check.
  5311. """
  5312. if not config.nested_graph_breaks:
  5313. return False
  5314. if not self.funcvar.should_allow_nested_graph_breaks():
  5315. return False
  5316. if not self.parent.should_compile_partial_graph():
  5317. return False
  5318. return True
  5319. def should_compile_partial_graph(self) -> bool:
  5320. if config.nested_graph_breaks:
  5321. if not self.funcvar.should_allow_nested_graph_breaks():
  5322. return False
  5323. if not self.parent.should_compile_partial_graph():
  5324. return False
  5325. return super().should_compile_partial_graph()
  5326. return False # inlining functions is all-or-nothing
  5327. def create_call_resume_at(
  5328. self,
  5329. inst: Instruction,
  5330. all_stack_locals_metadata: list[StackLocalsMetadata],
  5331. ) -> list[Instruction]:
  5332. if config.nested_graph_breaks:
  5333. return super().create_call_resume_at(inst, all_stack_locals_metadata)
  5334. unimplemented(
  5335. gb_type="Graph break in inlined function",
  5336. context="",
  5337. explanation="Graph breaks in an inlined call are not supported.",
  5338. hints=[],
  5339. )
  5340. def RETURN_VALUE(self, inst: Instruction) -> None:
  5341. self.symbolic_result = self.pop()
  5342. self.instruction_pointer = None
  5343. raise ReturnValueOp
  5344. def RETURN_CONST(self, inst: Instruction) -> None:
  5345. self.symbolic_result = self._load_const(inst)
  5346. self.instruction_pointer = None
  5347. raise ReturnValueOp
  5348. def get_globals_source_and_value(
  5349. self, name: str
  5350. ) -> tuple[Any, VariableTracker, Source]:
  5351. # NamedTuple's `__new__` has a fake global scope that's not an actual
  5352. # module. TODO generalize the check for other non-importable cases.
  5353. # https://github.com/python/cpython/blob/8421b03b16a4852a527256cb7cdce2ab2d318548/Lib/collections/__init__.py#L441-L447
  5354. if "__name__" in self.f_globals and not self.f_globals["__name__"].startswith(
  5355. "namedtuple_"
  5356. ):
  5357. module_name = self.f_globals["__name__"]
  5358. module_source = self.import_source(module_name)
  5359. if "torch_package" in module_name:
  5360. fglobals_value = (
  5361. torch.package.package_importer._package_imported_modules[
  5362. module_name
  5363. ]
  5364. ) # type: ignore[assignment]
  5365. else:
  5366. fglobals_value = _import_module(module_name)
  5367. # Dont use lazy vt because we will do a setattr afterwards
  5368. # TODO: fix InstructionTranslator -> InstructionTranslatorBase
  5369. # pyrefly: ignore[bad-argument-type]
  5370. fglobals_vt = VariableBuilder(self, module_source)(fglobals_value)
  5371. global_source = AttrSource(module_source, name)
  5372. else:
  5373. globals_name = self.output.install_global_by_id(
  5374. "___unnamed_scope", self.f_globals
  5375. )
  5376. globals_source = GlobalSource(globals_name)
  5377. fglobals_value = self.f_globals # type: ignore[assignment]
  5378. # Dont use lazy vt because we will do a setattr afterwards
  5379. # pyrefly: ignore[bad-argument-type]
  5380. fglobals_vt = VariableBuilder(self, globals_source)(fglobals_value)
  5381. global_source = DictGetItemSource(globals_source, name) # type: ignore[assignment]
  5382. if is_stdlib(fglobals_value):
  5383. # Users don't inplace mutate a stdlib attribute (like inspect,
  5384. # collections), skip guards that originate from the stdlib modules.
  5385. global_source = SkipGuardSource(global_source) # type: ignore[assignment]
  5386. return fglobals_value, fglobals_vt, global_source
  5387. def _load_global(self, inst: Instruction) -> None:
  5388. name = inst.argval
  5389. if name not in self.f_globals:
  5390. return self.load_builtin(inst)
  5391. if self.output.global_scope is self.f_globals:
  5392. # If the global scope matches that of the root frame, use handler in
  5393. # root frame instruction translator, to enforce consistency.
  5394. super()._load_global(inst)
  5395. else:
  5396. _, fglobals_vt, global_source = self.get_globals_source_and_value(name)
  5397. if self.output.side_effects.has_pending_mutation_of_attr(fglobals_vt, name):
  5398. self.push(self.output.side_effects.load_attr(fglobals_vt, name))
  5399. else:
  5400. value = self.f_globals[name]
  5401. self.push(VariableTracker.build(self, value, global_source))
  5402. def STORE_GLOBAL(self, inst: Instruction) -> None:
  5403. if self.output.global_scope is self.f_globals:
  5404. # If the global scope matches that of the root frame, use handler in
  5405. # root frame instruction translator, to enforce consistency.
  5406. super().STORE_GLOBAL(inst)
  5407. else:
  5408. value = self.pop()
  5409. if isinstance(value, RemovableHandleVariable):
  5410. unimplemented(
  5411. gb_type="Storing Tensor hook handle in globals (inline call)",
  5412. context=inst.argval,
  5413. explanation="This is not supported.",
  5414. hints=[],
  5415. )
  5416. name = inst.argval
  5417. _fglobals_value, fglobals_vt, _ = self.get_globals_source_and_value(name)
  5418. self.output.side_effects.store_attr(fglobals_vt, name, value)
  5419. class InliningGeneratorInstructionTranslator(InliningInstructionTranslator):
  5420. generated_items: list[VariableTracker]
  5421. # Flag whether or not the InlineGenerator should consume the entire iterator
  5422. def __init__(self, *args: Any, **kwargs: Any) -> None:
  5423. super().__init__(*args, **kwargs)
  5424. self.generated_items = []
  5425. self.generator_exhausted = False
  5426. self.is_generator_from_ctx_manager = False
  5427. def inline_call_(self) -> VariableTracker:
  5428. with profile_inline_call(self.output, self.f_code, lambda: self.inline_depth):
  5429. return super().inline_call_()
  5430. def should_compile_partial_graph(self) -> bool:
  5431. # resuming on graph break on inlined generator not supported
  5432. return False
  5433. def YIELD_VALUE(self, inst: Instruction) -> None:
  5434. top = self.pop()
  5435. self.generated_items.append(top)
  5436. if len(self.generated_items) > MAX_ITERATOR_LIMIT:
  5437. raise exc.InfiniteGeneratorError
  5438. self.push(CONSTANT_VARIABLE_NONE)
  5439. if (
  5440. config.enable_faithful_generator_behavior
  5441. or self.is_generator_from_ctx_manager
  5442. ):
  5443. self.symbolic_result = top
  5444. # Stop tracing
  5445. raise YieldValueOp
  5446. def GET_YIELD_FROM_ITER(self, inst: Instruction) -> None:
  5447. tos = self.stack[-1]
  5448. if not isinstance(tos, ListIteratorVariable):
  5449. self.pop()
  5450. res = BuiltinVariable(iter).call_function(self, [tos], {}) # type: ignore[arg-type]
  5451. self.push(res)
  5452. def RETURN_VALUE(self, inst: Instruction) -> None:
  5453. self.generator_exhausted = True
  5454. return super().RETURN_VALUE(inst)
  5455. def RETURN_CONST(self, inst: Instruction) -> None:
  5456. self.generator_exhausted = True
  5457. return super().RETURN_CONST(inst)
  5458. def YIELD_FROM(self, inst: Instruction) -> None:
  5459. assert len(self.stack) >= 2
  5460. val = self.pop()
  5461. tos = self.stack[-1]
  5462. if not val.is_constant_none():
  5463. # invoke send
  5464. # Unreachable code - if you hit this, you are implementing generator support and have
  5465. # lifted the `unimplemented("generator")` in frame conversion. This codepath handles
  5466. # subgenerator and lines up with this line in Python 3.10
  5467. # https://github.com/python/cpython/blob/3.10/Python/ceval.c#L2599
  5468. unimplemented(
  5469. gb_type="Unreachable sub-generator code",
  5470. context="",
  5471. explanation="Should only be encountered while implementing generator support.",
  5472. hints=[],
  5473. )
  5474. try:
  5475. val = tos.next_variable(self)
  5476. except (StopIteration, exc.ObservedUserStopIteration) as ex:
  5477. if isinstance(ex, exc.ObservedUserStopIteration):
  5478. exc.handle_observed_exception(self)
  5479. # The iterator is exhausted. Stop the loop and return.
  5480. self.pop()
  5481. self.push(ConstantVariable.create(ex.value))
  5482. else:
  5483. # Repeat the YIELD_FROM instruction in the next eval loop
  5484. assert (
  5485. isinstance(self.instruction_pointer, int)
  5486. and self.instruction_pointer > 0
  5487. )
  5488. self.instruction_pointer -= 1
  5489. self.push(val)
  5490. # Add the value to yield into generated_items and replace the top of the stack with None
  5491. self.YIELD_VALUE(inst)
  5492. def SEND(self, inst: Instruction) -> None:
  5493. assert len(self.stack) >= 2
  5494. val = self.pop()
  5495. tos = self.stack[-1]
  5496. if isinstance(tos, (IteratorVariable, LocalGeneratorObjectVariable)) or (
  5497. isinstance(tos, UserDefinedObjectVariable)
  5498. and isinstance(tos.value, collections.abc.Iterator)
  5499. ):
  5500. if val.is_constant_none():
  5501. try:
  5502. val = tos.next_variable(self) # type: ignore[arg-type]
  5503. except (StopIteration, exc.ObservedUserStopIteration) as ex:
  5504. # To implement SEND, we have to look at the implementation
  5505. # when the iterator returns StopIteration. This translates to this code
  5506. # 3.11: https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2613-L2619
  5507. # 3.12: https://github.com/python/cpython/blob/3.12/Python/bytecodes.c#L863-L866
  5508. # The implementation is different in 3.11 and 3.12. In 3.12, we rely
  5509. # on END_SEND to clean up. In 3.11, SEND does the cleanup as well.
  5510. if sys.version_info < (3, 12):
  5511. self.pop() # Python 3.12 uses new opcode END_SEND
  5512. self.push(ConstantVariable.create(ex.value))
  5513. self.jump(inst)
  5514. else:
  5515. self.push(val)
  5516. else:
  5517. # invoke send
  5518. # Unreachable code - if you hit this, you are implementing generator support and have
  5519. # lifted the `unimplemented("generator")` in frame conversion. This codepath handles
  5520. # subgenerator and lines up with this line in Python 3.11
  5521. # https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2597
  5522. unimplemented(
  5523. gb_type="Unreachable sub-generator code",
  5524. context="",
  5525. explanation="Should only be encountered while implementing generator support.",
  5526. hints=[],
  5527. )
  5528. else:
  5529. unimplemented(
  5530. gb_type="SEND with bad type",
  5531. context=f"TOS type: {typestr(tos)}",
  5532. explanation=f"Attempted to SEND with unsupported type {typestr(tos)}.",
  5533. hints=[],
  5534. )