ir.py 351 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490549154925493549454955496549754985499550055015502550355045505550655075508550955105511551255135514551555165517551855195520552155225523552455255526552755285529553055315532553355345535553655375538553955405541554255435544554555465547554855495550555155525553555455555556555755585559556055615562556355645565556655675568556955705571557255735574557555765577557855795580558155825583558455855586558755885589559055915592559355945595559655975598559956005601560256035604560556065607560856095610561156125613561456155616561756185619562056215622562356245625562656275628562956305631563256335634563556365637563856395640564156425643564456455646564756485649565056515652565356545655565656575658565956605661566256635664566556665667566856695670567156725673567456755676567756785679568056815682568356845685568656875688568956905691569256935694569556965697569856995700570157025703570457055706570757085709571057115712571357145715571657175718571957205721572257235724572557265727572857295730573157325733573457355736573757385739574057415742574357445745574657475748574957505751575257535754575557565757575857595760576157625763576457655766576757685769577057715772577357745775577657775778577957805781578257835784578557865787578857895790579157925793579457955796579757985799580058015802580358045805580658075808580958105811581258135814581558165817581858195820582158225823582458255826582758285829583058315832583358345835583658375838583958405841584258435844584558465847584858495850585158525853585458555856585758585859586058615862586358645865586658675868586958705871587258735874587558765877587858795880588158825883588458855886588758885889589058915892589358945895589658975898589959005901590259035904590559065907590859095910591159125913591459155916591759185919592059215922592359245925592659275928592959305931593259335934593559365937593859395940594159425943594459455946594759485949595059515952595359545955595659575958595959605961596259635964596559665967596859695970597159725973597459755976597759785979598059815982598359845985598659875988598959905991599259935994599559965997599859996000600160026003600460056006600760086009601060116012601360146015601660176018601960206021602260236024602560266027602860296030603160326033603460356036603760386039604060416042604360446045604660476048604960506051605260536054605560566057605860596060606160626063606460656066606760686069607060716072607360746075607660776078607960806081608260836084608560866087608860896090609160926093609460956096609760986099610061016102610361046105610661076108610961106111611261136114611561166117611861196120612161226123612461256126612761286129613061316132613361346135613661376138613961406141614261436144614561466147614861496150615161526153615461556156615761586159616061616162616361646165616661676168616961706171617261736174617561766177617861796180618161826183618461856186618761886189619061916192619361946195619661976198619962006201620262036204620562066207620862096210621162126213621462156216621762186219622062216222622362246225622662276228622962306231623262336234623562366237623862396240624162426243624462456246624762486249625062516252625362546255625662576258625962606261626262636264626562666267626862696270627162726273627462756276627762786279628062816282628362846285628662876288628962906291629262936294629562966297629862996300630163026303630463056306630763086309631063116312631363146315631663176318631963206321632263236324632563266327632863296330633163326333633463356336633763386339634063416342634363446345634663476348634963506351635263536354635563566357635863596360636163626363636463656366636763686369637063716372637363746375637663776378637963806381638263836384638563866387638863896390639163926393639463956396639763986399640064016402640364046405640664076408640964106411641264136414641564166417641864196420642164226423642464256426642764286429643064316432643364346435643664376438643964406441644264436444644564466447644864496450645164526453645464556456645764586459646064616462646364646465646664676468646964706471647264736474647564766477647864796480648164826483648464856486648764886489649064916492649364946495649664976498649965006501650265036504650565066507650865096510651165126513651465156516651765186519652065216522652365246525652665276528652965306531653265336534653565366537653865396540654165426543654465456546654765486549655065516552655365546555655665576558655965606561656265636564656565666567656865696570657165726573657465756576657765786579658065816582658365846585658665876588658965906591659265936594659565966597659865996600660166026603660466056606660766086609661066116612661366146615661666176618661966206621662266236624662566266627662866296630663166326633663466356636663766386639664066416642664366446645664666476648664966506651665266536654665566566657665866596660666166626663666466656666666766686669667066716672667366746675667666776678667966806681668266836684668566866687668866896690669166926693669466956696669766986699670067016702670367046705670667076708670967106711671267136714671567166717671867196720672167226723672467256726672767286729673067316732673367346735673667376738673967406741674267436744674567466747674867496750675167526753675467556756675767586759676067616762676367646765676667676768676967706771677267736774677567766777677867796780678167826783678467856786678767886789679067916792679367946795679667976798679968006801680268036804680568066807680868096810681168126813681468156816681768186819682068216822682368246825682668276828682968306831683268336834683568366837683868396840684168426843684468456846684768486849685068516852685368546855685668576858685968606861686268636864686568666867686868696870687168726873687468756876687768786879688068816882688368846885688668876888688968906891689268936894689568966897689868996900690169026903690469056906690769086909691069116912691369146915691669176918691969206921692269236924692569266927692869296930693169326933693469356936693769386939694069416942694369446945694669476948694969506951695269536954695569566957695869596960696169626963696469656966696769686969697069716972697369746975697669776978697969806981698269836984698569866987698869896990699169926993699469956996699769986999700070017002700370047005700670077008700970107011701270137014701570167017701870197020702170227023702470257026702770287029703070317032703370347035703670377038703970407041704270437044704570467047704870497050705170527053705470557056705770587059706070617062706370647065706670677068706970707071707270737074707570767077707870797080708170827083708470857086708770887089709070917092709370947095709670977098709971007101710271037104710571067107710871097110711171127113711471157116711771187119712071217122712371247125712671277128712971307131713271337134713571367137713871397140714171427143714471457146714771487149715071517152715371547155715671577158715971607161716271637164716571667167716871697170717171727173717471757176717771787179718071817182718371847185718671877188718971907191719271937194719571967197719871997200720172027203720472057206720772087209721072117212721372147215721672177218721972207221722272237224722572267227722872297230723172327233723472357236723772387239724072417242724372447245724672477248724972507251725272537254725572567257725872597260726172627263726472657266726772687269727072717272727372747275727672777278727972807281728272837284728572867287728872897290729172927293729472957296729772987299730073017302730373047305730673077308730973107311731273137314731573167317731873197320732173227323732473257326732773287329733073317332733373347335733673377338733973407341734273437344734573467347734873497350735173527353735473557356735773587359736073617362736373647365736673677368736973707371737273737374737573767377737873797380738173827383738473857386738773887389739073917392739373947395739673977398739974007401740274037404740574067407740874097410741174127413741474157416741774187419742074217422742374247425742674277428742974307431743274337434743574367437743874397440744174427443744474457446744774487449745074517452745374547455745674577458745974607461746274637464746574667467746874697470747174727473747474757476747774787479748074817482748374847485748674877488748974907491749274937494749574967497749874997500750175027503750475057506750775087509751075117512751375147515751675177518751975207521752275237524752575267527752875297530753175327533753475357536753775387539754075417542754375447545754675477548754975507551755275537554755575567557755875597560756175627563756475657566756775687569757075717572757375747575757675777578757975807581758275837584758575867587758875897590759175927593759475957596759775987599760076017602760376047605760676077608760976107611761276137614761576167617761876197620762176227623762476257626762776287629763076317632763376347635763676377638763976407641764276437644764576467647764876497650765176527653765476557656765776587659766076617662766376647665766676677668766976707671767276737674767576767677767876797680768176827683768476857686768776887689769076917692769376947695769676977698769977007701770277037704770577067707770877097710771177127713771477157716771777187719772077217722772377247725772677277728772977307731773277337734773577367737773877397740774177427743774477457746774777487749775077517752775377547755775677577758775977607761776277637764776577667767776877697770777177727773777477757776777777787779778077817782778377847785778677877788778977907791779277937794779577967797779877997800780178027803780478057806780778087809781078117812781378147815781678177818781978207821782278237824782578267827782878297830783178327833783478357836783778387839784078417842784378447845784678477848784978507851785278537854785578567857785878597860786178627863786478657866786778687869787078717872787378747875787678777878787978807881788278837884788578867887788878897890789178927893789478957896789778987899790079017902790379047905790679077908790979107911791279137914791579167917791879197920792179227923792479257926792779287929793079317932793379347935793679377938793979407941794279437944794579467947794879497950795179527953795479557956795779587959796079617962796379647965796679677968796979707971797279737974797579767977797879797980798179827983798479857986798779887989799079917992799379947995799679977998799980008001800280038004800580068007800880098010801180128013801480158016801780188019802080218022802380248025802680278028802980308031803280338034803580368037803880398040804180428043804480458046804780488049805080518052805380548055805680578058805980608061806280638064806580668067806880698070807180728073807480758076807780788079808080818082808380848085808680878088808980908091809280938094809580968097809880998100810181028103810481058106810781088109811081118112811381148115811681178118811981208121812281238124812581268127812881298130813181328133813481358136813781388139814081418142814381448145814681478148814981508151815281538154815581568157815881598160816181628163816481658166816781688169817081718172817381748175817681778178817981808181818281838184818581868187818881898190819181928193819481958196819781988199820082018202820382048205820682078208820982108211821282138214821582168217821882198220822182228223822482258226822782288229823082318232823382348235823682378238823982408241824282438244824582468247824882498250825182528253825482558256825782588259826082618262826382648265826682678268826982708271827282738274827582768277827882798280828182828283828482858286828782888289829082918292829382948295829682978298829983008301830283038304830583068307830883098310831183128313831483158316831783188319832083218322832383248325832683278328832983308331833283338334833583368337833883398340834183428343834483458346834783488349835083518352835383548355835683578358835983608361836283638364836583668367836883698370837183728373837483758376837783788379838083818382838383848385838683878388838983908391839283938394839583968397839883998400840184028403840484058406840784088409841084118412841384148415841684178418841984208421842284238424842584268427842884298430843184328433843484358436843784388439844084418442844384448445844684478448844984508451845284538454845584568457845884598460846184628463846484658466846784688469847084718472847384748475847684778478847984808481848284838484848584868487848884898490849184928493849484958496849784988499850085018502850385048505850685078508850985108511851285138514851585168517851885198520852185228523852485258526852785288529853085318532853385348535853685378538853985408541854285438544854585468547854885498550855185528553855485558556855785588559856085618562856385648565856685678568856985708571857285738574857585768577857885798580858185828583858485858586858785888589859085918592859385948595859685978598859986008601860286038604860586068607860886098610861186128613861486158616861786188619862086218622862386248625862686278628862986308631863286338634863586368637863886398640864186428643864486458646864786488649865086518652865386548655865686578658865986608661866286638664866586668667866886698670867186728673867486758676867786788679868086818682868386848685868686878688868986908691869286938694869586968697869886998700870187028703870487058706870787088709871087118712871387148715871687178718871987208721872287238724872587268727872887298730873187328733873487358736873787388739874087418742874387448745874687478748874987508751875287538754875587568757875887598760876187628763876487658766876787688769877087718772877387748775877687778778877987808781878287838784878587868787878887898790879187928793879487958796879787988799880088018802880388048805880688078808880988108811881288138814881588168817881888198820882188228823882488258826882788288829883088318832883388348835883688378838883988408841884288438844884588468847884888498850885188528853885488558856885788588859886088618862886388648865886688678868886988708871887288738874887588768877887888798880888188828883888488858886888788888889889088918892889388948895889688978898889989008901890289038904890589068907890889098910891189128913891489158916891789188919892089218922892389248925892689278928892989308931893289338934893589368937893889398940894189428943894489458946894789488949895089518952895389548955895689578958895989608961896289638964896589668967896889698970897189728973897489758976897789788979898089818982898389848985898689878988898989908991899289938994899589968997899889999000900190029003900490059006900790089009901090119012901390149015901690179018901990209021902290239024902590269027902890299030903190329033903490359036903790389039904090419042904390449045904690479048904990509051905290539054905590569057905890599060906190629063906490659066906790689069907090719072907390749075907690779078907990809081908290839084908590869087908890899090909190929093909490959096909790989099910091019102910391049105910691079108910991109111911291139114911591169117911891199120912191229123912491259126912791289129913091319132913391349135913691379138913991409141914291439144914591469147914891499150915191529153915491559156915791589159916091619162916391649165916691679168916991709171917291739174917591769177917891799180918191829183918491859186918791889189919091919192919391949195919691979198919992009201920292039204920592069207920892099210921192129213921492159216921792189219922092219222922392249225922692279228922992309231923292339234923592369237923892399240924192429243924492459246924792489249925092519252925392549255925692579258925992609261926292639264926592669267926892699270927192729273927492759276927792789279928092819282928392849285928692879288928992909291929292939294929592969297929892999300930193029303930493059306930793089309931093119312931393149315931693179318931993209321932293239324932593269327932893299330933193329333933493359336933793389339934093419342934393449345934693479348934993509351935293539354935593569357935893599360936193629363936493659366936793689369937093719372937393749375937693779378937993809381938293839384938593869387938893899390939193929393939493959396939793989399940094019402940394049405940694079408940994109411941294139414941594169417941894199420942194229423942494259426942794289429943094319432943394349435943694379438943994409441944294439444944594469447944894499450945194529453945494559456945794589459946094619462946394649465946694679468946994709471947294739474947594769477947894799480948194829483948494859486948794889489949094919492949394949495949694979498949995009501950295039504950595069507950895099510951195129513951495159516951795189519952095219522952395249525952695279528952995309531953295339534953595369537953895399540954195429543954495459546954795489549955095519552955395549555955695579558955995609561956295639564956595669567956895699570957195729573957495759576957795789579958095819582958395849585958695879588958995909591959295939594959595969597959895999600960196029603960496059606960796089609961096119612961396149615961696179618961996209621962296239624962596269627962896299630963196329633963496359636963796389639964096419642964396449645964696479648964996509651965296539654965596569657965896599660966196629663966496659666966796689669967096719672967396749675967696779678967996809681968296839684968596869687968896899690969196929693969496959696969796989699970097019702970397049705970697079708970997109711971297139714971597169717971897199720972197229723972497259726972797289729973097319732973397349735973697379738973997409741974297439744974597469747974897499750975197529753975497559756975797589759976097619762976397649765976697679768976997709771977297739774977597769777977897799780978197829783978497859786978797889789979097919792979397949795979697979798979998009801980298039804980598069807980898099810981198129813981498159816981798189819982098219822982398249825982698279828982998309831983298339834983598369837983898399840984198429843984498459846984798489849985098519852985398549855985698579858985998609861986298639864986598669867986898699870987198729873987498759876987798789879988098819882988398849885988698879888988998909891989298939894989598969897989898999900990199029903990499059906990799089909991099119912991399149915991699179918991999209921992299239924992599269927992899299930993199329933993499359936993799389939994099419942994399449945994699479948994999509951995299539954995599569957995899599960996199629963996499659966
  1. from __future__ import annotations
  2. import contextlib
  3. import copy
  4. import dataclasses
  5. import functools
  6. import itertools
  7. import logging
  8. import operator
  9. import textwrap
  10. import traceback
  11. from collections.abc import Callable, Generator, Iterable, Iterator, Sequence
  12. from contextlib import AbstractContextManager, nullcontext
  13. from enum import Enum
  14. from functools import partial
  15. from typing import (
  16. Any,
  17. cast,
  18. ClassVar,
  19. Literal,
  20. Optional,
  21. overload,
  22. SupportsFloat,
  23. SupportsInt,
  24. TYPE_CHECKING,
  25. TypeAlias,
  26. TypeVar,
  27. Union,
  28. )
  29. from typing_extensions import assert_never, Never, override, ParamSpec, Self, TypeIs
  30. from unittest.mock import patch
  31. import sympy
  32. from sympy import Expr, Integer, Symbol
  33. import torch._export.serde.schema as export_schema
  34. import torch._library.utils as library_utils
  35. import torch._logging
  36. import torch.fx
  37. import torch.utils._pytree as pytree
  38. from torch._dynamo.utils import identity
  39. from torch._export.serde.serialize import GraphModuleSerializer
  40. from torch._higher_order_ops.auto_functionalize import can_auto_functionalize
  41. from torch._inductor import metrics
  42. from torch._inductor.utils import get_free_symbols
  43. from torch._library.opaque_object import is_opaque_type
  44. from torch._prims_common import (
  45. compute_required_storage_length,
  46. is_boolean_dtype,
  47. is_float_dtype,
  48. make_channels_last_strides_for,
  49. StrideType,
  50. )
  51. from torch.fx.experimental.symbolic_shapes import (
  52. _remove_effect_token_unbacked_bindings,
  53. compute_unbacked_bindings,
  54. free_symbols,
  55. free_unbacked_symbols,
  56. IterateExprs,
  57. rebind_unbacked,
  58. resolve_unbacked_bindings,
  59. ShapeEnv,
  60. SymTypes,
  61. )
  62. from torch.fx.node import Node
  63. from torch.utils._ordered_set import OrderedSet
  64. from torch.utils._python_dispatch import _disable_current_modes
  65. from torch.utils._sympy.functions import CleanDiv, FloorDiv, Mod, ModularIndexing
  66. from torch.utils._sympy.symbol import SymT
  67. from . import config, dependencies
  68. from .codegen.common import (
  69. BackendFeature,
  70. CodegenSymbol,
  71. get_scheduling_for_device,
  72. index_prevent_reordering,
  73. Kernel,
  74. )
  75. from .dependencies import (
  76. Dep,
  77. extract_free_symbols,
  78. extract_input_node_reduction_ranges,
  79. extract_read_writes,
  80. var_builder,
  81. )
  82. from .loop_body import LoopBody
  83. from .ops_handler import OpCounterCSE, OpCountResult, ReductionType, StoreMode
  84. from .runtime.benchmarking import benchmarker
  85. from .runtime.hints import DeviceProperties, ReductionHint
  86. from .utils import (
  87. argsort,
  88. argsort_sym,
  89. cache_on_self,
  90. cache_on_self_and_args,
  91. ceildiv,
  92. convert_shape_to_inductor,
  93. convert_shape_to_symint,
  94. developer_warning,
  95. do_bench_using_profiling,
  96. dtype_from_size,
  97. get_dtype_size,
  98. get_kernel_metadata,
  99. GPU_ALIGN_BYTES,
  100. ir_dataclass,
  101. is_dynamic,
  102. is_gpu,
  103. sympy_dot,
  104. sympy_index_symbol,
  105. sympy_index_symbol_with_prefix,
  106. sympy_product,
  107. sympy_subs,
  108. tensor_is_aligned,
  109. )
  110. from .virtualized import ops, OpsValue, V
  111. if TYPE_CHECKING:
  112. from torch._library.fake_class_registry import FakeScriptObject
  113. from torch.fx.experimental.symbolic_shapes import SympyBoolean
  114. from torch.fx.node import Argument
  115. from .codegen.cutlass.template import CUTLASSTemplate
  116. from .codegen.wrapper import PythonWrapperCodegen
  117. from .graph import GraphLowering
  118. from .utils import IndentedBuffer
  119. else:
  120. CUTLASSTemplate: TypeAlias = object
  121. try:
  122. import triton
  123. triton_version = triton.__version__
  124. has_triton = True
  125. except ImportError:
  126. triton_version = None
  127. has_triton = False
  128. _P = ParamSpec("_P")
  129. _T = TypeVar("_T")
  130. _U = TypeVar("_U")
  131. _V = TypeVar("_V")
  132. _IntLike: TypeAlias = Union[int, Expr]
  133. _NumLike: TypeAlias = Union[int, float, Expr]
  134. _OpOverloads: TypeAlias = Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator]
  135. log = logging.getLogger(__name__)
  136. indent = functools.partial(textwrap.indent, prefix=" ")
  137. aten = torch.ops.aten
  138. """ [Note: Inductor IR]
  139. Inductor's IR is produced by executing 'lowering' code (see lowering.py). Each
  140. lowering is registered to a particular aten operator, and expects inputs that
  141. correspond to the aten schema. However, in place of torch Tensor inputs, lowerings
  142. expect Inductor TensorBox inputs.
  143. TensorBox IR represents torch tensors. Tensors are sometimes single objects owning
  144. storage, and sometimes views of another Tensor's storage. Mutating tensor operations
  145. (such as add_()) affect the underlying storage and any associated views. Other operations
  146. (such as .t_()) update metadata about the current view but don't modify the underlying storage.
  147. To model this in Inductor, the IR distinguishes between TensorBox, View, StorageBox and Buffer.
  148. TensorBox is the top level IR construct that any lowering should produce and maps to a torch.Tensor
  149. output from an operation. But just as torch.Tensors take different forms, TensorBox IR can
  150. reference View IR or directly reference StorageBox IRs.
  151. Some Inductor lowerings produce new sets of 'Box'es, while others (such as .t() or other view ops)
  152. may take an existing TensorBox and point it to a new underlying View IR.
  153. Tensors that directly own storage are represented as a chain of:
  154. TensorBox -> StorageBox -> Buffer
  155. where Buffer is a simple (1D) allocation, and StorageBox introduces the concept of a Layout.
  156. If you mutate the data of such a tensor, we swing the StorageBox pointer to point to a new buffer
  157. (leaving the old buffer unmodified and functionalizing the operation).
  158. Tensors backed by views add one more indirection to the IR.
  159. TensorBox -> View -> StorageBox -> Buffer
  160. In these cases, the underlying StorageBox/Buffer will be shared with the pre-view TensorBox.
  161. Computation is represented by Operation nodes, with each operation producing 1
  162. or more output Buffers. In the case of mutations, these will be new Buffers that have the
  163. mutated buffer listed in its get_mutation_names().
  164. It is also possible to have an InputBuffer for which there is no corresponding Operation,
  165. e.g. it may be a graph input or compile time constant.
  166. """
  167. _NodeOrNodes: TypeAlias = Union[
  168. int,
  169. "TensorBox",
  170. dict[str, "TensorBox"],
  171. "Symbol",
  172. "IRNode",
  173. Sequence[
  174. Optional[Union[int, dict[str, "TensorBox"], "TensorBox", "Symbol", "IRNode"]]
  175. ],
  176. ]
  177. def _is_static(x: object) -> TypeIs[Union[int, Integer]]:
  178. return isinstance(x, (int, Integer))
  179. @dataclasses.dataclass(frozen=True)
  180. class GraphPartitionSignature:
  181. # symbol inputs that are necessary for codegen
  182. symbol_inputs: OrderedSet[sympy.Symbol]
  183. # mapping from partition input name to IRNode or Expr. Need the name str since
  184. # we cannot get name from Expr.
  185. input_nodes: dict[str, Union[IRNode, sympy.Expr, TorchBindObject]]
  186. output_nodes: list[IRNode]
  187. # mapping from partition input name to a boolean for whether deallocating it
  188. # in the partition function
  189. input_deallocation: dict[str, bool]
  190. skip_cudagraph: bool
  191. # name of constants read/written by the graph partition
  192. constant_names: list[str]
  193. def validate_ir(node_or_nodes: Optional[_NodeOrNodes]) -> None:
  194. def _check_tensorbox(nodes: Optional[_NodeOrNodes]) -> None:
  195. # Could expand this to check deeper properties
  196. # (e.g. TensorBox points to View or StorageBox)
  197. if nodes is None:
  198. pass
  199. elif isinstance(nodes, (list, tuple)):
  200. for node in nodes:
  201. _check_tensorbox(node)
  202. elif isinstance(nodes, dict):
  203. for node in nodes.values():
  204. _check_tensorbox(node)
  205. else:
  206. assert isinstance(
  207. nodes,
  208. (
  209. ExpandView,
  210. DynamicScalar,
  211. AssertScalar,
  212. TensorBox,
  213. sympy.logic.boolalg.Boolean,
  214. Expr,
  215. int,
  216. EffectfulKernel,
  217. ShapeAsConstantBuffer,
  218. ),
  219. ), (
  220. f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]"
  221. )
  222. # Be picky about the accepted data structure (don't use pytree here)
  223. _check_tensorbox(node_or_nodes)
  224. def ops_wrapper(name: str) -> Callable[..., OpsValue]:
  225. assert isinstance(name, str), type(name)
  226. def fn(*args: object, **kwargs: object) -> OpsValue:
  227. return getattr(ops, name)(*args, **kwargs)
  228. return fn
  229. def inverse_reorder(order: Sequence[int]) -> Callable[[Sequence[_T]], Sequence[_T]]:
  230. inv_order = dict(zip(order, range(len(order))))
  231. def reindex(index: Sequence[_T]) -> Sequence[_T]:
  232. assert len(index) == len(inv_order)
  233. return [index[inv_order[i]] for i in range(len(index))]
  234. return reindex
  235. def same_reorder(order: Sequence[int]) -> Callable[[Sequence[_T]], Sequence[_T]]:
  236. def reindex(index: Sequence[_T]) -> Sequence[_T]:
  237. assert len(index) == len(order)
  238. return [index[order[i]] for i in range(len(index))]
  239. return reindex
  240. def fuse_reindexing(
  241. reindex1: Callable[[Sequence[_U]], Sequence[_V]],
  242. reindex2: Callable[[Sequence[_T]], Sequence[_U]],
  243. ) -> Callable[[Sequence[_T]], Sequence[_V]]:
  244. def reindex(index: Sequence[_T]) -> Sequence[_V]:
  245. return reindex1(reindex2(index))
  246. return reindex
  247. NHWC_STRIDE_ORDER = [3, 0, 2, 1]
  248. NHWDC_STRIDE_ORDER = [4, 0, 3, 2, 1]
  249. def get_fill_order(
  250. seq: Sequence[Union[int, torch.SymInt, Expr]], shape_env: Optional[ShapeEnv] = None
  251. ) -> Sequence[int]:
  252. """
  253. Convert strides to fill order (argsort)
  254. """
  255. if shape_env is None or all(isinstance(s, (int, sympy.Integer)) for s in seq):
  256. sorted_idx: Sequence[int] = argsort(seq)
  257. else:
  258. # argsort_sym handles unbacked symints (with the help of the shape_env)
  259. sorted_idx = argsort_sym(shape_env, seq)
  260. return sorted_idx
  261. def stride_order2fill_order(order: Sequence[Union[int, Integer]]) -> Sequence[int]:
  262. """
  263. Convert stride order to fill order
  264. For channel last format,
  265. stride order = [3, 0, 2, 1] and fill order = [1, 3, 2, 0]
  266. """
  267. lookup = {pos: idx for idx, pos in enumerate(order)}
  268. fill_order = [lookup[i] for i in range(len(order))]
  269. return fill_order
  270. def get_stride_order(
  271. seq: Sequence[Union[int, torch.SymInt, Expr]], shape_env: Optional[ShapeEnv] = None
  272. ) -> Sequence[int]:
  273. """
  274. Convert strides to stride order
  275. """
  276. sorted_idx: Sequence[int] = get_fill_order(seq, shape_env)
  277. out = [0 for _ in range(len(seq))]
  278. for i, elem in enumerate(sorted_idx):
  279. out[elem] = i
  280. return out
  281. @overload
  282. def ir_node_to_tensor(x: None, guard_shape: bool = True) -> None: ...
  283. @overload
  284. def ir_node_to_tensor(x: IRNode, guard_shape: bool = True) -> torch.Tensor: ...
  285. def ir_node_to_tensor(
  286. x: Optional[IRNode], guard_shape: bool = True
  287. ) -> Optional[torch.Tensor]:
  288. if x is None:
  289. return None
  290. shape_fn: Callable[[Union[int, Expr]], Union[int, Expr]]
  291. if not guard_shape:
  292. shape_fn = V.graph.sizevars.size_hint
  293. else:
  294. shape_fn = identity
  295. size = [shape_fn(s) for s in x.get_size()]
  296. stride: StrideType
  297. if is_storage_and_layout(x):
  298. stride = [shape_fn(s) for s in x.get_layout().stride]
  299. else:
  300. stride = FlexibleLayout.contiguous_strides(size)
  301. dtype = x.get_dtype()
  302. device = x.get_device()
  303. size = convert_shape_to_symint(size)
  304. # pyrefly: ignore [bad-assignment]
  305. stride = convert_shape_to_symint(stride)
  306. with V.graph.sizevars.shape_env.suppress_guards():
  307. t = torch.empty_strided(
  308. size=size, stride=stride, dtype=dtype, device=device
  309. ).zero_()
  310. return t
  311. def may_convert_to_optional(
  312. value: Optional[Sequence[_T]],
  313. ) -> Optional[Sequence[Optional[_T]]]:
  314. if isinstance(value, list) and not value:
  315. # [None] makes sure the cpp wrapper codegen will generate something like
  316. # {std::nullopt} instead of {}
  317. return [None]
  318. return value
  319. def get_device_type(
  320. x: Union[IRNode, OutputSpec, torch.device, None, str],
  321. ) -> Optional[str]:
  322. if isinstance(x, str) or x is None:
  323. return x
  324. elif isinstance(x, torch.device):
  325. return x.type
  326. elif isinstance(x, (IRNode, OutputSpec)):
  327. return get_device_type(x.get_device())
  328. # pyrefly: ignore [bad-argument-type]
  329. assert_never(f"get_device_type({x}: {type(x).__name__})")
  330. def is_triton(x: Union[IRNode, torch.device, None, str]) -> bool:
  331. device = get_device_type(x)
  332. # Special case cpu and cuda as using the method below
  333. # to determine if the scheduler is a triton scheduler subclass
  334. # requires instantiating a scheduler for them
  335. if device in ["cpu", "cuda"]:
  336. if getattr(config, f"{device}_backend") == "triton":
  337. return True
  338. return False
  339. if (
  340. device is None
  341. or (device_scheduling := get_scheduling_for_device(device)) is None
  342. ):
  343. return False
  344. from .codegen.triton import TritonScheduling
  345. assert isinstance(device_scheduling, type), type(device_scheduling)
  346. return issubclass(device_scheduling, TritonScheduling)
  347. def is_cpu(x: Union[IRNode, torch.device, None, str]) -> bool:
  348. return get_device_type(x) == "cpu"
  349. def is_aligned_realized_tensor(x: Union[Buffer, TensorBox], alignment: int) -> bool:
  350. if (
  351. not isinstance(x, IRNode)
  352. or x.maybe_get_stride() is None
  353. or free_unbacked_symbols(x.get_stride())
  354. or free_unbacked_symbols(x.get_size())
  355. ):
  356. return False
  357. aligned_strides = sympy.And(
  358. *(sympy.Eq(Mod(s, alignment), 0) for s in x.get_stride()[:-1])
  359. )
  360. aligned_last_dim = sympy.Or(
  361. sympy.Eq(x.get_stride()[-1], 1), sympy.Le(x.get_size()[-1], 1)
  362. )
  363. is_aligned = sympy.And(aligned_strides, aligned_last_dim)
  364. # Make sure to guard to recompile when necessary.
  365. return V.graph.sizevars.guard_or_false(is_aligned)
  366. def significant_strides_equal(
  367. strides1: Sequence[_IntLike],
  368. strides2: Sequence[_IntLike],
  369. shape: Sequence[_IntLike],
  370. ) -> bool:
  371. """
  372. Returns true if the strides are equal, ignoring dimensions of size 1 .
  373. """
  374. assert len(shape) == len(strides1) and len(strides1) == len(strides2)
  375. for dim, s1, s2 in zip(shape, strides1, strides2):
  376. if V.graph.sizevars.statically_known_leq(dim, 1):
  377. continue
  378. if not V.graph.sizevars.statically_known_equals(
  379. s1, s2
  380. ) and V.graph.sizevars.symbolic_hint(s1) != V.graph.sizevars.symbolic_hint(s2):
  381. return False
  382. return True
  383. def try_match_insignificant_strides(
  384. tensor: IRNode,
  385. strides: Sequence[Union[int, torch.SymInt]],
  386. ) -> IRNode:
  387. """
  388. Tries to match the strides of the tensor to those in the meta_strides. Strides of insignificant
  389. dimensions - size 0 or 1 - will be updated.
  390. If there are real stride differences (NHWC vs NCHW), or the tensor is not realized, then the input will be returned
  391. """
  392. if not is_storage_and_layout(tensor):
  393. return tensor
  394. if all(
  395. V.graph.sizevars.statically_known_equals(s1, s2)
  396. for s1, s2 in zip(strides, tensor.get_stride())
  397. ):
  398. return tensor
  399. if not significant_strides_equal(strides, tensor.get_stride(), tensor.get_size()):
  400. return tensor
  401. storage, old_layout = as_storage_and_layout(tensor)
  402. new_stride = [*old_layout.stride]
  403. for i, s in enumerate(tensor.get_size()):
  404. if V.graph.sizevars.statically_known_leq(s, 1):
  405. new_stride[i] = strides[i]
  406. new_layout = FixedLayout(
  407. old_layout.device,
  408. old_layout.dtype,
  409. old_layout.size,
  410. new_stride,
  411. old_layout.offset,
  412. old_layout.is_pinned,
  413. )
  414. return TensorBox(ReinterpretView(data=storage, layout=new_layout))
  415. def gm_original_output_strides(gm: torch.fx.GraphModule) -> None:
  416. output_node = gm.graph.find_nodes(op="output")[0]
  417. output_node.meta["user_visible_output_idxs"] = [
  418. idx for idx, _ in enumerate(output_node.args)
  419. ]
  420. from torch._inductor.compile_fx import record_original_output_strides
  421. record_original_output_strides(gm)
  422. def get_symbolic_inputs(inputs: Sequence[IRNode]) -> list[Expr]:
  423. sym_vars: OrderedSet[Expr] = OrderedSet()
  424. for inp in inputs:
  425. sym_vars |= get_free_symbols(inp.get_size(), unbacked_only=False)
  426. sym_vars |= get_free_symbols(inp.get_stride(), unbacked_only=False)
  427. return list(sym_vars)
  428. def try_get_name(x):
  429. if isinstance(x, TensorBox):
  430. x = x.data
  431. if isinstance(x, BaseView):
  432. x = x.unwrap_view()
  433. if isinstance(x, StorageBox):
  434. x = x.data
  435. return x.get_name() if isinstance(x, Buffer) else None
  436. class IRNode:
  437. """Base class for all intermediate representation (IR) nodes in TorchInductor.
  438. Note:
  439. This is an abstract base class. Most methods raise NotImplementedError
  440. and must be overridden by concrete subclasses.
  441. """
  442. _current_origins: ClassVar[OrderedSet[Any]] = OrderedSet()
  443. # NB: These are kinda weird,
  444. origins: OrderedSet[Any] = dataclasses.field(init=False)
  445. # traces back to where the IRNode is created in Inductor
  446. traceback: Optional[list[str]] = dataclasses.field(init=False)
  447. origin_node: Optional[torch.fx.Node] = dataclasses.field(init=False)
  448. # Annotations dict for storing metadata (e.g., KernelTemplateChoice)
  449. annotations: dict[str, Any] = dataclasses.field(init=False)
  450. @staticmethod
  451. @contextlib.contextmanager
  452. def current_origins(origins: OrderedSet[Node]) -> Generator[None, None, None]:
  453. old = IRNode._current_origins
  454. IRNode._current_origins = old | origins
  455. try:
  456. yield
  457. finally:
  458. IRNode._current_origins = old
  459. @staticmethod
  460. def is_realized_node(node: IRNode) -> bool:
  461. return isinstance(
  462. node,
  463. (
  464. ComputedBuffer,
  465. InputsKernel,
  466. InputBuffer,
  467. ReinterpretView,
  468. TemplateBuffer,
  469. ),
  470. )
  471. def _post_init_setattr(self, attr: str, value: Any) -> None:
  472. # Intended for use in __post_init__ for enforcing an invariant on a dataclass
  473. # If you must, can also be used for setting provenance info
  474. # We would like to try and minimize these usages though
  475. object.__setattr__(self, attr, value)
  476. def __post_init__(self) -> None:
  477. origins = OrderedSet(self._current_origins)
  478. self._post_init_setattr("origins", origins)
  479. self._post_init_setattr(
  480. "traceback", traceback.format_stack() if config.debug_ir_traceback else None
  481. )
  482. self._post_init_setattr("origin_node", None)
  483. # Annotations dict for storing metadata (e.g., KernelTemplateChoice)
  484. self._post_init_setattr("annotations", {})
  485. def get_read_names(self) -> OrderedSet[str]:
  486. return OrderedSet(dep.name for dep in self.get_reads())
  487. def get_traceback(self) -> Optional[list[str]]:
  488. return self.traceback
  489. def get_origin_node(self) -> Optional[torch.fx.Node]:
  490. return self.origin_node
  491. def get_defining_op(self) -> Optional[Operation]:
  492. return None
  493. def get_stack_traces(self) -> OrderedSet[str]:
  494. # Return stack traces to user model code
  495. # A single IRNode could correspond to multiple lines of code
  496. stack_traces: OrderedSet[str] = OrderedSet()
  497. origins = self.origins
  498. if isinstance(self, ExternKernel):
  499. origin_node = self.get_origin_node()
  500. if self.origin_node:
  501. origins = OrderedSet([origin_node])
  502. for node in origins:
  503. if hasattr(node, "stack_trace") and node.stack_trace:
  504. # nodes in the backward graph don't have mapping to pre_grad_graph
  505. stack_traces.add(node.stack_trace)
  506. else:
  507. pre_grad_nodes = (
  508. torch._inductor.debug._inductor_post_to_pre_grad_nodes.get(
  509. "postToPre",
  510. {},
  511. # pyrefly: ignore [missing-attribute]
  512. ).get(node.name, [])
  513. )
  514. if not isinstance(pre_grad_nodes, list):
  515. continue
  516. for node_name in pre_grad_nodes:
  517. stack_trace = (
  518. torch._inductor.debug._inductor_pre_grad_node_stack_trace.get(
  519. node_name, None
  520. )
  521. )
  522. if stack_trace:
  523. stack_traces.add(stack_trace)
  524. return stack_traces
  525. def common_repr(self, shorten: bool = True) -> Sequence[str]:
  526. origins = f"origins={getattr(self, 'origins', '')}"
  527. if shorten and len(origins) > 64:
  528. # this can get *very* long
  529. origins = f"{origins[:61]}..."
  530. if not self.get_stack_traces():
  531. return [origins]
  532. stack_trace_str = []
  533. for stack_trace in self.get_stack_traces():
  534. stack_trace_str.append("stack_traces = {")
  535. stack_trace_str += stack_trace.split("\n")
  536. stack_trace_str.append("}")
  537. return [origins] + stack_trace_str
  538. def str_helper(
  539. self, lines: Sequence[object], shorten: bool = True, multiline: bool = True
  540. ) -> str:
  541. lines = list(lines) + list(self.common_repr(shorten))
  542. lines = list(map(str, lines))
  543. if multiline:
  544. # pyrefly: ignore [no-matching-overload]
  545. new_lines = indent(",\n".join(lines))
  546. return f"{type(self).__name__}(\n{new_lines}\n)"
  547. else:
  548. return f"{type(self).__name__}({lines})"
  549. def get_dtype(self) -> torch.dtype:
  550. return self.dtype
  551. def maybe_get_dtype(self) -> Optional[torch.dtype]:
  552. try:
  553. return self.get_dtype()
  554. except NotImplementedError:
  555. return None
  556. def get_layout(self) -> Layout:
  557. raise NotImplementedError(f"get_layout() is not implemented by {type(self)}!")
  558. def maybe_get_layout(self) -> Optional[Layout]:
  559. try:
  560. return self.get_layout()
  561. except NotImplementedError:
  562. return None
  563. def get_output_spec(self) -> OutputSpec:
  564. return self.get_layout()
  565. def maybe_get_output_spec(self) -> Optional[OutputSpec]:
  566. try:
  567. return self.get_output_spec()
  568. except NotImplementedError:
  569. return None
  570. def has_tensor_output(self) -> bool:
  571. """True for single tensor output (excludes MultiOutput)"""
  572. return isinstance(self.maybe_get_output_spec(), Layout)
  573. def get_size(self) -> Sequence[Expr]:
  574. raise NotImplementedError(f"get_size() is not implemented by {type(self)}!")
  575. def maybe_get_size(self) -> Optional[Sequence[_IntLike]]:
  576. try:
  577. return self.get_size()
  578. except NotImplementedError:
  579. return None
  580. @property
  581. def shape(self) -> Union[_IntLike, sympy.Rel, Sequence[_IntLike]]:
  582. return self.get_size()
  583. def get_numel(self) -> Expr:
  584. return sympy_product(self.get_size())
  585. def is_zero_elements(self) -> bool:
  586. return V.graph.sizevars.statically_known_true(sympy.Eq(self.get_numel(), 0))
  587. def realize(self) -> Optional[str]:
  588. """
  589. If the IRNode refers to data which has not been materialized (e.g.,
  590. it is a Pointwise/Reduction that could potentially have more
  591. compute fused into it), realize the IRNode into physical memory,
  592. ending the possibility of fusing into it, but allowing, e.g., multiple
  593. users to access the data without having to recompute.
  594. Check StorageBox.realize for a particularly notable implementation.
  595. TODO(ezyang): I think, in principle, every IRNode should have an
  596. implementation of this, and most of the time no-op is OK, but you
  597. really do have to audit each IRNode for this, so for now, raise
  598. an error if it's not implemented. Note that some code in graph.py
  599. will catch this thrown error and suppress it with a warning.
  600. """
  601. raise NotImplementedError(f"realize NYI on {type(self)}")
  602. def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
  603. raise NotImplementedError(f"codegen_reference NYI on {type(self)}")
  604. def get_device(self) -> Optional[torch.device]:
  605. return None
  606. def get_device_or_error(self) -> torch.device:
  607. device = self.get_device()
  608. assert device is not None
  609. return device
  610. def has_exceeded_max_reads(self) -> bool:
  611. return False
  612. def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
  613. raise NotImplementedError(type(self).__name__)
  614. def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
  615. raise NotImplementedError(type(self).__name__)
  616. def get_stride(self) -> Sequence[_IntLike]:
  617. raise NotImplementedError(type(self).__name__)
  618. def maybe_get_stride(self) -> Optional[Sequence[_IntLike]]:
  619. try:
  620. return self.get_stride()
  621. except NotImplementedError:
  622. return None
  623. def get_name(self) -> str:
  624. raise NotImplementedError(type(self).__name__)
  625. def maybe_get_name(self) -> Optional[str]:
  626. try:
  627. return self.get_name()
  628. except NotImplementedError:
  629. return None
  630. def is_input_buffer(self) -> bool:
  631. try:
  632. return self.get_name() in V.graph.graph_inputs
  633. except NotImplementedError:
  634. return False
  635. def has_large_inner_fn(self, threshold: Optional[int] = None) -> bool:
  636. return False
  637. def mark_reuse(self, users: int) -> None:
  638. pass
  639. def realize_hint(self) -> None:
  640. pass
  641. def unwrap_view(self) -> IRNode:
  642. raise NotImplementedError(type(self).__name__)
  643. def freeze_layout(self) -> None:
  644. raise NotImplementedError(type(self).__name__)
  645. def freeze_layout_with_stride_order(
  646. self, order: Sequence[int], allow_padding: bool = False
  647. ) -> None:
  648. raise NotImplementedError(type(self).__name__)
  649. def freeze_layout_with_fill_order(self, order: Sequence[int]) -> None:
  650. raise NotImplementedError(type(self).__name__)
  651. def freeze_layout_with_same_order(self, stride: Sequence[_IntLike]) -> None:
  652. raise NotImplementedError(type(self).__name__)
  653. def freeze_layout_with_exact_strides(
  654. self, exact_strides: Sequence[_IntLike], allow_padding: bool = False
  655. ) -> None:
  656. raise NotImplementedError(type(self).__name__)
  657. def get_read_writes(self) -> dependencies.ReadWrites:
  658. raise NotImplementedError(type(self).__name__)
  659. def get_reads(self) -> OrderedSet[Dep]:
  660. return self.get_read_writes().reads
  661. def num_reads(self) -> int:
  662. return len(self.get_reads())
  663. def get_storage_numel(self) -> _IntLike:
  664. raise NotImplementedError(type(self).__name__)
  665. def get_free_symbol_uses(
  666. self, unbacked_only: bool = False
  667. ) -> OrderedSet[sympy.Symbol]:
  668. raise NotImplementedError(type(self).__name__)
  669. def get_reduction_type(self) -> Optional[str]:
  670. raise NotImplementedError(type(self).__name__)
  671. def get_reduction_size(self) -> Sequence[Expr]:
  672. raise NotImplementedError(type(self).__name__)
  673. def is_extern(self) -> bool:
  674. return False
  675. def is_no_op(self) -> bool:
  676. return False
  677. def constant_to_device(self, device: torch.device) -> IRNode:
  678. raise NotImplementedError(type(self).__name__)
  679. def get_mutation_names(self) -> Sequence[str]:
  680. raise NotImplementedError(type(self).__name__)
  681. def get_operation_name(self) -> str:
  682. raise NotImplementedError(type(self).__name__)
  683. def get_inputs_that_alias_output(self) -> Sequence[str]:
  684. raise NotImplementedError(type(self).__name__)
  685. if TYPE_CHECKING:
  686. @property
  687. def dtype(self) -> torch.dtype: ...
  688. @ir_dataclass(frozen=False)
  689. class Operation:
  690. def __post_init__(self) -> None:
  691. self.operation_name: Optional[str] = None
  692. def get_device(self) -> Optional[torch.device]:
  693. raise NotImplementedError
  694. def get_origin_node(self) -> Optional[torch.fx.Node]:
  695. assert hasattr(self, "origin_node")
  696. return self.origin_node
  697. def get_origins(self) -> OrderedSet[Any]:
  698. assert hasattr(self, "origins")
  699. return self.origins
  700. def get_operation_name(self) -> str:
  701. assert self.operation_name is not None
  702. return self.operation_name
  703. def is_extern(self) -> bool:
  704. return False
  705. def is_no_op(self) -> bool:
  706. return False
  707. def get_read_writes(self) -> dependencies.ReadWrites:
  708. raise NotImplementedError
  709. def is_user_of(self, name: str) -> bool:
  710. return name in self.get_read_names()
  711. def get_read_names(self) -> OrderedSet[str]:
  712. return OrderedSet(dep.name for dep in self.get_reads())
  713. def get_reads(self) -> OrderedSet[Dep]:
  714. return self.get_read_writes().reads
  715. def get_outputs(self) -> list[Buffer]:
  716. raise NotImplementedError
  717. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  718. return OrderedSet()
  719. def get_free_symbol_uses(
  720. self, unbacked_only: bool = False
  721. ) -> OrderedSet[sympy.Symbol]:
  722. """
  723. When unbacked_only=True:
  724. Returns the unbacked symbols which are required to be in scope in
  725. order to successfully perform codegen for this buffer. For example,
  726. a buffer that corresponds to an extern kernel call that takes i0 as
  727. an argument would return {i0} here. This is used to generate necessary
  728. dependencies that ensure we actually bind i0 in codegen before you
  729. try to use it.
  730. Note that this is NOT transitive; in particular, if this buffer takes
  731. in as input another buffer with dynamic shape (e.g., (i0,)), we will
  732. not report it here, because you will already have a dependency
  733. on that buffer, which will eventually have a dependency on i0 if
  734. necessary.
  735. When unbacked_only=False:
  736. Similar to `unbacked_only=True` but including all free symbols
  737. instead of only free unbacked symbols.
  738. """
  739. return OrderedSet()
  740. def get_workspace_size(self) -> int:
  741. """
  742. Gets extra global memory size needed by this buffer.
  743. Some algorithms (e.g. group gemm) may require extra global memory in the generated code.
  744. """
  745. return 0
  746. @ir_dataclass
  747. class Loops(IRNode):
  748. device: torch.device
  749. dtype: torch.dtype
  750. inner_fn: Callable[..., Any]
  751. ranges: Sequence[_IntLike]
  752. @cache_on_self_and_args("Loops")
  753. def get_free_symbol_uses(
  754. self, unbacked_only: bool = False
  755. ) -> OrderedSet[sympy.Symbol]:
  756. return OrderedSet().union(
  757. *(get_free_symbols(e, unbacked_only) for e in self.ranges),
  758. self.inner_fn_free_symbols(unbacked_only),
  759. )
  760. def _to_str(self, names: Sequence[str]) -> str:
  761. return self.str_helper(
  762. [
  763. f"'{self.device.type}'",
  764. str(self.dtype),
  765. self.inner_fn_str(),
  766. ]
  767. + [f"{name}={getattr(self, name)}" for name in names]
  768. + [f"origin_node={self.origin_node!r}"]
  769. )
  770. def __str__(self) -> str:
  771. return self._to_str(("ranges",))
  772. __repr__ = __str__
  773. def get_device(self) -> Optional[torch.device]:
  774. return self.device
  775. def get_origin_node(self) -> Optional[torch.fx.Node]:
  776. return self.origin_node
  777. def get_size(self) -> Sequence[Expr]:
  778. return self.ranges
  779. def get_pointwise_size(self) -> Sequence[Expr]:
  780. return self.ranges
  781. @classmethod
  782. def create(cls, *args: Any, **kwargs: Any) -> TensorBox:
  783. origin_node = kwargs.pop("origin_node", None)
  784. tb = kwargs.pop("traceback", None)
  785. r = cls(*args, **kwargs)
  786. # Need to explicitly set origin_node here to propagate it down.
  787. # todo(chilli): I think it would be better for IRNode to directly set
  788. # origin_node
  789. r._post_init_setattr("origin_node", origin_node)
  790. r._post_init_setattr("traceback", tb or r.traceback)
  791. return TensorBox.create(r)
  792. @staticmethod
  793. def _index(ranges: Sequence[_IntLike], prefix: SymT = SymT.INDEX) -> Sequence[Expr]:
  794. return [
  795. sympy.S.Zero if s == 1 else sympy_index_symbol_with_prefix(prefix, n)
  796. for n, s in enumerate(ranges)
  797. ]
  798. @cache_on_self
  799. def inner_fn_opcount(self) -> OpCountResult:
  800. opcounter = OpCounterCSE(V.MockHandler())
  801. with (
  802. V.set_ops_handler(opcounter),
  803. patch.object(FlexibleLayout, "allow_indexing", True),
  804. ):
  805. self.inner_fn(*self.inner_fn_args())
  806. return opcounter.getvalue()
  807. def inner_fn_args(self) -> Sequence[Sequence[_IntLike]]:
  808. return (self._index(self.ranges),)
  809. @cache_on_self
  810. def inner_fn_str(self) -> str:
  811. return V.KernelFormatterHandler.ir_to_string(
  812. self.inner_fn, *self.inner_fn_args()
  813. )
  814. def has_large_inner_fn(self, threshold: Optional[int] = None) -> bool:
  815. if threshold is None:
  816. threshold = 0
  817. threshold = max(threshold, config.realize_opcount_threshold)
  818. return self.inner_fn_opcount().num_ops > threshold
  819. def inner_fn_free_symbols(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
  820. index = self._index(self.ranges)
  821. return extract_free_symbols(self.inner_fn, index, unbacked_only=unbacked_only)
  822. def get_reads(self) -> OrderedSet[Dep]:
  823. with patch.object(FlexibleLayout, "allow_indexing", True):
  824. if self.get_reduction_type():
  825. return extract_read_writes(
  826. self.make_loader(),
  827. self.get_size(),
  828. self.get_reduction_size(),
  829. ).reads
  830. else:
  831. return extract_read_writes(
  832. self.make_loader(),
  833. self.get_size(),
  834. ).reads
  835. def get_read_names(self) -> OrderedSet[str]:
  836. return OrderedSet(self.inner_fn_opcount().read_buffers)
  837. def num_reads(self) -> int:
  838. return len(self.inner_fn_opcount().read_buffers)
  839. def get_reduction_size(self) -> Sequence[Expr]:
  840. raise NotImplementedError(
  841. f"get_reduction_size() is not implemented by {type(self)}!"
  842. )
  843. def get_reduction_type(self) -> Optional[str]:
  844. raise NotImplementedError(
  845. f"get_reduction_type() is not implemented by {type(self)}!"
  846. )
  847. def constant_to_device(self, device: torch.device) -> IRNode:
  848. raise NotImplementedError(
  849. f"constant_to_device() is not implemented by {type(self)}!"
  850. )
  851. def nop_loader_fn(idx: Union[Expr, Sequence[Expr]], *, dtype: torch.dtype) -> OpsValue:
  852. if dtype.is_floating_point:
  853. return ops.constant(float("nan"), dtype)
  854. else:
  855. return ops.constant(0, dtype)
  856. @ir_dataclass
  857. class Pointwise(Loops):
  858. def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
  859. # Make zero-element loops into a no-op
  860. if self.is_zero_elements():
  861. return partial(nop_loader_fn, dtype=self.dtype)
  862. return self.inner_fn
  863. def __str__(self) -> str:
  864. return self._to_str(("ranges",))
  865. __repr__ = __str__
  866. def get_reduction_size(self) -> Sequence[sympy.Expr]:
  867. return []
  868. def get_reduction_type(self) -> Optional[str]:
  869. return None
  870. def store_output(
  871. self,
  872. output_name: Optional[str],
  873. indexer: Callable[[Sequence[Expr]], Never],
  874. vars: Sequence[Expr],
  875. ) -> None:
  876. loader = self.make_loader()
  877. return ops.store(output_name or "unnamed", indexer(vars), loader(vars))
  878. def constant_to_device(self, device: torch.device) -> IRNode:
  879. """Move this to a given device. Requires that all reads are to constants."""
  880. loader = self.make_loader()
  881. loader = patch.object(ConstantBuffer, "override_device", device)(loader)
  882. return Pointwise(
  883. device=device,
  884. dtype=self.dtype,
  885. inner_fn=loader,
  886. ranges=self.ranges,
  887. )
  888. @ir_dataclass
  889. class Scatter(Pointwise):
  890. output_indexer: Callable[[Sequence[Expr]], Expr]
  891. scatter_mode: StoreMode = None
  892. def constant_to_device(self, device: torch.device) -> IRNode:
  893. """Move this to a given device. Requires that all reads are to constants."""
  894. loader = self.make_loader()
  895. loader = patch.object(ConstantBuffer, "override_device", device)(loader)
  896. return Scatter(
  897. device=device,
  898. dtype=self.dtype,
  899. inner_fn=loader,
  900. ranges=self.ranges,
  901. output_indexer=self.output_indexer,
  902. scatter_mode=self.scatter_mode,
  903. )
  904. def store_output(
  905. self,
  906. output_name: Optional[str],
  907. indexer: Callable[[Sequence[Expr]], Never],
  908. vars: Sequence[Expr],
  909. ) -> Any:
  910. loader = self.make_loader()
  911. if output_name is None:
  912. output_name = "unnamed"
  913. return ops.store(
  914. output_name,
  915. indexer(self.output_indexer(vars)),
  916. loader(vars),
  917. mode=self.scatter_mode,
  918. )
  919. REDUCTION_COMBINE_FN: dict[str, Callable[..., OpsValue]] = {
  920. "any": ops_wrapper("logical_or"),
  921. "max": ops_wrapper("maximum"),
  922. "min": ops_wrapper("minimum"),
  923. "prod": ops_wrapper("mul"),
  924. "sum": ops_wrapper("add"),
  925. "dot": ops_wrapper("add"),
  926. "xor_sum": ops_wrapper("bitwise_xor"),
  927. }
  928. def get_reduction_combine_fn(
  929. reduction_type: str, dtype: torch.dtype, arg_break_ties_left: bool = True
  930. ) -> Callable[..., object]:
  931. if reduction_type in REDUCTION_COMBINE_FN:
  932. return REDUCTION_COMBINE_FN[reduction_type]
  933. elif reduction_type in ("argmax", "argmin"):
  934. def argmax_combine_fn(
  935. a: tuple[object, object], b: tuple[object, object]
  936. ) -> tuple[OpsValue, OpsValue]:
  937. a_value, a_index = a
  938. b_value, b_index = b
  939. if reduction_type == "argmin":
  940. mask = ops.lt(a_value, b_value)
  941. else:
  942. mask = ops.gt(a_value, b_value)
  943. equal = ops.eq(a_value, b_value)
  944. if is_float_dtype(dtype):
  945. a_isnan = ops.ne(a_value, a_value)
  946. b_isnan = ops.ne(b_value, b_value)
  947. mask = ops.logical_or(mask, ops.gt(a_isnan, b_isnan))
  948. equal = ops.logical_or(equal, ops.logical_and(a_isnan, b_isnan))
  949. tie = (
  950. ops.lt(a_index, b_index)
  951. if arg_break_ties_left
  952. else ops.gt(a_index, b_index)
  953. )
  954. mask = ops.logical_or(mask, ops.logical_and(equal, tie))
  955. return (
  956. ops.where(mask, a_value, b_value),
  957. ops.where(mask, a_index, b_index),
  958. )
  959. return argmax_combine_fn
  960. elif reduction_type == "welford_combine":
  961. def welford_combine_fn(
  962. a: tuple[OpsValue, OpsValue, OpsValue],
  963. b: tuple[OpsValue, OpsValue, OpsValue],
  964. ) -> tuple[OpsValue, OpsValue, OpsValue]:
  965. a_mean, a_m2, a_weight = a
  966. b_mean, b_m2, b_weight = b
  967. delta = b_mean - a_mean
  968. new_weight = a_weight + b_weight
  969. w2_over_w = b_weight / new_weight
  970. return (
  971. a_mean + delta * w2_over_w,
  972. a_m2 + b_m2 + delta * delta * a_weight * w2_over_w,
  973. new_weight,
  974. )
  975. return welford_combine_fn
  976. else:
  977. raise NotImplementedError(f"unknown reduction_type={reduction_type}")
  978. @ir_dataclass
  979. class Reduction(Loops):
  980. reduction_ranges: Sequence[_IntLike]
  981. reduction_type: ReductionType
  982. # self.dtype represents the dst dtype
  983. src_dtype: torch.dtype
  984. reduction_hint: ReductionHint
  985. def __str__(self) -> str:
  986. return self._to_str(("ranges", "reduction_ranges", "reduction_type"))
  987. __repr__ = __str__
  988. @cache_on_self_and_args("Reduction")
  989. def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
  990. return super().get_free_symbol_uses(unbacked_only) | OrderedSet().union(
  991. *(get_free_symbols(e, unbacked_only) for e in self.reduction_ranges)
  992. )
  993. def get_reduction_size(self) -> Sequence[Expr]:
  994. return self.reduction_ranges
  995. def get_reduction_type(self) -> Optional[str]:
  996. return self.reduction_type
  997. def store_reduction(
  998. self,
  999. output_name: Optional[str],
  1000. indexer: Callable[[Sequence[Expr]], Never],
  1001. vars: Sequence[Expr],
  1002. reduction_vars: Sequence[Symbol],
  1003. ) -> None:
  1004. value = ops.reduction(
  1005. self.dtype,
  1006. self.src_dtype,
  1007. self.reduction_type,
  1008. self.inner_fn(vars, reduction_vars),
  1009. )
  1010. ops.store_reduction(output_name or "unnamed", indexer(vars), value)
  1011. def index_length(self) -> int:
  1012. return len(self.ranges) + len(self.reduction_ranges)
  1013. def inner_fn_args(self) -> Sequence[Sequence[Expr]]:
  1014. index = self._index(self.ranges)
  1015. rindex = self._index(self.reduction_ranges, SymT.R0_INDEX)
  1016. return (index, rindex)
  1017. def inner_fn_free_symbols(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
  1018. index = self._index(self.ranges)
  1019. rindex = self._index(self.reduction_ranges, SymT.R0_INDEX)
  1020. return extract_free_symbols(
  1021. self.inner_fn, index, rindex, unbacked_only=unbacked_only
  1022. )
  1023. def constant_to_device(self, device: torch.device) -> IRNode:
  1024. """Move this to a given device. Requires that all reads are to constants."""
  1025. loader = self.make_loader()
  1026. loader = patch.object(ConstantBuffer, "override_device", device)(loader)
  1027. return Reduction(
  1028. device=device,
  1029. dtype=self.dtype,
  1030. inner_fn=loader,
  1031. ranges=self.ranges,
  1032. reduction_ranges=self.reduction_ranges,
  1033. reduction_type=self.reduction_type,
  1034. src_dtype=self.src_dtype,
  1035. reduction_hint=ReductionHint.DEFAULT,
  1036. )
  1037. @staticmethod
  1038. def num_splits(
  1039. device: torch.device,
  1040. dst_dtype: torch.dtype,
  1041. src_dtype: torch.dtype,
  1042. inner_fn: Callable[_P, OpsValue],
  1043. ranges: Sequence[_IntLike],
  1044. reduction_ranges: Sequence[_IntLike],
  1045. reduction_type: Union[ReductionType, Literal["scan"]],
  1046. reduction_numel: Expr,
  1047. input_node: Optional[IRNode] = None,
  1048. ) -> tuple[ReductionHint, _IntLike]:
  1049. reduction_numel_hint = V.graph.sizevars.symbolic_hint(reduction_numel)
  1050. numel_hint = V.graph.sizevars.symbolic_hint(sympy_product(ranges))
  1051. should_split = reduction_type == "scan" or (
  1052. not V.graph.has_feature(device, BackendFeature.REDUCE_TO_SINGLE_ELEMENT)
  1053. and reduction_type
  1054. not in (
  1055. "argmax",
  1056. "argmin",
  1057. )
  1058. and config.split_reductions
  1059. )
  1060. if not (_is_static(reduction_numel_hint) and _is_static(numel_hint)):
  1061. # We don't support unbacked symints
  1062. return ReductionHint.DEFAULT, 1
  1063. if reduction_type == "dot":
  1064. # Don't split when doing native matmul
  1065. return ReductionHint.DEFAULT, 1
  1066. props = DeviceProperties.create(device)
  1067. num_sm = props.multi_processor_count
  1068. min_elements_per_thread = 32
  1069. if should_split:
  1070. inner_reduction_splits: Callable[[int, int], int] = functools.partial(
  1071. V.choices.reduction_split_factor, device, inner_reduction=True
  1072. )
  1073. outer_reduction_splits: Callable[[int, int], int] = functools.partial(
  1074. V.choices.reduction_split_factor, device, inner_reduction=False
  1075. )
  1076. else:
  1077. def inner_reduction_splits(
  1078. reduction_numel_hint: int,
  1079. numel_hint: int,
  1080. ) -> int:
  1081. return 1
  1082. outer_reduction_splits = inner_reduction_splits
  1083. # easy cases
  1084. if numel_hint == 1:
  1085. split = inner_reduction_splits(reduction_numel_hint, numel_hint)
  1086. if split == 1:
  1087. # No need to split.
  1088. return ReductionHint.INNER, split
  1089. if input_node is not None and isinstance(input_node, TensorBox):
  1090. with patch.object(FlexibleLayout, "allow_indexing", True):
  1091. (
  1092. new_ranges,
  1093. new_reduction_ranges,
  1094. ) = extract_input_node_reduction_ranges(input_node)
  1095. if new_ranges is not None and new_reduction_ranges is not None:
  1096. extracted_numel_hint = V.graph.sizevars.symbolic_hint(
  1097. sympy_product(new_ranges + new_reduction_ranges)
  1098. )
  1099. if reduction_numel_hint == extracted_numel_hint:
  1100. log.debug(
  1101. "Use previous IRNode's range and reduction_ranges instead of split. "
  1102. "current ranges: %s, current reduction ranges: %s, current split: %d, "
  1103. "new ranges: %s, new reduction ranges: %s",
  1104. ranges,
  1105. reduction_ranges,
  1106. split,
  1107. new_ranges,
  1108. new_reduction_ranges,
  1109. )
  1110. # If the input_node or its dependent nodes are also Reduction nodes,
  1111. # use reduction_sizes of this node or its dependent nodes directly.
  1112. return ReductionHint.INNER, -1
  1113. return ReductionHint.INNER, split
  1114. if (
  1115. reduction_numel_hint <= min_elements_per_thread
  1116. or numel_hint >= num_sm * 2 * 32
  1117. ):
  1118. return ReductionHint.DEFAULT, 1
  1119. r = Reduction(
  1120. device=device,
  1121. dtype=dst_dtype,
  1122. inner_fn=inner_fn,
  1123. ranges=ranges,
  1124. reduction_ranges=reduction_ranges,
  1125. reduction_type=reduction_type if reduction_type != "scan" else "sum",
  1126. src_dtype=src_dtype,
  1127. reduction_hint=ReductionHint.DEFAULT,
  1128. )
  1129. def get_read_indices(r: Reduction) -> tuple[Sequence[Expr], bool]:
  1130. device = r.get_device()
  1131. assert device is not None
  1132. cb = ComputedBuffer(
  1133. name=None,
  1134. layout=FlexibleLayout(
  1135. device=device,
  1136. dtype=r.get_dtype(),
  1137. size=r.get_size(),
  1138. ),
  1139. data=r,
  1140. )
  1141. read_writes = cb.get_read_writes()
  1142. # try finding the full size producer
  1143. # TODO this will fail for something like ((1, N) * (N, 1)).sum()
  1144. # this would also possibly be wrong for producers with the different contiguity but we hope those cases are rare
  1145. assert read_writes.range_vars is not None
  1146. range_vars = [
  1147. r
  1148. for r in read_writes.range_vars
  1149. if isinstance(r, Expr) and not isinstance(r, sympy.Number)
  1150. ]
  1151. indices = []
  1152. changed = False
  1153. for md in sorted(read_writes.reads, key=lambda x: x.name):
  1154. if all(r in md.index.free_symbols for r in range_vars):
  1155. indices.append(md.index)
  1156. if md.name in V.graph.name_to_buffer:
  1157. buf = V.graph.name_to_buffer[md.name]
  1158. original_stride = getattr(buf.layout, "stride", None)
  1159. buf.decide_layout()
  1160. if getattr(buf.layout, "stride", None) != original_stride:
  1161. changed = True
  1162. return indices, changed
  1163. indices, changed = get_read_indices(r)
  1164. if changed:
  1165. indices, _ = get_read_indices(r)
  1166. if len(indices) == 0:
  1167. # TODO determine splits when all inputs are broadcast
  1168. return ReductionHint.DEFAULT, 1
  1169. (_, reduction_vars), ranges1 = dependencies.index_vars_squeeze(
  1170. r.get_size(), r.get_reduction_size()
  1171. )
  1172. num_outer = 0
  1173. num_inner = 0
  1174. for i in indices:
  1175. j = V.graph.sizevars.simplify_with_ranges(i, ranges1)
  1176. strides = V.graph.sizevars.stride_hints(
  1177. j, reduction_vars, list(ranges1.keys())
  1178. )
  1179. # A 0 stride does not make a reduction contiguous.
  1180. # This can happen when the reduction ranges contains a 1.
  1181. outer = all(s == 0 or s > 1 for s in strides)
  1182. if outer:
  1183. num_outer += 1
  1184. else:
  1185. num_inner += 1
  1186. if num_inner > num_outer:
  1187. return ReductionHint.INNER, inner_reduction_splits(
  1188. reduction_numel_hint, numel_hint
  1189. )
  1190. else:
  1191. return ReductionHint.OUTER, outer_reduction_splits(
  1192. reduction_numel_hint, numel_hint
  1193. )
  1194. @staticmethod
  1195. def _unroll_reduction_fn(
  1196. inner_fn: Callable[[Sequence[_IntLike], Sequence[_IntLike]], OpsValue],
  1197. reduction_ranges: Sequence[_IntLike],
  1198. reduction_type: str,
  1199. src_dtype: torch.dtype,
  1200. ) -> Callable[[Sequence[_IntLike]], OpsValue]:
  1201. """Convert inner_fn from a reduction to an pointwise"""
  1202. reduction_ranges = V.graph.sizevars.guard_int_seq(reduction_ranges)
  1203. combine_fn = get_reduction_combine_fn(reduction_type, src_dtype)
  1204. def fn(index: Sequence[_IntLike]) -> Any:
  1205. return functools.reduce(
  1206. combine_fn,
  1207. (
  1208. value_fn(index, rindex)
  1209. for rindex in itertools.product(
  1210. *[range(x) for x in reduction_ranges]
  1211. )
  1212. ),
  1213. )
  1214. value_fn: Callable[[Sequence[_IntLike], Sequence[_IntLike]], Any]
  1215. if reduction_type in ("argmin", "argmax"):
  1216. flatten_index = _fixed_indexer(
  1217. reduction_ranges,
  1218. FlexibleLayout.contiguous_strides(reduction_ranges),
  1219. )
  1220. def value_fn(
  1221. index: Sequence[_IntLike], rindex: Sequence[_IntLike]
  1222. ) -> tuple[OpsValue, OpsValue]:
  1223. rindex = [sympy.expand(i) for i in rindex]
  1224. return (
  1225. inner_fn(index, rindex),
  1226. ops.index_expr(flatten_index(rindex), torch.int64),
  1227. )
  1228. return lambda index: fn(index)[1]
  1229. else:
  1230. value_fn = inner_fn
  1231. return fn
  1232. @classmethod
  1233. # pyrefly: ignore [bad-override]
  1234. def create(
  1235. cls,
  1236. device: torch.device,
  1237. dst_dtype: torch.dtype,
  1238. src_dtype: torch.dtype,
  1239. inner_fn: Callable[..., Any],
  1240. ranges: Sequence[Expr],
  1241. reduction_ranges: Sequence[Expr],
  1242. reduction_type: ReductionType,
  1243. reduction_hint: ReductionHint = ReductionHint.DEFAULT,
  1244. input_node: Optional[IRNode] = None,
  1245. ) -> TensorBox:
  1246. """
  1247. Create a reduction node. May split the reduction to multiple layers to expose
  1248. more parallelism.
  1249. """
  1250. reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
  1251. if reduction_numel == 0:
  1252. # N.B. This is a hack to generate the literal of the given type
  1253. # Ideally, we should be fixing `def constant` in triton.py
  1254. # but it breaks due to hardcoded dtypes in other places
  1255. def py_cnst(val: object) -> Union[bool, float, int]:
  1256. if dst_dtype == torch.bool:
  1257. return bool(val)
  1258. elif dst_dtype.is_floating_point:
  1259. assert isinstance(val, SupportsFloat), type(val)
  1260. return float(val)
  1261. else:
  1262. assert isinstance(val, SupportsInt), type(val)
  1263. return int(val)
  1264. rtypes_to_inits = {
  1265. "sum": py_cnst(0),
  1266. "xor_sum": py_cnst(0),
  1267. "prod": py_cnst(1),
  1268. "any": py_cnst(0),
  1269. # "all" is desugared to `!any(!val)`
  1270. }
  1271. assert reduction_type in rtypes_to_inits, (
  1272. f"{reduction_type} not supported for zero-dimension tensors!"
  1273. )
  1274. def const_fn(index: int) -> OpsValue:
  1275. return ops.constant(rtypes_to_inits[reduction_type], dst_dtype)
  1276. return Pointwise.create(
  1277. device=device,
  1278. dtype=src_dtype,
  1279. inner_fn=const_fn,
  1280. ranges=list(ranges),
  1281. )
  1282. if reduction_numel == 1:
  1283. # this reduction is actually a pointwise op
  1284. if reduction_type in ("argmin", "argmax"):
  1285. def fn(index: int) -> OpsValue:
  1286. return ops.constant(0, dst_dtype)
  1287. else:
  1288. def fn(index: int) -> OpsValue:
  1289. reduction_index = [sympy.S.Zero for _ in reduction_ranges]
  1290. return inner_fn(index, reduction_index)
  1291. return Pointwise.create(
  1292. device=device, dtype=dst_dtype, inner_fn=fn, ranges=ranges
  1293. )
  1294. if (
  1295. isinstance(reduction_numel, Integer)
  1296. and V.graph.sizevars.size_hint_or_throw(reduction_numel)
  1297. < config.unroll_reductions_threshold
  1298. and (sympy_product(ranges) != 1 or is_gpu(device.type))
  1299. and reduction_type != "dot"
  1300. ):
  1301. # When native matmul, don't unroll the dot reduction.
  1302. # NB: This works around https://github.com/pytorch/pytorch/issues/140457
  1303. # since turning reductions into pointwise ops can exacerbate this problem
  1304. return Pointwise.create(
  1305. device=device,
  1306. dtype=dst_dtype,
  1307. inner_fn=cls._unroll_reduction_fn(
  1308. inner_fn, reduction_ranges, reduction_type, src_dtype
  1309. ),
  1310. ranges=ranges,
  1311. )
  1312. # triton doesn't support reduce to single element well, so break it up
  1313. hint, split = cls.num_splits(
  1314. device,
  1315. dst_dtype,
  1316. src_dtype,
  1317. inner_fn,
  1318. ranges,
  1319. reduction_ranges,
  1320. reduction_type,
  1321. reduction_numel,
  1322. input_node,
  1323. )
  1324. def _maybe_increase_split(split: int) -> int:
  1325. # don't apply min_num_split constraint for static shape case.
  1326. if _is_static(reduction_numel):
  1327. return split
  1328. if split > 1:
  1329. return max(split, config.min_num_split)
  1330. else:
  1331. return split
  1332. split = _maybe_increase_split(split)
  1333. # intermediate reduction in split can contain complex indexing,
  1334. # and num_splits will fail to correctly set the hint
  1335. # reuse the passed hint if available
  1336. if reduction_hint == ReductionHint.DEFAULT:
  1337. reduction_hint = hint
  1338. if split == -1:
  1339. assert input_node is not None
  1340. with patch.object(FlexibleLayout, "allow_indexing", True):
  1341. new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges(
  1342. input_node
  1343. )
  1344. assert new_ranges is not None
  1345. assert new_reduction_ranges is not None
  1346. return cls.create_multilayer_existing_ranges(
  1347. device,
  1348. dst_dtype,
  1349. src_dtype,
  1350. inner_fn,
  1351. ranges,
  1352. reduction_ranges,
  1353. new_ranges,
  1354. new_reduction_ranges,
  1355. reduction_type,
  1356. reduction_hint,
  1357. )
  1358. elif split > 1:
  1359. # triton doesn't support reduce to single element well, so break it up
  1360. out = cls.create_multilayer(
  1361. device,
  1362. dst_dtype,
  1363. src_dtype,
  1364. inner_fn,
  1365. ranges,
  1366. reduction_ranges,
  1367. reduction_type,
  1368. split,
  1369. reduction_hint,
  1370. input_node,
  1371. )
  1372. # Find the reduction that get split
  1373. split_reduction = None
  1374. if config.triton.mix_order_reduction and isinstance(out, TensorBox):
  1375. def _find_split_reduction(
  1376. cur_node: TensorBox,
  1377. ) -> Optional[ComputedBuffer]:
  1378. read_names = cur_node.get_read_names()
  1379. if len(read_names) != 1:
  1380. return None
  1381. bufname = next(iter(read_names))
  1382. if bufname not in V.graph.name_to_buffer:
  1383. return None
  1384. buf = V.graph.name_to_buffer[bufname]
  1385. if not isinstance(buf, ComputedBuffer):
  1386. return None
  1387. assert buf.data.get_reduction_type() is not None
  1388. return buf
  1389. split_reduction = _find_split_reduction(out)
  1390. if split_reduction:
  1391. # If a reduction is split to more than 2 layers,
  1392. # say there are 3 layers,
  1393. # we always have the correct setting for layer1 (top layer).
  1394. # The setting on layer2 may be incorrect but it's fine
  1395. # since they are never get used.
  1396. # TODO: should we skip setting these fields for layer2
  1397. assert isinstance(split_reduction.data, Reduction), (
  1398. f"{type(split_reduction.data)}"
  1399. )
  1400. split_reduction._split_size = split_reduction.data.reduction_ranges[0]
  1401. split_reduction._original_inner_fn = inner_fn
  1402. split_reduction._original_ranges = ranges
  1403. split_reduction._original_reduction_ranges = reduction_ranges
  1404. return out
  1405. out = TensorBox.create(
  1406. Reduction(
  1407. device=device,
  1408. dtype=dst_dtype,
  1409. inner_fn=inner_fn,
  1410. ranges=ranges,
  1411. reduction_ranges=reduction_ranges,
  1412. reduction_type=reduction_type,
  1413. src_dtype=src_dtype,
  1414. reduction_hint=reduction_hint,
  1415. )
  1416. )
  1417. return out
  1418. @staticmethod
  1419. def default_accumulator(
  1420. reduction_type: str, dtype: torch.dtype
  1421. ) -> Union[_NumLike, Sequence[_NumLike]]:
  1422. if reduction_type in ("max", "argmax"):
  1423. if is_float_dtype(dtype):
  1424. return float("-inf")
  1425. elif is_boolean_dtype(dtype):
  1426. return False
  1427. else:
  1428. return torch.iinfo(dtype).min
  1429. if reduction_type in ("min", "argmin"):
  1430. if is_float_dtype(dtype):
  1431. return float("inf")
  1432. elif is_boolean_dtype(dtype):
  1433. return True
  1434. else:
  1435. return torch.iinfo(dtype).max
  1436. zero = False if is_boolean_dtype(dtype) else 0
  1437. one = True if is_boolean_dtype(dtype) else 1
  1438. return {
  1439. "sum": zero,
  1440. "prod": one,
  1441. "dot": zero,
  1442. "xor_sum": zero,
  1443. "any": zero,
  1444. "welford_reduce": (zero, zero, zero),
  1445. "welford_combine": (zero, zero, zero),
  1446. "online_softmax_reduce": (float("-inf"), zero),
  1447. }[reduction_type]
  1448. @staticmethod
  1449. def default_value(
  1450. reduction_type: str, dtype: torch.dtype
  1451. ) -> Union[_NumLike, Sequence[_NumLike]]:
  1452. if reduction_type == "welford_reduce":
  1453. return 0
  1454. return Reduction.default_accumulator(reduction_type, dtype)
  1455. @staticmethod
  1456. def _multilayer_second_step_hint(
  1457. split: _IntLike, numel_hint: int, reduction_hint: ReductionHint
  1458. ) -> ReductionHint:
  1459. if split == -1:
  1460. return reduction_hint
  1461. if split <= 512 and numel_hint <= 512 and reduction_hint == ReductionHint.OUTER:
  1462. return ReductionHint.OUTER_TINY
  1463. if (
  1464. split <= 1024
  1465. and numel_hint <= 256
  1466. and reduction_hint == ReductionHint.OUTER
  1467. ):
  1468. return ReductionHint.OUTER_TINY
  1469. return reduction_hint
  1470. @classmethod
  1471. def check_for_split_dense_dim_reindexing(
  1472. cls, reduction_numel: _IntLike, input_node: Optional[IRNode]
  1473. ) -> Optional[int]:
  1474. """
  1475. If we are reducing over the full tensor, and it is non-dense in the last dimension,
  1476. reindex so we reduce over the dense dimension. initially just handle complete
  1477. reduction case
  1478. """
  1479. if input_node is None:
  1480. return None
  1481. if not V.graph.sizevars.statically_known_equals(
  1482. input_node.get_numel(), reduction_numel
  1483. ):
  1484. return None
  1485. input_node.realize()
  1486. try:
  1487. # finalize layout
  1488. as_storage_and_layout(input_node)
  1489. except NotImplementedError:
  1490. return None
  1491. strides = input_node.get_stride()
  1492. for i, s in enumerate(strides[:-1]):
  1493. if V.graph.sizevars.statically_known_equals(s, 1):
  1494. return i
  1495. return None
  1496. @classmethod
  1497. def _multilayer_wrap_loader(
  1498. cls,
  1499. loader: Callable[..., OpsValue],
  1500. reduction_ranges: Sequence[_IntLike],
  1501. reduction_numel: _IntLike,
  1502. split: _IntLike,
  1503. block_size: _IntLike,
  1504. default: Union[_NumLike, Sequence[_NumLike]],
  1505. input_node: Optional[IRNode] = None,
  1506. ) -> Callable[..., object]:
  1507. dense_index = cls.check_for_split_dense_dim_reindexing(
  1508. reduction_numel, input_node
  1509. )
  1510. reindex = View.dynamic_reshape_indexer(
  1511. reduction_ranges, [reduction_numel], dense_index
  1512. )
  1513. need_mask = not V.graph.sizevars.statically_known_true(
  1514. sympy.Eq(reduction_numel % split, 0)
  1515. )
  1516. def wrapper_fn(
  1517. index: Sequence[Symbol], reduction_index: Sequence[Symbol]
  1518. ) -> OpsValue:
  1519. (reduction_index,) = reduction_index
  1520. *new_index, reduction_block = index
  1521. indices = block_size * reduction_block + reduction_index
  1522. def body() -> OpsValue:
  1523. return loader(new_index, reindex([indices]))
  1524. if need_mask:
  1525. index_dtype = dtype_from_size(reduction_numel)
  1526. mask = ops.lt(
  1527. ops.index_expr(indices, index_dtype),
  1528. ops.index_expr(reduction_numel, index_dtype),
  1529. )
  1530. return ops.masked(mask, body, default)
  1531. else:
  1532. return body()
  1533. return wrapper_fn
  1534. @classmethod
  1535. def _multilayer_wrap_loader_existing_ranges(
  1536. cls,
  1537. loader: Callable[[Sequence[Expr], Sequence[Expr]], OpsValue],
  1538. original_ranges: Sequence[Expr],
  1539. original_reduction_ranges: Sequence[Expr],
  1540. new_ranges: Sequence[Integer],
  1541. new_reduction_ranges: Sequence[Integer],
  1542. ) -> Callable[[Sequence[sympy.Expr], Sequence[sympy.Expr]], OpsValue]:
  1543. assert all(r == 1 for r in original_ranges), (
  1544. f"Only enabled for numel_hint == 1, found {original_ranges=}"
  1545. )
  1546. reindex = View.dynamic_reshape_indexer(
  1547. original_reduction_ranges, tuple(new_ranges) + tuple(new_reduction_ranges)
  1548. )
  1549. def wrapper_fn(
  1550. merged_index: Sequence[Expr],
  1551. new_reduction_index: Sequence[Expr],
  1552. ) -> OpsValue:
  1553. original_idx = merged_index[: len(original_ranges)]
  1554. new_index = merged_index[len(original_ranges) :]
  1555. return loader(
  1556. original_idx,
  1557. reindex(tuple(new_index) + tuple(new_reduction_index)),
  1558. )
  1559. return wrapper_fn
  1560. @classmethod
  1561. def create_multilayer_helper(
  1562. cls,
  1563. device: torch.device,
  1564. dst_dtype: torch.dtype,
  1565. src_dtype: torch.dtype,
  1566. wrapper_fn: Callable[..., Any],
  1567. original_ranges: Sequence[Expr],
  1568. original_reduction_ranges: Sequence[Expr],
  1569. new_ranges: list[Expr],
  1570. new_reduction_ranges: list[Integer],
  1571. reduction_type: ReductionType,
  1572. split: _IntLike,
  1573. reduction_hint: ReductionHint,
  1574. ) -> TensorBox:
  1575. """
  1576. Break a large reduction up into multiple smaller reductions
  1577. recursively
  1578. """
  1579. # triton will automatically compute reductions in fp32 if reducing over fp16/bf16
  1580. # within the kernel. keep the intermediate in fp32 so as to keep the whole reduction
  1581. # in fp32 and not reduce precision by breaking up the kernel into multiple layers
  1582. intermediate_dtype = (
  1583. dst_dtype
  1584. if dst_dtype not in (torch.float16, torch.bfloat16)
  1585. else torch.float
  1586. )
  1587. intermediate = Reduction.create(
  1588. device,
  1589. intermediate_dtype,
  1590. src_dtype,
  1591. wrapper_fn,
  1592. new_ranges,
  1593. new_reduction_ranges,
  1594. reduction_type,
  1595. reduction_hint,
  1596. )
  1597. intermediate.realize()
  1598. intermediate_loader = intermediate.make_loader()
  1599. def intermediate_fn(
  1600. index: Sequence[_IntLike], reduction_index: Sequence[_IntLike]
  1601. ) -> OpsValue:
  1602. return intermediate_loader([*index, *reduction_index])
  1603. numel_hint = V.graph.sizevars.optimization_hint(sympy_product(original_ranges))
  1604. reduction_hint = cls._multilayer_second_step_hint(
  1605. split, numel_hint, reduction_hint
  1606. )
  1607. assert original_ranges == new_ranges[: len(original_ranges)]
  1608. return TensorBox.create(
  1609. Reduction(
  1610. device=device,
  1611. dtype=dst_dtype,
  1612. inner_fn=intermediate_fn,
  1613. ranges=original_ranges,
  1614. reduction_ranges=new_ranges[len(original_ranges) :],
  1615. reduction_type=reduction_type,
  1616. src_dtype=src_dtype,
  1617. reduction_hint=reduction_hint,
  1618. )
  1619. )
  1620. @classmethod
  1621. def create_multilayer(
  1622. cls,
  1623. device: torch.device,
  1624. dst_dtype: torch.dtype,
  1625. src_dtype: torch.dtype,
  1626. inner_fn: Callable[..., Any],
  1627. ranges: Sequence[Expr],
  1628. reduction_ranges: Sequence[Expr],
  1629. reduction_type: ReductionType,
  1630. split: _IntLike,
  1631. reduction_hint: ReductionHint,
  1632. input_node: Optional[IRNode] = None,
  1633. ) -> TensorBox:
  1634. """
  1635. Break a large reduction up into multiple smaller reductions
  1636. recursively
  1637. """
  1638. # TODO(jansel): realize the reduction so we can do dynamic indexing
  1639. reduction_numel = sympy_product(reduction_ranges)
  1640. block_size = FloorDiv(reduction_numel + (split - 1), split)
  1641. default = cls.default_value(reduction_type, dst_dtype)
  1642. wrapper_fn = cls._multilayer_wrap_loader(
  1643. inner_fn,
  1644. reduction_ranges,
  1645. reduction_numel,
  1646. split,
  1647. block_size,
  1648. default,
  1649. input_node,
  1650. )
  1651. return cls.create_multilayer_helper(
  1652. device,
  1653. dst_dtype,
  1654. src_dtype,
  1655. wrapper_fn,
  1656. ranges,
  1657. reduction_ranges,
  1658. [*ranges, split],
  1659. [block_size],
  1660. reduction_type,
  1661. split,
  1662. reduction_hint,
  1663. )
  1664. @classmethod
  1665. def create_multilayer_existing_ranges(
  1666. cls,
  1667. device: torch.device,
  1668. dst_dtype: torch.dtype,
  1669. src_dtype: torch.dtype,
  1670. inner_fn: Callable[..., Any],
  1671. original_ranges: Sequence[Expr],
  1672. original_reduction_ranges: Sequence[Expr],
  1673. new_ranges: list[Integer],
  1674. new_reduction_ranges: list[Integer],
  1675. reduction_type: ReductionType,
  1676. reduction_hint: ReductionHint,
  1677. ) -> TensorBox:
  1678. """
  1679. Break a large reduction up into multiple smaller reductions
  1680. recursively
  1681. """
  1682. wrapper_fn = cls._multilayer_wrap_loader_existing_ranges(
  1683. inner_fn,
  1684. original_ranges,
  1685. original_reduction_ranges,
  1686. new_ranges,
  1687. new_reduction_ranges,
  1688. )
  1689. return cls.create_multilayer_helper(
  1690. device,
  1691. dst_dtype,
  1692. src_dtype,
  1693. wrapper_fn,
  1694. original_ranges,
  1695. original_reduction_ranges,
  1696. [*original_ranges, *new_ranges],
  1697. new_reduction_ranges,
  1698. reduction_type,
  1699. -1,
  1700. reduction_hint,
  1701. )
  1702. def _fixed_indexer(
  1703. size: Sequence[int],
  1704. stride: Optional[Sequence[int]] = None,
  1705. offset: Expr = Integer(0),
  1706. ) -> Callable[[Sequence[Expr]], Expr]:
  1707. """A closure containing math to read a given element"""
  1708. def indexer(index: Sequence[int]) -> int:
  1709. assert stride is not None and len(index) == len(stride)
  1710. assert len(index) == len(size)
  1711. result = offset
  1712. for idx, st, sz in zip(index, stride, size):
  1713. if sz != 1:
  1714. result = result + idx * st
  1715. return result
  1716. return indexer
  1717. INNER_FN_TY: TypeAlias = Callable[[Sequence[Expr], Sequence[Expr]], OpsValue]
  1718. class MultiOutputReduction(Reduction):
  1719. output_index: int
  1720. def __init__(
  1721. self,
  1722. device: torch.device,
  1723. dst_dtype: torch.dtype,
  1724. inner_fns: Union[INNER_FN_TY, Sequence[INNER_FN_TY]],
  1725. ranges: Sequence[Integer],
  1726. reduction_ranges: Sequence[Integer],
  1727. reduction_type: ReductionType,
  1728. src_dtype: torch.dtype,
  1729. reduction_hint: ReductionHint,
  1730. output_index: int,
  1731. ):
  1732. if callable(inner_fns):
  1733. inner_fns = (inner_fns,)
  1734. loader: Callable[[Sequence[Expr], Sequence[Expr]], Any]
  1735. if len(inner_fns) == 1:
  1736. loader = inner_fns[0]
  1737. else:
  1738. def loader(
  1739. idx: Sequence[Expr], reduction_idx: Sequence[Expr]
  1740. ) -> tuple[OpsValue, ...]:
  1741. return tuple(fn(idx, reduction_idx) for fn in inner_fns)
  1742. super().__init__(
  1743. device=device,
  1744. dtype=dst_dtype,
  1745. inner_fn=loader,
  1746. ranges=ranges,
  1747. reduction_ranges=reduction_ranges,
  1748. reduction_type=reduction_type,
  1749. src_dtype=src_dtype,
  1750. reduction_hint=reduction_hint,
  1751. )
  1752. self.output_index = output_index
  1753. def store_reduction(
  1754. self,
  1755. output_name: Optional[str],
  1756. indexer: Callable[[Sequence[Expr]], Never],
  1757. vars: Sequence[Expr],
  1758. reduction_vars: Sequence[Symbol],
  1759. ) -> Any:
  1760. values = ops.reduction(
  1761. self.dtype,
  1762. self.src_dtype,
  1763. self.reduction_type,
  1764. self.inner_fn(vars, reduction_vars),
  1765. )
  1766. assert isinstance(values, (tuple, list)), type(values)
  1767. value = values[self.output_index]
  1768. return ops.store_reduction(output_name or "unnamed", indexer(vars), value)
  1769. class OnlineSoftmaxReduction(MultiOutputReduction):
  1770. @classmethod
  1771. def create( # type: ignore[override]
  1772. cls,
  1773. device: torch.device,
  1774. dst_dtype: torch.dtype,
  1775. src_dtype: torch.dtype,
  1776. inner_fn: Callable[..., Any],
  1777. ranges: Sequence[Expr],
  1778. reduction_ranges: Sequence[Expr],
  1779. num_output: int,
  1780. reduction_hint: ReductionHint = ReductionHint.DEFAULT,
  1781. input_node: Optional[IRNode] = None,
  1782. ) -> Sequence[TensorBox]:
  1783. """
  1784. Create the reduction disregarding splitting.
  1785. """
  1786. results = tuple(
  1787. TensorBox.create(
  1788. MultiOutputReduction(
  1789. device,
  1790. dst_dtype,
  1791. inner_fn,
  1792. ranges,
  1793. reduction_ranges,
  1794. "online_softmax_reduce",
  1795. src_dtype,
  1796. reduction_hint,
  1797. output_idx,
  1798. )
  1799. )
  1800. for output_idx in range(num_output)
  1801. )
  1802. for t in results:
  1803. t.realize()
  1804. return results
  1805. class WelfordReduction(MultiOutputReduction):
  1806. @classmethod
  1807. def create( # type: ignore[override]
  1808. cls,
  1809. device: torch.device,
  1810. dtype: torch.dtype,
  1811. inner_fns: Sequence[Callable[..., Any]],
  1812. ranges: list[Integer],
  1813. reduction_ranges: list[Integer],
  1814. reduction_type: ReductionType,
  1815. reduction_hint: ReductionHint = ReductionHint.DEFAULT,
  1816. ) -> Sequence[TensorBox]:
  1817. assert reduction_type in ("welford_reduce", "welford_combine")
  1818. reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
  1819. def const(val: int) -> TensorBox:
  1820. def inner_fn(idx: Sequence[Expr]) -> OpsValue:
  1821. return ops.constant(
  1822. val,
  1823. dtype,
  1824. )
  1825. return Pointwise.create(
  1826. device=device,
  1827. dtype=dtype,
  1828. inner_fn=inner_fn,
  1829. ranges=list(ranges),
  1830. )
  1831. if reduction_numel == 0:
  1832. mean = const(0)
  1833. m2 = const(0)
  1834. weight = const(0)
  1835. return mean, m2, weight
  1836. if reduction_numel == 1:
  1837. def copy(
  1838. loader: Callable[[Sequence[Expr], Sequence[Expr]], OpsValue],
  1839. ) -> TensorBox:
  1840. def inner_fn(idx: Sequence[Expr]) -> OpsValue:
  1841. reduction_index = [sympy.S.Zero for _ in reduction_ranges]
  1842. return loader(idx, reduction_index)
  1843. return Pointwise.create(
  1844. device=device,
  1845. dtype=dtype,
  1846. inner_fn=inner_fn,
  1847. ranges=list(ranges),
  1848. )
  1849. if reduction_type == "welford_reduce":
  1850. return copy(inner_fns[0]), const(0), const(1)
  1851. else:
  1852. return tuple(copy(fn) for fn in inner_fns)
  1853. # TODO: Unrolled reduction
  1854. # if (
  1855. # isinstance(reduction_numel, Integer)
  1856. # and V.graph.sizevars.size_hint(reduction_numel)
  1857. # < config.unroll_reductions_threshold
  1858. # and sympy_product(ranges) != 1
  1859. # ):
  1860. # return Pointwise.create(
  1861. # device,
  1862. # dst_dtype,
  1863. # cls._unroll_reduction_fn(
  1864. # inner_fn, reduction_ranges, reduction_type, src_dtype,
  1865. # ),
  1866. # ranges,
  1867. # )
  1868. # triton doesn't support reduce to single element well, so break it up
  1869. hint, split = Reduction.num_splits(
  1870. device,
  1871. dtype,
  1872. dtype,
  1873. inner_fns[0],
  1874. ranges,
  1875. reduction_ranges,
  1876. reduction_type=reduction_type,
  1877. reduction_numel=reduction_numel,
  1878. )
  1879. # intermediate reduction in split can contain complex indexing,
  1880. # and num_splits will fail to correctly set the hint
  1881. # reuse the passed hint if available
  1882. if reduction_hint == ReductionHint.DEFAULT:
  1883. reduction_hint = hint
  1884. if split > 1:
  1885. # triton doesn't support reduce to single element well, so break it up
  1886. return cls.create_multilayer(
  1887. device,
  1888. dtype,
  1889. inner_fns,
  1890. ranges,
  1891. reduction_ranges,
  1892. reduction_type,
  1893. split,
  1894. reduction_hint,
  1895. )
  1896. results = [
  1897. TensorBox.create(
  1898. WelfordReduction(
  1899. device,
  1900. dtype,
  1901. inner_fns,
  1902. ranges,
  1903. reduction_ranges,
  1904. reduction_type,
  1905. dtype,
  1906. reduction_hint,
  1907. output_idx,
  1908. )
  1909. )
  1910. for output_idx in range(3)
  1911. ]
  1912. for t in results:
  1913. t.realize()
  1914. return results
  1915. @staticmethod
  1916. def default_value(
  1917. reduction_type: str, dtype: torch.dtype
  1918. ) -> Union[_NumLike, Sequence[_NumLike]]:
  1919. return (0, 0, 0)
  1920. @classmethod
  1921. def create_multilayer( # type: ignore[override]
  1922. cls,
  1923. device: torch.device,
  1924. dtype: torch.dtype,
  1925. inner_fns: Sequence[Callable[..., Any]],
  1926. ranges: list[Integer],
  1927. reduction_ranges: list[Integer],
  1928. reduction_type: ReductionType,
  1929. split: _IntLike,
  1930. reduction_hint: ReductionHint,
  1931. ) -> Sequence[TensorBox]:
  1932. """
  1933. Break a large reduction up into multiple smaller reductions
  1934. recursively
  1935. """
  1936. reduction_numel = sympy_product(reduction_ranges)
  1937. need_mask = not V.graph.sizevars.statically_known_true(
  1938. sympy.Eq(reduction_numel % split, 0)
  1939. )
  1940. if need_mask and reduction_type != "welford_combine":
  1941. # If we need mask, then "welford_reduce" doesn't work because
  1942. # masked inputs shouldn't count towards the welford weight
  1943. def constant(
  1944. idx: Sequence[Expr], reduction_idx: Sequence[Expr], value: int
  1945. ) -> OpsValue:
  1946. return ops.constant(value, dtype)
  1947. return cls.create_multilayer(
  1948. device=device,
  1949. dtype=dtype,
  1950. inner_fns=(
  1951. inner_fns[0],
  1952. partial(constant, value=0),
  1953. partial(constant, value=1),
  1954. ),
  1955. ranges=ranges,
  1956. reduction_ranges=reduction_ranges,
  1957. reduction_type="welford_combine",
  1958. split=split,
  1959. reduction_hint=reduction_hint,
  1960. )
  1961. block_size = FloorDiv(reduction_numel + (split - 1), split)
  1962. intermediates = WelfordReduction.create(
  1963. device,
  1964. dtype,
  1965. tuple(
  1966. cls._multilayer_wrap_loader(
  1967. loader,
  1968. reduction_ranges,
  1969. reduction_numel,
  1970. split,
  1971. block_size,
  1972. default=0,
  1973. )
  1974. for loader in inner_fns
  1975. ),
  1976. [*ranges, split],
  1977. [block_size],
  1978. reduction_type,
  1979. reduction_hint,
  1980. )
  1981. for i in intermediates:
  1982. i.realize()
  1983. def intermediate_loader_fn(
  1984. index: Sequence[Expr],
  1985. reduction_index: Sequence[Expr],
  1986. loader: Callable[[Sequence[Expr]], OpsValue],
  1987. ) -> OpsValue:
  1988. return loader([*index, *reduction_index])
  1989. numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges))
  1990. reduction_hint = cls._multilayer_second_step_hint(
  1991. split, numel_hint, reduction_hint
  1992. )
  1993. return WelfordReduction.create(
  1994. device,
  1995. dtype,
  1996. tuple(
  1997. partial(intermediate_loader_fn, loader=i.make_loader())
  1998. for i in intermediates
  1999. ),
  2000. ranges,
  2001. [split],
  2002. # welford_reduce turns one input into three outputs, which are combined with welford_combine
  2003. "welford_combine",
  2004. reduction_hint,
  2005. )
  2006. @ir_dataclass
  2007. class Scan(Loops):
  2008. scan_ranges: list[Integer]
  2009. size: list[Integer]
  2010. combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]]
  2011. reindex: Callable[[Sequence[_IntLike], Sequence[_IntLike]], Sequence[_IntLike]]
  2012. reduction_hint: ReductionHint
  2013. output_index: int
  2014. # output_index indexes the following tuples
  2015. dtypes: tuple[torch.dtype, ...]
  2016. inner_fns: tuple[Callable[..., Any], ...]
  2017. # HACK we mimic reduction
  2018. @cache_on_self_and_args("Scan")
  2019. def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
  2020. # TODO: Can combine_fn/reindex close over unbacked symbols? If so, we
  2021. # need to explicitly represent the closure so we can pull out unbacked
  2022. # symbols here
  2023. return (
  2024. super().get_free_symbol_uses(unbacked_only)
  2025. | OrderedSet().union(
  2026. *(get_free_symbols(e, unbacked_only) for e in self.scan_ranges)
  2027. )
  2028. | OrderedSet().union(
  2029. *(get_free_symbols(e, unbacked_only) for e in self.size)
  2030. )
  2031. )
  2032. def __post_init__(self) -> None:
  2033. assert len(self.ranges) + len(self.scan_ranges) == len(self.size)
  2034. super().__post_init__()
  2035. def store_reduction(
  2036. self,
  2037. output_name: Optional[str],
  2038. indexer: Callable[[Sequence[_IntLike]], Never],
  2039. vars: Sequence[Expr],
  2040. scan_vars: Sequence[Symbol],
  2041. ) -> Any:
  2042. idx = self.reindex(vars, scan_vars)
  2043. values = tuple(inner_fn(idx) for inner_fn in self.inner_fns)
  2044. result = ops.scan(self.dtypes, self.combine_fn, values)
  2045. return ops.store(
  2046. output_name or "unnamed", indexer(idx), result[self.output_index]
  2047. )
  2048. def get_reduction_type(self) -> Optional[str]:
  2049. # return self.scan_op
  2050. return "custom"
  2051. def get_reduction_size(self) -> Sequence[Expr]:
  2052. return self.scan_ranges
  2053. def get_size(self) -> Sequence[Expr]:
  2054. return self.size
  2055. def get_pointwise_size(self) -> Sequence[Expr]:
  2056. return self.ranges
  2057. def index_length(self) -> int:
  2058. return len(self.ranges) + len(self.scan_ranges)
  2059. def inner_fn_args(self) -> Sequence[Sequence[_IntLike]]:
  2060. index = self._index(self.ranges)
  2061. rindex = self._index(self.scan_ranges, SymT.R0_INDEX)
  2062. idx = self.reindex(index, rindex)
  2063. return (idx,)
  2064. def inner_fn_free_symbols(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
  2065. index = self._index(self.ranges)
  2066. rindex = self._index(self.scan_ranges, SymT.R0_INDEX)
  2067. idx = self.reindex(index, rindex)
  2068. return extract_free_symbols(self.inner_fn, idx, unbacked_only=unbacked_only)
  2069. @classmethod
  2070. def create( # type: ignore[override]
  2071. cls,
  2072. device: torch.device,
  2073. dtypes: tuple[torch.dtype, ...],
  2074. inner_fns: tuple[Callable[[Sequence[Expr]], Any], ...],
  2075. size: list[Integer],
  2076. axis: int,
  2077. combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]],
  2078. reduction_hint: ReductionHint = ReductionHint.DEFAULT,
  2079. *,
  2080. # Whether we have the option to fallback to aten
  2081. can_fallback_to_aten: bool = True,
  2082. **kwargs: Any,
  2083. ) -> Sequence[Optional[TensorBox]]:
  2084. pointwise_ranges = [*size[:axis], *size[axis + 1 :]]
  2085. scan_ranges = [size[axis]]
  2086. if not V.graph.has_feature(device, BackendFeature.SCAN):
  2087. return [None] * len(dtypes)
  2088. if len(dtypes) > 1 and not V.graph.has_feature(
  2089. device, BackendFeature.TUPLE_REDUCTION
  2090. ):
  2091. return [None] * len(dtypes)
  2092. sizevars = V.graph.sizevars
  2093. scan_numel = sizevars.simplify(sympy_product(scan_ranges))
  2094. assert len(dtypes) == len(inner_fns)
  2095. # Scan with a single element is just a copy
  2096. if sizevars.statically_known_true(sympy.Le(scan_numel, 1)):
  2097. return [
  2098. Pointwise.create(
  2099. device=device,
  2100. dtype=dtypes[output_index],
  2101. inner_fn=inner_fns[output_index],
  2102. ranges=size,
  2103. )
  2104. for output_index in range(len(dtypes))
  2105. ]
  2106. reduction_hint, num_splits = cls.num_splits(
  2107. device=device,
  2108. dtype=dtypes[0],
  2109. inner_fn=inner_fns[0],
  2110. axis=axis,
  2111. pointwise_ranges=pointwise_ranges,
  2112. scan_ranges=scan_ranges,
  2113. combine_fn=combine_fn,
  2114. scan_numel=scan_numel,
  2115. )
  2116. scan_type = Scan
  2117. if num_splits > 1:
  2118. supports_split = (
  2119. # pyrefly: ignore [unsupported-operation]
  2120. torch.version.hip is None or (has_triton and triton_version >= "3.3.0")
  2121. ) and (len(dtypes) == 1)
  2122. if not supports_split:
  2123. if can_fallback_to_aten:
  2124. # Fallback to ATen
  2125. return [None] * len(dtypes)
  2126. else:
  2127. num_splits = 1
  2128. else:
  2129. scan_type = SplitScan
  2130. def reindex(index: Sequence[Expr], scan_index: Sequence[Expr]) -> list[Expr]:
  2131. assert len(scan_index) == len(scan_ranges)
  2132. assert len(index) == len(pointwise_ranges)
  2133. return [*index[:axis], *scan_index, *index[axis:]]
  2134. results = [
  2135. TensorBox.create(
  2136. scan_type(
  2137. device=device,
  2138. dtype=dtypes[output_index],
  2139. dtypes=dtypes,
  2140. inner_fn=inner_fns[output_index],
  2141. inner_fns=inner_fns,
  2142. size=size,
  2143. ranges=pointwise_ranges,
  2144. scan_ranges=scan_ranges,
  2145. combine_fn=combine_fn,
  2146. reindex=reindex,
  2147. reduction_hint=reduction_hint,
  2148. output_index=output_index,
  2149. **kwargs,
  2150. )
  2151. )
  2152. for output_index in range(len(dtypes))
  2153. ]
  2154. for result in results:
  2155. result.realize()
  2156. return results
  2157. @classmethod
  2158. def num_splits(
  2159. cls,
  2160. device: torch.device,
  2161. dtype: torch.dtype,
  2162. inner_fn: Callable[[Sequence[Expr]], OpsValue],
  2163. axis: int,
  2164. pointwise_ranges: list[Integer],
  2165. scan_ranges: list[Integer],
  2166. combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]],
  2167. scan_numel: Expr,
  2168. ) -> tuple[ReductionHint, _IntLike]:
  2169. # TODO: custom splitting heuristic for scan
  2170. def wrapper_fn(idx: Sequence[Expr], reduction_idx: Sequence[Expr]) -> OpsValue:
  2171. return inner_fn([*idx[:axis], *reduction_idx, *idx[axis:]])
  2172. return Reduction.num_splits(
  2173. device=device,
  2174. dst_dtype=dtype,
  2175. src_dtype=dtype,
  2176. inner_fn=wrapper_fn,
  2177. ranges=pointwise_ranges,
  2178. reduction_ranges=scan_ranges,
  2179. reduction_type="scan",
  2180. reduction_numel=scan_numel,
  2181. )
  2182. # This signifies a scan op that should go through TritonSplitScanKernel codegen on CUDA.
  2183. @ir_dataclass
  2184. class SplitScan(Scan):
  2185. pass
  2186. @ir_dataclass
  2187. class Sort(Loops):
  2188. # Sorts a tuple of key, value pairs
  2189. sort_ranges: list[Integer]
  2190. size: list[Integer]
  2191. reindex: Callable[[Sequence[Expr], Sequence[Expr]], Sequence[Expr]]
  2192. reduction_hint: ReductionHint
  2193. output_index: int
  2194. # output_index indexes the following tuples
  2195. dtypes: tuple[torch.dtype, ...]
  2196. inner_fns: tuple[Callable[..., Any], ...]
  2197. stable: bool
  2198. descending: bool
  2199. # HACK we mimic reduction
  2200. @cache_on_self_and_args("Sort")
  2201. def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
  2202. return (
  2203. super().get_free_symbol_uses(unbacked_only)
  2204. | OrderedSet().union(
  2205. *(get_free_symbols(e, unbacked_only) for e in self.sort_ranges)
  2206. )
  2207. | OrderedSet().union(
  2208. *(get_free_symbols(e, unbacked_only) for e in self.size)
  2209. )
  2210. )
  2211. def __post_init__(self) -> None:
  2212. assert len(self.ranges) + len(self.sort_ranges) == len(self.size)
  2213. super().__post_init__()
  2214. def store_reduction(
  2215. self,
  2216. output_name: Optional[str],
  2217. indexer: Callable[[Sequence[Expr]], Expr],
  2218. vars: Sequence[Expr],
  2219. reduction_vars: Sequence[Expr],
  2220. ) -> Any:
  2221. idx = self.reindex(vars, reduction_vars)
  2222. values = tuple(inner_fn(idx) for inner_fn in self.inner_fns)
  2223. result = ops.sort(self.dtypes, values, self.stable, self.descending)
  2224. return ops.store(
  2225. output_name or "unnamed", indexer(idx), result[self.output_index]
  2226. )
  2227. def get_reduction_type(self) -> Optional[str]:
  2228. return "sort"
  2229. def get_reduction_size(self) -> Sequence[Expr]:
  2230. return self.sort_ranges
  2231. def get_size(self) -> Sequence[Expr]:
  2232. return self.size
  2233. def get_pointwise_size(self) -> Sequence[Expr]:
  2234. return self.ranges
  2235. def index_length(self) -> int:
  2236. return len(self.ranges) + len(self.sort_ranges)
  2237. def inner_fn_args(self) -> Sequence[Sequence[Expr]]:
  2238. index = self._index(self.ranges)
  2239. rindex = self._index(self.sort_ranges, SymT.R0_INDEX)
  2240. idx = self.reindex(index, rindex)
  2241. return (idx,)
  2242. def inner_fn_free_symbols(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
  2243. index = self._index(self.ranges)
  2244. rindex = self._index(self.sort_ranges, SymT.R0_INDEX)
  2245. idx = self.reindex(index, rindex)
  2246. return extract_free_symbols(self.inner_fn, idx, unbacked_only=unbacked_only)
  2247. @classmethod
  2248. def create( # type: ignore[override]
  2249. cls,
  2250. device: torch.device,
  2251. dtypes: tuple[torch.dtype, ...],
  2252. inner_fns: tuple[Callable[[list[Expr]], Any], ...],
  2253. size: list[Integer],
  2254. axis: int,
  2255. stable: bool,
  2256. descending: bool,
  2257. reduction_hint: ReductionHint = ReductionHint.DEFAULT,
  2258. **kwargs: Any,
  2259. ) -> Sequence[Optional[TensorBox]]:
  2260. pointwise_ranges = [*size[:axis], *size[axis + 1 :]]
  2261. sort_ranges = [size[axis]]
  2262. if not V.graph.has_feature(device, BackendFeature.SORT):
  2263. return [None] * len(dtypes)
  2264. sizevars = V.graph.sizevars
  2265. sort_numel = sizevars.simplify(sympy_product(sort_ranges))
  2266. # Heuristic, smallest rblock where triton usually outperforms aten.sort
  2267. # It also isn't bandwidth bound so fusion is unlikely to help.
  2268. max_rblock = 512
  2269. is_persistent_kernel = (
  2270. config.triton.persistent_reductions
  2271. and sizevars.statically_known_true(sympy.Le(sort_numel, max_rblock))
  2272. )
  2273. if not is_persistent_kernel:
  2274. # We only support persistent triton kernels
  2275. return [None] * len(dtypes)
  2276. assert len(dtypes) == len(inner_fns)
  2277. # Sort with a single element is just a copy
  2278. if sizevars.statically_known_true(sympy.Le(sort_numel, 1)):
  2279. return [
  2280. Pointwise.create(
  2281. device=device,
  2282. dtype=dtypes[output_index],
  2283. inner_fn=inner_fns[output_index],
  2284. ranges=size,
  2285. )
  2286. for output_index in range(len(dtypes))
  2287. ]
  2288. def reindex(index: Sequence[Expr], sort_index: Sequence[Expr]) -> list[Expr]:
  2289. assert len(sort_index) == len(sort_ranges)
  2290. assert len(index) == len(pointwise_ranges)
  2291. return [*index[:axis], *sort_index, *index[axis:]]
  2292. results = [
  2293. TensorBox.create(
  2294. Sort(
  2295. device=device,
  2296. dtype=dtypes[output_index],
  2297. dtypes=dtypes,
  2298. inner_fn=inner_fns[output_index],
  2299. inner_fns=inner_fns,
  2300. size=size,
  2301. ranges=pointwise_ranges,
  2302. sort_ranges=sort_ranges,
  2303. reindex=reindex,
  2304. reduction_hint=reduction_hint,
  2305. output_index=output_index,
  2306. stable=stable,
  2307. descending=descending,
  2308. **kwargs,
  2309. )
  2310. )
  2311. for output_index in range(len(dtypes))
  2312. ]
  2313. for result in results:
  2314. result.realize()
  2315. return results
  2316. def is_storage_and_layout(x: IRNode) -> bool:
  2317. try:
  2318. as_storage_and_layout(x, freeze=False)
  2319. return True
  2320. except NotImplementedError:
  2321. return False
  2322. def is_contiguous_storage_and_layout(x: IRNode) -> bool:
  2323. try:
  2324. _buffer, layout = as_storage_and_layout(x, freeze=False)
  2325. # pad the stride here so we will NOT claim an tensor as contiguous
  2326. # if a padding is gonna happen.
  2327. if layout.should_pad_strides():
  2328. layout.pad_strides()
  2329. return layout.is_contiguous()
  2330. except NotImplementedError:
  2331. return False
  2332. def as_storage_and_layout(
  2333. x: IRNode,
  2334. freeze: bool = True,
  2335. want_contiguous: bool = False,
  2336. stride_order: Optional[Sequence[Union[int, Integer]]] = None,
  2337. allow_padding: bool = False,
  2338. exact_strides: Optional[Sequence[Union[int, Integer]]] = None,
  2339. ) -> tuple[StorageBox, Layout]:
  2340. """
  2341. Try to simplify x into a StorageBox and a Layout.
  2342. allow_padding only affect how we apply stride_order. When allow_padding
  2343. is True, we have the freedom to add padding when applying the stride_order.
  2344. """
  2345. if isinstance(x, TensorBox):
  2346. return as_storage_and_layout(
  2347. x.data,
  2348. freeze=freeze,
  2349. want_contiguous=want_contiguous,
  2350. stride_order=stride_order,
  2351. allow_padding=allow_padding,
  2352. exact_strides=exact_strides,
  2353. )
  2354. if isinstance(x, StorageBox):
  2355. _, layout = as_storage_and_layout(
  2356. x.data,
  2357. freeze=freeze,
  2358. want_contiguous=want_contiguous,
  2359. stride_order=stride_order,
  2360. allow_padding=allow_padding,
  2361. exact_strides=exact_strides,
  2362. )
  2363. return x, x.data.get_layout()
  2364. if isinstance(x, Buffer):
  2365. if freeze:
  2366. if want_contiguous:
  2367. x.freeze_layout()
  2368. assert x.get_layout().is_contiguous()
  2369. elif stride_order is not None:
  2370. x.freeze_layout_with_stride_order(
  2371. stride_order, allow_padding=allow_padding
  2372. )
  2373. elif exact_strides is not None:
  2374. x.freeze_layout_with_exact_strides(
  2375. exact_strides, allow_padding=allow_padding
  2376. )
  2377. else:
  2378. x.decide_layout()
  2379. return StorageBox(x), x.get_layout()
  2380. if isinstance(x, ReinterpretView):
  2381. # making the base of x contiguous or stride_ordered will not necessarily make
  2382. # the ReinterpretView either, so don't pass along those arguments
  2383. buffer, _ = as_storage_and_layout(
  2384. x.data,
  2385. freeze=freeze,
  2386. )
  2387. return buffer, x.layout
  2388. raise NotImplementedError
  2389. def is_stride_order_storage_and_layout(
  2390. x: IRNode, stride_order: Sequence[Union[int, Integer]]
  2391. ) -> bool:
  2392. try:
  2393. _buffer, layout = as_storage_and_layout(x, freeze=False)
  2394. return layout.is_stride_ordered(stride_order)
  2395. except NotImplementedError:
  2396. return False
  2397. def is_unaligned(node: IRNode) -> bool:
  2398. if isinstance(node, (TensorBox, StorageBox)):
  2399. return is_unaligned(node.data)
  2400. if isinstance(node, ReinterpretView):
  2401. layout = node.layout
  2402. has_unaligned_layout = not V.graph.sizevars.statically_known_multiple_of(
  2403. layout.offset * get_dtype_size(layout.dtype), GPU_ALIGN_BYTES
  2404. )
  2405. return is_unaligned(node.data) or has_unaligned_layout
  2406. if isinstance(node, Buffer):
  2407. return node.get_name() in V.graph.unaligned_buffers
  2408. # assume to be aligned otherwise
  2409. return False
  2410. @ir_dataclass
  2411. class BaseView(IRNode):
  2412. data: IRNode
  2413. @cache_on_self_and_args("BaseView")
  2414. def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
  2415. return self.data.get_free_symbol_uses(unbacked_only)
  2416. def make_reindexer(self) -> Callable[[Sequence[Expr]], Sequence[Expr]]:
  2417. raise NotImplementedError(f"make_reindexer NYI on {self}")
  2418. def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
  2419. inner = self.data.make_indexer()
  2420. reindex = self.make_reindexer()
  2421. def indexer(idx: Sequence[Expr]) -> Expr:
  2422. return inner(reindex(idx))
  2423. return indexer
  2424. def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
  2425. inner = self.data.make_loader()
  2426. reindex = self.make_reindexer()
  2427. def loader(idx: Sequence[Expr]) -> OpsValue:
  2428. return inner(reindex(idx))
  2429. return loader
  2430. @property
  2431. def dtype(self) -> torch.dtype:
  2432. return self.data.get_dtype()
  2433. def get_layout(self) -> Layout:
  2434. return self.data.get_layout()
  2435. def get_device(self) -> Optional[torch.device]:
  2436. return self.data.get_device()
  2437. def get_origin_node(self) -> Optional[torch.fx.Node]:
  2438. return None
  2439. def get_name(self) -> str:
  2440. return self.data.get_name()
  2441. def get_pointwise_size(self) -> Sequence[Expr]:
  2442. return self.get_size()
  2443. def mark_reuse(self, users: int) -> None:
  2444. return self.data.mark_reuse(users)
  2445. def has_exceeded_max_reads(self) -> bool:
  2446. return self.data.has_exceeded_max_reads()
  2447. def realize(self) -> Optional[str]:
  2448. return self.data.realize()
  2449. def realize_hint(self) -> None:
  2450. self.data.realize_hint()
  2451. def get_storage_numel(self) -> _IntLike:
  2452. return self.data.get_storage_numel()
  2453. def is_extern(self) -> bool:
  2454. return self.data.is_extern()
  2455. def is_module_buffer(self) -> bool:
  2456. assert isinstance(self.data, BaseView), type(self.data)
  2457. return self.data.is_module_buffer()
  2458. def get_read_names(self) -> OrderedSet[str]:
  2459. return self.data.get_read_names()
  2460. def get_reads(self) -> OrderedSet[Dep]:
  2461. with patch.object(FlexibleLayout, "allow_indexing", True):
  2462. return extract_read_writes(
  2463. self.make_loader(),
  2464. self.get_size(),
  2465. ).reads
  2466. def unwrap_view(self) -> IRNode:
  2467. x: IRNode = self
  2468. while isinstance(x, BaseView):
  2469. x = x.data
  2470. return x
  2471. def constant_to_device(self, device: torch.device) -> IRNode:
  2472. """Move this to a given device. Requires that all reads are to constants."""
  2473. loader = self.make_loader()
  2474. loader = patch.object(ConstantBuffer, "override_device", device)(loader)
  2475. return Pointwise(
  2476. device=device,
  2477. dtype=self.get_dtype(),
  2478. inner_fn=loader,
  2479. ranges=self.get_size(),
  2480. )
  2481. @ir_dataclass
  2482. class ExpandView(BaseView):
  2483. size: Sequence[Expr]
  2484. @staticmethod
  2485. def _normalize_size(x: IRNode, new_size: Sequence[_IntLike]) -> Sequence[_IntLike]:
  2486. """Replace `-1` with correct sizes"""
  2487. sizevars = V.graph.sizevars
  2488. new_size = [sympy.expand(s) for s in new_size]
  2489. old_size = x.get_size()
  2490. old_size = [None] * (len(new_size) - len(old_size)) + list(old_size)
  2491. assert len(new_size) == len(old_size)
  2492. for i in range(len(new_size)):
  2493. if new_size[i] == -1:
  2494. assert old_size[i] is not None
  2495. new_size[i] = old_size[i]
  2496. elif old_size[i] is None or V.graph.sizevars.is_size_one_or_false(
  2497. old_size[i]
  2498. ):
  2499. pass
  2500. else:
  2501. # Sanity check: Expect broadcast compatibility
  2502. #
  2503. # NB: new_size[i] == old_size[i] is expected to already be
  2504. # guarded because the meta formula was expected to have taught
  2505. # us this equality.
  2506. v1 = new_size[i]
  2507. v2 = old_size[i]
  2508. assert v1 is not None
  2509. assert v2 is not None
  2510. diff = v1 - v2
  2511. assert (
  2512. sizevars.optimization_hint(
  2513. diff,
  2514. fallback=0,
  2515. )
  2516. == 0
  2517. ), (
  2518. f"Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}"
  2519. )
  2520. return new_size
  2521. @classmethod
  2522. def create(cls, x: IRNode, new_size: Sequence[_IntLike]) -> BaseView:
  2523. new_size = cls._normalize_size(x, new_size)
  2524. if is_storage_and_layout(x):
  2525. storage, old_layout = as_storage_and_layout(x)
  2526. skip = len(new_size) - len(old_layout.size)
  2527. assert skip >= 0
  2528. new_stride = [sympy.S.Zero] * skip
  2529. for stride, size in zip(old_layout.stride, old_layout.size):
  2530. new_stride.append(
  2531. stride
  2532. if not V.graph.sizevars.is_size_one_or_false(size)
  2533. else sympy.S.Zero
  2534. )
  2535. new_layout = FixedLayout(
  2536. old_layout.device,
  2537. old_layout.dtype,
  2538. list(new_size),
  2539. new_stride,
  2540. old_layout.offset,
  2541. old_layout.is_pinned,
  2542. )
  2543. return ReinterpretView(data=storage, layout=new_layout)
  2544. return ExpandView(data=x, size=new_size)
  2545. def get_size(self) -> Sequence[Expr]:
  2546. return self.size
  2547. def make_reindexer(
  2548. self,
  2549. ) -> Callable[[Sequence[Expr]], Sequence[Expr]]:
  2550. target = self.get_size()
  2551. actual = self.data.get_size()
  2552. skip = len(target) - len(actual)
  2553. def reindex(
  2554. index: Sequence[Expr],
  2555. ) -> Sequence[Expr]:
  2556. index = list(index[skip:])
  2557. assert len(index) == len(actual)
  2558. for i in range(len(actual)):
  2559. if actual[i] == 1:
  2560. # zero out broadcast dimension
  2561. index[i] = sympy.S.Zero
  2562. return index
  2563. return reindex
  2564. @ir_dataclass
  2565. class PermuteView(BaseView):
  2566. dims: list[Expr]
  2567. @classmethod
  2568. def create(cls, x: IRNode, dims: Sequence[int]) -> BaseView:
  2569. dims = cls._map_neg_dims(dims)
  2570. assert OrderedSet(dims) == OrderedSet(range(len(dims)))
  2571. if is_storage_and_layout(x):
  2572. storage, old_layout = as_storage_and_layout(x)
  2573. new_layout = FixedLayout(
  2574. old_layout.device,
  2575. old_layout.dtype,
  2576. [old_layout.size[i] for i in dims],
  2577. [old_layout.stride[i] for i in dims],
  2578. old_layout.offset,
  2579. old_layout.is_pinned,
  2580. )
  2581. return ReinterpretView(data=storage, layout=new_layout)
  2582. return PermuteView(data=x, dims=dims)
  2583. @classmethod
  2584. def _map_neg_dims(cls, dims: Sequence[int]) -> list[int]:
  2585. return [dim if dim >= 0 else len(dims) + dim for dim in dims]
  2586. def get_size(self) -> Sequence[Expr]:
  2587. assert OrderedSet(self._map_neg_dims(self.dims)) == OrderedSet(
  2588. range(len(self.dims))
  2589. )
  2590. size = self.data.get_size()
  2591. return [size[i] for i in self.dims]
  2592. def make_reindexer(
  2593. self,
  2594. ) -> Callable[[Sequence[Expr]], Sequence[Expr]]:
  2595. inv = {j: i for i, j in enumerate(self.dims)}
  2596. inv = [inv[i] for i in range(len(self.dims))]
  2597. assert OrderedSet(inv) == OrderedSet(range(len(self.dims)))
  2598. def reindex(
  2599. index: Sequence[Expr],
  2600. ) -> Sequence[Expr]:
  2601. return [index[i] for i in inv]
  2602. return reindex
  2603. @ir_dataclass
  2604. class SqueezeView(BaseView):
  2605. @classmethod
  2606. def create(cls, x: IRNode, *, dim: Optional[int] = None) -> IRNode:
  2607. if is_storage_and_layout(x):
  2608. storage, old_layout = as_storage_and_layout(x)
  2609. new_size = []
  2610. new_stride = []
  2611. if dim is not None:
  2612. assert isinstance(dim, int), type(dim)
  2613. assert 0 <= dim and dim < len(old_layout.size)
  2614. for i, (size, stride) in enumerate(zip(old_layout.size, old_layout.stride)):
  2615. if dim is None:
  2616. # Only append if dim is not squeezed out
  2617. if not V.graph.sizevars.is_size_one_or_false(size):
  2618. new_size.append(size)
  2619. new_stride.append(stride)
  2620. else:
  2621. if i != dim:
  2622. new_size.append(size)
  2623. new_stride.append(stride)
  2624. else:
  2625. assert size == 1, "expected squeezed size to be 1"
  2626. new_layout = FixedLayout(
  2627. old_layout.device,
  2628. old_layout.dtype,
  2629. new_size,
  2630. new_stride,
  2631. old_layout.offset,
  2632. old_layout.is_pinned,
  2633. )
  2634. return ReinterpretView(data=storage, layout=new_layout)
  2635. if dim is None:
  2636. return View.create(
  2637. x,
  2638. [
  2639. s
  2640. for s in x.get_size()
  2641. if not V.graph.sizevars.is_size_one_or_false(s)
  2642. ],
  2643. )
  2644. else:
  2645. assert x.get_size()[dim] == 1
  2646. return View.create(x, [s for i, s in enumerate(x.get_size()) if i != dim])
  2647. @staticmethod
  2648. def squeezer(
  2649. size: Sequence[Expr],
  2650. ) -> tuple[list[int], Callable[[Sequence[Expr]], tuple[Expr, ...]]]:
  2651. new_size = [s for s in size if s != 1]
  2652. not_one = [i for i, s in enumerate(size) if s != 1]
  2653. length = len(size)
  2654. def reindex(index: Sequence[Expr]) -> tuple[Expr, ...]:
  2655. assert len(index) == len(not_one), f"{index} {not_one}"
  2656. new_index: list[Expr] = [sympy.S.Zero] * length
  2657. for idx, s in zip(not_one, index):
  2658. new_index[idx] = s
  2659. return tuple(new_index)
  2660. return new_size, reindex
  2661. def __init__(self, data: Any) -> None:
  2662. raise AssertionError("use SqueezeView.create()")
  2663. @ir_dataclass
  2664. class GenericView(BaseView):
  2665. size: Sequence[Expr]
  2666. reindex: Callable[[Sequence[Expr]], Sequence[Expr]]
  2667. def make_reindexer(
  2668. self,
  2669. ) -> Callable[[Sequence[Expr]], Sequence[Expr]]:
  2670. return self.reindex
  2671. def reindex_str(self) -> str:
  2672. index_old = [
  2673. sympy_index_symbol_with_prefix(SymT.INDEX, n) for n in range(len(self.size))
  2674. ]
  2675. index_new = list(self.reindex(index_old))
  2676. return f"lambda {', '.join(map(str, index_old))}: {index_new}"
  2677. def __str__(self) -> str:
  2678. return self.str_helper(
  2679. [self.data, f"size={self.size}", f"reindex={self.reindex_str()}"]
  2680. )
  2681. __repr__ = __str__
  2682. @classmethod
  2683. def create(
  2684. cls,
  2685. x: IRNode,
  2686. new_size: Sequence[Expr],
  2687. reindex: Callable[[Sequence[Expr]], Sequence[Expr]],
  2688. ) -> BaseView:
  2689. return cls(data=x, size=list(new_size), reindex=reindex)
  2690. def get_size(self) -> Sequence[Expr]:
  2691. return self.size
  2692. @ir_dataclass
  2693. class View(GenericView):
  2694. """
  2695. This class handles tensor reshaping by computing appropriate index transformations
  2696. to map the new shape back to the original storage layout.
  2697. """
  2698. @staticmethod
  2699. def handle_negative_index(idx: Expr, size: Expr) -> Expr:
  2700. idx = sympy.expand(idx)
  2701. size = sympy.expand(size)
  2702. evaluate_expr = V.graph.sizevars.shape_env.evaluate_expr
  2703. if evaluate_expr(sympy.Lt(idx, 0)):
  2704. idx = idx + size
  2705. return idx
  2706. @classmethod
  2707. @override
  2708. def create(cls, x: IRNode, new_size: Sequence[Expr]) -> IRNode: # type: ignore[override]
  2709. assert isinstance(new_size, Sequence), type(new_size)
  2710. old_size, new_size = cls.resolve_negative_size(x.get_size(), new_size)
  2711. # Skip pointless views
  2712. if V.graph.sizevars.statically_known_list_equals(old_size, new_size):
  2713. return x
  2714. unbacked_symbols_in_sizes = (
  2715. len(free_unbacked_symbols(old_size)) > 0
  2716. or len(free_unbacked_symbols(new_size)) > 0
  2717. )
  2718. is_contiguous = is_contiguous_storage_and_layout(x)
  2719. def create_reinterpret_view(
  2720. inp: IRNode, new_size: Sequence[Expr], new_stride: Sequence[Expr]
  2721. ) -> ReinterpretView:
  2722. storage, old_layout = as_storage_and_layout(inp, want_contiguous=True)
  2723. new_layout = FixedLayout(
  2724. old_layout.device,
  2725. old_layout.dtype,
  2726. new_size,
  2727. new_stride,
  2728. old_layout.offset,
  2729. old_layout.is_pinned,
  2730. )
  2731. return ReinterpretView(data=storage, layout=new_layout)
  2732. def handle_unbacked_or_dynamic_reshape(
  2733. x: IRNode,
  2734. ) -> IRNode:
  2735. """
  2736. Handle the case where view is not possible with current strides.
  2737. For unbacked symbols, make contiguous; otherwise use dynamic_reshape_indexer.
  2738. """
  2739. nonlocal old_size, new_size, unbacked_symbols_in_sizes
  2740. if unbacked_symbols_in_sizes:
  2741. # For unbacked symbols, we must require contiguous
  2742. # dynamic_reshape_indexer cannot handle unbacked SymInts
  2743. # https://github.com/pytorch/pytorch/issues/145561
  2744. x = ExternKernel.require_contiguous(x)
  2745. return create_reinterpret_view(
  2746. x, new_size, FlexibleLayout.contiguous_strides(new_size)
  2747. )
  2748. # For backed symbols, fall back to dynamic_reshape_indexer
  2749. reindex = cls.dynamic_reshape_indexer(old_size, new_size)
  2750. return cls(data=x, size=list(new_size), reindex=reindex)
  2751. if 0 in new_size:
  2752. def fake_reindex(index: Any) -> tuple[int, ...]:
  2753. return tuple([0] * len(old_size))
  2754. return cls(data=x, size=list(new_size), reindex=fake_reindex)
  2755. # TODO: a new class for FixedTransferLayout that output layout is constrained by input layout
  2756. elif is_contiguous:
  2757. # Input is contiguous, output can use contiguous strides
  2758. return create_reinterpret_view(
  2759. x, new_size, FlexibleLayout.contiguous_strides(new_size)
  2760. )
  2761. # Input is non-contiguous. Check if we can get storage/layout.
  2762. if not is_storage_and_layout(x):
  2763. # Can't get storage/layout (e.g., for Pointwise nodes)
  2764. return handle_unbacked_or_dynamic_reshape(x)
  2765. # Try to compute valid output strides.
  2766. storage, old_layout = as_storage_and_layout(x, freeze=False)
  2767. old_stride = old_layout.stride
  2768. # Convert sympy exprs to SymInt for _compute_stride, then convert back
  2769. old_size_symint = V.graph.sizevars.to_symints_or_ints(old_size)
  2770. old_stride_symint = V.graph.sizevars.to_symints_or_ints(old_stride)
  2771. new_size_symint = V.graph.sizevars.to_symints_or_ints(new_size)
  2772. from torch._subclasses.fake_impls import _compute_stride
  2773. # Use size_oblivious=True for unbacked symbols to avoid DDE errors
  2774. new_stride_symint = _compute_stride(
  2775. old_size_symint,
  2776. old_stride_symint,
  2777. new_size_symint,
  2778. size_oblivious=unbacked_symbols_in_sizes,
  2779. )
  2780. if new_stride_symint is not None:
  2781. # Convert SymInt back to sympy expressions
  2782. new_stride = [
  2783. s.node.expr if hasattr(s, "node") else sympy.Integer(s)
  2784. for s in new_stride_symint
  2785. ]
  2786. # View is possible with computed strides
  2787. new_layout = FixedLayout(
  2788. old_layout.device,
  2789. old_layout.dtype,
  2790. new_size,
  2791. new_stride,
  2792. old_layout.offset,
  2793. old_layout.is_pinned,
  2794. )
  2795. return ReinterpretView(data=storage, layout=new_layout)
  2796. # View not possible with current strides
  2797. return handle_unbacked_or_dynamic_reshape(x)
  2798. @staticmethod
  2799. def resolve_negative_size(
  2800. old_size: Sequence[Expr], new_size: Sequence[Expr]
  2801. ) -> tuple[list[Expr], list[Expr]]:
  2802. new_size = [V.graph.sizevars.simplify(x) for x in new_size]
  2803. old_size = [V.graph.sizevars.simplify(x) for x in old_size]
  2804. new_size = list(new_size)
  2805. for i in range(len(new_size)):
  2806. if new_size[i] == -1:
  2807. new_size[i] = sympy.S.One
  2808. new_size[i] = CleanDiv(sympy_product(old_size), sympy_product(new_size))
  2809. break
  2810. V.graph.sizevars.check_equals(sympy_product(old_size), sympy_product(new_size))
  2811. return old_size, new_size
  2812. @classmethod
  2813. def dynamic_reshape_indexer(
  2814. cls,
  2815. old_size: Sequence[_IntLike],
  2816. new_size: Sequence[_IntLike],
  2817. dense_dim: Optional[int] = None,
  2818. ) -> Callable[[Sequence[_T]], Sequence[_V]]:
  2819. try:
  2820. reindex = cls._dynamic_reshape_indexer(old_size, new_size, dense_dim)
  2821. except (AssertionError, IndexError):
  2822. # optimistic algorithm failed, lets do a fallback
  2823. flat = [sympy_product(old_size)]
  2824. reindex1 = cls._dynamic_reshape_indexer(old_size, flat)
  2825. reindex2 = cls._dynamic_reshape_indexer(flat, new_size)
  2826. reindex = fuse_reindexing(reindex1, reindex2)
  2827. return reindex
  2828. @staticmethod
  2829. def _dynamic_reshape_indexer(
  2830. old_size: Sequence[Expr],
  2831. new_size: Sequence[Expr],
  2832. dense_dim: Optional[int] = None,
  2833. ) -> Callable[[Sequence[Expr]], Sequence[Expr]]:
  2834. """
  2835. Perform a reshape entirely by modifying indexing math
  2836. """
  2837. size_hint = V.graph.sizevars.size_hint
  2838. # TODO: These symbols may not escape, if they don't assert so and
  2839. # treat them as temporary
  2840. vars = [
  2841. sympy_index_symbol_with_prefix(SymT.VIEW, i) for i in range(len(new_size))
  2842. ]
  2843. stack_new = list(zip(vars, new_size))
  2844. stack_old = list(old_size)
  2845. # process the dense dim first
  2846. reordering_dense_dim = (
  2847. dense_dim is not None
  2848. and dense_dim != len(stack_old) - 1
  2849. and len(new_size) == 1
  2850. )
  2851. if reordering_dense_dim:
  2852. assert dense_dim is not None # mypy
  2853. old_dim = stack_old.pop(dense_dim)
  2854. stack_old.append(old_dim)
  2855. view_expr = []
  2856. while stack_new and stack_old:
  2857. size_old = stack_old.pop()
  2858. var, size_new = stack_new.pop()
  2859. if size_old == 1:
  2860. view_expr.append(sympy.S.Zero)
  2861. stack_new.append((var, size_new)) # re-add
  2862. elif size_new == 1:
  2863. stack_old.append(size_old) # re-add
  2864. elif size_hint(size_new) == size_hint(size_old):
  2865. view_expr.append(var)
  2866. V.graph.sizevars.check_equals(size_new, size_old)
  2867. elif size_hint(size_new) < size_hint(size_old):
  2868. while size_hint(size_new) < size_hint(size_old):
  2869. var2, size_new2 = stack_new.pop()
  2870. var = var2 * size_new + var
  2871. size_new = size_new * size_new2
  2872. view_expr.append(var)
  2873. V.graph.sizevars.check_equals(size_new, size_old)
  2874. elif size_hint(size_new) > size_hint(size_old):
  2875. divisor = sympy.S.One
  2876. modulus = size_old
  2877. view_expr.append(ModularIndexing(var, divisor, modulus))
  2878. divisor = divisor * modulus
  2879. while size_hint(size_new) > size_hint(size_old):
  2880. modulus = stack_old.pop()
  2881. view_expr.append(ModularIndexing(var, divisor, modulus))
  2882. divisor = divisor * modulus
  2883. size_old = size_old * modulus
  2884. V.graph.sizevars.check_equals(size_new, size_old)
  2885. else:
  2886. raise AssertionError
  2887. while stack_old:
  2888. size_old = stack_old.pop()
  2889. V.graph.sizevars.check_equals(size_old, 1)
  2890. view_expr.append(sympy.S.Zero)
  2891. while stack_new:
  2892. var, size_new = stack_new.pop()
  2893. V.graph.sizevars.check_equals(size_new, 1)
  2894. if dense_dim is not None and len(new_size) == 1:
  2895. view_expr.reverse()
  2896. # Move the last expression (dense dim) to its original position
  2897. dense_expr = view_expr.pop()
  2898. view_expr.insert(dense_dim, dense_expr)
  2899. else:
  2900. view_expr.reverse()
  2901. assert len(view_expr) == len(old_size)
  2902. def reindex(
  2903. index: Sequence[Expr],
  2904. ) -> Sequence[Expr]:
  2905. assert len(index) == len(vars), (len(index), len(vars))
  2906. replacements = dict(zip(vars, index))
  2907. return tuple(sympy_subs(x, replacements) for x in view_expr)
  2908. return reindex
  2909. @ir_dataclass
  2910. class ReinterpretView(BaseView):
  2911. """Pretend our storage has a different layout"""
  2912. layout: Layout
  2913. def __post_init__(self) -> None:
  2914. super().__post_init__()
  2915. if isinstance(self.data, BaseView):
  2916. object.__setattr__(self, "data", self.data.unwrap_view())
  2917. def __str__(self) -> str:
  2918. return self.str_helper(
  2919. [
  2920. self.data,
  2921. self.layout,
  2922. ]
  2923. )
  2924. __repr__ = __str__
  2925. def get_name(self) -> str:
  2926. return self.data.get_name()
  2927. def get_device(self) -> Optional[torch.device]:
  2928. return self.layout.device
  2929. def get_origin_node(self) -> Optional[torch.fx.Node]:
  2930. return None
  2931. @property
  2932. def dtype(self) -> torch.dtype:
  2933. return self.layout.dtype
  2934. def get_size(self) -> Sequence[Expr]:
  2935. return list(self.layout.size)
  2936. def get_stride(self) -> Sequence[Expr]:
  2937. return list(self.layout.stride)
  2938. def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
  2939. def loader(index: Sequence[Expr]) -> OpsValue:
  2940. indexer = self.layout.make_indexer()
  2941. tmp_loader = ops.load(self.get_name(), indexer(index))
  2942. if self.layout.dtype != self.data.dtype:
  2943. return ops.to_dtype_bitcast(tmp_loader, self.dtype, self.data.dtype)
  2944. else:
  2945. return tmp_loader
  2946. return loader
  2947. def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
  2948. return self.layout.make_indexer()
  2949. def get_layout(self) -> Layout:
  2950. return self.layout
  2951. def freeze_layout(self) -> None:
  2952. pass
  2953. @cache_on_self_and_args("ReinterpretView")
  2954. def get_free_symbol_uses(
  2955. self, unbacked_only: bool = False
  2956. ) -> OrderedSet[sympy.Symbol]:
  2957. return (
  2958. get_free_symbols(self.layout.size, unbacked_only)
  2959. | get_free_symbols(self.layout.stride, unbacked_only)
  2960. | get_free_symbols(self.layout.offset, unbacked_only)
  2961. )
  2962. def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
  2963. # reinterpret_tensor is similar to as_strided except:
  2964. # - offset is added to the existing offset (rather than replacing it)
  2965. # - view tracking is disabled similar to unsafe_view
  2966. return V.graph.wrapper_code.codegen_reinterpret_view(
  2967. self.data,
  2968. self.layout.size,
  2969. self.layout.stride,
  2970. self.layout.offset,
  2971. writer.writeline if writer is not None else V.graph.wrapper_code.writeline,
  2972. dtype=self.layout.dtype,
  2973. )
  2974. def num_reads(self) -> int:
  2975. return 1
  2976. @ir_dataclass
  2977. class DtypeView(BaseView):
  2978. """Pretend our storage has a different type"""
  2979. target_dtype: torch.dtype
  2980. @classmethod
  2981. def create(cls, x: IRNode, new_dtype: torch.dtype) -> BaseView:
  2982. if is_storage_and_layout(x):
  2983. storage, old_layout = as_storage_and_layout(x)
  2984. new_layout = FixedLayout(
  2985. old_layout.device,
  2986. new_dtype,
  2987. old_layout.size,
  2988. old_layout.stride,
  2989. old_layout.offset,
  2990. old_layout.is_pinned,
  2991. )
  2992. return ReinterpretView(data=storage, layout=new_layout)
  2993. return DtypeView(data=x, target_dtype=new_dtype)
  2994. def __str__(self) -> str:
  2995. return self.str_helper([self.data, self.target_dtype])
  2996. __repr__ = __str__
  2997. @property
  2998. def dtype(self) -> torch.dtype:
  2999. return self.target_dtype
  3000. def get_size(self) -> Sequence[Expr]:
  3001. return self.data.get_size()
  3002. def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
  3003. inner = self.data.make_loader()
  3004. def loader(idx: Sequence[Expr]) -> OpsValue:
  3005. return ops.to_dtype_bitcast(inner(idx), self.target_dtype, self.data.dtype)
  3006. return loader
  3007. class SliceView(View):
  3008. @classmethod
  3009. def normalize_start_end(
  3010. cls, x: IRNode, dim: int, start: int, end: int
  3011. ) -> tuple[int, int]:
  3012. """
  3013. Normalize start and end such that both are in the range
  3014. [0, x.get_size()[dim]] and start <= end.
  3015. """
  3016. sizevars = V.graph.sizevars
  3017. dim_size = x.get_size()[dim]
  3018. if any(free_unbacked_symbols(x) for x in (start, end, dim_size)):
  3019. min_func = sympy.Min
  3020. max_func = sympy.Max
  3021. else:
  3022. min_func = sizevars.evaluate_min
  3023. max_func = sizevars.evaluate_max
  3024. def clamp(x: Expr, lower: int, upper: int) -> Expr:
  3025. clamped_lower = (
  3026. x if sizevars.statically_known_geq(x, lower) else max_func(x, lower)
  3027. )
  3028. clamped_full = (
  3029. clamped_lower
  3030. if sizevars.statically_known_leq(clamped_lower, upper)
  3031. else min_func(clamped_lower, upper)
  3032. )
  3033. return clamped_full
  3034. def clamp_wrap(
  3035. val: Union[int, None], lower: int, upper: int, default: Union[Expr, int]
  3036. ) -> Union[Expr, int]:
  3037. if val is None:
  3038. # TODO(rec): can this really happen?
  3039. return default
  3040. val = cls.handle_negative_index(val, dim_size)
  3041. return clamp(val, lower, upper)
  3042. start = clamp_wrap(start, 0, dim_size, 0)
  3043. end = clamp_wrap(end, start, dim_size, dim_size)
  3044. return start, end
  3045. @classmethod
  3046. def create( # type: ignore[override]
  3047. cls,
  3048. x: IRNode,
  3049. dim: int,
  3050. start: int,
  3051. end: int,
  3052. step: int = 1,
  3053. clamp: bool = True,
  3054. ) -> IRNode:
  3055. step = sympy.expand(step)
  3056. assert isinstance(step, Expr) or step > 0, step
  3057. try:
  3058. if start == 0 and end >= 2**63 - 1 and step == 1:
  3059. return x
  3060. except TypeError:
  3061. pass
  3062. new_size = list(x.get_size())
  3063. # NB: Ordinarily we default to clamping.
  3064. # We only don't clamp for split_with_sizes. For split_with_sizes, sizes should be already valid
  3065. # failing in this situation is ok, since invalid sizes could trigger silent errors.
  3066. if clamp:
  3067. start, end = cls.normalize_start_end(x, dim, start, end)
  3068. new_size[dim] = FloorDiv(end - start + (step - 1), step)
  3069. if is_storage_and_layout(x):
  3070. # Fast path
  3071. storage, old_layout = as_storage_and_layout(x)
  3072. new_stride = list(old_layout.stride)
  3073. new_stride[dim] = new_stride[dim] * step
  3074. new_layout = FixedLayout(
  3075. old_layout.device,
  3076. old_layout.dtype,
  3077. new_size,
  3078. new_stride,
  3079. old_layout.offset + old_layout.stride[dim] * start,
  3080. old_layout.is_pinned,
  3081. )
  3082. return ReinterpretView(data=storage, layout=new_layout)
  3083. def reindex(
  3084. index: Sequence[Expr],
  3085. ) -> Sequence[Expr]:
  3086. assert len(index) == len(new_size), f"wrong ndim {index} {new_size}"
  3087. index = list(index)
  3088. index[dim] = index[dim] * step + start
  3089. return index
  3090. # redirect to a generic view
  3091. return SliceView(data=x, size=new_size, reindex=reindex)
  3092. @ir_dataclass
  3093. class BaseConstant(IRNode):
  3094. dtype: torch.dtype
  3095. device: torch.device
  3096. def get_size(self) -> Sequence[Expr]:
  3097. return ()
  3098. def get_device(self) -> Optional[torch.device]:
  3099. return self.device
  3100. def get_origin_node(self) -> Optional[torch.fx.Node]:
  3101. return None
  3102. def get_reads(self) -> OrderedSet[Dep]:
  3103. return OrderedSet()
  3104. @ir_dataclass
  3105. class Constant(BaseConstant):
  3106. value: Any
  3107. dtype: torch.dtype
  3108. device: torch.device
  3109. def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
  3110. def loader(index: Sequence[Expr]) -> OpsValue:
  3111. return ops.constant(self.value, self.dtype)
  3112. return loader
  3113. def realize(self) -> Optional[str]:
  3114. pass
  3115. def constant_to_device(self, device: torch.device) -> IRNode:
  3116. return Constant(value=self.value, dtype=self.dtype, device=device)
  3117. @ir_dataclass
  3118. class IndexingConstant(BaseConstant):
  3119. index: Any
  3120. dtype: torch.dtype
  3121. device: torch.device
  3122. def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
  3123. def loader(index: Sequence[Expr]) -> OpsValue:
  3124. return ops.index_expr(self.index, self.dtype)
  3125. return loader
  3126. def constant_to_device(self, device: torch.device) -> IRNode:
  3127. return IndexingConstant(index=self.index, dtype=self.dtype, device=device)
  3128. def is_contiguous_strides_for_shape(
  3129. stride: Sequence[_IntLike], shape: Sequence[_IntLike]
  3130. ) -> bool:
  3131. expected_stride = 1
  3132. expected_stride_max = 1
  3133. for x, y in reversed(tuple(zip(shape, stride))):
  3134. if x == 1:
  3135. continue
  3136. if not V.graph.sizevars.statically_known_equals(
  3137. y, expected_stride
  3138. ) and not V.graph.sizevars.statically_known_equals(y, expected_stride_max):
  3139. return False
  3140. expected_stride_max *= sympy.Max(1, x)
  3141. expected_stride *= x
  3142. return True
  3143. def get_align_for_dtype(dtype: torch.dtype) -> int:
  3144. return config.padding_alignment_bytes // dtype.itemsize
  3145. class OutputSpec:
  3146. """Abstract base for Layout, MultiOutputLayout, NoneLayout.
  3147. Represents the memory layout of the output of an Operation."""
  3148. def get_device(self) -> Optional[torch.device]:
  3149. raise NotImplementedError(type(self).__name__)
  3150. def storage_size(self) -> int:
  3151. raise NotImplementedError(type(self).__name__)
  3152. def get_free_symbol_uses(
  3153. self, unbacked_only: bool = False
  3154. ) -> OrderedSet[sympy.Symbol]:
  3155. raise NotImplementedError(type(self).__name__)
  3156. @ir_dataclass
  3157. class Layout(OutputSpec):
  3158. """
  3159. Layout base class
  3160. Carries tensor meta-information including offset and
  3161. whether it is pinned.
  3162. """
  3163. def __init__(
  3164. self,
  3165. device: torch.device,
  3166. dtype: torch.dtype,
  3167. size: Sequence[Expr],
  3168. stride: Optional[Sequence[Expr]] = None,
  3169. offset: Expr = Integer(0),
  3170. is_pinned: bool = False,
  3171. ) -> None:
  3172. if stride is None:
  3173. stride = FlexibleLayout.contiguous_strides(size)
  3174. # pyrefly: ignore [read-only]
  3175. self.device = device
  3176. self.dtype = dtype
  3177. assert len(size) == len(stride), f"size={size}, stride={stride}"
  3178. assert all(isinstance(s, (Expr, int)) for s in size)
  3179. self._size = size
  3180. self._stride = stride
  3181. self._offset = offset
  3182. self.is_pinned = is_pinned
  3183. # is_pinned implies cpu
  3184. assert (not self.is_pinned) or (self.device.type == "cpu"), (
  3185. "Only CPU tensors can be pinned"
  3186. )
  3187. @property
  3188. def size(self) -> Sequence[Expr]:
  3189. return self._size
  3190. @size.setter
  3191. def size(self, value: Sequence[Expr]) -> None:
  3192. self._size = value
  3193. @property
  3194. def stride(self) -> Sequence[Expr]:
  3195. return self._stride
  3196. @stride.setter
  3197. def stride(self, value: Sequence[Expr]) -> None:
  3198. self._stride = value
  3199. @property
  3200. def offset(self) -> Expr:
  3201. return self._offset
  3202. @offset.setter
  3203. def offset(self, value: Expr) -> None:
  3204. self._offset = value
  3205. def __str__(self) -> str:
  3206. offset = ""
  3207. if self.offset != 0:
  3208. offset = f", offset={self.offset}"
  3209. device_index_str = "" if self.device.index is None else f":{self.device.index}"
  3210. is_pinned_str = ""
  3211. if self.is_pinned:
  3212. is_pinned_str = f", is_pinned={self.is_pinned}"
  3213. return (
  3214. f"{type(self).__name__}('{self.device.type}{device_index_str}', {self.dtype}, "
  3215. f"size={self.size}, stride={self.stride}{offset}{is_pinned_str})"
  3216. )
  3217. __repr__ = __str__
  3218. def get_device(self) -> torch.device:
  3219. return self.device
  3220. def get_example(self) -> torch.Tensor:
  3221. with V.fake_mode:
  3222. return torch.empty_strided(
  3223. convert_shape_to_symint(self.size),
  3224. convert_shape_to_symint(self.stride),
  3225. dtype=self.dtype,
  3226. device=self.device,
  3227. pin_memory=self.is_pinned,
  3228. )
  3229. def is_contiguous(self) -> bool:
  3230. return is_contiguous_strides_for_shape(self.stride, self.size)
  3231. @staticmethod
  3232. def is_channels_last_contiguous(
  3233. shape: Sequence[_IntLike], strides: Sequence[_IntLike]
  3234. ) -> bool:
  3235. ndim = len(shape)
  3236. if ndim not in [4, 5] or shape[1] == 1:
  3237. return False
  3238. for left, right, size in zip(
  3239. # pyrefly: ignore [bad-specialization]
  3240. strides,
  3241. # pyrefly: ignore [bad-specialization]
  3242. make_channels_last_strides_for(shape),
  3243. shape,
  3244. ):
  3245. if size != 1 and left != right:
  3246. return False
  3247. return True
  3248. def is_transposed(self) -> bool:
  3249. for left, right, size in zip(
  3250. self.stride,
  3251. reversed(FlexibleLayout.contiguous_strides(list(reversed(self.size)))),
  3252. self.size,
  3253. ):
  3254. if size != 1 and left != right:
  3255. return False
  3256. return True
  3257. def is_stride_ordered(self, order: Sequence[int]) -> bool:
  3258. assert len(self.stride) == len(order)
  3259. # ignore dimensions of size 1, they dont affect layout
  3260. non_1_indices = [
  3261. i
  3262. for i, dim in enumerate(self.size)
  3263. if V.graph.sizevars.optimization_hint(dim, fallback=2) != 1
  3264. ]
  3265. stride = [self.stride[i] for i in non_1_indices]
  3266. order: Sequence[int] = [order[i] for i in non_1_indices]
  3267. def sorted_indices(arr: Sequence[int]) -> Sequence[int]:
  3268. sorted_arr = sorted(arr)
  3269. return [sorted_arr.index(element) for element in arr]
  3270. # since we may have removed dimensions, need to re-sort & re-index order
  3271. order = sorted_indices(order)
  3272. # reorder the stride given order
  3273. stride_ordered = [-1] * len(order)
  3274. for i in range(len(order)):
  3275. stride_ordered[order[i]] = stride[i]
  3276. # check if it is in ascending order
  3277. for i in range(len(order) - 1):
  3278. expr = stride_ordered[i] > stride_ordered[i + 1]
  3279. if not isinstance(expr, bool):
  3280. expr = V.graph._shape_env.evaluate_expr(
  3281. stride_ordered[i] > stride_ordered[i + 1], size_oblivious=True
  3282. )
  3283. if expr:
  3284. return False
  3285. return True
  3286. def is_channels_last_stride_ordered(self) -> bool:
  3287. # create channels_last order(NCHW, NCDHW, the C is the first order).
  3288. order = [0] + list(reversed(range(1, len(self.stride) - 1)))
  3289. order = [len(order)] + order
  3290. return self.is_stride_ordered(order)
  3291. @staticmethod
  3292. def _pad_strides(
  3293. in_strides: Sequence[int], size: Sequence[Expr], dtype: torch.dtype
  3294. ) -> Sequence[int]:
  3295. """
  3296. The padding does not change stride order but makes sure all strides larger
  3297. than the threshold are multiple of align.
  3298. """
  3299. align = get_align_for_dtype(dtype)
  3300. if len(in_strides) == 0:
  3301. return in_strides
  3302. if not config.pad_channels_last and Layout.is_channels_last_contiguous(
  3303. size, in_strides
  3304. ):
  3305. return in_strides
  3306. current_fx_node = V.get_current_node()
  3307. if hasattr(current_fx_node, "meta") and current_fx_node.meta.get(
  3308. "dislike_padding", False
  3309. ):
  3310. return in_strides
  3311. # Skip padding the strides for dynamic shapes based on config.pad_dynamic_shape
  3312. # Checking both shape and strides, as there are cases where only one is dynamic
  3313. is_dynamic = not all(
  3314. isinstance(s, (int, sympy.Integer))
  3315. for s in itertools.chain(in_strides, size)
  3316. )
  3317. if not config.pad_dynamic_shapes and is_dynamic:
  3318. return in_strides
  3319. shape_env = V.graph._shape_env if hasattr(V.graph, "_shape_env") else None
  3320. def contains_unbacked_symints(expr: sympy.Expr | int) -> bool:
  3321. if shape_env is None:
  3322. return False
  3323. if not isinstance(expr, sympy.Expr):
  3324. return False
  3325. return any(shape_env.is_unbacked_symint(s) for s in expr.free_symbols)
  3326. # Skip padding the strides when it contains unbacked symints for now.
  3327. if shape_env and any(contains_unbacked_symints(s) for s in in_strides):
  3328. return in_strides
  3329. stride_order = get_stride_order(in_strides, shape_env)
  3330. fill_order = stride_order2fill_order(stride_order)
  3331. new_strides = [0 for _ in range(len(in_strides))]
  3332. # since we pad when the layout is flexible, we can decide the
  3333. # smallest stride to be 1.
  3334. new_strides[fill_order[0]] = 1
  3335. padded = False
  3336. for rank, idx in enumerate(fill_order[1:], start=1):
  3337. prev_idx = fill_order[rank - 1]
  3338. stride = new_strides[prev_idx] * size[prev_idx]
  3339. # Static stride and meets padding conditions OR
  3340. # Dynamic stride and config.pad_dynamic_shape=True
  3341. require_padding = (
  3342. isinstance(stride, (int, sympy.Integer))
  3343. and stride > config.padding_stride_threshold
  3344. and stride % align != 0
  3345. ) or (isinstance(stride, sympy.Expr) and config.pad_dynamic_shapes)
  3346. new_strides[idx] = stride
  3347. if require_padding:
  3348. new_strides[idx] = ceildiv(stride, align) * align
  3349. padded = True
  3350. if not padded:
  3351. # Consider a tensor with shape [256, 1, 5, 5]
  3352. # Avoid strides like [25, 5, 5, 1] being padded to equivalent strides
  3353. # [25, 25, 5, 1].
  3354. return in_strides
  3355. # pyrefly: ignore [bad-assignment]
  3356. metrics.num_comprehensive_padding += 1
  3357. return new_strides
  3358. def pad_strides(self) -> None:
  3359. assert isinstance(self, FlexibleLayout), type(self)
  3360. assert self.stride is not None
  3361. self.stride = self._pad_strides(self.stride, self.size, self.dtype)
  3362. def should_pad_strides(self) -> bool:
  3363. return config.comprehensive_padding and isinstance(self, FlexibleLayout)
  3364. def as_fixed(self) -> FixedLayout:
  3365. if isinstance(self, FixedLayout):
  3366. return self
  3367. if self.should_pad_strides():
  3368. self.pad_strides()
  3369. return FixedLayout(
  3370. self.device,
  3371. self.dtype,
  3372. self.size,
  3373. self.stride,
  3374. self.offset,
  3375. self.is_pinned,
  3376. )
  3377. def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
  3378. assert FlexibleLayout.allow_indexing, (
  3379. f"convert {type(self).__name__} to FixedLayout first"
  3380. )
  3381. return self.as_fixed().make_indexer()
  3382. def __eq__(self, other: object) -> bool:
  3383. return (
  3384. isinstance(other, Layout)
  3385. and self.device == other.device
  3386. and self.dtype == other.dtype
  3387. and self.size == other.size
  3388. and self.stride == other.stride
  3389. and self.offset == other.offset
  3390. and self.is_pinned == other.is_pinned
  3391. )
  3392. def storage_size(self) -> Expr:
  3393. return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type]
  3394. @cache_on_self_and_args("Layout")
  3395. def get_free_symbol_uses(
  3396. self, unbacked_only: bool = False
  3397. ) -> OrderedSet[sympy.Symbol]:
  3398. return (
  3399. get_free_symbols(self.size, unbacked_only)
  3400. | get_free_symbols(self.stride, unbacked_only)
  3401. | get_free_symbols(self.offset, unbacked_only)
  3402. )
  3403. class FixedLayout(Layout):
  3404. """A Tensor layout we cannot change"""
  3405. def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
  3406. """A closure containing math to read a given element"""
  3407. return _fixed_indexer(self.size, self.stride, self.offset)
  3408. class FlexibleLayout(Layout):
  3409. """
  3410. A Tensor layout that we are allowed to change
  3411. Assumption: layout change should NOT add or remove free symbols
  3412. """
  3413. allow_indexing = False
  3414. def get_fixed_layout_without_freezing(self) -> FixedLayout:
  3415. """
  3416. Compute what the strides would be if this layout were frozen,
  3417. without actually modifying the layout. This is used for speculative
  3418. stride computation during Triton template code generation.
  3419. """
  3420. # Create a temporary copy and use as_fixed to keep freezing path in sync
  3421. return copy.deepcopy(self).as_fixed()
  3422. # WARNING! This doesn't handle zero size tensors correctly
  3423. @staticmethod
  3424. def contiguous_strides(sizes: Sequence[int]) -> list[Expr]:
  3425. if len(sizes) == 0:
  3426. return []
  3427. reversed_strides = [sympy.S.One]
  3428. for size in reversed(sizes[1:]):
  3429. reversed_strides.append(size * reversed_strides[-1])
  3430. return list(reversed(reversed_strides))
  3431. @staticmethod
  3432. def fill_ordered(sizes: Sequence[int], order: Sequence[int]) -> list[Expr]:
  3433. """
  3434. Create a stride based on the order the dimensions should be filled in.
  3435. In this format, channels last would be:
  3436. [1, 3, 2, 0]
  3437. """
  3438. assert OrderedSet(range(len(sizes))) == OrderedSet(order), (sizes, order)
  3439. next_stride = sympy.S.One
  3440. strides = [None] * len(order)
  3441. for i in order:
  3442. strides[i] = next_stride
  3443. next_stride = next_stride * sizes[i]
  3444. return strides
  3445. @staticmethod
  3446. def stride_ordered(sizes: Sequence[int], order: Sequence[int]) -> Sequence[Expr]:
  3447. """
  3448. Create a stride based on the sorted order of a permuted range.
  3449. In this format, channels last would be:
  3450. [3, 0, 2, 1]
  3451. """
  3452. assert OrderedSet(range(len(sizes))) == OrderedSet(order)
  3453. fill_order = stride_order2fill_order(order)
  3454. return FlexibleLayout.fill_ordered(sizes, fill_order)
  3455. @staticmethod
  3456. def stride_ordered_for_memory_format(
  3457. sizes: Sequence[int], memory_format: torch.memory_format
  3458. ) -> Sequence[Expr]:
  3459. """
  3460. Create a stride based on a memory format.
  3461. Memory format is translasted into a stride order,
  3462. so channels_last is the same as:
  3463. FlexibleLayout.stride_ordered(sizes, [3, 0, 2, 1])
  3464. This interface does not support memory_format `torch.preserve_format`
  3465. which should be used to deduce a format from another source
  3466. """
  3467. if memory_format == torch.channels_last:
  3468. return FlexibleLayout.stride_ordered(sizes, NHWC_STRIDE_ORDER)
  3469. elif memory_format == torch.channels_last_3d:
  3470. return FlexibleLayout.stride_ordered(sizes, NHWDC_STRIDE_ORDER)
  3471. elif memory_format == torch.contiguous_format:
  3472. return FlexibleLayout.contiguous_strides(sizes)
  3473. else:
  3474. log.debug(
  3475. "stride_ordered_for_memory_format, unsuppored memory_format: %s",
  3476. memory_format,
  3477. )
  3478. raise NotImplementedError
  3479. @staticmethod
  3480. def same_ordered(
  3481. sizes: Sequence[int], stride: Sequence[_IntLike]
  3482. ) -> Sequence[Expr]:
  3483. """
  3484. Create a stride that has the same stride order as given stride
  3485. For example, if given stride is [1000, 1, 100, 10],
  3486. the fill order should be [1, 3, 2, 0]
  3487. """
  3488. assert len(sizes) == len(stride)
  3489. stride = [V.graph.sizevars.size_hint_or_throw(x) for x in stride]
  3490. fill_order = sorted(range(len(stride)), key=stride.__getitem__)
  3491. return FlexibleLayout.fill_ordered(sizes, fill_order)
  3492. @property
  3493. def size(self) -> Sequence[Expr]:
  3494. return self._size
  3495. @size.setter
  3496. def size(self, value: Sequence[Expr]) -> None:
  3497. self.assert_free_symbol_uses_unchanged("size", value)
  3498. self._size = value
  3499. @property
  3500. def stride(self) -> Sequence[Expr]:
  3501. return self._stride
  3502. @stride.setter
  3503. def stride(self, value: Sequence[Expr]) -> None:
  3504. self.assert_free_symbol_uses_unchanged("stride", value)
  3505. self._stride = value
  3506. @property
  3507. def offset(self) -> Expr:
  3508. return self._offset
  3509. @offset.setter
  3510. def offset(self, value: Expr) -> None:
  3511. self.assert_free_symbol_uses_unchanged("offset", value)
  3512. self._offset = value
  3513. def as_stride_order(
  3514. self, order: Sequence[int], allow_padding: bool = False
  3515. ) -> FixedLayout:
  3516. new_stride = self.stride_ordered(self.size, order)
  3517. if self.should_pad_strides() and allow_padding:
  3518. new_stride = self._pad_strides(new_stride, self.size, self.dtype)
  3519. return FixedLayout(
  3520. self.device,
  3521. self.dtype,
  3522. self.size,
  3523. new_stride,
  3524. self.offset,
  3525. self.is_pinned,
  3526. )
  3527. def as_exact_strides(
  3528. self, exact_strides: Sequence[_IntLike], allow_padding: bool = False
  3529. ) -> FixedLayout:
  3530. new_stride = exact_strides
  3531. if self.should_pad_strides() and allow_padding:
  3532. new_stride = self._pad_strides(new_stride, self.size, self.dtype)
  3533. return FixedLayout(
  3534. self.device,
  3535. self.dtype,
  3536. self.size,
  3537. new_stride,
  3538. self.offset,
  3539. self.is_pinned,
  3540. )
  3541. def as_fill_order(self, order: Sequence[int]) -> FixedLayout:
  3542. new_stride: Sequence[int] = self.fill_ordered(self.size, order)
  3543. if self.should_pad_strides():
  3544. new_stride = self._pad_strides(new_stride, self.size, self.dtype)
  3545. return FixedLayout(
  3546. self.device,
  3547. self.dtype,
  3548. self.size,
  3549. new_stride,
  3550. self.offset,
  3551. self.is_pinned,
  3552. )
  3553. def as_same_order(self, stride: Sequence[_IntLike]) -> FixedLayout:
  3554. new_stride = self.same_ordered(self.size, stride)
  3555. if self.should_pad_strides():
  3556. new_stride = self._pad_strides(new_stride, self.size, self.dtype)
  3557. return FixedLayout(
  3558. self.device,
  3559. self.dtype,
  3560. self.size,
  3561. new_stride,
  3562. self.offset,
  3563. self.is_pinned,
  3564. )
  3565. def get_initial_free_symbol_uses(self) -> dict[tuple[str, bool], sympy.Symbol]:
  3566. initial_free_symbols = {}
  3567. for name in ["size", "stride", "offset"]:
  3568. for unbacked_only in [True, False]:
  3569. key = (name, unbacked_only)
  3570. initial_free_symbols[key] = OrderedSet(
  3571. get_free_symbols(getattr(self, name), unbacked_only)
  3572. )
  3573. return initial_free_symbols
  3574. def assert_free_symbol_uses_unchanged(self, name: str, value: IterateExprs) -> None:
  3575. for unbacked_only in [True, False]:
  3576. old_free_symbols = self.initial_free_symbols[(name, unbacked_only)]
  3577. new_free_symbols = OrderedSet(get_free_symbols(value, unbacked_only))
  3578. assert new_free_symbols == old_free_symbols, (
  3579. f"Expected free symbols unchanged, but got {new_free_symbols} vs {old_free_symbols}"
  3580. )
  3581. def __init__(
  3582. self,
  3583. device: torch.device,
  3584. dtype: torch.dtype,
  3585. size: Sequence[Expr],
  3586. stride_order: Optional[Sequence[Union[int, Integer]]] = None,
  3587. is_pinned: bool = False,
  3588. ) -> None:
  3589. if stride_order:
  3590. strides = FlexibleLayout.fill_ordered(size, stride_order)
  3591. else:
  3592. strides = FlexibleLayout.contiguous_strides(size)
  3593. super().__init__(device, dtype, size, strides, is_pinned=is_pinned)
  3594. # record the initial free symbols to check that we do not add new free symbols
  3595. # later when modifying sizes, strides, and offsets.
  3596. self.initial_free_symbols = self.get_initial_free_symbol_uses()
  3597. class NonOwningLayout(Layout):
  3598. """Is a view into the storage of another tensor"""
  3599. def __init__(self, view: Union[BaseView, TensorBox]) -> None:
  3600. layout = view.get_layout()
  3601. super().__init__(
  3602. layout.device,
  3603. layout.dtype,
  3604. layout.size,
  3605. layout.stride,
  3606. )
  3607. self.view = view
  3608. def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
  3609. return self.as_fixed().make_indexer()
  3610. def maybe_guard_aligned(self) -> bool:
  3611. offset = self.view.get_layout().offset
  3612. if offset == 0:
  3613. return True
  3614. from .utils import ALIGNMENT
  3615. return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT)
  3616. @cache_on_self_and_args("NonOwningLayout")
  3617. def get_free_symbol_uses(
  3618. self, unbacked_only: bool = False
  3619. ) -> OrderedSet[sympy.Symbol]:
  3620. assert isinstance(self.view, ReinterpretView)
  3621. box = self.view.data
  3622. assert isinstance(box, StorageBox), type(box)
  3623. input_buffer = box.data
  3624. assert isinstance(input_buffer, Buffer), type(box)
  3625. return input_buffer.layout.get_free_symbol_uses(unbacked_only)
  3626. class CommBufferType(Enum):
  3627. SYMM_MEM = "symm_mem"
  3628. class CommBufferLayout(FixedLayout):
  3629. """
  3630. A layout that signifies the buffer is a comm buffer.
  3631. In terms of striding, the layout is identical to `FixedLayout`.
  3632. Buffers with this layout do not participate in in-place reuse - it can be
  3633. neither the source nor the target for in-place reuse.
  3634. For detailed motivation and usage of this layout, see
  3635. NOTE [lowering-time collective optimization].
  3636. """
  3637. comm_buffer_type: CommBufferType
  3638. group_name: str
  3639. def __init__(
  3640. self,
  3641. layout: Union[FlexibleLayout, FixedLayout],
  3642. comm_buffer_type: CommBufferType,
  3643. group_name: str,
  3644. ):
  3645. fixed = layout.as_fixed() if isinstance(layout, FlexibleLayout) else layout
  3646. super().__init__(
  3647. device=fixed.device,
  3648. dtype=fixed.dtype,
  3649. size=fixed.size,
  3650. stride=fixed.stride,
  3651. offset=fixed.offset,
  3652. is_pinned=fixed.is_pinned,
  3653. )
  3654. self.comm_buffer_type = comm_buffer_type
  3655. self.group_name = group_name
  3656. @ir_dataclass
  3657. class NoneLayout(OutputSpec):
  3658. # This is janky, I figured out what fields to populate by just running
  3659. # the model I was interested in and adding properties/methods as needed.
  3660. # This doesn't inherit from Layout because Layout assumes you have stuff
  3661. # like sizes, but I don't really have anything here.
  3662. #
  3663. # If you have an ir.Node with NoneLayout, you probably need to setup
  3664. # dependencies manually in scheduler
  3665. device: Optional[torch.device]
  3666. size: list[int] = dataclasses.field(default_factory=lambda: [0])
  3667. stride: list[int] = dataclasses.field(default_factory=lambda: [0])
  3668. def storage_size(self) -> int:
  3669. return 0
  3670. def as_fixed(self) -> OutputSpec:
  3671. return self
  3672. def get_device(self) -> Optional[torch.device]:
  3673. return self.device
  3674. class MutationLayoutSHOULDREMOVE(Layout):
  3675. def __init__(self, target: IRNode) -> None:
  3676. super().__init__(
  3677. target.get_device_or_error(),
  3678. target.get_dtype(),
  3679. target.get_size(),
  3680. None,
  3681. )
  3682. self.target = target
  3683. name = self.get_buffer().get_name()
  3684. V.graph.mark_buffer_mutated(name)
  3685. @property
  3686. def stride(self) -> Sequence[Expr]: # type: ignore[override]
  3687. return self.real_layout().stride
  3688. @stride.setter # type: ignore[override]
  3689. def stride(self, value: Never) -> None:
  3690. pass # ignore setting of stride
  3691. def storage_size(self) -> Expr:
  3692. return self.real_layout().storage_size()
  3693. def get_buffer(self) -> Buffer:
  3694. def unwrap_views(target: Any) -> Any:
  3695. if isinstance(target, MutationLayoutSHOULDREMOVE):
  3696. return unwrap_views(target.target)
  3697. if isinstance(target, BaseView):
  3698. return unwrap_views(target.unwrap_view())
  3699. if isinstance(target, MutableBox):
  3700. return unwrap_views(target.data)
  3701. return target
  3702. result = unwrap_views(self.target)
  3703. assert isinstance(result, Buffer), type(result)
  3704. return result
  3705. def real_layout(self) -> Layout:
  3706. layout = self.get_buffer().layout
  3707. assert isinstance(layout, Layout)
  3708. return layout
  3709. @classmethod
  3710. def realize_into(
  3711. cls, src: IRNode, dst: IRNode, unsafe_alias: bool = False
  3712. ) -> IRNode:
  3713. dst.realize()
  3714. # NOTE: We must realize users of `dst` before we realize `src`, since
  3715. # realization order determines scheduling order. Otherwise, src's
  3716. # mutation would be scheduled before the existing users of dst!
  3717. V.graph.mark_buffer_mutated(dst.get_name())
  3718. if isinstance(src, TensorBox):
  3719. src = src.data
  3720. # We copy the contents of src into dst. In most cases this should
  3721. # be fused into a single kernel by the scheduler.
  3722. # NOTE: We cannot change src's layout to mutate dst directly as this
  3723. # would alias src to dst, which is not correct as further mutations to
  3724. # dst would effect users of src. However if there are no more users of
  3725. # dst, we can alias src to dst.
  3726. src.realize_hint()
  3727. if not unsafe_alias:
  3728. node = Pointwise.create(
  3729. device=src.get_device(),
  3730. dtype=src.get_dtype(),
  3731. inner_fn=src.make_loader(),
  3732. ranges=[
  3733. V.graph.sizevars.check_equals_and_simplify(a, b)
  3734. for a, b in zip(src.get_size(), dst.get_size())
  3735. ],
  3736. )
  3737. assert isinstance(node, (BaseView, MutableBox))
  3738. src = node.data
  3739. src.realize()
  3740. assert hasattr(src, "data"), src
  3741. assert isinstance(src.data.layout, FlexibleLayout), type(src.data.layout)
  3742. src.data.layout = MutationLayoutSHOULDREMOVE(dst)
  3743. return src.data
  3744. def as_fixed(self) -> Self: # type: ignore[override]
  3745. return self
  3746. def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
  3747. return self.target.make_indexer()
  3748. @ir_dataclass(frozen=False)
  3749. class Buffer(IRNode, CodegenSymbol):
  3750. # Name is sometimes None; e.g., ForceInPlace, where there isn't
  3751. # a meaningful name
  3752. name: Optional[str]
  3753. layout: OutputSpec
  3754. # Multi-output buffers will define 'outputs: List[Buffer]'. Confusingly,
  3755. # MultiOutput does NOT define this!
  3756. def __post_init__(self) -> None:
  3757. super().__post_init__()
  3758. self._post_init_setattr("origin_node", None)
  3759. def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
  3760. return self.get_layout().make_indexer()
  3761. def get_name(self) -> str:
  3762. assert self.name, self
  3763. return self.name
  3764. def get_example(self) -> Union[torch.Tensor, torch.SymInt]:
  3765. if isinstance(self.layout, Layout):
  3766. return self.layout.get_example()
  3767. raise NotImplementedError(type(self.layout).__name__)
  3768. def get_device(self) -> Optional[torch.device]:
  3769. return self.get_output_spec().get_device()
  3770. def get_defining_op(self) -> Optional[Operation]:
  3771. return None
  3772. @property
  3773. def dtype(self) -> torch.dtype:
  3774. return self.get_layout().dtype
  3775. def get_size(self) -> Sequence[Expr]:
  3776. return [*self.get_layout().size]
  3777. def get_stride(self) -> list[Expr]:
  3778. return [*self.get_layout().stride]
  3779. def get_offset(self) -> Expr:
  3780. return self.get_layout().offset
  3781. def get_layout(self) -> Layout:
  3782. if isinstance(self.layout, Layout):
  3783. return self.layout
  3784. raise NotImplementedError(type(self.layout).__name__)
  3785. def get_output_spec(self) -> OutputSpec:
  3786. return self.layout
  3787. def get_storage_numel(self) -> int:
  3788. return self.get_numel()
  3789. def get_is_pinned(self) -> bool:
  3790. return self.get_layout().is_pinned
  3791. def freeze_layout(self) -> None:
  3792. if isinstance(self.layout, Layout) and not isinstance(
  3793. self.layout, NonOwningLayout
  3794. ):
  3795. self.layout = self.layout.as_fixed()
  3796. def freeze_layout_with_stride_order(
  3797. self, order: Sequence[int], allow_padding: bool = False
  3798. ) -> None:
  3799. assert isinstance(self.layout, FlexibleLayout), type(self.layout)
  3800. self.layout = self.layout.as_stride_order(order, allow_padding=allow_padding)
  3801. def freeze_layout_with_fill_order(self, order: Sequence[int]) -> None:
  3802. assert isinstance(self.layout, FlexibleLayout), type(self.layout)
  3803. self.layout = self.layout.as_fill_order(order)
  3804. def freeze_layout_with_same_order(self, stride: Sequence[int]) -> None:
  3805. assert isinstance(self.layout, FlexibleLayout), type(self.layout)
  3806. self.layout = self.layout.as_same_order(stride)
  3807. def freeze_layout_with_exact_strides(
  3808. self, exact_strides: Sequence[int], allow_padding: bool = False
  3809. ) -> None:
  3810. assert isinstance(self.layout, FlexibleLayout), type(self.layout)
  3811. self.layout = self.layout.as_exact_strides(
  3812. exact_strides, allow_padding=allow_padding
  3813. )
  3814. def is_zero_elements(self) -> bool:
  3815. return V.graph.sizevars.statically_known_true(sympy.Eq(self.get_numel(), 0))
  3816. def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
  3817. # Loading from a zero-element buffer is a no-op
  3818. if self.is_zero_elements():
  3819. return partial(nop_loader_fn, dtype=self.get_dtype())
  3820. def loader(index: Sequence[Expr]) -> OpsValue:
  3821. indexer = self.make_indexer()
  3822. return ops.load(self.name or "unnamed", indexer(index))
  3823. return loader
  3824. def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
  3825. return self.get_name()
  3826. def decide_layout(self) -> None:
  3827. pass
  3828. def get_inputs_that_alias_output(self) -> Sequence[str]:
  3829. if isinstance(self.layout, NonOwningLayout):
  3830. return [self.layout.view.get_name()]
  3831. return ()
  3832. def get_mutation_names(self) -> Sequence[str]:
  3833. if isinstance(self.layout, MutationLayoutSHOULDREMOVE):
  3834. return [self.layout.target.get_name()]
  3835. return ()
  3836. def get_read_names(self) -> OrderedSet[str]:
  3837. return OrderedSet([self.get_name()])
  3838. @cache_on_self_and_args("Buffer")
  3839. def get_free_symbol_uses(
  3840. self, unbacked_only: bool = False
  3841. ) -> OrderedSet[sympy.Symbol]:
  3842. return OrderedSet()
  3843. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  3844. return OrderedSet()
  3845. def realize(self) -> Optional[str]:
  3846. pass
  3847. def should_allocate(self) -> bool:
  3848. # Returns False by default.
  3849. return False
  3850. @ir_dataclass(frozen=False)
  3851. class OperationBuffer(Buffer, Operation):
  3852. # An operation that produces a single output buffer
  3853. def get_outputs(self) -> list[Buffer]:
  3854. return [self]
  3855. def get_defining_op(self) -> Operation:
  3856. return self
  3857. # Skip implementation in Buffer
  3858. get_operation_name = Operation.get_operation_name
  3859. def __post_init__(self) -> None:
  3860. Buffer.__post_init__(self)
  3861. Operation.__post_init__(self)
  3862. class InputBuffer(Buffer):
  3863. def num_reads(self) -> int:
  3864. return 1
  3865. class DonatedBuffer(InputBuffer):
  3866. """
  3867. Represents a donated buffer which is a saved tensor that is not alias to any
  3868. fwd inputs, fwd user outputs, and bwd outputs. We generally cannot inplace
  3869. reuse the input tensor memory during backward since it might be used in another
  3870. function. However, donated buffer can be inplace reused during backward
  3871. to save memory.
  3872. """
  3873. class ConstantBuffer(InputBuffer):
  3874. override_device: Optional[torch.device] = None
  3875. def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
  3876. def loader(index: Sequence[Expr]) -> OpsValue:
  3877. indexer = self.get_layout().make_indexer()
  3878. return ops.load(
  3879. V.graph.constant_name(self.get_name(), self.override_device),
  3880. indexer(index),
  3881. )
  3882. return loader
  3883. def constant_to_device(self, device: torch.device) -> IRNode:
  3884. return ConstantBuffer(
  3885. name=V.graph.constant_name(self.get_name(), device), layout=self.layout
  3886. )
  3887. @ir_dataclass
  3888. class NoneAsConstantBuffer(IRNode):
  3889. def get_reads(self) -> OrderedSet[Dep]:
  3890. return OrderedSet()
  3891. @cache_on_self_and_args("NoneAsConstantBuffer")
  3892. def get_free_symbol_uses(
  3893. self, unbacked_only: bool = False
  3894. ) -> OrderedSet[sympy.Symbol]:
  3895. return OrderedSet()
  3896. def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
  3897. return V.graph.wrapper_code.none_str
  3898. def get_output_spec(self) -> OutputSpec:
  3899. return NoneLayout(device=None)
  3900. def has_tensor_output(self) -> bool:
  3901. return False
  3902. @ir_dataclass
  3903. class ShapeAsConstantBuffer(IRNode):
  3904. expr: Expr
  3905. @cache_on_self_and_args("ShapeAsConstantBuffer")
  3906. def get_free_symbol_uses(
  3907. self, unbacked_only: bool = False
  3908. ) -> OrderedSet[sympy.Symbol]:
  3909. return get_free_symbols(self.expr, unbacked_only)
  3910. def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
  3911. return V.graph.wrapper_code.codegen_sizevar(self.expr)
  3912. def has_tensor_output(self) -> bool:
  3913. return False
  3914. @ir_dataclass(frozen=False)
  3915. class ComputedBuffer(OperationBuffer):
  3916. """
  3917. Represents a buffer that is computed during kernel execution rather than being an input.
  3918. """
  3919. data: Loops
  3920. _force_realize: ClassVar[bool] = False
  3921. # fields for split reduction
  3922. _split_size: Optional[int] = None
  3923. _original_inner_fn: Optional[Callable[..., Any]] = None
  3924. _original_ranges: Optional[Sequence[_IntLike]] = None
  3925. _original_reduction_ranges: Optional[Sequence[_IntLike]] = None
  3926. @contextlib.contextmanager
  3927. def with_original_inner_fn(self) -> Iterator[None]:
  3928. assert self._split_size is not None
  3929. assert self._original_inner_fn is not None
  3930. assert self._original_ranges is not None
  3931. assert self._original_reduction_ranges is not None
  3932. assert isinstance(self.data, Reduction), f"{type(self.data)}"
  3933. old_data = self.data
  3934. old_layout = self.layout
  3935. try:
  3936. new_data = Reduction(
  3937. device=old_data.device,
  3938. dtype=old_data.dtype,
  3939. inner_fn=self._original_inner_fn,
  3940. ranges=self._original_ranges,
  3941. reduction_ranges=self._original_reduction_ranges,
  3942. reduction_type=old_data.reduction_type,
  3943. src_dtype=old_data.src_dtype,
  3944. reduction_hint=old_data.reduction_hint,
  3945. )
  3946. self.data = new_data
  3947. # this layout does not matter since we skip tl.store
  3948. # later
  3949. self.layout = FixedLayout(
  3950. old_data.device,
  3951. old_data.dtype,
  3952. self._original_ranges,
  3953. )
  3954. self.get_default_sizes_body.clear_cache(self)
  3955. yield
  3956. finally:
  3957. self.data = old_data
  3958. self.layout = old_layout
  3959. @staticmethod
  3960. @contextlib.contextmanager
  3961. def force_realize() -> Iterator[None]:
  3962. old_value = ComputedBuffer._force_realize
  3963. try:
  3964. ComputedBuffer._force_realize = True
  3965. yield
  3966. finally:
  3967. ComputedBuffer._force_realize = old_value
  3968. def get_computed_buffer_name(self) -> Optional[str]:
  3969. """
  3970. Returns self.name if it exists, otherwise returns the name of the data node if that exists.
  3971. If neither exist, returns None.
  3972. """
  3973. if self.name is not None:
  3974. return self.name
  3975. if hasattr(self.data, "name"):
  3976. return self.data.name
  3977. return None
  3978. def num_reads(self) -> int:
  3979. return self.data.num_reads()
  3980. def get_reads(self) -> OrderedSet[Dep]:
  3981. return self.data.get_reads()
  3982. def get_read_names(self) -> OrderedSet[str]:
  3983. return self.data.get_read_names()
  3984. def get_read_writes(self) -> dependencies.ReadWrites:
  3985. if not isinstance(self.data, (Reduction, Scan, Sort, Pointwise)):
  3986. return dependencies.ReadWrites(
  3987. reads=OrderedSet(),
  3988. writes=OrderedSet(),
  3989. index_exprs=OrderedSet(),
  3990. )
  3991. with patch.object(FlexibleLayout, "allow_indexing", True):
  3992. if self.data.get_reduction_type():
  3993. return extract_read_writes(
  3994. self.get_store_function(),
  3995. self.data.get_pointwise_size(),
  3996. self.data.get_reduction_size(),
  3997. )
  3998. else:
  3999. return extract_read_writes(
  4000. self.get_store_function(),
  4001. self.data.get_size(),
  4002. )
  4003. @cache_on_self_and_args("ComputedBuffer")
  4004. def get_free_symbol_uses(
  4005. self, unbacked_only: bool = False
  4006. ) -> OrderedSet[sympy.Symbol]:
  4007. # Ordinarily, we'd like to just peek at the arguments list,
  4008. # but ComputedBuffers have no argument list.
  4009. #
  4010. # Morally, this logic needs to be synchronized with the
  4011. # KernelArgs.size calls, which are responsible for making symbols make
  4012. # there way as kernel arguments (and it is precisely passing in one of
  4013. # those symbols that establishes a dependency). However, we haven't
  4014. # started codegen yet so we can't directly reuse that logic.
  4015. #
  4016. # One thing you might wonder is if this is enough for a ComputedBuffer
  4017. # denoting a reduction over i0. Empirically, it is enough, but for an
  4018. # unusual reason: we only need accurate dependencies for item() call,
  4019. # but it's impossible to end up with a reduction over i0 from an
  4020. # item() call without a regular non-reduction buffer first.
  4021. result = self.layout.get_free_symbol_uses(
  4022. unbacked_only
  4023. ) | self.data.get_free_symbol_uses(unbacked_only)
  4024. if self.has_store_function():
  4025. result |= self.get_read_writes().get_free_symbol_uses(unbacked_only)
  4026. return result
  4027. def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
  4028. if (
  4029. not self.get_reduction_type()
  4030. and self.name not in V.graph.mutated_buffers
  4031. and self.num_reads() == 0
  4032. and not self._force_realize
  4033. ):
  4034. # inline this op rather than generating ops.load()
  4035. return self.data.make_loader()
  4036. return super().make_loader()
  4037. def has_store_function(self) -> bool:
  4038. return isinstance(self.data, (Reduction, Scan, Sort, Pointwise))
  4039. def get_store_function(self) -> Callable[..., None]:
  4040. indexer = self.get_layout().as_fixed().make_indexer()
  4041. if isinstance(self.data, (Reduction, Scan, Sort)):
  4042. return partial(self.data.store_reduction, self.name, indexer)
  4043. else:
  4044. assert isinstance(self.data, Pointwise), type(self.data)
  4045. return partial(self.data.store_output, self.name, indexer)
  4046. def get_fill_order(self) -> Optional[list[int]]:
  4047. """
  4048. If our layout is still flexible, try to determine the stride order based on stride orders of reads.
  4049. TODO(jansel): A better algorithm here would look at downstream consumers of this
  4050. value and try to do global graph-level layout optimization.
  4051. This is also something just begging to be autotuned.
  4052. """
  4053. if isinstance(self.layout, FlexibleLayout):
  4054. (index_vars, reduction_vars), _ = dependencies.index_vars_squeeze(
  4055. self.data.get_pointwise_size(), self.data.get_reduction_size()
  4056. )
  4057. reads = self.get_read_writes().reads
  4058. # only consider reads to buffer of same size
  4059. # ignore StarDeps because they don't contribute stride information
  4060. assert all(
  4061. isinstance(r, (dependencies.StarDep, dependencies.MemoryDep))
  4062. for r in reads
  4063. )
  4064. reads = [
  4065. sympy_subs(r.index, {v: sympy.S.Zero for v in reduction_vars if v != 0})
  4066. for r in reads
  4067. if isinstance(r, dependencies.MemoryDep)
  4068. ]
  4069. if reads:
  4070. if isinstance(self.data, (Scan, Sort)):
  4071. indices = self.data.reindex(index_vars, reduction_vars)
  4072. else:
  4073. indices = index_vars
  4074. stride_lengths = [
  4075. V.graph.sizevars.stride_hints(expr, indices) for expr in reads
  4076. ]
  4077. from .scheduler import pick_loop_order
  4078. return pick_loop_order(stride_lengths, self.get_size())
  4079. return None
  4080. def decide_layout(self) -> None:
  4081. if isinstance(self.layout, FlexibleLayout):
  4082. order = self.get_fill_order()
  4083. if order:
  4084. self.freeze_layout_with_fill_order(order)
  4085. else:
  4086. self.freeze_layout()
  4087. @cache_on_self
  4088. def get_default_sizes_body(
  4089. self,
  4090. ) -> tuple[
  4091. tuple[list[Expr], list[Expr]],
  4092. LoopBody,
  4093. tuple[list[Expr], list[Expr]],
  4094. ]:
  4095. args, var_ranges = dependencies.index_vars_squeeze(
  4096. self.get_pointwise_size(), self.get_reduction_size(), prefix="q"
  4097. )
  4098. with patch.object(ConstantBuffer, "override_device", self.get_device()):
  4099. body = LoopBody(
  4100. self.get_store_function(),
  4101. (args if self.get_reduction_type() else args[:1]),
  4102. var_ranges,
  4103. *args,
  4104. )
  4105. index_vars = []
  4106. reduce_vars: list[Any] = []
  4107. index_size = []
  4108. reduce_size = []
  4109. for v, s in var_ranges.items():
  4110. if v in args[0]:
  4111. assert not reduce_vars
  4112. index_vars.append(v)
  4113. index_size.append(s)
  4114. else:
  4115. assert v in args[1]
  4116. reduce_vars.append(v)
  4117. reduce_size.append(s)
  4118. return (index_size, reduce_size), body, (index_vars, reduce_vars)
  4119. def simplify_and_reorder(
  4120. self,
  4121. extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None,
  4122. recompute_sizes_body_func: Optional[Callable[..., Any]] = None,
  4123. ) -> tuple[tuple[list[Expr], list[Expr]], Optional[LoopBody]]:
  4124. """
  4125. This is a main place where we do loop transformations in a
  4126. backend-agnostic way.
  4127. Here we:
  4128. 1) Remove any 1 dimensions
  4129. 2) Fuse contiguous dimensions together
  4130. 3) Reorder dimensions based on stride orders
  4131. Optional argument extra_indexing_constraints can be used to append additional
  4132. indexing expressions to existing ones derived from buffer's body. This can be useful
  4133. to fuse scheduler nodes with compatible ranges, e.g. (s0*s1*...,) and (s0, s1, s2, ...)
  4134. on CPU by preventing indexing simplifications and obtaining index/reduce ranges for
  4135. the scheduler node compatible with other nodes.
  4136. Optional argument recompute_sizes_body_func can be used to recompute sizes and body
  4137. on the default body. This can be useful to append additional loop transformations.
  4138. """
  4139. (
  4140. (index_size, reduce_size),
  4141. body,
  4142. (index_vars, reduce_vars),
  4143. ) = self.get_default_sizes_body()
  4144. if recompute_sizes_body_func:
  4145. (
  4146. (index_size, reduce_size),
  4147. body,
  4148. (index_vars, reduce_vars),
  4149. ) = recompute_sizes_body_func(
  4150. (index_size, reduce_size), body, (index_vars, reduce_vars)
  4151. )
  4152. index_formulas = [*body.indexing_exprs.values()]
  4153. if extra_indexing_constraints is not None:
  4154. assert (
  4155. isinstance(extra_indexing_constraints, tuple)
  4156. and len(extra_indexing_constraints) == 2
  4157. )
  4158. extra_indexing_ranges, extra_indexing_expr = extra_indexing_constraints
  4159. assert isinstance(extra_indexing_ranges, dict), type(extra_indexing_ranges)
  4160. assert isinstance(extra_indexing_expr, list), type(extra_indexing_expr)
  4161. assert all(isinstance(f, Expr) for f in extra_indexing_expr)
  4162. expected_var_ranges = body.var_ranges
  4163. assert expected_var_ranges == extra_indexing_ranges, (
  4164. expected_var_ranges,
  4165. extra_indexing_ranges,
  4166. )
  4167. # remove already existing expressions
  4168. extra_indexing_expr = [
  4169. e for e in extra_indexing_expr if e not in index_formulas
  4170. ]
  4171. index_formulas += extra_indexing_expr
  4172. memory_addrs = [*body.get_write_exprs()]
  4173. if not V.graph.has_feature(self, BackendFeature.PREFER_STORE_LOOP_ORDER):
  4174. memory_addrs.extend(body.get_read_exprs())
  4175. def simplify_and_reorder(
  4176. x_vars: Sequence[sympy.Symbol],
  4177. support_vars: Sequence[sympy.Symbol],
  4178. sizes: Sequence[int],
  4179. simplify_loops: bool,
  4180. ) -> tuple[
  4181. list[int],
  4182. Callable[[Sequence[int]], Sequence[int]],
  4183. Callable[[Sequence[int]], Sequence[int]],
  4184. ]:
  4185. newsizes, reindex0, reindex1 = self._apply_loop_reordering(
  4186. x_vars, support_vars, sizes, memory_addrs
  4187. )
  4188. # When using native matmul, the codegen assumes the following loop order,
  4189. # regardless of the stride of A and B:
  4190. #
  4191. # for z -> y -> x -> r: C[z, y, x] += A[z, y, r] * B[z, r, x]
  4192. # or
  4193. # for z -> x -> y -> r: C[z, y, x] += A[z, y, r] * B[z, r, x]
  4194. #
  4195. # The critical point is the position of the "z" (batch) axis in bmm.
  4196. # It is fine to swap the y and x axes (e.g., (z, y, x, r) or (z, x, y, r)),
  4197. # but reordering the z axis (e.g., (y, x, z, r)) breaks codegen.
  4198. #
  4199. # Therefore, if loop reordering changes the "z" location in bmm,
  4200. # it should be reverted to the default.
  4201. # This may not always produce the optimal loop order when strides
  4202. # do not align with the default assumption.
  4203. #
  4204. # TODO: Consider extending tl.dot codegen to support arbitrary loop orders.
  4205. if self.get_reduction_type() == "dot" and len(sizes) == 3:
  4206. order = list(range(len(sizes))) # default order
  4207. # if z axis is not the outermost, use the default reorder.
  4208. if reindex0(order)[0] != 0:
  4209. newsizes = [sizes[i] for i in order]
  4210. reindex0 = same_reorder(order)
  4211. reindex1 = inverse_reorder(order)
  4212. # for NHWC: reindex0([0,1,2,3]) = [0,2,3,1], reindex1([0,1,2,3]) = [0,3,2,1]
  4213. x_vars = reindex0(x_vars)
  4214. if simplify_loops:
  4215. newsizes, reindex2, _prune = V.graph.sizevars._simplify_loops(
  4216. x_vars,
  4217. newsizes,
  4218. index_prevent_reordering(index_formulas, x_vars, newsizes),
  4219. )
  4220. reindex = fuse_reindexing(reindex1, reindex2)
  4221. else:
  4222. reindex = reindex1
  4223. return newsizes, reindex, reindex1
  4224. support_vars = index_vars + reduce_vars
  4225. should_merge_loops = (
  4226. not is_gpu(get_device_type(self)) or not config.loop_ordering_after_fusion
  4227. )
  4228. iter_ranges, iter_reindex, _ = simplify_and_reorder(
  4229. index_vars,
  4230. support_vars,
  4231. index_size,
  4232. should_merge_loops,
  4233. )
  4234. # Like iteration dimensions, we may also want to delay merging reduction dimensions.
  4235. # E.g., if we reduce a tensor [M, N, K] for its M and N dimensions followed by a pointwise
  4236. # kernel, merging M and N dimension too early makes it hard to decide what loop order
  4237. # we should pick for the piontwise kernel so that it is fusible with the reduction.
  4238. reduce_ranges, reduce_reindex, _ = simplify_and_reorder(
  4239. reduce_vars, support_vars, reduce_size, should_merge_loops
  4240. )
  4241. # retrace the loop body with simplification and reordering applied
  4242. (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
  4243. iter_ranges,
  4244. reduce_ranges,
  4245. prefix="p",
  4246. )
  4247. body = LoopBody(
  4248. body,
  4249. [iter_reindex(iter_vars), reduce_reindex(reduce_vars)],
  4250. var_ranges,
  4251. iter_vars,
  4252. reduce_vars,
  4253. )
  4254. return (iter_ranges, reduce_ranges), body
  4255. @staticmethod
  4256. def _apply_loop_reordering(
  4257. index_vars: Sequence[sympy.Symbol],
  4258. support_vars: Sequence[sympy.Symbol],
  4259. sizes: Sequence[int],
  4260. memory_addrs: list[sympy.Expr],
  4261. priority_idx: Optional[list[int]] = None,
  4262. ) -> tuple[
  4263. list[int],
  4264. Callable[[Sequence[int]], Sequence[int]],
  4265. Callable[[Sequence[int]], Sequence[int]],
  4266. ]:
  4267. """
  4268. Shuffle the order of loops around to hopefully improve performance.
  4269. """
  4270. from .scheduler import pick_loop_order
  4271. if priority_idx is None:
  4272. priority_idx = []
  4273. try:
  4274. strides = [
  4275. V.graph.sizevars.stride_hints(expr, index_vars, support_vars)
  4276. for expr in memory_addrs
  4277. ]
  4278. assert len(strides) == len(memory_addrs) and len(strides[0]) == len(
  4279. index_vars
  4280. )
  4281. order = list(reversed(pick_loop_order(strides, sizes, priority_idx)))
  4282. except Exception:
  4283. if config.debug:
  4284. log.warning(
  4285. "Did not simplify complex index:\n%s\n%s",
  4286. dict(zip(index_vars, sizes)),
  4287. memory_addrs,
  4288. )
  4289. order = list(range(len(sizes)))
  4290. sizes = [sizes[i] for i in order]
  4291. return sizes, same_reorder(order), inverse_reorder(order)
  4292. def get_pointwise_size(self) -> Sequence[Expr]:
  4293. return self.data.get_pointwise_size()
  4294. def get_reduction_size(self) -> Sequence[Expr]:
  4295. return self.data.get_reduction_size()
  4296. def get_reduction_type(self) -> Optional[str]:
  4297. return self.data.get_reduction_type()
  4298. def is_no_op(self) -> bool:
  4299. return self.data.is_zero_elements()
  4300. def should_allocate(self) -> bool:
  4301. return True
  4302. def constant_to_device(self, device: torch.device) -> IRNode:
  4303. """Move this to a given device. Requires that all reads are to constants."""
  4304. return self.data.constant_to_device(device)
  4305. class TemplateBuffer(OperationBuffer):
  4306. """
  4307. Represents a Triton (in the future other type) of template operator
  4308. that we can fuse an epilogue onto.
  4309. """
  4310. def __init__(
  4311. self,
  4312. layout: OutputSpec,
  4313. inputs: Sequence[IRNode],
  4314. make_kernel_render: Optional[Callable[..., Any]],
  4315. ) -> None:
  4316. super().__init__(name=None, layout=layout)
  4317. self.inputs = InputsKernel.unwrap_storage(inputs)
  4318. self.make_kernel_render = make_kernel_render
  4319. self.name = V.graph.register_buffer(self)
  4320. V.graph.register_operation(self)
  4321. # Annotations dict for storing metadata (e.g., KernelTemplateChoice)
  4322. self.annotations: dict[str, Any] = {}
  4323. def get_read_writes(self) -> dependencies.ReadWrites:
  4324. return self.extract_read_writes(normalize=True)
  4325. def extract_read_writes(self, normalize: bool = False) -> dependencies.ReadWrites:
  4326. name = self.get_name()
  4327. indexer = self.get_layout().make_indexer()
  4328. def dummy(index: Sequence[Any], rindex: Sequence[Any]) -> Any:
  4329. assert len(rindex) == 0
  4330. return ops.store(name, indexer(index), "fake")
  4331. deps = dependencies.extract_read_writes(
  4332. dummy, self.get_size(), (), normalize=normalize
  4333. )
  4334. for inp in self.inputs:
  4335. assert isinstance(inp, (ReinterpretView, Buffer)), type(inp)
  4336. assert isinstance(inp.layout, Layout), type(inp.layout)
  4337. indexer = inp.layout.make_indexer()
  4338. def dummy(index: Sequence[Any], rindex: Sequence[Any]) -> Any:
  4339. assert len(rindex) == 0
  4340. # pyrefly: ignore [missing-attribute]
  4341. return ops.load(inp.get_name(), indexer(index))
  4342. deps.reads |= dependencies.extract_read_writes(
  4343. dummy, inp.get_size(), (), normalize=normalize
  4344. ).reads
  4345. return deps
  4346. def get_reduction_size(self) -> Sequence[Expr]:
  4347. return sympy.S.One
  4348. def get_reduction_type(self) -> Optional[str]:
  4349. return None
  4350. def should_allocate(self) -> bool:
  4351. return True
  4352. def simplify_and_reorder(
  4353. self,
  4354. extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None,
  4355. recompute_sizes_body_func: Optional[Callable[..., Any]] = None,
  4356. ) -> tuple[tuple[Sequence[Expr], list[Expr]], Optional[LoopBody]]:
  4357. return (
  4358. (
  4359. self.get_size(),
  4360. [],
  4361. ),
  4362. None,
  4363. )
  4364. class TritonTemplateBuffer(TemplateBuffer):
  4365. def __init__(
  4366. self,
  4367. layout: Layout,
  4368. inputs: Sequence[IRNode],
  4369. make_kernel_render: Optional[Callable[_P, _T]],
  4370. mutated_inputs: Optional[Iterable[IRNode]] = None,
  4371. allowed_prologue_inps: Optional[OrderedSet[str]] = None,
  4372. ) -> None:
  4373. """
  4374. NOTE:[TritonTemplates with multiple outputs]
  4375. We want the ability for TritonTemplates to output multiple tensors. Triton
  4376. kernels have no notion of outputs and this is done by creating tensors that
  4377. are then mutated by the kernel. Currently our STORE_OUTPUT codegen doesn't
  4378. support creating multinode outputs for triton templates.
  4379. We work around this by creating an extra input buffer during the lowering
  4380. and we mark them as mutated inputs.
  4381. """
  4382. super().__init__(layout, inputs, make_kernel_render)
  4383. self.mutated_inputs = mutated_inputs
  4384. self.outputs: list[Buffer] = [self]
  4385. if mutated_inputs is not None:
  4386. assert isinstance(self.inputs[0], IRNode), type(self.inputs[0])
  4387. device = self.inputs[0].get_device()
  4388. self.outputs += [
  4389. MutationOutput(NoneLayout(device=device), buf, self)
  4390. for buf in mutated_inputs
  4391. ]
  4392. self.allowed_prologue_inps = (
  4393. allowed_prologue_inps if allowed_prologue_inps else OrderedSet()
  4394. )
  4395. self.subgraph_inps: Optional[list[Optional[Union[IRNode, sympy.Expr]]]] = None
  4396. self.subgraph_outs: Optional[list[Optional[IRNode]]] = None
  4397. @cache_on_self_and_args("TritonTemplateBuffer")
  4398. def get_free_symbol_uses(
  4399. self, unbacked_only: bool = False
  4400. ) -> OrderedSet[sympy.Symbol]:
  4401. res = super().get_free_symbol_uses(unbacked_only)
  4402. subgraph_outs = self.subgraph_outs if self.subgraph_outs else []
  4403. subgraph_inps = self.subgraph_inps if self.subgraph_inps else []
  4404. for inp in subgraph_inps:
  4405. if isinstance(inp, sympy.Expr):
  4406. res.update(get_free_symbols(inp, unbacked_only))
  4407. elif isinstance(inp, IRNode):
  4408. res.update(inp.get_free_symbol_uses(unbacked_only))
  4409. else:
  4410. assert inp is None
  4411. for out in subgraph_outs:
  4412. if isinstance(out, IRNode):
  4413. res.update(out.get_free_symbol_uses(unbacked_only))
  4414. else:
  4415. assert out is None
  4416. return res
  4417. def get_outputs(self) -> list[Buffer]:
  4418. return self.outputs
  4419. def get_allowed_prologue_inps(self) -> OrderedSet[str]:
  4420. return self.allowed_prologue_inps
  4421. def __str__(self) -> str:
  4422. out = f"TritonTemplateBuffer(layout={self.layout})"
  4423. return out
  4424. PrimitiveInfoType = Union[int, float, bool, str, list[Union[int, str, float, bool]]]
  4425. class ChoiceCaller:
  4426. """
  4427. Represents a possible choice used in autotune_process.py.
  4428. During autotuning, self.benchmark() is first called to get benchmark result,
  4429. and if this choice is selected, self.output_node() is called to get the output_node.
  4430. Children classes: TritonTemplateCaller, CUTLASSTemplateCaller.
  4431. """
  4432. def __init__(
  4433. self,
  4434. name: str,
  4435. input_nodes: list[Buffer],
  4436. layout: Layout,
  4437. description: str,
  4438. ) -> None:
  4439. super().__init__()
  4440. self.name = name
  4441. self.layout = layout
  4442. self.input_nodes = input_nodes
  4443. # An additional description used to describe the choice (useful for
  4444. # knowing what autotuning is choosing)
  4445. self.description = description
  4446. self.failed: bool = False
  4447. # A place to store annotations that can be read post benchmarking
  4448. # Use this to shuttle information between ChoieCaller generation
  4449. # and the end of benchmarking
  4450. self.annotations: dict[Any, Any] = {}
  4451. def benchmark(self, *args: Any, out: torch.Tensor) -> float:
  4452. algo = self.to_callable()
  4453. if config.profile_bandwidth_with_do_bench_using_profiling:
  4454. return do_bench_using_profiling(lambda: algo(*args)) # type: ignore[arg-type]
  4455. return benchmarker.benchmark(algo, args, {"out": out}, device=None)
  4456. def call_name(self) -> str:
  4457. raise NotImplementedError
  4458. def to_callable(self) -> Callable[..., Any]:
  4459. raise NotImplementedError
  4460. def kernel_hash_key(self) -> str:
  4461. """
  4462. Hash key for the underlying kernel. By default, we assume there are no
  4463. runtime params, so kernel hash key defaults to choice caller's hash key.
  4464. """
  4465. return self.hash_key()
  4466. def hash_key(self) -> str:
  4467. raise NotImplementedError
  4468. def output_node(self) -> TensorBox:
  4469. raise NotImplementedError
  4470. def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]:
  4471. """Information returned here is logged to the autotune log file when that is enabled."""
  4472. return {}
  4473. def autoheuristic_id(self) -> str:
  4474. return "unsupported_choice"
  4475. def mark_failed(self) -> None:
  4476. """
  4477. Mark the choice as failed so that it can be
  4478. removed later. Useful for when we decouple
  4479. compilation and tuning.
  4480. """
  4481. self.failed = True
  4482. class TritonTemplateCallerBase(ChoiceCaller):
  4483. def get_make_kernel_render(self) -> Any:
  4484. raise NotImplementedError
  4485. class MultiTemplateBuffer(TritonTemplateBuffer):
  4486. """
  4487. Represents a Buffer with multiple backing implementation choices.
  4488. Choices can be TritonTemplates or ExternKernels. During scheduling if there is a potential
  4489. epilogue we will benchmark each of the choices with the epilogue to determine an implementation.
  4490. Otherwise, the fastest base choice will be chosen.
  4491. """
  4492. def __init__(
  4493. self,
  4494. layout: Layout,
  4495. inputs: Sequence[IRNode],
  4496. choice_timings_fn: Callable[[Optional[int]], dict[ChoiceCaller, float]],
  4497. unfiltered_choices: list[ChoiceCaller],
  4498. allowed_prologue_inps: OrderedSet[str],
  4499. ) -> None:
  4500. super().__init__(
  4501. layout=layout,
  4502. inputs=inputs,
  4503. make_kernel_render=None,
  4504. allowed_prologue_inps=allowed_prologue_inps,
  4505. )
  4506. self._choice_timings_fn = choice_timings_fn
  4507. self._choice_timings: dict[Optional[int], dict[ChoiceCaller, float]] = {}
  4508. self._choices: list[ChoiceCaller] = unfiltered_choices
  4509. self.original_inputs = inputs
  4510. self._output_plannable = all(
  4511. isinstance(choice, TritonTemplateCallerBase)
  4512. or (
  4513. isinstance(choice, torch._inductor.select_algorithm.ExternKernelCaller)
  4514. and choice.has_out_variant
  4515. )
  4516. for choice in unfiltered_choices
  4517. )
  4518. self._make_kernel_renders: dict[Optional[int], Any] = {}
  4519. @property
  4520. def output_plannable(self) -> bool:
  4521. """
  4522. Are all possible choices TritonTemplates or Extern Kernels with out variants
  4523. """
  4524. return self._output_plannable
  4525. @property
  4526. def choices(self) -> list[ChoiceCaller]:
  4527. return self._choices
  4528. def choice_timings(
  4529. self, hint_override: Optional[int] = None
  4530. ) -> dict[ChoiceCaller, float]:
  4531. if hint_override not in self._choice_timings:
  4532. self._choice_timings[hint_override] = self._choice_timings_fn(hint_override)
  4533. return self._choice_timings[hint_override]
  4534. @contextlib.contextmanager
  4535. def swap_as_triton_caller(self, caller: TritonTemplateCallerBase) -> Iterator[None]:
  4536. assert isinstance(
  4537. caller, torch._inductor.select_algorithm.TritonTemplateCaller
  4538. ), type(caller)
  4539. assert self.layout == caller.layout
  4540. render = self.make_kernel_render
  4541. self.make_kernel_render = caller.get_make_kernel_render()
  4542. try:
  4543. yield
  4544. finally:
  4545. self.make_kernel_render = render
  4546. def finalize_as_triton_caller(self, caller: TritonTemplateCallerBase) -> None:
  4547. assert isinstance(
  4548. caller, torch._inductor.select_algorithm.TritonTemplateCaller
  4549. ), type(caller)
  4550. assert self.get_size() == caller.layout.size
  4551. assert self.get_stride() == caller.layout.stride
  4552. self.make_kernel_render = caller.get_make_kernel_render()
  4553. def get_min_choice(
  4554. self, hint_override: Optional[int] = None
  4555. ) -> tuple[ChoiceCaller, float]:
  4556. timings = self.choice_timings(hint_override=hint_override)
  4557. min_choice = min(timings, key=timings.get) # type: ignore[arg-type]
  4558. return (min_choice, timings[min_choice])
  4559. def finalize_as_triton_callers(
  4560. self, callers: dict[Optional[int], TritonTemplateCallerBase]
  4561. ) -> None:
  4562. """Finalize with multiple callers for different hint overrides"""
  4563. for hint_override, caller in callers.items():
  4564. self._make_kernel_renders[hint_override] = caller.get_make_kernel_render()
  4565. # Set the default to be the one without hint override
  4566. self.make_kernel_render = self._make_kernel_renders[None]
  4567. class CUTLASSTemplateBuffer(TemplateBuffer):
  4568. def __init__(
  4569. self,
  4570. layout: Layout,
  4571. inputs: Sequence[IRNode],
  4572. make_kernel_render: Callable[_P, _T],
  4573. workspace_size: int,
  4574. template: CUTLASSTemplate,
  4575. supports_epilogue_fusion: bool,
  4576. ) -> None:
  4577. super().__init__(layout, inputs, make_kernel_render)
  4578. # Global memory (in bytes) needed for this template.
  4579. self.workspace_size = workspace_size
  4580. self.template = template
  4581. self.supports_epilogue_fusion = supports_epilogue_fusion
  4582. def get_workspace_size(self) -> int:
  4583. return self.workspace_size if self.workspace_size is not None else 0
  4584. def emulate_store_fn(self) -> None:
  4585. for output in self.get_outputs():
  4586. ops.store(output.get_name(), None, None)
  4587. class CppTemplateBuffer(TemplateBuffer):
  4588. def __init__(
  4589. self,
  4590. layout: Layout,
  4591. inputs: Sequence[IRNode],
  4592. make_kernel_render: Callable[_P, _T],
  4593. template: CUTLASSTemplate,
  4594. choice: Any,
  4595. ) -> None:
  4596. super().__init__(layout, inputs, make_kernel_render)
  4597. self.template = template
  4598. self.choice = choice
  4599. self.outputs: Optional[list[Buffer]] = None
  4600. def get_layout(self) -> Layout:
  4601. if isinstance(self.layout, MultiOutputLayout):
  4602. assert isinstance(self.outputs, Iterable), type(self.outputs)
  4603. first_output = self.outputs[0]
  4604. assert isinstance(first_output, Buffer), type(first_output)
  4605. layout = first_output.layout
  4606. assert isinstance(layout, Layout), type(layout)
  4607. return layout
  4608. else:
  4609. return super().get_layout()
  4610. class CuteDSLTemplateBuffer(TemplateBuffer):
  4611. """
  4612. Buffer for CuteDSL (CUTLASS Python DSL) template kernels.
  4613. Similar to other template buffers but specialized for CuteDSL operations.
  4614. """
  4615. def __init__(
  4616. self,
  4617. layout: Layout,
  4618. inputs: Sequence[IRNode],
  4619. make_kernel_render: Callable[_P, _T],
  4620. template: Any,
  4621. mutated_inputs: Optional[Iterable[IRNode]] = None,
  4622. ) -> None:
  4623. super().__init__(layout, inputs, make_kernel_render)
  4624. self.template = template
  4625. self.mutated_inputs = mutated_inputs
  4626. self.outputs: list[Buffer] = [self]
  4627. if mutated_inputs is not None:
  4628. assert isinstance(self.inputs[0], IRNode), type(self.inputs[0])
  4629. device = self.inputs[0].get_device()
  4630. self.outputs += [
  4631. MutationOutput(NoneLayout(device=device), buf, self)
  4632. for buf in mutated_inputs
  4633. ]
  4634. def get_outputs(self) -> list[Buffer]:
  4635. return self.outputs
  4636. class NVUniversalGemmBuffer(TemplateBuffer):
  4637. """
  4638. Buffer for NVIDIA Universal GEMM kernels.
  4639. Unlike CuteDSL templates which use Jinja templates, this generates
  4640. simpler Python code that directly calls the cutlass_api library.
  4641. """
  4642. def __init__(
  4643. self,
  4644. layout: Layout,
  4645. inputs: Sequence[IRNode],
  4646. kernel: Any,
  4647. accumulator_type: Any,
  4648. variant: Any, # GemmVariant, use Any to avoid circular import
  4649. workspace_size: int = 0,
  4650. scale_type_a: Optional[Any] = None,
  4651. scale_type_b: Optional[Any] = None,
  4652. swizzle_type_a: Optional[Any] = None,
  4653. swizzle_type_b: Optional[Any] = None,
  4654. ) -> None:
  4655. # We pass None initially, then override with our method below
  4656. super().__init__(layout, inputs, make_kernel_render=None)
  4657. self.kernel = kernel
  4658. self.accumulator_type = accumulator_type
  4659. self.outputs: list[Buffer] = [self]
  4660. self.workspace_size = workspace_size
  4661. self.variant = variant
  4662. self.scale_type_a = scale_type_a
  4663. self.scale_type_b = scale_type_b
  4664. self.swizzle_type_a = swizzle_type_a
  4665. self.swizzle_type_b = swizzle_type_b
  4666. # Store kernel metadata for code generation since kernels aren't serializeable yet
  4667. self.kernel_metadata = {
  4668. "kernel_name": kernel.metadata.kernel_name,
  4669. "min_cc": kernel.metadata.min_cc,
  4670. }
  4671. # Override the instance attribute set by parent with our method
  4672. # This is necessary because TemplateBuffer stores make_kernel_render as instance attr
  4673. self.make_kernel_render = self._make_kernel_render
  4674. def get_workspace_size(self) -> int:
  4675. """Return the workspace size in bytes."""
  4676. return self.workspace_size
  4677. def get_outputs(self) -> list[Buffer]:
  4678. return self.outputs
  4679. def _make_kernel_render(
  4680. self, out_node: Any, hint_override: Optional[int] = None
  4681. ) -> tuple[Any, Any]:
  4682. """
  4683. Create a kernel renderer for code generation.
  4684. Returns (kernel, render) tuple where:
  4685. - kernel: NVUniversalGemmKernel object with call_kernel() method
  4686. - render: function that returns source code string
  4687. """
  4688. from torch._inductor.codegen.nv_universal_gemm.nv_universal_gemm_kernel import (
  4689. NVUniversalGemmKernel,
  4690. )
  4691. from torch._inductor.utils import Placeholder
  4692. input_nodes: list[Any] = []
  4693. for inp in self.inputs:
  4694. if isinstance(inp, TensorBox):
  4695. inp = inp.data
  4696. if isinstance(inp, StorageBox):
  4697. inp = inp.data
  4698. input_nodes.append(inp)
  4699. kernel_name = str(Placeholder.KERNEL_NAME)
  4700. render_kernel = NVUniversalGemmKernel(
  4701. kernel_name=kernel_name,
  4702. input_nodes=input_nodes,
  4703. output_node=out_node,
  4704. kernel_metadata=self.kernel_metadata,
  4705. accumulator_type=self.accumulator_type,
  4706. workspace_size=self.workspace_size,
  4707. variant=self.variant,
  4708. scale_type_a=self.scale_type_a,
  4709. scale_type_b=self.scale_type_b,
  4710. swizzle_type_a=self.swizzle_type_a,
  4711. swizzle_type_b=self.swizzle_type_b,
  4712. )
  4713. def render():
  4714. return render_kernel.render()
  4715. return render_kernel, render
  4716. def is_node_sequence(
  4717. nodes: Sequence[Union[IRNode, Sequence[IRNode]]],
  4718. ) -> TypeIs[Sequence[IRNode]]:
  4719. return all(isinstance(n, IRNode) for n in nodes)
  4720. @ir_dataclass(frozen=False)
  4721. class InputsKernel(OperationBuffer):
  4722. inputs: Sequence[Union[IRNode, Sequence[IRNode]]]
  4723. def input_name(self, i: int) -> str:
  4724. input = self.inputs[i]
  4725. assert isinstance(input, IRNode)
  4726. return input.get_name()
  4727. def get_read_writes(self) -> dependencies.ReadWrites:
  4728. reads = OrderedSet[dependencies.Dep]()
  4729. StarDep = dependencies.StarDep
  4730. for input in self.inputs:
  4731. if isinstance(input, Sequence):
  4732. reads.update(StarDep(x.get_name()) for x in input)
  4733. elif isinstance(input, ShapeAsConstantBuffer):
  4734. # Skip creating dependency for symbolics as they're visible globally
  4735. continue
  4736. else:
  4737. reads.add(StarDep(input.get_name()))
  4738. writes = OrderedSet[dependencies.Dep](
  4739. StarDep(buf.get_name()) for buf in self.get_outputs()
  4740. )
  4741. return dependencies.ReadWrites(
  4742. reads=reads,
  4743. writes=writes,
  4744. index_exprs=OrderedSet(),
  4745. )
  4746. def get_reads(self) -> OrderedSet[Dep]:
  4747. return self.get_read_writes().reads
  4748. @classmethod
  4749. def unwrap_storage_for_input(cls, x: IRNode) -> IRNode:
  4750. if isinstance(x, TensorBox):
  4751. x = x.data
  4752. if isinstance(x, StorageBox):
  4753. x = x.data
  4754. if isinstance(x, BaseView) and not isinstance(x, ReinterpretView):
  4755. x = ExternKernel.realize_input(x)
  4756. if isinstance(x, TensorBox):
  4757. # when converting to ReinterpretView fails in the
  4758. # realize_input call above, the result will be wrapped
  4759. # into TensorBox / StorageBox pair as a result of the
  4760. # cls.copy_input call; so we should unwrap recursively
  4761. return cls.unwrap_storage_for_input(x)
  4762. if isinstance(x, TorchBindObject):
  4763. return x
  4764. assert isinstance(x, (Buffer, ReinterpretView)), type(x)
  4765. return x
  4766. @staticmethod
  4767. def unwrap_storage(
  4768. inputs: Sequence[Union[IRNode, Sequence[IRNode]]],
  4769. ) -> list[Union[IRNode, Sequence[IRNode]]]:
  4770. inputs_new: list[Union[IRNode, Sequence[IRNode]]] = []
  4771. for x in inputs:
  4772. if isinstance(x, Sequence):
  4773. x = [InputsKernel.unwrap_storage_for_input(i) for i in x]
  4774. else:
  4775. x = InputsKernel.unwrap_storage_for_input(x)
  4776. inputs_new.append(x)
  4777. return inputs_new
  4778. def is_extern(self) -> bool:
  4779. return True
  4780. def num_reads(self) -> int:
  4781. return 1
  4782. @cache_on_self_and_args("InputsKernel")
  4783. def get_free_symbol_uses(
  4784. self, unbacked_only: bool = False
  4785. ) -> OrderedSet[sympy.Symbol]:
  4786. r = OrderedSet[sympy.Symbol]()
  4787. for inp in self.inputs:
  4788. if isinstance(inp, IRNode):
  4789. r |= inp.get_free_symbol_uses(unbacked_only)
  4790. else:
  4791. for inner_inp in inp:
  4792. r |= inner_inp.get_free_symbol_uses(unbacked_only)
  4793. return r
  4794. class NopKernel(InputsKernel):
  4795. def is_no_op(self) -> bool:
  4796. return True
  4797. def get_reads(self) -> OrderedSet[Dep]:
  4798. return OrderedSet()
  4799. class ConcatKernel(NopKernel):
  4800. """
  4801. There isn't actually a real kernel for concat, we just change the
  4802. storage for the upstream data.
  4803. """
  4804. @classmethod
  4805. def create(cls, inputs: Sequence[IRNode], dim: int) -> StorageBox:
  4806. """
  4807. Create the concat kernel from inputs
  4808. """
  4809. device = inputs[0].get_device()
  4810. dtype = inputs[0].get_dtype()
  4811. new_size = list(inputs[0].get_size())
  4812. offsets_start = [0]
  4813. offsets_end = [new_size[dim]]
  4814. assert 0 <= dim < len(new_size)
  4815. for i in range(1, len(inputs)):
  4816. input_size = inputs[i].get_size()
  4817. offsets_start.append(new_size[dim])
  4818. assert len(input_size) == len(new_size)
  4819. assert inputs[i].get_dtype() == dtype
  4820. assert inputs[i].get_device() == device
  4821. for j in range(len(new_size)):
  4822. if j == dim:
  4823. new_size[j] = new_size[j] + input_size[j]
  4824. else:
  4825. new_size[j] = V.graph.sizevars.check_equals_and_simplify(
  4826. new_size[j], input_size[j]
  4827. )
  4828. offsets_end.append(new_size[dim])
  4829. output_stride: Sequence[int] = FlexibleLayout.contiguous_strides(new_size)
  4830. if config.comprehensive_padding:
  4831. # Ensure the output stride matches the alignment requirements
  4832. output_stride = Layout._pad_strides(
  4833. output_stride, new_size, inputs[0].dtype
  4834. )
  4835. # If any of the inputs is in CL format, use CL format for the output
  4836. for i in range(len(inputs)):
  4837. x = inputs[i]
  4838. if is_storage_and_layout(x):
  4839. layout = x.get_layout()
  4840. if isinstance(
  4841. layout, FixedLayout
  4842. ) and Layout.is_channels_last_contiguous(layout.size, layout.stride):
  4843. # use CL stride for the output
  4844. output_stride = make_channels_last_strides_for(new_size)
  4845. break
  4846. any_input_is_storage_and_layout = any(is_storage_and_layout(x) for x in inputs)
  4847. fx_node_args = V.graph.current_node.args[0]
  4848. assert isinstance(fx_node_args, list), type(fx_node_args)
  4849. # If any of the inputs has meta tensor and the meta tensor is in CL format, use CL format for the output
  4850. if any_input_is_storage_and_layout is False and any(
  4851. # pyrefly: ignore [missing-attribute]
  4852. "val" in arg.meta
  4853. and (
  4854. # pyrefly: ignore [missing-attribute]
  4855. arg.meta["val"].is_contiguous(memory_format=torch.channels_last)
  4856. # pyrefly: ignore [missing-attribute]
  4857. or arg.meta["val"].is_contiguous(memory_format=torch.channels_last_3d)
  4858. )
  4859. for arg in fx_node_args
  4860. ):
  4861. output_stride = make_channels_last_strides_for(new_size)
  4862. is_pinned = all(
  4863. is_storage_and_layout(x) and x.get_layout().is_pinned for x in inputs
  4864. )
  4865. assert device is not None
  4866. concat_kernel = ConcatKernel(
  4867. name=None,
  4868. layout=FixedLayout(
  4869. device=device,
  4870. dtype=dtype,
  4871. size=new_size,
  4872. stride=output_stride,
  4873. is_pinned=is_pinned,
  4874. ),
  4875. inputs=[],
  4876. )
  4877. kernel = StorageBox(concat_kernel)
  4878. op_names = []
  4879. for i, inp in enumerate(inputs):
  4880. assert isinstance(inp, (BaseView, MutableBox)), type(inp)
  4881. input_buffer = cls.realize_into(
  4882. inp,
  4883. SliceView.create(
  4884. kernel, dim, offsets_start[i], offsets_end[i], clamp=False
  4885. ),
  4886. )
  4887. assert isinstance(input_buffer, Buffer), type(input_buffer)
  4888. assert isinstance(concat_kernel.inputs, list), type(concat_kernel.inputs)
  4889. concat_kernel.inputs.append(input_buffer)
  4890. if isinstance(inp.data, BaseView):
  4891. input_unwrapped = inp.data.unwrap_view()
  4892. else:
  4893. input_unwrapped = inp.data
  4894. if (
  4895. isinstance(input_unwrapped, StorageBox)
  4896. and input_unwrapped.is_input_buffer()
  4897. and (dev := inp.get_device()) is not None
  4898. and is_gpu(dev.type)
  4899. and not is_dynamic(input_buffer)
  4900. ):
  4901. op_names.append(input_buffer.get_operation_name())
  4902. if len(op_names) > 1 and V.graph.has_feature(device, BackendFeature.FOREACH):
  4903. V.graph.register_operation_list(op_names)
  4904. concat_kernel.name = V.graph.register_buffer(concat_kernel)
  4905. concat_kernel.inputs = cls.unwrap_storage(concat_kernel.inputs)
  4906. V.graph.register_operation(concat_kernel)
  4907. return kernel
  4908. @classmethod
  4909. def can_realize_into_without_copy(
  4910. cls, src: IRNode, dst: Optional[IRNode] = None
  4911. ) -> bool:
  4912. if isinstance(src, TensorBox):
  4913. # unwrap a TensorBox
  4914. return cls.can_realize_into_without_copy(src.data, dst)
  4915. assert isinstance(src, (BaseView, StorageBox)), type(src)
  4916. if isinstance(src.data, MultiTemplateBuffer):
  4917. if (
  4918. not isinstance(src.data.layout, FixedLayout)
  4919. or not src.data.output_plannable
  4920. ):
  4921. return False
  4922. # we call can_realize_into_without_copy in cat lowering before we've decided
  4923. # on output format, optimistically assume layout matches
  4924. if dst is None:
  4925. return True
  4926. # otherwise, check equality of layouts
  4927. if len(src.get_stride()) != len(dst.get_stride()):
  4928. return False
  4929. return all(
  4930. V.graph.sizevars.statically_known_equals(s1, s2)
  4931. for s1, s2 in zip(src.get_stride(), dst.get_stride())
  4932. )
  4933. return (
  4934. hasattr(src.data, "layout")
  4935. and isinstance(src.data.layout, FlexibleLayout)
  4936. and not isinstance(src.data, ExternKernelAlloc)
  4937. )
  4938. @cache_on_self_and_args("ConcatKernel")
  4939. def get_free_symbol_uses(
  4940. self, unbacked_only: bool = False
  4941. ) -> OrderedSet[sympy.Symbol]:
  4942. return NopKernel.get_free_symbol_uses(self, unbacked_only)
  4943. @classmethod
  4944. def realize_into(cls, src: IRNode, dst: IRNode) -> IRNode:
  4945. # Attempt to turn this into a ReinterpretView rather than assert.
  4946. # This has concessions around layout, as as_storage_and_layout
  4947. # can cause us to go from flexible to fixed layout.
  4948. if not isinstance(dst, ReinterpretView):
  4949. if is_storage_and_layout(dst):
  4950. storage, layout = as_storage_and_layout(dst)
  4951. dst = ReinterpretView(data=storage, layout=layout)
  4952. assert isinstance(dst, ReinterpretView), type(dst)
  4953. if isinstance(src, TensorBox):
  4954. # unwrap a TensorBox
  4955. return cls.realize_into(src.data, dst)
  4956. if isinstance(src, StorageBox):
  4957. src.realize()
  4958. # ExternKernelAlloc has specific requirements for output layout, should create a copy
  4959. assert hasattr(src.data, "layout")
  4960. if cls.can_realize_into_without_copy(src, dst):
  4961. # pyrefly: ignore [missing-attribute]
  4962. src.data.layout = NonOwningLayout(dst)
  4963. return src.data
  4964. # introduce a copy
  4965. pw = Pointwise.create(
  4966. device=src.get_device(),
  4967. dtype=src.get_dtype(),
  4968. inner_fn=src.make_loader(),
  4969. ranges=[
  4970. V.graph.sizevars.check_equals_and_simplify(a, b)
  4971. for a, b in zip(src.get_size(), dst.get_size())
  4972. ],
  4973. )
  4974. return cls.realize_into(pw, dst)
  4975. def should_allocate(self) -> bool:
  4976. return True
  4977. @ir_dataclass(frozen=False)
  4978. class ExternKernel(InputsKernel):
  4979. """
  4980. A class that represents Kernels which are not directly lowered to Inductor
  4981. Loop Level IR, such as custom operators, or aten operators which we fallback to.
  4982. """
  4983. constant_args: Sequence[Any] = ()
  4984. kwargs: dict[str, Any] = dataclasses.field(default_factory=dict)
  4985. output_view: Optional[ReinterpretView] = None
  4986. python_kernel_name: Optional[str] = None
  4987. cpp_kernel_name: Optional[str] = None
  4988. # FIXME: in some cases we sill need to explicitly pass in ordered_kwargs_for_cpp_kernel
  4989. # We shouldn't need to do this since the information can be retrieved from op_overload._schema.
  4990. ordered_kwargs_for_cpp_kernel: Iterable[str] = dataclasses.field(
  4991. default_factory=list
  4992. )
  4993. op_overload: Optional[_OpOverloads] = None
  4994. arg_properties: Optional[list[dict[str, Any]]] = None
  4995. allarg_properties: dict[str, dict[str, Any]] = dataclasses.field(
  4996. default_factory=dict
  4997. )
  4998. kwarg_properties: Optional[dict[str, dict[str, Any]]] = None
  4999. unbacked_bindings: dict[sympy.Symbol, pytree.KeyPath] = dataclasses.field(
  5000. default_factory=dict
  5001. )
  5002. mutation_outputs: list[MutationOutput] = dataclasses.field(default_factory=list)
  5003. def __init__(
  5004. self,
  5005. name: Optional[str],
  5006. layout: OutputSpec,
  5007. inputs: Sequence[Union[IRNode, Sequence[IRNode]]],
  5008. constant_args: Sequence[Any] = (),
  5009. kwargs: dict[str, Any] | None = None,
  5010. output_view: Optional[ReinterpretView] = None,
  5011. python_kernel_name: Optional[str] = None,
  5012. cpp_kernel_name: Optional[str] = None,
  5013. ordered_kwargs_for_cpp_kernel: Iterable[str] = (),
  5014. op_overload: Optional[_OpOverloads] = None,
  5015. ) -> None:
  5016. super().__init__(
  5017. name=name,
  5018. layout=layout,
  5019. inputs=inputs,
  5020. )
  5021. self.constant_args = constant_args
  5022. self.kwargs = kwargs if kwargs else {}
  5023. self.output_view = output_view
  5024. self.op_overload = op_overload
  5025. self.set_cpp_kernel_name(cpp_kernel_name)
  5026. self.set_python_kernel_name(python_kernel_name)
  5027. self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
  5028. self.collect_arg_kwarg_properties()
  5029. self.unbacked_bindings = {}
  5030. self.mutation_outputs = []
  5031. self.fx_node = V.graph.current_node
  5032. # Annotations dict for storing metadata (e.g., KernelTemplateChoice)
  5033. self.annotations: dict[str, Any] = {}
  5034. def get_outputs(self) -> list[Buffer]:
  5035. return [self, *self.mutation_outputs]
  5036. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  5037. return OrderedSet()
  5038. def collect_arg_kwarg_properties(self) -> None:
  5039. # if self.op_overload is torch._ops.OpOverload, we can use its schema to collect additional
  5040. # information for args and kwargs, e.g. type and default value, to help with the cpp wrapper codegen
  5041. self.arg_properties = (
  5042. [
  5043. {
  5044. "name": x.name,
  5045. "type": x.real_type,
  5046. "default_value": x.default_value,
  5047. }
  5048. for x in self.op_overload._schema.arguments
  5049. if not x.kwarg_only
  5050. ]
  5051. if isinstance(self.op_overload, torch._ops.OpOverload)
  5052. else [{} for i in range(len(self.inputs))]
  5053. )
  5054. self.allarg_properties = (
  5055. {
  5056. x.name: {"type": x.real_type, "default_value": x.default_value}
  5057. for x in self.op_overload._schema.arguments
  5058. }
  5059. if isinstance(self.op_overload, torch._ops.OpOverload)
  5060. else {}
  5061. )
  5062. # FIXME: self.kwargs does not always match kwargs defined in schema, so sometimes
  5063. # ordered_kwargs_for_cpp_kernel is explicitly passed in.
  5064. if isinstance(self.op_overload, torch._ops.OpOverload):
  5065. if not self.ordered_kwargs_for_cpp_kernel:
  5066. self.ordered_kwargs_for_cpp_kernel = [
  5067. x.name for x in self.op_overload._schema.arguments if x.kwarg_only
  5068. ]
  5069. self.schema_kwargs = [
  5070. x for x in self.op_overload._schema.arguments if x.kwarg_only
  5071. ]
  5072. else:
  5073. self.schema_kwargs = []
  5074. def decide_layout(self) -> None:
  5075. if isinstance(self.layout, FlexibleLayout):
  5076. self.apply_constraint()
  5077. self.freeze_layout()
  5078. def codegen_comment(
  5079. self, wrapper: PythonWrapperCodegen, kernel_name: Optional[str] = None
  5080. ) -> None:
  5081. origin_str, _detailed_origin_str = get_kernel_metadata(self, wrapper)
  5082. if origin_str:
  5083. wrapper.make_comment(origin_str)
  5084. if not kernel_name:
  5085. kernel_name = self.try_get_kernel_name()
  5086. if kernel_name:
  5087. from .debug import set_kernel_post_grad_provenance_tracing
  5088. debug_handle = set_kernel_post_grad_provenance_tracing(
  5089. self, kernel_name, is_extern=True
  5090. )
  5091. wrapper.write_provenance_debug_handle(kernel_name, debug_handle)
  5092. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  5093. raise NotImplementedError
  5094. def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None) -> None:
  5095. self.cpp_kernel_name = cpp_kernel_name
  5096. if not V.graph.cpp_wrapper or not isinstance(
  5097. self.op_overload, torch._ops.OpOverload
  5098. ):
  5099. return
  5100. kernel = self.op_overload
  5101. if self.cpp_kernel_name is None:
  5102. # Try to construct cpp_kernel_name from op_overload
  5103. if kernel.namespace == "aten":
  5104. # Calling with the default kernel name can lead to ambiguous behavior like the following example.
  5105. # repeat_interleave(const at::Tensor & repeats, std::optional<int64_t> output_size=std::nullopt)
  5106. # repeat_interleave(const at::Tensor & self, int64_t repeats,
  5107. # std::optional<int64_t> dim=std::nullopt, std::optional<int64_t> output_size=std::nullopt)
  5108. opname = (
  5109. kernel.__name__.split(".")[0]
  5110. if kernel._overloadname == "default"
  5111. else kernel.__name__.replace(".", "_")
  5112. )
  5113. self.cpp_kernel_name = f"at::_ops::{opname}::call"
  5114. else:
  5115. self.cpp_kernel_name = kernel._schema.name
  5116. def set_python_kernel_name(self, python_kernel_name: Optional[str]) -> None:
  5117. self.python_kernel_name = python_kernel_name
  5118. if python_kernel_name is not None:
  5119. return
  5120. kernel = self.op_overload
  5121. if kernel is None:
  5122. pass
  5123. elif isinstance(kernel, torch._ops.HigherOrderOperator):
  5124. self.python_kernel_name = f"torch.ops.higher_order.{kernel.__name__}"
  5125. else:
  5126. self.python_kernel_name = (
  5127. f"{kernel.__module__.replace('._ops.', '.ops.')}.{kernel.__name__}"
  5128. )
  5129. def try_get_kernel_name(self) -> Optional[str]:
  5130. from .codegen.cpp_wrapper_cpu import CppWrapperCpu
  5131. device = d.type if (d := self.get_device()) else V.graph.device_type
  5132. if V.graph.fx_wrapper:
  5133. return self.python_kernel_name
  5134. elif V.graph.cpp_wrapper:
  5135. assert isinstance(V.graph.wrapper_code, CppWrapperCpu), type(
  5136. V.graph.wrapper_code
  5137. )
  5138. if self.cpp_kernel_name is None:
  5139. return None
  5140. return V.graph.wrapper_code.get_c_shim_func_name(
  5141. self.cpp_kernel_name, device
  5142. )
  5143. else:
  5144. return self.python_kernel_name
  5145. def get_kernel_name(self) -> str:
  5146. name = self.try_get_kernel_name()
  5147. assert name is not None
  5148. return name
  5149. @staticmethod
  5150. def copy_input(x: IRNode) -> TensorBox:
  5151. pw = Pointwise.create(
  5152. device=x.get_device(),
  5153. dtype=x.get_dtype(),
  5154. inner_fn=x.make_loader(),
  5155. ranges=x.get_size(),
  5156. origin_node=x.get_origin_node(),
  5157. traceback=x.get_traceback(),
  5158. )
  5159. pw.realize()
  5160. return pw
  5161. @classmethod
  5162. def process_kernel(
  5163. cls, kernel: _OpOverloads, *args: Any, **kwargs: Any
  5164. ) -> tuple[
  5165. Any,
  5166. list[Any],
  5167. list[Any],
  5168. Callable[[Any, Any], Any],
  5169. Optional[dict[sympy.Symbol, pytree.KeyPath]],
  5170. ]:
  5171. binded_args = {"args": args, "kwargs": kwargs}
  5172. args_flat, args_spec = pytree.tree_flatten(binded_args)
  5173. is_arg_tensor = []
  5174. # tensor_args can be either tensor or torchbind objects
  5175. tensor_args = []
  5176. non_tensor_args: list[Any] = []
  5177. for arg in args_flat:
  5178. is_arg_tensor.append(
  5179. isinstance(arg, IRNode) and not isinstance(arg, GeneratorState)
  5180. )
  5181. if is_arg_tensor[-1]:
  5182. tensor_args.append(arg)
  5183. else:
  5184. if isinstance(arg, Expr):
  5185. arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None)
  5186. non_tensor_args.append(arg)
  5187. def unflatten_args(
  5188. new_tensor_args: Sequence[_T], new_non_tensor_args: Sequence[_T]
  5189. ) -> tuple[list[_T], dict[str, _T]]:
  5190. result = []
  5191. it_tensors = iter(new_tensor_args)
  5192. it_non_tensors = iter(new_non_tensor_args)
  5193. for is_tensor in is_arg_tensor:
  5194. if is_tensor:
  5195. result.append(next(it_tensors))
  5196. else:
  5197. result.append(next(it_non_tensors))
  5198. r = pytree.tree_unflatten(result, args_spec)
  5199. return r.get("args", []), r.get("kwargs", {})
  5200. tensor_args = [cls.realize_input(x) for x in tensor_args]
  5201. # freeze layout otherwise our output stride calculation might
  5202. # become incorrect
  5203. for x in tensor_args:
  5204. if is_storage_and_layout(x):
  5205. as_storage_and_layout(x, freeze=True)
  5206. # Rerun fake tensor propagation, because Inductor may have changed the
  5207. # strides of inputs and we need to determine accurately what the
  5208. # output stride will be.
  5209. example_args: list[
  5210. Union[
  5211. torch.Tensor, torch._C.ScriptObject, FakeScriptObject, torch.Generator
  5212. ]
  5213. ] = []
  5214. # We need to retain the constant values of fake tensors that we originally
  5215. # propagated the graph with, because for some operators running without a
  5216. # constant would trigger an error / DataDependentException
  5217. for x in tensor_args:
  5218. # if x is a view of a constant, we need to realize the view
  5219. # (we can't pass the constant into the kernel directly)
  5220. if not isinstance(x, BaseView) and x.get_name() in V.graph.constants:
  5221. example_args.append(V.graph.constants[x.get_name()])
  5222. elif (
  5223. not isinstance(x, BaseView)
  5224. and x.get_name() in V.graph.torchbind_constants
  5225. ):
  5226. example_args.append(V.graph.torchbind_constants[x.get_name()])
  5227. elif isinstance(x, TorchBindObject):
  5228. example_args.append(x.get_value())
  5229. elif isinstance(x, torch._inductor.ir.GeneratorState):
  5230. device_index = x.device.index
  5231. assert x.device.type == "cuda" and device_index is not None
  5232. example_args.append(
  5233. torch.cuda.default_generators[device_index].clone_state()
  5234. )
  5235. else:
  5236. example_args.append(ir_node_to_tensor(x, guard_shape=True))
  5237. new_args, new_kwargs = unflatten_args(example_args, non_tensor_args)
  5238. example_output = kernel(*new_args, **new_kwargs)
  5239. unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None
  5240. if shape_env := V.fake_mode.shape_env:
  5241. node_meta_val = V.current_node.meta.get("val")
  5242. ctx: AbstractContextManager[None] = nullcontext()
  5243. if V.current_node.target is torch._higher_order_ops.effects.with_effects:
  5244. # remove the first effect token in meta["val"] and meta["unbacked_bindings"]
  5245. node_meta_val = node_meta_val[1]
  5246. ctx = _remove_effect_token_unbacked_bindings(V.current_node)
  5247. with ctx:
  5248. rebind_unbacked(shape_env, V.current_node, example_output)
  5249. unbacked_bindings = compute_unbacked_bindings(
  5250. shape_env, example_output, node_meta_val
  5251. )
  5252. example_out_li = (
  5253. [example_output]
  5254. if not isinstance(example_output, (list, tuple))
  5255. else example_output
  5256. )
  5257. # When graph_partition is enabled, skip - partitioning handles sparse outputs
  5258. for t in example_out_li:
  5259. if (
  5260. isinstance(t, torch.Tensor)
  5261. and t.is_sparse
  5262. and not config.graph_partition
  5263. ):
  5264. msg = "sparsity not handled. Please file issue for sparse inference weights."
  5265. if stack_trace := V.graph.current_node.meta.get("stack_trace", None):
  5266. msg = f"{msg} Found from : \n {stack_trace}"
  5267. V.graph.disable_cudagraphs_reason = msg
  5268. return (
  5269. example_output,
  5270. tensor_args,
  5271. non_tensor_args,
  5272. unflatten_args,
  5273. unbacked_bindings,
  5274. )
  5275. @classmethod
  5276. def convert_to_reinterpret_view(cls, x: IRNode) -> ReinterpretView:
  5277. """
  5278. In order to pass this to an extern kernel we need a
  5279. ReinterpretView not a View. This allows us to avoid some
  5280. unneeded copies.
  5281. """
  5282. assert isinstance(x, BaseView), type(x)
  5283. if isinstance(x, ReinterpretView):
  5284. return x
  5285. # NOTE: Don't use extract_read_writes here as it fails when
  5286. # make_loader() inlines the computation
  5287. x_unwrap_view = x.unwrap_view()
  5288. buf = V.graph.get_buffer(x_unwrap_view.get_name())
  5289. assert buf is not None
  5290. x_unwrap_view_fx_node = buf.get_origin_node()
  5291. # Prefer channels last format according to how the format is set from eager.
  5292. if (
  5293. x_unwrap_view_fx_node is not None
  5294. and "val" in x_unwrap_view_fx_node.meta
  5295. and isinstance(x_unwrap_view, (ReinterpretView, Buffer, MutableBox))
  5296. and isinstance(x_unwrap_view.layout, FlexibleLayout)
  5297. and (
  5298. x_unwrap_view_fx_node.meta["val"].is_contiguous(
  5299. memory_format=torch.channels_last
  5300. )
  5301. or x_unwrap_view_fx_node.meta["val"].is_contiguous(
  5302. memory_format=torch.channels_last_3d
  5303. )
  5304. )
  5305. ):
  5306. x_unwrap_view.freeze_layout_with_same_order(
  5307. make_channels_last_strides_for(x_unwrap_view.get_size())
  5308. )
  5309. else:
  5310. x_unwrap_view.freeze_layout()
  5311. index_args, var_ranges = dependencies.index_vars_squeeze(
  5312. x.get_size(), prefix="r"
  5313. )
  5314. range_vars = index_args[0]
  5315. index = x.make_indexer()(range_vars)
  5316. index = V.graph.sizevars.simplify_with_ranges(index, var_ranges)
  5317. strides = V.graph.sizevars.stride_vars(index, range_vars)
  5318. offset = V.graph.sizevars.offset_var(index, range_vars)
  5319. expected = sympy_dot(range_vars, strides) + offset
  5320. if index != expected:
  5321. log.debug(
  5322. "convert_to_reinterpret_view failed: stride=%s offset=%s index=%s",
  5323. strides,
  5324. offset,
  5325. index,
  5326. )
  5327. raise NotImplementedError
  5328. return ReinterpretView(
  5329. data=x.data,
  5330. layout=FixedLayout(
  5331. device=x.get_device_or_error(),
  5332. dtype=x.get_dtype(),
  5333. size=x.get_size(),
  5334. stride=strides,
  5335. offset=offset,
  5336. is_pinned=False,
  5337. ),
  5338. )
  5339. @classmethod
  5340. def realize_input(cls, x: IRNode) -> IRNode:
  5341. if x is None:
  5342. return NoneAsConstantBuffer()
  5343. if isinstance(x, (Expr, sympy.logic.boolalg.Boolean, int)):
  5344. return ShapeAsConstantBuffer(expr=x)
  5345. if isinstance(x, Constant):
  5346. # We need to unset fake mode, or else the torch.tensor() call will
  5347. # turn into a FakeTensor
  5348. with _disable_current_modes():
  5349. return V.graph.add_tensor_constant(
  5350. torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
  5351. )
  5352. if isinstance(x, ConstantBuffer):
  5353. return x
  5354. if isinstance(x, TensorBox):
  5355. return cls.realize_input(x.data)
  5356. if isinstance(x, ReinterpretView):
  5357. return ReinterpretView(
  5358. data=cls.realize_input(x.data), layout=x.get_layout()
  5359. )
  5360. if isinstance(x, BaseView):
  5361. x.realize()
  5362. if is_storage_and_layout(x.unwrap_view()):
  5363. try:
  5364. return cls.convert_to_reinterpret_view(x)
  5365. except NotImplementedError:
  5366. pass
  5367. if isinstance(x, StorageBox):
  5368. # TODO(jansel): impose layout preference on realized buffer
  5369. x.realize()
  5370. return x
  5371. if isinstance(x, (NonTensorObj, ShapeAsConstantBuffer)):
  5372. return x
  5373. return cls.copy_input(x)
  5374. @classmethod
  5375. def require_stride1(cls, x: IRNode) -> IRNode:
  5376. if is_storage_and_layout(x):
  5377. if len(x.get_stride()) == 0:
  5378. return x
  5379. for stride in x.get_stride():
  5380. if stride == 1:
  5381. return x
  5382. return cls.copy_input(x)
  5383. @classmethod
  5384. def require_strides(
  5385. cls,
  5386. x: IRNode,
  5387. order: Optional[Sequence[int]] = None,
  5388. exact_strides: Optional[Sequence[_IntLike]] = None,
  5389. allow_padding: bool = False,
  5390. ) -> IRNode:
  5391. assert order is not None or exact_strides is not None
  5392. # Layout generally doesn't matter, but some consuming external ops might have requirements
  5393. if x.get_numel() in (0, 1) and not exact_strides:
  5394. return x
  5395. # require x to have the layout
  5396. if is_storage_and_layout(x):
  5397. if isinstance(x.get_layout(), FlexibleLayout):
  5398. if order:
  5399. # If the FlexibleLayout already has the size and stride in the required order,
  5400. # freeze it to a FixedLayout by using its current size and stride.
  5401. # The behavior of using its current size and stride or the given order can be different
  5402. # if the size and stride has ambiguilty, for example for a 4D input where the iC = 1:
  5403. # size=[s0, 1, 28, 28], stride=[784, 784, 28, 1]. If the required order is [3, 0, 2, 1] (channels last),
  5404. # the current size and stride already satisfies this order.
  5405. # However by freezing it to the required order, the layout will be changed to:
  5406. # size=[s0, 1, 28, 28], stride=[784, 1, 28, 1]), which is not actually necessary.
  5407. use_current_stride_order = is_stride_order_storage_and_layout(
  5408. x, order
  5409. ) and not free_unbacked_symbols(x.get_layout().stride)
  5410. # fix flexiblelayout to be FixedLayout with stride_order
  5411. as_storage_and_layout(
  5412. x,
  5413. freeze=True,
  5414. want_contiguous=False,
  5415. stride_order=(
  5416. get_stride_order(
  5417. V.graph.sizevars.size_hints_or_throw(
  5418. x.get_layout().stride
  5419. )
  5420. )
  5421. if use_current_stride_order
  5422. else order
  5423. ),
  5424. allow_padding=allow_padding,
  5425. )
  5426. return x
  5427. else:
  5428. # If the exact_strides is given, freeze the FlexibleLayout to a FixedLayout with the exact_strides.
  5429. as_storage_and_layout(
  5430. x,
  5431. freeze=True,
  5432. want_contiguous=False,
  5433. stride_order=None,
  5434. allow_padding=allow_padding,
  5435. exact_strides=exact_strides,
  5436. )
  5437. return x
  5438. elif isinstance(x.get_layout(), (FixedLayout, NonOwningLayout)) and (
  5439. (order and x.get_layout().is_stride_ordered(order))
  5440. or (
  5441. exact_strides
  5442. and significant_strides_equal(
  5443. exact_strides, x.get_layout().stride, x.get_size()
  5444. )
  5445. )
  5446. ):
  5447. return (
  5448. try_match_insignificant_strides(x, exact_strides)
  5449. if exact_strides is not None
  5450. else x
  5451. )
  5452. elif isinstance(
  5453. (mutation_layout := x.get_layout()), MutationLayoutSHOULDREMOVE
  5454. ):
  5455. if isinstance(
  5456. (real_layout := mutation_layout.real_layout()), FlexibleLayout
  5457. ):
  5458. raise AssertionError(
  5459. "the MutationLayoutSHOULDREMOVE's real layout shouldn't be FlexibleLayout"
  5460. )
  5461. elif isinstance(real_layout, FixedLayout) and (
  5462. (order and real_layout.is_stride_ordered(order))
  5463. or (
  5464. exact_strides
  5465. and significant_strides_equal(
  5466. exact_strides, real_layout.stride, x.get_size()
  5467. )
  5468. )
  5469. ):
  5470. return x
  5471. # TODO - Storage to InputBuffer
  5472. if isinstance(x, InputBuffer) and (
  5473. (order and x.get_layout().is_stride_ordered(order))
  5474. or (
  5475. exact_strides
  5476. and significant_strides_equal(
  5477. exact_strides, x.get_layout().stride, x.get_size()
  5478. )
  5479. )
  5480. ):
  5481. return x
  5482. if (
  5483. isinstance(x, TensorBox)
  5484. and isinstance(x.data, BaseView)
  5485. and not isinstance(x.data, ReinterpretView)
  5486. and is_storage_and_layout(unwrap_view := x.unwrap_view())
  5487. and hasattr(unwrap_view, "data")
  5488. and not isinstance(unwrap_view.data, ExternKernelAlloc)
  5489. ):
  5490. try:
  5491. x.data = cls.convert_to_reinterpret_view(x.data)
  5492. if order:
  5493. return cls.require_stride_order(
  5494. x, order, allow_padding=allow_padding
  5495. )
  5496. elif exact_strides:
  5497. return cls.require_exact_strides(
  5498. x, exact_strides, allow_padding=allow_padding
  5499. )
  5500. except NotImplementedError:
  5501. pass
  5502. # Preserve ExpandView representation that would be lost during copy_input
  5503. # Without representation of the expand in inductor IR, in codegen we end up
  5504. # launching a grid for the full size tensor and doing redundant computation
  5505. # across expanded dims.
  5506. # TODO: could also be good to have a codegen fix to recognize overlapping elements
  5507. expanded_dims: Optional[list[int]] = None
  5508. orig_size = x.get_size()
  5509. if exact_strides is not None:
  5510. sizevars = V.graph.sizevars
  5511. expanded_dims = [
  5512. i
  5513. for i in range(len(x.get_size()))
  5514. if sizevars.statically_known_equals(exact_strides[i], 0)
  5515. and sizevars.statically_known_geq(x.get_size()[i], 2)
  5516. ]
  5517. for dim in expanded_dims:
  5518. x = torch._inductor.lowering.slice_(x, dim, 0, 1)
  5519. # Although this is a clone, inductor is good about fusing clones into previous
  5520. # operations if they weren't realized and their layouts were flexible.
  5521. x = cls.copy_input(x)
  5522. as_storage_and_layout(
  5523. x,
  5524. freeze=True,
  5525. want_contiguous=False,
  5526. stride_order=order,
  5527. allow_padding=allow_padding,
  5528. exact_strides=exact_strides,
  5529. )
  5530. if order:
  5531. assert is_stride_order_storage_and_layout(x, order)
  5532. elif expanded_dims:
  5533. assert orig_size is not None and exact_strides is not None
  5534. x = torch._inductor.lowering.expand(x, orig_size)
  5535. # the expand will sometimes may change insignificant strides, so match them back
  5536. return try_match_insignificant_strides(x, exact_strides)
  5537. return x
  5538. @classmethod
  5539. def require_exact_strides(
  5540. cls, x: IRNode, exact_strides: Sequence[_IntLike], allow_padding: bool = False
  5541. ) -> IRNode:
  5542. return cls.require_strides(
  5543. x,
  5544. exact_strides=[
  5545. s.node.expr if isinstance(s, torch.SymInt) else s for s in exact_strides
  5546. ],
  5547. allow_padding=allow_padding,
  5548. )
  5549. @classmethod
  5550. def require_stride_order(
  5551. cls, x: IRNode, order: Sequence[int], allow_padding: bool = False
  5552. ) -> IRNode:
  5553. return cls.require_strides(x, order=order, allow_padding=allow_padding)
  5554. @classmethod
  5555. def require_channels_last(cls, x: IRNode) -> IRNode:
  5556. return cls.require_stride_order(x, NHWC_STRIDE_ORDER)
  5557. @classmethod
  5558. def require_channels_last_3d(cls, x: IRNode) -> IRNode:
  5559. return cls.require_stride_order(x, NHWDC_STRIDE_ORDER)
  5560. @classmethod
  5561. def require_contiguous(cls, x: IRNode) -> IRNode:
  5562. def is_mkldnn_tensor(x: IRNode) -> bool:
  5563. try:
  5564. name = x.get_name()
  5565. except (AttributeError, NotImplementedError):
  5566. return False
  5567. return name in V.graph.constants and V.graph.constants[name].is_mkldnn
  5568. # TODO move this to the more proper places
  5569. if is_mkldnn_tensor(x):
  5570. return x
  5571. else:
  5572. return cls.require_exact_strides(
  5573. x, FlexibleLayout.contiguous_strides(x.get_size())
  5574. )
  5575. @classmethod
  5576. def require_contiguous_strides(cls, x: IRNode) -> IRNode:
  5577. # TODO: combine this with require_contiguous after
  5578. # https://github.com/pytorch/pytorch/pull/148235 lands.
  5579. return cls.require_exact_strides(
  5580. x, FlexibleLayout.contiguous_strides(x.get_size())
  5581. )
  5582. def apply_constraint(self) -> None:
  5583. pass
  5584. def fill_non_provided_args(
  5585. self, args: Sequence[Any], kwargs: dict[str, Any]
  5586. ) -> Sequence[Any]:
  5587. # Previously, we want to maintain forward-compatibility by skipping
  5588. # default args in the serialized artifacts in fbcode. However,
  5589. # some of our shim interfaces require default values being OrderedSet.
  5590. # Discussed with Sherlock offline and we decided to allow serializing
  5591. # default args into the C++ wrapper code for now. We will refine this
  5592. # part if we see real FC requirement. More details related to FC
  5593. # can be found at:
  5594. # https://docs.google.com/document/d/1FzWm-sHYwmRi3x_g036kOxd99KaYquUsA-L5JwOn8ys/edit?usp=sharing
  5595. assert isinstance(args, Sequence), type(args)
  5596. if not isinstance(args, list):
  5597. args = list(args)
  5598. assert self.arg_properties, "ExternKernel.arg_properties should not be empty"
  5599. n_args = len(args)
  5600. n_pos_args = len(self.arg_properties)
  5601. # For cpp wrapper, if some positional args are not provided, we need to check
  5602. # if they're in the kwargs or use their default value
  5603. if n_args < n_pos_args:
  5604. log.debug(
  5605. "%s has %d unprovided positional arguments. "
  5606. "Will check if they are in the keyword arguments or will use default values.",
  5607. self.op_overload,
  5608. n_pos_args - n_args,
  5609. )
  5610. for i in range(n_args, n_pos_args):
  5611. arg_name = self.arg_properties[i]["name"]
  5612. args.append(
  5613. kwargs[arg_name]
  5614. if arg_name in kwargs
  5615. else self.arg_properties[i]["default_value"]
  5616. )
  5617. return args
  5618. def codegen_const_args(self, names: Optional[list[str]] = None) -> list[str]:
  5619. if V.graph.cpp_wrapper:
  5620. result = []
  5621. # Aten ops follow the convention that tensor args are before non-tensor args,
  5622. # in which case the following 'len(self.inputs) + i' logic works. But this
  5623. # may not be true for other ops, and if that is the case, caller needs to
  5624. # pass in a list of const arg names for arg_properties lookup.
  5625. name_to_arg_properties = None
  5626. if names and self.arg_properties:
  5627. assert len(self.constant_args) == len(names), (
  5628. "names passed to codegen_const_args does not match self.constant_args"
  5629. )
  5630. name_to_arg_properties = {
  5631. arg.get("name"): arg for arg in self.arg_properties
  5632. }
  5633. for i, x in enumerate(self.constant_args):
  5634. if name_to_arg_properties is not None:
  5635. assert names is not None
  5636. prop = name_to_arg_properties.get(names[i])
  5637. type_ = prop.get("type") if prop else None
  5638. else:
  5639. idx = len(self.inputs) + i
  5640. type_ = (
  5641. self.arg_properties[idx].get("type")
  5642. if self.arg_properties and idx < len(self.arg_properties)
  5643. else None
  5644. )
  5645. result.append(V.graph.wrapper_code.val_to_arg_str(x, type_))
  5646. return result
  5647. else:
  5648. return [V.graph.wrapper_code.val_to_arg_str(a) for a in self.constant_args]
  5649. def codegen_args(self) -> list[str]:
  5650. if V.graph.cpp_wrapper and self.op_overload is not None:
  5651. # cpp wrapper needs special logic to fill in missing args with default values
  5652. inputs = self.fill_non_provided_args(
  5653. [*self.inputs, *self.constant_args], self.kwargs
  5654. )
  5655. # fill_non_provided_args has handled constant args, so no need to codegen for that later
  5656. need_codegen_constant_args = False
  5657. else:
  5658. inputs = self.inputs
  5659. need_codegen_constant_args = True
  5660. args = []
  5661. for i, x in enumerate(inputs):
  5662. if V.graph.cpp_wrapper:
  5663. assert self.arg_properties and i < len(self.arg_properties), (
  5664. "Invalid access to ExternKernel.arg_properties"
  5665. )
  5666. type_ = self.arg_properties[i].get("type")
  5667. args.append(V.graph.wrapper_code.val_to_arg_str(x, type_))
  5668. else:
  5669. args.append(V.graph.wrapper_code.val_to_arg_str(x))
  5670. if need_codegen_constant_args:
  5671. args.extend(self.codegen_const_args())
  5672. return args
  5673. def get_kwargs_value(self, arg_name: str, **kwargs: Any) -> Any:
  5674. """Given an argument name, queries for values in (in order):
  5675. 1. any provided kwargs for this function.
  5676. 2. the class self.kwargs member.
  5677. 3. any available default arguments in self.allarg_properties."""
  5678. if arg_name in kwargs:
  5679. return kwargs.get(arg_name)
  5680. if arg_name in self.kwargs:
  5681. return self.kwargs.get(arg_name)
  5682. if (arg := self.allarg_properties.get(arg_name)) is not None:
  5683. return arg.get("default_value")
  5684. raise AssertionError(f"{arg_name} not in self.allarg_properties")
  5685. def codegen_kwargs(self, skip_out: bool = False) -> list[str]:
  5686. if V.graph.cpp_wrapper:
  5687. if self.op_overload is not None and len(self.schema_kwargs) == 0:
  5688. # All the args should have been generated by fill_non_provided_args in codegen_args
  5689. return []
  5690. kwargs = []
  5691. for arg_name in self.ordered_kwargs_for_cpp_kernel:
  5692. if skip_out and arg_name == "out":
  5693. # ExternKernelOut has its own logic for inserting the out parameter
  5694. continue
  5695. v = self.get_kwargs_value(arg_name)
  5696. if isinstance(v, Expr):
  5697. kwargs.append(v)
  5698. else:
  5699. assert self.allarg_properties is not None
  5700. type_ = self.allarg_properties.get(arg_name, {}).get("type")
  5701. kwargs.append(V.graph.wrapper_code.val_to_arg_str(v, type_))
  5702. else:
  5703. kwargs = [
  5704. f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}"
  5705. for k, v in self.kwargs.items()
  5706. ]
  5707. return kwargs
  5708. def get_op_name(self) -> str:
  5709. if self.fx_node is not None:
  5710. target = self.fx_node.target
  5711. op_namespace = getattr(target, "__module__", "unknown_namespace")
  5712. op_namespace = op_namespace.replace("._ops.", ".ops.")
  5713. op_namespace = op_namespace.rsplit(".", 1)[0]
  5714. op_name = f"{op_namespace}.{target}"
  5715. else:
  5716. op_name = "unknown_op"
  5717. return op_name
  5718. def codegen_size_asserts(self, wrapper: PythonWrapperCodegen) -> None:
  5719. if config.size_asserts and not V.graph.cpp_wrapper:
  5720. # comparing strides for 0 size tensor is tricky. Ignore them for now.
  5721. if sympy_product(self.get_size()) == 0:
  5722. return
  5723. size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size())
  5724. stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride())
  5725. op_name = self.get_op_name()
  5726. wrapper.writeline(
  5727. f"assert_size_stride({self.get_name()}, {size}, {stride}, {op_name!r})"
  5728. )
  5729. def codegen_alignment_asserts(self, wrapper: PythonWrapperCodegen) -> None:
  5730. if config.alignment_asserts and not V.graph.cpp_wrapper:
  5731. name = self.get_name()
  5732. aligned = name not in V.graph.unaligned_buffers
  5733. op_name = self.get_op_name()
  5734. if aligned:
  5735. wrapper.writeline(
  5736. f"assert_alignment({name}, {GPU_ALIGN_BYTES}, {op_name!r})"
  5737. )
  5738. else:
  5739. wrapper.writeline(
  5740. f"# buffer {name} (op: {op_name}) is assumed to be not aligned"
  5741. )
  5742. def codegen_memory_tracking(self, wrapper: PythonWrapperCodegen) -> None:
  5743. """
  5744. Track outputs of fallback operators if config.test_configs.track_memory_lifecycle
  5745. """
  5746. if not config.test_configs.track_memory_lifecycle or V.graph.cpp_wrapper:
  5747. return
  5748. wrapper.write_memory_track_allocation_once()
  5749. name = self.get_name()
  5750. wrapper.writeline(f"track_tensor({name}, '{name}')")
  5751. def get_group_stride(self) -> tuple[list[Sequence[Expr]], list[Expr]]:
  5752. """
  5753. get output sizes and strides, for template_codegen
  5754. """
  5755. _size = self.get_size()
  5756. _stride = self.get_stride()
  5757. # iter_ranges = _size of output tensor, reduce_range = [] because no reduction
  5758. return [_size, []], _stride
  5759. def canonicalize(self) -> tuple[Expr, Sequence[Expr]]:
  5760. """
  5761. Manually get canonicalization of the output index
  5762. """
  5763. # manually generate index formula for conv
  5764. sizevars = V.graph.sizevars
  5765. sizes = self.get_size()
  5766. strides = self.get_stride()
  5767. strides = [sizevars.size_hint(x) for x in strides]
  5768. # TODO: I can't tell if the symbols here are temporary
  5769. index_vars = [sympy_index_symbol(f"d{i}") for i in range(len(sizes))]
  5770. # reorder index vars according to stride
  5771. index_order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True)
  5772. lookup = {pos: idx for idx, pos in enumerate(index_order)}
  5773. order = [lookup[i] for i in range(len(lookup))]
  5774. index_vars = [index_vars[i] for i in order]
  5775. indexer = self.make_indexer()
  5776. index = indexer(index_vars)
  5777. new_sizes, reindex, _prune = V.graph.sizevars._simplify_loops(
  5778. index_vars, sizes, [index]
  5779. )
  5780. # assign new variables each dimension to deal with numbering mismatches
  5781. # d0, d1, d2 could become d0, d2 -- which won't match d0, d1
  5782. _, add_var = var_builder("c")
  5783. replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
  5784. index = sympy_subs(sympy.expand(index), replacement)
  5785. return index, tuple(new_sizes)
  5786. @cache_on_self_and_args("ExternKernel")
  5787. def get_free_symbol_uses(
  5788. self, unbacked_only: bool = False
  5789. ) -> OrderedSet[sympy.Symbol]:
  5790. # NB: It's not necessary to check regular inputs as we automatically
  5791. # have dependencies on them
  5792. maybe_get_symbols = (
  5793. maybe_free_unbacked_symbols if unbacked_only else maybe_free_symbols
  5794. )
  5795. r = InputsKernel.get_free_symbol_uses(self, unbacked_only)
  5796. for arg in self.constant_args:
  5797. r |= maybe_get_symbols(arg)
  5798. for arg in self.kwargs.values():
  5799. r |= maybe_get_symbols(arg)
  5800. return r
  5801. def __str__(self) -> str:
  5802. kernel_name = getattr(self, "python_kernel_name", None)
  5803. lines = [
  5804. f"python_kernel_name={kernel_name!r}",
  5805. ]
  5806. lines += [
  5807. f"{field.name}={getattr(self, field.name)}"
  5808. for field in dataclasses.fields(self)
  5809. ]
  5810. lines.append(f"origin_node={self.origin_node!r}")
  5811. return self.str_helper(lines)
  5812. __repr__ = __str__
  5813. @ir_dataclass(frozen=False)
  5814. class ExternKernelOut(ExternKernel):
  5815. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  5816. wrapper.generate_extern_kernel_out(self)
  5817. def __init__(
  5818. self,
  5819. layout: Layout,
  5820. inputs: Sequence[IRNode],
  5821. constant_args: Sequence[Any] = (),
  5822. kwargs: Optional[dict[str, Any]] = None,
  5823. output_view: Optional[ReinterpretView] = None,
  5824. python_kernel_name: Optional[str] = None,
  5825. cpp_kernel_name: Optional[str] = None,
  5826. ordered_kwargs_for_cpp_kernel: Sequence[Any] = (),
  5827. op_overload: Optional[_OpOverloads] = None,
  5828. ) -> None:
  5829. unwrapped_inputs = self.unwrap_storage(inputs)
  5830. assert isinstance(unwrapped_inputs, Sequence), type(unwrapped_inputs)
  5831. super().__init__(
  5832. None,
  5833. layout,
  5834. unwrapped_inputs,
  5835. constant_args,
  5836. kwargs or {},
  5837. None,
  5838. python_kernel_name,
  5839. cpp_kernel_name,
  5840. ordered_kwargs_for_cpp_kernel,
  5841. op_overload,
  5842. )
  5843. self.name = V.graph.register_buffer(self)
  5844. V.graph.register_operation(self)
  5845. def should_allocate(self) -> bool:
  5846. return True
  5847. class RandomSeeds(ExternKernelOut):
  5848. def __init__(self, count: int, device: torch.device) -> None:
  5849. limits = torch.iinfo(torch.int64)
  5850. super().__init__(
  5851. layout=FixedLayout(
  5852. device=device,
  5853. dtype=torch.int64,
  5854. size=[count],
  5855. ),
  5856. inputs=[],
  5857. constant_args=[limits.min, limits.max, [count]],
  5858. python_kernel_name="aten.randint.low_out",
  5859. # FIXME: Ideally we should only use at::_ops::randint_low_out::call here,
  5860. # but the signature is different from is at::randint_out. Again,
  5861. # we can simplify the code when only keeping an ABI-compatible version.
  5862. cpp_kernel_name="at::_ops::randint_low_out::call",
  5863. op_overload=aten.randint.low_out,
  5864. )
  5865. class ExternKernelAlloc(ExternKernel):
  5866. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  5867. wrapper.generate_extern_kernel_alloc(self)
  5868. def __init__(
  5869. self,
  5870. layout: OutputSpec,
  5871. inputs: Sequence[IRNode],
  5872. constant_args: Sequence[Any] = (),
  5873. kwargs: Optional[dict[str, Any]] = None,
  5874. python_kernel_name: Optional[str] = None,
  5875. cpp_kernel_name: Optional[str] = None,
  5876. ordered_kwargs_for_cpp_kernel: Sequence[Any] = (),
  5877. op_overload: Optional[_OpOverloads] = None,
  5878. ) -> None:
  5879. unwrapped_inputs = self.unwrap_storage(inputs)
  5880. assert all(isinstance(i, IRNode) for i in unwrapped_inputs)
  5881. super().__init__(
  5882. None,
  5883. layout,
  5884. cast(Sequence[IRNode], unwrapped_inputs),
  5885. constant_args,
  5886. kwargs or {},
  5887. None,
  5888. python_kernel_name,
  5889. cpp_kernel_name,
  5890. ordered_kwargs_for_cpp_kernel,
  5891. op_overload,
  5892. )
  5893. # We need output buffers for generating kernel arguments in the
  5894. # abi-compatible mode, where we retrieve outputs by pass each individual
  5895. # output through the abi-compatible interface.
  5896. self.outputs: Sequence[Any] = []
  5897. self.name = V.graph.register_buffer(self)
  5898. V.graph.register_operation(self)
  5899. def should_allocate(self) -> bool:
  5900. return False
  5901. def apply_constraint(self) -> None:
  5902. raise NotImplementedError
  5903. class MutationOutput(Buffer):
  5904. """
  5905. An output buffer that represents the mutation of a pre-existing buffer
  5906. """
  5907. def __init__(
  5908. self, layout: OutputSpec, mutated_node: IRNode, mutating_node: Operation
  5909. ) -> None:
  5910. super().__init__(name=None, layout=layout)
  5911. mutated_node_name = mutated_node.get_name()
  5912. V.graph.mark_buffer_mutated(mutated_node_name)
  5913. self.mutation_names = [mutated_node_name]
  5914. self.mutating_node: Operation = mutating_node
  5915. self.name = V.graph.register_buffer(self)
  5916. def get_defining_op(self) -> Operation:
  5917. return self.mutating_node
  5918. def get_mutation_names(self) -> Sequence[str]:
  5919. return self.mutation_names
  5920. def should_allocate(self) -> bool:
  5921. return False
  5922. def get_mutation_buffers(self) -> Sequence[IRNode]:
  5923. mutation_names = self.get_mutation_names()
  5924. return [
  5925. buf
  5926. for buf in (V.graph.try_get_buffer(name) for name in mutation_names)
  5927. if buf is not None
  5928. ]
  5929. class TMADescriptor(ExternKernel):
  5930. """
  5931. An IR node representing a generic host-side TMA descriptor in the Triton API
  5932. Mostly useful for user-defined Triton kernels relying on host-side TMA;
  5933. but can, in principle, be used for Inductor's Triton templates, too.
  5934. See TMADescriptorExperimental and TMADescriptorStable for the two implementations
  5935. (the old API and the new API)
  5936. """
  5937. # as TMA descriptors are immutable,
  5938. # we can dedup them by the input args
  5939. _CACHE: dict[Any, TMADescriptor] = {}
  5940. @classmethod
  5941. def _create_impl(
  5942. cls, tensor: IRNode, tma_meta: tuple[str, tuple[Any, ...]]
  5943. ) -> TMADescriptor:
  5944. assert len(tma_meta) == 2
  5945. if tma_meta[0] == "experimental":
  5946. return TMADescriptorExperimental(tensor, *tma_meta[1])
  5947. else:
  5948. assert tma_meta[0] == "stable"
  5949. return TMADescriptorStable(tensor, *tma_meta[1])
  5950. @classmethod
  5951. def create(
  5952. cls, tensor: IRNode, tma_meta: tuple[str, tuple[Any, ...]]
  5953. ) -> TMADescriptor:
  5954. key = (id(tensor), tma_meta)
  5955. if key not in cls._CACHE:
  5956. cls._CACHE[key] = cls._create_impl(tensor, tma_meta)
  5957. return cls._CACHE[key]
  5958. def __init__(
  5959. self, tensor: IRNode, inputs: Sequence[Any], constant_args: Sequence[Any]
  5960. ) -> None:
  5961. super().__init__(
  5962. None,
  5963. # link back to the underlying tensor in terms of ownership
  5964. # to avoid getting the underlying tensor deleted *before*
  5965. # the TMADescriptor node can be deleted.
  5966. NonOwningLayout(
  5967. ReinterpretView(
  5968. data=tensor,
  5969. layout=tensor.get_layout(),
  5970. )
  5971. ),
  5972. cast(Sequence[Buffer], inputs),
  5973. tuple(constant_args),
  5974. None,
  5975. )
  5976. self.tensor = tensor
  5977. self.name = V.graph.register_buffer(self)
  5978. V.graph.register_operation(self)
  5979. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  5980. wrapper.generate_tma_descriptor(self)
  5981. def get_tensor(self) -> IRNode:
  5982. return self.tensor
  5983. class TMADescriptorExperimental(TMADescriptor):
  5984. """
  5985. the new host-side TMA Descriptor API:
  5986. (the ones obtained via create_{1d,2d}_tma_descriptor calls).
  5987. See also TMADescriptorStable for the new API.
  5988. """
  5989. def __init__(
  5990. self,
  5991. tensor: IRNode,
  5992. dims: list[Union[int, torch.SymInt]],
  5993. block_dims: list[Union[int, torch.SymInt]],
  5994. element_size: Optional[int] = None,
  5995. ) -> None:
  5996. assert len(dims) in (1, 2)
  5997. assert len(dims) == len(block_dims)
  5998. if element_size is None:
  5999. element_size = tensor.get_dtype().itemsize
  6000. self.dims = dims
  6001. self.block_dims = block_dims
  6002. self.element_size = element_size
  6003. self.rank = len(self.dims)
  6004. inputs = [tensor]
  6005. constant_args = [
  6006. *self.dims,
  6007. *self.block_dims,
  6008. self.element_size,
  6009. ]
  6010. super().__init__(
  6011. tensor=tensor,
  6012. inputs=inputs,
  6013. constant_args=constant_args,
  6014. )
  6015. class TMADescriptorStable(TMADescriptor):
  6016. """
  6017. the new host-side TMA descriptor API
  6018. (the ones obtained via TensorDescriptor.from_tensor).
  6019. See also TMADescriptorExperimental for the old API.
  6020. """
  6021. def __init__(self, tensor: IRNode, block_shape: list[Union[int, torch.SymInt]]):
  6022. self.block_shape = block_shape
  6023. super().__init__(
  6024. tensor=tensor,
  6025. inputs=[tensor],
  6026. constant_args=block_shape,
  6027. )
  6028. class SubgraphBuffer(ExternKernel):
  6029. def __init__(
  6030. self,
  6031. layout: Layout,
  6032. input_nodes: list[Buffer],
  6033. gm: torch.fx.GraphModule,
  6034. example_inputs: list[Any],
  6035. subgraph_name: str,
  6036. ):
  6037. super().__init__(None, layout, input_nodes)
  6038. self.gm = gm
  6039. self.example_inputs = example_inputs
  6040. self.name = V.graph.register_buffer(self)
  6041. V.graph.register_operation(self)
  6042. self.subgraph = V.graph.make_subgraph(self.gm, example_inputs, subgraph_name)
  6043. assert is_node_sequence(self.inputs)
  6044. sym_inputs = get_symbolic_inputs(self.inputs)
  6045. for sym_inp in sym_inputs:
  6046. self.subgraph.graph_inputs[sym_inp.name] = sym_inp
  6047. self.subgraph.graph_input_names.append(sym_inp.name)
  6048. self.sym_inputs = [sym_var.name for sym_var in sym_inputs]
  6049. import torch._inductor.config as inductor_config
  6050. with V.set_graph_handler(self.subgraph):
  6051. # Don't bother autotuning on Triton here
  6052. with inductor_config.patch(
  6053. max_autotune=False,
  6054. max_autotune_gemm=False,
  6055. max_autotune_gemm_backends="ATEN",
  6056. ):
  6057. self.subgraph.run(*self.example_inputs)
  6058. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  6059. class CodegenGraph:
  6060. def __init__(self, graph: GraphLowering):
  6061. self.graph = graph
  6062. self.name = graph.name
  6063. assert is_node_sequence(self.inputs)
  6064. outer_inputs = [t.codegen_reference() for t in self.inputs]
  6065. wrapper.codegen_subgraph_with_flattened_outputs(
  6066. CodegenGraph(self.subgraph),
  6067. [*self.sym_inputs, *outer_inputs],
  6068. [self.name],
  6069. )
  6070. class UserDefinedTritonKernel(ExternKernel):
  6071. def get_kernel_and_metadata(self) -> tuple[Kernel, Any, list[str], list[str]]:
  6072. from triton.runtime.autotuner import Autotuner
  6073. from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
  6074. kernel = kernel_side_table.get_kernel(self.kernel_idx)
  6075. configs = []
  6076. restore_value_args: list[str] = []
  6077. reset_to_zero_args: list[str] = []
  6078. if isinstance(kernel, Autotuner):
  6079. # https://github.com/triton-lang/triton/pull/5083
  6080. # changes kernel.restore_idx to kernel.restore_value
  6081. if hasattr(kernel, "restore_idx"):
  6082. restore_value_args.extend(
  6083. kernel.fn.arg_names[i] for i in kernel.restore_idx
  6084. )
  6085. else:
  6086. assert hasattr(kernel, "restore_value")
  6087. restore_value_args.extend(kernel.restore_value)
  6088. if hasattr(kernel, "reset_idx"):
  6089. for i in kernel.reset_idx:
  6090. reset_to_zero_args.append(kernel.fn.arg_names[i])
  6091. else:
  6092. assert hasattr(kernel, "reset_to_zero")
  6093. reset_to_zero_args.extend(kernel.reset_to_zero)
  6094. configs = kernel.configs
  6095. kernel = kernel.fn
  6096. return kernel, configs, restore_value_args, reset_to_zero_args
  6097. @override
  6098. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  6099. """Overrides the parent member.
  6100. See https://github.com/pytorch/pytorch/issues/151692"""
  6101. from torch._inductor.utils import triton_version_uses_attrs_dict
  6102. (
  6103. kernel,
  6104. configs,
  6105. restore_value_args,
  6106. reset_to_zero_args,
  6107. ) = self.get_kernel_and_metadata()
  6108. # Definition of kernel
  6109. (
  6110. new_name,
  6111. triton_meta,
  6112. inductor_meta,
  6113. extra_launch_args,
  6114. ) = wrapper.define_user_defined_triton_kernel(
  6115. kernel,
  6116. configs,
  6117. self.kwargs,
  6118. restore_value_args,
  6119. reset_to_zero_args,
  6120. self.grid,
  6121. )
  6122. named_args = {
  6123. k: self.get_kwargs_value(k) for k in self.ordered_kwargs_for_cpp_kernel
  6124. }
  6125. arg_names = [p.name for p in kernel.params] # type: ignore[attr-defined]
  6126. constexprs = [p.num for p in kernel.params if p.is_constexpr] # type: ignore[attr-defined]
  6127. constexpr_names = OrderedSet(arg_names[i] for i in constexprs)
  6128. args: list[Any] = []
  6129. arg_types: list[Any] = []
  6130. raw_keys_filtered: list[Any] = []
  6131. raw_args_filtered: list[Any] = []
  6132. for name, arg in itertools.chain(
  6133. named_args.items(), zip(itertools.repeat(""), extra_launch_args)
  6134. ):
  6135. if name in constexpr_names and triton_version_uses_attrs_dict():
  6136. # see #160000 - we don't pass in constexpr args to speed up runtime.
  6137. continue
  6138. raw_keys_filtered.append(name)
  6139. raw_args_filtered.append(arg)
  6140. if isinstance(arg, IRNode):
  6141. args.append(arg.codegen_reference())
  6142. arg_types.append(arg.get_dtype())
  6143. elif isinstance(arg, (int, float, bool, sympy.Expr)):
  6144. args.append(arg)
  6145. arg_types.append(type(arg))
  6146. elif name in constexpr_names:
  6147. # insert a dummy value for constexpr args of unsupported type
  6148. # constexprs will end up getting baked into the kernel at compile time
  6149. args.append(-1)
  6150. arg_types.append(int)
  6151. elif arg is None:
  6152. """
  6153. Filter out None args.
  6154. see https://github.com/pytorch/pytorch/issues/115344
  6155. Two cases for a None arg:
  6156. 1. The arg is already tl.constexpr, so leave it in
  6157. 2. The arg is not tl.constexpr so we have to remove it
  6158. """
  6159. if triton_version_uses_attrs_dict():
  6160. args.append(-1)
  6161. arg_types.append(int)
  6162. else:
  6163. raw_keys_filtered.pop()
  6164. raw_args_filtered.pop()
  6165. else:
  6166. raise NotImplementedError(f"Unsupported arg type: {type(arg)}: {arg}")
  6167. self.codegen_comment(wrapper, new_name)
  6168. wrapper.generate_kernel_call(
  6169. new_name,
  6170. args,
  6171. arg_types=arg_types,
  6172. raw_args=raw_args_filtered,
  6173. raw_keys=raw_keys_filtered,
  6174. triton_meta=triton_meta,
  6175. inductor_meta=inductor_meta,
  6176. triton=True,
  6177. device=self.get_device(),
  6178. original_fxnode_name=self.fx_node.name,
  6179. )
  6180. @cache_on_self_and_args("UserDefinedTritonKernel")
  6181. def get_free_symbol_uses(
  6182. self, unbacked_only: bool = False
  6183. ) -> OrderedSet[sympy.Symbol]:
  6184. # add unbacked symbols used in the grid to the ones used
  6185. # in the kwargs (the latter is generated by ExternKernel)
  6186. return super().get_free_symbol_uses(unbacked_only) | get_free_symbols(
  6187. self.grid, unbacked_only
  6188. )
  6189. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  6190. return OrderedSet()
  6191. def __init__(
  6192. self,
  6193. *,
  6194. kernel_idx: int,
  6195. grid: Any,
  6196. tma_descriptor_metadata: dict[str, Any],
  6197. kernel_args: dict[str, Any],
  6198. ) -> None:
  6199. inputs: list[IRNode] = []
  6200. kwargs: dict[str, IRNode] = {}
  6201. constant_args: list[IRNode] = []
  6202. for k, v in kernel_args.items():
  6203. if isinstance(v, TensorBox):
  6204. t = InputsKernel.unwrap_storage_for_input(self.realize_input(v))
  6205. if k in tma_descriptor_metadata:
  6206. t = TMADescriptor.create(t, tma_descriptor_metadata[k])
  6207. inputs.append(t)
  6208. kwargs[k] = t
  6209. else:
  6210. constant_args.append(v)
  6211. kwargs[k] = v
  6212. assert len(inputs) != 0
  6213. self.device = inputs[0].get_device()
  6214. assert isinstance(inputs, Sequence), type(inputs)
  6215. super().__init__(
  6216. None,
  6217. NoneLayout(device=self.device),
  6218. inputs,
  6219. tuple(constant_args),
  6220. kwargs,
  6221. )
  6222. self.kernel_idx = kernel_idx
  6223. self.grid = grid
  6224. kernel, configs, _, _ = self.get_kernel_and_metadata()
  6225. # If we are autotuning, not all arguments will be passed
  6226. assert hasattr(kernel, "arg_names")
  6227. self.ordered_kwargs_for_cpp_kernel = [
  6228. arg for arg in kernel.arg_names if arg in kernel_args
  6229. ]
  6230. from torch._higher_order_ops.triton_kernel_wrap import identify_mutated_tensors
  6231. autotuned_kwargs = configs[0].kwargs if len(configs) > 0 else {}
  6232. self.mutable_args = [
  6233. kernel_args[key]
  6234. for key in identify_mutated_tensors(
  6235. kernel,
  6236. {**kernel_args, **autotuned_kwargs},
  6237. tma_descriptor_metadata,
  6238. )
  6239. ]
  6240. self.mutation_outputs = [
  6241. MutationOutput(NoneLayout(device=self.device), buf, self)
  6242. for buf in self.mutable_args
  6243. ]
  6244. V.graph.register_operation(self)
  6245. def get_outputs(self) -> list[Buffer]:
  6246. return list(self.mutation_outputs)
  6247. def get_device(self) -> Optional[torch.device]:
  6248. return self.device
  6249. class InplaceBernoulliFallback(ExternKernel):
  6250. """
  6251. This needs to be a custom class to handle mutation properly
  6252. """
  6253. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  6254. assert all(isinstance(t, IRNode) for t in self.inputs)
  6255. (x,) = (cast(IRNode, t).codegen_reference() for t in self.inputs)
  6256. if V.graph.cpp_wrapper:
  6257. # Inductor doesn't really support aten Generator, so the Generator kwarg is always NULL here,
  6258. # which needs to be explicitly generated for cpp wrapper
  6259. wrapper.writeline(
  6260. f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}, NULL){wrapper.ending}"
  6261. )
  6262. else:
  6263. wrapper.writeline(
  6264. f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}"
  6265. )
  6266. def should_allocate(self) -> bool:
  6267. return False
  6268. def get_mutation_names(self) -> Sequence[str]:
  6269. return [self.input_name(0)]
  6270. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  6271. return OrderedSet()
  6272. def __init__(
  6273. self, op_overload: _OpOverloads, x: IRNode, *constant_args: Any
  6274. ) -> None:
  6275. super().__init__(
  6276. None,
  6277. NoneLayout(device=x.get_device()),
  6278. self.unwrap_storage([x]),
  6279. constant_args,
  6280. op_overload=op_overload,
  6281. )
  6282. V.graph.mark_buffer_mutated(x.get_name())
  6283. self.name = V.graph.register_buffer(self)
  6284. V.graph.register_operation(self)
  6285. # Used to deal with torch.complex types
  6286. class InplaceCopyFallback(ExternKernel):
  6287. """
  6288. This needs to be a custom class to handle mutation properly
  6289. """
  6290. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  6291. (dst, src, non_blocking) = self.codegen_args()
  6292. wrapper.codegen_device_copy(src, dst, non_blocking)
  6293. def should_allocate(self) -> bool:
  6294. return False
  6295. def get_mutation_names(self) -> Sequence[str]:
  6296. return [self.input_name(0)]
  6297. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  6298. return OrderedSet()
  6299. def __init__(
  6300. self,
  6301. layout: OutputSpec,
  6302. inputs: Sequence[IRNode],
  6303. constant_args: Sequence[Any],
  6304. ) -> None:
  6305. super().__init__(
  6306. None,
  6307. layout,
  6308. inputs,
  6309. constant_args,
  6310. python_kernel_name="aten.copy_",
  6311. cpp_kernel_name="aoti_torch_copy_",
  6312. )
  6313. V.graph.mark_buffer_mutated(inputs[0].get_name())
  6314. self.name = V.graph.register_buffer(self)
  6315. V.graph.register_operation(self)
  6316. @classmethod
  6317. def create(
  6318. cls, dst: IRNode, src: IRNode, non_blocking: bool = False
  6319. ) -> InplaceCopyFallback:
  6320. inputs = [cls.realize_input(t) for t in [dst, src]]
  6321. constant_args = (non_blocking,)
  6322. result = InplaceCopyFallback(
  6323. NoneLayout(device=dst.get_device()),
  6324. inputs,
  6325. constant_args,
  6326. )
  6327. return result
  6328. class MutatingFirstArgExternKernel(ExternKernel):
  6329. """
  6330. This needs to be a custom class to handle mutation properly
  6331. """
  6332. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  6333. assert is_node_sequence(self.inputs)
  6334. argrefs = [
  6335. *(t.codegen_reference() for t in self.inputs),
  6336. *map(repr, self.constant_args),
  6337. ]
  6338. wrapper.writeline(
  6339. f"{self.get_kernel_name()}({', '.join(argrefs)}){wrapper.ending}"
  6340. )
  6341. def should_allocate(self) -> bool:
  6342. return False
  6343. def get_mutation_names(self) -> Sequence[str]:
  6344. return [self.input_name(0)]
  6345. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  6346. return OrderedSet()
  6347. def has_side_effects(self) -> bool:
  6348. return True
  6349. class ResizeStorageBytes(MutatingFirstArgExternKernel):
  6350. def __init__(self, variable: IRNode, new_size: int) -> None:
  6351. assert isinstance(new_size, int), "TODO: dynamic shapes"
  6352. super().__init__(
  6353. None,
  6354. NoneLayout(device=variable.get_device()),
  6355. self.unwrap_storage([variable]),
  6356. constant_args=(new_size,),
  6357. )
  6358. V.graph.mark_buffer_mutated(variable.get_name())
  6359. self.name = V.graph.register_buffer(self)
  6360. V.graph.register_operation(self)
  6361. self.python_kernel_name = "inductor_ops.resize_storage_bytes_"
  6362. self.cpp_kernel_name = "torch::inductor::resize_storage_bytes_"
  6363. assert isinstance(variable, (BaseView, StorageBox, TensorBox)), type(variable)
  6364. V.graph.never_reuse_buffers.add(variable.data.get_name())
  6365. class SetSourceTensorKernel(ExternKernelAlloc):
  6366. def __init__(self, self_tensor: IRNode, storage_tensor: IRNode) -> None:
  6367. storage_tensor.freeze_layout()
  6368. super().__init__(
  6369. storage_tensor.get_layout(),
  6370. [self_tensor, storage_tensor],
  6371. python_kernel_name="torch.ops.aten.set_.source_Tensor",
  6372. op_overload=torch.ops.aten.set_.source_Tensor,
  6373. )
  6374. assert isinstance(self_tensor, (BaseView, StorageBox, TensorBox)), type(
  6375. self_tensor
  6376. )
  6377. V.graph.never_reuse_buffers.add(self_tensor.data.get_name())
  6378. V.graph.never_reuse_buffers.add(storage_tensor.get_name())
  6379. V.graph.never_reuse_buffers.add(self.get_name())
  6380. device = storage_tensor.get_device()
  6381. self.mutation_outputs = [
  6382. MutationOutput(NoneLayout(device=device), self_tensor, self),
  6383. MutationOutput(NoneLayout(device=device), storage_tensor, self),
  6384. ]
  6385. def get_inputs_that_alias_output(self) -> Sequence[str]:
  6386. return [self.input_name(0), self.input_name(1)]
  6387. class ScatterFallback(ExternKernel):
  6388. """
  6389. This needs to be a custom class to handle mutation properly.
  6390. This class handles both aten.scatter_ and aten.scatter_reduce_.
  6391. It also handle the case `src` being a scalar properly.
  6392. """
  6393. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  6394. wrapper.generate_scatter_fallback(self)
  6395. def should_allocate(self) -> bool:
  6396. return False
  6397. def get_mutation_names(self) -> list[str]:
  6398. inp = self.inputs[0]
  6399. assert isinstance(inp, IRNode)
  6400. return [inp.get_name()]
  6401. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  6402. return OrderedSet()
  6403. def __init__(
  6404. self,
  6405. op_overload: _OpOverloads,
  6406. x: IRNode,
  6407. dim: int,
  6408. index: IRNode,
  6409. src: IRNode,
  6410. *,
  6411. reduce: Optional[str] = None,
  6412. include_self: bool = True,
  6413. ) -> None:
  6414. self.src_is_tensor = isinstance(src, TensorBox)
  6415. constant_args: tuple[Any, ...]
  6416. if self.src_is_tensor:
  6417. tensors = [self.realize_input(t) for t in [x, index, src]]
  6418. constant_args = (dim,)
  6419. else:
  6420. tensors = [self.realize_input(t) for t in [x, index]]
  6421. constant_args = (dim, src)
  6422. super().__init__(
  6423. None,
  6424. NoneLayout(device=x.get_device()),
  6425. self.unwrap_storage(tensors),
  6426. constant_args,
  6427. {"reduce": reduce, "include_self": include_self},
  6428. python_kernel_name=str(op_overload),
  6429. ordered_kwargs_for_cpp_kernel=["reduce", "include_self"],
  6430. op_overload=op_overload,
  6431. )
  6432. V.graph.mark_buffer_mutated(x.get_name())
  6433. self.name = V.graph.register_buffer(self)
  6434. V.graph.register_operation(self)
  6435. class IndexPutFallback(ExternKernel):
  6436. """
  6437. This needs to be a custom class to handle mutation and indices properly
  6438. """
  6439. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  6440. wrapper.generate_index_put_fallback(self)
  6441. def should_allocate(self) -> bool:
  6442. return False
  6443. def get_mutation_names(self) -> Sequence[str]:
  6444. return [self.input_name(0)]
  6445. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  6446. return OrderedSet()
  6447. def __init__(
  6448. self,
  6449. op_overload: torch._ops.OpOverload,
  6450. x: IRNode,
  6451. indices: list[Any],
  6452. values: Sequence[Any],
  6453. accumulate: Any,
  6454. ) -> None:
  6455. self.indices = indices
  6456. valid_indices = [i for i in indices if i is not None]
  6457. # pyrefly: ignore [bad-argument-type]
  6458. tensors = [self.realize_input(x) for x in [x, values, *valid_indices]]
  6459. cpp_kernel_name = "aoti_torch_index_put_out"
  6460. super().__init__(
  6461. None,
  6462. NoneLayout(device=x.get_device()),
  6463. self.unwrap_storage(tensors),
  6464. (accumulate,),
  6465. python_kernel_name="aten.index_put_",
  6466. cpp_kernel_name=cpp_kernel_name,
  6467. op_overload=op_overload,
  6468. )
  6469. V.graph.mark_buffer_mutated(self.input_name(0))
  6470. self.name = V.graph.register_buffer(self)
  6471. V.graph.register_operation(self)
  6472. class DeviceCopy(ExternKernelOut):
  6473. @classmethod
  6474. def create(cls, x: IRNode, device: torch.device, non_blocking: bool) -> IRNode:
  6475. x_device = x.get_device()
  6476. assert x_device is not None
  6477. if (
  6478. not x.is_extern()
  6479. # Can not apply this optimization if x has been mutated
  6480. and try_get_name(x) not in V.graph.mutated_buffers
  6481. and all(r in V.graph.constants for r in x.get_read_names())
  6482. and not config.aot_inductor.use_runtime_constant_folding
  6483. ):
  6484. if V.graph.cpp_wrapper:
  6485. # Even if x is promoted to be a device constant, we still need to
  6486. # register device info to construct the correct CppWrapper class later
  6487. V.graph.add_device_info(device)
  6488. V.graph.add_device_info(x_device)
  6489. return x.constant_to_device(device)
  6490. V.graph.add_device_info(device)
  6491. V.graph.add_device_info(x_device)
  6492. developer_warning("DeviceCopy in input program")
  6493. constant_args = (non_blocking,)
  6494. # Device Copy should keep the same layout as input
  6495. x = ExternKernel.require_contiguous(x)
  6496. stride = None
  6497. if x.get_size():
  6498. # x.get_stride() may be unimplemented if x's size is empty
  6499. stride = x.get_stride()
  6500. is_destination_pinned = (
  6501. is_gpu(x_device.type) and device.type == "cpu" and non_blocking
  6502. )
  6503. is_source_pinned = (
  6504. x_device.type == "cpu" and is_gpu(device.type) and non_blocking
  6505. )
  6506. if is_source_pinned and is_storage_and_layout(x):
  6507. x.get_layout().is_pinned = True
  6508. return DeviceCopy(
  6509. FixedLayout(
  6510. device,
  6511. x.get_dtype(),
  6512. x.get_size(),
  6513. stride,
  6514. is_pinned=is_destination_pinned,
  6515. ),
  6516. [cls.realize_input(x)],
  6517. constant_args,
  6518. )
  6519. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  6520. args = self.codegen_args()
  6521. assert len(args) == 2
  6522. if self.output_view:
  6523. wrapper.codegen_device_copy(
  6524. args[0], self.output_view.codegen_reference(), args[1]
  6525. )
  6526. else:
  6527. wrapper.codegen_device_copy(args[0], self.codegen_reference(), args[1])
  6528. class DynamicSelectStorageOffset(ExternKernel):
  6529. """
  6530. The result of computing a dynamic selection index is determined as follows: when the index in the
  6531. select operation is unbacked, the actual index calculation is ambiguous for negative indices
  6532. (index + size) versus non-negative indices (just index). To resolve this, we allocate an unbacked
  6533. SymInt to represent the storage offset and decompose the select operation into a call to as_strided,
  6534. computing the storage offset at runtime with this node.
  6535. """
  6536. def get_reads(self) -> OrderedSet[Dep]:
  6537. return OrderedSet()
  6538. def should_allocate(self) -> bool:
  6539. return False
  6540. def __init__(
  6541. self,
  6542. unbacked_offset_symbol: sympy.Symbol,
  6543. index: sympy.Symbol,
  6544. base_offset: Union[sympy.Symbol, int],
  6545. base_dim_stride: Union[sympy.Symbol, int],
  6546. size: Union[sympy.Symbol, int],
  6547. clamp: bool,
  6548. ) -> None:
  6549. super().__init__(None, NoneLayout(device=torch.device("cpu")), [])
  6550. # This node codegen the following:
  6551. # unbacked_offset_symbol = base_offset + base_dim_stride * (index if index >=0 else index + size)
  6552. self.unbacked_offset_symbol = unbacked_offset_symbol
  6553. self.index = index
  6554. self.base_offset = base_offset
  6555. self.base_dim_stride = base_dim_stride
  6556. self.size = size
  6557. self.clamp = clamp
  6558. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  6559. return OrderedSet([self.unbacked_offset_symbol])
  6560. @cache_on_self_and_args("DynamicSelectStorageOffset")
  6561. def get_free_symbol_uses(
  6562. self, unbacked_only: bool = False
  6563. ) -> OrderedSet[sympy.Symbol]:
  6564. return get_free_symbols(self.index, unbacked_only)
  6565. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  6566. wrapper.codegen_dynamic_select_index(self, clamp=self.clamp)
  6567. class DynamicSliceSize(ExternKernel):
  6568. """
  6569. Computes the output size of a slice call, handling the correct semantics in codegen.
  6570. We do this for flexible handling for unbacked indices (to not data-dependent error).
  6571. Slicing has 4 semantics for indices, i.e. x[start:] could be:
  6572. 1) start < -x.size(0) -> x[0:] # negative out-of-bounds
  6573. 2) start in [-x.size(0), 0) -> x[x.size(0) + start:] # negative slicing
  6574. 3) start in [0, x.size(0)) -> x[start:] # standard slicing
  6575. 4) start >= x.size(0) -> empty slice # positive out-of-bounds
  6576. If the appropriate semantics are known beforehand, the output size is computed based on
  6577. the start & end indices. If not (with unbacked indices), a new unbacked symbol is created
  6578. to represent the output size, and codegen handles computing the correct case.
  6579. """
  6580. def get_reads(self) -> OrderedSet[Dep]:
  6581. return OrderedSet()
  6582. def should_allocate(self) -> bool:
  6583. return False
  6584. def __init__(
  6585. self,
  6586. unbacked_size_symbol: sympy.Symbol,
  6587. start: Union[sympy.Symbol, int],
  6588. end: Union[sympy.Symbol, int],
  6589. step: Union[sympy.Symbol, int],
  6590. size: Union[sympy.Symbol, int],
  6591. ):
  6592. super().__init__(None, NoneLayout(device=torch.device("cpu")), [])
  6593. # This node codegen
  6594. self.unbacked_size_symbol = unbacked_size_symbol
  6595. self.start = start
  6596. self.end = end
  6597. self.step = step
  6598. self.size = size
  6599. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  6600. return OrderedSet([self.unbacked_size_symbol])
  6601. @cache_on_self_and_args("DynamicSliceSize")
  6602. def get_free_symbol_uses(
  6603. self, unbacked_only: bool = False
  6604. ) -> OrderedSet[sympy.Symbol]:
  6605. return get_free_symbols(self.start, unbacked_only).union(
  6606. get_free_symbols(self.end, unbacked_only)
  6607. )
  6608. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  6609. wrapper.codegen_dynamic_slice_size(self)
  6610. class DynamicScalar(ExternKernel):
  6611. """
  6612. The result of a call to aten._local_scalar_dense.
  6613. """
  6614. def get_reads(self) -> OrderedSet[Dep]:
  6615. return OrderedSet()
  6616. def should_allocate(self) -> bool:
  6617. return False
  6618. def __init__(
  6619. self, sym: sympy.Symbol, keypath: pytree.KeyPath, data: IRNode
  6620. ) -> None:
  6621. data.realize()
  6622. super().__init__(
  6623. None, NoneLayout(device=torch.device("cpu")), self.unwrap_storage([data])
  6624. )
  6625. self.sym = sym
  6626. self.keypath = keypath
  6627. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  6628. return OrderedSet([self.sym])
  6629. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  6630. wrapper.codegen_dynamic_scalar(self)
  6631. class AssertScalar(ExternKernel):
  6632. """
  6633. The result of a call to aten._assert_scalar
  6634. """
  6635. def get_reads(self) -> OrderedSet[Dep]:
  6636. return OrderedSet()
  6637. def should_allocate(self) -> bool:
  6638. return False
  6639. def __init__(self, scalar: SympyBoolean, msg: str) -> None:
  6640. super().__init__(
  6641. # Buffer(name, layotu)
  6642. None,
  6643. NoneLayout(device=torch.device("cpu")),
  6644. # InputsKernel(inputs)
  6645. [],
  6646. )
  6647. self.scalar = scalar
  6648. self.msg = msg
  6649. def has_side_effects(self) -> bool:
  6650. return True
  6651. @cache_on_self_and_args("AssertScalar")
  6652. def get_free_symbol_uses(
  6653. self, unbacked_only: bool = False
  6654. ) -> OrderedSet[sympy.Symbol]:
  6655. return get_free_symbols(self.scalar, unbacked_only)
  6656. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  6657. if not config.scalar_asserts:
  6658. return
  6659. # NB: It is EXTREMELY important not to simplify the scalar under assertion here,
  6660. # because simplify is done with respect to runtime asserts. So if you have
  6661. # "u0 == 0" in the runtime asserts, if you subsequently try to
  6662. # simplify(u0 == 0), you will get True (because we've already runtime assert'ed
  6663. # that it's true). But we're code generating the actual runtime assert here!!
  6664. symbol = next(iter(self.get_free_symbol_uses(unbacked_only=False)))
  6665. if V.graph.fx_wrapper:
  6666. # TODO fix
  6667. pass
  6668. elif V.graph.cpp_wrapper:
  6669. symbol_str = f"std::to_string({symbol})"
  6670. sizevar = V.graph.wrapper_code.codegen_cpp_sizevar(
  6671. self.scalar, simplify=False
  6672. )
  6673. # TODO: when we start compiling in C++20, annotate with [[unlikely]].
  6674. wrapper.writeline(
  6675. f'if (!({sizevar})) {{ throw std::runtime_error("Expected {self.msg} but received " + {symbol_str}); }}'
  6676. )
  6677. else:
  6678. sizevar = V.graph.wrapper_code.codegen_python_sizevar(
  6679. self.scalar, simplify=False
  6680. )
  6681. wrapper.writeline(f"if not ({sizevar}):")
  6682. wrapper.writeline(f" raise RuntimeError({repr(self.msg)})")
  6683. # No one should ever use this buffer, but for uniformity
  6684. # define the variable and assign it None
  6685. wrapper.writeline(f"{self.get_name()} = None")
  6686. @ir_dataclass(frozen=False)
  6687. class ExternKernelNode:
  6688. name: str
  6689. node: export_schema.Node
  6690. class FallbackKernel(ExternKernelAlloc):
  6691. """
  6692. A class that represents a fallback kernel for handling operators that are not
  6693. directly support by inductor. It currently supports functional ops, view ops,
  6694. inplace aten ops, and mutating ops that are auto-functionalizable.
  6695. """
  6696. def __init__(
  6697. self,
  6698. layout: OutputSpec,
  6699. kernel: _OpOverloads,
  6700. tensor_args: Sequence[IRNode],
  6701. nontensor_args: Sequence[Any],
  6702. unflatten_args: Callable[..., Any],
  6703. kwargs: Optional[dict[str, Any]] = None,
  6704. *,
  6705. unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None,
  6706. ) -> None:
  6707. super().__init__(
  6708. layout,
  6709. tuple(tensor_args),
  6710. tuple(nontensor_args),
  6711. op_overload=kernel,
  6712. )
  6713. self.use_runtime_dispatch = False
  6714. self.unbacked_bindings = unbacked_bindings or {}
  6715. assert isinstance(
  6716. kernel, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
  6717. ), f"Fails to create FallbackKernel for {kernel}: {type(kernel)} not supported"
  6718. self.op_overload = kernel
  6719. self.unflatten_args = unflatten_args
  6720. self.kwargs = {} if kwargs is None else kwargs
  6721. assert self.python_kernel_name is not None
  6722. V.graph.warn_fallback(self.python_kernel_name)
  6723. # args that are aliased
  6724. self.alias_names: list[str] = []
  6725. # args that are mutated AND returned from the op
  6726. self.mutation_names: list[str] = []
  6727. if isinstance(self.op_overload, torch._ops.HigherOrderOperator):
  6728. # We assume here that HOPs with FallbackKernel are functional.
  6729. # This may not always be true! HOPs must individually opt-in to
  6730. # FallbackKernel, so please check this if you opt-in.
  6731. return
  6732. if "_c10d_functional" in self.op_overload.name():
  6733. # _c10d_functional kernels are lowered into _CollectiveKernel which
  6734. # derives from FallbackKernel for the cpp codegen. The kernels
  6735. # don't pass the can_auto_functionalize check, but their mutation
  6736. # is handled properly by _CollectiveKernel.
  6737. return
  6738. schema = self.op_overload._schema
  6739. # NOTE: [FallbackKernel supported operators]
  6740. # We only support three types of operators:
  6741. # - functional ops
  6742. # - view ops
  6743. # - inplace aten ops
  6744. # - mutating ops that are auto-functionalizable. That is,
  6745. # the operator may mutate any number of inputs, but its outputs
  6746. # may not alias any of the inputs.
  6747. #
  6748. # The unsupported cases usually do not show up here (because
  6749. # AOTAutograd functionalized them away); the only way for an in-place
  6750. # op to show up here is if a lowering or pass introduced it.
  6751. if torch._library.utils.mutates_and_returns_first_arg(self.op_overload):
  6752. self.mutation_names.append(tensor_args[0].get_name())
  6753. return
  6754. def has_functionalize_impl(op: torch._ops.OpOverload) -> bool:
  6755. return torch._C._dispatch_has_kernel_for_dispatch_key(
  6756. op.name(), torch._C.DispatchKey.Functionalize
  6757. ) or (
  6758. hasattr(op, "py_kernels")
  6759. and torch._C.DispatchKey.Functionalize in op.py_kernels
  6760. )
  6761. if (
  6762. schema.is_mutable
  6763. and not can_auto_functionalize(self.op_overload)
  6764. and not has_functionalize_impl(self.op_overload)
  6765. ):
  6766. raise NotImplementedError(
  6767. f"NYI: Can't generate FallbackKernel for {self.op_overload}"
  6768. )
  6769. args, kwargs = self.unflatten_args(self.inputs, self.constant_args)
  6770. def handle_aliasing_and_mutation(info: torch._C.Argument, arg: Any) -> None:
  6771. # Assertions to make sure we didn't mismatch args
  6772. if isinstance(info.type, torch.ListType):
  6773. assert isinstance(arg, (list, tuple)), type(arg)
  6774. if library_utils.is_tensor_like_type(info.type):
  6775. # PyTorch also accepts None and scalar types for args marked as "Tensor".
  6776. # We're not going to check all of them here.
  6777. assert not isinstance(arg, (tuple, list))
  6778. if arg is None:
  6779. return
  6780. if info.alias_info is None:
  6781. return
  6782. def add_alias(t: IRNode) -> None:
  6783. self.alias_names.append(t.get_name())
  6784. assert info.alias_info is not None
  6785. if info.alias_info.is_write:
  6786. self.mutation_outputs.append(
  6787. MutationOutput(NoneLayout(device=t.get_device()), t, self)
  6788. )
  6789. if library_utils.is_tensorlist_like_type(info.type):
  6790. if arg is not None:
  6791. for optional_tensor_arg in arg:
  6792. add_alias(optional_tensor_arg)
  6793. else:
  6794. assert library_utils.is_tensor_like_type(info.type)
  6795. add_alias(arg)
  6796. for info, arg in torch._library.utils.zip_schema(schema, args, kwargs):
  6797. handle_aliasing_and_mutation(info, arg)
  6798. def get_read_writes(self) -> dependencies.ReadWrites:
  6799. read_writes = super().get_read_writes()
  6800. if self.op_overload is torch._prims.rng_prims.graphsafe_run_with_rng_state:
  6801. for arg in self.constant_args:
  6802. if isinstance(arg, GeneratorState):
  6803. read_writes = read_writes.with_read(
  6804. dependencies.StarDep(arg.get_name())
  6805. )
  6806. return read_writes
  6807. def codegen_unbacked_symbol_defs(self, wrapper: PythonWrapperCodegen) -> None:
  6808. return wrapper.codegen_unbacked_symbol_defs_for_outputs(
  6809. self.get_name(), self.outputs, getattr(self, "unbacked_bindings", None)
  6810. )
  6811. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  6812. if unbacked_bindings := getattr(self, "unbacked_bindings", None):
  6813. resolved = resolve_unbacked_bindings(
  6814. V.graph.sizevars.shape_env, unbacked_bindings
  6815. )
  6816. assert resolved is not None
  6817. return OrderedSet(resolved.keys())
  6818. else:
  6819. return OrderedSet()
  6820. def codegen_args(self) -> list[str]:
  6821. @dataclasses.dataclass
  6822. class Shim:
  6823. ref: Any
  6824. def __repr__(self) -> str:
  6825. return self.ref
  6826. assert is_node_sequence(self.inputs)
  6827. tensor_args = [Shim(x.codegen_reference()) for x in self.inputs]
  6828. args, kwargs = self.unflatten_args(tensor_args, self.constant_args)
  6829. if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload):
  6830. args = self.fill_non_provided_args(args, kwargs)
  6831. args = [
  6832. V.graph.wrapper_code.val_to_arg_str(x, param.real_type)
  6833. for param, x in zip(self.op_overload._schema.arguments, args)
  6834. ]
  6835. else:
  6836. args = [V.graph.wrapper_code.val_to_arg_str(x) for x in args]
  6837. # let self.codegen_kwargs handle kwargs
  6838. self.kwargs.update(kwargs)
  6839. return args
  6840. @staticmethod
  6841. def find_device(
  6842. tensor_args: Optional[Sequence[torch.Tensor]], example_output: Sequence[Any]
  6843. ) -> Any:
  6844. non_torch_bind_tensor_args = (
  6845. [t for t in tensor_args if not isinstance(t, TorchBindObject)]
  6846. if tensor_args
  6847. else None
  6848. )
  6849. if non_torch_bind_tensor_args:
  6850. assert tensor_args
  6851. devices = [arg.get_device() for arg in tensor_args if arg.get_device()]
  6852. return devices[0]
  6853. if isinstance(example_output, torch.Tensor):
  6854. return example_output.device
  6855. if isinstance(example_output, (list, tuple)):
  6856. device_set = OrderedSet(
  6857. # pyrefly: ignore [bad-argument-type]
  6858. FallbackKernel.find_device(None, x)
  6859. for x in example_output
  6860. )
  6861. # Remove None
  6862. devices = [device for device in device_set if device]
  6863. if len(devices) == 1:
  6864. return devices[0]
  6865. for device in devices:
  6866. assert isinstance(device, torch.device)
  6867. if is_gpu(device.type):
  6868. return device
  6869. return devices[0]
  6870. return None
  6871. def has_side_effects(self) -> bool:
  6872. from torch._library.utils import is_impure
  6873. # Note: We don't pass args/kwargs here because they're IRNodes, not actual values
  6874. # The check is done on the op_overload itself
  6875. return is_impure(self.op_overload) # pyrefly: ignore[bad-argument-type]
  6876. def get_inputs_that_alias_output(self) -> Sequence[str]:
  6877. assert isinstance(
  6878. self.op_overload, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
  6879. ), (
  6880. f"Fails to create FallbackKernel for {self.op_overload}: "
  6881. f"{type(self.op_overload)} not supported"
  6882. )
  6883. # See [Note: FallbackKernel supported operators]: for a mutating
  6884. # op that is auto-functionalizable, its outputs does NOT
  6885. # alias any of the inputs.
  6886. if (
  6887. not isinstance(self.op_overload, torch._ops.HigherOrderOperator)
  6888. and "_c10d_functional" not in self.op_overload.name()
  6889. and self.op_overload._schema.is_mutable
  6890. and can_auto_functionalize(self.op_overload)
  6891. ):
  6892. return []
  6893. else:
  6894. return self.alias_names
  6895. def get_mutation_names(self) -> Sequence[str]:
  6896. assert len(self.mutation_names) <= 1
  6897. return self.mutation_names
  6898. def export_extern_kernel_node(self): # type: ignore[no-untyped-def]
  6899. """
  6900. ProxyExecutor Design Note
  6901. We export the ExternFallbackNodes (for custom ops) into a serialized file
  6902. and run it with a host side proxy executor to address the ABI problem
  6903. This is currently only implemented for fbcode. Eventually, we will also make this work for OSS.
  6904. Detailed design doc can be found at
  6905. https://docs.google.com/document/d/1wC4DOZFaYym2t1Esz0X5yxlLI3RDnSiyRbUus3bkJ64/edit?usp=sharing
  6906. """
  6907. log.debug(
  6908. "Extern kernel node added for node %s with target %s.",
  6909. self.get_name(),
  6910. self.op_overload,
  6911. )
  6912. assert isinstance(self, FallbackKernel), type(self)
  6913. args, kwargs = self.unflatten_args(self.inputs, self.constant_args)
  6914. args = self.fill_non_provided_args(args, kwargs)
  6915. ordered_kwargs = [
  6916. self.get_kwargs_value(key, **kwargs)
  6917. for key in self.ordered_kwargs_for_cpp_kernel
  6918. ]
  6919. target = self.op_overload
  6920. if not V.graph.aot_mode:
  6921. # No need to serialize in the cpp wrapper JIT mode
  6922. return [*args, *ordered_kwargs]
  6923. serializer = GraphModuleSerializer(None, []) # type: ignore[arg-type]
  6924. named_arguments = serializer.serialize_inputs(target, args, kwargs)
  6925. # serialize_outputs
  6926. def handle_single_output(
  6927. return_type: Union[torch.TensorType, torch.ListType, torch.JitType],
  6928. output: Union[IRNode, Sequence[IRNode]],
  6929. ) -> export_schema.Argument:
  6930. if isinstance(return_type, (torch.TensorType, torch.NoneType)):
  6931. # For single Tensor or None
  6932. out = output
  6933. if isinstance(output, (list, tuple)):
  6934. assert len(output) == 1
  6935. out = output[0]
  6936. if isinstance(return_type, torch.TensorType):
  6937. assert isinstance(out, IRNode)
  6938. return export_schema.Argument.create(
  6939. as_tensor=export_schema.TensorArgument(name=out.get_name())
  6940. )
  6941. else: # NoneType
  6942. assert out is None
  6943. return export_schema.Argument.create(as_none=True)
  6944. elif isinstance(return_type, torch.ListType) and isinstance(
  6945. return_type.getElementType(), torch.TensorType
  6946. ):
  6947. assert isinstance(output, Sequence), type(output)
  6948. # For single TensorList
  6949. return export_schema.Argument.create(
  6950. as_tensors=[
  6951. export_schema.TensorArgument(name=out.get_name())
  6952. for out in output
  6953. ]
  6954. )
  6955. elif isinstance(return_type, torch.OptionalType) and isinstance(
  6956. return_type.getElementType(), torch.TensorType
  6957. ):
  6958. # For OptionalTensor
  6959. if output is None:
  6960. return export_schema.Argument.create(
  6961. as_optional_tensor=export_schema.OptionalTensorArgument.create(
  6962. as_none=True
  6963. )
  6964. )
  6965. else:
  6966. assert isinstance(output, IRNode)
  6967. return export_schema.Argument.create(
  6968. as_optional_tensor=export_schema.OptionalTensorArgument.create(
  6969. as_tensor=export_schema.TensorArgument(
  6970. name=output.get_name()
  6971. )
  6972. )
  6973. )
  6974. elif isinstance(return_type, torch.IntType):
  6975. return export_schema.Argument.create(as_int=output)
  6976. else:
  6977. raise RuntimeError(f"Unsupported return type {type(return_type)}")
  6978. if isinstance(target, torch._higher_order_ops.torchbind.CallTorchBind):
  6979. returns = target.schema(args[0], args[1]).returns
  6980. else:
  6981. returns = target._schema.returns # type: ignore[union-attr]
  6982. if len(returns) == 1:
  6983. # NOTE: [special handling of all_reduce_coalesced_'s return value]
  6984. # all_reduce_coalesced_ return a list of tensors via self.mutation_outputs
  6985. outputs = self.outputs if self.outputs else self.mutation_outputs
  6986. return_type = returns[0].real_type
  6987. output_arguments = [handle_single_output(return_type, outputs)]
  6988. else:
  6989. # For tuple returns, e.g "-> (Tensor, Tensor)" or "-> (Tensor, Tensor[])"
  6990. # Not generating output args for self.mutation_outputs
  6991. output_arguments = [
  6992. handle_single_output(
  6993. return_schema.real_type, # type: ignore[attr-defined]
  6994. output,
  6995. )
  6996. for return_schema, output in zip(returns, self.outputs)
  6997. ]
  6998. assert self.op_overload is not None
  6999. node = ExternKernelNode(
  7000. name=self.get_name(),
  7001. node=export_schema.Node(
  7002. target=self.op_overload.name(),
  7003. inputs=named_arguments,
  7004. outputs=output_arguments,
  7005. metadata={},
  7006. ),
  7007. )
  7008. V.extern_kernel_nodes.append(node)
  7009. return [*args, *ordered_kwargs]
  7010. @override
  7011. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  7012. """Overrides the parent member.
  7013. See https://github.com/pytorch/pytorch/issues/151692"""
  7014. kernel = self.op_overload
  7015. assert kernel is not None
  7016. if kernel.namespace == "aten":
  7017. # Aten Fallback Ops
  7018. assert isinstance(kernel, torch._ops.OpOverload), type(kernel)
  7019. if V.graph.cpp_wrapper:
  7020. from torchgen.aoti.fallback_ops import inductor_fallback_ops
  7021. if str(kernel) not in inductor_fallback_ops:
  7022. # C shim v2 is torchgen-ed, which should cover all aten ops.
  7023. # If you do hit a missed op, please update fallback_ops.py.
  7024. log.warning(
  7025. "%s is missing a c-shim implementation, using proxy executor as fallback",
  7026. kernel,
  7027. )
  7028. self.use_runtime_dispatch = True
  7029. elif kernel.namespace == "_quantized":
  7030. # Internal Quantized Fallback Ops
  7031. assert isinstance(kernel, torch._ops.OpOverload), type(kernel)
  7032. elif V.graph.cpp_wrapper:
  7033. # For non-aten OpOverload, i.e. custom ops
  7034. # If the op is in custom_ops_to_c_shims, generate direct function call
  7035. self.use_runtime_dispatch = (
  7036. kernel not in config.aot_inductor.custom_ops_to_c_shims
  7037. )
  7038. # Handle the special case where a complex number is input to a C-shim kernel for
  7039. # a scalar input. The torchgen'ed shim API will use type "double", which is
  7040. # incompatible with complex numbers, forcing a fallback to runtime dispatch.
  7041. if (
  7042. V.graph.cpp_wrapper
  7043. and isinstance(kernel, torch._ops.OpOverload)
  7044. and not self.use_runtime_dispatch
  7045. ):
  7046. def is_number(t: torch.JitType) -> bool:
  7047. if isinstance(t, torch.OptionalType):
  7048. return is_number(t.getElementType())
  7049. return isinstance(t, torch.NumberType)
  7050. # Using unflatten_args is a bit of a hack, but all the complex arguments we
  7051. # care about are in self.constant_args, and calling unflatten_args puts them
  7052. # in the correct order without triggering codegen.
  7053. args, kwargs = self.unflatten_args(self.inputs, self.constant_args)
  7054. # Append kwarg values to args. ordered_kwargs_for_cpp_kernel is guaranteed
  7055. # to be set, since this is an OpOverload kernel.
  7056. args_iter = itertools.chain(
  7057. args,
  7058. (
  7059. self.get_kwargs_value(k, **kwargs)
  7060. for k in self.ordered_kwargs_for_cpp_kernel
  7061. ),
  7062. )
  7063. self.use_runtime_dispatch = any(
  7064. isinstance(v, complex) and is_number(a.real_type)
  7065. for v, a in zip(args_iter, kernel._schema.arguments)
  7066. )
  7067. self.codegen_comment(wrapper)
  7068. if self.use_runtime_dispatch:
  7069. exported_args = self.export_extern_kernel_node()
  7070. assert self.python_kernel_name is not None
  7071. assert self.op_overload is not None
  7072. wrapper.generate_fallback_kernel_with_runtime_lookup(
  7073. self.get_name(),
  7074. self.python_kernel_name,
  7075. lambda: [*self.codegen_args(), *self.codegen_kwargs()],
  7076. self.op_overload,
  7077. exported_args,
  7078. # NOTE: [special handling of all_reduce_coalesced_'s return value]
  7079. self.outputs if self.outputs else self.mutation_outputs,
  7080. )
  7081. else:
  7082. wrapper.generate_fallback_kernel(self)
  7083. if isinstance(self.layout, Layout):
  7084. self.codegen_size_asserts(wrapper)
  7085. self.codegen_alignment_asserts(wrapper)
  7086. self.codegen_memory_tracking(wrapper)
  7087. self.codegen_unbacked_symbol_defs(wrapper)
  7088. @staticmethod
  7089. def tensor_to_layout(output: torch.Tensor) -> FixedLayout:
  7090. is_pinned = False
  7091. try:
  7092. is_pinned = output.is_pinned()
  7093. except RuntimeError:
  7094. # dispatch not implemented
  7095. pass
  7096. return FixedLayout(
  7097. output.device,
  7098. output.dtype,
  7099. convert_shape_to_inductor(output.size()),
  7100. convert_shape_to_inductor(output.stride()),
  7101. is_pinned=is_pinned,
  7102. )
  7103. @classmethod
  7104. def create(cls, kernel: _OpOverloads, *args: Any, **kwargs: Any) -> FallbackKernel:
  7105. """Create an instance of FallbackKernel from an _OpOverloads"""
  7106. fake_incorrect_kernels = (aten._fused_moving_avg_obs_fq_helper_functional,)
  7107. if kernel not in fake_incorrect_kernels:
  7108. context = cast(AbstractContextManager[None], V.graph.fake_mode)
  7109. else:
  7110. context = nullcontext()
  7111. with context:
  7112. (
  7113. example_output,
  7114. tensor_args,
  7115. non_tensor_args,
  7116. unflatten_args,
  7117. unbacked_bindings,
  7118. ) = cls.process_kernel(kernel, *args, **kwargs)
  7119. # We need this extra check for input alignment since the example
  7120. # inputs we created are always aligned.
  7121. has_unaligned_input = any(is_unaligned(arg) for arg in tensor_args)
  7122. device = cls.find_device(tensor_args, example_output)
  7123. # Default to CPU for torchbind methods or HOPs that don't produce tensors
  7124. if not device and (
  7125. isinstance(kernel, torch._higher_order_ops.torchbind.CallTorchBind)
  7126. or kernel is torch.ops.higher_order.print
  7127. ):
  7128. device = torch.device("cpu")
  7129. if example_output is None:
  7130. packed = cls(
  7131. NoneLayout(device=device),
  7132. kernel,
  7133. tensor_args,
  7134. non_tensor_args,
  7135. unflatten_args,
  7136. kwargs=kwargs,
  7137. unbacked_bindings=unbacked_bindings,
  7138. )
  7139. else:
  7140. assert device, "Not sure where to find device info"
  7141. packed = cls(
  7142. MultiOutputLayout(device=device),
  7143. kernel,
  7144. tensor_args,
  7145. non_tensor_args,
  7146. unflatten_args,
  7147. kwargs=kwargs,
  7148. unbacked_bindings=unbacked_bindings,
  7149. )
  7150. def generate_output(output: Any, indices: list[tuple[Any, int]]) -> Any:
  7151. if isinstance(output, (list, tuple)):
  7152. return type(output)(
  7153. generate_output(output[i], indices + [(type(output), i)])
  7154. for i in range(len(output))
  7155. )
  7156. elif isinstance(output, dict):
  7157. return {
  7158. key: generate_output(val, indices + [(type(output), key)])
  7159. for key, val in output.items()
  7160. }
  7161. elif isinstance(output, torch.Tensor):
  7162. buf = MultiOutput(
  7163. cls.tensor_to_layout(output),
  7164. packed,
  7165. indices,
  7166. )
  7167. if (
  7168. config.assume_unaligned_fallback_output
  7169. or has_unaligned_input
  7170. or not tensor_is_aligned(output)
  7171. ):
  7172. V.graph.unaligned_buffers.add(buf.name) # type: ignore[arg-type]
  7173. return buf
  7174. elif isinstance(output, int):
  7175. return output
  7176. elif isinstance(output, torch.SymInt):
  7177. return output.node.expr
  7178. else:
  7179. assert output is None, (
  7180. f"FallbackKernel output type {type(output)} is not supported"
  7181. )
  7182. return None
  7183. outputs = generate_output(example_output, [])
  7184. if isinstance(outputs, (list, tuple)):
  7185. packed.outputs = outputs
  7186. elif isinstance(outputs, dict):
  7187. packed.outputs = tuple(outputs)
  7188. else:
  7189. packed.outputs = [outputs]
  7190. return outputs
  7191. @ir_dataclass(frozen=False)
  7192. class ComplexView(FallbackKernel):
  7193. """View a complex number as two dtyped numbers or vice versa"""
  7194. def should_allocate(self) -> bool:
  7195. return False
  7196. def get_inputs_that_alias_output(self) -> Sequence[str]:
  7197. # Signal to codegen that our output buffer isn't safe to reuse
  7198. return [self.input_name(0)]
  7199. def __init__(
  7200. self,
  7201. layout: OutputSpec,
  7202. kernel: _OpOverloads,
  7203. tensor_args: Sequence[IRNode],
  7204. nontensor_args: Sequence[Any],
  7205. unflatten_args: Callable[..., Any],
  7206. *,
  7207. kwargs: dict[str, Any] | None = None,
  7208. unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None,
  7209. ) -> None:
  7210. super().__init__(
  7211. layout,
  7212. kernel,
  7213. tensor_args,
  7214. nontensor_args,
  7215. unflatten_args,
  7216. kwargs=kwargs,
  7217. unbacked_bindings=unbacked_bindings,
  7218. )
  7219. class MemoryCheckKernel(FallbackKernel):
  7220. """
  7221. Custom kernel for memory checking that generates direct function calls
  7222. TODO - the custom op was erroring with str inputs. should be able to custom op directly.
  7223. """
  7224. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  7225. """Override codegen to write direct function call"""
  7226. # Extract our arguments from nontensor_args
  7227. wrapper.write_memory_track_allocation_once()
  7228. alive_list, dead_list, is_final_step = self.constant_args
  7229. alive_repr = repr(alive_list)
  7230. dead_repr = repr(dead_list)
  7231. if is_final_step:
  7232. wrapper.writeline(
  7233. "# note: dont currently distinguish between buffers returned and dealloc'd in last step"
  7234. )
  7235. call = f"check_memory_step(allocated={alive_repr}, freed={dead_repr}, is_final_step={is_final_step})"
  7236. else:
  7237. call = f"check_memory_step(allocated={alive_repr}, freed={dead_repr})"
  7238. wrapper.writeline(call)
  7239. @ir_dataclass
  7240. class MultiOutputLayout(OutputSpec):
  7241. device: torch.device
  7242. def get_device(self) -> Optional[torch.device]:
  7243. return self.device
  7244. class MultiOutput(ExternKernel):
  7245. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  7246. wrapper.codegen_multi_output(self)
  7247. if not self.skip_size_stride_alignment_checks:
  7248. self.codegen_size_asserts(wrapper)
  7249. self.codegen_alignment_asserts(wrapper)
  7250. def __init__(
  7251. self,
  7252. layout: OutputSpec,
  7253. input: IRNode,
  7254. indices: list[tuple[Any, ...]],
  7255. skip_size_stride_alignment_checks: bool = False,
  7256. ) -> None:
  7257. super().__init__(None, layout, [input], ())
  7258. self.name = V.graph.register_buffer(self)
  7259. V.graph.register_operation(self)
  7260. self.indices = indices
  7261. self.skip_size_stride_alignment_checks = skip_size_stride_alignment_checks
  7262. @cache_on_self_and_args("MultiOutput")
  7263. def get_free_symbol_uses(
  7264. self, unbacked_only: bool = False
  7265. ) -> OrderedSet[sympy.Symbol]:
  7266. input_node = self.inputs[0]
  7267. assert isinstance(input_node, IRNode), input_node
  7268. return input_node.get_free_symbol_uses(unbacked_only)
  7269. def should_allocate(self) -> bool:
  7270. return len(self.inputs) == 1 and (
  7271. isinstance(self.inputs[0], CppTemplateBuffer) # Grouped GEMM
  7272. )
  7273. def get_inputs_that_alias_output(self) -> Sequence[str]:
  7274. return [
  7275. inp.get_name()
  7276. for inp in self.inputs
  7277. if isinstance(inp, FallbackKernel)
  7278. and len(inp.get_inputs_that_alias_output()) > 0
  7279. ]
  7280. # We just use a normal dataclass for MutableBox/TensorBox/StorageBox since
  7281. # they're mainly lowering-time constructs that we expect to mutate and such.
  7282. @dataclasses.dataclass
  7283. class MutableBox(IRNode):
  7284. """
  7285. TensorBox / StorageBox allow in-place mutation of Tensors
  7286. """
  7287. data: IRNode
  7288. def has_exceeded_max_reads(self) -> bool:
  7289. return self.data.has_exceeded_max_reads()
  7290. def get_device(self) -> Optional[torch.device]:
  7291. return self.data.get_device()
  7292. def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
  7293. return self.data.make_loader()
  7294. def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
  7295. return self.data.make_indexer()
  7296. def get_stride(self) -> Sequence[_IntLike]:
  7297. return self.data.get_stride()
  7298. def get_name(self) -> str:
  7299. return self.data.get_name()
  7300. def has_large_inner_fn(self, threshold: Optional[int] = None) -> bool:
  7301. return self.data.has_large_inner_fn(threshold)
  7302. def mark_reuse(self, users: int) -> None:
  7303. return self.data.mark_reuse(users)
  7304. def realize_hint(self) -> None:
  7305. return self.data.realize_hint()
  7306. def unwrap_view(self) -> IRNode:
  7307. return self.data.unwrap_view()
  7308. def is_input_buffer(self) -> bool:
  7309. return self.data.is_input_buffer()
  7310. def freeze_layout(self) -> None:
  7311. return self.data.freeze_layout()
  7312. def freeze_layout_with_stride_order(
  7313. self, order: Sequence[int], allow_padding: bool = False
  7314. ) -> None:
  7315. return self.data.freeze_layout_with_stride_order(order, allow_padding)
  7316. def freeze_layout_with_fill_order(self, order: Sequence[int]) -> None:
  7317. return self.data.freeze_layout_with_fill_order(order)
  7318. def freeze_layout_with_same_order(self, stride: Sequence[_IntLike]) -> None:
  7319. return self.data.freeze_layout_with_same_order(stride)
  7320. def freeze_layout_with_exact_strides(
  7321. self, exact_strides: Sequence[_IntLike], allow_padding: bool = False
  7322. ) -> None:
  7323. return self.data.freeze_layout_with_exact_strides(exact_strides, allow_padding)
  7324. def get_read_writes(self) -> dependencies.ReadWrites:
  7325. return self.data.get_read_writes()
  7326. def get_reads(self) -> OrderedSet[Dep]:
  7327. return self.data.get_reads()
  7328. def num_reads(self) -> int:
  7329. return self.data.num_reads()
  7330. def get_storage_numel(self) -> _IntLike:
  7331. return self.data.get_storage_numel()
  7332. def get_reduction_type(self) -> Optional[str]:
  7333. return self.data.get_reduction_type()
  7334. def get_reduction_size(self) -> Sequence[Expr]:
  7335. return self.data.get_reduction_size()
  7336. def is_extern(self) -> bool:
  7337. return self.data.is_extern()
  7338. def is_no_op(self) -> bool:
  7339. return self.data.is_no_op()
  7340. def constant_to_device(self, device: torch.device) -> IRNode:
  7341. return self.data.constant_to_device(device)
  7342. def get_mutation_names(self) -> Sequence[str]:
  7343. return self.data.get_mutation_names()
  7344. def get_operation_name(self) -> str:
  7345. return self.data.get_operation_name()
  7346. def get_inputs_that_alias_output(self) -> Sequence[str]:
  7347. return self.data.get_inputs_that_alias_output()
  7348. def realize(self) -> Optional[str]:
  7349. return self.data.realize()
  7350. @cache_on_self_and_args("MutableBox")
  7351. def get_free_symbol_uses(
  7352. self, unbacked_only: bool = False
  7353. ) -> OrderedSet[sympy.Symbol]:
  7354. return self.data.get_free_symbol_uses(unbacked_only)
  7355. def get_read_names(self) -> OrderedSet[str]:
  7356. return self.data.get_read_names()
  7357. def get_defining_op(self) -> Optional[Operation]:
  7358. return self.data.get_defining_op()
  7359. def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
  7360. return self.data.codegen_reference(writer)
  7361. @property
  7362. def layout(self) -> OutputSpec:
  7363. # we intentionally call get_output_spec (rather than get_layout) since Buffer.layout is an OutputSpec
  7364. return self.data.get_output_spec()
  7365. def get_layout(self) -> Layout:
  7366. return self.data.get_layout()
  7367. def get_output_spec(self) -> OutputSpec:
  7368. return self.data.get_output_spec()
  7369. def get_size(self) -> Sequence[Expr]:
  7370. return self.data.get_size()
  7371. @property
  7372. def dtype(self) -> torch.dtype:
  7373. return self.data.dtype
  7374. def __str__(self) -> str:
  7375. if isinstance(self.data, MutableBox):
  7376. line0 = f"{type(self).__name__}({type(self.data).__name__}("
  7377. endl = "))"
  7378. inner = self.data.data
  7379. else:
  7380. line0 = f"{type(self).__name__}("
  7381. inner = self.data
  7382. endl = ")"
  7383. lines = [
  7384. line0,
  7385. indent(str(inner)),
  7386. endl,
  7387. ]
  7388. return "\n".join(lines)
  7389. __repr__ = __str__
  7390. class TensorBox(MutableBox):
  7391. @overload
  7392. @staticmethod
  7393. def create(data: ShapeAsConstantBuffer) -> ShapeAsConstantBuffer: ...
  7394. @overload
  7395. @staticmethod
  7396. def create(data: IRNode) -> TensorBox: ...
  7397. @staticmethod
  7398. def create(data: IRNode):
  7399. if isinstance(data, ShapeAsConstantBuffer):
  7400. return data
  7401. return TensorBox(StorageBox(data))
  7402. class StorageBox(MutableBox):
  7403. """
  7404. StorageBox allow in-place mutation of Tensors
  7405. """
  7406. def is_input_buffer(self) -> bool:
  7407. if isinstance(self.data, (InputBuffer, ReinterpretView)):
  7408. return self.data.get_name() in V.graph.graph_inputs
  7409. return False
  7410. def is_module_buffer(self) -> bool:
  7411. return (
  7412. isinstance(self.data, (ConstantBuffer))
  7413. and self.data.get_name() in V.graph.constants
  7414. )
  7415. def realize(self) -> Optional[str]:
  7416. if IRNode.is_realized_node(self.data):
  7417. return self.data.get_name()
  7418. assert isinstance(self.data, (Pointwise, Reduction, Scan, Sort)), type(
  7419. self.data
  7420. )
  7421. origin_node = self.data.get_origin_node()
  7422. traceback = self.data.get_traceback()
  7423. device = self.data.get_device()
  7424. assert device is not None
  7425. self.data = ComputedBuffer(
  7426. name=None,
  7427. layout=FlexibleLayout(
  7428. device=device,
  7429. dtype=self.data.get_dtype(),
  7430. size=self.data.get_size(),
  7431. is_pinned=False,
  7432. ),
  7433. data=self.data,
  7434. )
  7435. self.data.name = V.graph.register_buffer(self.data)
  7436. V.graph.register_operation(self.data)
  7437. self.data.origins = self.origins
  7438. self.data.origin_node = origin_node
  7439. self.data.traceback = traceback
  7440. return self.data.name
  7441. def realize_hint(self) -> None:
  7442. """
  7443. Called on buffers we expect to be forced to realize later.
  7444. """
  7445. if (
  7446. isinstance(self.data, (Pointwise, Reduction))
  7447. and self.data.inner_fn_opcount().nontrivial_read_count > 1
  7448. ):
  7449. self.realize()
  7450. def has_accumulated_enough_reads_by_size(self, threshold: int) -> bool:
  7451. from torch._inductor.utils import is_nonfreeable_buffers
  7452. size_of_reads = [
  7453. V.graph.get_dep_size_hint(dep)
  7454. for dep in self.get_reads()
  7455. if not is_nonfreeable_buffers(dep)
  7456. ]
  7457. if not size_of_reads:
  7458. return False
  7459. total_size = sum(size_of_reads)
  7460. max_size = max(size_of_reads)
  7461. min_size = min(size_of_reads)
  7462. return (
  7463. total_size >= threshold
  7464. and total_size / max_size >= 2
  7465. and max_size == min_size
  7466. )
  7467. def has_exceeded_max_reads(self) -> bool:
  7468. return isinstance(self.data, Pointwise) and (
  7469. self.num_reads() > config.realize_acc_reads_threshold
  7470. or self.has_large_inner_fn()
  7471. or (
  7472. config.realize_acc_reads_size_threshold is not None
  7473. and self.has_accumulated_enough_reads_by_size(
  7474. config.realize_acc_reads_size_threshold
  7475. )
  7476. )
  7477. )
  7478. def should_realize_on_reuse(self, users: int) -> bool:
  7479. """
  7480. A heuristic to decide if we should realize a tensor
  7481. that is used multiple times.
  7482. """
  7483. if users > 1 and isinstance(self.data, (Pointwise, Reduction)):
  7484. if is_cpu(self.data):
  7485. # Heuristic for realizing reused result of heavy ops on cpu
  7486. opcount = self.data.inner_fn_opcount()
  7487. heavy_ops = ["exp", "sigmoid"] # a list of heavy ops
  7488. if any(x in opcount.used_ops for x in heavy_ops):
  7489. return True
  7490. return (
  7491. self.num_reads() > config.realize_reads_threshold
  7492. or self.has_large_inner_fn()
  7493. )
  7494. return False
  7495. def mark_reuse(self, users: int) -> None:
  7496. if self.should_realize_on_reuse(users):
  7497. self.realize()
  7498. def num_reads(self) -> int:
  7499. return self.data.num_reads()
  7500. @ir_dataclass(frozen=False)
  7501. class Subgraph(IRNode):
  7502. name: str
  7503. graph_module: torch.fx.GraphModule
  7504. graph: Optional[GraphLowering] = None
  7505. def _has_aliased_buffers(buffers: Sequence[IRNode]) -> bool:
  7506. buffers = [
  7507. buffer.unwrap_view() if isinstance(buffer, ReinterpretView) else buffer
  7508. for buffer in buffers
  7509. ]
  7510. # assuming the same buffer is represented by the same IRNode object
  7511. return len(OrderedSet(id(buffer) for buffer in buffers)) < len(buffers)
  7512. @ir_dataclass(frozen=False)
  7513. class InvokeSubgraph(ExternKernel):
  7514. """
  7515. Ir node for the invoke_subgraph HOP.
  7516. """
  7517. subgraph: Optional[Subgraph] = None
  7518. operands: Optional[Sequence[IRNode]] = None
  7519. outputs: Optional[Sequence[IRNode]] = None
  7520. def __init__(
  7521. self, subgraph: Subgraph, operands: Sequence[IRNode], layout: MultiOutputLayout
  7522. ) -> None:
  7523. super().__init__(
  7524. name=None,
  7525. layout=layout,
  7526. inputs=operands,
  7527. )
  7528. self.subgraph = subgraph
  7529. self.name = V.graph.register_buffer(self)
  7530. V.graph.register_operation(self)
  7531. @classmethod
  7532. def create(
  7533. cls, subgraph: Subgraph, *operands: IRNode
  7534. ) -> list[Union[ShapeAsConstantBuffer, NoneAsConstantBuffer, MultiOutput]]:
  7535. """For each operand, get a realized input, force it to have the same
  7536. strides as the subgraph inputs, then use an InvokeSubgraph"""
  7537. from .lowering import constrain_to_fake_tensor
  7538. # TODO(anijain2305) - Support sym expr as operands in future.
  7539. current_node = V.graph.current_node
  7540. fake_operands = None
  7541. if eager_input_vals := current_node.meta.get("eager_input_vals"):
  7542. # eager_input_vals is (args_values, kwargs_values). We need args for invoke_subgraph
  7543. offset = 2
  7544. if current_node.target is torch.ops.higher_order.with_effects:
  7545. # Aruguments eagerly are (token, subgraph, identifier, *operands)
  7546. assert current_node.args[1] is torch.ops.higher_order.invoke_subgraph
  7547. offset = 3
  7548. fake_operands = eager_input_vals[0][offset:]
  7549. else:
  7550. offset = 2
  7551. if current_node.target is torch.ops.higher_order.with_effects:
  7552. # with_effects args: (token, invoke_subgraph, subgraph, identifier, *operands)
  7553. assert current_node.args[1] is torch.ops.higher_order.invoke_subgraph
  7554. offset = 4
  7555. # For the partitioned backward graph, we do not have
  7556. # eager_input_vals. Here, we rely on the recorded example values.
  7557. fx_operands = current_node.args[offset:]
  7558. fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr]
  7559. # Realize the inputs. Also intermediates can have different strides than
  7560. # the inputs of the subgraph. So, force the intermediates to have same
  7561. # strides as that of subgraph inputs.
  7562. # pyrefly: ignore [annotation-mismatch, redefinition]
  7563. operands: list[IRNode] = [cls.realize_input(x) for x in operands]
  7564. new_operands: list[IRNode] = []
  7565. for idx, operand in enumerate(operands):
  7566. if isinstance(operand, (ShapeAsConstantBuffer, GeneratorState)):
  7567. new_operands.append(operand)
  7568. else:
  7569. new_operands.append(
  7570. constrain_to_fake_tensor(operand, fake_operands[idx])
  7571. )
  7572. # pyrefly: ignore [bad-assignment]
  7573. operands = new_operands
  7574. if subgraph.graph is None:
  7575. # create and lower subgraphs
  7576. subgraph.graph = V.graph.make_subgraph(
  7577. gm=subgraph.graph_module,
  7578. example_inputs=fake_operands,
  7579. subgraph_name=subgraph.name,
  7580. )
  7581. with V.set_graph_handler(subgraph.graph):
  7582. subgraph.graph.run(*fake_operands)
  7583. outputs = subgraph.graph.graph_outputs
  7584. # Find the device - operands could be integers from shapes, so we can't
  7585. # use operands[0]
  7586. device = None
  7587. for operand in operands:
  7588. if not isinstance(operand, ShapeAsConstantBuffer):
  7589. device = operand.get_device()
  7590. break
  7591. assert device is not None
  7592. invoke_subgraph = InvokeSubgraph(
  7593. subgraph=subgraph,
  7594. operands=operands,
  7595. layout=MultiOutputLayout(device=device),
  7596. )
  7597. def create_output(
  7598. output: IRNode, ind: int
  7599. ) -> Union[ShapeAsConstantBuffer, NoneAsConstantBuffer, MultiOutput]:
  7600. if isinstance(output, (ShapeAsConstantBuffer, NoneAsConstantBuffer)):
  7601. return output
  7602. else:
  7603. device = output.get_device()
  7604. assert device is not None
  7605. return MultiOutput(
  7606. FixedLayout(
  7607. device=device,
  7608. dtype=output.get_dtype(),
  7609. size=output.get_size(),
  7610. stride=output.get_stride(),
  7611. offset=output.get_layout().offset,
  7612. is_pinned=output.get_layout().is_pinned,
  7613. ),
  7614. invoke_subgraph, # type: ignore[has-type]
  7615. [(list, ind)],
  7616. skip_size_stride_alignment_checks=True,
  7617. )
  7618. outs = [create_output(output, i) for i, output in enumerate(outputs)]
  7619. invoke_subgraph.outputs = outs # type: ignore[assignment]
  7620. return outs
  7621. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  7622. wrapper.codegen_invoke_subgraph(self)
  7623. @ir_dataclass(frozen=False)
  7624. class Conditional(ExternKernel):
  7625. """
  7626. IR node representing torch.cond
  7627. Attributes:
  7628. predicate: A boolean scalar tensor determining which branch to execute.
  7629. operands: Input tensors passed to both true and false subgraphs.
  7630. true_subgraph: Subgraph executed when predicate is True.
  7631. false_subgraph: Subgraph executed when predicate is False.
  7632. outputs: MultiOutput nodes representing the conditional's outputs.
  7633. """
  7634. predicate: Optional[IRNode] = None
  7635. operands: Optional[Sequence[IRNode]] = None
  7636. true_subgraph: Optional[Subgraph] = None
  7637. false_subgraph: Optional[Subgraph] = None
  7638. outputs: Optional[Sequence[MultiOutput]] = None
  7639. def __init__(
  7640. self,
  7641. predicate: IRNode,
  7642. operands: Sequence[IRNode],
  7643. true_subgraph: Subgraph,
  7644. false_subgraph: Subgraph,
  7645. layout: MultiOutputLayout,
  7646. unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]],
  7647. ) -> None:
  7648. self.predicate = predicate
  7649. self.operands = operands
  7650. self.true_subgraph = true_subgraph
  7651. self.false_subgraph = false_subgraph
  7652. sym_args, tensor_args = _split_by_sym_type([predicate, *operands])
  7653. super().__init__(
  7654. name=None,
  7655. layout=layout,
  7656. inputs=tensor_args,
  7657. constant_args=sym_args,
  7658. )
  7659. if unbacked_bindings is not None:
  7660. self.unbacked_bindings = unbacked_bindings
  7661. self.name = V.graph.register_buffer(self)
  7662. V.graph.register_operation(self)
  7663. @staticmethod
  7664. def _maybe_expr(s: Union[int, torch.SymInt]) -> Union[int, sympy.Expr]:
  7665. if isinstance(s, int):
  7666. return s
  7667. return s.node.expr
  7668. @classmethod
  7669. def create(
  7670. cls,
  7671. predicate: TensorBox,
  7672. true_fn: Subgraph,
  7673. false_fn: Subgraph,
  7674. operands: list[TensorBox],
  7675. ) -> list[MultiOutput]:
  7676. """Create a Sequence of IRNodes from a conditional statement (see .lowering.cond)"""
  7677. # pyrefly: ignore [bad-assignment]
  7678. predicate = cls.realize_input(predicate)
  7679. # pyrefly: ignore [bad-assignment]
  7680. operands = [cls.realize_input(x) for x in operands]
  7681. fx_operands: Argument = V.graph.current_node.args[-1]
  7682. assert isinstance(fx_operands, Sequence), type(fx_operands)
  7683. assert all(isinstance(n, Node) for n in fx_operands)
  7684. fake_operands = [cast(Node, x).meta["val"] for x in fx_operands]
  7685. fake_outputs = V.graph.current_node.meta["val"]
  7686. def _require_exact_strides(
  7687. graph_outputs: Sequence[IRNode],
  7688. fake_tensors: Sequence[torch.Tensor],
  7689. ) -> list[IRNode]:
  7690. ret = []
  7691. for output, fake in zip(graph_outputs, fake_tensors):
  7692. if isinstance(output, ShapeAsConstantBuffer):
  7693. ret.append(output)
  7694. else:
  7695. ret.append(
  7696. # pyrefly: ignore [bad-argument-type]
  7697. ExternKernel.require_exact_strides(
  7698. TensorBox(output), fake.stride(), allow_padding=False
  7699. )
  7700. )
  7701. # pyrefly: ignore [bad-return]
  7702. return ret
  7703. for subgraph in (true_fn, false_fn):
  7704. if subgraph.graph is None:
  7705. # create and lower subgraphs
  7706. subgraph.graph = V.graph.make_subgraph(
  7707. gm=subgraph.graph_module,
  7708. example_inputs=fake_operands,
  7709. subgraph_name=subgraph.name,
  7710. )
  7711. with V.set_graph_handler(subgraph.graph):
  7712. subgraph.graph.run(*fake_operands)
  7713. # Force subgraph outputs to have the expected strides from
  7714. # FakeTensor metadata. This ensures both branches produce
  7715. # outputs with consistent strides.
  7716. subgraph.graph.graph_outputs = _require_exact_strides(
  7717. subgraph.graph.graph_outputs, fake_outputs
  7718. )
  7719. assert true_fn.graph is not None
  7720. assert false_fn.graph is not None
  7721. true_outputs = true_fn.graph.graph_outputs
  7722. false_outputs = false_fn.graph.graph_outputs
  7723. for name, outputs in (("true_fn", true_outputs), ("false_fn", false_outputs)):
  7724. if _has_aliased_buffers(true_outputs):
  7725. raise AssertionError(
  7726. "Output aliasing is currently not supported in compiled torch.cond. "
  7727. f"The outputs of the {name} subgraph of torch.cond are aliased: {outputs}"
  7728. )
  7729. # make sure true and false outputs are structurally equivalent
  7730. assert len(true_outputs) == len(false_outputs), (true_outputs, false_outputs)
  7731. for i, (t_o, f_o) in enumerate(zip(true_outputs, false_outputs)):
  7732. assert t_o.get_device() == f_o.get_device(), (i, t_o, f_o)
  7733. assert t_o.get_dtype() == f_o.get_dtype(), (i, t_o, f_o)
  7734. assert t_o.get_layout().offset == f_o.get_layout().offset, (i, t_o, f_o)
  7735. # Determine device from operands and predicate
  7736. # The predicate can be on a different device (e.g., CPU for control flow)
  7737. # while the data operands and outputs should be on the compute device, so
  7738. # using predicate device as a fallback.
  7739. device = next(
  7740. o.get_device()
  7741. for o in operands + [predicate]
  7742. if not isinstance(o, ShapeAsConstantBuffer)
  7743. )
  7744. unbacked_bindings = resolve_unbacked_bindings(
  7745. V.graph.sizevars.shape_env,
  7746. V.graph.current_node.meta.get("unbacked_bindings", None),
  7747. )
  7748. assert device is not None, "cannot determine device"
  7749. conditional = Conditional(
  7750. predicate=predicate,
  7751. operands=operands,
  7752. true_subgraph=true_fn,
  7753. false_subgraph=false_fn,
  7754. layout=MultiOutputLayout(device=device),
  7755. unbacked_bindings=unbacked_bindings,
  7756. )
  7757. outputs = [
  7758. MultiOutput(
  7759. FixedLayout(
  7760. # pyrefly: ignore [bad-argument-type]
  7761. device=output.get_device()
  7762. if output.get_device() is not None
  7763. else device, # type: ignore[arg-type]
  7764. dtype=output.get_dtype(),
  7765. size=[Conditional._maybe_expr(sz) for sz in merged_output.size()],
  7766. stride=[
  7767. Conditional._maybe_expr(sz) for sz in merged_output.stride()
  7768. ],
  7769. offset=output.get_layout().offset,
  7770. is_pinned=output.get_layout().is_pinned,
  7771. ),
  7772. conditional,
  7773. [(list, i)],
  7774. )
  7775. # as the true and false outputs are equivalent,
  7776. # we can use either of them here as a "template"
  7777. for i, (output, merged_output) in enumerate(
  7778. zip(true_outputs, V.graph.current_node.meta["val"])
  7779. )
  7780. ]
  7781. conditional.outputs = outputs # type: ignore[assignment]
  7782. return outputs
  7783. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  7784. wrapper.codegen_conditional(self)
  7785. wrapper.codegen_unbacked_symbol_defs_for_outputs(
  7786. self.get_name(), self.outputs, getattr(self, "unbacked_bindings", {})
  7787. )
  7788. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  7789. if unbacked_bindings := getattr(self, "unbacked_bindings", None):
  7790. resolved = resolve_unbacked_bindings(
  7791. V.graph.sizevars.shape_env, unbacked_bindings
  7792. )
  7793. assert resolved is not None
  7794. return OrderedSet(resolved.keys())
  7795. else:
  7796. return OrderedSet()
  7797. def _split_by_sym_type(
  7798. args: list[Any],
  7799. ) -> tuple[list[ShapeAsConstantBuffer], list[Any]]:
  7800. non_sym_args = []
  7801. sym_args = []
  7802. for arg in args:
  7803. if isinstance(arg, ShapeAsConstantBuffer):
  7804. sym_args.append(arg.expr)
  7805. else:
  7806. non_sym_args.append(arg)
  7807. return sym_args, non_sym_args
  7808. @ir_dataclass(frozen=False)
  7809. class WhileLoop(ExternKernel):
  7810. """The IR node for while_loop and while_loop_stack_output. It supports input mutation."""
  7811. carried_inputs: Optional[Sequence[IRNode]] = None
  7812. additional_inputs: Optional[Sequence[IRNode]] = None
  7813. cond_subgraph: Optional[Subgraph] = None
  7814. body_subgraph: Optional[Subgraph] = None
  7815. outputs: Optional[Sequence[MultiOutput]] = None
  7816. def __init__(
  7817. self,
  7818. carried_inputs: Sequence[IRNode],
  7819. additional_inputs: Sequence[IRNode],
  7820. cond_subgraph: Subgraph,
  7821. body_subgraph: Subgraph,
  7822. layout: MultiOutputLayout,
  7823. unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]],
  7824. stack_output: bool,
  7825. ) -> None:
  7826. self.carried_inputs = carried_inputs
  7827. self.additional_inputs = additional_inputs
  7828. self.cond_subgraph = cond_subgraph
  7829. self.body_subgraph = body_subgraph
  7830. sym_args, tensor_args = _split_by_sym_type(
  7831. [*carried_inputs, *additional_inputs]
  7832. )
  7833. super().__init__(
  7834. name=None,
  7835. layout=layout,
  7836. inputs=tensor_args,
  7837. constant_args=sym_args,
  7838. )
  7839. if unbacked_bindings is not None:
  7840. self.unbacked_bindings = unbacked_bindings
  7841. self.stack_output = stack_output
  7842. self.name = V.graph.register_buffer(self)
  7843. V.graph.register_operation(self)
  7844. # Accidental aliasing can be created due to cse, where the empty buffers we
  7845. # allocated for backward to use gets csed into the same buffer in function fx_graph_cse.
  7846. # See test_scan_multiple_layers_gradient for a concrete example.
  7847. @staticmethod
  7848. def _clone_aliased_inputs(carried_inputs: Sequence[IRNode]) -> Sequence[IRNode]:
  7849. if not _has_aliased_buffers(carried_inputs):
  7850. return carried_inputs
  7851. # Import clone from lowering module
  7852. # Unwrap views to get the underlying buffers for comparison
  7853. unwrapped_buffers = [
  7854. buffer.unwrap_view() if isinstance(buffer, ReinterpretView) else buffer
  7855. for buffer in carried_inputs
  7856. ]
  7857. # Track which buffers we've seen and their indices
  7858. seen_buffers: OrderedSet[int] = OrderedSet()
  7859. result: list[Union[IRNode, TensorBox]] = []
  7860. for original_input, unwrapped_buffer in zip(carried_inputs, unwrapped_buffers):
  7861. if id(unwrapped_buffer) in seen_buffers:
  7862. result.append(ExternKernel.copy_input(original_input))
  7863. else:
  7864. seen_buffers.add(id(unwrapped_buffer))
  7865. result.append(original_input)
  7866. return result
  7867. @staticmethod
  7868. def _maybe_wrap_as_tensor_box(out: IRNode) -> IRNode:
  7869. if isinstance(out, TensorBox):
  7870. return out
  7871. elif isinstance(out, (StorageBox, ReinterpretView)):
  7872. return TensorBox(out)
  7873. elif isinstance(out, MultiOutput):
  7874. return TensorBox.create(out)
  7875. else:
  7876. raise RuntimeError(f"NYI unsupported output type: {type(out)}")
  7877. @classmethod
  7878. def create(
  7879. cls,
  7880. cond_fn: Subgraph,
  7881. body_fn: Subgraph,
  7882. carried_inputs: Sequence[IRNode],
  7883. additional_inputs: Sequence[IRNode],
  7884. stack_output: bool,
  7885. ) -> Union[IRNode, Sequence[IRNode]]:
  7886. """create the while_loop IR node. stack_output controls whether it stack
  7887. each iterations' output, which is necessary for training.
  7888. """
  7889. from torch._higher_order_ops.utils import check_input_alias_and_mutation
  7890. def _require_exact_strides(
  7891. tensor_boxes: Sequence[IRNode],
  7892. fake_tensors: list[Union[int, torch.SymInt, torch.Tensor]],
  7893. ) -> list[IRNode]:
  7894. assert len(tensor_boxes) == len(fake_tensors)
  7895. ret = []
  7896. for tb, fk in zip(tensor_boxes, fake_tensors):
  7897. if isinstance(fk, torch.Tensor):
  7898. # Subgraph lowering always return StorageBox as graph_outputs because
  7899. # it realizes the outputs.
  7900. #
  7901. # However, require_exact_strides is expecting TensorBox
  7902. # e.g. in require_exact_strides when an expand happens,
  7903. # the fake tensor's stride is (0, 0, 0) but the storage
  7904. # box might have a different stride so lowering.slice_
  7905. # is used to make the stride consistent and it expects input to
  7906. # be TensorBox.
  7907. #
  7908. # So we wrap the inputs as tensor boxes if they're not yet.
  7909. new_tb = WhileLoop._maybe_wrap_as_tensor_box(tb)
  7910. ret.append(
  7911. ExternKernel.require_exact_strides(
  7912. new_tb, fk.stride(), allow_padding=False
  7913. )
  7914. )
  7915. else:
  7916. ret.append(tb)
  7917. return ret
  7918. fx_carried_inputs = V.graph.current_node.args[-2]
  7919. fx_additional_inputs = V.graph.current_node.args[-1]
  7920. fx_all_inputs = fx_carried_inputs + fx_additional_inputs # type: ignore[operator]
  7921. fake_all_inputs = [x.meta["val"] for x in fx_all_inputs] # type: ignore[union-attr]
  7922. fake_carried_inputs = [x.meta["val"] for x in fx_carried_inputs] # type: ignore[union-attr]
  7923. fake_additional_inputs = [x.meta["val"] for x in fx_additional_inputs] # type: ignore[union-attr]
  7924. carried_inputs_ = [cls.realize_input(x) for x in carried_inputs]
  7925. carried_inputs_ = WhileLoop._clone_aliased_inputs(carried_inputs_)
  7926. carried_inputs_ = _require_exact_strides(carried_inputs_, fake_carried_inputs)
  7927. additional_inputs_ = [cls.realize_input(x) for x in additional_inputs]
  7928. additional_inputs_ = _require_exact_strides(
  7929. additional_inputs_, fake_additional_inputs
  7930. )
  7931. all_inputs = carried_inputs_ + additional_inputs_
  7932. for subgraph in (cond_fn, body_fn):
  7933. if subgraph.graph is None:
  7934. # create and lower subgraphs
  7935. assert isinstance(fx_all_inputs, Sequence), type(fx_all_inputs)
  7936. subgraph.graph = V.graph.make_subgraph(
  7937. gm=subgraph.graph_module,
  7938. example_inputs=fx_all_inputs, # type: ignore[arg-type]
  7939. subgraph_name=subgraph.name,
  7940. )
  7941. with V.set_graph_handler(subgraph.graph):
  7942. subgraph.graph.run(*fake_all_inputs)
  7943. # For body_fn, we require its output to have the exact same stride
  7944. # as inputs because the previous output is the input of next iteration.
  7945. #
  7946. # This cannot be automatically done in graph lowering because body_fn's graph outputs
  7947. # are not user-facing so the special handling for strides of user-facing output in graph
  7948. # lowering is not applicable.
  7949. if subgraph is body_fn:
  7950. assert len(subgraph.graph.graph_outputs) == len(
  7951. fake_carried_inputs
  7952. )
  7953. subgraph.graph.graph_outputs = _require_exact_strides( # type: ignore[assignment]
  7954. subgraph.graph.graph_outputs,
  7955. fake_carried_inputs,
  7956. )
  7957. assert cond_fn.graph and body_fn.graph
  7958. cond_outputs = cond_fn.graph.graph_outputs
  7959. body_outputs = body_fn.graph.graph_outputs
  7960. if _has_aliased_buffers(body_outputs):
  7961. raise AssertionError(
  7962. "Output aliasing is currently not supported in compiled torch.while_loop. "
  7963. f"The outputs of the body_fn subgraph of torch.while_loop are aliased: {body_outputs}"
  7964. )
  7965. # make sure cond_fn returns a boolean scalar Tensor
  7966. assert len(cond_outputs) == 1, cond_outputs
  7967. p = cond_outputs[0]
  7968. if not isinstance(p, ShapeAsConstantBuffer):
  7969. assert p.get_dtype() == torch.bool, p
  7970. assert len(p.get_size()) == 0, p
  7971. assert len(all_inputs) > 0, (
  7972. "torch.while_loop is assumed to have at least one operand."
  7973. )
  7974. device = all_inputs[0].get_device()
  7975. assert device is not None # to make linter happy
  7976. # make sure carried_inputs_ and body outputs are structurally equivalent
  7977. assert len(carried_inputs_) == len(body_outputs), (
  7978. carried_inputs_,
  7979. body_outputs,
  7980. )
  7981. for i, (op, bo) in enumerate(zip(carried_inputs_, body_outputs)):
  7982. def _guard_list_equals(
  7983. lhs_exprs: Sequence[Union[int, sympy.Expr]],
  7984. rhs_exprs: Sequence[Union[int, sympy.Expr]],
  7985. ) -> None:
  7986. assert len(lhs_exprs) == len(rhs_exprs)
  7987. for lhs, rhs in zip(lhs_exprs, rhs_exprs):
  7988. V.graph.sizevars.check_equals(lhs, rhs)
  7989. _guard_list_equals(op.get_size(), bo.get_size())
  7990. _guard_list_equals(op.get_stride(), bo.get_stride())
  7991. # assume all carried_inputs_ and outputs are on the same device
  7992. # as the MultiOutputLayout below requires single device
  7993. assert op.get_device() == bo.get_device(), (i, op, bo, device)
  7994. assert op.get_dtype() == bo.get_dtype(), (i, op, bo)
  7995. assert device is not None
  7996. unbacked_bindings = resolve_unbacked_bindings(
  7997. V.graph.sizevars.shape_env,
  7998. V.graph.current_node.meta.get("unbacked_bindings", None),
  7999. )
  8000. while_loop = WhileLoop(
  8001. carried_inputs=carried_inputs_,
  8002. additional_inputs=additional_inputs_,
  8003. cond_subgraph=cond_fn,
  8004. body_subgraph=body_fn,
  8005. # asserted above that there is at least one operand
  8006. layout=MultiOutputLayout(device=device),
  8007. unbacked_bindings=unbacked_bindings,
  8008. stack_output=stack_output,
  8009. )
  8010. assert body_fn.graph is not None and isinstance(
  8011. body_fn.graph.module, torch.fx.GraphModule
  8012. ) # to make linter happy
  8013. # Handling input mutations
  8014. mutated_idxs = check_input_alias_and_mutation(
  8015. body_fn.graph.module, fake_all_inputs
  8016. )[3]
  8017. mutated_idx_set = OrderedSet(mutated_idxs)
  8018. mutated_inputs = [all_inputs[idx] for idx in mutated_idx_set]
  8019. # Create all outputs first
  8020. mutated_inputs_iter = iter(mutated_inputs)
  8021. all_outputs: list[IRNode] = []
  8022. while_loop.outputs = []
  8023. while_loop.mutation_outputs = []
  8024. if stack_output:
  8025. assert len(mutated_idx_set) == 0, (
  8026. "NYI: while_loop_stack_output input mutations."
  8027. )
  8028. for idx, output in enumerate(V.graph.current_node.meta["val"]):
  8029. # Create MultiOutput for regular outputs
  8030. multi_out = MultiOutput(
  8031. FixedLayout(
  8032. device=output.device, # type: ignore[arg-type]
  8033. dtype=output.dtype,
  8034. size=[Conditional._maybe_expr(sz) for sz in output.size()],
  8035. stride=[Conditional._maybe_expr(st) for st in output.stride()],
  8036. ),
  8037. while_loop,
  8038. [(list, idx)],
  8039. )
  8040. while_loop.outputs.append(multi_out)
  8041. all_outputs.append(multi_out)
  8042. else:
  8043. for idx, output in enumerate(body_outputs):
  8044. if idx in mutated_idx_set:
  8045. assert idx < len(carried_inputs), "only carries can be mutated."
  8046. # Create MutationOutput for mutated inputs
  8047. mutated_input = next(mutated_inputs_iter)
  8048. while_loop.mutation_outputs.append(
  8049. MutationOutput(mutated_input.layout, mutated_input, while_loop) # type: ignore[attr-defined, union-attr]
  8050. )
  8051. all_outputs.append(mutated_input)
  8052. else:
  8053. multi_out = MultiOutput(
  8054. FixedLayout(
  8055. device=output.get_device(), # type: ignore[arg-type]
  8056. dtype=output.get_dtype(),
  8057. size=output.get_size(),
  8058. stride=output.get_stride(),
  8059. offset=output.get_layout().offset,
  8060. ),
  8061. while_loop,
  8062. [(list, idx)],
  8063. )
  8064. while_loop.outputs.append(multi_out)
  8065. all_outputs.append(multi_out)
  8066. for inp, out in zip(carried_inputs, all_outputs):
  8067. if inp.get_name() in V.graph.graph_inputs:
  8068. # if a carried input of the while_loop is a graph input,
  8069. # it can be returned as is when the number of iterations
  8070. # is zero. due to this, we can't (generally) reuse the
  8071. # output buffers corresponding to the graph inputs, as
  8072. # the inputs may end up being mutated.
  8073. V.graph.never_reuse_buffers.add(out.get_name())
  8074. return all_outputs
  8075. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  8076. wrapper.codegen_while_loop(self, self.stack_output)
  8077. wrapper.codegen_unbacked_symbol_defs_for_outputs(
  8078. self.get_name(), self.outputs, getattr(self, "unbacked_bindings", {})
  8079. )
  8080. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  8081. if unbacked_bindings := getattr(self, "unbacked_bindings", None):
  8082. resolved = resolve_unbacked_bindings(
  8083. V.graph.sizevars.shape_env, unbacked_bindings
  8084. )
  8085. assert resolved is not None
  8086. return OrderedSet(resolved.keys())
  8087. else:
  8088. return OrderedSet()
  8089. class EffectfulKernel(FallbackKernel):
  8090. def __init__(
  8091. self,
  8092. layout: OutputSpec,
  8093. kernel: _OpOverloads,
  8094. tensor_args: Sequence[IRNode],
  8095. nontensor_args: Sequence[Any],
  8096. unflatten_args: Callable[..., Any],
  8097. kwargs: Optional[dict[str, Any]] = None,
  8098. *,
  8099. unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None,
  8100. ) -> None:
  8101. super().__init__(
  8102. layout,
  8103. kernel,
  8104. tensor_args,
  8105. nontensor_args,
  8106. unflatten_args,
  8107. kwargs=None,
  8108. unbacked_bindings=unbacked_bindings,
  8109. )
  8110. from torch._higher_order_ops.effects import _get_effect
  8111. effect_type = _get_effect(kernel)
  8112. assert effect_type is not None
  8113. self.effect_type = effect_type
  8114. self.prev_effect_buffer = V.graph.effectful_ops.get(effect_type, None)
  8115. V.graph.effectful_ops[effect_type] = self
  8116. def get_read_writes(self) -> dependencies.ReadWrites:
  8117. read_writes = super().get_read_writes()
  8118. if self.prev_effect_buffer is not None:
  8119. read_writes.reads.add(
  8120. dependencies.StarDep(self.prev_effect_buffer.get_name())
  8121. )
  8122. return read_writes
  8123. def has_side_effects(self) -> bool:
  8124. return True
  8125. class NonTensorObj(IRNode):
  8126. @cache_on_self_and_args("NonTensorObj")
  8127. def get_free_symbol_uses(
  8128. self, unbacked_only: bool = False
  8129. ) -> OrderedSet[sympy.Symbol]:
  8130. return OrderedSet()
  8131. @ir_dataclass
  8132. class TorchBindObject(NonTensorObj):
  8133. name: str
  8134. value: Union[FakeScriptObject, torch.ScriptObject]
  8135. def get_name(self) -> str:
  8136. return self.name
  8137. def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
  8138. return self.name
  8139. def get_value(self) -> Union[FakeScriptObject, torch.ScriptObject]:
  8140. return self.value
  8141. def get_real_obj(self) -> torch.ScriptObject:
  8142. if isinstance(self.value, torch.ScriptObject):
  8143. return self.value
  8144. else:
  8145. return self.value.real_obj
  8146. def get_buf_bytes(self) -> int:
  8147. # Returns the sum of all tensors in the flattened object
  8148. real_script_obj = self.get_real_obj()
  8149. if is_opaque_type(real_script_obj):
  8150. return 0
  8151. assert hasattr(real_script_obj, "__obj_flatten__")
  8152. flat_dict = dict(real_script_obj.__obj_flatten__())
  8153. flat_elems = pytree.tree_flatten(flat_dict)[0]
  8154. flat_sizes = [
  8155. x.element_size() * x.numel()
  8156. for x in flat_elems
  8157. if isinstance(x, torch.Tensor)
  8158. ]
  8159. return functools.reduce(operator.add, flat_sizes, 0)
  8160. @ir_dataclass
  8161. class GeneratorState(NonTensorObj):
  8162. name: str
  8163. device: torch.device
  8164. def get_name(self) -> str:
  8165. return self.name
  8166. def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
  8167. return self.name
  8168. class _CollectiveKernel(FallbackKernel):
  8169. def should_allocate(self) -> bool:
  8170. return False
  8171. def has_side_effects(self) -> bool:
  8172. return True
  8173. # This is identical to FallbackKernel.set_cpp_kernel(), minus the
  8174. # part that checks against input aliasing and mutation.
  8175. def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None) -> None:
  8176. assert type(self.op_overload) is torch._ops.OpOverload, (
  8177. "Setting cpp kernel needs a valid op_overload"
  8178. )
  8179. kernel = self.op_overload
  8180. if cpp_kernel_name is not None:
  8181. self.cpp_kernel_name = cpp_kernel_name
  8182. else:
  8183. self.cpp_kernel_name = kernel._schema.name
  8184. self.ordered_kwargs_for_cpp_kernel = [
  8185. x.name for x in kernel._schema.arguments if x.kwarg_only
  8186. ]
  8187. # NOTE: [In-Place Collective Safety]
  8188. # Between the initiation and completion of an in-place collective, the
  8189. # input buffers are subject to both volatile reads and volatile writes.
  8190. # They must not be read, written to or reused by another kernel. To ensure
  8191. # the constraints, we model collective -> wait_tensor as as two-step
  8192. # mutation of the input buffers.
  8193. @classmethod
  8194. def create_inplace(
  8195. cls,
  8196. kernel: _OpOverloads,
  8197. inputs: Union[IRNode, list[IRNode]],
  8198. *args: Any,
  8199. **kwargs: Any,
  8200. ) -> None:
  8201. with V.graph.fake_mode:
  8202. (
  8203. _example_output,
  8204. tensor_args,
  8205. non_tensor_args,
  8206. unflatten_args,
  8207. unbacked_bindings,
  8208. ) = cls.process_kernel(kernel, inputs, *args, **kwargs)
  8209. assert not unbacked_bindings, f"{kernel} {unbacked_bindings}"
  8210. for tensor_arg in tensor_args:
  8211. tensor_arg.realize()
  8212. V.graph.mark_buffer_mutated(tensor_arg.get_name())
  8213. device = tensor_args[0].get_device()
  8214. packed = cls(
  8215. NoneLayout(device=device),
  8216. kernel,
  8217. tensor_args,
  8218. non_tensor_args,
  8219. unflatten_args,
  8220. )
  8221. inps = pytree.tree_leaves(inputs)
  8222. packed.mutation_outputs.extend(
  8223. [MutationOutput(NoneLayout(device=device), buf, packed) for buf in inps]
  8224. )
  8225. # For inplace collective ops, the input is guaranteed to be alias of the returned value of op.
  8226. packed.alias_names.extend([inp.get_name() for inp in inps])
  8227. if "out" in kwargs:
  8228. packed.mutation_outputs.append(
  8229. MutationOutput(NoneLayout(device=device), kwargs["out"], packed)
  8230. )
  8231. # For out-variant collective ops, the `out=` arg is guaranteed to be alias of the returned value of op.
  8232. packed.alias_names.append(kwargs["out"].get_name())
  8233. # NOTE: [Out-of-Place Collective Safety]
  8234. # Between the initiation and completion of an out-of-place collective:
  8235. #
  8236. # Input buffers:
  8237. # - Are subject to volatile reads
  8238. # - Can be read by another kernel
  8239. # - Must not be written to or reused by another kernel
  8240. #
  8241. # Output buffers:
  8242. # - Are subject to volatile writes
  8243. # - Must not be read, written to or reused by another kernel
  8244. #
  8245. # To ensure the safety of input buffers without sacrificing read
  8246. # availability, we add input buffers as read deps of wait_tensor kernels.
  8247. #
  8248. # To ensure the safety of output buffers, we model wait_tensor as a
  8249. # mutation to the output buffer. Note we also assumes the user program being
  8250. # correct and the output buffer is not consumed by kernels other than
  8251. # wait_tensor.
  8252. #
  8253. # TODO(yifu): add a pre-grad pass to validate the correctness of collective
  8254. # usage in the user program.
  8255. @classmethod
  8256. def create_out_of_place(
  8257. cls,
  8258. kernel: _OpOverloads,
  8259. inputs: Union[TensorBox, list[TensorBox]],
  8260. *args: Any,
  8261. **kwargs: Any,
  8262. ) -> Union[list[MultiOutput], _CollectiveKernel]:
  8263. with V.graph.fake_mode:
  8264. (
  8265. example_output,
  8266. tensor_args,
  8267. non_tensor_args,
  8268. unflatten_args,
  8269. unbacked_bindings,
  8270. ) = cls.process_kernel(kernel, inputs, *args, **kwargs)
  8271. assert not unbacked_bindings, f"{kernel}, {unbacked_bindings}"
  8272. for tensor_arg in tensor_args:
  8273. if not isinstance(tensor_arg, TorchBindObject):
  8274. tensor_arg.realize()
  8275. if isinstance(example_output, list):
  8276. device = cls.find_device(tensor_args, example_output)
  8277. assert device is not None
  8278. packed = cls(
  8279. MultiOutputLayout(device=device),
  8280. kernel,
  8281. tensor_args,
  8282. non_tensor_args,
  8283. unflatten_args,
  8284. )
  8285. packed.outputs = [
  8286. MultiOutput(
  8287. cls.tensor_to_layout(tensor),
  8288. packed,
  8289. [(list, i)],
  8290. )
  8291. for i, tensor in enumerate(example_output)
  8292. ]
  8293. for buf, tensor in zip(packed.outputs, example_output):
  8294. if config.assume_unaligned_fallback_output or not tensor_is_aligned(
  8295. tensor
  8296. ):
  8297. V.graph.unaligned_buffers.add(buf.name) # type: ignore[arg-type]
  8298. return packed.outputs
  8299. else:
  8300. packed = cls(
  8301. cls.tensor_to_layout(example_output),
  8302. kernel,
  8303. tensor_args,
  8304. non_tensor_args,
  8305. unflatten_args,
  8306. )
  8307. if config.assume_unaligned_fallback_output or not tensor_is_aligned(
  8308. example_output
  8309. ):
  8310. V.graph.unaligned_buffers.add(packed.name) # type: ignore[arg-type]
  8311. packed.outputs = [packed]
  8312. return packed
  8313. class _AllReduce_Kernel(_CollectiveKernel):
  8314. def __init__(
  8315. self,
  8316. layout: OutputSpec,
  8317. kernel: _OpOverloads,
  8318. tensor_args: Sequence[IRNode],
  8319. nontensor_args: Sequence[Any],
  8320. unflatten_args: Callable[..., Any],
  8321. kwargs: Optional[dict[str, Any]] = None,
  8322. *,
  8323. unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None,
  8324. ) -> None:
  8325. super().__init__(
  8326. layout,
  8327. kernel,
  8328. tensor_args,
  8329. nontensor_args,
  8330. unflatten_args,
  8331. kwargs=None,
  8332. unbacked_bindings=unbacked_bindings,
  8333. )
  8334. self.set_cpp_kernel_name("aoti_torch_cpu__c10d_functional_all_reduce_")
  8335. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  8336. wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h")
  8337. wrapper.generate_extern_kernel_alloc(self)
  8338. if isinstance(self.layout, Layout):
  8339. self.codegen_size_asserts(wrapper)
  8340. class _AllReduceKernel(_CollectiveKernel):
  8341. def __init__(
  8342. self,
  8343. layout: OutputSpec,
  8344. kernel: _OpOverloads,
  8345. tensor_args: Sequence[IRNode],
  8346. nontensor_args: Sequence[Any],
  8347. unflatten_args: Callable[..., Any],
  8348. kwargs: Optional[dict[str, Any]] = None,
  8349. *,
  8350. unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None,
  8351. ) -> None:
  8352. super().__init__(
  8353. layout,
  8354. kernel,
  8355. tensor_args,
  8356. nontensor_args,
  8357. unflatten_args,
  8358. kwargs=None,
  8359. unbacked_bindings=unbacked_bindings,
  8360. )
  8361. self.set_cpp_kernel_name("aoti_torch_cpu__c10d_functional_all_reduce")
  8362. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  8363. wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h")
  8364. wrapper.generate_extern_kernel_alloc(self)
  8365. if isinstance(self.layout, Layout):
  8366. self.codegen_size_asserts(wrapper)
  8367. class _WaitKernel(_CollectiveKernel):
  8368. def __init__(
  8369. self,
  8370. layout: OutputSpec,
  8371. kernel: _OpOverloads,
  8372. tensor_args: Sequence[IRNode],
  8373. nontensor_args: Sequence[Any],
  8374. unflatten_args: Callable[..., Any],
  8375. kwargs: Optional[dict[str, Any]] = None,
  8376. *,
  8377. unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None,
  8378. ) -> None:
  8379. super().__init__(
  8380. layout,
  8381. kernel,
  8382. tensor_args,
  8383. nontensor_args,
  8384. unflatten_args,
  8385. kwargs=None,
  8386. unbacked_bindings=unbacked_bindings,
  8387. )
  8388. self.set_cpp_kernel_name("aoti_torch_cpu__c10d_functional_wait_tensor")
  8389. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  8390. wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h")
  8391. wrapper.generate_extern_kernel_alloc(self)
  8392. if isinstance(self.layout, Layout):
  8393. self.codegen_size_asserts(wrapper)
  8394. def get_volatile_reads(self) -> Sequence[IRNode]:
  8395. inp = self.inputs[0]
  8396. assert isinstance(inp, IRNode)
  8397. if isinstance(inp, _CollectiveKernel):
  8398. # Out-of-place single-output
  8399. i = inp.inputs[0]
  8400. assert isinstance(i, IRNode), type(i)
  8401. return [i]
  8402. elif isinstance(inp, MultiOutput):
  8403. # This can be two things:
  8404. # 1. Out-of-place multi-output coll
  8405. # 2. In-place coll with inputs coming from another MultiOutput
  8406. coll = inp.inputs[0]
  8407. # Case 1
  8408. if isinstance(coll, _CollectiveKernel):
  8409. _, idx = inp.indices[0]
  8410. return [coll.inputs[idx]]
  8411. # Case 2
  8412. return []
  8413. else:
  8414. # In-place requires no additional deps handling for volatile
  8415. # reads since the inputs are mutated.
  8416. return []
  8417. @classmethod
  8418. def create_wait(cls, kernel: _OpOverloads, inp: TensorBox) -> None:
  8419. with V.graph.fake_mode:
  8420. (
  8421. _example_output,
  8422. tensor_args,
  8423. non_tensor_args,
  8424. unflatten_args,
  8425. unbacked_bindings,
  8426. ) = cls.process_kernel(kernel, inp)
  8427. assert not unbacked_bindings, f"{kernel} {unbacked_bindings}"
  8428. packed = cls(
  8429. NoneLayout(device=inp.get_device()),
  8430. kernel,
  8431. tensor_args,
  8432. non_tensor_args,
  8433. unflatten_args,
  8434. )
  8435. packed.mutation_outputs.append(
  8436. MutationOutput(NoneLayout(device=inp.get_device()), inp, packed)
  8437. )
  8438. def get_read_writes(self) -> dependencies.ReadWrites:
  8439. read_writes = super().get_read_writes()
  8440. # See [Out-of-Place Collective Safety].
  8441. volatile_reads = self.get_volatile_reads()
  8442. for vr in volatile_reads:
  8443. read_writes.reads.add(dependencies.StarDep(vr.get_name()))
  8444. return read_writes
  8445. # NB: recursive structure here reflects val_to_arg_str, avoid
  8446. # calling free_unbacked_symbols on "exotic" types that don't get pexpr
  8447. # treatment
  8448. def maybe_free_unbacked_symbols(s: object) -> OrderedSet[Symbol]:
  8449. if isinstance(s, (SymTypes, Expr)):
  8450. # This branch should be impossible in return position
  8451. return free_unbacked_symbols(s)
  8452. elif isinstance(s, (tuple, list)):
  8453. r = OrderedSet[sympy.Symbol]()
  8454. for t in s:
  8455. r |= maybe_free_unbacked_symbols(t)
  8456. return r
  8457. elif isinstance(s, torch.Tensor):
  8458. # This branch is impossible in constant-args position
  8459. return free_unbacked_symbols(s)
  8460. else:
  8461. return OrderedSet()
  8462. def maybe_free_symbols(s: object) -> OrderedSet[Symbol]:
  8463. if isinstance(s, (SymTypes, Expr)):
  8464. # This branch should be impossible in return position
  8465. return free_symbols(s)
  8466. elif isinstance(s, (tuple, list)):
  8467. r = OrderedSet[sympy.Symbol]()
  8468. for t in s:
  8469. r |= maybe_free_symbols(t)
  8470. return r
  8471. elif isinstance(s, torch.Tensor):
  8472. # This branch is impossible in constant-args position
  8473. return free_symbols(s)
  8474. else:
  8475. return OrderedSet()
  8476. def assign_origin_node(result: Any, n: torch.fx.Node) -> None:
  8477. # This is not complete, but it doesn't have to be: origin_node
  8478. # tracking is best effort. The logic here critically relies on direct
  8479. # TensorBox -> StorageBox denoting a non-view; we don't bother trying
  8480. # to get views to work. Feel free to add any extra cases as needed.
  8481. #
  8482. # Note: we can't YOLO tree_map over this result, because if there are
  8483. # buffers or a view involved, we might not be able to validly assign
  8484. # the origin_node here.
  8485. if isinstance(result, TensorBox) and isinstance(result.data, StorageBox):
  8486. if isinstance(result.data.data, Loops):
  8487. result.data.data._post_init_setattr("origin_node", n)
  8488. elif isinstance(result.data.data, Buffer):
  8489. result.data.data._post_init_setattr("origin_node", n)
  8490. if isinstance(result.data.data, ComputedBuffer) and isinstance(
  8491. result.data.data.data, Loops
  8492. ):
  8493. result.data.data.data._post_init_setattr("origin_node", n)
  8494. # Not really multi-output, can straightforwardly recurse in
  8495. elif (
  8496. isinstance(result.data.data, MultiOutput)
  8497. and not result.data.data.indices
  8498. ):
  8499. if isinstance(result.data.data.inputs[0], Buffer):
  8500. result.data.data.inputs[0]._post_init_setattr("origin_node", n)