scheduler.py 274 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490549154925493549454955496549754985499550055015502550355045505550655075508550955105511551255135514551555165517551855195520552155225523552455255526552755285529553055315532553355345535553655375538553955405541554255435544554555465547554855495550555155525553555455555556555755585559556055615562556355645565556655675568556955705571557255735574557555765577557855795580558155825583558455855586558755885589559055915592559355945595559655975598559956005601560256035604560556065607560856095610561156125613561456155616561756185619562056215622562356245625562656275628562956305631563256335634563556365637563856395640564156425643564456455646564756485649565056515652565356545655565656575658565956605661566256635664566556665667566856695670567156725673567456755676567756785679568056815682568356845685568656875688568956905691569256935694569556965697569856995700570157025703570457055706570757085709571057115712571357145715571657175718571957205721572257235724572557265727572857295730573157325733573457355736573757385739574057415742574357445745574657475748574957505751575257535754575557565757575857595760576157625763576457655766576757685769577057715772577357745775577657775778577957805781578257835784578557865787578857895790579157925793579457955796579757985799580058015802580358045805580658075808580958105811581258135814581558165817581858195820582158225823582458255826582758285829583058315832583358345835583658375838583958405841584258435844584558465847584858495850585158525853585458555856585758585859586058615862586358645865586658675868586958705871587258735874587558765877587858795880588158825883588458855886588758885889589058915892589358945895589658975898589959005901590259035904590559065907590859095910591159125913591459155916591759185919592059215922592359245925592659275928592959305931593259335934593559365937593859395940594159425943594459455946594759485949595059515952595359545955595659575958595959605961596259635964596559665967596859695970597159725973597459755976597759785979598059815982598359845985598659875988598959905991599259935994599559965997599859996000600160026003600460056006600760086009601060116012601360146015601660176018601960206021602260236024602560266027602860296030603160326033603460356036603760386039604060416042604360446045604660476048604960506051605260536054605560566057605860596060606160626063606460656066606760686069607060716072607360746075607660776078607960806081608260836084608560866087608860896090609160926093609460956096609760986099610061016102610361046105610661076108610961106111611261136114611561166117611861196120612161226123612461256126612761286129613061316132613361346135613661376138613961406141614261436144614561466147614861496150615161526153615461556156615761586159616061616162616361646165616661676168616961706171617261736174617561766177617861796180618161826183618461856186618761886189619061916192619361946195619661976198619962006201620262036204620562066207620862096210621162126213621462156216621762186219622062216222622362246225622662276228622962306231623262336234623562366237623862396240624162426243624462456246624762486249625062516252625362546255625662576258625962606261626262636264626562666267626862696270627162726273627462756276627762786279628062816282628362846285628662876288628962906291629262936294629562966297629862996300630163026303630463056306630763086309631063116312631363146315631663176318631963206321632263236324632563266327632863296330633163326333633463356336633763386339634063416342634363446345634663476348634963506351635263536354635563566357635863596360636163626363636463656366636763686369637063716372637363746375637663776378637963806381638263836384638563866387638863896390639163926393639463956396639763986399640064016402640364046405640664076408640964106411641264136414641564166417641864196420642164226423642464256426642764286429643064316432643364346435643664376438643964406441644264436444644564466447644864496450645164526453645464556456645764586459646064616462646364646465646664676468646964706471647264736474647564766477647864796480648164826483648464856486648764886489649064916492649364946495649664976498649965006501650265036504650565066507650865096510651165126513651465156516651765186519652065216522652365246525652665276528652965306531653265336534653565366537653865396540654165426543654465456546654765486549655065516552655365546555655665576558655965606561656265636564656565666567656865696570657165726573657465756576657765786579658065816582658365846585658665876588658965906591659265936594659565966597659865996600660166026603660466056606660766086609661066116612661366146615661666176618661966206621662266236624662566266627662866296630663166326633663466356636663766386639664066416642664366446645664666476648664966506651665266536654665566566657665866596660666166626663666466656666666766686669667066716672667366746675667666776678667966806681668266836684668566866687668866896690669166926693669466956696669766986699670067016702670367046705670667076708670967106711671267136714671567166717671867196720672167226723672467256726672767286729673067316732673367346735673667376738673967406741674267436744674567466747674867496750675167526753675467556756675767586759676067616762676367646765676667676768676967706771677267736774677567766777677867796780678167826783678467856786678767886789679067916792679367946795679667976798679968006801680268036804680568066807680868096810681168126813681468156816681768186819682068216822682368246825682668276828682968306831683268336834683568366837683868396840684168426843684468456846684768486849685068516852685368546855685668576858685968606861686268636864686568666867686868696870687168726873687468756876687768786879688068816882688368846885688668876888688968906891689268936894689568966897689868996900690169026903690469056906690769086909691069116912691369146915691669176918691969206921692269236924692569266927692869296930693169326933693469356936693769386939694069416942694369446945694669476948694969506951695269536954695569566957695869596960696169626963696469656966696769686969697069716972697369746975697669776978697969806981698269836984698569866987698869896990699169926993699469956996699769986999700070017002700370047005700670077008700970107011701270137014701570167017701870197020702170227023702470257026702770287029703070317032703370347035703670377038703970407041704270437044704570467047704870497050705170527053705470557056705770587059706070617062706370647065706670677068706970707071707270737074707570767077707870797080708170827083708470857086708770887089709070917092709370947095709670977098709971007101710271037104710571067107710871097110711171127113711471157116711771187119712071217122712371247125712671277128712971307131713271337134713571367137713871397140714171427143714471457146714771487149715071517152715371547155715671577158
  1. from __future__ import annotations
  2. import collections
  3. import contextlib
  4. import dataclasses
  5. import functools
  6. import inspect
  7. import itertools
  8. import logging
  9. import math
  10. import operator
  11. import os
  12. import pprint
  13. import textwrap
  14. import traceback
  15. import typing
  16. from collections import Counter, defaultdict
  17. from concurrent.futures import as_completed, Future
  18. from typing import Any, Generic, Optional, TYPE_CHECKING, TypeAlias, TypeVar, Union
  19. from typing_extensions import ParamSpec
  20. from torch.utils._ordered_set import OrderedSet
  21. from .ir import ComputedBuffer
  22. if TYPE_CHECKING:
  23. from collections.abc import Callable, Iterator, Sequence
  24. from types import ModuleType
  25. import sympy
  26. import torch
  27. import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
  28. import torch.utils._pytree as pytree
  29. from torch._dynamo.utils import counters, dynamo_timed
  30. from torch._inductor.autotune_process import use_pipelined_autotuning
  31. from torch._inductor.codecache import LambdaFuture, PyCodeCache
  32. from torch._inductor.ir import TritonTemplateCallerBase
  33. from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
  34. from torch.fx.experimental.symbolic_shapes import free_symbols
  35. from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
  36. from torch.utils._triton import has_triton
  37. from . import comms, config, config_comms, dependencies, ir, metrics
  38. from .analyze_preserves_zero_mask import can_codegen_without_upcasts
  39. from .codegen.common import BackendFeature, get_scheduling_for_device, Kernel
  40. from .comm_analysis import (
  41. estimate_nccl_collective_runtime,
  42. estimate_nccl_collective_runtime_nccl_estimator,
  43. )
  44. from .dependencies import Dep, MemoryDep, StarDep, WeakDep
  45. from .exc import GPUTooOldForTriton, TritonMissing
  46. from .fx_utils import count_flops_fx
  47. from .ir import (
  48. assign_origin_node,
  49. get_device_type,
  50. GraphPartitionSignature,
  51. MultiOutput,
  52. MultiOutputLayout,
  53. NoneLayout,
  54. )
  55. from .loop_body import LoopBody
  56. from .memory import MemoryPlanningInfoForBuffer, MemoryPlanningInfoForNode
  57. from .runtime.hints import ReductionHint
  58. from .runtime.runtime_utils import green_text, red_text
  59. from .sizevars import SimplifyIndexing
  60. from .utils import (
  61. _unstable_customized_partition_wrapper,
  62. cache_on_self,
  63. cmp,
  64. device_need_guard,
  65. get_current_backend,
  66. get_device_tflops,
  67. get_dtype_size,
  68. get_gpu_dram_gbps,
  69. get_op_names,
  70. GraphPartitionMap,
  71. IndentedBuffer,
  72. is_collective,
  73. is_cudagraph_unsafe_op,
  74. is_gpu,
  75. is_multi_outputs_template,
  76. is_output_of_multi_outputs_template,
  77. is_wait,
  78. sympy_product,
  79. )
  80. from .virtualized import V
  81. log = logging.getLogger(__name__)
  82. fusion_log = torch._logging.getArtifactLogger(__name__, "fusion")
  83. loop_ordering_log = torch._logging.getArtifactLogger(__name__, "loop_ordering")
  84. compute_dependencies_log = torch._logging.getArtifactLogger(
  85. __name__, "compute_dependencies"
  86. )
  87. cudagraphs_log = torch._logging.getArtifactLogger(__name__, "cudagraphs")
  88. PartitionType: TypeAlias = list["BaseSchedulerNode"]
  89. _T = TypeVar("_T")
  90. _P = ParamSpec("_P")
  91. @dataclasses.dataclass
  92. class FusionResult:
  93. should_fuse: Optional[bool] = None
  94. callable_fn: Optional[Callable[[], bool]] = None
  95. future: Optional[LambdaFuture] = None
  96. def __post_init__(self):
  97. assert (self.should_fuse is not None) ^ (self.callable_fn is not None), (
  98. "Fusion result should contain either fusion decision or callable_fn, not both"
  99. )
  100. @classmethod
  101. def fuse(cls, should_fuse: bool):
  102. return FusionResult(should_fuse=should_fuse)
  103. @classmethod
  104. def from_callable(
  105. cls, callable_fn: Callable[[], bool], future: Optional[LambdaFuture] = None
  106. ):
  107. return FusionResult(callable_fn=callable_fn, future=future)
  108. @dataclasses.dataclass
  109. class PendingFusion:
  110. callable_fn: Callable[[], bool]
  111. node1: BaseSchedulerNode
  112. node2: BaseSchedulerNode
  113. future: Optional[LambdaFuture] = None
  114. def get_fusion_nodes(self) -> tuple[BaseSchedulerNode, BaseSchedulerNode]:
  115. return (self.node1, self.node2)
  116. class MixOrderReduction:
  117. """
  118. This class contains utility functions to decide if we should fuse reductions
  119. reducing across different dimensions of the same input tensor.
  120. """
  121. @staticmethod
  122. def is_split_reduction(node: BaseSchedulerNode) -> bool:
  123. return node.is_reduction() and all(
  124. subnode.node._split_size is not None
  125. for subnode in node.get_nodes()
  126. if isinstance(subnode, SchedulerNode)
  127. and subnode.is_reduction()
  128. and isinstance(subnode.node, ComputedBuffer)
  129. )
  130. @classmethod
  131. def get_numel_rnumel(cls, node: BaseSchedulerNode) -> tuple[sympy.Expr, sympy.Expr]:
  132. if cls.is_split_reduction(node):
  133. xnumel = None
  134. rnumel = None
  135. for subnode in node.get_nodes():
  136. if not (
  137. isinstance(subnode, SchedulerNode)
  138. and subnode.is_reduction()
  139. and isinstance(subnode.node, ComputedBuffer)
  140. ):
  141. continue
  142. assert subnode.node._original_ranges is not None
  143. curxnumel = V.graph.sizevars.simplify(
  144. sympy_product(subnode.node._original_ranges)
  145. )
  146. assert subnode.node._original_reduction_ranges is not None
  147. currnumel = V.graph.sizevars.simplify(
  148. sympy_product(subnode.node._original_reduction_ranges)
  149. )
  150. if xnumel is None:
  151. xnumel = curxnumel
  152. rnumel = currnumel
  153. else:
  154. assert V.graph.sizevars.statically_known_equals(
  155. xnumel, curxnumel
  156. ), f"{xnumel} v.s. {curxnumel}"
  157. assert V.graph.sizevars.statically_known_equals(
  158. rnumel, currnumel
  159. ), f"{rnumel} v.s. {currnumel}"
  160. assert xnumel is not None
  161. return (xnumel, rnumel)
  162. else:
  163. return node.group[1] # type: ignore[return-value]
  164. @classmethod
  165. def has_mix_reduction_orders(
  166. cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode
  167. ) -> bool:
  168. g1 = cls.get_numel_rnumel(node1)
  169. g2 = cls.get_numel_rnumel(node2)
  170. if len(g1) != 2 or len(g2) != 2 or g1 == g2:
  171. return False
  172. return tuple(g1) == tuple(reversed(g2))
  173. @classmethod
  174. def _is_full_access(cls, buf: str, node: BaseSchedulerNode) -> bool:
  175. """
  176. The access to 'buf' is not a broadcast access.
  177. """
  178. found_dep = None
  179. for dep in node.read_writes.reads:
  180. if isinstance(dep, MemoryDep) and dep.name == buf:
  181. found_dep = dep
  182. break
  183. if not found_dep:
  184. return False
  185. index = found_dep.index
  186. var_ranges = node.read_writes.var_ranges
  187. if not var_ranges:
  188. assert isinstance(node, FusedSchedulerNode), f"{type(node)}"
  189. var_ranges = node.snodes[0].read_writes.var_ranges
  190. assert var_ranges
  191. if not (OrderedSet(var_ranges) - OrderedSet(index.free_symbols)):
  192. return True
  193. # cases that happen after merging loops:
  194. # MemoryDep('arg0_1', c0, {c0: 25165824})])
  195. # var_ranges={d0: 32768, d1: 768}
  196. if V.graph.sizevars.statically_known_equals(
  197. sympy_product(found_dep.size), sympy_product(var_ranges.values())
  198. ):
  199. return True
  200. return False
  201. @classmethod
  202. def get_common_read(
  203. cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode
  204. ) -> list[str]:
  205. out = []
  206. common_reads = node1.used_buffer_names() & node2.used_buffer_names()
  207. for buf in common_reads:
  208. if cls._is_full_access(buf, node1) and cls._is_full_access(buf, node2):
  209. out.append(buf)
  210. return out
  211. @classmethod
  212. def has_common_read(
  213. cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode
  214. ) -> bool:
  215. return len(cls.get_common_read(node1, node2)) > 0
  216. @classmethod
  217. def get_numel(cls, node: BaseSchedulerNode) -> int:
  218. g1 = cls.get_numel_rnumel(node)
  219. return V.graph.sizevars.optimization_hint(g1[0] * g1[1], fallback=0)
  220. @classmethod
  221. def get_fusion_score(
  222. cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode
  223. ) -> int:
  224. # node2 is ignored for now
  225. return cls.get_numel(node1)
  226. # TODO add a cache
  227. @classmethod
  228. def can_fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool:
  229. """
  230. Check whether we can fuse two reductions with mix loop orders.
  231. """
  232. if not config.triton.mix_order_reduction:
  233. return False
  234. # TODO: Mix order reduction is not supported with cpp_wrapper yet
  235. if V.graph.cpp_wrapper:
  236. return False
  237. if not node1.is_gpu() or not node2.is_gpu():
  238. return False
  239. device_type = node1.get_device().type # type: ignore[union-attr]
  240. if (
  241. device_type not in ("cuda", "xpu")
  242. or get_current_backend(device_type) != "triton"
  243. ):
  244. return False
  245. if not node1.is_reduction() or not node2.is_reduction():
  246. return False
  247. if (node1.ancestors & node2.get_operation_names()) or (
  248. node2.ancestors & node1.get_operation_names()
  249. ):
  250. # the two reductions have no producer/consumer relationship
  251. return False
  252. # check for mix reduction orders
  253. if not cls.has_mix_reduction_orders(node1, node2):
  254. return False
  255. # check common buffer accesses
  256. common_reads = MixOrderReduction.get_common_read(node1, node2)
  257. if len(common_reads) == 0:
  258. return False
  259. if cls.is_contiguous_node(node1):
  260. contiguous_node, other_node = node1, node2
  261. elif cls.is_contiguous_node(node2):
  262. contiguous_node, other_node = node2, node1
  263. else:
  264. return False
  265. g1 = cls.get_numel_rnumel(contiguous_node)
  266. nrow, ncol = g1
  267. # in non strict mode, we will skip the non-critical checks
  268. if not config.triton.mix_order_reduction_non_strict_mode:
  269. # the fused version has worse perf than non-fused version for
  270. # small workload. When a workload is small enough, data can be
  271. # fully cached by L2
  272. size_thres = 5 * 2**20
  273. # Call evaluate_expr rather than statically_known_geq since nrow can
  274. # have dynamic shape in real models.
  275. # Don't use hint directly since hint can be non-representative.
  276. if not V.graph.sizevars.guard_or_true(sympy.Ge(nrow * ncol, size_thres)):
  277. return False
  278. # We require more more row than columns since
  279. # 1, we prefer doing persistent reduction for each row
  280. # 2, we will split the reduction across the rows
  281. if not V.graph.sizevars.guard_or_true(sympy.Ge(nrow, ncol * 2)):
  282. return False
  283. # When nrow is small, ncol should also be small (due to the check
  284. # above). Thus the entire tensor should be well cached in L2.
  285. # Mix order reduction is less beneficial.
  286. if not V.graph.sizevars.guard_or_true(sympy.Ge(nrow, 4096)):
  287. return False
  288. # Make sure a persistent reduction will be generated
  289. if any(
  290. subnode.node.data.reduction_hint # type: ignore[union-attr]
  291. not in (
  292. ReductionHint.INNER,
  293. ReductionHint.DEFAULT,
  294. )
  295. for subnode in contiguous_node.get_nodes()
  296. if subnode.is_reduction()
  297. ):
  298. return False
  299. # rnumel so large that we will not generated persistent reduction
  300. # We don't see real use cases with dynamic ncol. But if we do,
  301. # we should call evaluete_expr here which adds guards.
  302. if not V.graph.sizevars.statically_known_leq(ncol, 1024 * 16):
  303. return False
  304. # Other reduction types like max/min is not supported yet.
  305. # There are no real use case as well.
  306. out = all(
  307. subnode.node.get_reduction_type() # type: ignore[union-attr]
  308. in {
  309. "sum",
  310. "prod",
  311. }
  312. for subnode in other_node.get_nodes()
  313. if subnode.is_reduction()
  314. )
  315. return out
  316. @classmethod
  317. def are_mix_order_reductions(
  318. cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode
  319. ) -> bool:
  320. return cls.can_fuse(node1, node2)
  321. @classmethod
  322. def is_contiguous_node(cls, node: BaseSchedulerNode) -> bool:
  323. if not all(
  324. cls.is_contiguous_load(dep.name, node) for dep in node.read_writes.reads
  325. ):
  326. return False
  327. return True
  328. @classmethod
  329. def is_contiguous_load(cls, buf: str, parent_node: BaseSchedulerNode) -> bool:
  330. from torch._inductor.loop_body import MemoryUsageType
  331. for node in parent_node.get_nodes():
  332. assert isinstance(node, SchedulerNode)
  333. loop_body = node._body
  334. entries = loop_body.memory_usage[MemoryUsageType.LOAD]
  335. index_names = [e.index_name for e in entries if e.buffer_name == buf]
  336. if len(index_names) == 0:
  337. continue
  338. # there can be multiple index_names some times
  339. for index_name in index_names:
  340. index_expr = loop_body.indexing_exprs[index_name]
  341. var_ranges = loop_body.var_ranges
  342. # assumes the final symbol is for reduction
  343. var_symbols = list(var_ranges.keys())
  344. stride_vars = V.graph.sizevars.stride_vars(
  345. index_expr,
  346. var_symbols,
  347. var_symbols,
  348. )
  349. # stride==0 means a broadcast
  350. if not (stride_vars[-1] == 0 or stride_vars[-1] == 1):
  351. return False
  352. return True
  353. @dataclasses.dataclass
  354. class SchedulerBuffer:
  355. scheduler: Scheduler
  356. node: ir.Buffer
  357. defining_op: Optional[BaseSchedulerNode]
  358. users: list[NodeUser] = dataclasses.field(default_factory=list)
  359. mpi_buffer: MemoryPlanningInfoForBuffer = dataclasses.field(
  360. default_factory=MemoryPlanningInfoForBuffer
  361. )
  362. def defining_op_name(self) -> str:
  363. op = self.defining_op
  364. assert op is not None
  365. return op.get_name()
  366. def __hash__(self) -> int:
  367. return hash(self.node.name)
  368. def debug_str(self) -> str:
  369. result = IndentedBuffer()
  370. name = self.get_name()
  371. result.writeline(f"{name}: {type(self.node).__name__}")
  372. result.writeline(f"{name}.layout = {self.node.layout}")
  373. if self.get_aliases():
  374. result.writeline(f"{name}.aliases = {pformat(self.get_aliases())}")
  375. if self.get_mutations():
  376. result.writeline(f"{name}.mutations = {pformat(self.get_mutations())}")
  377. if len(self.users) <= 1:
  378. result.writeline(f"{name}.users = {self.users}")
  379. else:
  380. result.writeline(f"{name}.users = [")
  381. with result.indent(1):
  382. for user in self.users:
  383. result.writeline(f"{user},")
  384. result.writeline("]")
  385. return result.getrawvalue()
  386. def get_name(self) -> str:
  387. return self.node.get_name()
  388. def allocate(self) -> None:
  389. assert self.node is not None
  390. if not self.node.should_allocate():
  391. return
  392. if (
  393. self.node.get_inputs_that_alias_output()
  394. or self.node.get_mutation_names()
  395. or isinstance(self.node.get_output_spec(), ir.CommBufferLayout)
  396. ):
  397. V.graph.wrapper_code.codegen_allocation(self.node)
  398. return
  399. # hacky check for if V.kernel is a real kernel or NullHandler
  400. if (
  401. hasattr(V.kernel, "args")
  402. and self.get_name() in V.kernel.inplace_update_buffers
  403. ):
  404. input_buffer: Union[ir.DonatedBuffer, ir.Buffer]
  405. input_buffer_name = V.kernel.inplace_update_buffers[self.get_name()]
  406. if input_buffer_name in self.scheduler.name_to_donated_buffer:
  407. input_buffer = self.scheduler.name_to_donated_buffer[
  408. input_buffer_name
  409. ].node
  410. else:
  411. input_buffer = self.scheduler.name_to_buf[input_buffer_name].node
  412. V.graph.wrapper_code.codegen_inplace_reuse(
  413. input_buffer,
  414. self.node,
  415. )
  416. else:
  417. V.graph.wrapper_code.codegen_allocation(self.node)
  418. def can_free(self) -> bool:
  419. # There's no real allocated buffer, no need to free it
  420. assert self.node is not None
  421. if isinstance(self.node.layout, ir.NoneLayout) or is_multi_outputs_template(
  422. self.node
  423. ):
  424. return False
  425. for use in self.users:
  426. if isinstance(use.node, OutputNode):
  427. return False
  428. return True
  429. def set_users(self, users: list[NodeUser]) -> None:
  430. # deduplicate
  431. result: dict[int, NodeUser] = {}
  432. for use in users:
  433. if id(use.node) in result:
  434. result[id(use.node)] = use.merge(result[id(use.node)])
  435. else:
  436. result[id(use.node)] = use
  437. self.users = list(result.values())
  438. def get_aliases(self) -> Sequence[str]:
  439. assert self.node is not None
  440. return self.node.get_inputs_that_alias_output()
  441. def get_mutations(self) -> Sequence[str]:
  442. assert self.node is not None
  443. return self.node.get_mutation_names()
  444. def get_device(self) -> Optional[torch.device]:
  445. return self.node.get_output_spec().get_device()
  446. @dataclasses.dataclass
  447. class SchedulerDonatedBuffer(SchedulerBuffer):
  448. defining_op: Optional[BaseSchedulerNode] = None
  449. class BaseSchedulerNode:
  450. ancestors: OrderedSet[str]
  451. group: tuple[torch.device, tuple[tuple[sympy.Expr, ...], ...]]
  452. last_usage: OrderedSet[str]
  453. # .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode.
  454. # e.g. if the FusedSchedulerNode includes nodes (op_1, op_2, op_3), and op_X is X-th node
  455. # in `self.scheduler.nodes`, then for this FusedSchedulerNode, .min_order is 1 and .max_order is 3.
  456. # For non-"grouped" nodes (i.e. regular SchedulerNode),
  457. # .min_order = .max_order = X if this node is X-th node in `self.scheduler.nodes`.
  458. min_order: int
  459. max_order: int
  460. mpi_node: MemoryPlanningInfoForNode
  461. mutation_renames: dict[str, str]
  462. node: Optional[ir.Operation] = None
  463. outputs: list[SchedulerBuffer]
  464. outputs_by_name: dict[str, SchedulerBuffer]
  465. override_estimated_runtime: Optional[float] = None
  466. read_writes: dependencies.ReadWrites
  467. unmet_dependencies: OrderedSet[Dep]
  468. written: bool = False
  469. def __init__(self, scheduler: Scheduler) -> None:
  470. self.scheduler: Scheduler = scheduler
  471. self.debug_device_str: Callable[[BaseSchedulerNode], list[str]] = (
  472. lambda *args, **kwargs: []
  473. )
  474. def _init_from_node(self, node: ir.Operation) -> None:
  475. self.node = node
  476. self.ancestors = OrderedSet()
  477. self.last_usage = OrderedSet[
  478. str
  479. ]() # buffers that won't be used after this kernel
  480. self.written = False
  481. self.outputs = [
  482. SchedulerBuffer(
  483. scheduler=self.scheduler,
  484. node=output,
  485. defining_op=self,
  486. )
  487. for output in node.get_outputs()
  488. ]
  489. self.outputs_by_name = {buf.get_name(): buf for buf in self.outputs}
  490. # mutation_renames for the current node. Due to potential
  491. # more mutations happening later, this can be different
  492. # to Scheduler.mutation_renames. Also this dict should be small
  493. # since only mutation information relevant to the deps for this
  494. # node is stored here.
  495. self.mutation_renames = {}
  496. def __repr__(self) -> str:
  497. return f"{type(self).__name__}(name={self.get_name()!r})"
  498. def debug_str(self) -> str:
  499. """Longer form printout for trace logs"""
  500. name = self.get_name()
  501. buf = IndentedBuffer()
  502. buf.splice(
  503. f"""\
  504. {name}: {type(self).__name__}({type(getattr(self, "node", None)).__name__})
  505. {name}.writes = {pformat(self.read_writes.writes)}
  506. {name}.unmet_dependencies = {pformat(self.unmet_dependencies)}
  507. {name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)}
  508. {name}.outputs = [
  509. """
  510. )
  511. with buf.indent():
  512. for out in self.get_outputs():
  513. buf.splice(out.debug_str())
  514. buf.writeline("]")
  515. try:
  516. buf.splice(self.debug_str_extra())
  517. except Exception:
  518. log.warning("Ignoring error in debug_str()", exc_info=True)
  519. return buf.getrawvalue().rstrip()
  520. def debug_str_extra(self) -> str:
  521. return ""
  522. def _debug_str_for_device(self) -> list[str]:
  523. return self.debug_device_str(self)
  524. def debug_str_short(self) -> str:
  525. maybe_data = getattr(self.node, "data", None)
  526. data_str = ""
  527. if isinstance(maybe_data, torch._inductor.ir.Pointwise):
  528. data_str = ", " + maybe_data.str_helper(
  529. [maybe_data.get_size()], shorten=False, multiline=False
  530. )
  531. elif isinstance(maybe_data, torch._inductor.ir.Reduction):
  532. data_str = ", " + maybe_data.str_helper(
  533. [maybe_data.get_reduction_size(), maybe_data.get_reduction_type()],
  534. shorten=False,
  535. multiline=False,
  536. )
  537. return f"{self}{data_str}"
  538. def log_details(self) -> None:
  539. log.info(
  540. "%s: unmet_dependencies = %s, writes = %s",
  541. self,
  542. self.unmet_dependencies,
  543. self.read_writes.writes,
  544. )
  545. def reorder_loops_by_dep_pair(
  546. self, self_dep: MemoryDep, other_dep: MemoryDep
  547. ) -> bool:
  548. return False
  549. def update_mutated_names(self, renames: dict[str, str]) -> None:
  550. self.mutation_renames = {
  551. name: renames[name]
  552. for name in (dep.name for dep in self.read_writes.reads_and_writes())
  553. if name in renames
  554. }
  555. self.set_read_writes(self.read_writes.rename(self.mutation_renames))
  556. def add_fake_dep(self, dep: Dep) -> None:
  557. self.set_read_writes(self.read_writes.with_read(dep))
  558. def has_aliasing_or_mutation(self) -> bool:
  559. return any(
  560. buf.get_aliases() or buf.get_mutations() for buf in self.get_outputs()
  561. )
  562. def set_read_writes(self, rw: dependencies.ReadWrites) -> None:
  563. self.read_writes = rw
  564. self.unmet_dependencies = self.read_writes.reads
  565. self.prune_deps()
  566. def set_last_usage(
  567. self, future_used_buffers: OrderedSet[str], mutation_real_name: dict[str, str]
  568. ) -> None:
  569. used_buffers = self.used_or_aliased_buffer_names()
  570. used_buffers = OrderedSet(mutation_real_name.get(k, k) for k in used_buffers)
  571. self.last_usage = used_buffers - future_used_buffers
  572. def mark_run(self) -> None:
  573. for buf in self.outputs:
  574. buf.allocate()
  575. def used_buffer_names(self) -> OrderedSet[str]:
  576. return OrderedSet(
  577. dep.name
  578. for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes)
  579. )
  580. def used_or_aliased_buffer_names(self) -> OrderedSet[str]:
  581. """
  582. Returns buffer names used by this node, including aliases.
  583. Note: is_fake WeakDeps are excluded since they are purely for ordering
  584. and should not affect buffer lifetime.
  585. """
  586. used_names: OrderedSet[str] = OrderedSet()
  587. deps = [
  588. dep.name
  589. for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes)
  590. if not (isinstance(dep, WeakDep) and dep.is_fake)
  591. ]
  592. while len(deps) > 0:
  593. dep = deps.pop()
  594. used_names.add(dep)
  595. if V.graph.name_to_buffer.get(dep):
  596. deps.extend(
  597. alias
  598. for alias in V.graph.name_to_buffer[
  599. dep
  600. ].get_inputs_that_alias_output()
  601. if alias not in used_names
  602. )
  603. return used_names
  604. def prune_deps(self) -> None:
  605. self.unmet_dependencies = OrderedSet(
  606. dep
  607. for dep in self.unmet_dependencies
  608. if dep.name not in self.scheduler.available_buffer_names
  609. )
  610. def prune_weak_deps(self) -> None:
  611. # Prune weak dependencies on operations that have been removed
  612. def should_prune(dep: Dep) -> bool:
  613. if not isinstance(dep, WeakDep):
  614. return False
  615. if dep.name not in self.scheduler.name_to_buf:
  616. return False
  617. op_name = self.scheduler.name_to_buf[dep.name].defining_op_name()
  618. return op_name in V.graph.removed_operations
  619. to_remove = OrderedSet(
  620. dep for dep in self.read_writes.reads if should_prune(dep)
  621. )
  622. self.set_read_writes(self.read_writes.remove_reads(to_remove))
  623. def prune_redundant_deps(
  624. self, name_to_fused_node: dict[str, BaseSchedulerNode]
  625. ) -> None:
  626. _prune_redundant_deps(self, name_to_fused_node, self.scheduler.name_to_buf)
  627. def get_name(self) -> str:
  628. assert self.node is not None
  629. return self.node.get_operation_name()
  630. def get_first_name(self) -> str:
  631. return self.get_name()
  632. @cache_on_self
  633. def get_operation_names(self) -> OrderedSet[str]:
  634. return OrderedSet(node.get_name() for node in self.get_nodes())
  635. @cache_on_self
  636. def get_buffer_names(self) -> OrderedSet[str]:
  637. return OrderedSet(out.get_name() for out in self.outputs)
  638. @cache_on_self
  639. def can_codegen_in_low_precision(self) -> bool:
  640. return all(
  641. isinstance(n, SchedulerNode)
  642. and can_codegen_without_upcasts(n, disallow_fp32_ops=True)
  643. for n in self.get_nodes()
  644. )
  645. @cache_on_self
  646. def can_codegen_without_upcasts(self) -> bool:
  647. return all(
  648. isinstance(n, SchedulerNode) and can_codegen_without_upcasts(n)
  649. for n in self.get_nodes()
  650. )
  651. def get_nodes(self) -> Sequence[BaseSchedulerNode]:
  652. return [self]
  653. def get_outputs(self) -> Sequence[SchedulerBuffer]:
  654. return self.outputs
  655. def get_output(self, buf_name: str) -> SchedulerBuffer:
  656. return self.outputs_by_name[buf_name]
  657. def get_device(self) -> Optional[torch.device]:
  658. assert self.node is not None
  659. return self.node.get_device()
  660. def is_cpu(self) -> bool:
  661. device = self.get_device()
  662. return device is not None and device.type == "cpu"
  663. def is_gpu(self) -> bool:
  664. device = self.get_device()
  665. return device is not None and is_gpu(device.type)
  666. def is_reduction(self) -> bool:
  667. return False
  668. def is_native_matmul(self) -> bool:
  669. return False
  670. def is_split_scan(self) -> bool:
  671. return False
  672. def is_template(self) -> bool:
  673. return False
  674. def is_extern(self) -> bool:
  675. return False
  676. def is_foreach(self) -> bool:
  677. return False
  678. def can_inplace(self, read_dep: dependencies.Dep) -> bool:
  679. return False
  680. def has_side_effects(self) -> bool:
  681. return False
  682. def decide_inplace_update(self) -> None:
  683. """
  684. Decide if there should be inplace updates for the node
  685. and record the decision in the active kernel.
  686. """
  687. from .codegen.wrapper import can_match_buffer_size
  688. if not (
  689. isinstance(self, SchedulerNode)
  690. and config.inplace_buffers
  691. and V.graph.has_feature(self.get_device(), BackendFeature.INPLACE_BUFFERS)
  692. and (
  693. not isinstance(V.kernel, torch._inductor.codegen.simd.SIMDKernel)
  694. or getattr(V.kernel, "mutations", None) is not None
  695. )
  696. # hacky check for if V.kernel is a real kernel or NullHandler
  697. and hasattr(V.kernel, "args")
  698. ):
  699. return
  700. # NOTE remove V.graph.removed_operations once deps issue is fixed
  701. inconsequential_nodes = (
  702. self.ancestors
  703. | V.graph.removed_operations
  704. | self.scheduler.completed_operations
  705. )
  706. def single_index_in_fused_node(buf_to_be_inplaced: SchedulerBuffer) -> bool:
  707. # Inside of NodeUser, we track that the read and write are equivalent
  708. # before deciding if the use can be inplace.
  709. # But if that use is fused into a larger kernel, we need to check equivalence
  710. # of other accesses in fused scheduler node as well.
  711. fused_node = buf_to_be_inplaced.scheduler.get_fused_node(self)
  712. buf_name = buf_to_be_inplaced.get_name()
  713. # Dedup read/writes with equivalent indices
  714. # TODO - would be nice if we could just cache accesses on ReadWrites,
  715. # and enforce variant that this class & members are functional..
  716. deps: OrderedSet[Dep] = OrderedSet()
  717. for user in buf_to_be_inplaced.users:
  718. user_node = user.node
  719. if not isinstance(user_node, BaseSchedulerNode):
  720. continue
  721. if (
  722. user_node.get_first_name()
  723. not in buf_to_be_inplaced.scheduler.name_to_fused_node
  724. or buf_to_be_inplaced.scheduler.get_fused_node(user_node)
  725. is not fused_node
  726. ):
  727. continue
  728. deps |= (
  729. o
  730. for o in user_node.read_writes.reads_and_writes()
  731. if o.name == buf_name
  732. )
  733. if len(deps) > 1:
  734. return False
  735. return True
  736. for buf in self.get_outputs():
  737. buf_node = buf.node
  738. assert buf_node is not None
  739. if (
  740. not buf_node.should_allocate()
  741. or buf_node.get_inputs_that_alias_output()
  742. or buf_node.get_mutation_names()
  743. or buf.get_name() in V.graph.removed_buffers
  744. ):
  745. continue
  746. for read in self.read_writes.reads:
  747. input_buf: Optional[Union[SchedulerBuffer, SchedulerDonatedBuffer]]
  748. if read.name in self.scheduler.name_to_donated_buffer:
  749. input_buf = self.scheduler.name_to_donated_buffer[read.name]
  750. else:
  751. input_buf = self.scheduler.name_to_buf.get(read.name)
  752. if (
  753. input_buf
  754. and V.graph.wrapper_code.can_reuse(input_buf, self)
  755. and not isinstance(input_buf.defining_op, NopKernelSchedulerNode)
  756. ):
  757. assert input_buf.users is not None
  758. remaining_uses = [
  759. x
  760. for x in input_buf.users
  761. if x.node.get_name() not in inconsequential_nodes
  762. ]
  763. if (
  764. len(remaining_uses) == 1
  765. and remaining_uses[0].can_inplace
  766. and remaining_uses[0].node is self
  767. and input_buf.node is not None
  768. and not isinstance(
  769. input_buf.node.get_output_spec(),
  770. (
  771. ir.NoneLayout,
  772. ir.MultiOutputLayout,
  773. ir.MutationLayoutSHOULDREMOVE,
  774. ),
  775. )
  776. and not (
  777. input_buf.defining_op
  778. and isinstance(
  779. input_buf.defining_op.node,
  780. (ir.FallbackKernel, ir.MultiOutput),
  781. )
  782. and len(input_buf.node.get_inputs_that_alias_output()) > 0
  783. )
  784. and can_match_buffer_size(input_buf.node, buf.node)
  785. and single_index_in_fused_node(input_buf)
  786. ):
  787. # if there isn't a triton kernel, then we don't need to call triton-specific things.
  788. # but TODO this might be a convenient place to signal to the Collective kernels to inplace
  789. # (and, can we make "kernel" less generic of a name?)
  790. V.kernel.args.make_inplace(input_buf.get_name(), buf.get_name())
  791. # mutations not tracked in cpp kernels
  792. if isinstance(
  793. V.kernel, torch._inductor.codegen.simd.SIMDKernel
  794. ):
  795. V.kernel.mutations.add(input_buf.get_name())
  796. V.kernel.mutations.add(buf.get_name())
  797. V.kernel.inplace_update_buffers[buf.get_name()] = (
  798. input_buf.get_name()
  799. )
  800. break
  801. def codegen_originating_info(
  802. self, buffer: IndentedBuffer, only_once: bool = True
  803. ) -> None:
  804. if not config.comment_origin:
  805. return
  806. if only_once and self.written:
  807. return
  808. assert self.node is not None
  809. origins = self.node.get_origins()
  810. out_lines = []
  811. for o in origins:
  812. if o.op == "output":
  813. # These are boring and samey
  814. continue
  815. out_lines.append("")
  816. # TODO(voz): Should the pragma be constant somewhere?
  817. out_lines.append("#pragma CMT ORIGIN:")
  818. op_info_str = f"#pragma CMT {o.op} {o.target}"
  819. if "seq_nr" in o.meta:
  820. op_info_str = op_info_str + f" seq_nr:{o.meta['seq_nr']}"
  821. out_lines.append(op_info_str)
  822. if "stack_trace" in o.meta:
  823. stack_trace = f"{o.meta['stack_trace']}"
  824. stack_trace_last_line = stack_trace.rsplit("|", maxsplit=1)[-1]
  825. out_lines.append(
  826. "#pragma CMT "
  827. + stack_trace_last_line.replace("{", "{{")
  828. .replace("}", "}}")
  829. .replace("\n", "\\")
  830. .replace(
  831. "\\", "\\\\"
  832. ) # For windows safe path, avoid for example \x, \U.
  833. )
  834. out_lines.append("#pragma CMT END ORIGIN")
  835. out_lines.append("")
  836. if len(out_lines) == 0:
  837. return
  838. # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does
  839. # not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
  840. buffer.writelines(out_lines)
  841. self.written = True
  842. @cache_on_self
  843. def get_read_write_buffers_sizes(self) -> int:
  844. return self.get_read_write_buffers_sizes_impl(
  845. include_reads=True, include_writes=True
  846. )
  847. @cache_on_self
  848. def get_read_buffer_sizes(self) -> int:
  849. return self.get_read_write_buffers_sizes_impl(
  850. include_reads=True, include_writes=False
  851. )
  852. @cache_on_self
  853. def get_write_buffer_sizes(self) -> int:
  854. return self.get_read_write_buffers_sizes_impl(
  855. include_reads=False, include_writes=True
  856. )
  857. def get_read_write_buffers_sizes_impl(
  858. self, include_reads: bool, include_writes: bool
  859. ) -> int:
  860. return sum(
  861. self.get_read_write_buffer_accesses(
  862. include_reads=include_reads, include_writes=include_writes
  863. ).values(),
  864. start=0,
  865. )
  866. def get_read_write_buffer_accesses(
  867. self, include_reads: bool, include_writes: bool
  868. ) -> dict[str, int]:
  869. """
  870. Counting the number of bytes accessed for a kernel is
  871. surprisingly tricky. In particular, there is a differentiation
  872. between 'theoretical' memory accesses and practical memory
  873. accesses. For example, a layernorm kernel may actually access an
  874. input 3 times, but in theory, it only needs to access its input
  875. once (and may be optimized to do so through say, persistent
  876. reductions)
  877. Another example is that even though a buffer is passed in, we may
  878. not access the entire buffer. This may occur if we are accessing
  879. a slice of the buffer. Another tricky case is for indirect
  880. indexing, where the amount of bytes accessed depends on the
  881. values of the input.
  882. What this function aims to compute is the memory accesses for
  883. worst-case inputs, best-case optimization. What this means is
  884. that for each buffer we compute the amount of potential accesses in two ways and take the minimum.
  885. 1. Numel in ranges multiplied by number of deps the buffer has
  886. 2. The buffer size
  887. Returns memory accesses per buffer.
  888. """
  889. if isinstance(self, NopKernelSchedulerNode):
  890. return {}
  891. if isinstance(self, ExternKernelSchedulerNode) and isinstance(
  892. self.node, MultiOutput
  893. ):
  894. # todo: Calculate this - it's kinda annoying.
  895. return {}
  896. if (
  897. isinstance(self, ExternKernelSchedulerNode)
  898. and isinstance(self.node, ir.FallbackKernel)
  899. and self.node.op_overload
  900. is torch._prims.rng_prims.graphsafe_run_with_rng_state
  901. ):
  902. return {}
  903. def try_size_hint(s: sympy.Expr) -> int:
  904. return V.graph.sizevars.optimization_hint(s, fallback=0)
  905. if isinstance(self, SchedulerNode):
  906. node_numel = try_size_hint(
  907. sympy_product(self.get_ranges()[0])
  908. * sympy_product(self.get_ranges()[1]),
  909. )
  910. else:
  911. node_numel = int(1e9)
  912. buf_accesses = collections.defaultdict(list)
  913. if include_reads:
  914. for dep in self.read_writes.reads:
  915. buf_accesses[dep.name].append(dep)
  916. if include_writes:
  917. for dep in self.read_writes.writes:
  918. buf_accesses[dep.name].append(dep)
  919. reads = (
  920. OrderedSet(dep.name for dep in self.read_writes.reads)
  921. if include_reads
  922. else OrderedSet()
  923. )
  924. writes = (
  925. OrderedSet(dep.name for dep in self.read_writes.writes)
  926. if include_writes
  927. else OrderedSet()
  928. )
  929. def is_materialized(buf: str, snodes: Sequence[BaseSchedulerNode]) -> bool:
  930. users = self.scheduler.name_to_buf[buf].users
  931. buf_uses = OrderedSet(user.node for user in users)
  932. return len(buf_uses - OrderedSet(snodes)) > 0
  933. if isinstance(self, FusedSchedulerNode):
  934. removed_buffers = OrderedSet(
  935. dep for dep in writes if not is_materialized(dep, self.snodes)
  936. )
  937. writes = writes - removed_buffers
  938. reads = reads - removed_buffers
  939. buf_byte_accesses: dict[str, int] = {}
  940. for buf_name in reads | writes:
  941. buf_accessed_elems = sum(node_numel for dep in buf_accesses[buf_name])
  942. buf: Union[ir.Buffer, ir.TensorBox, ir.TorchBindObject]
  943. if buf_name in V.graph.name_to_buffer:
  944. buf = V.graph.name_to_buffer[buf_name]
  945. elif buf_name in V.graph.graph_inputs:
  946. buf = V.graph.graph_inputs[buf_name]
  947. else:
  948. continue
  949. def get_buf_bytes(
  950. buf: Optional[Union[ir.Buffer, ir.TensorBox, ir.TorchBindObject]],
  951. ) -> int:
  952. if not buf:
  953. return 0
  954. if isinstance(buf, ir.TorchBindObject):
  955. return buf.get_buf_bytes()
  956. elif isinstance(buf.layout, MultiOutputLayout):
  957. # Kind of a lazy way to get the MultiOutput nodes corresponding to
  958. # a MultiOutputLayout
  959. users = self.scheduler.name_to_buf[buf.get_name()].users
  960. tot = 0
  961. for user in users:
  962. if isinstance(user.node, OutputNode):
  963. continue
  964. assert isinstance(user.node, BaseSchedulerNode)
  965. if isinstance(user.node.node, MultiOutput):
  966. for sched_buf in user.node.get_outputs():
  967. tot += get_buf_bytes(sched_buf.node)
  968. else:
  969. # Buf is a MultiOutputLayout but not all of its
  970. # users are MultiOutputs...
  971. # TODO: Figure out what's going on
  972. return 0
  973. return tot
  974. elif isinstance(buf.layout, ir.NoneLayout):
  975. return sum(
  976. get_buf_bytes(V.graph.get_buffer(mut_name))
  977. for mut_name in buf.get_mutation_names()
  978. )
  979. else:
  980. buf_elems = try_size_hint(sympy_product(buf.get_size()))
  981. return get_dtype_size(buf.get_dtype()) * min(
  982. buf_accessed_elems, buf_elems
  983. )
  984. buf_bytes = get_buf_bytes(buf)
  985. if buf_name not in buf_byte_accesses:
  986. buf_byte_accesses[buf_name] = buf_bytes
  987. else:
  988. buf_byte_accesses[buf_name] += buf_bytes
  989. return buf_byte_accesses
  990. @cache_on_self
  991. def estimate_flops(self) -> int | None:
  992. if self.node is None:
  993. return None
  994. fx_node = self.node.get_origin_node()
  995. if fx_node is None:
  996. return None
  997. flops = count_flops_fx(fx_node)
  998. if flops is None:
  999. return None
  1000. if isinstance(flops, torch.SymInt):
  1001. flops = flops.node.expr
  1002. resolved_flops = V.graph.sizevars.optimization_hint(flops, fallback=0)
  1003. counters["inductor"]["flop_count"] += resolved_flops
  1004. return resolved_flops
  1005. def get_estimated_runtime(self) -> float:
  1006. if self.override_estimated_runtime is not None:
  1007. return self.override_estimated_runtime
  1008. return self._get_estimated_runtime()
  1009. @cache_on_self
  1010. def _get_estimated_runtime(self) -> float:
  1011. """
  1012. Returns estimated op runtime in milliseconds (ms)
  1013. """
  1014. buf = self.get_nodes()[0].get_outputs()[0]
  1015. layout = buf.node.get_output_spec()
  1016. if not is_gpu(get_device_type(layout)):
  1017. # default to no reordering based on runtime
  1018. return 0
  1019. # Collective kernels
  1020. if is_collective(self.node):
  1021. assert isinstance(self.node, ir.IRNode)
  1022. try:
  1023. if config_comms.runtime_estimations_use_nccl_lib_estimations:
  1024. cache_key = get_estimate_runtime_cache_key_from_snode(self)
  1025. cache = get_estimate_runtime_cache()
  1026. cache_val = cache.lookup(cache_key)
  1027. if cache_val is not None:
  1028. assert isinstance(cache_val, float)
  1029. return cache_val
  1030. ms = estimate_nccl_collective_runtime_nccl_estimator(self)
  1031. if ms is None:
  1032. # NCCL estimations fail: fallback to in-tree algorithmic estimation.
  1033. ms = estimate_nccl_collective_runtime(self.node)
  1034. cache.set_value(cache_key, value=ms)
  1035. return ms
  1036. return estimate_nccl_collective_runtime(self.node)
  1037. except ValueError as e:
  1038. # We don't know how to estimate runtime for this collective,
  1039. # falling back to 0
  1040. log.info(e) # noqa: G200
  1041. return 0
  1042. except TypeError as e:
  1043. # this happens when the collective is not of type ir._CollectiveKernel
  1044. log.info(e) # noqa: G200
  1045. return 0
  1046. elif is_wait(self.node):
  1047. # ir.Wait is only used for collective ops.
  1048. # The time needed for the collective op is already estimated and considered
  1049. # when we are processing the collective op IR node, so ir.Wait takes 0 time
  1050. # since it doesn't take extra time to get the result after the collective is completed.
  1051. return 0
  1052. ret = maybe_estimate_runtime_benchmark(self)
  1053. if ret is not None:
  1054. return ret
  1055. dtype = buf.node.maybe_get_dtype()
  1056. try:
  1057. gpu_memory_bandwidth = get_gpu_dram_gbps()
  1058. gpu_flops = get_device_tflops(dtype) * 10**12
  1059. # If cudaGetDeviceProperties returns 0 for gpu_memory_bandwidth or gpu_flops
  1060. # there is a chance to continue execution successfully. Otherwise, it would fail with
  1061. # ZeroDivisionError below.
  1062. if gpu_memory_bandwidth <= 0:
  1063. raise AssertionError(
  1064. f"gpu_memory_bandwidth cannot be <= 0, but got {gpu_memory_bandwidth}"
  1065. )
  1066. if gpu_flops <= 0:
  1067. raise AssertionError(f"gpu_flops cannot be <= 0, but got {gpu_flops}")
  1068. except Exception:
  1069. return 0
  1070. flops_est = self.estimate_flops()
  1071. if flops_est == 0 or flops_est is None:
  1072. # no flops estimate, so fall back to memory estimate
  1073. ns = self.get_read_write_buffers_sizes() / gpu_memory_bandwidth
  1074. ms = ns / 1e6
  1075. return ms
  1076. # TODO(xmfan): find a better heuristic to model FLOPS/latency relationship
  1077. factor = 1.0
  1078. counted_bytes = self.get_read_write_buffers_sizes()
  1079. counted_bytes = 0 if counted_bytes is None else counted_bytes
  1080. compute_time = (factor * flops_est / gpu_flops) * 1e9
  1081. transfer_time = counted_bytes / gpu_memory_bandwidth
  1082. # Return estimated runtime in milliseconds
  1083. ns = max(compute_time, transfer_time)
  1084. ms = ns / 1e6
  1085. return ms
  1086. def get_template_node(self) -> Optional[ir.TemplateBuffer]:
  1087. return None
  1088. def get_template_node_or_throw(self) -> ir.TemplateBuffer:
  1089. template = self.get_template_node()
  1090. assert template is not None
  1091. return template
  1092. @staticmethod
  1093. def get_prologue_template_epilogue(
  1094. nodes: list[BaseSchedulerNode],
  1095. ) -> tuple[list[BaseSchedulerNode], BaseSchedulerNode, list[BaseSchedulerNode]]:
  1096. """
  1097. For the list of nodes, get the prologue, template, and epilogue
  1098. """
  1099. template_index = next(i for i, n in enumerate(nodes) if n.is_template())
  1100. prologue = nodes[:template_index]
  1101. template_node = nodes[template_index]
  1102. epilogue = nodes[template_index + 1 :]
  1103. return prologue, template_node, epilogue
  1104. @functools.cache
  1105. def get_estimate_runtime_cache() -> torch._inductor.codecache.LocalCache:
  1106. return torch._inductor.codecache.LocalCache()
  1107. def get_estimate_runtime_cache_key_from_snode(snode: BaseSchedulerNode) -> str:
  1108. python_kernel_name = getattr(snode.node, "python_kernel_name", "")
  1109. args = snode.node.inputs # type: ignore[union-attr]
  1110. args = snode.node.fill_non_provided_args( # type: ignore[union-attr]
  1111. [*args, *snode.node.constant_args], # type: ignore[union-attr]
  1112. snode.node.kwargs, # type: ignore[union-attr]
  1113. )
  1114. kwargs = snode.node.kwargs # type: ignore[union-attr]
  1115. flat_args, flat_args_pytree_spec = pytree.tree_flatten((args, kwargs))
  1116. def _is_tensor_ir(x) -> bool: # type: ignore[no-untyped-def]
  1117. return isinstance(x, ir.IRNode) and not isinstance(x, ir.GeneratorState)
  1118. cache_key = str(
  1119. (python_kernel_name,)
  1120. + tuple(tuple(a.get_size()) if _is_tensor_ir(a) else None for a in flat_args)
  1121. )
  1122. return cache_key
  1123. def _get_mm_like_fn(snode: BaseSchedulerNode) -> Optional[Callable[[Any], Any]]:
  1124. if not isinstance(snode, ExternKernelSchedulerNode):
  1125. return None
  1126. mms_fns = {
  1127. "extern_kernels.mm": torch.ops.aten.mm,
  1128. "extern_kernels.bmm": torch.ops.aten.bmm,
  1129. "extern_kernels.addmm": torch.ops.aten.addmm,
  1130. }
  1131. python_kernel_name = getattr(snode.node, "python_kernel_name", "")
  1132. if python_kernel_name not in mms_fns:
  1133. return None
  1134. if not isinstance(snode.node, ir.ExternKernel):
  1135. return None
  1136. return mms_fns[python_kernel_name]
  1137. def maybe_estimate_runtime_benchmark(snode: BaseSchedulerNode) -> Optional[float]:
  1138. bench_fn = None
  1139. args_kwargs_fn = None
  1140. if config.runtime_estimations_mms_benchmark:
  1141. mm_fn = _get_mm_like_fn(snode)
  1142. if mm_fn is None:
  1143. return None
  1144. bench_fn = mm_fn
  1145. args_kwargs_fn = lambda: snode_args_kwargs(snode) # noqa: E731
  1146. else:
  1147. return None
  1148. cache_key = get_estimate_runtime_cache_key_from_snode(snode)
  1149. cache = get_estimate_runtime_cache()
  1150. cache_val = cache.lookup(cache_key)
  1151. if cache_val is not None:
  1152. assert isinstance(cache_val, float)
  1153. return cache_val
  1154. from .utils import snode_args_kwargs
  1155. args, kwargs = args_kwargs_fn()
  1156. from torch._inductor.runtime.benchmarking import benchmarker
  1157. ms = benchmarker.benchmark(
  1158. bench_fn,
  1159. args, # pyrefly: ignore[bad-argument-type]
  1160. kwargs,
  1161. memory_warmup_iters=5,
  1162. benchmark_iters=10,
  1163. max_benchmark_duration=10,
  1164. ) # type: ignore[arg-type]
  1165. cache.set_value(cache_key, value=ms)
  1166. return ms
  1167. @dataclasses.dataclass(slots=True)
  1168. class WhyNoFuse:
  1169. name1: str
  1170. name2: str
  1171. reason: str
  1172. args: tuple[Any, ...]
  1173. def __init__(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> None:
  1174. self.name1 = node1.get_name()
  1175. self.name2 = node2.get_name()
  1176. def __call__(self, reason: str, *args: Any) -> None:
  1177. self.reason = reason
  1178. self.args = args
  1179. fusion_log.debug(self)
  1180. def __str__(self) -> str:
  1181. return f"cannot fuse {self.name1} with {self.name2}: " + (
  1182. self.reason % self.args
  1183. )
  1184. def pformat(obj: Any) -> str:
  1185. if isinstance(obj, (OrderedSet, set)): # noqa: set_linter
  1186. # pformat has trouble with sets of sympy exprs
  1187. obj = sorted(obj, key=str)
  1188. result = pprint.pformat(obj, indent=4)
  1189. if "\n" in result:
  1190. return f"\n{textwrap.indent(result, ' ' * 4)}"
  1191. return result
  1192. class OutputNode:
  1193. def __init__(self, dep: StarDep) -> None:
  1194. self.unmet_dependencies = OrderedSet([dep])
  1195. def is_reduction(self) -> bool:
  1196. return False
  1197. def get_inputs_that_alias_output(self) -> Sequence[str]:
  1198. return ()
  1199. def get_name(self) -> str:
  1200. return "OUTPUT"
  1201. __repr__ = get_name
  1202. def _prune_redundant_deps(
  1203. node: BaseSchedulerNode,
  1204. name_to_fused_node: dict[str, BaseSchedulerNode],
  1205. name_to_buf: dict[str, SchedulerBuffer],
  1206. ) -> None:
  1207. """
  1208. Prunes weakdeps intended for mutation ordering
  1209. on an upstream fused node if after fusion there is another dependency
  1210. on the fused upstream node, making the weakdep redundant
  1211. In essence this enforces an ordering on fusions. As fusions occur, weakdeps will
  1212. be incrementally removed, enabling other fusions, ensuring they are fused in order.
  1213. """
  1214. name_to_dep_count: Counter[str] = collections.Counter()
  1215. for dep in node.unmet_dependencies:
  1216. if not isinstance(dep, WeakDep):
  1217. op_name = name_to_buf[dep.name].defining_op_name()
  1218. name_to_dep_count[name_to_fused_node[op_name].get_name()] += 1
  1219. def should_prune(dep: Dep) -> bool:
  1220. if isinstance(dep, WeakDep):
  1221. op_name = name_to_buf[dep.name].defining_op_name()
  1222. is_redundant = name_to_dep_count[
  1223. name_to_fused_node[op_name].get_name()
  1224. ] > 0 and node.scheduler.fusable_weak_dep(
  1225. dep, name_to_fused_node[op_name], node
  1226. )
  1227. # These can occur because fused nodes always gather deps from their snodes
  1228. # If B has a weakdep on A
  1229. # B gets fused with C, then any time BC is fused, the weakdep will reappear
  1230. is_self_dep = name_to_fused_node[op_name] == node
  1231. return is_redundant or is_self_dep
  1232. else:
  1233. return False
  1234. deps_to_prune = OrderedSet(
  1235. dep for dep in node.unmet_dependencies if should_prune(dep)
  1236. )
  1237. if deps_to_prune:
  1238. node.unmet_dependencies = node.unmet_dependencies - deps_to_prune
  1239. node.set_read_writes(node.read_writes.remove_reads(deps_to_prune))
  1240. class ExternKernelSchedulerNode(BaseSchedulerNode):
  1241. def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None:
  1242. super().__init__(scheduler)
  1243. self._init_from_node(node)
  1244. self.set_read_writes(node.get_read_writes())
  1245. def debug_str_extra(self) -> str:
  1246. return f"{self.get_name()}.node.kernel = {getattr(self.node, 'python_kernel_name', None)}"
  1247. def is_extern(self) -> bool:
  1248. return True
  1249. def has_side_effects(self) -> bool:
  1250. assert self.node is not None
  1251. return hasattr(self.node, "has_side_effects") and self.node.has_side_effects()
  1252. class NopKernelSchedulerNode(BaseSchedulerNode):
  1253. def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None:
  1254. super().__init__(scheduler)
  1255. self._init_from_node(node)
  1256. self.set_read_writes(node.get_read_writes())
  1257. class SchedulerNode(BaseSchedulerNode):
  1258. """
  1259. A SchedulerNode is a node for scheduling that encapsulates either
  1260. a ComputedBuffer or a TemplateBuffer.
  1261. """
  1262. _sizes: tuple[Sequence[sympy.Expr], ...]
  1263. _body: LoopBody
  1264. def __init__(
  1265. self,
  1266. scheduler: Scheduler,
  1267. node: Union[ir.ComputedBuffer, ir.TemplateBuffer],
  1268. ) -> None:
  1269. super().__init__(scheduler)
  1270. self._init_from_node(node)
  1271. self._compute_attrs()
  1272. def _compute_attrs(
  1273. self,
  1274. extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None,
  1275. recompute_sizes_body_func: Optional[Callable[_P, _T]] = None,
  1276. ) -> None:
  1277. assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer))
  1278. self._sizes, body = self.node.simplify_and_reorder(
  1279. extra_indexing_constraints=extra_indexing_constraints,
  1280. recompute_sizes_body_func=recompute_sizes_body_func,
  1281. )
  1282. self._body = body # type: ignore[assignment]
  1283. device = self.node.get_device_or_error()
  1284. group_fn = self.scheduler.get_backend(device).group_fn
  1285. self.group = (device, group_fn(self._sizes))
  1286. # Don't normalize since normalization will merge loops which
  1287. # makes it hard to decide new loop orders.
  1288. should_normalize = not config.loop_ordering_after_fusion or not is_gpu(
  1289. device.type
  1290. )
  1291. if isinstance(self.node, ir.TemplateBuffer):
  1292. self.set_read_writes(
  1293. self.node.extract_read_writes(normalize=should_normalize)
  1294. )
  1295. else:
  1296. self.set_read_writes(
  1297. dependencies.extract_read_writes(
  1298. self._body, *self._sizes, normalize=should_normalize
  1299. )
  1300. )
  1301. def recompute_size_and_body(
  1302. self,
  1303. extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None,
  1304. recompute_sizes_body_func: Optional[Callable[..., Any]] = None,
  1305. ) -> None:
  1306. self._compute_attrs(
  1307. extra_indexing_constraints=extra_indexing_constraints,
  1308. recompute_sizes_body_func=recompute_sizes_body_func,
  1309. )
  1310. def refresh_dependencies(
  1311. self, normalize: bool, need_clear_tiling_cache: bool
  1312. ) -> None:
  1313. # Fake dependencies are added manually. They can not be analyzed from
  1314. # extract_read_writes. Find them out and apply manually.
  1315. fake_deps: OrderedSet[Dep] = OrderedSet(
  1316. dep for dep in self.read_writes.reads if isinstance(dep, (WeakDep, StarDep))
  1317. )
  1318. # don't normalize since the loop order may need to be further changed
  1319. # later
  1320. self.set_read_writes(
  1321. dependencies.extract_read_writes(
  1322. self._body, *self._sizes, normalize=normalize
  1323. )
  1324. .with_read(fake_deps)
  1325. .rename(self.mutation_renames)
  1326. )
  1327. self.pointwise_read_writes.clear_cache(self)
  1328. if need_clear_tiling_cache:
  1329. from .codegen.simd import SIMDScheduling
  1330. # TODO(shunting) if this cause compilation time increase when
  1331. # enabling LOAF by default, try just clearing the specific cache
  1332. # entry by using a customized cache implementation rather than
  1333. # lru_cache.
  1334. SIMDScheduling.candidate_tilings.cache_clear()
  1335. def apply_new_loop_order(self, new_order: Sequence[int]) -> None:
  1336. self._body = self._body.reorder_iter_loops(
  1337. new_order,
  1338. )
  1339. self._sizes = self._body.sizes
  1340. self.refresh_dependencies(normalize=False, need_clear_tiling_cache=True)
  1341. def swap_pw_red_dimension(self) -> None:
  1342. num_rdims = self._body.get_original_num_rdims()
  1343. num_pwdims = len(self._body.iter_vars) - num_rdims
  1344. pwdims = tuple(range(num_pwdims))
  1345. rdims = tuple(range(num_pwdims, num_pwdims + num_rdims))
  1346. self.apply_new_loop_order(rdims + pwdims)
  1347. assert len(self.group[1]) == 2
  1348. self.group = self.group[0], (self.group[1][1], self.group[1][0])
  1349. def extract_pw_from_reduction(self) -> BaseSchedulerNode:
  1350. self._body = self._body.extract_pw_from_reduction()
  1351. return self
  1352. def cancel_reduction_split(self) -> None:
  1353. if not MixOrderReduction.is_split_reduction(self):
  1354. return
  1355. assert isinstance(self.node, ir.ComputedBuffer)
  1356. with self.node.with_original_inner_fn():
  1357. self._compute_attrs()
  1358. def expand_dimension_for_pointwise_node(
  1359. self, dimension: int, new_range: int
  1360. ) -> None:
  1361. assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer))
  1362. self._body = self._body.expand_dimension_for_pointwise_node(
  1363. dimension, new_range
  1364. )
  1365. self._sizes = self._body.sizes
  1366. device = self.node.get_device_or_error()
  1367. group_fn = self.scheduler.get_backend(device).group_fn
  1368. self.group = (device, group_fn(self._sizes))
  1369. # Need normalize the prefix name to facilitate finding common dependencies
  1370. self.refresh_dependencies(normalize=True, need_clear_tiling_cache=True)
  1371. def merge_loops(self) -> None:
  1372. self._body = self._body.merge_loops()
  1373. self._sizes = self._body.sizes
  1374. # merge_loops is called after loop reordering.
  1375. # We still need retain fake dependencies since codegen the
  1376. # estimated amount of memory access rely on them.
  1377. #
  1378. # Merge loops does not affect the tiling decision. So we
  1379. # don't need clear the tiling cache.
  1380. self.refresh_dependencies(normalize=True, need_clear_tiling_cache=False)
  1381. def reorder_loops_by_dep_pair(
  1382. self, self_dep: MemoryDep, other_dep: MemoryDep
  1383. ) -> bool:
  1384. new_order = None
  1385. self_sizes = self._sizes[0]
  1386. if len(self_sizes) == self_dep.num_vars == other_dep.num_vars:
  1387. new_order = self_dep.decide_loop_order_to_match(other_dep)
  1388. if new_order:
  1389. # pyrefly: ignore [bad-assignment]
  1390. metrics.num_loop_reordering += 1
  1391. loop_ordering_log.debug(
  1392. "Reorder loops for %s with order %s", self.get_name(), new_order
  1393. )
  1394. self.apply_new_loop_order(new_order)
  1395. return True
  1396. else:
  1397. loop_ordering_log.debug(
  1398. "Don't reordering %s because we can not decide the suitable loop order",
  1399. self.get_name(),
  1400. )
  1401. return False
  1402. def debug_str_extra(self) -> str:
  1403. name = self.get_name()
  1404. lines = [
  1405. f"{name}.group.device = {self.group[0]}",
  1406. f"{name}.group.iteration = {self.group[1]}",
  1407. f"{name}.sizes = {self._sizes}",
  1408. ]
  1409. for dep in self.read_writes.reads_and_writes():
  1410. if not isinstance(dep, WeakDep):
  1411. buf_name = dep.name
  1412. buf = V.graph.get_buffer(buf_name)
  1413. if not isinstance(buf, ir.TorchBindObject):
  1414. lines.append(f"{buf_name}_layout = {pformat(buf.layout)}")
  1415. if isinstance(self._body, LoopBody):
  1416. lines.append(f"class {name}_loop_body:")
  1417. lines.append(textwrap.indent(self._body.debug_str(), " "))
  1418. assert self.node is not None
  1419. lines.extend(self._debug_str_for_device())
  1420. return "\n".join(lines)
  1421. def get_ranges(self) -> Sequence[Sequence[sympy.Expr]]:
  1422. return self._sizes
  1423. def is_reduction(self) -> bool:
  1424. assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)), (
  1425. f"{type(self.node)=}"
  1426. )
  1427. # self._body containing partial accumulate means the reduction is
  1428. # converted to a pointwise node. Need this extra check since
  1429. # we change self._body but didn't change self.node (IRNode)
  1430. # when converting a reduction to a pointwise
  1431. return bool(self.node.get_reduction_type()) and (
  1432. self._body is None or not self._body.has_partial_accumulate
  1433. )
  1434. def is_native_matmul(self) -> bool:
  1435. assert isinstance(self.node, ir.ComputedBuffer), f"{type(self.node)=}"
  1436. return self.node.get_reduction_type() == "dot"
  1437. def is_split_scan(self) -> bool:
  1438. assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)), (
  1439. f"{type(self.node)=}"
  1440. )
  1441. return isinstance(self.node, ir.ComputedBuffer) and isinstance(
  1442. self.node.data, ir.SplitScan
  1443. )
  1444. def is_template(self) -> bool:
  1445. return isinstance(self.node, ir.TemplateBuffer)
  1446. def get_template_node(self) -> Optional[ir.TemplateBuffer]:
  1447. return self.node if isinstance(self.node, ir.TemplateBuffer) else None
  1448. def run(self, *index_vars: Sequence[sympy.Expr]) -> None:
  1449. self.decide_inplace_update()
  1450. self.mark_run()
  1451. self.codegen(index_vars)
  1452. def ranges_from_index_vars(
  1453. self, index_vars: Sequence[Sequence[sympy.Expr]]
  1454. ) -> dict[sympy.Expr, sympy.Expr]:
  1455. sizes = self._sizes
  1456. assert sum(map(len, sizes)) == sum(map(len, index_vars))
  1457. var_ranges = dict(
  1458. zip(
  1459. itertools.chain.from_iterable(index_vars),
  1460. itertools.chain.from_iterable(sizes),
  1461. )
  1462. )
  1463. return var_ranges
  1464. def codegen(self, index_vars: Sequence[Sequence[sympy.Expr]]) -> None:
  1465. """
  1466. Generate code for this node using the provided index variables.
  1467. This method sets up the appropriate context for code generation, including
  1468. simplifying indexing expressions based on the variable ranges, and then
  1469. calls the node's body function with the index variables.
  1470. Args:
  1471. index_vars: A sequence of sequences of sympy expressions representing
  1472. the index variables for each dimension of the computation.
  1473. """
  1474. var_ranges = self.ranges_from_index_vars(index_vars)
  1475. try:
  1476. with (
  1477. V.set_ops_handler(SimplifyIndexing(V.get_ops_handler(), var_ranges)),
  1478. V.kernel.set_current_node(self),
  1479. ):
  1480. self._body(*index_vars)
  1481. except Exception:
  1482. log.fatal("Error in codegen for %s", self.node)
  1483. raise
  1484. def pointwise_or_reduction_read_writes(
  1485. self, pointwise: bool = True
  1486. ) -> dependencies.ReadWrites:
  1487. """
  1488. Get the memory dependencies in either the pointwise or the reduction axes.
  1489. """
  1490. keep_sizes, ignore_sizes = self._sizes if pointwise else reversed(self._sizes)
  1491. return dependencies.extract_read_writes(
  1492. self._body, keep_sizes, hidden_args=[[sympy.S.Zero] * len(ignore_sizes)]
  1493. )
  1494. @cache_on_self
  1495. def pointwise_read_writes(self) -> dependencies.ReadWrites:
  1496. """
  1497. Get the memory dependencies in the non-reduction axes.
  1498. """
  1499. return self.pointwise_or_reduction_read_writes(pointwise=True)
  1500. @cache_on_self
  1501. def reduction_read_writes(self) -> dependencies.ReadWrites:
  1502. """
  1503. Get the memory dependencies in the reduction axes.
  1504. """
  1505. return self.pointwise_or_reduction_read_writes(pointwise=False)
  1506. def can_inplace(self, read_dep: dependencies.Dep) -> bool:
  1507. if self.is_template():
  1508. return False
  1509. if any(out.get_aliases() for out in self.get_outputs()):
  1510. return False
  1511. if len(self.read_writes.writes) == 1 and isinstance(
  1512. read_dep, dependencies.MemoryDep
  1513. ):
  1514. write_dep = next(iter(self.read_writes.writes))
  1515. assert isinstance(write_dep, dependencies.MemoryDep), f"{type(write_dep)=}"
  1516. return read_dep.index == write_dep.index and read_dep.size == write_dep.size
  1517. return False
  1518. @cache_on_self
  1519. def _get_atomic_add_buffers(self) -> OrderedSet[str]:
  1520. buffers_store_as_atomic_add: OrderedSet[str] = OrderedSet()
  1521. if isinstance(self._body, LoopBody):
  1522. for node in self._body.get_nodes():
  1523. if (
  1524. node.op == "call_method"
  1525. and node.target == "store"
  1526. and (
  1527. ("mode" in node.kwargs and node.kwargs["mode"] == "atomic_add")
  1528. or (len(node.args) == 5 and node.args[4] == "atomic_add")
  1529. )
  1530. ):
  1531. buffers_store_as_atomic_add.add(
  1532. node.kwargs["name"]
  1533. if "name" in node.kwargs
  1534. else (node.args[1] if len(node.args) >= 2 else "")
  1535. )
  1536. return buffers_store_as_atomic_add
  1537. @cache_on_self
  1538. def has_side_effects(self) -> bool:
  1539. # self._body is None sometimes that's why this check was added
  1540. if self._body is not None and self._body.has_op("device_assert_async"):
  1541. return True
  1542. return super().has_side_effects()
  1543. def refresh_group_node_dependencies(
  1544. group_snode: Union[FusedSchedulerNode, GroupedSchedulerNode],
  1545. ) -> None:
  1546. snodes = group_snode.snodes
  1547. group_snode.set_read_writes(
  1548. dependencies.ReadWrites.merge_list([x.read_writes for x in snodes])
  1549. )
  1550. group_snode.unmet_dependencies = (
  1551. OrderedSet(
  1552. dep
  1553. for dep in OrderedSet.union(*[x.unmet_dependencies for x in snodes])
  1554. if dep.name not in group_snode.get_buffer_names()
  1555. )
  1556. - group_snode.read_writes.writes
  1557. )
  1558. def init_group_node(
  1559. group_snode: Union[FusedSchedulerNode, GroupedSchedulerNode],
  1560. scheduler: Scheduler,
  1561. snodes: list[BaseSchedulerNode],
  1562. ) -> None:
  1563. assert isinstance(group_snode, (FusedSchedulerNode, GroupedSchedulerNode))
  1564. group_snode.snodes = snodes
  1565. group_snode.scheduler = scheduler
  1566. group_snode.node = None
  1567. group_snode.ancestors = OrderedSet.union(
  1568. *[x.ancestors for x in snodes if x.ancestors is not None]
  1569. )
  1570. refresh_group_node_dependencies(group_snode)
  1571. group_snode.min_order = min(x.min_order for x in group_snode.snodes)
  1572. group_snode.max_order = max(x.max_order for x in group_snode.snodes)
  1573. group_snode.outputs_by_name = {
  1574. buf.get_name(): buf for buf in group_snode.get_outputs()
  1575. }
  1576. class FusedSchedulerNode(BaseSchedulerNode):
  1577. """
  1578. This is a "fake" scheduler node that represents a group of scheduler nodes
  1579. that are meant to be fused together. The way it does this is by maintaining
  1580. its unmet dependencies as the union of its constituent nodes.
  1581. """
  1582. snodes: list[BaseSchedulerNode]
  1583. @classmethod
  1584. def fuse(
  1585. cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode
  1586. ) -> FusedSchedulerNode:
  1587. assert node1.scheduler is node2.scheduler
  1588. assert isinstance(node1, (SchedulerNode, FusedSchedulerNode))
  1589. if node1.is_template() and isinstance(node2, ExternKernelSchedulerNode):
  1590. # Fuse multi outputs template and its outputs
  1591. # * Node1 has memorydep of MultiOutput in reads
  1592. # * Node2 has StarDep of MultiOutput in writes
  1593. # Rewrite the Node2' StarDep to MemoryDep, because calculate score_fusion_memory
  1594. # of the template node and its epilogue requires the same type of dependencies
  1595. assert isinstance(node2.node, MultiOutput)
  1596. assert len(node2.read_writes.writes) == 1
  1597. assert isinstance(next(iter(node2.read_writes.writes)), StarDep)
  1598. name = next(iter(node2.read_writes.writes)).name
  1599. template_nodes = [node for node in node1.get_nodes() if node.is_template()]
  1600. assert len(template_nodes) == 1
  1601. template_node = template_nodes[0]
  1602. assert len(template_node.read_writes.writes) == 1
  1603. write = next(iter(template_node.read_writes.writes))
  1604. assert isinstance(write, MemoryDep)
  1605. node2.read_writes.writes = OrderedSet(
  1606. [
  1607. MemoryDep(
  1608. name, write.index, write.var_names, write.size, write.mode
  1609. ),
  1610. ]
  1611. )
  1612. else:
  1613. assert isinstance(node2, (SchedulerNode, FusedSchedulerNode))
  1614. nodes = list(itertools.chain(node1.get_nodes(), node2.get_nodes()))
  1615. return cls(node1.scheduler, nodes)
  1616. def extract_pw_from_reduction(self) -> BaseSchedulerNode:
  1617. for subnode in self.snodes:
  1618. assert isinstance(subnode, SchedulerNode)
  1619. assert subnode.is_reduction()
  1620. subnode.extract_pw_from_reduction()
  1621. return self
  1622. def swap_pw_red_dimension(self) -> None:
  1623. for subnode in self.snodes:
  1624. assert isinstance(subnode, SchedulerNode)
  1625. subnode.swap_pw_red_dimension()
  1626. @cache_on_self
  1627. def estimate_flops(self) -> int | None:
  1628. # don't increment counters in fused methods so we don't double count
  1629. fps = list(
  1630. filter(
  1631. None,
  1632. (
  1633. node.estimate_flops()
  1634. for node in self.get_nodes()
  1635. if node.is_template() or node.is_extern()
  1636. ),
  1637. )
  1638. )
  1639. if len(fps) == 0:
  1640. return None
  1641. ret = sum(fps)
  1642. return ret
  1643. def reorder_loops_by_dep_pair(
  1644. self, self_dep: MemoryDep, other_dep: MemoryDep
  1645. ) -> bool:
  1646. """
  1647. Return true if a loop reordering is performed.
  1648. """
  1649. if self.is_template():
  1650. # We can not really reorder loops for a triton template
  1651. return False
  1652. self_sizes = None
  1653. for snode in self.snodes:
  1654. assert isinstance(snode, SchedulerNode)
  1655. if self_sizes is not None and tuple(self_sizes) != tuple(snode._sizes[0]):
  1656. loop_ordering_log.debug(
  1657. "Can not reorder fused node due to different sizes"
  1658. )
  1659. return False
  1660. self_sizes = snode._sizes[0]
  1661. new_order = None
  1662. assert self_sizes is not None
  1663. if len(self_sizes) == self_dep.num_vars == other_dep.num_vars:
  1664. new_order = self_dep.decide_loop_order_to_match(other_dep)
  1665. if not new_order:
  1666. loop_ordering_log.debug(
  1667. "Dont reordering fused node %s because we can not decide the suitable loop order",
  1668. self.get_name(),
  1669. )
  1670. return False
  1671. # pyrefly: ignore [bad-assignment]
  1672. metrics.num_loop_reordering += 1
  1673. loop_ordering_log.debug(
  1674. "Reorder loops for fused node %s with order %s", self.get_name(), new_order
  1675. )
  1676. for snode in self.snodes:
  1677. assert isinstance(snode, SchedulerNode)
  1678. snode.apply_new_loop_order(new_order)
  1679. refresh_group_node_dependencies(self)
  1680. return True
  1681. def __init__(self, scheduler: Scheduler, snodes: list[BaseSchedulerNode]) -> None:
  1682. super().__init__(scheduler)
  1683. init_group_node(self, scheduler, snodes)
  1684. self.users: list[NodeUser] = []
  1685. self.group = max(snodes, key=lambda x: int(x.is_reduction())).group
  1686. @cache_on_self
  1687. def get_name(self) -> str:
  1688. return "_".join([x.get_name() for x in self.snodes])
  1689. def get_first_name(self) -> str:
  1690. return self.snodes[0].get_name()
  1691. @cache_on_self
  1692. def get_buffer_names(self) -> OrderedSet[str]:
  1693. return OrderedSet.union(*[x.get_buffer_names() for x in self.snodes])
  1694. def get_outputs(self) -> list[SchedulerBuffer]:
  1695. result: list[SchedulerBuffer] = []
  1696. for node in self.snodes:
  1697. result.extend(node.get_outputs())
  1698. return result
  1699. def debug_str_extra(self) -> str:
  1700. lines = [
  1701. f"{self.get_name()}.snodes[{i}] =\n{node.debug_str()}"
  1702. for i, node in enumerate(self.snodes)
  1703. ]
  1704. node = self.snodes[0].node
  1705. if node is not None:
  1706. lines.extend(self._debug_str_for_device())
  1707. return textwrap.indent("\n".join(lines).rstrip(), " ")
  1708. def debug_str_short(self) -> str:
  1709. snodes_str = [node.debug_str_short() for node in self.snodes]
  1710. return f"{self}, snodes: {snodes_str}"
  1711. def set_last_usage(
  1712. self, future_used_buffers: OrderedSet[str], mutation_real_name: dict[str, str]
  1713. ) -> None:
  1714. # Set self.last_usage using the global information
  1715. # This will be used for inter-kernel optimisations
  1716. super().set_last_usage(future_used_buffers, mutation_real_name)
  1717. # Set self.last_usage on the snodes
  1718. # This will be used for optimisations within the kernel
  1719. future_used_buffers: OrderedSet[str] = OrderedSet()
  1720. for node in reversed(self.snodes):
  1721. node.set_last_usage(future_used_buffers, mutation_real_name)
  1722. future_used_buffers.update(node.last_usage)
  1723. @cache_on_self
  1724. def used_buffer_names(self) -> OrderedSet[str]:
  1725. return OrderedSet.union(*[x.used_buffer_names() for x in self.snodes])
  1726. @cache_on_self
  1727. def used_or_aliased_buffer_names(self) -> OrderedSet[str]:
  1728. return OrderedSet.union(
  1729. *[x.used_or_aliased_buffer_names() for x in self.snodes]
  1730. )
  1731. def get_nodes(self) -> Sequence[BaseSchedulerNode]:
  1732. return self.snodes
  1733. def __repr__(self) -> str:
  1734. return f"{type(self).__name__}(nodes={self.get_name()})"
  1735. @cache_on_self
  1736. def is_reduction(self) -> bool:
  1737. return any(x.is_reduction() for x in self.snodes)
  1738. @cache_on_self
  1739. def is_native_matmul(self) -> bool:
  1740. return any(x.is_native_matmul() for x in self.snodes)
  1741. @cache_on_self
  1742. def is_split_scan(self) -> bool:
  1743. return any(x.is_split_scan() for x in self.snodes)
  1744. @cache_on_self
  1745. def is_template(self) -> bool:
  1746. return any(x.is_template() for x in self.snodes)
  1747. @cache_on_self
  1748. def get_template_node(self) -> Optional[ir.TemplateBuffer]:
  1749. for node in self.snodes:
  1750. if node.is_template():
  1751. return node.get_template_node()
  1752. return None
  1753. def get_device(self) -> torch.device:
  1754. return self.group[0]
  1755. @cache_on_self
  1756. def has_aliasing_or_mutation(self) -> bool:
  1757. return any(x.has_aliasing_or_mutation() for x in self.snodes)
  1758. # None of these need to be implemented, as a FusedSchedulerNode is just an
  1759. # abstraction for scheduling purposes
  1760. def update_mutated_names(self, renames: dict[str, str]) -> None:
  1761. raise NotImplementedError
  1762. def add_fake_dep(self, name: Dep) -> None:
  1763. raise NotImplementedError
  1764. def can_inplace(self, read_dep: dependencies.Dep) -> bool:
  1765. raise NotImplementedError
  1766. def debug_str(self) -> str:
  1767. """Longer form printout for trace logs"""
  1768. name = self.get_name()
  1769. node_typestr = ",".join(type(n).__name__ for n in self.snodes)
  1770. buf = IndentedBuffer()
  1771. buf.splice(
  1772. f"""\
  1773. {name}: {type(self).__name__}({node_typestr})
  1774. {name}.writes = {pformat(self.read_writes.writes)}
  1775. {name}.unmet_dependencies = {pformat(self.unmet_dependencies)}
  1776. {name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)}
  1777. {name}.outputs = [
  1778. """
  1779. )
  1780. with buf.indent():
  1781. for out in self.get_outputs():
  1782. buf.splice(out.debug_str())
  1783. buf.writeline("]")
  1784. try:
  1785. buf.splice(self.debug_str_extra())
  1786. except Exception:
  1787. log.warning("Ignoring error in debug_str()", exc_info=True)
  1788. return buf.getrawvalue().rstrip()
  1789. @cache_on_self
  1790. def has_side_effects(self) -> bool:
  1791. if self.snodes is not None:
  1792. return any(node.has_side_effects() for node in self.snodes)
  1793. return super().has_side_effects()
  1794. class FusedMixOrderReductions(FusedSchedulerNode):
  1795. def __init__(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> None:
  1796. if not MixOrderReduction.is_contiguous_node(node1):
  1797. assert MixOrderReduction.is_contiguous_node(node2)
  1798. node1, node2 = node2, node1
  1799. self.node1 = node1
  1800. self.node2 = node2
  1801. super().__init__(
  1802. node1.scheduler, list(node1.get_nodes()) + list(node2.get_nodes())
  1803. )
  1804. self.numel = MixOrderReduction.get_numel(self.node1)
  1805. def sub_node_can_fuse(
  1806. self,
  1807. node1: BaseSchedulerNode,
  1808. node2: BaseSchedulerNode,
  1809. other_nodes: tuple[BaseSchedulerNode, ...],
  1810. ):
  1811. """
  1812. node1 is from the current mix order reduction; node2 is another node we want to fuse in.
  1813. other_nodes are passed in to check if fusion will introduce producer/consumer relationship
  1814. between the inner and outer reduction. If yes, we don't fuse.
  1815. """
  1816. assert not isinstance(node1, FusedMixOrderReductions)
  1817. assert not isinstance(node2, FusedMixOrderReductions)
  1818. # When we fuse extra nodes into a FusedMixOrderReductions node,
  1819. # we should not allow recursive mix-order reduction being
  1820. # created.
  1821. if not self.scheduler.can_fuse(node1, node2, allow_mix_order_reduction=False):
  1822. return False
  1823. # Since node1 is from the current mix order reduction, if node1 is
  1824. # contiguous, the fused node should also be contiguous.
  1825. if MixOrderReduction.is_contiguous_node(
  1826. node1
  1827. ) and not MixOrderReduction.is_contiguous_node(node2):
  1828. return False
  1829. def _get_ancestors(nodes: tuple[BaseSchedulerNode, ...]) -> OrderedSet[str]:
  1830. out = OrderedSet()
  1831. return out.union(*(n.ancestors for n in nodes))
  1832. def _get_operation_names(
  1833. nodes: tuple[BaseSchedulerNode, ...],
  1834. ) -> OrderedSet[str]:
  1835. out = OrderedSet()
  1836. return out.union(*(n.get_operation_names() for n in nodes))
  1837. if other_nodes:
  1838. if (_get_ancestors((node1, node2)) & _get_operation_names(other_nodes)) or (
  1839. _get_ancestors(other_nodes) & _get_operation_names((node1, node2))
  1840. ):
  1841. return False
  1842. return (
  1843. not node2.is_reduction()
  1844. or typing.cast(
  1845. int, self.scheduler.score_fusion_memory(node1, node2, count_bytes=False)
  1846. )
  1847. >= self.numel
  1848. )
  1849. def can_fuse_with(self, other: BaseSchedulerNode):
  1850. if not isinstance(other, FusedMixOrderReductions):
  1851. return self.sub_node_can_fuse(
  1852. self.node1, other, (self.node2,)
  1853. ) or self.sub_node_can_fuse(self.node2, other, (self.node1,))
  1854. else:
  1855. # pass empty tuple for the second since the producer/consumer relationship has
  1856. # already been checked in the first call
  1857. return self.sub_node_can_fuse(
  1858. self.node1, other.node1, (self.node2, other.node2)
  1859. ) and self.sub_node_can_fuse(self.node2, other.node2, tuple())
  1860. def fuse_with(self, other: BaseSchedulerNode):
  1861. device = self.node1.get_device()
  1862. backend = self.scheduler.get_backend(device)
  1863. if isinstance(other, FusedMixOrderReductions):
  1864. fused_node1 = backend.fuse(self.node1, other.node1)
  1865. fused_node2 = backend.fuse(self.node2, other.node2)
  1866. return FusedMixOrderReductions(fused_node1, fused_node2)
  1867. else:
  1868. if self.sub_node_can_fuse(self.node1, other, (self.node2,)):
  1869. fused_node = backend.fuse(self.node1, other)
  1870. return FusedMixOrderReductions(fused_node, self.node2)
  1871. else:
  1872. fused_node = backend.fuse(self.node2, other)
  1873. return FusedMixOrderReductions(self.node1, fused_node)
  1874. class ForeachKernelSchedulerNode(FusedSchedulerNode):
  1875. """
  1876. This is a schedular node that consists of a set of scheduler nodes that
  1877. has no data dependencies among them and can be executed in parallel.
  1878. """
  1879. def get_consumer_subnode_for(
  1880. self, producer: BaseSchedulerNode
  1881. ) -> Optional[BaseSchedulerNode]:
  1882. for buf in producer.get_outputs():
  1883. if buf.get_name() in self.read_to_node:
  1884. return self.read_to_node[buf.get_name()]
  1885. return None
  1886. def get_producer_subnode_for(
  1887. self, consumer: BaseSchedulerNode
  1888. ) -> Optional[BaseSchedulerNode]:
  1889. producers = OrderedSet[BaseSchedulerNode]()
  1890. for rd in consumer.read_writes.reads:
  1891. if rd.name not in self.scheduler.name_to_buf:
  1892. continue
  1893. node_name = self.scheduler.name_to_buf[rd.name].defining_op_name()
  1894. if node_name in self.name_to_node:
  1895. producers.add(self.name_to_node[node_name])
  1896. # Don't permit fusion if there are multiple subnodes
  1897. # that this consumer reads from
  1898. if len(producers) == 1:
  1899. return next(iter(producers))
  1900. else:
  1901. return None
  1902. @classmethod
  1903. def can_fuse(cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode) -> bool:
  1904. why = WhyNoFuse(producer, consumer)
  1905. if producer.is_foreach() and consumer.is_foreach():
  1906. producer = typing.cast(ForeachKernelSchedulerNode, producer)
  1907. consumer = typing.cast(ForeachKernelSchedulerNode, consumer)
  1908. foreach_match = len(producer.snodes) == len(consumer.snodes)
  1909. if not foreach_match:
  1910. why("foreach do not have same length")
  1911. return foreach_match and all(
  1912. producer.scheduler.can_fuse(l, r)
  1913. for l, r in zip(producer.snodes, consumer.snodes)
  1914. )
  1915. elif consumer.is_foreach():
  1916. if producer.is_reduction():
  1917. why(
  1918. "candidate producer is a reduction, foreach ops cannot be fused with reductions currently"
  1919. )
  1920. return False
  1921. consumer = typing.cast(ForeachKernelSchedulerNode, consumer)
  1922. consumer_subnode = consumer.get_consumer_subnode_for(producer)
  1923. if consumer_subnode is not None:
  1924. return consumer.scheduler.can_fuse(producer, consumer_subnode)
  1925. why("candidate producer is not dep of any foreach consumer")
  1926. return False
  1927. elif producer.is_foreach():
  1928. if consumer.is_reduction():
  1929. why(
  1930. "candidate consumer is a reduction, foreach ops cannot be fused with reductions currently"
  1931. )
  1932. return False
  1933. producer = typing.cast(ForeachKernelSchedulerNode, producer)
  1934. producer_subnode = producer.get_producer_subnode_for(consumer)
  1935. if producer_subnode is not None:
  1936. return producer.scheduler.can_fuse(producer_subnode, consumer)
  1937. why("candidate consumer has no dep in any foreach producer")
  1938. return False
  1939. raise AssertionError(
  1940. "At least one node passed to ForeachKernelSchedulerNode.can_fuse should be a foreach node"
  1941. )
  1942. @classmethod
  1943. def fuse(
  1944. cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode
  1945. ) -> ForeachKernelSchedulerNode:
  1946. assert producer.is_foreach() or consumer.is_foreach()
  1947. if producer.is_foreach():
  1948. producer = typing.cast(ForeachKernelSchedulerNode, producer)
  1949. use_custom_partition_algo = producer.use_custom_partition_algo
  1950. enable_autotune = producer.enable_autotune
  1951. else:
  1952. consumer = typing.cast(ForeachKernelSchedulerNode, consumer)
  1953. use_custom_partition_algo = consumer.use_custom_partition_algo
  1954. enable_autotune = consumer.enable_autotune
  1955. prev_node_1 = None
  1956. prev_node_2 = None
  1957. fused_nodes: list[BaseSchedulerNode]
  1958. if producer.is_foreach() and consumer.is_foreach():
  1959. producer = typing.cast(ForeachKernelSchedulerNode, producer)
  1960. consumer = typing.cast(ForeachKernelSchedulerNode, consumer)
  1961. fused_nodes = [
  1962. FusedSchedulerNode.fuse(l, r)
  1963. for l, r in zip(producer.snodes, consumer.snodes)
  1964. ]
  1965. elif producer.is_foreach():
  1966. producer = typing.cast(ForeachKernelSchedulerNode, producer)
  1967. producer_subnode = producer.get_producer_subnode_for(consumer)
  1968. fused_nodes = []
  1969. prev_node_1 = producer
  1970. prev_node_2 = None
  1971. for node in producer.snodes:
  1972. if node is producer_subnode:
  1973. new_node = FusedSchedulerNode.fuse(node, consumer)
  1974. prev_node_2 = new_node
  1975. fused_nodes.append(new_node)
  1976. else:
  1977. fused_nodes.append(node)
  1978. elif consumer.is_foreach():
  1979. consumer = typing.cast(ForeachKernelSchedulerNode, consumer)
  1980. consumer_subnode = consumer.get_consumer_subnode_for(producer)
  1981. fused_nodes = []
  1982. prev_node_1 = consumer
  1983. prev_node_2 = None
  1984. for node in consumer.snodes:
  1985. if node is consumer_subnode:
  1986. new_node = FusedSchedulerNode.fuse(producer, node)
  1987. prev_node_2 = new_node
  1988. fused_nodes.append(new_node)
  1989. else:
  1990. fused_nodes.append(node)
  1991. else:
  1992. raise AssertionError(
  1993. "At least one node passed to ForeachKernelSchedulerNode.fuse should be a foreach node"
  1994. )
  1995. return cls(
  1996. producer.scheduler,
  1997. fused_nodes,
  1998. use_custom_partition_algo=use_custom_partition_algo,
  1999. prev_node_1=prev_node_1,
  2000. prev_node_2=prev_node_2,
  2001. enable_autotune=enable_autotune,
  2002. )
  2003. def __init__(
  2004. self,
  2005. scheduler: Scheduler,
  2006. snodes: list[BaseSchedulerNode],
  2007. use_custom_partition_algo: bool,
  2008. prev_node_1: Optional[BaseSchedulerNode] = None,
  2009. prev_node_2: Optional[BaseSchedulerNode] = None,
  2010. enable_autotune: bool = False,
  2011. ) -> None:
  2012. self.read_to_node = {}
  2013. self.name_to_node = {}
  2014. if prev_node_1 is None or prev_node_2 is None:
  2015. super().__init__(scheduler, snodes)
  2016. for node in snodes:
  2017. for read in node.read_writes.reads:
  2018. self.read_to_node[read.name] = node
  2019. for name in node.get_operation_names():
  2020. self.name_to_node[name] = node
  2021. else:
  2022. self.scheduler = scheduler
  2023. self.snodes = snodes
  2024. self.node = None
  2025. self.users: list[NodeUser] = []
  2026. self.set_read_writes(
  2027. dependencies.ReadWrites.merge_list(
  2028. [prev_node_1.read_writes, prev_node_2.read_writes]
  2029. )
  2030. )
  2031. self.unmet_dependencies = (
  2032. OrderedSet(
  2033. dep
  2034. for dep in OrderedSet.union(
  2035. prev_node_1.unmet_dependencies, prev_node_2.unmet_dependencies
  2036. )
  2037. if dep.name not in self.get_buffer_names()
  2038. )
  2039. - self.read_writes.writes
  2040. )
  2041. self.min_order = min([prev_node_1.min_order, prev_node_2.min_order])
  2042. self.max_order = max([prev_node_1.max_order, prev_node_2.max_order])
  2043. if prev_node_1.is_foreach():
  2044. assert isinstance(prev_node_1, ForeachKernelSchedulerNode)
  2045. foreach_node, other_node = prev_node_1, prev_node_2
  2046. else:
  2047. assert isinstance(prev_node_2, ForeachKernelSchedulerNode)
  2048. foreach_node, other_node = prev_node_2, prev_node_1
  2049. self.ancestors = foreach_node.ancestors
  2050. self.ancestors.update(other_node.ancestors)
  2051. self.name_to_node = foreach_node.name_to_node
  2052. for name in other_node.get_operation_names():
  2053. self.name_to_node[name] = other_node
  2054. self.outputs_by_name: dict[str, SchedulerBuffer] = {
  2055. k: v for snode in self.snodes for k, v in snode.outputs_by_name.items()
  2056. }
  2057. self.use_custom_partition_algo = use_custom_partition_algo
  2058. device = snodes[0].get_device()
  2059. assert device
  2060. self.group = (device, ((sympy.Expr("combo_kernel"),),))
  2061. self.origins = OrderedSet[torch.fx.Node]()
  2062. self.enable_autotune = enable_autotune
  2063. @classmethod
  2064. def combinable_nodes(
  2065. cls, nodes: list[BaseSchedulerNode]
  2066. ) -> list[BaseSchedulerNode]:
  2067. extern = [x for x in nodes if isinstance(x, ExternKernelSchedulerNode)]
  2068. if extern:
  2069. log.debug(
  2070. "ComboKernels: %d external nodes are filtered %s",
  2071. len(extern),
  2072. [node.node.get_origins() for node in extern if node.node is not None],
  2073. )
  2074. grouped = [x for x in nodes if isinstance(x, GroupedSchedulerNode)]
  2075. if grouped:
  2076. log.debug(
  2077. "ComboKernels: %d grouped nodes are filtered",
  2078. len(grouped),
  2079. )
  2080. mix_order = [x for x in nodes if isinstance(x, FusedMixOrderReductions)]
  2081. if mix_order:
  2082. log.debug(
  2083. "ComboKernels: %d FusedMixOrderReductions nodes are filtered",
  2084. len(mix_order),
  2085. )
  2086. filtered_nodes = [
  2087. x
  2088. for x in nodes
  2089. if not isinstance(
  2090. x,
  2091. (
  2092. NopKernelSchedulerNode,
  2093. ExternKernelSchedulerNode,
  2094. GroupedSchedulerNode,
  2095. FusedMixOrderReductions,
  2096. ),
  2097. )
  2098. ]
  2099. foreach_nodes = [
  2100. x for x in filtered_nodes if isinstance(x, ForeachKernelSchedulerNode)
  2101. ]
  2102. if foreach_nodes:
  2103. log.debug("ComboKernels: %d foreach nodes are filtered", len(foreach_nodes))
  2104. filtered_nodes = [
  2105. x for x in filtered_nodes if not isinstance(x, ForeachKernelSchedulerNode)
  2106. ]
  2107. template_nodes = [x for x in filtered_nodes if x.is_template()]
  2108. if template_nodes:
  2109. log.debug(
  2110. "ComboKernels: %d template nodes are filtered: %s",
  2111. len(template_nodes),
  2112. template_nodes,
  2113. )
  2114. filtered_nodes = [x for x in filtered_nodes if x not in template_nodes]
  2115. # Filter out reduction nodes if combo_kernels_pointwise_only is enabled
  2116. if config.combo_kernels_pointwise_only:
  2117. reduction_nodes = [x for x in filtered_nodes if x.is_reduction()]
  2118. if reduction_nodes:
  2119. log.debug(
  2120. "ComboKernels: %d reduction nodes are filtered (pointwise_only mode)",
  2121. len(reduction_nodes),
  2122. )
  2123. filtered_nodes = [x for x in filtered_nodes if not x.is_reduction()]
  2124. return filtered_nodes
  2125. @staticmethod
  2126. def _default_group_nodes_for_combo_kernels(
  2127. scheduler: Scheduler,
  2128. ) -> list[list[BaseSchedulerNode]]:
  2129. """
  2130. Returns a list of lists of nodes that are to be grouped together.
  2131. """
  2132. sorted_nodes = scheduler._topological_sort_nodes()
  2133. grouped_nodes = []
  2134. max_num_nodes = 8
  2135. excluded_buffer_names: OrderedSet[str] = OrderedSet(
  2136. [
  2137. buf_name
  2138. for group in sorted_nodes
  2139. for node in group
  2140. if isinstance(node, FusedMixOrderReductions)
  2141. for buf_name in node.get_buffer_names()
  2142. ]
  2143. )
  2144. for nodes in sorted_nodes:
  2145. # Group nodes by device first to avoid mixed-device fusion
  2146. device_groups: dict[Optional[torch.device], list[BaseSchedulerNode]] = (
  2147. defaultdict(list)
  2148. )
  2149. for node in nodes:
  2150. device = node.get_device()
  2151. if device and (device.type == "mps" or device.type == "cpu"):
  2152. continue
  2153. # exclude nodes that read from FusedMixOrderReductions output buffers'
  2154. if node.used_buffer_names() & excluded_buffer_names:
  2155. continue
  2156. device_groups[device].append(node)
  2157. # Chunk each device group separately
  2158. for device_nodes in device_groups.values():
  2159. grouped_nodes.extend(
  2160. [
  2161. device_nodes[i : i + max_num_nodes]
  2162. for i in range(0, len(device_nodes), max_num_nodes)
  2163. ]
  2164. )
  2165. return grouped_nodes
  2166. group_algorithm_for_combo_kernels: Callable[
  2167. [Scheduler], list[list[BaseSchedulerNode]]
  2168. ] = _default_group_nodes_for_combo_kernels
  2169. @staticmethod
  2170. def set_group_algorithm_for_combo_kernels(
  2171. custom_group_algorithm: Callable[[Scheduler], list[list[BaseSchedulerNode]]],
  2172. ) -> None:
  2173. ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels = (
  2174. custom_group_algorithm
  2175. )
  2176. @staticmethod
  2177. def group_nodes_for_combo_kernels(
  2178. scheduler: Scheduler,
  2179. ) -> list[list[BaseSchedulerNode]]:
  2180. return ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels(scheduler)
  2181. def mark_run(self) -> None:
  2182. raise NotImplementedError
  2183. def codegen(self) -> None:
  2184. raise NotImplementedError
  2185. def is_foreach(self) -> bool:
  2186. return True
  2187. def get_subkernel_nodes(self) -> list[BaseSchedulerNode]:
  2188. """Returns a list of nodes which comprise the combo kernel.
  2189. These nodes may be vertically fused."""
  2190. return list(self.snodes)
  2191. def get_nodes(self) -> Sequence[BaseSchedulerNode]:
  2192. """Returns all nodes contained in this kernel, unpacking fused nodes
  2193. into their constituent scheduler nodes."""
  2194. return list(itertools.chain.from_iterable(x.get_nodes() for x in self.snodes))
  2195. def get_first_name(self) -> str:
  2196. return self.snodes[0].get_first_name()
  2197. def prune_redundant_deps(
  2198. self, name_to_fused_node: dict[str, BaseSchedulerNode]
  2199. ) -> None:
  2200. _prune_redundant_deps(self, name_to_fused_node, self.scheduler.name_to_buf)
  2201. for node in self.snodes:
  2202. node.prune_redundant_deps(name_to_fused_node)
  2203. class GroupedSchedulerNode(BaseSchedulerNode):
  2204. """
  2205. This is a "fake" scheduler node that represents a group of scheduler nodes
  2206. that are meant to be *grouped* together (it does not allow another node to be scheduled
  2207. in between its constituent nodes, nor does it allow another node to fuse into any of its constituent nodes).
  2208. The way it does this is by maintaining its unmet dependencies as the union of its constituent nodes.
  2209. Fusion will still happen among the nodes within each GroupedSchedulerNode.
  2210. At codegen time, this scheduler node will be unpacked and codegen is called on each constituent node.
  2211. """
  2212. snodes: list[BaseSchedulerNode]
  2213. @classmethod
  2214. def create(cls, snodes: list[BaseSchedulerNode]) -> GroupedSchedulerNode:
  2215. scheduler = snodes[0].scheduler
  2216. assert all(node.scheduler is scheduler for node in snodes)
  2217. grouped_snode = cls(scheduler, snodes)
  2218. for snode in snodes:
  2219. scheduler.name_to_fused_node[snode.get_name()] = grouped_snode
  2220. scheduler.name_to_fused_node[grouped_snode.get_name()] = grouped_snode
  2221. return grouped_snode
  2222. def __init__(
  2223. self,
  2224. scheduler: Scheduler,
  2225. snodes: list[BaseSchedulerNode],
  2226. temp_grouping: bool = False,
  2227. ) -> None:
  2228. super().__init__(scheduler)
  2229. init_group_node(self, scheduler, snodes)
  2230. # This flag is introduced for "temporary" grouping during some passes,
  2231. # Where nodes are grouped and moved together.
  2232. # After the pass those nodes are flattened.
  2233. # Reusing calculation of grouped unmed_dependencies etc.
  2234. # No fusion logic in this case.
  2235. self.temp_grouping = temp_grouping
  2236. def unpack(self) -> list[BaseSchedulerNode]:
  2237. """
  2238. Do fusion among nodes within this GroupedSchedulerNode,
  2239. and then unpack this GroupedSchedulerNode into regular nodes.
  2240. """
  2241. if self.temp_grouping:
  2242. return self.snodes
  2243. for snode in self.snodes:
  2244. self.scheduler.name_to_fused_node[snode.get_name()] = snode
  2245. del self.scheduler.name_to_fused_node[self.get_name()]
  2246. return self.scheduler.fuse_nodes(self.snodes)
  2247. def add_fake_dep(self, fake_dep: Dep) -> None:
  2248. self.set_read_writes(self.read_writes.with_read(fake_dep))
  2249. self.unmet_dependencies.add(fake_dep)
  2250. @cache_on_self
  2251. def get_name(self) -> str:
  2252. return "_".join([x.get_name() for x in self.snodes])
  2253. def get_first_name(self) -> str:
  2254. return self.snodes[0].get_name()
  2255. @cache_on_self
  2256. def get_buffer_names(self) -> OrderedSet[str]:
  2257. return OrderedSet.union(*[x.get_buffer_names() for x in self.snodes])
  2258. def get_outputs(self) -> list[SchedulerBuffer]:
  2259. result: list[SchedulerBuffer] = []
  2260. for node in self.snodes:
  2261. result.extend(node.get_outputs())
  2262. return result
  2263. @cache_on_self
  2264. def estimate_flops(self) -> int | None:
  2265. # don't increment counters in fused methods so we don't double count
  2266. fps = list(
  2267. filter(
  2268. None,
  2269. (
  2270. node.estimate_flops()
  2271. for node in self.get_nodes()
  2272. if node.is_template() or node.is_extern()
  2273. ),
  2274. )
  2275. )
  2276. if len(fps) == 0:
  2277. return None
  2278. ret = sum(fps)
  2279. return ret
  2280. def get_nodes(self) -> Sequence[BaseSchedulerNode]:
  2281. return self.snodes
  2282. def get_device(self) -> Optional[torch.device]:
  2283. return self.snodes[0].get_device() if self.snodes else None
  2284. @classmethod
  2285. def can_fuse(cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode) -> bool:
  2286. # GroupedSchedulerNode cannot be fused with another node
  2287. return False
  2288. def pick_loop_order(
  2289. stride_lengths: list[list[int]],
  2290. sizes: Sequence[sympy.Expr],
  2291. priority_idx: Sequence[int] = (),
  2292. ) -> list[int]:
  2293. """
  2294. A heuristic to decide loop iteration orders. This has not been well
  2295. tuned and may be something we should autotune.
  2296. """
  2297. @functools.cmp_to_key
  2298. def index_cmp(a: int, b: int) -> int:
  2299. if sizes[a] == 1 or sizes[b] == 1:
  2300. # 1-sizes don't matter, just move them to the end
  2301. return cmp(sizes[a] == 1, sizes[b] == 1)
  2302. # Take abs, otherwise flipped dimensions are treated as smaller
  2303. # strides than contiguous dims
  2304. stride_len_a = [abs(sl[a]) for sl in stride_lengths]
  2305. stride_len_b = [abs(sl[b]) for sl in stride_lengths]
  2306. # equivalent to
  2307. # np.logical_or(stride_lengths[:, b] == 0, stride_lengths[:, a] < stride_lengths[:, b]).all()
  2308. a_first = sum(
  2309. sl_b == 0 or sl_a < sl_b for sl_a, sl_b in zip(stride_len_a, stride_len_b)
  2310. )
  2311. b_first = sum(
  2312. sl_a == 0 or sl_b < sl_a for sl_a, sl_b in zip(stride_len_a, stride_len_b)
  2313. )
  2314. if a_first > b_first:
  2315. return -1
  2316. if b_first > a_first:
  2317. return 1
  2318. # otherwise contiguous
  2319. return cmp(b, a)
  2320. order = list(reversed(range(len(stride_lengths[0]))))
  2321. if len(priority_idx) > 0:
  2322. # if we have priority node, only use that node's order
  2323. stride_lengths = [stride_lengths[pi] for pi in priority_idx]
  2324. if config.pick_loop_orders:
  2325. order.sort(key=index_cmp)
  2326. return order
  2327. def _replace_operation_buffer(
  2328. orig_node: ir.MultiTemplateBuffer, new_node: ir.OperationBuffer
  2329. ) -> None:
  2330. replaced_buf_name = new_node.get_name()
  2331. orig_buf_name = orig_node.get_name()
  2332. assert isinstance(orig_buf_name, str) and isinstance(replaced_buf_name, str)
  2333. replaced_op_name = new_node.get_operation_name()
  2334. orig_op_name = orig_node.get_operation_name()
  2335. assert isinstance(orig_op_name, str) and isinstance(replaced_op_name, str)
  2336. del V.graph.name_to_buffer[replaced_buf_name]
  2337. new_node.name = orig_buf_name
  2338. del V.graph.name_to_op[replaced_op_name]
  2339. new_node.operation_name = orig_op_name
  2340. orig = V.graph.buffers.index(orig_node)
  2341. V.graph.buffers.remove(new_node)
  2342. V.graph.buffers[orig] = new_node
  2343. V.graph.name_to_buffer[orig_buf_name] = new_node
  2344. orig = V.graph.operations.index(orig_node)
  2345. V.graph.operations.remove(new_node)
  2346. V.graph.operations[orig] = new_node
  2347. V.graph.name_to_op[orig_op_name] = new_node
  2348. def _estimate_fused_epilogue_runtime(node1, node2, epilogue_runtime) -> float:
  2349. # If no extra memory read by epilogue, assume epilogue is free
  2350. # if extra memory is read by epilogue, add to minimum choice
  2351. total_read_bytes = node2.get_read_buffer_sizes()
  2352. template_write_bytes = node1.get_write_buffer_sizes()
  2353. extra_bytes = total_read_bytes - template_write_bytes
  2354. extra_bytes_ratio = extra_bytes / template_write_bytes
  2355. # Smoothly approaches 1 as extra_bytes_ratio increases
  2356. extra_memory_ratio = extra_bytes_ratio / (1 + extra_bytes_ratio)
  2357. return extra_memory_ratio * epilogue_runtime
  2358. @dataclasses.dataclass
  2359. class NodeUser:
  2360. node: Union[BaseSchedulerNode, OutputNode]
  2361. can_inplace: bool = False
  2362. # A weak user must be scheduled after a given node, but doesn't actually
  2363. # use the result
  2364. is_weak: bool = False
  2365. def __hash__(self) -> int:
  2366. return hash((self.node.get_name(), self.can_inplace, self.is_weak))
  2367. def __eq__(self, other: object) -> bool:
  2368. return (
  2369. isinstance(other, NodeUser)
  2370. and self.get_name() == other.get_name()
  2371. and self.can_inplace == other.can_inplace
  2372. and self.is_weak == other.is_weak
  2373. )
  2374. def get_name(self) -> str:
  2375. return self.node.get_name()
  2376. def merge(self, other: NodeUser) -> NodeUser:
  2377. assert self.node is other.node
  2378. return NodeUser(
  2379. self.node,
  2380. self.can_inplace and other.can_inplace,
  2381. self.is_weak and other.is_weak,
  2382. )
  2383. _post_grad_graph_counter = itertools.count()
  2384. def used_non_deterministic_runtime_estimations() -> bool:
  2385. return config.runtime_estimations_mms_benchmark
  2386. def get_layout_symints(node: ir.IRNode) -> OrderedSet[sympy.Symbol]:
  2387. """Get free symbols from a node's layout (size, stride, offset)."""
  2388. free_symbol_uses: OrderedSet[sympy.Symbol] = OrderedSet()
  2389. layout = node.maybe_get_layout()
  2390. if isinstance(layout, ir.Layout):
  2391. free_symbol_uses.update(
  2392. free_symbols(layout.size)
  2393. | free_symbols(layout.stride)
  2394. | free_symbols(layout.offset)
  2395. )
  2396. if isinstance(layout, ir.MutationLayoutSHOULDREMOVE):
  2397. # symint may be used as index in layout.target
  2398. free_symbol_uses.update(get_layout_symints(layout.target))
  2399. else:
  2400. assert layout is None, f"Expect layout to be None but found layout={layout}"
  2401. return free_symbol_uses
  2402. def get_scheduler_node_symbol_uses(
  2403. node: BaseSchedulerNode,
  2404. ) -> OrderedSet[sympy.Symbol]:
  2405. """
  2406. Gets symbols used in a scheduler node, including free symbols from
  2407. the node's operations and layout symints from outputs.
  2408. """
  2409. if isinstance(node, FusedSchedulerNode):
  2410. return OrderedSet().union(
  2411. *(get_scheduler_node_symbol_uses(snode) for snode in node.snodes)
  2412. )
  2413. assert node.node is not None
  2414. free_symbol_uses = node.node.get_free_symbol_uses()
  2415. free_symbol_uses.update(
  2416. *(get_layout_symints(ir_node) for ir_node in node.node.get_outputs())
  2417. )
  2418. return free_symbol_uses
  2419. def is_epilogue_fusion(node1: BaseSchedulerNode, node2: BaseSchedulerNode):
  2420. return node1.is_template() and config.epilogue_fusion and not node2.is_template()
  2421. def is_prologue_fusion(node1: BaseSchedulerNode, node2: BaseSchedulerNode):
  2422. return node2.is_template() and config.prologue_fusion and not node1.is_template()
  2423. def is_template_fusion(node1: BaseSchedulerNode, node2: BaseSchedulerNode):
  2424. return is_epilogue_fusion(node1, node2) or is_prologue_fusion(node1, node2)
  2425. def template_fusion_pw_node(node1: BaseSchedulerNode, node2: BaseSchedulerNode):
  2426. return node2 if is_epilogue_fusion(node1, node2) else node1
  2427. class Scheduler:
  2428. """
  2429. A Scheduler is a graph of BaseSchedulerNodes. It is responsible for
  2430. optimizations such as fusion, reorder, and graph partition.
  2431. """
  2432. def __init__(self, nodes: list[ir.Operation]) -> None:
  2433. with dynamo_timed("Scheduler.__init__"):
  2434. self._init(nodes)
  2435. def _init(self, nodes: list[ir.Operation]) -> None:
  2436. super().__init__()
  2437. V.graph.scheduler = self
  2438. self.backends: dict[torch.device, BaseScheduling] = {}
  2439. self.post_grad_graph_id = next(_post_grad_graph_counter)
  2440. self._graph_partition_counter = itertools.count()
  2441. self.completed_operations: OrderedSet[str] = OrderedSet()
  2442. self.available_buffer_names = OrderedSet(
  2443. [
  2444. *V.graph.graph_inputs.keys(),
  2445. *V.graph.constants.keys(),
  2446. *V.graph.torchbind_constants.keys(),
  2447. ]
  2448. )
  2449. self.nodes = [self.create_scheduler_node(n) for n in nodes]
  2450. self.previous_node: Optional[BaseSchedulerNode] = None
  2451. self.current_node: Optional[BaseSchedulerNode] = None
  2452. self.update_zero_dim_cpu_tensor()
  2453. # some new constants could have been created above
  2454. self.available_buffer_names.update(V.graph.constants.keys())
  2455. for node in self.nodes:
  2456. node.prune_deps()
  2457. # See [Note: Graph Partition Device Contexts]
  2458. self.default_device_context: Optional[torch.device] = None
  2459. self.name_to_donated_buffer: dict[str, SchedulerDonatedBuffer] = (
  2460. self.get_donated_buffers()
  2461. )
  2462. self.name_to_node: dict[str, BaseSchedulerNode] = {
  2463. n.get_name(): n for n in self.nodes
  2464. }
  2465. self.name_to_buf: dict[str, SchedulerBuffer] = {
  2466. buf.get_name(): buf for node in self.nodes for buf in node.get_outputs()
  2467. }
  2468. self.name_to_fused_node: dict[str, BaseSchedulerNode] = self.name_to_node.copy()
  2469. # mutation_real_name: Maps back to the original name for codegen
  2470. # Example:
  2471. # If you mutate buf0 inside of buf1's kernel, then:
  2472. # mutation_real_name = {"buf0" : "buf1"}
  2473. # all subsequent uses of buf0 become buf1's usage in dependency graph
  2474. self.mutation_real_name: dict[str, str] = {}
  2475. # We handle mutation by renaming modified versions of the same
  2476. # buffer in the dependency graph to prevent cycles.
  2477. # mutation_renames: tracks the current name for a given buffer
  2478. # (changed once per mutation)
  2479. # Example:
  2480. # If you mutate buf0 inside of buf1's kernel, then:
  2481. # mutation_renames = {"buf1" : "buf0"}
  2482. # in codegen we only use buf0, never buf1
  2483. self.mutation_renames: dict[str, str] = {}
  2484. self.seen_template_fusions: OrderedSet[
  2485. tuple[BaseSchedulerNode, BaseSchedulerNode]
  2486. ] = OrderedSet()
  2487. # Must run first to correctly set dependencies, before all other passes that rely on
  2488. # reading from .read_writes.reads or .unmet_dependencies
  2489. self.nodes = comms.decide_global_ordering_of_comms(
  2490. self.nodes,
  2491. self.name_to_buf,
  2492. self.name_to_fused_node,
  2493. )
  2494. self.compute_dependencies()
  2495. self.nodes = self.topological_sort_schedule(self.nodes)
  2496. self.dead_node_elimination()
  2497. self.name_to_fused_node = {n.get_name(): n for n in self.nodes}
  2498. self.compute_ancestors()
  2499. # pyrefly: ignore [bad-assignment]
  2500. metrics.ir_nodes_pre_fusion += len(self.nodes)
  2501. from torch._inductor.debug import log_ir_post_fusion, log_ir_pre_fusion
  2502. log_ir_pre_fusion(self.nodes)
  2503. self.num_orig_nodes = len(self.nodes)
  2504. self.create_foreach_nodes()
  2505. self.nodes = self.topological_sort_schedule(self.nodes)
  2506. self.logged_slow_fusion = OrderedSet[tuple[str, str]]()
  2507. if config._pre_fusion_custom_pass is not None:
  2508. self.nodes = config._pre_fusion_custom_pass(self.nodes)
  2509. if config.distributed_max_autotune_gemm:
  2510. from . import distributed_autotune
  2511. distributed_autotune.schedule(self)
  2512. self.compute_ancestors()
  2513. self.nodes = self.fuse_nodes(self.nodes)
  2514. if config._post_fusion_custom_pass is not None:
  2515. self.nodes = config._post_fusion_custom_pass(self.nodes)
  2516. self.merge_loops()
  2517. self.finalize_multi_template_buffers()
  2518. if (
  2519. config.max_autotune_gemm or config.max_autotune
  2520. ) and use_pipelined_autotuning():
  2521. torch._inductor.select_algorithm.PrecompileThreadPool.shutdown_instance()
  2522. if config.combo_kernels:
  2523. with dynamo_timed(
  2524. "Scheduler.create_combo_kernel_nodes",
  2525. log_pt2_compile_event=True,
  2526. log_waitcounter=True,
  2527. ):
  2528. self.create_combo_kernel_nodes(num_ck_nodes=None)
  2529. # Peak memory pass and overlap pass must run last, otherwise
  2530. # other reordering passes could undo their effects.
  2531. if config.reorder_for_peak_memory:
  2532. from .memory import reorder_for_peak_memory
  2533. self.nodes = reorder_for_peak_memory(
  2534. self.nodes,
  2535. self.name_to_buf,
  2536. self.name_to_fused_node,
  2537. OrderedSet(V.graph.graph_inputs.keys()),
  2538. OrderedSet(V.graph.get_output_names()),
  2539. )
  2540. # reorder_for_compute_comm_overlap may do benchmarking to estimate
  2541. # op runtime. Disable it for now in deterministic mode.
  2542. if not config.deterministic and config.reorder_for_compute_comm_overlap:
  2543. if not config.reorder_for_peak_memory:
  2544. from .memory import assign_memory_planning_info_for_scheduler_buffers
  2545. assign_memory_planning_info_for_scheduler_buffers(
  2546. self.nodes, self.name_to_buf
  2547. )
  2548. if (
  2549. used_non_deterministic_runtime_estimations()
  2550. and config_comms.runtime_estimations_align_across_all_distributed_ranks
  2551. and (
  2552. config.runtime_estimations_mms_benchmark
  2553. or config_comms.runtime_estimations_use_nccl_lib_estimations
  2554. )
  2555. ):
  2556. has_collectives = False
  2557. for node in self.nodes:
  2558. if is_collective(node.node):
  2559. has_collectives = True
  2560. break
  2561. if has_collectives:
  2562. from .comms import (
  2563. align_runtime_estimations_across_all_distributed_ranks,
  2564. )
  2565. align_runtime_estimations_across_all_distributed_ranks(self.nodes)
  2566. # pyrefly: ignore [unbound-name]
  2567. if config_comms.reorder_sink_verbose_logging:
  2568. from torch._logging import trace_structured
  2569. trace_structured(
  2570. "artifact",
  2571. metadata_fn=lambda: {
  2572. "name": "scheduler_nodes_before_comm_overlap",
  2573. "encoding": "string",
  2574. },
  2575. payload_fn=lambda: "\n\n".join(
  2576. [
  2577. f"snode[{i}]"
  2578. + n.debug_str()
  2579. + f" buffer_names:{n.get_buffer_names()}"
  2580. for i, n in enumerate(self.nodes)
  2581. ]
  2582. ),
  2583. )
  2584. self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes)
  2585. self.process_grouped_nodes()
  2586. if (
  2587. # pyrefly: ignore[unbound-name]
  2588. config.graph_partition
  2589. # pyrefly: ignore[unbound-name]
  2590. and config.triton.cudagraphs
  2591. # pyrefly: ignore[unbound-name]
  2592. and config.triton.reorder_for_reducing_graph_partitions
  2593. ):
  2594. self.nodes = self.maybe_reorder_for_minimizing_partition(self.nodes)
  2595. self.nodes = self.reorder_for_partition_with_simple_dependency(self.nodes)
  2596. self.compute_last_usage()
  2597. if torch._inductor.config.test_configs.track_memory_lifecycle:
  2598. self.insert_memory_check_nodes()
  2599. log_ir_post_fusion(self.nodes)
  2600. # pyrefly: ignore[unbound-name]
  2601. V.debug.graph_diagram(self.nodes)
  2602. self.debug_draw_graph()
  2603. # used during codegen:
  2604. self.buffer_names_to_free: OrderedSet[str] = OrderedSet()
  2605. # fx graph node to the position it appears in the graph
  2606. # for debug attribution
  2607. self.origin_to_index: dict[torch.fx.Node, int] = {}
  2608. get_metric_table("graph_stats").add_row(
  2609. lambda: {
  2610. "graph_id": self.post_grad_graph_id,
  2611. "num_nodes_before_fusion": self.num_orig_nodes,
  2612. "num_nodes_after_fusion": len(self.nodes),
  2613. }
  2614. )
  2615. # Unlike V.graph.removed_buffers, the op recorded here is removed but
  2616. # we still need the buffer (generated in alternative ways)
  2617. self.removed_ops: OrderedSet[str] = OrderedSet()
  2618. def get_donated_buffers(self) -> dict[str, SchedulerDonatedBuffer]:
  2619. name_to_donated_buf = {}
  2620. for name in V.graph.graph_inputs_original:
  2621. if isinstance(V.graph.graph_inputs_original[name], ir.DonatedBuffer):
  2622. name_to_donated_buf[name] = SchedulerDonatedBuffer(
  2623. self,
  2624. V.graph.graph_inputs_original[name],
  2625. defining_op=None,
  2626. )
  2627. return name_to_donated_buf
  2628. @property
  2629. def current_device(self) -> Optional[torch.device]:
  2630. return V.graph.current_device
  2631. @current_device.setter
  2632. def current_device(self, device: Optional[torch.device]) -> None:
  2633. V.graph.current_device = device
  2634. def debug_draw_graph(self) -> None:
  2635. """Generate an image of the graph for debugging"""
  2636. if os.environ.get("INDUCTOR_WRITE_SCHEDULER_GRAPH", None) == "1":
  2637. from .debug import draw_buffers
  2638. draw_buffers(self.nodes, print_graph=True)
  2639. def debug_print_nodes(self, label: str) -> None:
  2640. if log.isEnabledFor(logging.INFO):
  2641. log.info("%s:", label)
  2642. for node in self.nodes:
  2643. node.log_details()
  2644. def create_scheduler_node(self, node: ir.Operation) -> BaseSchedulerNode:
  2645. assert node.get_origins() is not None, (
  2646. "All nodes passed to scheduling must have an origin"
  2647. )
  2648. if node.is_no_op():
  2649. return NopKernelSchedulerNode(self, node)
  2650. elif isinstance(node, (ir.ComputedBuffer, ir.TemplateBuffer)):
  2651. return SchedulerNode(self, node)
  2652. elif isinstance(node, ir.ExternKernel):
  2653. return ExternKernelSchedulerNode(self, node)
  2654. else:
  2655. raise NotImplementedError(node)
  2656. def create_foreach_nodes(self) -> None:
  2657. removed_node_names: OrderedSet[str] = OrderedSet()
  2658. fe_nodes = []
  2659. kept_node_names = self.name_to_fused_node.keys()
  2660. for names in V.graph.lists.values():
  2661. names = [
  2662. name
  2663. for name in names
  2664. if name in kept_node_names
  2665. and not isinstance(self.name_to_node[name], NopKernelSchedulerNode)
  2666. ]
  2667. if not names:
  2668. # All nodes eliminated
  2669. continue
  2670. removed_node_names.update(names)
  2671. snodes = [self.name_to_node[name] for name in names]
  2672. enable_autotune = config.combo_kernels_autotune > 1
  2673. fe_node = ForeachKernelSchedulerNode(
  2674. self,
  2675. snodes,
  2676. use_custom_partition_algo=False,
  2677. enable_autotune=enable_autotune,
  2678. )
  2679. fe_nodes.append(fe_node)
  2680. for name in names:
  2681. self.name_to_fused_node[name] = fe_node
  2682. self.nodes = [
  2683. node for node in self.nodes if node.get_name() not in removed_node_names
  2684. ] + list(fe_nodes)
  2685. def compute_dependencies(self) -> None:
  2686. """
  2687. Create dependency edges between nodes, handling aliasing and
  2688. mutation properly.
  2689. """
  2690. class DedupList(Generic[_T]):
  2691. """
  2692. This data structure behaves like a list except it makes sure the
  2693. elements remain unique.
  2694. Normally one could use a OrderedSet/dict for this purpose however
  2695. the list in question gets elements appended as it is being
  2696. iterated over which means that we need to keep the list
  2697. semantics.
  2698. """
  2699. def __init__(
  2700. self,
  2701. items: Optional[list[_T]] = None,
  2702. membership: Optional[OrderedSet[_T]] = None,
  2703. ) -> None:
  2704. self.items = items or []
  2705. self.membership = membership or OrderedSet()
  2706. def append(self, node_user: _T) -> None:
  2707. if node_user in self.membership:
  2708. return
  2709. self.items.append(node_user)
  2710. self.membership.add(node_user)
  2711. def __add__(self, other: DedupList[_T]) -> DedupList[_T]:
  2712. new_membership = OrderedSet.union(self.membership, other.membership)
  2713. new_items = self.items + [
  2714. x for x in other.items if x not in self.membership
  2715. ]
  2716. return DedupList(new_items, new_membership)
  2717. # pyrefly: ignore [not-a-type]
  2718. name_to_users: defaultdict[str, DedupList[NodeUser]] = collections.defaultdict(
  2719. DedupList
  2720. )
  2721. # handle aliasing by using python aliasing in name_to_users
  2722. # if foo aliases bar then we will make name_to_users["foo"] point
  2723. # to the same python list as name_to_users["bar"]
  2724. for node in self.nodes:
  2725. for buf1 in node.get_outputs():
  2726. buf1_name = buf1.get_name()
  2727. # This is for handling auto functionized ops which return None
  2728. # and mutate more than 1 inputs, we shouldn't let them all
  2729. # point to the same user list since buffers in the aliases
  2730. # list might not be alias to each other.
  2731. if (
  2732. isinstance(buf1.node.layout, ir.NoneLayout)
  2733. and len(buf1.get_aliases()) > 1
  2734. ):
  2735. continue
  2736. for buf2_name in buf1.get_aliases():
  2737. if buf1_name in name_to_users and buf2_name in name_to_users:
  2738. # merge the two
  2739. list1 = name_to_users[buf1_name]
  2740. list2 = name_to_users[buf2_name]
  2741. combined = list1 + list2
  2742. for key in name_to_users:
  2743. if (
  2744. name_to_users[key] is list1
  2745. or name_to_users[key] is list2
  2746. ):
  2747. name_to_users[key] = combined
  2748. elif buf1_name in name_to_users:
  2749. name_to_users[buf2_name] = name_to_users[buf1_name]
  2750. else:
  2751. name_to_users[buf1_name] = name_to_users[buf2_name]
  2752. # pyrefly: ignore [not-a-type]
  2753. def rename(n: str) -> str:
  2754. if n in self.mutation_renames:
  2755. return rename(self.mutation_renames[n])
  2756. return n
  2757. def add_user(
  2758. # pyrefly: ignore [not-a-type]
  2759. used_by_name: str,
  2760. user_node: Union[BaseSchedulerNode, OutputNode],
  2761. can_inplace: bool = False,
  2762. is_weak: bool = False,
  2763. ) -> None:
  2764. name_to_users[rename(used_by_name)].append(
  2765. NodeUser(user_node, can_inplace, is_weak)
  2766. )
  2767. # pyrefly: ignore [not-a-type]
  2768. unbacked_symbol_to_origin_node: dict[sympy.Symbol, Optional[str]] = {}
  2769. # NB: None means that the dependency is on an input. Don't actually
  2770. # generate a dependency because if we do, Inductor will start trying
  2771. # to free the unbacked int but that's pointless
  2772. for val in V.graph.graph_inputs.values():
  2773. if isinstance(val, sympy.Expr):
  2774. for fs in val.free_symbols:
  2775. unbacked_symbol_to_origin_node[fs] = None
  2776. elif isinstance(val, ir.TensorBox):
  2777. # We also need to add symbols from input size as well because
  2778. # AOTI doesn't lift the unbacked symints to inputs
  2779. sym_size = [s for s in val.get_size() if isinstance(s, sympy.Expr)]
  2780. for s in sym_size:
  2781. for fs in s.free_symbols:
  2782. unbacked_symbol_to_origin_node[fs] = None
  2783. has_non_input_unbacked_defs = False
  2784. for node in self.nodes:
  2785. assert node.node is not None
  2786. # unbacked symbols don't follow ordinary buffer dependencies, so
  2787. # we track their def/uses separately
  2788. unbacked_symbol_defs = sorted(
  2789. node.node.get_unbacked_symbol_defs(), key=lambda x: x.name
  2790. )
  2791. for s in unbacked_symbol_defs:
  2792. assert isinstance(s, sympy.Symbol)
  2793. # Pick the first definer as canonical. There may be multiple
  2794. # because if a MultiOutputLayout buffer propagates an unbacked
  2795. # symint to multiple outputs, they will all claim to def it.
  2796. has_non_input_unbacked_defs = True
  2797. if s not in unbacked_symbol_to_origin_node:
  2798. unbacked_symbol_to_origin_node[s] = node.get_name()
  2799. for node in self.nodes:
  2800. log.debug("scheduling %s", node.node)
  2801. if has_non_input_unbacked_defs:
  2802. assert node.node is not None
  2803. unbacked_symbol_uses = sorted(
  2804. node.node.get_free_symbol_uses(unbacked_only=True),
  2805. key=lambda x: x.name,
  2806. )
  2807. # if a kernel takes unbacked symints, register dependencies
  2808. for s in unbacked_symbol_uses:
  2809. assert s in unbacked_symbol_to_origin_node, (
  2810. f"{s} not in {unbacked_symbol_to_origin_node}"
  2811. )
  2812. if (r := unbacked_symbol_to_origin_node[s]) is not None:
  2813. for buf in self.name_to_node[r].get_outputs():
  2814. node.add_fake_dep(StarDep(buf.get_name()))
  2815. if (
  2816. len(node.read_writes.writes) == 1
  2817. and (dep := next(iter(node.read_writes.writes)))
  2818. and isinstance(dep, MemoryDep)
  2819. ):
  2820. node_mode = dep.mode
  2821. else:
  2822. node_mode = None
  2823. # Handle output mutations
  2824. for buf in node.get_outputs():
  2825. # a node will mutate either 0 or 1 buffers
  2826. assert len(buf.get_mutations()) <= 1
  2827. for alt_name in buf.get_mutations():
  2828. alt_name = rename(alt_name)
  2829. # this node must run after the prior writer
  2830. add_user(alt_name, node)
  2831. node.add_fake_dep(StarDep(alt_name, mode=node_mode))
  2832. for user in name_to_users[alt_name].items:
  2833. if user.get_name() == node.get_name():
  2834. continue
  2835. assert isinstance(user.node, BaseSchedulerNode)
  2836. for out_buf in user.node.get_outputs():
  2837. other_name = out_buf.get_name()
  2838. # this node must run after all prior readers
  2839. other_name = rename(other_name)
  2840. # Check if the prior reader is a true alias (view) vs a clone.
  2841. # Views share underlying storage with the mutated buffer, so we
  2842. # need a real dependency (is_fake=False) to keep the view's
  2843. # buffer alive until after this mutation completes. Clones have
  2844. # independent storage, so we only need an ordering dependency
  2845. # (is_fake=True) that won't extend their buffer lifetime.
  2846. is_alias = alt_name in out_buf.get_aliases()
  2847. node.add_fake_dep(
  2848. WeakDep(
  2849. other_name,
  2850. mutating_buf=buf.get_name(),
  2851. is_fake=not is_alias,
  2852. )
  2853. )
  2854. add_user(other_name, node, is_weak=True)
  2855. for add_dep in V.graph.additional_buffer_deps[node.get_name()]:
  2856. add_user(add_dep, node, is_weak=True)
  2857. # is_fake=True because these are control dependencies for ordering only,
  2858. # they should not extend buffer lifetimes
  2859. node.add_fake_dep(WeakDep(add_dep, node.get_name(), is_fake=True))
  2860. for add_dep in V.graph.additional_star_deps[node.get_name()]:
  2861. add_user(add_dep, node, is_weak=False) # Strong dependency
  2862. node.add_fake_dep(StarDep(add_dep))
  2863. # add normal non-mutation dependencies
  2864. for read in node.read_writes.reads:
  2865. if not isinstance(read, WeakDep):
  2866. add_user(read.name, node, node.can_inplace(read))
  2867. node.update_mutated_names(self.mutation_renames)
  2868. # update our renaming scheme for the next iteration
  2869. for buf in node.get_outputs():
  2870. for alt_name in buf.get_mutations():
  2871. self.mutation_renames[rename(alt_name)] = buf.get_name()
  2872. self.mutation_renames[alt_name] = buf.get_name()
  2873. self.mutation_real_name[buf.get_name()] = (
  2874. self.mutation_real_name.get(alt_name, alt_name)
  2875. )
  2876. # make sure outputs aren't dead-code-eliminated
  2877. for buf_name in V.graph.get_output_names():
  2878. log.debug("scheduling output %s", buf_name)
  2879. add_user(buf_name, OutputNode(StarDep(buf_name)))
  2880. # make sure unbacked symints aren't dead-code-eliminated
  2881. if has_non_input_unbacked_defs:
  2882. for out in V.graph.graph_outputs:
  2883. for s in out.get_free_symbol_uses(unbacked_only=True):
  2884. assert s in unbacked_symbol_to_origin_node, (
  2885. f"{s} not in {unbacked_symbol_to_origin_node.keys()}"
  2886. )
  2887. if r := unbacked_symbol_to_origin_node[s]:
  2888. for buf_name in self.name_to_node[r].get_buffer_names():
  2889. log.debug(
  2890. "scheduling output %s for unbacked symint %s",
  2891. buf_name,
  2892. s,
  2893. )
  2894. add_user(buf_name, OutputNode(StarDep(buf_name)))
  2895. # make sure input mutation isn't dead-code-eliminated
  2896. for name in self.mutation_renames:
  2897. if name in V.graph.graph_inputs:
  2898. add_user(name, OutputNode(StarDep(name)))
  2899. V.graph.mutated_inputs.add(name)
  2900. elif name in V.graph.constants:
  2901. # In AOTI, module parameters and buffers are not lifted as graph inputs
  2902. add_user(name, OutputNode(StarDep(name)))
  2903. inp_names = {
  2904. name: index for index, name in enumerate(V.graph.graph_inputs.keys())
  2905. }
  2906. V.graph.mutated_input_idxs = [
  2907. inp_names[name] for name in V.graph.mutated_inputs
  2908. ]
  2909. # copy users information onto the nodes
  2910. for node in self.nodes:
  2911. for buf in node.get_outputs():
  2912. buf.set_users(name_to_users[buf.get_name()].items)
  2913. for name in self.name_to_donated_buffer:
  2914. self.name_to_donated_buffer[name].set_users(name_to_users[name].items)
  2915. # For debug logging
  2916. logbuf = IndentedBuffer()
  2917. logbuf.splice("{")
  2918. for key, value in name_to_users.items():
  2919. with logbuf.indent():
  2920. users = [v.get_name() for v in value.items]
  2921. logbuf.splice(f"'{key}': {users},")
  2922. logbuf.splice("}")
  2923. str = logbuf.getrawvalue().rstrip()
  2924. compute_dependencies_log.debug("BUFFER USER LIST\n")
  2925. compute_dependencies_log.debug("===== AFTER SCHEDULING =====\n%s", str)
  2926. def insert_memory_check_nodes(self) -> None:
  2927. from .memory import (
  2928. assign_memory_planning_info_for_scheduler_buffers,
  2929. compute_memory_timeline,
  2930. FreeableInputBuffer,
  2931. get_freeable_input_buf,
  2932. )
  2933. graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
  2934. name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = (
  2935. get_freeable_input_buf(self.nodes, graph_inputs)
  2936. )
  2937. if not torch._inductor.config.reorder_for_peak_memory:
  2938. assign_memory_planning_info_for_scheduler_buffers(
  2939. self.nodes, self.name_to_buf
  2940. )
  2941. graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
  2942. buf_info_list, _, _ = compute_memory_timeline(
  2943. self.nodes,
  2944. name_to_freeable_input_buf,
  2945. graph_outputs,
  2946. )
  2947. step_allocs_deallocs: list[tuple[list[str], list[str]]] = [
  2948. ([], []) for _ in range(len(self.nodes))
  2949. ]
  2950. for buf_info in buf_info_list:
  2951. # Skip zero-size buffers
  2952. if buf_info.size_alloc == 0 and buf_info.size_free == 0:
  2953. continue
  2954. buf_name = buf_info.buffer.get_name()
  2955. step_allocs_deallocs[buf_info.start_step][0].append(buf_name)
  2956. step_allocs_deallocs[buf_info.end_step][1].append(buf_name)
  2957. from torch._inductor.runtime.debug_utils import register_check_mem_op
  2958. register_check_mem_op()
  2959. def construct_mem_check_node(
  2960. step_idx: int, is_final_step: bool
  2961. ) -> ExternKernelSchedulerNode:
  2962. expected_newly_alive = step_allocs_deallocs[step_idx][0]
  2963. expected_newly_dead = step_allocs_deallocs[step_idx][1]
  2964. nontensor_args = [expected_newly_alive, expected_newly_dead, is_final_step]
  2965. node = ir.MemoryCheckKernel(
  2966. layout=NoneLayout(device=torch.device("cpu")),
  2967. kernel=torch.ops._inductor_debug.check_memory_step.default,
  2968. tensor_args=[],
  2969. nontensor_args=nontensor_args,
  2970. unflatten_args=lambda tensor_args, constant_args: (
  2971. tensor_args,
  2972. {
  2973. "alive": constant_args[0],
  2974. "dead": constant_args[1],
  2975. "is_final_step": constant_args[2],
  2976. },
  2977. ),
  2978. )
  2979. node.operation_name = f"mem_check_{self.nodes[step_idx].get_name()}"
  2980. return ExternKernelSchedulerNode(self, node)
  2981. new_nodes = []
  2982. for i, node in enumerate(self.nodes):
  2983. new_nodes.append(node)
  2984. new_nodes.append(
  2985. construct_mem_check_node(i, is_final_step=(i == len(self.nodes) - 1))
  2986. )
  2987. self.nodes = new_nodes
  2988. def dead_node_elimination(self) -> None:
  2989. """
  2990. Remove any nodes without users
  2991. """
  2992. if not config.use_dce:
  2993. return
  2994. # self.nodes is in topological order, so by iterating in reverse order
  2995. # we have visited (and potentially removed) all users before visiting a
  2996. # given node.
  2997. updated_nodes = []
  2998. for node in reversed(self.nodes):
  2999. def can_eliminate_user(user: NodeUser) -> bool:
  3000. return user.is_weak or user.get_name() in V.graph.removed_operations
  3001. active_buffers = False
  3002. for buf in node.get_outputs():
  3003. can_eliminate = all(can_eliminate_user(u) for u in buf.users)
  3004. if can_eliminate:
  3005. log.debug("removed dead buffer: %s", buf.get_name())
  3006. V.graph.removed_buffers.add(buf.get_name())
  3007. else:
  3008. active_buffers = True
  3009. can_eliminate = not node.has_side_effects() and not active_buffers
  3010. if not can_eliminate:
  3011. updated_nodes.append(node)
  3012. else:
  3013. # dead code
  3014. log.debug("removed dead operation: %s", node.get_name())
  3015. V.graph.removed_operations.add(node.get_name())
  3016. for read in node.read_writes.reads:
  3017. if read.name in self.name_to_buf:
  3018. users = self.name_to_buf[read.name].users
  3019. self.name_to_buf[read.name].users = [
  3020. u for u in users if u.node.get_name() != node.get_name()
  3021. ]
  3022. self.nodes = list(reversed(updated_nodes))
  3023. # Prune any WeakDeps no longer needed
  3024. for node in self.nodes:
  3025. node.prune_weak_deps()
  3026. def mode_requires_synchronization(self, mode: Optional[str]) -> bool:
  3027. """Check if store mode requires cross-thread synchronization."""
  3028. return mode is not None # Currently all non-None modes need sync
  3029. def topological_sort_schedule(
  3030. self, nodes: list[BaseSchedulerNode]
  3031. ) -> list[BaseSchedulerNode]:
  3032. """
  3033. Ensure nodes is in topologically sorted order
  3034. """
  3035. seen = OrderedSet[BaseSchedulerNode]()
  3036. name_to_node: dict[str, BaseSchedulerNode] = dict()
  3037. result: list[BaseSchedulerNode] = []
  3038. def visit(n: BaseSchedulerNode) -> None:
  3039. if n not in seen:
  3040. seen.add(n)
  3041. for dep in sorted(n.unmet_dependencies, key=lambda d: d.name):
  3042. # We only care about doing toposort within `nodes`
  3043. if dep.name not in name_to_node:
  3044. continue
  3045. visit(name_to_node[dep.name])
  3046. result.append(n)
  3047. for node in nodes:
  3048. for name in node.get_buffer_names():
  3049. name_to_node[name] = node
  3050. for node in nodes:
  3051. visit(node)
  3052. return result
  3053. def _get_unmet_dep_nodes(self, snode: BaseSchedulerNode) -> list[BaseSchedulerNode]:
  3054. unmet_deps: OrderedSet[str] = OrderedSet()
  3055. if isinstance(
  3056. snode,
  3057. (
  3058. SchedulerNode,
  3059. ExternKernelSchedulerNode,
  3060. NopKernelSchedulerNode,
  3061. FusedSchedulerNode,
  3062. GroupedSchedulerNode,
  3063. ),
  3064. ):
  3065. for dep in snode.unmet_dependencies:
  3066. unmet_deps.add(dep.name)
  3067. else:
  3068. raise RuntimeError(
  3069. f"get_unmet_dep_nodes is not implemented for {type(snode)}."
  3070. )
  3071. unmet_dep_ops = (self.name_to_buf[dep].defining_op_name() for dep in unmet_deps)
  3072. return list(OrderedSet(self.name_to_fused_node[n] for n in unmet_dep_ops))
  3073. def _topological_sort_nodes(self) -> list[list[BaseSchedulerNode]]:
  3074. """
  3075. Sort nodes by their topological order, return a list of node lists.
  3076. """
  3077. order = []
  3078. nodes = dict.fromkeys(self.nodes, 0)
  3079. children: dict[Any, Any] = {}
  3080. for node in self.nodes:
  3081. deps = self._get_unmet_dep_nodes(node)
  3082. nodes[node] = len(deps)
  3083. for dep in deps:
  3084. c = children.get(dep, [])
  3085. c.append(node)
  3086. children[dep] = c
  3087. zero_deg_nodes = [n for n, v in nodes.items() if v == 0]
  3088. while zero_deg_nodes:
  3089. order.append(zero_deg_nodes)
  3090. for n in zero_deg_nodes:
  3091. for user in children.get(n, []):
  3092. nodes[user] -= 1
  3093. nodes.pop(n)
  3094. zero_deg_nodes = [n for n, v in nodes.items() if v == 0]
  3095. assert not nodes, "Topological sort failed!"
  3096. return order
  3097. def compute_ancestors(self) -> None:
  3098. """
  3099. Populate each node.ancestors
  3100. """
  3101. # note self.nodes is topologically sorted
  3102. name_to_ancestors: dict[str, OrderedSet[str]] = {}
  3103. for node in self.nodes:
  3104. ancestors: OrderedSet[str] = OrderedSet()
  3105. for dep in node.unmet_dependencies:
  3106. dep_node_name = self.name_to_buf[dep.name].defining_op_name()
  3107. ancestors.add(dep_node_name)
  3108. ancestors |= name_to_ancestors[dep_node_name]
  3109. name_to_ancestors[node.get_name()] = ancestors
  3110. node.ancestors = ancestors
  3111. for order, node in enumerate(self.nodes):
  3112. node.min_order = order
  3113. node.max_order = order
  3114. def merge_loops(self) -> None:
  3115. if not config.loop_ordering_after_fusion:
  3116. return
  3117. for node in self.nodes:
  3118. # Even for CPU, if we are using the halide backend, we still need
  3119. # the merge loops steps below
  3120. if not isinstance(node, (SchedulerNode, FusedSchedulerNode)) or (
  3121. not node.is_gpu() and config.cpu_backend != "halide"
  3122. ):
  3123. continue
  3124. for snode in node.get_nodes():
  3125. # merge loops for the scheduler node
  3126. if not isinstance(snode, SchedulerNode) or snode.is_template():
  3127. continue
  3128. snode.merge_loops()
  3129. # Note that for CPU backend, merging loops will change
  3130. # snode.group. It's fine for Triton backend.
  3131. # But if we simplify update snode.group like this:
  3132. # group_fn = self.get_backend(snode.node.get_device()).group_fn
  3133. # snode.group = (snode.node.get_device(), group_fn(snode._sizes))
  3134. # There is still an issue due to different snode in a
  3135. # FusedSchedulerNode having different merged loops.
  3136. # Skip CPU backend for now.
  3137. def fuse_nodes(self, nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
  3138. """
  3139. Combine eligible nodes into FusedSchedulerNodes.
  3140. """
  3141. with dynamo_timed(
  3142. "Scheduler.fused_nodes", log_pt2_compile_event=True, log_waitcounter=True
  3143. ):
  3144. for i in range(10):
  3145. old_len = len(nodes)
  3146. fusion_log.debug(
  3147. "===== attempting fusion (%d/10): %d nodes =====",
  3148. i + 1,
  3149. old_len,
  3150. )
  3151. nodes = self.fuse_nodes_once(nodes, is_reorder_round=False)
  3152. new_len = len(nodes)
  3153. fusion_log.debug(
  3154. "completed fusion round (%d/10): fused %d nodes into %d nodes\n",
  3155. i + 1,
  3156. old_len,
  3157. new_len,
  3158. )
  3159. if new_len == old_len or new_len == 1:
  3160. fusion_log.debug(
  3161. "===== fusion complete (%d iterations) =====", i + 1
  3162. )
  3163. break
  3164. if (
  3165. config.loop_ordering_after_fusion
  3166. or config.loop_index_inversion_in_fusion
  3167. ):
  3168. nodes = self.fuse_nodes_once(nodes, is_reorder_round=True)
  3169. return nodes
  3170. def process_grouped_nodes(self) -> None:
  3171. """
  3172. Unpack GroupedSchedulerNode into regular nodes.
  3173. """
  3174. new_nodes: list[BaseSchedulerNode] = []
  3175. for node in self.nodes:
  3176. new_nodes.extend(
  3177. node.unpack() if isinstance(node, GroupedSchedulerNode) else [node]
  3178. )
  3179. self.nodes = new_nodes
  3180. def benchmark_fused_nodes(
  3181. self, nodes: Sequence[BaseSchedulerNode]
  3182. ) -> tuple[float, str]:
  3183. """
  3184. Benchmark fused list of nodes and return the execution time
  3185. in milliseconds on randomly generated inputs.
  3186. """
  3187. assert len(nodes) > 0
  3188. device = nodes[0].get_device()
  3189. self.current_device = device
  3190. backend = self.get_backend(device)
  3191. with dynamo_timed(
  3192. "benchmark_fused_nodes",
  3193. log_pt2_compile_event=True,
  3194. dynamo_compile_column_us="compile_time_autotune_time_us",
  3195. ):
  3196. return backend.benchmark_fused_nodes(nodes)
  3197. def generate_kernel_code_from_nodes(
  3198. self,
  3199. nodes: Sequence[BaseSchedulerNode],
  3200. benchmark_kernel: bool,
  3201. hint_override: Optional[int] = None,
  3202. ) -> str:
  3203. """
  3204. Benchmark fused list of nodes and return the execution time
  3205. in milliseconds on randomly generated inputs.
  3206. """
  3207. assert len(nodes) > 0
  3208. device = nodes[0].get_device()
  3209. self.current_device = device
  3210. backend = self.get_backend(device)
  3211. with dynamo_timed("generate_kernel_code_from_nodes"):
  3212. return backend.generate_kernel_code_from_nodes(
  3213. nodes, benchmark_kernel, hint_override=hint_override
  3214. )
  3215. def benchmark_codegened_module(
  3216. self, module: ModuleType, device: torch.device
  3217. ) -> tuple[float, str]:
  3218. """
  3219. Benchmark fused list of nodes and return the execution time
  3220. in milliseconds on randomly generated inputs.
  3221. """
  3222. self.current_device = device
  3223. backend = self.get_backend(device)
  3224. with dynamo_timed("benchmark_codegened_module"):
  3225. return backend.benchmark_codegened_module(module)
  3226. def _has_layout_conflict_for_template(
  3227. self, multi_node: ir.MultiTemplateBuffer
  3228. ) -> bool:
  3229. """
  3230. Check if selecting a Triton template would cause layout conflicts.
  3231. Returns True if there's a conflict and we should fall back to ATen.
  3232. """
  3233. constraints = V.graph.buffer_layout_constraints
  3234. if not constraints:
  3235. return False
  3236. log.debug("Node %s has constraints %s", multi_node, constraints)
  3237. for inp in multi_node.inputs:
  3238. # pyrefly: ignore [missing-attribute]
  3239. inp_name = inp.get_name()
  3240. if not getattr(inp, "layout", None) or inp_name not in constraints:
  3241. continue
  3242. layout = inp.layout
  3243. expected_layout = constraints[inp_name]
  3244. if isinstance(layout, ir.FlexibleLayout):
  3245. # Freeze to the expected layout to avoid conflicts
  3246. # pyrefly: ignore [missing-attribute]
  3247. inp.freeze_layout_with_exact_strides(expected_layout.stride)
  3248. layout = inp.layout
  3249. if isinstance(layout, ir.FixedLayout) and expected_layout != layout:
  3250. # Layout already frozen to a different layout - conflict
  3251. log.warning(
  3252. "Layout conflict detected for %s: template expects %s but layout is frozen to %s",
  3253. inp_name,
  3254. expected_layout,
  3255. layout,
  3256. )
  3257. return True
  3258. return False
  3259. def finalize_multi_template_buffers(self) -> None:
  3260. """
  3261. Finalize a backing choice for MultiTemplateBuffers which did not already have a
  3262. choice finalized through fusion. In the case of an extern choice, this will result
  3263. in replacing the SchedulerNode.
  3264. If a MultiTemplateBuffer did not have any fusion opportunities, finalizing a choice
  3265. will force completion of compilation and benchmarking.
  3266. """
  3267. for i, node in enumerate(self.nodes):
  3268. if isinstance(node, SchedulerNode) and isinstance(
  3269. node.node, ir.MultiTemplateBuffer
  3270. ):
  3271. multi_node = node.node
  3272. if not config.test_configs.force_extern_kernel_in_multi_template:
  3273. min_node_unfused, _ = multi_node.get_min_choice()
  3274. else:
  3275. min_node_unfused = next(
  3276. (
  3277. timing
  3278. for timing in multi_node.choice_timings()
  3279. if isinstance(
  3280. timing,
  3281. torch._inductor.select_algorithm.ExternKernelCaller,
  3282. )
  3283. ),
  3284. )
  3285. if isinstance(
  3286. min_node_unfused,
  3287. torch._inductor.ir.TritonTemplateCallerBase,
  3288. ):
  3289. # Check for layout conflicts before committing to Triton template
  3290. if self._has_layout_conflict_for_template(multi_node):
  3291. # Fall back to first ExternKernelCaller (ATen)
  3292. for choice in multi_node.choice_timings():
  3293. if isinstance(
  3294. choice,
  3295. torch._inductor.select_algorithm.ExternKernelCaller,
  3296. ):
  3297. min_node_unfused = choice
  3298. break
  3299. assert isinstance(
  3300. choice, torch._inductor.select_algorithm.ExternKernelCaller
  3301. ), (
  3302. "No extern kernel detected to fallback to when layout constraints fail for Triton templates"
  3303. )
  3304. if isinstance(
  3305. min_node_unfused,
  3306. torch._inductor.ir.TritonTemplateCallerBase,
  3307. ):
  3308. # pyrefly: ignore [unbound-name]
  3309. if config.multi_kernel_hints:
  3310. callers: dict[Optional[int], TritonTemplateCallerBase] = {}
  3311. callers[None] = min_node_unfused
  3312. # pyrefly: ignore [unbound-name]
  3313. for hint in config.multi_kernel_hints:
  3314. timings = multi_node.choice_timings(hint_override=hint)
  3315. triton_timings = {
  3316. k: v
  3317. for k, v in timings.items()
  3318. if isinstance(k, TritonTemplateCallerBase)
  3319. }
  3320. choice = min(triton_timings.items(), key=lambda x: x[1])[0]
  3321. callers[hint] = choice
  3322. node.node.finalize_as_triton_callers(callers)
  3323. else:
  3324. node.node.finalize_as_triton_caller(min_node_unfused)
  3325. continue
  3326. with ir.IRNode.current_origins(multi_node.origins):
  3327. out_tensorbox = min_node_unfused.output_node()
  3328. out_storage = out_tensorbox.data # type: ignore[union-attr]
  3329. assert isinstance(out_storage, ir.StorageBox)
  3330. out_buffer = out_storage.data
  3331. assert isinstance(out_buffer, ir.OperationBuffer)
  3332. if multi_node.origin_node:
  3333. assign_origin_node(out_tensorbox, multi_node.origin_node)
  3334. out_buffer.layout = multi_node.layout
  3335. self._replace_node(out_buffer, multi_node, i, node)
  3336. def _replace_node(
  3337. self,
  3338. out_buffer: ir.OperationBuffer,
  3339. multi_node: ir.MultiTemplateBuffer,
  3340. i: int,
  3341. node: SchedulerNode,
  3342. ) -> None:
  3343. _replace_operation_buffer(multi_node, out_buffer)
  3344. new_scheduler_node = self.create_scheduler_node(out_buffer)
  3345. self.nodes[i] = new_scheduler_node
  3346. self.name_to_node[node.get_name()] = new_scheduler_node
  3347. self.name_to_fused_node[node.get_name()] = new_scheduler_node
  3348. # We need to reflect the mutation renames that were recorded in the original node
  3349. mutation_renames = {}
  3350. for dep in itertools.chain(node.read_writes.reads, node.unmet_dependencies):
  3351. if real_name := self.mutation_real_name.get(dep.name, None):
  3352. mutation_renames[real_name] = dep.name
  3353. def rename_deps(deps: OrderedSet[Dep]) -> OrderedSet[Dep]:
  3354. return OrderedSet(dep.rename(mutation_renames) for dep in deps)
  3355. new_scheduler_node.unmet_dependencies = rename_deps(
  3356. new_scheduler_node.unmet_dependencies
  3357. )
  3358. new_scheduler_node.read_writes.reads = rename_deps(
  3359. new_scheduler_node.read_writes.reads
  3360. )
  3361. for new_out, old_out in zip(
  3362. new_scheduler_node.get_outputs(), node.get_outputs()
  3363. ):
  3364. self.name_to_buf[old_out.get_name()] = new_out
  3365. new_out.users = old_out.users
  3366. new_scheduler_node.min_order = node.min_order
  3367. new_scheduler_node.max_order = node.max_order
  3368. new_scheduler_node.ancestors = node.ancestors
  3369. new_scheduler_node.last_usage = node.last_usage
  3370. def _any_atomic_add(self, node_list: Sequence[BaseSchedulerNode]) -> bool:
  3371. return any(
  3372. hasattr(n.node, "data")
  3373. and n.node is not None
  3374. and hasattr(n.node.data, "scatter_mode")
  3375. and n.node.data.scatter_mode == "atomic_add"
  3376. for n in node_list
  3377. )
  3378. def compile_kernel(
  3379. self, nodes: Sequence[BaseSchedulerNode], hint_override: Optional[int] = None
  3380. ) -> tuple[Optional[LambdaFuture], ModuleType]:
  3381. src_code = self.generate_kernel_code_from_nodes(
  3382. nodes, benchmark_kernel=True, hint_override=hint_override
  3383. )
  3384. mod = PyCodeCache.load(src_code)
  3385. async_compile = torch._inductor.async_compile.AsyncCompile()
  3386. if not async_compile.use_process_pool():
  3387. fut = None
  3388. else:
  3389. fut = async_compile.triton(kernel_name="triton_", source_code=src_code)
  3390. assert isinstance(fut, LambdaFuture)
  3391. return (fut, mod)
  3392. def speedup_by_fusion(
  3393. self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
  3394. ) -> FusionResult:
  3395. """
  3396. If config.benchmark_fusion is False, always return True.
  3397. Otherwise, return True if fusion can brings speedup.
  3398. """
  3399. is_multi_template = any(
  3400. n.is_template()
  3401. and isinstance(n.get_template_node(), ir.MultiTemplateBuffer)
  3402. for n in (node1, node2)
  3403. )
  3404. if not config.benchmark_fusion and not is_multi_template:
  3405. return FusionResult.fuse(True)
  3406. if (
  3407. node1.is_template()
  3408. and not isinstance(node1.get_template_node(), ir.TritonTemplateBuffer)
  3409. or node1.is_foreach()
  3410. or node2.is_foreach()
  3411. ):
  3412. # TODO support benchmarking epilogue fusion
  3413. return FusionResult.fuse(True)
  3414. node_list_1 = node1.get_nodes()
  3415. device = node_list_1[0].get_device()
  3416. assert device
  3417. # don't support benchmark fusion for CPU C++ backend right now.
  3418. if device.type == "cpu" and config.cpu_backend != "triton":
  3419. return FusionResult.fuse(True)
  3420. node_list_2 = node2.get_nodes()
  3421. node_list_fused = list(itertools.chain(node_list_1, node_list_2))
  3422. # We can not accurately benchmark kernel using atomic_add
  3423. # due to how we generate random integer inputs.
  3424. # Skip benchmarking them by allowing fusion.
  3425. if self._any_atomic_add(node_list_fused):
  3426. return FusionResult.fuse(True)
  3427. from triton.compiler.errors import CompilationError
  3428. why = WhyNoFuse(node1, node2)
  3429. device = node_list_fused[0].get_device()
  3430. assert device is not None
  3431. def log_fusion(ms_fused: float, ms1: float, ms2: float) -> None:
  3432. if fusion_log.isEnabledFor(logging.DEBUG):
  3433. if ms_fused < ms1 + ms2:
  3434. fusion_log.debug(
  3435. "can fuse (benchmark): fusing %s with %s cause %sx speedup",
  3436. node1.get_buffer_names(),
  3437. node2.get_buffer_names(),
  3438. green_text(f"{(ms1 + ms2) / ms_fused:.3f}"),
  3439. )
  3440. else:
  3441. fusion_log.debug(
  3442. "cannot fuse (benchmark): fusing %s with %s cause %sx slowdown",
  3443. node1.get_buffer_names(),
  3444. node2.get_buffer_names(),
  3445. red_text(f"{ms_fused / (ms1 + ms2):.3f}"),
  3446. )
  3447. if is_multi_template and any(
  3448. n.get_template_node() is not None for n in (node1, node2)
  3449. ):
  3450. epilogue_fusion = node1.get_template_node() is not None
  3451. multi_node = (
  3452. node1.get_template_node()
  3453. if epilogue_fusion
  3454. else node2.get_template_node()
  3455. )
  3456. assert isinstance(multi_node, ir.MultiTemplateBuffer)
  3457. # Check for layout conflicts before committing to Triton template
  3458. if self._has_layout_conflict_for_template(multi_node):
  3459. return FusionResult.fuse(False)
  3460. hint_override_best_fusion_choice: dict[
  3461. Optional[int], TritonTemplateCallerBase
  3462. ] = {}
  3463. future_choices: list[tuple[Any, Optional[LambdaFuture], ModuleType]] = []
  3464. for hint_override in config.multi_kernel_hints:
  3465. choice_timings = multi_node.choice_timings(hint_override)
  3466. for choice, _ in sorted(choice_timings.items(), key=lambda x: x[1]):
  3467. if not isinstance(
  3468. choice, torch._inductor.select_algorithm.TritonTemplateCaller
  3469. ):
  3470. continue
  3471. with multi_node.swap_as_triton_caller(choice):
  3472. future_choices.append(
  3473. (
  3474. choice,
  3475. *self.compile_kernel(
  3476. node_list_fused, hint_override=choice.hint_override
  3477. ),
  3478. )
  3479. )
  3480. min_ms_fused = float("inf")
  3481. ms_fused_choice: Optional[TritonTemplateCallerBase] = None
  3482. new_timings = {}
  3483. for choice, future, mod_fused in future_choices:
  3484. try:
  3485. if future is not None:
  3486. future.result()
  3487. except Exception as e:
  3488. if fusion_log.isEnabledFor(logging.DEBUG):
  3489. fusion_log.debug( # noqa: G200
  3490. "Exception in compiling %s: %s",
  3491. "prologue" if not epilogue_fusion else "epilogue",
  3492. str(e),
  3493. )
  3494. continue
  3495. with multi_node.swap_as_triton_caller(choice):
  3496. ms_fused, path = self.benchmark_codegened_module(
  3497. mod_fused, device
  3498. )
  3499. new_timings[choice] = ms_fused
  3500. if ms_fused < min_ms_fused:
  3501. min_ms_fused = ms_fused
  3502. ms_fused_choice = choice
  3503. multi_node._choice_timings[hint_override] = new_timings
  3504. assert isinstance(ms_fused_choice, TritonTemplateCallerBase)
  3505. hint_override_best_fusion_choice[hint_override] = ms_fused_choice
  3506. bench_epilogue = config.benchmark_epilogue_fusion
  3507. num_triton_callers = sum(
  3508. isinstance(c, TritonTemplateCallerBase) for c in multi_node.choices
  3509. )
  3510. # Track if the choice timings can be retrieved async after compilation
  3511. get_choice_timings_async = (
  3512. use_pipelined_autotuning()
  3513. and not bench_epilogue
  3514. and num_triton_callers <= config.max_epilogue_benchmarked_choices
  3515. )
  3516. ms1, ms2 = float("inf"), float("inf")
  3517. min_choice: ir.ChoiceCaller | None = None
  3518. if not get_choice_timings_async:
  3519. # Eagerly compile and benchmark non-template nodes
  3520. choice_timings = multi_node.choice_timings()
  3521. min_choice, ms1 = multi_node.get_min_choice()
  3522. choice_timings_iter = sorted(
  3523. choice_timings.items(), key=operator.itemgetter(1)
  3524. )
  3525. else:
  3526. # Use 0 for unfused time, won't be used as bench_epilogue
  3527. # is guaranteed to be False here
  3528. choice_timings_iter = [(c, 0) for c in multi_node.choices]
  3529. if bench_epilogue:
  3530. ms2, path2 = (
  3531. self.benchmark_fused_nodes(node_list_2)
  3532. if epilogue_fusion
  3533. else self.benchmark_fused_nodes(node_list_1)
  3534. )
  3535. else:
  3536. # By default, don't do prologue fusion. Generally slower
  3537. if not epilogue_fusion:
  3538. return FusionResult.fuse(False)
  3539. ms2 = node2._get_estimated_runtime()
  3540. ms2_fused = _estimate_fused_epilogue_runtime(node1, node2, ms2)
  3541. # Start compiling choices in parallel
  3542. future_choices: list[tuple[Any, Optional[LambdaFuture], ModuleType]] = []
  3543. triton_choices = 0
  3544. for choice, unfused_time in choice_timings_iter:
  3545. if not isinstance(choice, TritonTemplateCallerBase):
  3546. continue
  3547. # For prologue fusion we check if the underlying template of the choice
  3548. # supports all allowed prologue inputs. If not, we skip this choice in
  3549. # the fusion benchmark.
  3550. # TODO: Remove this check after all Triton templates support prologue fusion.
  3551. # Currently, persistent+TMA Triton template does not due to the TMA-based loads.
  3552. if (
  3553. not epilogue_fusion
  3554. and hasattr(choice, "allowed_prologue_inps")
  3555. and choice.allowed_prologue_inps != multi_node.allowed_prologue_inps
  3556. ):
  3557. continue
  3558. if bench_epilogue and unfused_time >= ms1 + ms2:
  3559. break
  3560. triton_choices += 1
  3561. if triton_choices > config.max_epilogue_benchmarked_choices:
  3562. break
  3563. with multi_node.swap_as_triton_caller(choice):
  3564. future_choices.append(
  3565. (choice, *self.compile_kernel(node_list_fused))
  3566. )
  3567. if len(future_choices) == 0:
  3568. return FusionResult.fuse(False)
  3569. def benchmark_when_ready() -> bool:
  3570. nonlocal choice_timings, future_choices, ms1, min_choice, multi_node
  3571. min_ms_fused = float("inf")
  3572. ms_fused_choice = None
  3573. new_timings = {}
  3574. if get_choice_timings_async:
  3575. assert multi_node and isinstance(multi_node, ir.MultiTemplateBuffer)
  3576. choice_timings = multi_node.choice_timings()
  3577. min_choice, ms1 = multi_node.get_min_choice()
  3578. future_choices = sorted(
  3579. future_choices,
  3580. key=lambda x: choice_timings[x[0]],
  3581. )
  3582. # Benchmark each choice after compilation completes
  3583. for choice, future, mod_fused in future_choices:
  3584. try:
  3585. if future is not None:
  3586. res = future.result()
  3587. elif not bench_epilogue:
  3588. res = mod_fused.triton_
  3589. res.precompile()
  3590. else:
  3591. res = None
  3592. # Ideally we would more narrowly catch Exceptions here but
  3593. # triton will unpredictably error with valid prologue fusions
  3594. except Exception as e:
  3595. if fusion_log.isEnabledFor(logging.DEBUG):
  3596. fusion_log.debug( # noqa: G200
  3597. "Exception in compiling %s: %s",
  3598. "prologue" if not epilogue_fusion else "epilogue",
  3599. str(e),
  3600. )
  3601. continue
  3602. if bench_epilogue:
  3603. # pyrefly: ignore [missing-attribute]
  3604. with multi_node.swap_as_triton_caller(choice):
  3605. ms_fused, path = self.benchmark_codegened_module(
  3606. mod_fused,
  3607. # pyrefly: ignore [bad-argument-type]
  3608. device,
  3609. )
  3610. new_timings[choice] = ms_fused
  3611. if ms_fused < min_ms_fused:
  3612. min_ms_fused = ms_fused
  3613. ms_fused_choice = choice
  3614. else:
  3615. fusible_choice = (
  3616. min_choice == choice
  3617. or ms2 + ms1 > choice_timings[choice] + ms2_fused
  3618. )
  3619. if (
  3620. res
  3621. # pyrefly: ignore [missing-attribute]
  3622. and len(res.launchers) == 1
  3623. # pyrefly: ignore [bad-index]
  3624. and res.launchers[0].n_spills <= 8
  3625. and fusible_choice
  3626. ):
  3627. ms_fused_choice = choice
  3628. break
  3629. if bench_epilogue:
  3630. log_fusion(min_ms_fused, ms1, ms2)
  3631. if (
  3632. not bench_epilogue or min_ms_fused < (ms1 + ms2)
  3633. ) and ms_fused_choice is not None:
  3634. if config.multi_kernel_hints:
  3635. hint_override_best_fusion_choice[None] = ms_fused_choice
  3636. # pyrefly: ignore [missing-attribute]
  3637. multi_node.finalize_as_triton_callers(
  3638. hint_override_best_fusion_choice
  3639. )
  3640. else:
  3641. # pyrefly: ignore [missing-attribute]
  3642. multi_node.finalize_as_triton_caller(ms_fused_choice)
  3643. # pyrefly: ignore [missing-attribute]
  3644. multi_node._choice_timings[None] = new_timings
  3645. return True
  3646. else:
  3647. return False
  3648. return FusionResult.from_callable(
  3649. benchmark_when_ready, future_choices[0][1]
  3650. )
  3651. else:
  3652. # Start parallel compilation for all three kernels
  3653. future_and_mod_l1 = self.compile_kernel(node_list_1)
  3654. future_and_mod_l2 = self.compile_kernel(node_list_2)
  3655. future_and_mod_l1_fused = self.compile_kernel(node_list_fused)
  3656. def benchmark_when_ready() -> bool:
  3657. from torch._inductor.runtime.triton_heuristics import (
  3658. NoTritonConfigsError,
  3659. )
  3660. try:
  3661. # Wait for all compilations to complete
  3662. for fut in (
  3663. future_and_mod_l1[0],
  3664. future_and_mod_l2[0],
  3665. future_and_mod_l1_fused[0],
  3666. ):
  3667. if fut is not None:
  3668. fut.result()
  3669. ms1, path1 = self.benchmark_codegened_module(
  3670. future_and_mod_l1[1],
  3671. # pyrefly: ignore [bad-argument-type]
  3672. device,
  3673. )
  3674. if math.isinf(ms1):
  3675. why("register spilling of the first kernel")
  3676. return False
  3677. ms2, path2 = self.benchmark_codegened_module(
  3678. future_and_mod_l2[1],
  3679. # pyrefly: ignore [bad-argument-type]
  3680. device,
  3681. )
  3682. if math.isinf(ms2):
  3683. why("register spilling of the second kernel")
  3684. return False
  3685. ms_fused, path_fused = self.benchmark_codegened_module(
  3686. future_and_mod_l1_fused[1],
  3687. # pyrefly: ignore [bad-argument-type]
  3688. device,
  3689. )
  3690. if math.isinf(ms_fused):
  3691. why("register spilling of the fused kernel")
  3692. return False
  3693. log_fusion(ms_fused, ms1, ms2)
  3694. if (
  3695. is_metric_table_enabled("slow_fusion")
  3696. and ms_fused >= ms1 + ms2
  3697. and (path1, path2) not in self.logged_slow_fusion
  3698. ):
  3699. self.logged_slow_fusion.add((path1, path2))
  3700. get_metric_table("slow_fusion").add_row(
  3701. lambda: {
  3702. "kernel1_path": path1,
  3703. "kernel1_latency": ms1,
  3704. "kernel2_path": path2,
  3705. "kernel2_latency": ms2,
  3706. "fused_kernel_path": path_fused,
  3707. "fused_kernel_latency": ms_fused,
  3708. "slow_down_ratio": ms_fused / (ms1 + ms2),
  3709. }
  3710. )
  3711. return ms_fused < ms1 + ms2
  3712. except NoTritonConfigsError:
  3713. return False
  3714. except CompilationError as e:
  3715. if "Loop-carried variable" in str(e):
  3716. return True
  3717. raise
  3718. return FusionResult.from_callable(
  3719. callable_fn=benchmark_when_ready, future=future_and_mod_l1_fused[0]
  3720. )
  3721. def get_fused_node(self, node: BaseSchedulerNode) -> BaseSchedulerNode:
  3722. "Look up the node in Scheduler name_to_fused_node"
  3723. return self.name_to_fused_node[node.get_first_name()]
  3724. def fuse_two_nodes(
  3725. self,
  3726. node1: BaseSchedulerNode,
  3727. node2: BaseSchedulerNode,
  3728. fused_nodes: OrderedSet[BaseSchedulerNode],
  3729. ) -> BaseSchedulerNode:
  3730. fusion_log.debug("fusing %s with %s", node1.get_name(), node2.get_name())
  3731. device = node1.get_device()
  3732. assert node2.get_device() == device
  3733. node3 = self.get_backend(device).fuse(node1, node2)
  3734. fused_nodes.remove(node1)
  3735. fused_nodes.remove(node2)
  3736. fused_nodes.add(node3)
  3737. self.name_to_fused_node.update({n.get_name(): node3 for n in node3.get_nodes()})
  3738. return node3
  3739. def fuse_if_speedup(
  3740. self,
  3741. node1: BaseSchedulerNode,
  3742. node2: BaseSchedulerNode,
  3743. speedup_fn: Callable[[], bool],
  3744. fused_nodes: OrderedSet[BaseSchedulerNode],
  3745. ):
  3746. if (
  3747. self.can_fuse(node1, node2)
  3748. and not self.will_fusion_create_cycle(node1, node2)
  3749. and speedup_fn()
  3750. ):
  3751. self.fuse_two_nodes(node1, node2, fused_nodes)
  3752. return True
  3753. return False
  3754. def _evaluate_pending_template_fusions(
  3755. self,
  3756. template_fusion_candidates: dict[BaseSchedulerNode, list[PendingFusion]],
  3757. fused_nodes: OrderedSet[BaseSchedulerNode],
  3758. ) -> None:
  3759. """
  3760. Evaluate pending template fusions for a set of fusion candidate nodes.
  3761. The fusion candidate nodes are pointwise nodes as potential epilogue
  3762. or prologue fusions
  3763. """
  3764. while template_fusion_candidates:
  3765. template_futures: list[Future] = []
  3766. future_to_pending_fusion: dict[
  3767. Future, tuple[PendingFusion, BaseSchedulerNode]
  3768. ] = {}
  3769. fusions_to_remove: OrderedSet[BaseSchedulerNode] = OrderedSet()
  3770. for candidate in template_fusion_candidates:
  3771. assert (
  3772. candidate in template_fusion_candidates
  3773. and len(template_fusion_candidates[candidate]) >= 1
  3774. )
  3775. pending_fusion = template_fusion_candidates[candidate].pop(0)
  3776. if len(template_fusion_candidates[candidate]) == 0:
  3777. fusions_to_remove.add(candidate)
  3778. node1, node2 = pending_fusion.get_fusion_nodes()
  3779. if node2 == candidate:
  3780. assert is_epilogue_fusion(node1, node2)
  3781. template_node = node1
  3782. else:
  3783. assert node1 == candidate
  3784. assert is_prologue_fusion(node1, node2)
  3785. template_node = node2
  3786. # template node fused with same class of pointwise (prologue/epilogue)
  3787. # move onto next candidate as not fusible
  3788. # TODO (PaulZhang12): Does not support fusions of templates with
  3789. # multiple potential epilogues
  3790. if self.get_fused_node(template_node) is not template_node:
  3791. continue
  3792. if pending_fusion.future:
  3793. f = pending_fusion.future.future
  3794. assert f is not None
  3795. template_futures.append(f)
  3796. future_to_pending_fusion[f] = (pending_fusion, candidate)
  3797. else:
  3798. # Non AsyncCompile path, perform fusion
  3799. if self.fuse_if_speedup(
  3800. node1, node2, pending_fusion.callable_fn, fused_nodes
  3801. ):
  3802. fusions_to_remove.add(candidate)
  3803. # Evaluate fusion candidates as async_compile completes
  3804. for f in as_completed(template_futures):
  3805. pending_fusion, cand = future_to_pending_fusion[f]
  3806. if self.fuse_if_speedup(
  3807. self.get_fused_node(pending_fusion.node1),
  3808. self.get_fused_node(pending_fusion.node2),
  3809. pending_fusion.callable_fn,
  3810. fused_nodes,
  3811. ):
  3812. fusions_to_remove.add(cand)
  3813. for f in fusions_to_remove:
  3814. template_fusion_candidates.pop(f)
  3815. def _try_fusion_pairs(
  3816. self,
  3817. possible_fusion_pairs: list[tuple[BaseSchedulerNode, BaseSchedulerNode]],
  3818. pending_fusions: dict[BaseSchedulerNode, PendingFusion],
  3819. template_fusion_nodes: dict[BaseSchedulerNode, list[PendingFusion]],
  3820. fused_nodes: OrderedSet[BaseSchedulerNode],
  3821. is_reorder_round: bool,
  3822. ):
  3823. def resolve_pending_fusions(
  3824. node1: BaseSchedulerNode,
  3825. node2: BaseSchedulerNode,
  3826. ) -> None:
  3827. while (
  3828. self.get_fused_node(node1) in pending_fusions
  3829. or self.get_fused_node(node2) in pending_fusions
  3830. ):
  3831. pending_fusion = pending_fusions.get(
  3832. self.get_fused_node(node1),
  3833. pending_fusions.get(self.get_fused_node(node2)),
  3834. )
  3835. assert pending_fusion is not None
  3836. node_key1, node_key2 = pending_fusion.get_fusion_nodes()
  3837. is_speedup = pending_fusion.callable_fn
  3838. pending_fusions.pop(node_key1, None)
  3839. pending_fusions.pop(node_key2, None)
  3840. assert self.get_fused_node(node_key1) is node_key1
  3841. assert self.get_fused_node(node_key2) is node_key2
  3842. if not is_speedup() or self.will_fusion_create_cycle(node1, node2):
  3843. continue
  3844. self.fuse_two_nodes(node_key1, node_key2, fused_nodes)
  3845. for node1, node2 in possible_fusion_pairs:
  3846. # if either node is in a pending fusion, resolve it.
  3847. # since we iterate on potential fusions based on profitability
  3848. # the first potential fusion should take precedence.
  3849. resolve_pending_fusions(node1, node2)
  3850. node1 = self.get_fused_node(node1)
  3851. node2 = self.get_fused_node(node2)
  3852. if (
  3853. is_template_fusion(node1, node2)
  3854. and (node1, node2) in self.seen_template_fusions
  3855. ):
  3856. continue
  3857. if self.can_fuse(
  3858. node1, node2, is_reorder_round
  3859. ) and not self.will_fusion_create_cycle(node1, node2):
  3860. fusion_res = self.speedup_by_fusion(node1, node2)
  3861. if fusion_res.callable_fn is not None:
  3862. pending_fusion = PendingFusion(
  3863. callable_fn=fusion_res.callable_fn,
  3864. node1=node1,
  3865. node2=node2,
  3866. future=fusion_res.future,
  3867. )
  3868. if is_template_fusion(node1, node2):
  3869. assert (node1, node2) not in self.seen_template_fusions
  3870. self.seen_template_fusions.add((node1, node2))
  3871. template_pw_node = template_fusion_pw_node(node1, node2)
  3872. if template_pw_node not in template_fusion_nodes:
  3873. template_fusion_nodes[template_pw_node] = []
  3874. template_fusion_nodes[template_pw_node].append(pending_fusion)
  3875. else:
  3876. pending_fusions[node1] = pending_fusion
  3877. pending_fusions[node2] = pending_fusion
  3878. continue
  3879. if not fusion_res.should_fuse:
  3880. continue
  3881. self.fuse_two_nodes(node1, node2, fused_nodes)
  3882. def _finish_pending_fusions(
  3883. self,
  3884. fused_nodes: OrderedSet[BaseSchedulerNode],
  3885. pending_fusions: dict[BaseSchedulerNode, PendingFusion],
  3886. ):
  3887. seen_pair_speedup_fn: OrderedSet[Callable[[], bool]] = OrderedSet()
  3888. # Resolve pending fusions for non templates in case of benchmark_kernel=True
  3889. for pending_fusion in pending_fusions.values():
  3890. node_key1, node_key2 = pending_fusion.get_fusion_nodes()
  3891. is_speedup_fn = pending_fusion.callable_fn
  3892. if is_speedup_fn in seen_pair_speedup_fn or is_template_fusion(
  3893. node_key1, node_key2
  3894. ):
  3895. continue
  3896. seen_pair_speedup_fn.add(is_speedup_fn)
  3897. assert self.get_fused_node(node_key1) is node_key1
  3898. assert self.get_fused_node(node_key2) is node_key2
  3899. self.fuse_if_speedup(node_key1, node_key2, is_speedup_fn, fused_nodes)
  3900. def _handle_template_overlap(
  3901. self,
  3902. possible_fusions: list[tuple[BaseSchedulerNode, BaseSchedulerNode]],
  3903. deferred_prologue_fusions: list[tuple[BaseSchedulerNode, BaseSchedulerNode]],
  3904. ):
  3905. # Potentially a prologue fusion might have the same template as an epilogue
  3906. # the prologue fusion therefore has to be evaluated on the potential
  3907. # fused template + epilogue
  3908. epilogue_template_nodes = OrderedSet(
  3909. [n1 for n1, n2 in possible_fusions if is_epilogue_fusion(n1, n2)]
  3910. )
  3911. new_possible_fusions = []
  3912. for n1, n2 in possible_fusions:
  3913. if is_prologue_fusion(n1, n2) and n1 in epilogue_template_nodes:
  3914. deferred_prologue_fusions.append((n1, n2))
  3915. else:
  3916. new_possible_fusions.append((n1, n2))
  3917. possible_fusions = new_possible_fusions
  3918. def fuse_nodes_once(
  3919. self,
  3920. nodes: list[BaseSchedulerNode],
  3921. is_reorder_round: bool,
  3922. ) -> list[BaseSchedulerNode]:
  3923. """
  3924. Combine eligible nodes into FusedSchedulerNodes.
  3925. This relies on two key functions to control the logic:
  3926. - self.can_fuse(): checks if a fusion is legal
  3927. - self.score_fusion(): assigns priority to a given fusion
  3928. """
  3929. self.prune_redundant_deps(nodes)
  3930. fused_nodes = OrderedSet(nodes)
  3931. if fusion_log.isEnabledFor(logging.DEBUG):
  3932. fusion_log.debug("fuse_nodes_once, candidates:")
  3933. for node in fused_nodes:
  3934. fusion_log.debug(" %s", node.debug_str_short())
  3935. # These are potential fusions which we are async compiling,
  3936. # and which we will benchmark profitability of.
  3937. # Maps node -> (is_speedup_fn, LambdaFuture, node1, node2)
  3938. # Only used in the case of benchmark_kernel=True
  3939. pending_fusions: dict[
  3940. BaseSchedulerNode,
  3941. PendingFusion,
  3942. ] = {}
  3943. template_fusion_nodes: dict[BaseSchedulerNode, list[PendingFusion]] = {}
  3944. deferred_prologue_fusions: list[
  3945. tuple[BaseSchedulerNode, BaseSchedulerNode]
  3946. ] = []
  3947. possible_fusions = self.get_possible_fusions(
  3948. nodes,
  3949. is_reorder_round,
  3950. )
  3951. if (
  3952. (config.max_autotune_gemm or config.max_autotune)
  3953. and config.prologue_fusion
  3954. and config.epilogue_fusion
  3955. ):
  3956. self._handle_template_overlap(possible_fusions, deferred_prologue_fusions)
  3957. self._try_fusion_pairs(
  3958. possible_fusions,
  3959. pending_fusions,
  3960. template_fusion_nodes,
  3961. fused_nodes,
  3962. is_reorder_round,
  3963. )
  3964. self._finish_pending_fusions(fused_nodes, pending_fusions)
  3965. self._evaluate_pending_template_fusions(template_fusion_nodes, fused_nodes)
  3966. template_fusion_nodes.clear()
  3967. if deferred_prologue_fusions:
  3968. self._try_fusion_pairs(
  3969. deferred_prologue_fusions,
  3970. pending_fusions,
  3971. template_fusion_nodes,
  3972. fused_nodes,
  3973. is_reorder_round,
  3974. )
  3975. self._evaluate_pending_template_fusions(template_fusion_nodes, fused_nodes)
  3976. nodes = sorted(fused_nodes, key=lambda x: x.min_order)
  3977. nodes = self.topological_sort_schedule(nodes)
  3978. return nodes
  3979. def create_combo_kernel_nodes(self, num_ck_nodes: Optional[int] = None) -> None:
  3980. """
  3981. Groups parallel nodes
  3982. """
  3983. fused_nodes = OrderedSet(self.nodes)
  3984. count = 0
  3985. num_nodes_orig = len(self.nodes)
  3986. log.debug("ComboKernels: Generating with num_ck_nodes = %s...", num_ck_nodes)
  3987. for num, node_list in enumerate(
  3988. ForeachKernelSchedulerNode.group_nodes_for_combo_kernels(self)
  3989. ):
  3990. node_list = ForeachKernelSchedulerNode.combinable_nodes(node_list)
  3991. if len(node_list) < 2:
  3992. continue
  3993. if num_ck_nodes is not None and count > num_ck_nodes:
  3994. break
  3995. if not self.speedup_by_combo_kernel(node_list):
  3996. log.debug("ComboKernels: Not speeding up %d-th group", num)
  3997. continue
  3998. count += 1
  3999. enable_autotune = config.combo_kernels_autotune > 0
  4000. group_snode = ForeachKernelSchedulerNode(
  4001. node_list[0].scheduler,
  4002. node_list,
  4003. use_custom_partition_algo=True,
  4004. enable_autotune=enable_autotune,
  4005. )
  4006. log.info(
  4007. "ComboKernels: Combining %d nodes for %d-th group",
  4008. len(node_list),
  4009. num,
  4010. )
  4011. for node in node_list:
  4012. fused_nodes.remove(node)
  4013. fused_nodes.add(group_snode)
  4014. self.name_to_fused_node.update(
  4015. {n.get_name(): group_snode for n in group_snode.get_nodes()}
  4016. )
  4017. self.nodes = sorted(fused_nodes, key=lambda x: x.min_order)
  4018. self.nodes = self.topological_sort_schedule(self.nodes)
  4019. log.info(
  4020. "Generated ComboKernel nodes: %d ComboKernels, totally %d -> %d nodes",
  4021. count,
  4022. num_nodes_orig,
  4023. len(self.nodes),
  4024. )
  4025. self.prune_redundant_deps(self.nodes)
  4026. def prune_redundant_deps(self, nodes: list[BaseSchedulerNode]) -> None:
  4027. for node in nodes:
  4028. node.prune_redundant_deps(self.name_to_fused_node)
  4029. def get_possible_fusions(
  4030. self,
  4031. nodes: list[BaseSchedulerNode],
  4032. is_reorder_round: bool,
  4033. ) -> list[tuple[BaseSchedulerNode, BaseSchedulerNode]]:
  4034. """
  4035. Helper to find all legal fusion opportunities, sorted by self.score_fusion()
  4036. """
  4037. possible_fusions = []
  4038. seen = OrderedSet[tuple[BaseSchedulerNode, BaseSchedulerNode]]()
  4039. def check_all_pairs(nodes: list[BaseSchedulerNode]) -> None:
  4040. for node1_index, node1 in enumerate(nodes):
  4041. for node2 in nodes[
  4042. node1_index + 1 : node1_index
  4043. + 1
  4044. + config.max_fusion_buffer_group_pairwise_attempts
  4045. ]:
  4046. key = (node1, node2)
  4047. if key in seen:
  4048. continue
  4049. seen.add(key)
  4050. if self.can_fuse(node1, node2, is_reorder_round):
  4051. possible_fusions.append(key)
  4052. elif (node2.is_template() or node2.is_foreach()) and self.can_fuse(
  4053. node2, node1, is_reorder_round
  4054. ):
  4055. # foreach fusions and epilogue fusions are order dependent
  4056. possible_fusions.append((node2, node1))
  4057. buffer_names_grouping = collections.defaultdict(list)
  4058. for node in nodes:
  4059. if self.unfusable_node(node):
  4060. continue
  4061. for buf in node.used_buffer_names():
  4062. buffer_names_grouping[buf].append(node)
  4063. for node_grouping in buffer_names_grouping.values():
  4064. check_all_pairs(node_grouping)
  4065. if config.aggressive_fusion:
  4066. group_grouping = collections.defaultdict(list)
  4067. for node in nodes:
  4068. group = getattr(node, "group", None)
  4069. if group:
  4070. group_grouping[group].append(node)
  4071. for node_grouping in group_grouping.values():
  4072. check_all_pairs(node_grouping)
  4073. possible_fusions = self.get_possible_fusions_with_highest_priority(
  4074. possible_fusions
  4075. )
  4076. possible_fusions.sort(key=self.score_fusion_key, reverse=True)
  4077. fusion_log.debug("found %d possible fusions", len(possible_fusions))
  4078. return possible_fusions
  4079. def will_fusion_create_cycle(
  4080. self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
  4081. ) -> bool:
  4082. """
  4083. Finds whether there's a path from node1 to node2 (or vice-versa)
  4084. caused indirectly by other fusions.
  4085. """
  4086. # since we are just returning boolean here, use slightly faster, unordered set
  4087. visited = OrderedSet[FusedSchedulerNode]()
  4088. def found_path(node: BaseSchedulerNode) -> bool:
  4089. # only fused nodes can introduce new ancestors.
  4090. if isinstance(node, FusedSchedulerNode) and node not in visited:
  4091. visited.add(node)
  4092. if node.get_operation_names().issubset(combined_ancestors):
  4093. # All fusion outputs are in ancestors of node1 and node2, thus
  4094. # cannot introduce new path:
  4095. #
  4096. # 1. if output is neither descendent of node1 or node2, the
  4097. # output cannot introduce a path
  4098. # 2. due to [can_fuse]: if WLOG output is descendent of node1, it cannot be
  4099. # on path(node1->node2), hence it cannot be ancestor of node2
  4100. # 3. due to [acyclic]: if WLOG output is descendent of node1, it cannot be
  4101. # ancestor of node1
  4102. return False
  4103. else:
  4104. # continue DFS of new ancestors introduced by the fusion
  4105. return bool(combined_names & node.ancestors) or any(
  4106. found_path(self.name_to_fused_node[n])
  4107. for n in node.ancestors - combined_ancestors
  4108. )
  4109. return False
  4110. # as above - use slightly faster, unordered set
  4111. combined_names = (
  4112. node1.get_operation_names()._dict.keys()
  4113. | node2.get_operation_names()._dict.keys()
  4114. )
  4115. combined_ancestors = (
  4116. node1.ancestors._dict.keys() | node2.ancestors._dict.keys()
  4117. ) - combined_names
  4118. cycle = any(found_path(self.name_to_fused_node[n]) for n in combined_ancestors)
  4119. if cycle:
  4120. WhyNoFuse(node1, node2)("will create cycle")
  4121. return cycle
  4122. def can_fusion_increase_peak_memory(
  4123. self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
  4124. ) -> bool:
  4125. """
  4126. Return true if fusing the two nodes can potentially increasing peak memory.
  4127. The implementation is more like a heuristic since we don't really know if we are at peak
  4128. or not when trying to fuse these two nodes. The order of nodes may change later which makes the
  4129. peak memory estimation hard.
  4130. Here is how we decide the LOWER BOUND of extra memory allocation if we fuse these 2 nodes:
  4131. 1. find all buffers read by each node with a single user. These buffers are supposed to
  4132. be reused if we don't fuses these 2 nodes
  4133. 2. find the intersection of these buffers for the two node and sum the total buffer size.
  4134. If we don't fuse these two nodes, we can at lease avoid this much memory allocation.
  4135. Note that the extra memory allocation is not necessarily causing peak memory increase.
  4136. This is just a heuristic.
  4137. We return true only if the saving for fusion can not trade off the extra memory allocation.
  4138. """
  4139. from .codegen.wrapper import buffer_reuse_key
  4140. def _find_single_user_inputs(
  4141. node: BaseSchedulerNode,
  4142. ) -> list[ir.Buffer]:
  4143. output = []
  4144. for rd in node.read_writes.reads:
  4145. buf = self.name_to_buf.get(rd.name)
  4146. if buf and len(buf.users) == 1 and buf.node.has_tensor_output():
  4147. output.append(buf.node)
  4148. return output
  4149. # Check inputs that can be potentially reused
  4150. lhs_dep_nodes = _find_single_user_inputs(node1)
  4151. rhs_dep_nodes = _find_single_user_inputs(node2)
  4152. lhs_reuse_keys = OrderedSet(buffer_reuse_key(buf) for buf in lhs_dep_nodes)
  4153. rhs_reuse_keys = OrderedSet(buffer_reuse_key(buf) for buf in rhs_dep_nodes)
  4154. common_reuse_keys = lhs_reuse_keys.intersection(rhs_reuse_keys)
  4155. memory_overhead = 0
  4156. for key in common_reuse_keys:
  4157. try:
  4158. memory_overhead += int(key[2])
  4159. except ValueError:
  4160. # not an integer. Fallback is to fuse
  4161. return False
  4162. bw_saving = self.score_fusion_memory(node1, node2)
  4163. # The factor 32 here is quite arbitrary.
  4164. if V.graph.sizevars.statically_known_gt(memory_overhead, 32 * bw_saving):
  4165. return True
  4166. return False
  4167. def fusion_prevent_too_many_reads_and_writes(
  4168. self, node1: BaseSchedulerNode, node2: BaseSchedulerNode, threshold: int
  4169. ) -> bool:
  4170. # After fusion, we need to calculate the unique I/O buffers
  4171. # accounting for buffers that become internal (removed through fusion)
  4172. # Get all nodes that will be in the fused node
  4173. fused_node_names = OrderedSet(
  4174. [node.get_name() for node in node1.get_nodes()]
  4175. + [node.get_name() for node in node2.get_nodes()]
  4176. )
  4177. # Calculate node2 reads that can be removed through fusion,
  4178. # i.e. node2 reads that are outputs of node1
  4179. node1_write_names = OrderedSet(dep.name for dep in node1.read_writes.writes)
  4180. node2_read_names = OrderedSet(dep.name for dep in node2.read_writes.reads)
  4181. reads_removed_through_fusion = node2_read_names & node1_write_names
  4182. # Calculate node1 writes that can be removed through fusion,
  4183. # i.e. node1 writes that are only read by node2
  4184. writes_removed_through_fusion: OrderedSet[str] = OrderedSet()
  4185. for write_dep in node1.read_writes.writes:
  4186. if self.can_buffer_be_removed_through_fusion(
  4187. write_dep.name, fused_node_names
  4188. ):
  4189. writes_removed_through_fusion.add(write_dep.name)
  4190. # Get all unique reads (union of both nodes' reads)
  4191. all_read_names = OrderedSet(
  4192. dep.name for dep in node1.read_writes.reads
  4193. ) | OrderedSet(dep.name for dep in node2.read_writes.reads)
  4194. # Get all unique writes (union of both nodes' writes)
  4195. all_write_names = OrderedSet(
  4196. dep.name for dep in node1.read_writes.writes
  4197. ) | OrderedSet(dep.name for dep in node2.read_writes.writes)
  4198. # Remove reads that become internal
  4199. unique_reads = all_read_names - reads_removed_through_fusion
  4200. # Remove writes that become internal
  4201. unique_writes = all_write_names - writes_removed_through_fusion
  4202. # Get all unique buffer names (reads and writes combined, but no double counting)
  4203. unique_io_buffers = unique_reads | unique_writes
  4204. return len(unique_io_buffers) > threshold
  4205. def are_long_distant_nodes(
  4206. self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
  4207. ) -> bool:
  4208. """
  4209. This function prevents fusion for nodes that can increase memory
  4210. footprint. This problem is more common in horizontal fusion, where nodes
  4211. that are far apart in the original order get fused, lengthening the live
  4212. intervals of tensors. This is very evident in models with activation
  4213. checkpointing, where the recomputed nodes from different checkpointed
  4214. regions get fused and significantly increase the memory footprint.
  4215. The current attempt is a quick, possibly hacky, heuristic to prevent the
  4216. fusion of nodes that are far away in the original order.
  4217. A better but difficult to implement heuristic would be to use live
  4218. intervals of the buffers, find region of peak pressure in the original
  4219. program and prevent fusion that crosses that peak region. We might need
  4220. special care or good approximation in this implementation, as fusion of
  4221. node changes live intervals, and re-computing live intervals and peak
  4222. memory after each fusion can introduce large compilation overhead.
  4223. """
  4224. proximity_score = max(
  4225. abs(node1.min_order - node2.max_order),
  4226. abs(node2.min_order - node1.max_order),
  4227. )
  4228. return proximity_score > 64
  4229. def decide_fusion_fail_reason(
  4230. self,
  4231. node1: BaseSchedulerNode,
  4232. node2: BaseSchedulerNode,
  4233. common_buf_names: Union[tuple[str, ...], OrderedSet[str]],
  4234. ) -> str:
  4235. """
  4236. Try to decide reasons why fusion fail due to no shared memory even though
  4237. there are common buffers.
  4238. """
  4239. reasons = {}
  4240. node1_name2dep = {dep.name: dep for dep in node1.read_writes.reads_and_writes()}
  4241. node2_name2dep = {dep.name: dep for dep in node2.read_writes.reads_and_writes()}
  4242. for buf_name in common_buf_names:
  4243. buf = V.graph.get_buffer(buf_name)
  4244. lhs_dep = node1_name2dep[buf_name]
  4245. rhs_dep = node2_name2dep[buf_name]
  4246. if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep):
  4247. reasons[buf_name] = (
  4248. f"not MemoryDep: {type(lhs_dep)} v.s. {type(rhs_dep)}"
  4249. )
  4250. continue
  4251. if lhs_dep.get_numel() != rhs_dep.get_numel():
  4252. reasons[buf_name] = (
  4253. f"different numel: {lhs_dep.get_numel()} v.s. {rhs_dep.get_numel()}"
  4254. )
  4255. continue
  4256. # same numel but different MemoryDep.size. Should be broadcasting
  4257. if sympy_product(lhs_dep.size) != sympy_product(rhs_dep.size):
  4258. reasons[buf_name] = "broadcast"
  4259. continue
  4260. lhs_off = lhs_dep.get_offset()
  4261. rhs_off = rhs_dep.get_offset()
  4262. if lhs_off != rhs_off:
  4263. # One example is in transformer, we use a concatenated linear layer
  4264. # to project Q/K/V and then split the result. The 3 splits will
  4265. # point to the same buffer with different offsets.
  4266. reasons[buf_name] = f"different offset: {lhs_off} v.s. {rhs_off}"
  4267. continue
  4268. if (
  4269. lhs_dep.normalize_with_stride_order()
  4270. == rhs_dep.normalize_with_stride_order()
  4271. ):
  4272. reasons[buf_name] = f"Mismatch loop orders: {lhs_dep} v.s. {rhs_dep}"
  4273. continue
  4274. # Add more rules here
  4275. layout_str = ""
  4276. if not isinstance(buf, ir.TorchBindObject):
  4277. layout_str = f"Layout: {buf.layout}"
  4278. reasons[buf_name] = (
  4279. f"Unknown reason: {lhs_dep} v.s. {rhs_dep}. {layout_str}"
  4280. )
  4281. return str(reasons)
  4282. def shared_data_after_inverting_indexing(
  4283. self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
  4284. ) -> int:
  4285. """
  4286. Attempts to enable fusion between two nodes by inverting indexing patterns.
  4287. This optimization targets cases where node1 has a contiguous write and
  4288. node2 has a contiguous write but discontiguous read. By inverting the
  4289. indexing in node2's read and write operations, we can make them compatible
  4290. with node1 for potential fusion.
  4291. Args:
  4292. node1: First scheduler node (source)
  4293. node2: Second scheduler node (target for inversion)
  4294. Returns:
  4295. int: Fusion score if successful, 0 if optimization not applicable
  4296. """
  4297. if not config.loop_index_inversion_in_fusion:
  4298. return -1
  4299. if any(n.is_cpu() for n in [node1, node2]):
  4300. return -1
  4301. # Check for shared buffers between nodes
  4302. node1_buffer_names = node1.read_writes.buffer_names()
  4303. node2_buffer_names = node2.read_writes.buffer_names()
  4304. common_buffer_names = node1_buffer_names & node2_buffer_names
  4305. if not common_buffer_names:
  4306. return -1
  4307. # only invert if node1 is single unmet dep
  4308. node2_unmet_dependencies = OrderedSet(
  4309. dep.name for dep in node2.unmet_dependencies
  4310. )
  4311. if node2_unmet_dependencies - node1_buffer_names:
  4312. return -1
  4313. if len(node2_unmet_dependencies) > 1:
  4314. return -1
  4315. # Currently only handle single read/write operations
  4316. if len(node2.read_writes.reads) > 1 or len(node2.read_writes.writes) > 1:
  4317. return -1
  4318. node2_read = next(iter(node2.read_writes.reads))
  4319. node2_write = next(iter(node2.read_writes.writes))
  4320. if not isinstance(node2_read, MemoryDep) or not isinstance(
  4321. node2_write, MemoryDep
  4322. ):
  4323. return -1
  4324. node1_writes = {dep.name: dep for dep in node1.read_writes.writes}
  4325. if node2_read.name not in node1_writes:
  4326. return -1
  4327. node1_write = node1_writes[node2_read.name]
  4328. if not isinstance(node1_write, MemoryDep):
  4329. return -1
  4330. # We are checking for compatibility with the normalized node1 write
  4331. # then modifying node2 reads/writes. since the node1 write will be just used
  4332. # for compatibility, while node2 will be used in actual modification, just
  4333. # normalize node1 not node2.
  4334. node1_write = node1_write.normalize()
  4335. if (
  4336. node1_write.index != node2_write.index
  4337. and node1_write.size != node2_write.size
  4338. ):
  4339. return -1
  4340. if node2_read.size != node2_write.size or len(node2_read.var_names) != 1:
  4341. return -1
  4342. # Verify we have exactly two indexing expressions (one read, one write)
  4343. if len(node2._body.indexing_exprs) != 2: # type: ignore[attr-defined]
  4344. return -1
  4345. # No subblocks allowed for this optimization
  4346. if node2._body.subblocks: # type: ignore[attr-defined]
  4347. return -1
  4348. assert (
  4349. "index0" in node2._body.indexing_exprs # type: ignore[attr-defined]
  4350. and "index1" in node2._body.indexing_exprs # type: ignore[attr-defined]
  4351. )
  4352. # Extract and verify single read expression
  4353. node2_read_exprs = OrderedSet(expr for expr in node2._body.get_read_exprs()) # type: ignore[attr-defined]
  4354. if len(node2_read_exprs) != 1:
  4355. return -1
  4356. read_expr = next(iter(node2_read_exprs))
  4357. # Determine which index is for reading vs writing
  4358. if read_expr == node2._body.indexing_exprs["index0"]: # type: ignore[attr-defined]
  4359. read_expr_index = "index0"
  4360. write_expr_index = "index1"
  4361. else:
  4362. assert read_expr == node2._body.indexing_exprs["index1"] # type: ignore[attr-defined]
  4363. read_expr_index = "index1"
  4364. write_expr_index = "index0"
  4365. from torch._inductor.invert_expr_analysis import generate_inverse_formula
  4366. index_vars = node2._body.vars[0] # type: ignore[attr-defined]
  4367. if len(index_vars) != 1:
  4368. return -1
  4369. simplified_terms = []
  4370. for term in sympy.Add.make_args(read_expr):
  4371. simplified_terms.append(
  4372. V.graph.sizevars.combine_modular_indexing_pairs(term)
  4373. )
  4374. simplified_read_expr = sum(simplified_terms)
  4375. inverse_formula = generate_inverse_formula(simplified_read_expr, index_vars[0])
  4376. # formula is not invertible
  4377. if inverse_formula is None:
  4378. return -1
  4379. # === Apply Inversion ===
  4380. # Swap the indexing expressions using the inverse formula
  4381. node2._body.indexing_exprs[read_expr_index] = node2._body.indexing_exprs[ # type: ignore[attr-defined]
  4382. write_expr_index
  4383. ]
  4384. node2._body.indexing_exprs[write_expr_index] = inverse_formula # type: ignore[attr-defined]
  4385. # Refresh dependencies and calculate fusion score
  4386. node2.refresh_dependencies(True, False) # type: ignore[attr-defined]
  4387. score = self.score_fusion_memory(node1, node2)
  4388. assert isinstance(score, int)
  4389. fusion_log.info("Shared memory after inversion: %d", score)
  4390. return score
  4391. def shared_data_after_reordering_loop(
  4392. self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
  4393. ) -> int:
  4394. """
  4395. Right now just greedily reorder the loop of node1 to be compatible with node2,
  4396. but ideally we should have some heuristics to reorder the loop for node2
  4397. to be compatible with node1 if that's more efficient.
  4398. Return the amount of shared data re-computed in this method.
  4399. If no such recomputation happens, return -1 (not return 0 since 0 is a valid
  4400. amount of shared data).
  4401. """
  4402. # TODO Don't do loop reordering for CPU for now.
  4403. # Should debug more why it does not work for CPU codegen
  4404. if not config.loop_ordering_after_fusion or any(
  4405. n.is_cpu() for n in [node1, node2]
  4406. ):
  4407. return -1
  4408. # in some rare case, a template can be passed in.
  4409. # Check test_interaction_with_multi_template in test_loop_ordering.py
  4410. # and https://github.com/pytorch/pytorch/issues/165579
  4411. if node1.is_template() or node2.is_template():
  4412. return -1
  4413. node1_buffer_names = node1.read_writes.buffer_names()
  4414. node2_buffer_names = node2.read_writes.buffer_names()
  4415. # Fast path: no common buffers.
  4416. common_buffer_names = node1_buffer_names & node2_buffer_names
  4417. if not common_buffer_names:
  4418. return -1
  4419. node1_name2dep = {dep.name: dep for dep in node1.read_writes.reads_and_writes()}
  4420. node2_name2dep = {dep.name: dep for dep in node2.read_writes.reads_and_writes()}
  4421. # Find the commons buffers that has different loop orders
  4422. candidates = []
  4423. for buffer_name in common_buffer_names:
  4424. lhs_dep = node1_name2dep[buffer_name]
  4425. rhs_dep = node2_name2dep[buffer_name]
  4426. if (
  4427. lhs_dep.normalize_with_stride_order()
  4428. == rhs_dep.normalize_with_stride_order()
  4429. ):
  4430. candidates.append(
  4431. (
  4432. V.graph.sizevars.size_hint(lhs_dep.get_numel(), fallback=0),
  4433. lhs_dep,
  4434. rhs_dep,
  4435. )
  4436. )
  4437. if len(candidates) == 0:
  4438. return -1
  4439. # Pick the largest buffer to guide the loop reordering
  4440. _numel, lhs_dep, rhs_dep = max(candidates, key=operator.itemgetter(0))
  4441. if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep):
  4442. return -1
  4443. if lhs_dep.num_vars != rhs_dep.num_vars:
  4444. # this can happen due to we don't merge loops.
  4445. # We can not do loop reordering in this case right now
  4446. # Simply returning true if the two Deps are the same after
  4447. # normalization (merging loops)
  4448. if lhs_dep.normalize() == rhs_dep.normalize():
  4449. return self.dep_size_hint(lhs_dep)
  4450. return -1
  4451. reordered = False
  4452. # Only reorder loops for pointwise for now
  4453. if not node1.is_reduction():
  4454. reordered = node1.reorder_loops_by_dep_pair(lhs_dep, rhs_dep)
  4455. elif not node2.is_reduction():
  4456. reordered = node2.reorder_loops_by_dep_pair(rhs_dep, lhs_dep)
  4457. else:
  4458. loop_ordering_log.debug(
  4459. "Don't reorder loops since both nodes are reductions: %s v.s. %s",
  4460. node1.get_name(),
  4461. node2.get_name(),
  4462. )
  4463. return (
  4464. typing.cast(int, self.score_fusion_memory(node1, node2))
  4465. if reordered
  4466. else -1
  4467. )
  4468. def unfusable_node(self, node: BaseSchedulerNode) -> bool:
  4469. """
  4470. Is this node unfusable under any conditions.
  4471. """
  4472. return (
  4473. isinstance(node, (ExternKernelSchedulerNode, NopKernelSchedulerNode))
  4474. and not node.is_template()
  4475. and not is_output_of_multi_outputs_template(node.node)
  4476. )
  4477. def check_prologue_fusion_heuristics_fusable(
  4478. self,
  4479. prologue_node: BaseSchedulerNode,
  4480. template_node: BaseSchedulerNode,
  4481. why: WhyNoFuse,
  4482. ) -> bool:
  4483. """
  4484. Heuristics to avoid benchmarking predictably slow prologue fusions
  4485. """
  4486. # user opt into more aggressive prologue fusion, dont use heuristics
  4487. if prologue_node.get_operation_names() <= V.graph.invoke_quant_ops:
  4488. return True
  4489. read_bytes = prologue_node.get_read_buffer_sizes()
  4490. write_bytes = prologue_node.get_write_buffer_sizes()
  4491. # Initially, only do fusions which will result in fewer memory accesses inside of the template to avoid
  4492. # potential bad cache behavior and shared memory use.
  4493. # we also want to avoid benchmarking reliably unprofitable fusions like downcasts from fp32 -> fp16 inside kernel.
  4494. # allowing gathers by allowing increasing write_bytes by small factor
  4495. # TODO - make configurable per input, for instance, bias can fuse fp32 -> fp16 profitably
  4496. BYTES_THRESHOLD_MULTIPLIER = 1.1
  4497. if read_bytes > (write_bytes * BYTES_THRESHOLD_MULTIPLIER):
  4498. why("prologue fusion will not increase amount of bytes read in kernel")
  4499. return False
  4500. # we want to avoid attempting to fuse predictably unprofitable prologues
  4501. # such as increasing the unaligned reads or writes.
  4502. # TODO - would be nice to generalize this, however, we would need more explicit
  4503. # knowledge of memory access patterns in the TritonTemplate in order to know
  4504. # the stride order to check alignment.
  4505. origins = tuple(
  4506. e.target
  4507. for n in prologue_node.get_nodes()
  4508. if n.node is not None
  4509. for e in n.node.get_origins()
  4510. if e.op == "call_function"
  4511. )
  4512. if origins == (torch.ops.aten.constant_pad_nd.default,):
  4513. why(
  4514. "prologue fusion will not increase attempt to fuse in padding bc it increases unaligned reads"
  4515. )
  4516. return False
  4517. def low_prec_fp(dtype: torch.dtype) -> bool:
  4518. return dtype.itemsize <= 2 and dtype.is_floating_point
  4519. if (
  4520. low_prec_fp(template_node.get_template_node_or_throw().dtype)
  4521. and not prologue_node.can_codegen_in_low_precision()
  4522. ):
  4523. why(
  4524. "prologue fusion that must be upcast to fp32 not profitable for low precision templates"
  4525. )
  4526. return False
  4527. return True
  4528. def get_expand_dim_for_pointwise_nodes(
  4529. self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
  4530. ) -> Optional[tuple[int, SchedulerNode, sympy.Expr]]:
  4531. """
  4532. Fusing two small pointwise nodes significantly reduces kernel overhead
  4533. and launch overhead. However, slightly different sizes would prevent fusion.
  4534. Here, we decide if expanding sizes of one node is profitible by allowing
  4535. fusion, and returns the dimension to expand, node with smaller sizes,
  4536. and new size after expand.
  4537. """
  4538. # only support scheduler node
  4539. if not isinstance(node1, SchedulerNode) or not isinstance(node2, SchedulerNode):
  4540. return None
  4541. # only support computued buffer
  4542. if not (
  4543. isinstance(node1.node, ir.ComputedBuffer)
  4544. and isinstance(node2.node, ir.ComputedBuffer)
  4545. ):
  4546. return None
  4547. # does not support mutation yet since relying on index mod to handle
  4548. # out-of-boundary access.
  4549. if node1.has_aliasing_or_mutation() or node2.has_aliasing_or_mutation():
  4550. return None
  4551. # skip halide which does not support mod for index
  4552. if config.cpu_backend == "halide":
  4553. return None
  4554. # only support pointwise nodes with the same reduction size
  4555. n1_sizes, n2_sizes = node1._sizes, node2._sizes
  4556. n1_iter_sizes, n1_reduce_sizes = n1_sizes
  4557. n2_iter_sizes, n2_reduce_sizes = n2_sizes
  4558. if (
  4559. node1.is_reduction()
  4560. or node2.is_reduction()
  4561. or n1_reduce_sizes != n2_reduce_sizes
  4562. or len(n1_iter_sizes) != len(n2_iter_sizes)
  4563. ):
  4564. return None
  4565. # only support nodes with 1 write for simplification
  4566. if len(node1.read_writes.writes) > 1 or len(node2.read_writes.writes) > 1:
  4567. return None
  4568. # When memory access is small, reducing gpu kernel overhead is profitable over
  4569. # slightly larger memory access.
  4570. node1_write_memory = self.dep_size_hint(next(iter(node1.read_writes.writes)))
  4571. node2_write_memory = self.dep_size_hint(next(iter(node1.read_writes.writes)))
  4572. if (
  4573. max(node1_write_memory, node2_write_memory)
  4574. > config.small_memory_access_threshold
  4575. ):
  4576. return None
  4577. # does not support reinplace since `index % boundary` may lead to
  4578. # race condition
  4579. def has_reusable_buffer(node: BaseSchedulerNode) -> bool:
  4580. for read in node.read_writes.reads:
  4581. input_buf: Optional[Union[SchedulerBuffer, SchedulerDonatedBuffer]]
  4582. if read.name in self.name_to_donated_buffer:
  4583. input_buf = self.name_to_donated_buffer[read.name]
  4584. else:
  4585. input_buf = self.name_to_buf.get(read.name)
  4586. if (
  4587. input_buf
  4588. and V.graph.wrapper_code.can_reuse(input_buf, node)
  4589. and not isinstance(input_buf.defining_op, NopKernelSchedulerNode)
  4590. ):
  4591. return True
  4592. return False
  4593. if has_reusable_buffer(node1) or has_reusable_buffer(node2):
  4594. return None
  4595. # only support nodes with 1 mismatch dimension
  4596. mismatch_dimensions = []
  4597. for idx, (n1_size, n2_size) in enumerate(zip(n1_iter_sizes, n2_iter_sizes)):
  4598. if n1_size != n2_size:
  4599. mismatch_dimensions.append(idx)
  4600. if len(mismatch_dimensions) != 1:
  4601. return None
  4602. mismatch_dim = mismatch_dimensions[0]
  4603. mismatch_size1, mismatch_size2 = (
  4604. n1_iter_sizes[mismatch_dim],
  4605. n2_iter_sizes[mismatch_dim],
  4606. )
  4607. if V.graph.sizevars.statically_known_lt(mismatch_size1, mismatch_size2):
  4608. return mismatch_dim, node1, mismatch_size2
  4609. elif V.graph.sizevars.statically_known_lt(mismatch_size2, mismatch_size1):
  4610. return mismatch_dim, node2, mismatch_size1
  4611. else:
  4612. return None
  4613. def can_fuse(
  4614. self,
  4615. node1: BaseSchedulerNode,
  4616. node2: BaseSchedulerNode,
  4617. can_reorder: bool = False,
  4618. allow_mix_order_reduction: bool = True,
  4619. ) -> bool:
  4620. """
  4621. Determine if it is possible to combine node1 and node2 into a
  4622. single fused node.
  4623. """
  4624. if node1 is node2:
  4625. return False
  4626. if isinstance(node1, FusedMixOrderReductions):
  4627. return node1.can_fuse_with(node2)
  4628. if isinstance(node2, FusedMixOrderReductions):
  4629. # We don't fuse something before a FusedMixOrderReductions
  4630. # right now
  4631. return False
  4632. why = WhyNoFuse(node1, node2)
  4633. if node1.is_template() and self.get_backend(
  4634. node1.get_device()
  4635. ).can_fuse_multi_outputs_template(node1, node2):
  4636. return True
  4637. if isinstance(node1, GroupedSchedulerNode) or isinstance(
  4638. node2, GroupedSchedulerNode
  4639. ):
  4640. why("grouped node must not be fused with other nodes")
  4641. return False
  4642. if (
  4643. isinstance(node1, (ExternKernelSchedulerNode, NopKernelSchedulerNode))
  4644. and not node1.is_template()
  4645. ):
  4646. why("node1 is extern or nop")
  4647. return False
  4648. if (
  4649. isinstance(node2, (ExternKernelSchedulerNode, NopKernelSchedulerNode))
  4650. and not node2.is_template()
  4651. ):
  4652. why("node2 is extern or nop")
  4653. return False
  4654. if node2.get_operation_names() & node1.ancestors:
  4655. why("node1 must go before node2")
  4656. return False
  4657. if node2.is_template():
  4658. if not config.prologue_fusion:
  4659. why("prologue fusion turned off")
  4660. return False
  4661. if node1.is_reduction() or node1.is_template():
  4662. why("prologue fusion only supported for pointwise nodes")
  4663. return False
  4664. template = node2.get_template_node_or_throw()
  4665. if not isinstance(template, ir.TritonTemplateBuffer):
  4666. why("prologue fusion only supported for TritonTemplates")
  4667. return False
  4668. allowed_prologue_inps = template.get_allowed_prologue_inps()
  4669. unsupported_prologue_args = (
  4670. OrderedSet(inp.get_name() for inp in template.inputs) # type: ignore[union-attr]
  4671. - allowed_prologue_inps
  4672. )
  4673. if node1.get_buffer_names() & unsupported_prologue_args:
  4674. why("prologue fusion not implemented for kernel for these inputs")
  4675. return False
  4676. if node1.has_aliasing_or_mutation() or node1.has_aliasing_or_mutation():
  4677. why("template prologue can only fuse functional pointwise nodes")
  4678. return False
  4679. prologue_nodes = node1.get_nodes()
  4680. for node in prologue_nodes[:-1]:
  4681. node_outs = node.get_outputs()
  4682. for out in node_outs:
  4683. if not all(user.node in prologue_nodes for user in out.users):
  4684. why("template prologue can only fuse nodes with a single use")
  4685. return False
  4686. template_snodes = (
  4687. [node2]
  4688. if not isinstance(node2, FusedSchedulerNode)
  4689. else [n for n in node2.snodes if n.is_template()]
  4690. )
  4691. assert len(template_snodes) == 1
  4692. template_snode = template_snodes[0]
  4693. if not (
  4694. len(prologue_nodes[-1].outputs) == 1
  4695. and len(prologue_nodes[-1].outputs[0].users) == 1
  4696. and prologue_nodes[-1].outputs[0].users[0].node is template_snode
  4697. ):
  4698. why(
  4699. "template prologue can only fuse nodes with a single use into template"
  4700. )
  4701. return False
  4702. if not self.check_prologue_fusion_heuristics_fusable(node1, node2, why):
  4703. return False
  4704. if node1.is_template() and (
  4705. node2.has_aliasing_or_mutation()
  4706. or node2.is_reduction()
  4707. or not config.epilogue_fusion
  4708. ):
  4709. why("template epilogue not satisfied")
  4710. return False
  4711. if (node1.get_buffer_names() & V.graph.no_fuse_buffer_names) or (
  4712. node2.get_buffer_names() & V.graph.no_fuse_buffer_names
  4713. ):
  4714. why("fusion for buffer explicit disabled")
  4715. return False
  4716. device = node1.get_device()
  4717. device2 = node2.get_device()
  4718. if device != device2:
  4719. why("device mismatch (%s vs %s)", device, device2)
  4720. return False
  4721. del device2
  4722. shared_data_score = self.score_fusion_memory(
  4723. node1, node2, allow_mix_order_reduction=allow_mix_order_reduction
  4724. )
  4725. assert isinstance(shared_data_score, int)
  4726. if (
  4727. can_reorder
  4728. and shared_data_score < config.score_fusion_memory_threshold
  4729. and config.loop_ordering_after_fusion
  4730. ):
  4731. new_shared_data_score = self.shared_data_after_reordering_loop(node1, node2)
  4732. if new_shared_data_score >= 0:
  4733. shared_data_score = new_shared_data_score
  4734. if config.expand_dimension_for_pointwise_nodes and (
  4735. expand_analysis := self.get_expand_dim_for_pointwise_nodes(node1, node2)
  4736. ):
  4737. (expand_dim, smaller_node, expand_size) = expand_analysis
  4738. smaller_node.expand_dimension_for_pointwise_node(expand_dim, expand_size)
  4739. shared_data_score = self.score_fusion_memory(node1, node2)
  4740. assert isinstance(shared_data_score, int)
  4741. if (
  4742. config.loop_index_inversion_in_fusion
  4743. and shared_data_score < config.score_fusion_memory_threshold
  4744. ):
  4745. new_shared_data_score = self.shared_data_after_inverting_indexing(
  4746. node1, node2
  4747. )
  4748. if new_shared_data_score >= 0:
  4749. shared_data_score = new_shared_data_score
  4750. if loop_ordering_log.isEnabledFor(logging.DEBUG):
  4751. loop_ordering_log.debug(
  4752. "%s and %s has %s shared data",
  4753. node1.get_name(),
  4754. node2.get_name(),
  4755. shared_data_score,
  4756. )
  4757. if not V.choices.can_fuse(self, node1, node2, shared_data_score):
  4758. return False
  4759. if node1.get_operation_names() & node2.ancestors:
  4760. # node2 depends on node1 outputs
  4761. return (
  4762. self.can_fuse_vertical(node1, node2)
  4763. and V.choices.can_fuse_vertical(self, node1, node2, shared_data_score)
  4764. and self.get_backend(device).can_fuse_vertical(node1, node2)
  4765. )
  4766. else: # nodes don't depend on each other, but may have common reads
  4767. return V.choices.can_fuse_horizontal(
  4768. self, node1, node2, shared_data_score
  4769. ) and self.get_backend(device).can_fuse_horizontal(node1, node2)
  4770. def can_fuse_vertical(
  4771. self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
  4772. ) -> bool:
  4773. """
  4774. Check if it is legal to fuse a consumer (node2) into a producer (node1).
  4775. We can fuse them if all the reads of node2 either match
  4776. corresponding writes in node1, or are written by nodes that can
  4777. be scheduled before the fusion of node1 and node2.
  4778. """
  4779. node1_buf_names = node1.get_buffer_names()
  4780. why = WhyNoFuse(node1, node2)
  4781. remaining_deps_by_name: dict[str, list[Dep]] = defaultdict(list)
  4782. for dep in node2.unmet_dependencies:
  4783. name = self.mutation_renames.get(dep.name, dep.name)
  4784. if isinstance(dep, WeakDep) and self.fusable_weak_dep(dep, node1, node2):
  4785. continue
  4786. remaining_deps_by_name[name].append(dep)
  4787. for cd in node1.read_writes.writes:
  4788. if not isinstance(cd, MemoryDep):
  4789. continue
  4790. remaining = remaining_deps_by_name.get(
  4791. self.mutation_renames.get(cd.name, cd.name)
  4792. )
  4793. if remaining:
  4794. for rd in remaining:
  4795. if self.fusable_read_and_write(rd, cd):
  4796. remaining.remove(rd) # noqa: B909
  4797. remaining_deps = OrderedSet(
  4798. dep.name
  4799. for dep in itertools.chain.from_iterable(remaining_deps_by_name.values())
  4800. )
  4801. if remaining_deps & node1_buf_names:
  4802. # MemoryDeps didn't match and read different locations of the same buffer.
  4803. # Examples here include:
  4804. # - MemoryDep("foo", x) != MemoryDep("foo", x + 1)
  4805. # - MemoryDep("foo", x) != StarDep("foo")
  4806. why("memory deps did not match")
  4807. return False
  4808. node1_op_names = node1.get_operation_names()
  4809. for name in remaining_deps:
  4810. op_name = self.name_to_buf[name].defining_op_name()
  4811. if node1_op_names & self.name_to_fused_node[op_name].ancestors:
  4812. why("intermediate nodes between node1 & node2")
  4813. return False
  4814. return True
  4815. def fusable_weak_dep(
  4816. self, weak_dep: WeakDep, node1: BaseSchedulerNode, node2: BaseSchedulerNode
  4817. ) -> bool:
  4818. if weak_dep.name not in node1.get_buffer_names():
  4819. return False
  4820. # A weak dep can be fused if and only if the fused operation acts inplace
  4821. # on the buffer being mutated. i.e. the same index is being read then mutated
  4822. mutating_writes = [
  4823. write
  4824. for write in node2.read_writes.writes
  4825. if write.name == weak_dep.mutating_buf
  4826. ]
  4827. if len(mutating_writes) != 1:
  4828. return False
  4829. write = mutating_writes[0]
  4830. if isinstance(write, StarDep):
  4831. return False
  4832. assert isinstance(write, MemoryDep)
  4833. if free_symbol_is_type(write.index, SymT.TMP):
  4834. return False
  4835. real_name = self.mutation_real_name[weak_dep.mutating_buf]
  4836. relevant_reading_nodes = [node1]
  4837. if isinstance(node1, ForeachKernelSchedulerNode):
  4838. relevant_reading_nodes = node1.snodes
  4839. num_concurrent_reads = 0
  4840. for reading_node in relevant_reading_nodes:
  4841. relevant_reads = [
  4842. read
  4843. for read in reading_node.read_writes.reads
  4844. if read.name == real_name
  4845. ]
  4846. if not relevant_reads:
  4847. continue
  4848. num_concurrent_reads += 1
  4849. if not all(
  4850. isinstance(read, MemoryDep)
  4851. and not free_symbol_is_type(read.index, SymT.TMP)
  4852. and read.index == write.index
  4853. and read.size == write.size
  4854. for read in relevant_reads
  4855. ):
  4856. return False
  4857. return num_concurrent_reads <= 1
  4858. # StarDep doesn't match MemoryDep, different indices don't match
  4859. # However, broadcasting sometimes strips dimensions, and if that's the case
  4860. # we still can match unmet dep
  4861. # if there's indirect indexing, don't match it
  4862. def fusable_read_and_write(self, read: Dep, write: MemoryDep) -> bool:
  4863. if isinstance(read, MemoryDep):
  4864. read_name = self.mutation_renames.get(read.name, read.name)
  4865. if (
  4866. read_name != write.name
  4867. or free_symbol_is_type(read.index, SymT.TMP)
  4868. or free_symbol_is_type(write.index, SymT.TMP)
  4869. ):
  4870. return False
  4871. if config.loop_ordering_after_fusion and read.num_vars != write.num_vars:
  4872. # Need merge loops if we do loop ordering after fusion since
  4873. # we have not merged the loops yet when creating the scheduler
  4874. # nodes.
  4875. read = read.normalize()
  4876. write = write.normalize()
  4877. # Operations like index_add_, scatter_add_, etc. require global
  4878. # synchronization - all threads must complete writes before any reads.
  4879. # These cannot be safely fused into the same kernel. Atomic modes and TMA stores require synchronization barriers
  4880. if self.mode_requires_synchronization(write.mode):
  4881. return False
  4882. return (
  4883. read.index == write.index
  4884. and len(read.size) >= len(write.size)
  4885. and read.size[: len(write.size)] == write.size
  4886. )
  4887. elif isinstance(read, StarDep):
  4888. read_name = self.mutation_renames.get(read.name, read.name)
  4889. write_name = self.mutation_renames.get(write.name, write.name)
  4890. if (
  4891. read.mode == write.mode
  4892. and write.mode is not None
  4893. and read_name == write_name
  4894. ):
  4895. return True
  4896. return False
  4897. def dep_size_hint(self, dep: Dep, count_bytes: bool = True) -> int:
  4898. return V.graph.get_dep_size_hint(dep, count_bytes)
  4899. def score_fusion_memory(
  4900. self,
  4901. node1: BaseSchedulerNode,
  4902. node2: BaseSchedulerNode,
  4903. count_bytes: bool = True,
  4904. return_is_mix_order_reduction: bool = False,
  4905. allow_mix_order_reduction: bool = True,
  4906. ) -> int | tuple[int, bool]:
  4907. """
  4908. The first term in our fusion score that estimates number of saved
  4909. memory operations.
  4910. """
  4911. def _construct_return_value(score, is_mix_order_reduction):
  4912. return (
  4913. (score, is_mix_order_reduction)
  4914. if return_is_mix_order_reduction
  4915. else score
  4916. )
  4917. if allow_mix_order_reduction and MixOrderReduction.can_fuse(node1, node2):
  4918. # The fusion score for mix order reduction only count
  4919. # numel so far. It's actually fine. This makes other fusions
  4920. # sharing the same amount of numels go first; but make
  4921. # fusions only share weight/bias go later.
  4922. score = MixOrderReduction.get_fusion_score(node1, node2)
  4923. return _construct_return_value(score, True)
  4924. node1_dep_len = len(node1.read_writes.reads) + len(node1.read_writes.writes)
  4925. node2_dep_len = len(node2.read_writes.reads) + len(node2.read_writes.writes)
  4926. # optimization: iter over smaller set
  4927. if min(node1_dep_len, node2_dep_len) * 4 < max(node1_dep_len, node2_dep_len):
  4928. if node1_dep_len > node2_dep_len:
  4929. node1, node2 = node2, node1
  4930. deps = [
  4931. dep
  4932. for dep in node1.read_writes.reads | node1.read_writes.writes
  4933. if dep in node2.read_writes.reads or dep in node2.read_writes.writes
  4934. ]
  4935. return _construct_return_value(
  4936. sum(self.dep_size_hint(dep, count_bytes) for dep in deps), False
  4937. )
  4938. common_memory_deps = (node1.read_writes.reads | node1.read_writes.writes) & (
  4939. node2.read_writes.reads | node2.read_writes.writes
  4940. )
  4941. return _construct_return_value(
  4942. sum(self.dep_size_hint(dep) for dep in common_memory_deps), False
  4943. )
  4944. def get_possible_fusions_with_highest_priority(
  4945. self, possible_fusions: list[tuple[BaseSchedulerNode, BaseSchedulerNode]]
  4946. ) -> list[tuple[BaseSchedulerNode, BaseSchedulerNode]]:
  4947. # Group the possible fusions based on their priority from the backend.
  4948. # Only return the group of possible fusions with highest priority.
  4949. if len(possible_fusions) == 0:
  4950. return possible_fusions
  4951. possible_fusions_group_by_priority: dict[
  4952. int, list[tuple[BaseSchedulerNode, BaseSchedulerNode]]
  4953. ] = {}
  4954. for node1, node2 in possible_fusions:
  4955. assert node1.get_device() == node2.get_device()
  4956. device = node1.get_device()
  4957. fusion_pair_priority = int(
  4958. self.get_backend(device).get_fusion_pair_priority(node1, node2)
  4959. )
  4960. if fusion_pair_priority not in possible_fusions_group_by_priority:
  4961. possible_fusions_group_by_priority[fusion_pair_priority] = [
  4962. (node1, node2),
  4963. ]
  4964. else:
  4965. possible_fusions_group_by_priority[fusion_pair_priority].append(
  4966. (node1, node2)
  4967. )
  4968. # return the possible fusions with highest priority
  4969. possible_fusions_with_highest_priority = min(
  4970. possible_fusions_group_by_priority.items(), key=operator.itemgetter(0)
  4971. )[1]
  4972. assert len(possible_fusions_with_highest_priority) > 0
  4973. return possible_fusions_with_highest_priority
  4974. def score_fusion_key(
  4975. self, nodes: tuple[BaseSchedulerNode, BaseSchedulerNode]
  4976. ) -> Any:
  4977. """
  4978. Shim for list.sort(key=...)
  4979. """
  4980. return V.choices.score_fusion(self, *nodes)
  4981. def compute_last_usage(self) -> None:
  4982. """
  4983. Populate node.last_usage recursively (also for the nodes within a FusedSchedulerNode)
  4984. """
  4985. future_used_buffers = OrderedSet(V.graph.get_output_names())
  4986. for node in reversed(self.nodes):
  4987. node.set_last_usage(future_used_buffers, self.mutation_real_name)
  4988. future_used_buffers.update(node.last_usage)
  4989. def free_buffers(self) -> None:
  4990. """Free any buffers that are no longer needed"""
  4991. for name in sorted(
  4992. self.buffer_names_to_free
  4993. - V.graph.removed_buffers
  4994. - V.graph.wrapper_code.freed # type: ignore[has-type]
  4995. ):
  4996. if name in self.name_to_buf:
  4997. buf = self.name_to_buf[name]
  4998. if buf.can_free():
  4999. V.graph.wrapper_code.codegen_free(buf.node)
  5000. elif name in V.graph.graph_inputs:
  5001. inp = V.graph.graph_inputs[name]
  5002. if isinstance(inp, ir.TorchBindObject):
  5003. V.graph.wrapper_code.codegen_free(inp)
  5004. elif isinstance(inp, ir.GeneratorState):
  5005. continue
  5006. else:
  5007. storage = inp.data
  5008. assert (
  5009. isinstance(storage, ir.StorageBox) and storage.is_input_buffer()
  5010. )
  5011. V.graph.wrapper_code.codegen_free(storage.data)
  5012. self.buffer_names_to_free.clear()
  5013. def flush(self) -> None:
  5014. for backend in self.backends.values():
  5015. backend.flush()
  5016. self.free_buffers()
  5017. def codegen_extern_call(self, scheduler_node: ExternKernelSchedulerNode) -> None:
  5018. assert isinstance(scheduler_node, ExternKernelSchedulerNode)
  5019. # 'decide_inplace_update' stores the inplace update decisions in
  5020. # the current kernel from where 'allocate' retrieve those decisions.
  5021. # We have to make sure there is a non-NULL kernel handler to store
  5022. # those inplace update decisions.
  5023. counters["inductor"]["extern_calls"] += 1
  5024. with V.set_kernel_handler(Kernel(increase_kernel_count=False)):
  5025. scheduler_node.decide_inplace_update()
  5026. scheduler_node.mark_run()
  5027. node = scheduler_node.node
  5028. assert isinstance(node, ir.ExternKernel), f"{type(node)=}"
  5029. node.codegen(V.graph.wrapper_code)
  5030. self.free_buffers()
  5031. def create_backend(self, device: torch.device) -> BaseScheduling:
  5032. assert not is_gpu(device.type) or device.index is not None, (
  5033. f"{device} should have been normalized in lowering"
  5034. )
  5035. V.graph.add_device_info(device)
  5036. device_scheduling = get_scheduling_for_device(device.type)
  5037. if device_scheduling is None:
  5038. raise RuntimeError(f"Unsupported device type: {device.type}")
  5039. if not has_triton():
  5040. if (
  5041. device.type == "cuda"
  5042. and (device_props := torch.cuda.get_device_properties(device)).major < 7
  5043. ):
  5044. raise GPUTooOldForTriton(device_props, inspect.currentframe())
  5045. elif is_gpu(device.type) and not device.type == "mps":
  5046. raise TritonMissing(inspect.currentframe())
  5047. return device_scheduling(self)
  5048. def get_backend(self, device: Optional[torch.device]) -> BaseScheduling:
  5049. assert device is not None
  5050. if device not in self.backends:
  5051. self.backends[device] = self.create_backend(device)
  5052. return self.backends[device]
  5053. def enter_context(self, node: BaseSchedulerNode) -> None:
  5054. def get_order(n: torch.fx.Node) -> int:
  5055. if n not in self.origin_to_index:
  5056. self.origin_to_index.update({n: i for i, n in enumerate(n.graph.nodes)})
  5057. return self.origin_to_index[n]
  5058. # Use a dict to have ordering
  5059. origins = {
  5060. (get_order(e), e): None
  5061. for n in node.get_nodes()
  5062. if n.node is not None
  5063. for e in n.node.get_origins()
  5064. }
  5065. origins = list(origins.keys())
  5066. if origins:
  5067. _, last = max(origins, key=operator.itemgetter(0))
  5068. V.graph.wrapper_code.enter_context(last)
  5069. def can_buffer_be_removed_through_fusion(
  5070. self, name: str, fused_node_names: OrderedSet[str]
  5071. ) -> bool:
  5072. try:
  5073. users = self.name_to_buf[name].users
  5074. except KeyError:
  5075. return False
  5076. return (
  5077. all(user.is_weak or user.get_name() in fused_node_names for user in users)
  5078. and name not in self.mutation_renames
  5079. and name not in self.mutation_real_name
  5080. )
  5081. def should_partition(self, node: BaseSchedulerNode) -> Optional[str]:
  5082. """
  5083. Return the reason why we should partition the inductor graph on this node,
  5084. or None if the node is cudagraphable.
  5085. """
  5086. # Allow users to manually specify if a node should be partitioned
  5087. # Can only do this for FallbackKernels
  5088. ir_node = node.node
  5089. if isinstance(ir_node, torch._inductor.ir.FallbackKernel) and (
  5090. op := ir_node.op_overload
  5091. ):
  5092. op_overload_packet_name, op_overload_name = get_op_names(op)
  5093. if (
  5094. op_overload_packet_name in config.custom_should_partition_ops
  5095. or op_overload_name in config.custom_should_partition_ops
  5096. ):
  5097. assert isinstance(op, torch._ops.OpOverload)
  5098. return f"custom partition op: {op_overload_name}"
  5099. # When not using cudagraphs, keep all kernels in the `call` function
  5100. # instead of graph partition functions, since graph partition only brings
  5101. # benefit to cudagraph
  5102. if (
  5103. not torch._inductor.config.triton.cudagraphs
  5104. and _unstable_customized_partition_wrapper.wrapper is None
  5105. ):
  5106. return "partition includes all ops when cudagraphs is disabled"
  5107. if isinstance(node, FusedSchedulerNode):
  5108. for snode in node.snodes:
  5109. reason = self.should_partition(snode)
  5110. if reason:
  5111. return reason
  5112. return None
  5113. assert node.node is not None
  5114. if not node.is_gpu():
  5115. return f"{node.get_device()} ops"
  5116. if isinstance(node.node, ir.DeviceCopy):
  5117. return "DeviceCopy ops"
  5118. if isinstance(node.node, ir.Conditional):
  5119. return "Conditional ops"
  5120. if getattr(node.node, "unbacked_bindings", None):
  5121. return "unbacked binding ops"
  5122. if is_cudagraph_unsafe_op(node.node):
  5123. return "CUDAGraph-unsafe custom ops"
  5124. if reason := self._uses_cudagraph_unsafe_unbacked_symint(node):
  5125. return reason
  5126. # Partition around nodes with dynamic shapes when cudagraph_skip_dynamic_graphs is enabled
  5127. if config.triton.cudagraph_skip_dynamic_graphs:
  5128. if get_scheduler_node_symbol_uses(node):
  5129. return "dynamic shape ops"
  5130. return None
  5131. @cache_on_self
  5132. def _get_cudagraph_unsafe_unbacked_symints(self) -> OrderedSet[sympy.Symbol]:
  5133. """
  5134. Collect output unbacked symints from ops in config.cudagraph_unsafe_unbacked_ops.
  5135. """
  5136. unsafe_symints: OrderedSet[sympy.Symbol] = OrderedSet()
  5137. if not config.cudagraph_unsafe_unbacked_ops:
  5138. return unsafe_symints
  5139. for node in self.nodes:
  5140. ir_node = node.node
  5141. if ir_node is None:
  5142. continue
  5143. if not isinstance(ir_node, torch._inductor.ir.FallbackKernel):
  5144. continue
  5145. op = ir_node.op_overload
  5146. if op is None:
  5147. continue
  5148. op_overload_packet_name, op_overload_name = get_op_names(op)
  5149. if (
  5150. op_overload_packet_name not in config.cudagraph_unsafe_unbacked_ops
  5151. and op_overload_name not in config.cudagraph_unsafe_unbacked_ops
  5152. ):
  5153. continue
  5154. for sym in ir_node.get_unbacked_symbol_defs():
  5155. sym = V.graph.sizevars.simplify(sym)
  5156. if symbol_is_type(sym, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT)):
  5157. unsafe_symints.add(sym)
  5158. return unsafe_symints
  5159. def _uses_cudagraph_unsafe_unbacked_symint(
  5160. self, node: BaseSchedulerNode
  5161. ) -> Optional[str]:
  5162. unsafe_symints = self._get_cudagraph_unsafe_unbacked_symints()
  5163. if not unsafe_symints:
  5164. return None
  5165. node_symbols = get_scheduler_node_symbol_uses(node)
  5166. for sym in node_symbols:
  5167. simplified_sym = V.graph.sizevars.simplify(sym)
  5168. for free_sym in simplified_sym.free_symbols:
  5169. if free_sym in unsafe_symints:
  5170. return f"uses cudagraph-unsafe unbacked symint: {free_sym}"
  5171. return None
  5172. def get_name_to_nodes(
  5173. self,
  5174. ) -> dict[str, Union[ir.IRNode, ir.TorchBindObject, sympy.Expr]]:
  5175. """
  5176. Return a mapping from name strings to the corresponding graph inputs or
  5177. base scheduler node outputs.
  5178. """
  5179. name_to_node: dict[str, Union[ir.IRNode, ir.TorchBindObject, sympy.Expr]] = {}
  5180. name_to_node.update(V.graph.graph_inputs)
  5181. for node in self.nodes:
  5182. for name, scheduler_buffer in node.outputs_by_name.items():
  5183. name_to_node[name] = scheduler_buffer.node
  5184. return name_to_node
  5185. def compute_graph_partition_maps(
  5186. self,
  5187. signatures: list[GraphPartitionSignature],
  5188. ) -> None:
  5189. """
  5190. computes a mapping from partition input/output indices to graph input/output
  5191. indices for each partition.
  5192. """
  5193. name_to_graph_input_index = {
  5194. name: idx for idx, name in enumerate(V.graph.graph_inputs)
  5195. }
  5196. name_to_graph_output_index = {
  5197. name: idx for idx, name in enumerate(V.graph.get_output_names())
  5198. }
  5199. V.graph.partition_maps = []
  5200. for partition_id, signature in enumerate(signatures):
  5201. if signature.skip_cudagraph:
  5202. # Note: [Graph Partition Map for CUDAGraph]
  5203. # number of partition map should be the same as the number of generated
  5204. # partition functions. This assumption will be used when cudagraphify
  5205. # each partition function.
  5206. continue
  5207. input_mapping = []
  5208. for name in signature.input_nodes:
  5209. input_mapping.append(name_to_graph_input_index.get(name))
  5210. output_mapping = []
  5211. for node in signature.output_nodes:
  5212. output_mapping.append(name_to_graph_output_index.get(node.get_name()))
  5213. V.graph.partition_maps.append(
  5214. GraphPartitionMap(
  5215. partition_id,
  5216. input_mapping,
  5217. output_mapping,
  5218. signature.constant_names,
  5219. )
  5220. )
  5221. def get_graph_partition_symbol_inputs(
  5222. self,
  5223. partition: PartitionType,
  5224. input_nodes: dict[str, Union[ir.IRNode, ir.TorchBindObject, sympy.Expr]],
  5225. ) -> OrderedSet[sympy.Symbol]:
  5226. """
  5227. Returns all symbol inputs which are required to be in scope to successfully
  5228. perform codegen for this graph partition, including:
  5229. - free symbols used in partition nodes
  5230. - free symbols in partition input/node shapes, strides, and offsets. This is needed
  5231. for recording cudagraphs for tensors with dynamic shapes.
  5232. """
  5233. def get_input_node_symbols(
  5234. node: Union[ir.IRNode, sympy.Expr, ir.TorchBindObject],
  5235. ) -> OrderedSet[sympy.Symbol]:
  5236. """
  5237. Gets symbols used in input node shapes, strides, and offsets.
  5238. """
  5239. if isinstance(node, ir.TorchBindObject):
  5240. # TorchBindObject does not involve dynamic shapes yet
  5241. return OrderedSet()
  5242. elif isinstance(node, ir.IRNode):
  5243. return get_layout_symints(node)
  5244. else:
  5245. # node cannot be sympy.Expr since node comes from read_writes and
  5246. # read_writes does not contain sympy.Expr
  5247. raise NotImplementedError(f"Unsupported input node type: {type(node)}")
  5248. def filter_symbols(
  5249. symbols: OrderedSet[sympy.Symbol],
  5250. ) -> OrderedSet[sympy.Symbol]:
  5251. """
  5252. Filters a set of symbols that are required for codegen. Skip symbols
  5253. that are always internal to kernels, such as SymT.TMP, SymT.INDEX,
  5254. and SymT.R0_INDEX.
  5255. """
  5256. return OrderedSet(
  5257. s
  5258. for s in symbols
  5259. if symbol_is_type(
  5260. s,
  5261. (
  5262. SymT.SIZE,
  5263. SymT.FLOAT,
  5264. SymT.UNBACKED_INT,
  5265. SymT.UNBACKED_FLOAT,
  5266. ),
  5267. )
  5268. )
  5269. candidate_symbols: OrderedSet[sympy.Symbol] = OrderedSet().union(
  5270. *(get_scheduler_node_symbol_uses(node) for node in partition)
  5271. )
  5272. candidate_symbols.update(
  5273. *(get_input_node_symbols(node) for _, node in input_nodes.items())
  5274. )
  5275. candidate_symbols = filter_symbols(candidate_symbols)
  5276. res: OrderedSet[sympy.Symbol] = OrderedSet()
  5277. for s in candidate_symbols:
  5278. symplified_s = V.graph.sizevars.simplify(s)
  5279. # use free_symbols only when s is simplified to an Integer or expr
  5280. res.update(symplified_s.free_symbols)
  5281. return OrderedSet(sorted(res, key=operator.attrgetter("name")))
  5282. def get_graph_partition_signature(
  5283. self, partitions: list[PartitionType], skip_cudagraphs: list[bool]
  5284. ) -> list[GraphPartitionSignature]:
  5285. """
  5286. Gets signature for each graph partition, including input nodes, output nodes, and
  5287. whether deallocating an input within graph partition.
  5288. """
  5289. signatures = []
  5290. unmet_output_names = OrderedSet(V.graph.get_output_names())
  5291. name_to_node = self.get_name_to_nodes()
  5292. def is_unallocated_buffer(buf_name: str) -> bool:
  5293. """
  5294. Checks if buf_name resolves to a NoneLayout buffer (following mutation_real_name).
  5295. Buffers with NoneLayout are not allocated so graph partition should not
  5296. take them as inputs or outputs.
  5297. """
  5298. buf = self.name_to_buf.get(buf_name, None)
  5299. if buf is None:
  5300. return False
  5301. if isinstance(buf.node.layout, NoneLayout):
  5302. # If there's a mutation real name, check the underlying buffer
  5303. # This handles both MutationOutput and other mutation ops like
  5304. # IndexPutFallback that have NoneLayout but mutate real buffers
  5305. if real_name := self.mutation_real_name.get(buf_name, None):
  5306. return is_unallocated_buffer(real_name)
  5307. return True
  5308. return False
  5309. for partition, skip_cudagraph in zip(
  5310. reversed(partitions), reversed(skip_cudagraphs)
  5311. ):
  5312. output_names: OrderedSet[str] = OrderedSet()
  5313. for node in partition:
  5314. output_names.update(node.outputs_by_name.keys())
  5315. returned_output_names = output_names.intersection(unmet_output_names)
  5316. # all reads/writes are partition inputs except those generated
  5317. # within the partition and tensor constants
  5318. read_writes = dependencies.ReadWrites.merge_list(
  5319. [node.read_writes for node in partition]
  5320. )
  5321. # WeakDep is fake dependency on unused buffer. It should not appear
  5322. # in partition_input_names for inputs that are actually read or written.
  5323. partition_input_names = (
  5324. OrderedSet(
  5325. [
  5326. x.name
  5327. for x in read_writes.reads | read_writes.writes
  5328. if not isinstance(x, WeakDep)
  5329. ]
  5330. )
  5331. - output_names
  5332. )
  5333. partition_input_names = OrderedSet(
  5334. self.mutation_real_name.get(name, name)
  5335. for name in partition_input_names
  5336. )
  5337. buffer_names_to_free: OrderedSet[str] = OrderedSet()
  5338. for node in partition:
  5339. buffer_names_to_free.update(node.last_usage)
  5340. # buffer_names_to_free may contain buffers allocated in previous
  5341. # graph partitions. These buffers should also be a partition
  5342. # input.
  5343. extra_input_names = [
  5344. name
  5345. for name in (buffer_names_to_free - output_names)
  5346. if name in name_to_node
  5347. ]
  5348. partition_input_names.update(extra_input_names)
  5349. input_nodes = {
  5350. name: name_to_node[name]
  5351. for name in partition_input_names
  5352. if name in name_to_node
  5353. }
  5354. input_deallocation = {
  5355. name: name in buffer_names_to_free
  5356. for name in partition_input_names
  5357. if name in name_to_node
  5358. }
  5359. # if an input tensor is not freed in the partition function, it should
  5360. # also be returned as an output. This brings benefits to cudagraph
  5361. # since the returned output tensor is a cudagraph managed tensor with
  5362. # a static tensor address.
  5363. extra_output_names = [
  5364. name
  5365. for name in partition_input_names
  5366. if name in name_to_node and name not in buffer_names_to_free
  5367. ]
  5368. returned_output_names.update(extra_output_names)
  5369. returned_output_names = OrderedSet(
  5370. self.mutation_real_name.get(name, name)
  5371. for name in returned_output_names
  5372. )
  5373. output_nodes = [
  5374. name_to_node[name]
  5375. for name in returned_output_names
  5376. if not is_unallocated_buffer(name)
  5377. ]
  5378. constant_names = [
  5379. name for name in partition_input_names if name in V.graph.constants
  5380. ]
  5381. symbol_inputs = self.get_graph_partition_symbol_inputs(
  5382. partition, input_nodes
  5383. )
  5384. partition_signature = GraphPartitionSignature(
  5385. symbol_inputs,
  5386. input_nodes,
  5387. output_nodes,
  5388. input_deallocation,
  5389. skip_cudagraph,
  5390. constant_names,
  5391. )
  5392. signatures.append(partition_signature)
  5393. unmet_output_names = partition_input_names.union(
  5394. unmet_output_names - returned_output_names
  5395. )
  5396. return signatures[::-1]
  5397. def clean_removed_buffer_from_partition_signatures(
  5398. self, signature: GraphPartitionSignature
  5399. ) -> GraphPartitionSignature:
  5400. """
  5401. Updates the partition signature by removing buffers specified in
  5402. V.graph.removed_buffers. See [Note: Removed Graph Partition Arguments]
  5403. """
  5404. input_nodes = {
  5405. name: buffer
  5406. for name, buffer in signature.input_nodes.items()
  5407. if name not in V.graph.removed_buffers
  5408. }
  5409. input_deallocation = {
  5410. name: val
  5411. for name, val in signature.input_deallocation.items()
  5412. if name not in V.graph.removed_buffers
  5413. }
  5414. output_nodes = [
  5415. node
  5416. for node in signature.output_nodes
  5417. if node.maybe_get_name() not in V.graph.removed_buffers
  5418. ]
  5419. constant_names = [
  5420. name
  5421. for name in signature.constant_names
  5422. if name not in V.graph.removed_buffers
  5423. ]
  5424. return GraphPartitionSignature(
  5425. signature.symbol_inputs,
  5426. input_nodes,
  5427. output_nodes,
  5428. input_deallocation,
  5429. signature.skip_cudagraph,
  5430. constant_names,
  5431. )
  5432. def reorder_for_minimizing_partition(
  5433. self,
  5434. nodes: list[BaseSchedulerNode],
  5435. ) -> list[BaseSchedulerNode]:
  5436. """
  5437. Reorder nodes to minimize the number of partitions via a bfs
  5438. topological sort. This is the optimal reordering such that the
  5439. number of partitions cannot be reduced further. This may be
  5440. sub-optimal for other metrics such as peak memory. This does not
  5441. change relative orders of two cudagraphable nodes, nor the
  5442. relative order of two non_cudagraphable nodes.
  5443. """
  5444. import heapq
  5445. node_to_indegree: dict[BaseSchedulerNode, int] = dict()
  5446. cudagraphable_nodes: list[tuple[int, BaseSchedulerNode]] = []
  5447. non_cudagraphable_nodes: list[tuple[int, BaseSchedulerNode]] = []
  5448. node_to_index = {node: idx for idx, node in enumerate(nodes)}
  5449. def insert_pending_nodes(node: BaseSchedulerNode) -> None:
  5450. node_with_index = (node_to_index[node], node)
  5451. if self.should_partition(node):
  5452. heapq.heappush(non_cudagraphable_nodes, node_with_index)
  5453. else:
  5454. heapq.heappush(cudagraphable_nodes, node_with_index)
  5455. def update_indegree(node: BaseSchedulerNode) -> None:
  5456. for succ_node in node.mpi_node.succ_nodes:
  5457. assert node_to_indegree[succ_node] > 0
  5458. node_to_indegree[succ_node] -= 1
  5459. if node_to_indegree[succ_node] == 0:
  5460. insert_pending_nodes(succ_node)
  5461. for node in nodes:
  5462. node_to_indegree[node] = len(node.mpi_node.pred_nodes)
  5463. if node_to_indegree[node] == 0:
  5464. insert_pending_nodes(node)
  5465. schedule: list[BaseSchedulerNode] = []
  5466. num_iters: int = 0
  5467. while num_iters < len(nodes) and (
  5468. non_cudagraphable_nodes or cudagraphable_nodes
  5469. ):
  5470. while non_cudagraphable_nodes:
  5471. _, node = heapq.heappop(non_cudagraphable_nodes)
  5472. schedule.append(node)
  5473. update_indegree(node)
  5474. while cudagraphable_nodes:
  5475. _, node = heapq.heappop(cudagraphable_nodes)
  5476. schedule.append(node)
  5477. update_indegree(node)
  5478. num_iters += 1
  5479. if num_iters > len(nodes):
  5480. raise RuntimeError(
  5481. """
  5482. Failed to schedule, while loop ran too long when
  5483. reordering for minimizing the num of partitions
  5484. """
  5485. )
  5486. return schedule
  5487. def maybe_reorder_for_minimizing_partition(
  5488. self,
  5489. nodes: list[BaseSchedulerNode],
  5490. ) -> list[BaseSchedulerNode]:
  5491. """
  5492. Reorder nodes to minimize the number of partitions if this only slightly
  5493. increase peak memory.
  5494. """
  5495. from .memory import estimate_peak_memory, prepare_planning_info
  5496. graph_outputs = OrderedSet(V.graph.get_output_names())
  5497. default_peak_memory, name_to_freeable_input_buf = prepare_planning_info(
  5498. nodes,
  5499. self.name_to_buf,
  5500. self.name_to_fused_node,
  5501. OrderedSet(V.graph.graph_inputs.keys()),
  5502. graph_outputs,
  5503. )
  5504. reordered_nodes = self.reorder_for_minimizing_partition(nodes)
  5505. reorder_peak_memory, _ = estimate_peak_memory(
  5506. reordered_nodes, name_to_freeable_input_buf, graph_outputs
  5507. )
  5508. # 1.1 here means 10% extra peak memory budget which is quite arbitrary
  5509. if reorder_peak_memory < default_peak_memory * 1.1:
  5510. return reordered_nodes
  5511. return nodes
  5512. def reorder_for_partition_with_simple_dependency(
  5513. self, nodes: list[BaseSchedulerNode]
  5514. ) -> list[BaseSchedulerNode]:
  5515. """
  5516. Reorder a node if it should be partitioned and has simple dependency:
  5517. 1. move a partitioned node to the front if it has no dependency
  5518. 2. move a partitioned node to the back if it is only used by OutputNode
  5519. 3. otherwise do not reorder
  5520. """
  5521. front: list[BaseSchedulerNode] = []
  5522. middle: list[BaseSchedulerNode] = []
  5523. back: list[BaseSchedulerNode] = []
  5524. def only_output_user(node: BaseSchedulerNode) -> bool:
  5525. for buf in node.get_outputs():
  5526. for use in buf.users:
  5527. if not isinstance(use.node, OutputNode):
  5528. return False
  5529. return True
  5530. for node in nodes:
  5531. should_partition = self.should_partition(node) is not None
  5532. if should_partition and len(node.unmet_dependencies) == 0:
  5533. front.append(node)
  5534. elif should_partition and only_output_user(node):
  5535. back.append(node)
  5536. else:
  5537. middle.append(node)
  5538. return front + middle + back
  5539. def graph_partition(
  5540. self,
  5541. ) -> tuple[list[PartitionType], list[GraphPartitionSignature]]:
  5542. """
  5543. Given a list of BaseSchedulerNodes, split into a list of
  5544. graph partitions and compute partition input/output signatures.
  5545. """
  5546. partitions: list[PartitionType] = []
  5547. skip_cudagraph = True
  5548. cur_partition: PartitionType = []
  5549. skip_cudagraphs = []
  5550. for node in self.nodes:
  5551. node_should_partition = self.should_partition(node) is not None
  5552. if cur_partition and skip_cudagraph != node_should_partition:
  5553. partitions.append(cur_partition)
  5554. skip_cudagraphs.append(skip_cudagraph)
  5555. cur_partition = []
  5556. skip_cudagraph = node_should_partition
  5557. cur_partition.append(node)
  5558. if cur_partition:
  5559. partitions.append(cur_partition)
  5560. skip_cudagraphs.append(skip_cudagraph)
  5561. signatures = self.get_graph_partition_signature(
  5562. partitions=partitions, skip_cudagraphs=skip_cudagraphs
  5563. )
  5564. self.compute_graph_partition_maps(signatures)
  5565. self._log_graph_partitions(partitions, signatures)
  5566. return partitions, signatures
  5567. def _log_graph_partitions(
  5568. self,
  5569. partitions: list[PartitionType],
  5570. signatures: list[GraphPartitionSignature],
  5571. ) -> None:
  5572. if not cudagraphs_log.isEnabledFor(logging.DEBUG):
  5573. return
  5574. # Don't log partition reasons for CPU-only graphs since cudagraph
  5575. # partitioning is not relevant when there are no GPU devices
  5576. has_gpu_device = any(is_gpu(device) for device in V.graph.device_types)
  5577. if not has_gpu_device:
  5578. return
  5579. cudagraphable_count = sum(1 for s in signatures if not s.skip_cudagraph)
  5580. non_cudagraphable_count = len(signatures) - cudagraphable_count
  5581. cudagraphs_log.debug(
  5582. "Created %d graph partitions: %d cudagraphable, %d non-cudagraphable",
  5583. len(partitions),
  5584. cudagraphable_count,
  5585. non_cudagraphable_count,
  5586. )
  5587. for i, (partition, signature) in enumerate(zip(partitions, signatures)):
  5588. cudagraphs_log.debug(
  5589. " Partition %d: %d nodes, %s, inputs=%d, outputs=%d",
  5590. i,
  5591. len(partition),
  5592. "non-cudagraphable" if signature.skip_cudagraph else "cudagraphable",
  5593. len(signature.input_nodes),
  5594. len(signature.output_nodes),
  5595. )
  5596. if signature.skip_cudagraph:
  5597. # Log details for each non-cudagraphable node
  5598. for node in partition:
  5599. self._log_non_cudagraphable_node(node)
  5600. def _log_non_cudagraphable_node(self, node: BaseSchedulerNode) -> None:
  5601. """Log details for a non-cudagraphable node."""
  5602. reason = self.should_partition(node)
  5603. if not reason:
  5604. return
  5605. node_name = node.get_name()
  5606. fx_node = node.node.get_origin_node() if node.node is not None else None
  5607. parts = [f"reason={reason}"]
  5608. ir_type = type(node.node).__name__
  5609. parts.append(f"ir={ir_type}")
  5610. if fx_node is not None:
  5611. fx_str = f"{fx_node.target}({', '.join(str(a) for a in fx_node.args)})"
  5612. parts.append(f"fx={fx_str}")
  5613. cudagraphs_log.debug(" %s: %s", node_name, ", ".join(parts))
  5614. # Log full stack trace if available
  5615. if fx_node is not None:
  5616. stack_trace = fx_node.meta.get("stack_trace", None)
  5617. if stack_trace:
  5618. for line in stack_trace.strip().split("\n"):
  5619. cudagraphs_log.debug(" %s", line)
  5620. def codegen(self) -> None:
  5621. with dynamo_timed("Scheduler.codegen"):
  5622. return (
  5623. self._codegen_partitions()
  5624. if torch._inductor.config.graph_partition
  5625. else self._codegen(self.nodes)
  5626. )
  5627. def _codegen_partition_wrapper(
  5628. self,
  5629. partition: PartitionType,
  5630. signature: GraphPartitionSignature,
  5631. ) -> None:
  5632. """Codegen a partition given its inputs/outputs"""
  5633. from .codegen.wrapper import SubgraphPythonWrapperCodegen
  5634. parent_wrapper_code = V.graph.wrapper_code
  5635. graph_partition_id = next(self._graph_partition_counter)
  5636. with V.graph.set_current_wrapper_code():
  5637. V.graph.init_wrapper_code(
  5638. is_subgraph=True,
  5639. subgraph_name=f"partition_{graph_partition_id}",
  5640. parent_wrapper_code=parent_wrapper_code,
  5641. partition_signatures=signature,
  5642. )
  5643. self._codegen(partition)
  5644. # Note: [Removed Graph Partition Arguments]
  5645. # Graph partition relies on node.read_writes to analyze the partition
  5646. # inputs and outputs. However, during codegen, we may decide some buffers
  5647. # are internal to a kernel (e.g., triton kernel) such that these buffers
  5648. # are never actually defined. This information is collected during codegen
  5649. # and recorded in V.graph.removed_buffers. So we cleanup signature and write
  5650. # prefix (i.e., generating call function and return outputs) after we have
  5651. # codegen the partition.
  5652. assert isinstance(V.graph.wrapper_code, SubgraphPythonWrapperCodegen)
  5653. signature = self.clean_removed_buffer_from_partition_signatures(signature)
  5654. V.graph.wrapper_code.partition_signatures = signature
  5655. V.graph.wrapper_code.write_prefix()
  5656. graph_name = V.graph.name
  5657. partition_code, _ = V.graph.wrapper_code.generate(V.graph.is_inference)
  5658. V.graph.wrapper_code.define_subgraph_launcher_fn(graph_name, partition_code)
  5659. V.graph.wrapper_code.codegen_partition_call(graph_partition_id, signature)
  5660. V.graph.wrapper_code.allocated.update( # type: ignore[has-type]
  5661. [node.get_name() for node in signature.output_nodes]
  5662. )
  5663. def use_default_device_context(
  5664. self, partitions: list[PartitionType], signatures: list[GraphPartitionSignature]
  5665. ) -> contextlib.AbstractContextManager[None]:
  5666. @contextlib.contextmanager
  5667. def ctx() -> Iterator[None]:
  5668. self.update_graph_partition_default_device(partitions, signatures)
  5669. if self.default_device_context and device_need_guard(
  5670. self.default_device_context.type
  5671. ):
  5672. assert self.default_device_context.index is not None, (
  5673. "device should have an index"
  5674. )
  5675. V.graph.wrapper_code.codegen_device_guard_enter(
  5676. self.default_device_context.index
  5677. )
  5678. try:
  5679. yield
  5680. finally:
  5681. if self.default_device_context and device_need_guard(
  5682. self.default_device_context.type
  5683. ):
  5684. V.graph.wrapper_code.codegen_device_guard_exit()
  5685. self.default_device_context = None
  5686. return ctx()
  5687. def update_graph_partition_default_device(
  5688. self, partitions: list[PartitionType], signatures: list[GraphPartitionSignature]
  5689. ) -> None:
  5690. # Note: [Graph Partition Device Contexts]
  5691. # Entering a device context takes 60 microseconds and exiting a device
  5692. # context takes 20 microseconds. If all graph partitions and
  5693. # cudagraph-unsafe ops happen on the same device, we can share the
  5694. # device context.
  5695. if len(partitions) == 1 and not signatures[0].skip_cudagraph:
  5696. # If there is only 1 cudagraph partition, the device context
  5697. # should happen within the cudagraph partition, which
  5698. # would be removed by cudagraph.
  5699. return
  5700. def get_cudagraph_partition_device(partition: PartitionType) -> torch.device:
  5701. partition_device = partition[0].get_device()
  5702. assert partition_device is not None
  5703. return partition_device
  5704. def all_on_target_device(
  5705. partition: PartitionType, target_device: torch.device
  5706. ) -> bool:
  5707. for node in partition:
  5708. device = node.get_device()
  5709. if device != target_device:
  5710. return False
  5711. return True
  5712. cudagraph_partition_device = None
  5713. for partition, signature in zip(partitions, signatures):
  5714. if not signature.skip_cudagraph:
  5715. cudagraph_partition_device = get_cudagraph_partition_device(partition)
  5716. break
  5717. # all partitions skip cudagraph
  5718. if cudagraph_partition_device is None:
  5719. return
  5720. for partition, signature in zip(partitions, signatures):
  5721. if signature.skip_cudagraph and not all_on_target_device(
  5722. partition, cudagraph_partition_device
  5723. ):
  5724. return
  5725. self.default_device_context = cudagraph_partition_device
  5726. def _codegen_partitions(self) -> None:
  5727. """
  5728. Split nodes into partitions and codegen each partition into separate functions.
  5729. This allows further applying different optimizations (e.g., cudagraph) to
  5730. each function.
  5731. """
  5732. partitions, signatures = self.graph_partition()
  5733. if len(partitions) > 1:
  5734. counters["inductor"]["cudagraph_partitions"] += len(partitions)
  5735. with self.use_default_device_context(partitions, signatures):
  5736. for partition, signature in zip(partitions, signatures):
  5737. assert len(partition) >= 1, (
  5738. f"Each partition must have at least one node but found {len(partition)}"
  5739. )
  5740. if signature.skip_cudagraph:
  5741. self._codegen(partition)
  5742. else:
  5743. self._codegen_partition_wrapper(partition, signature)
  5744. num_partitions = next(self._graph_partition_counter)
  5745. V.graph.wrapper_code.set_all_partition_names(num_partitions)
  5746. # See [Note: Graph Partition Map for CUDAGraph]
  5747. if num_partitions > 0:
  5748. assert V.graph.partition_maps is not None
  5749. assert num_partitions == len(V.graph.partition_maps), (
  5750. f"Expect {num_partitions} partition maps but got {len(V.graph.partition_maps)}"
  5751. )
  5752. def _codegen(self, nodes: list[BaseSchedulerNode]) -> None:
  5753. if config.check_stack_no_cycles_TESTING_ONLY:
  5754. import torch._dynamo.convert_frame
  5755. stack = traceback.extract_stack()
  5756. seen: OrderedSet[tuple[str, int | None]] = OrderedSet()
  5757. for frame in reversed(stack):
  5758. # This is where maybe_cprofile is
  5759. if (
  5760. frame.name == "_compile_inner"
  5761. and frame.filename == torch._dynamo.convert_frame.__file__
  5762. ):
  5763. break
  5764. key = (frame.filename, frame.lineno)
  5765. assert key not in seen, (
  5766. f"Duplicate stack frame {frame.filename}:{frame.lineno}; "
  5767. "did you add a decorator to one of the functions in this stack "
  5768. "trace? If so, try using a context manager instead."
  5769. )
  5770. seen.add(key)
  5771. self.current_device = self.default_device_context
  5772. assert self.previous_node is None
  5773. # pyrefly: ignore [unbound-name]
  5774. if self.default_device_context and config.triton.autotune_at_compile_time:
  5775. V.graph.wrapper_code.write_get_raw_stream_header()
  5776. for node in nodes:
  5777. if log.isEnabledFor(logging.DEBUG):
  5778. try:
  5779. log.debug(
  5780. "Generating code for node %s with estimated runtime %f",
  5781. node.get_name(),
  5782. node.get_estimated_runtime(),
  5783. )
  5784. except Exception:
  5785. log.debug(
  5786. "Generating code for node %s with estimated runtime 0.0",
  5787. node.get_name(),
  5788. )
  5789. self.enter_context(node)
  5790. if device := node.get_device():
  5791. if (
  5792. device != self.current_device
  5793. or node.is_extern()
  5794. or node.is_template()
  5795. ):
  5796. self.flush()
  5797. if device != self.current_device:
  5798. if self.current_device and device_need_guard(
  5799. self.current_device.type
  5800. ):
  5801. V.graph.wrapper_code.codegen_device_guard_exit()
  5802. self.current_device = device
  5803. if device_need_guard(device.type):
  5804. assert device.index is not None, "device should have an index"
  5805. V.graph.wrapper_code.codegen_device_guard_enter(device.index)
  5806. self.current_node = node
  5807. self.buffer_names_to_free.update(node.last_usage)
  5808. if node.is_template():
  5809. prologue, template_node, epilogue = node.get_prologue_template_epilogue(
  5810. list(node.get_nodes())
  5811. )
  5812. # pyrefly: ignore [unbound-name]
  5813. self.get_backend(device).codegen_template(
  5814. template_node, epilogue, prologue
  5815. )
  5816. elif node.is_extern():
  5817. node = typing.cast(ExternKernelSchedulerNode, node)
  5818. self.codegen_extern_call(node)
  5819. elif node.is_foreach():
  5820. node = typing.cast(ForeachKernelSchedulerNode, node)
  5821. # pyrefly: ignore [unbound-name]
  5822. backend_ = self.get_backend(device)
  5823. from .codegen.cuda_combined_scheduling import CUDACombinedScheduling
  5824. from .codegen.simd import SIMDScheduling
  5825. if isinstance(backend_, (SIMDScheduling, CUDACombinedScheduling)):
  5826. backend = backend_
  5827. else:
  5828. raise AssertionError(f"{type(self)=}")
  5829. backend.codegen_combo_kernel(node)
  5830. elif isinstance(node, FusedMixOrderReductions):
  5831. # pyrefly: ignore [unbound-name]
  5832. self.get_backend(device).codegen_mix_order_reduction(node)
  5833. elif isinstance(node, (FusedSchedulerNode, SchedulerNode)):
  5834. # pyrefly: ignore [unbound-name]
  5835. self.get_backend(device).codegen_node(node)
  5836. else:
  5837. assert isinstance(node, NopKernelSchedulerNode)
  5838. node.mark_run()
  5839. # pyrefly: ignore [unbound-name]
  5840. if config.triton.debug_sync_kernel:
  5841. # pyrefly: ignore [unbound-name]
  5842. self.get_backend(device).codegen_sync()
  5843. self.available_buffer_names.update(node.get_buffer_names())
  5844. self.completed_operations.update(node.get_operation_names())
  5845. if not isinstance(node, NopKernelSchedulerNode):
  5846. device = node.get_device()
  5847. if (
  5848. device is not None
  5849. and device.type != "meta"
  5850. and self.get_backend(device).ready_to_flush()
  5851. ):
  5852. self.flush()
  5853. if all(isinstance(n, SchedulerNode) for n in node.get_nodes()):
  5854. self.previous_node = node
  5855. else:
  5856. self.previous_node = None
  5857. if self.current_device != self.default_device_context:
  5858. # when default_device_context is not None, we are codegen
  5859. # for graph partitions and all nodes must be on
  5860. # the same default device.
  5861. assert self.current_device is not None
  5862. if device_need_guard(self.current_device.type):
  5863. # exit the outermost CUDA device guard. this is
  5864. # important for nested indentation codegen-ing.
  5865. V.graph.wrapper_code.codegen_device_guard_exit()
  5866. self.previous_node = None
  5867. self.flush()
  5868. def benchmark_combo_kernel(
  5869. self, node_list: Sequence[BaseSchedulerNode], node_benchmark_results
  5870. ) -> tuple[float, float, list[Optional[str]]]:
  5871. """
  5872. Benchmark fused list of nodes and return the execution time
  5873. in milliseconds on randomly generated inputs.
  5874. """
  5875. device = node_list[0].get_device()
  5876. V.graph.scheduler = self
  5877. self.current_device = device
  5878. assert device is not None
  5879. backend = self.get_backend(device)
  5880. return backend.benchmark_combo_kernel(node_list, node_benchmark_results)
  5881. def speedup_by_combo_kernel(self, nodes: list[BaseSchedulerNode]) -> bool:
  5882. """
  5883. If config.benchmark_fusion is False, always return True.
  5884. Otherwise, return True if fusion can brings speedup.
  5885. """
  5886. subkernel_nodes = nodes
  5887. device = subkernel_nodes[0].get_device()
  5888. assert all(node.get_device() == device for node in subkernel_nodes), (
  5889. "All nodes in a combo kernel group must be on the same device"
  5890. )
  5891. if not config.benchmark_combo_kernel:
  5892. return True
  5893. from triton.compiler.errors import CompilationError
  5894. ms1, path1_list = 0.0, []
  5895. node_benchmark_results = {}
  5896. for i, snode in enumerate(subkernel_nodes):
  5897. node_list = snode.get_nodes()
  5898. # We can not accurately benchmark kernel using atomic_add
  5899. # due to how we generate random integer inputs.
  5900. if self._any_atomic_add(node_list):
  5901. fusion_log.debug(
  5902. "ComboKernel: benchmarking may not accurate due to atomic_add"
  5903. )
  5904. try:
  5905. ms, path = self.benchmark_fused_nodes(node_list)
  5906. node_benchmark_results[snode] = (ms, path)
  5907. if math.isinf(ms):
  5908. fusion_log.debug(
  5909. "ComboKernel benchmark: register spilling of %d-th subkernel",
  5910. i,
  5911. )
  5912. return False
  5913. except CompilationError as e:
  5914. # workaround triton issue: https://github.com/triton-lang/triton/issues/2151
  5915. if "Loop-carried variable" in str(e):
  5916. fusion_log.debug(
  5917. "ComboKernel benchmark: return True because of loop-carried variable"
  5918. )
  5919. return True # allow fusion
  5920. else:
  5921. raise
  5922. ms1 += ms
  5923. path1_list.append(path)
  5924. try:
  5925. ms2, ms2_clone, _path2_list = self.benchmark_combo_kernel(
  5926. subkernel_nodes, node_benchmark_results
  5927. )
  5928. except CompilationError as e:
  5929. # workaround triton issue: https://github.com/triton-lang/triton/issues/2151
  5930. if "Loop-carried variable" in str(e):
  5931. fusion_log.debug(
  5932. "ComboKernel benchmark: return True because of loop-carried variable"
  5933. )
  5934. return True # allow fusion
  5935. else:
  5936. raise
  5937. # small kernels are very likely to have speedup but hard to benchmark. So we skip benchmarking.
  5938. small_kernel = ms2 - ms2_clone < 0.3 or ms1 < 0.3
  5939. if fusion_log.isEnabledFor(logging.DEBUG):
  5940. if ms1 > ms2 or small_kernel:
  5941. fusion_log.debug(
  5942. "can fuse (benchmark): fusing causes %sx speedup",
  5943. green_text(f"{ms1 / ms2:.3f}"),
  5944. )
  5945. else:
  5946. fusion_log.debug(
  5947. "cannot fuse (benchmark): fusing causes %sx slowdown",
  5948. red_text(f"{ms1 / ms2:.3f}"),
  5949. )
  5950. # ms1 returned by benchmark_fused_nodes discounted clone time
  5951. return ms2 - ms2_clone < ms1 or small_kernel
  5952. def get_buffer_layout(self, buf_name: str) -> ir.Layout:
  5953. buf = self.name_to_buf[buf_name]
  5954. assert buf.node is not None
  5955. return buf.node.get_layout()
  5956. def update_zero_dim_cpu_tensor(self) -> None:
  5957. for node in self.nodes:
  5958. if node.is_gpu():
  5959. for read in node.read_writes.reads:
  5960. buffer = V.graph.name_to_buffer.get(read.name)
  5961. if (
  5962. buffer
  5963. and get_device_type(buffer) == "cpu"
  5964. and not isinstance(
  5965. buffer.layout, (NoneLayout, MultiOutputLayout)
  5966. )
  5967. and buffer.get_size() == []
  5968. ):
  5969. V.graph.zero_dim_cpu_tensor_list.add(read.name)
  5970. class BaseScheduling: # noqa: docstring_linter
  5971. def __init__(self, scheduler: Optional[Scheduler]):
  5972. super().__init__()
  5973. self.scheduler = scheduler
  5974. def free_buffers_in_scheduler(self) -> None:
  5975. if self.scheduler:
  5976. self.scheduler.free_buffers()
  5977. def get_backend_features(self, device: torch.device) -> OrderedSet[BackendFeature]:
  5978. """Return a set of .codegen.common.BackendFeature()"""
  5979. return OrderedSet()
  5980. def can_fuse_vertical(
  5981. self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
  5982. ) -> bool:
  5983. """
  5984. Check whether node1 and node2 can be vertically fused or not.
  5985. """
  5986. raise NotImplementedError
  5987. def can_fuse_horizontal(
  5988. self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
  5989. ) -> bool:
  5990. """
  5991. Check whether node1 and node2 can be horizontally fused or not.
  5992. """
  5993. raise NotImplementedError
  5994. def can_fuse_multi_outputs_template(
  5995. self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
  5996. ) -> bool:
  5997. """
  5998. A Multi-Output Template (referenced in #144012) is a template node
  5999. with MultiOutputLayout, and its output buffers are instances of MultiOutput.
  6000. In this context, we verify whether node1 represents the Multi-Output Template
  6001. and node2 corresponds to one of its outputs. If so, we further check if
  6002. backend supports this fusion.
  6003. """
  6004. return False
  6005. def fuse(
  6006. self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
  6007. ) -> FusedSchedulerNode:
  6008. """
  6009. Fuse two nodes
  6010. """
  6011. if node1.is_foreach() or node2.is_foreach():
  6012. return ForeachKernelSchedulerNode.fuse(node1, node2)
  6013. elif MixOrderReduction.are_mix_order_reductions(node1, node2):
  6014. return FusedMixOrderReductions(node1, node2)
  6015. elif isinstance(node1, FusedMixOrderReductions):
  6016. return node1.fuse_with(node2)
  6017. else:
  6018. return FusedSchedulerNode.fuse(node1, node2)
  6019. def group_fn(
  6020. self, sizes: Sequence[Sequence[sympy.Expr]]
  6021. ) -> tuple[tuple[sympy.Expr, ...], ...]:
  6022. """
  6023. Process the iteration sizes in case a transformation needs to be applied.
  6024. """
  6025. raise NotImplementedError
  6026. def codegen_template(
  6027. self,
  6028. template_node: BaseSchedulerNode,
  6029. epilogue_nodes: Sequence[BaseSchedulerNode],
  6030. prologue_nodes: Sequence[BaseSchedulerNode],
  6031. ) -> Optional[str]:
  6032. """
  6033. Given a template node, generate a kernel.
  6034. This function is only available for triton now. If the third-party backend behaves as a sub-class
  6035. of TritonScheduling, it can override it or reuse it.
  6036. """
  6037. raise NotImplementedError
  6038. def generate_kernel_code_from_nodes(
  6039. self,
  6040. nodes: Sequence[BaseSchedulerNode],
  6041. benchmark_kernel: bool,
  6042. hint_override: Optional[int] = None,
  6043. ) -> str:
  6044. """
  6045. Generate a kernel given a list of pre-fused nodes.
  6046. """
  6047. raise NotImplementedError
  6048. def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]) -> None:
  6049. """
  6050. Generate a kernel given a list of pre-fused nodes.
  6051. """
  6052. raise NotImplementedError
  6053. def codegen_mix_order_reduction(self, node: FusedMixOrderReductions) -> None:
  6054. raise NotImplementedError
  6055. def codegen_sync(self) -> None:
  6056. """
  6057. Generate synchronization code for the kernel. This method depends on the hardware characteristics.
  6058. """
  6059. raise NotImplementedError
  6060. def ready_to_flush(self) -> bool:
  6061. """
  6062. Check whether the backend is requesting the scheduler to flush the generated kernel.
  6063. If not supported, please return False.
  6064. """
  6065. return False
  6066. def flush(self) -> None:
  6067. """
  6068. Flush the generated kernel and python wrapper code to the source code file.
  6069. """
  6070. raise NotImplementedError
  6071. def benchmark_fused_nodes(
  6072. self, nodes: Sequence[BaseSchedulerNode]
  6073. ) -> tuple[float, str]:
  6074. """
  6075. Benchmark fused list of nodes and return the execution time
  6076. in milliseconds on randomly generated inputs.
  6077. """
  6078. raise NotImplementedError
  6079. def benchmark_codegened_module(self, module: ModuleType) -> tuple[float, str]:
  6080. """
  6081. Benchmark a compiled module and return the execution time
  6082. in milliseconds on randomly generated inputs.
  6083. """
  6084. raise NotImplementedError
  6085. def get_fusion_pair_priority(
  6086. self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
  6087. ) -> int:
  6088. """
  6089. Return an unsigned integer which represents the priority of this fusion pair.
  6090. The smaller is with higher priority.
  6091. """
  6092. return 0
  6093. def benchmark_combo_kernel(
  6094. self, node_list: Sequence[BaseSchedulerNode], node_benchmark_results
  6095. ) -> tuple[float, float, list[Optional[str]]]:
  6096. """
  6097. Benchmark the list of nodes to combine and return the execution time
  6098. and memory copy time in milliseconds on randomly generated inputs.
  6099. """
  6100. raise NotImplementedError
  6101. def codegen_comment(
  6102. self,
  6103. node_schedule: Sequence[BaseSchedulerNode],
  6104. kernel_name: Optional[str] = None,
  6105. ) -> None:
  6106. if kernel_name:
  6107. from torch._inductor.debug import set_kernel_post_grad_provenance_tracing
  6108. debug_handle = set_kernel_post_grad_provenance_tracing(
  6109. node_schedule, # type: ignore[arg-type]
  6110. kernel_name,
  6111. )
  6112. V.graph.wrapper_code.write_provenance_debug_handle(
  6113. kernel_name, debug_handle
  6114. )