| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490549154925493549454955496549754985499550055015502550355045505550655075508550955105511551255135514551555165517551855195520552155225523552455255526552755285529553055315532553355345535553655375538553955405541554255435544554555465547554855495550555155525553555455555556555755585559556055615562556355645565556655675568556955705571557255735574557555765577557855795580558155825583558455855586558755885589559055915592559355945595559655975598559956005601560256035604560556065607560856095610561156125613561456155616561756185619562056215622562356245625562656275628562956305631563256335634563556365637563856395640564156425643564456455646564756485649565056515652565356545655565656575658565956605661566256635664566556665667566856695670567156725673567456755676567756785679568056815682568356845685568656875688568956905691569256935694569556965697569856995700570157025703570457055706570757085709571057115712571357145715571657175718571957205721572257235724572557265727572857295730573157325733573457355736573757385739574057415742574357445745574657475748574957505751575257535754575557565757575857595760576157625763576457655766576757685769577057715772577357745775577657775778577957805781578257835784578557865787578857895790579157925793579457955796579757985799580058015802580358045805580658075808580958105811581258135814581558165817581858195820582158225823582458255826582758285829583058315832583358345835583658375838583958405841584258435844584558465847584858495850585158525853585458555856585758585859586058615862586358645865586658675868586958705871587258735874587558765877587858795880588158825883588458855886588758885889589058915892589358945895589658975898589959005901590259035904590559065907590859095910591159125913591459155916591759185919592059215922592359245925592659275928592959305931593259335934593559365937593859395940594159425943594459455946594759485949595059515952595359545955595659575958595959605961596259635964596559665967596859695970597159725973597459755976597759785979598059815982598359845985598659875988598959905991599259935994599559965997599859996000600160026003600460056006600760086009601060116012601360146015601660176018601960206021602260236024602560266027602860296030603160326033603460356036603760386039604060416042604360446045604660476048604960506051605260536054605560566057605860596060606160626063606460656066606760686069607060716072607360746075607660776078607960806081608260836084608560866087608860896090609160926093609460956096609760986099610061016102610361046105610661076108610961106111611261136114611561166117611861196120612161226123612461256126612761286129613061316132613361346135613661376138613961406141614261436144614561466147614861496150615161526153615461556156615761586159616061616162616361646165616661676168616961706171617261736174617561766177617861796180618161826183618461856186618761886189619061916192619361946195619661976198619962006201620262036204620562066207620862096210621162126213621462156216621762186219622062216222622362246225622662276228622962306231623262336234623562366237623862396240624162426243624462456246624762486249625062516252625362546255625662576258625962606261626262636264626562666267626862696270627162726273627462756276627762786279628062816282628362846285628662876288628962906291629262936294629562966297629862996300630163026303630463056306630763086309631063116312631363146315631663176318631963206321632263236324632563266327632863296330633163326333633463356336633763386339634063416342634363446345634663476348634963506351635263536354635563566357635863596360636163626363636463656366636763686369637063716372637363746375637663776378637963806381638263836384638563866387638863896390639163926393639463956396639763986399640064016402640364046405640664076408640964106411641264136414641564166417641864196420642164226423642464256426642764286429643064316432643364346435643664376438643964406441644264436444644564466447644864496450645164526453645464556456645764586459646064616462646364646465646664676468646964706471647264736474647564766477647864796480648164826483648464856486648764886489649064916492649364946495649664976498649965006501650265036504650565066507650865096510651165126513651465156516651765186519652065216522652365246525652665276528652965306531653265336534653565366537653865396540654165426543654465456546654765486549655065516552655365546555655665576558655965606561656265636564656565666567656865696570657165726573657465756576657765786579658065816582658365846585658665876588658965906591659265936594659565966597659865996600660166026603660466056606660766086609661066116612661366146615661666176618661966206621662266236624662566266627662866296630663166326633663466356636663766386639664066416642664366446645664666476648664966506651665266536654665566566657665866596660666166626663666466656666666766686669667066716672667366746675667666776678667966806681668266836684668566866687668866896690669166926693669466956696669766986699670067016702670367046705670667076708670967106711671267136714671567166717671867196720672167226723672467256726672767286729673067316732673367346735673667376738673967406741674267436744674567466747674867496750675167526753675467556756675767586759676067616762676367646765676667676768676967706771677267736774677567766777677867796780678167826783678467856786678767886789679067916792679367946795679667976798679968006801680268036804680568066807680868096810681168126813681468156816681768186819682068216822682368246825682668276828682968306831683268336834683568366837683868396840684168426843684468456846684768486849685068516852685368546855685668576858685968606861686268636864686568666867686868696870687168726873687468756876687768786879688068816882688368846885688668876888688968906891689268936894689568966897689868996900690169026903690469056906690769086909691069116912691369146915691669176918691969206921692269236924692569266927692869296930693169326933693469356936693769386939694069416942694369446945694669476948694969506951695269536954695569566957695869596960696169626963696469656966696769686969697069716972697369746975697669776978697969806981698269836984698569866987698869896990699169926993699469956996699769986999700070017002700370047005700670077008700970107011701270137014701570167017701870197020702170227023702470257026702770287029703070317032703370347035703670377038703970407041704270437044704570467047704870497050705170527053705470557056705770587059706070617062706370647065706670677068706970707071707270737074707570767077707870797080708170827083708470857086708770887089709070917092709370947095709670977098709971007101710271037104710571067107710871097110711171127113711471157116711771187119712071217122712371247125712671277128712971307131713271337134713571367137713871397140714171427143714471457146714771487149715071517152715371547155715671577158 |
- from __future__ import annotations
- import collections
- import contextlib
- import dataclasses
- import functools
- import inspect
- import itertools
- import logging
- import math
- import operator
- import os
- import pprint
- import textwrap
- import traceback
- import typing
- from collections import Counter, defaultdict
- from concurrent.futures import as_completed, Future
- from typing import Any, Generic, Optional, TYPE_CHECKING, TypeAlias, TypeVar, Union
- from typing_extensions import ParamSpec
- from torch.utils._ordered_set import OrderedSet
- from .ir import ComputedBuffer
- if TYPE_CHECKING:
- from collections.abc import Callable, Iterator, Sequence
- from types import ModuleType
- import sympy
- import torch
- import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
- import torch.utils._pytree as pytree
- from torch._dynamo.utils import counters, dynamo_timed
- from torch._inductor.autotune_process import use_pipelined_autotuning
- from torch._inductor.codecache import LambdaFuture, PyCodeCache
- from torch._inductor.ir import TritonTemplateCallerBase
- from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
- from torch.fx.experimental.symbolic_shapes import free_symbols
- from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
- from torch.utils._triton import has_triton
- from . import comms, config, config_comms, dependencies, ir, metrics
- from .analyze_preserves_zero_mask import can_codegen_without_upcasts
- from .codegen.common import BackendFeature, get_scheduling_for_device, Kernel
- from .comm_analysis import (
- estimate_nccl_collective_runtime,
- estimate_nccl_collective_runtime_nccl_estimator,
- )
- from .dependencies import Dep, MemoryDep, StarDep, WeakDep
- from .exc import GPUTooOldForTriton, TritonMissing
- from .fx_utils import count_flops_fx
- from .ir import (
- assign_origin_node,
- get_device_type,
- GraphPartitionSignature,
- MultiOutput,
- MultiOutputLayout,
- NoneLayout,
- )
- from .loop_body import LoopBody
- from .memory import MemoryPlanningInfoForBuffer, MemoryPlanningInfoForNode
- from .runtime.hints import ReductionHint
- from .runtime.runtime_utils import green_text, red_text
- from .sizevars import SimplifyIndexing
- from .utils import (
- _unstable_customized_partition_wrapper,
- cache_on_self,
- cmp,
- device_need_guard,
- get_current_backend,
- get_device_tflops,
- get_dtype_size,
- get_gpu_dram_gbps,
- get_op_names,
- GraphPartitionMap,
- IndentedBuffer,
- is_collective,
- is_cudagraph_unsafe_op,
- is_gpu,
- is_multi_outputs_template,
- is_output_of_multi_outputs_template,
- is_wait,
- sympy_product,
- )
- from .virtualized import V
- log = logging.getLogger(__name__)
- fusion_log = torch._logging.getArtifactLogger(__name__, "fusion")
- loop_ordering_log = torch._logging.getArtifactLogger(__name__, "loop_ordering")
- compute_dependencies_log = torch._logging.getArtifactLogger(
- __name__, "compute_dependencies"
- )
- cudagraphs_log = torch._logging.getArtifactLogger(__name__, "cudagraphs")
- PartitionType: TypeAlias = list["BaseSchedulerNode"]
- _T = TypeVar("_T")
- _P = ParamSpec("_P")
- @dataclasses.dataclass
- class FusionResult:
- should_fuse: Optional[bool] = None
- callable_fn: Optional[Callable[[], bool]] = None
- future: Optional[LambdaFuture] = None
- def __post_init__(self):
- assert (self.should_fuse is not None) ^ (self.callable_fn is not None), (
- "Fusion result should contain either fusion decision or callable_fn, not both"
- )
- @classmethod
- def fuse(cls, should_fuse: bool):
- return FusionResult(should_fuse=should_fuse)
- @classmethod
- def from_callable(
- cls, callable_fn: Callable[[], bool], future: Optional[LambdaFuture] = None
- ):
- return FusionResult(callable_fn=callable_fn, future=future)
- @dataclasses.dataclass
- class PendingFusion:
- callable_fn: Callable[[], bool]
- node1: BaseSchedulerNode
- node2: BaseSchedulerNode
- future: Optional[LambdaFuture] = None
- def get_fusion_nodes(self) -> tuple[BaseSchedulerNode, BaseSchedulerNode]:
- return (self.node1, self.node2)
- class MixOrderReduction:
- """
- This class contains utility functions to decide if we should fuse reductions
- reducing across different dimensions of the same input tensor.
- """
- @staticmethod
- def is_split_reduction(node: BaseSchedulerNode) -> bool:
- return node.is_reduction() and all(
- subnode.node._split_size is not None
- for subnode in node.get_nodes()
- if isinstance(subnode, SchedulerNode)
- and subnode.is_reduction()
- and isinstance(subnode.node, ComputedBuffer)
- )
- @classmethod
- def get_numel_rnumel(cls, node: BaseSchedulerNode) -> tuple[sympy.Expr, sympy.Expr]:
- if cls.is_split_reduction(node):
- xnumel = None
- rnumel = None
- for subnode in node.get_nodes():
- if not (
- isinstance(subnode, SchedulerNode)
- and subnode.is_reduction()
- and isinstance(subnode.node, ComputedBuffer)
- ):
- continue
- assert subnode.node._original_ranges is not None
- curxnumel = V.graph.sizevars.simplify(
- sympy_product(subnode.node._original_ranges)
- )
- assert subnode.node._original_reduction_ranges is not None
- currnumel = V.graph.sizevars.simplify(
- sympy_product(subnode.node._original_reduction_ranges)
- )
- if xnumel is None:
- xnumel = curxnumel
- rnumel = currnumel
- else:
- assert V.graph.sizevars.statically_known_equals(
- xnumel, curxnumel
- ), f"{xnumel} v.s. {curxnumel}"
- assert V.graph.sizevars.statically_known_equals(
- rnumel, currnumel
- ), f"{rnumel} v.s. {currnumel}"
- assert xnumel is not None
- return (xnumel, rnumel)
- else:
- return node.group[1] # type: ignore[return-value]
- @classmethod
- def has_mix_reduction_orders(
- cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode
- ) -> bool:
- g1 = cls.get_numel_rnumel(node1)
- g2 = cls.get_numel_rnumel(node2)
- if len(g1) != 2 or len(g2) != 2 or g1 == g2:
- return False
- return tuple(g1) == tuple(reversed(g2))
- @classmethod
- def _is_full_access(cls, buf: str, node: BaseSchedulerNode) -> bool:
- """
- The access to 'buf' is not a broadcast access.
- """
- found_dep = None
- for dep in node.read_writes.reads:
- if isinstance(dep, MemoryDep) and dep.name == buf:
- found_dep = dep
- break
- if not found_dep:
- return False
- index = found_dep.index
- var_ranges = node.read_writes.var_ranges
- if not var_ranges:
- assert isinstance(node, FusedSchedulerNode), f"{type(node)}"
- var_ranges = node.snodes[0].read_writes.var_ranges
- assert var_ranges
- if not (OrderedSet(var_ranges) - OrderedSet(index.free_symbols)):
- return True
- # cases that happen after merging loops:
- # MemoryDep('arg0_1', c0, {c0: 25165824})])
- # var_ranges={d0: 32768, d1: 768}
- if V.graph.sizevars.statically_known_equals(
- sympy_product(found_dep.size), sympy_product(var_ranges.values())
- ):
- return True
- return False
- @classmethod
- def get_common_read(
- cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode
- ) -> list[str]:
- out = []
- common_reads = node1.used_buffer_names() & node2.used_buffer_names()
- for buf in common_reads:
- if cls._is_full_access(buf, node1) and cls._is_full_access(buf, node2):
- out.append(buf)
- return out
- @classmethod
- def has_common_read(
- cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode
- ) -> bool:
- return len(cls.get_common_read(node1, node2)) > 0
- @classmethod
- def get_numel(cls, node: BaseSchedulerNode) -> int:
- g1 = cls.get_numel_rnumel(node)
- return V.graph.sizevars.optimization_hint(g1[0] * g1[1], fallback=0)
- @classmethod
- def get_fusion_score(
- cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode
- ) -> int:
- # node2 is ignored for now
- return cls.get_numel(node1)
- # TODO add a cache
- @classmethod
- def can_fuse(cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool:
- """
- Check whether we can fuse two reductions with mix loop orders.
- """
- if not config.triton.mix_order_reduction:
- return False
- # TODO: Mix order reduction is not supported with cpp_wrapper yet
- if V.graph.cpp_wrapper:
- return False
- if not node1.is_gpu() or not node2.is_gpu():
- return False
- device_type = node1.get_device().type # type: ignore[union-attr]
- if (
- device_type not in ("cuda", "xpu")
- or get_current_backend(device_type) != "triton"
- ):
- return False
- if not node1.is_reduction() or not node2.is_reduction():
- return False
- if (node1.ancestors & node2.get_operation_names()) or (
- node2.ancestors & node1.get_operation_names()
- ):
- # the two reductions have no producer/consumer relationship
- return False
- # check for mix reduction orders
- if not cls.has_mix_reduction_orders(node1, node2):
- return False
- # check common buffer accesses
- common_reads = MixOrderReduction.get_common_read(node1, node2)
- if len(common_reads) == 0:
- return False
- if cls.is_contiguous_node(node1):
- contiguous_node, other_node = node1, node2
- elif cls.is_contiguous_node(node2):
- contiguous_node, other_node = node2, node1
- else:
- return False
- g1 = cls.get_numel_rnumel(contiguous_node)
- nrow, ncol = g1
- # in non strict mode, we will skip the non-critical checks
- if not config.triton.mix_order_reduction_non_strict_mode:
- # the fused version has worse perf than non-fused version for
- # small workload. When a workload is small enough, data can be
- # fully cached by L2
- size_thres = 5 * 2**20
- # Call evaluate_expr rather than statically_known_geq since nrow can
- # have dynamic shape in real models.
- # Don't use hint directly since hint can be non-representative.
- if not V.graph.sizevars.guard_or_true(sympy.Ge(nrow * ncol, size_thres)):
- return False
- # We require more more row than columns since
- # 1, we prefer doing persistent reduction for each row
- # 2, we will split the reduction across the rows
- if not V.graph.sizevars.guard_or_true(sympy.Ge(nrow, ncol * 2)):
- return False
- # When nrow is small, ncol should also be small (due to the check
- # above). Thus the entire tensor should be well cached in L2.
- # Mix order reduction is less beneficial.
- if not V.graph.sizevars.guard_or_true(sympy.Ge(nrow, 4096)):
- return False
- # Make sure a persistent reduction will be generated
- if any(
- subnode.node.data.reduction_hint # type: ignore[union-attr]
- not in (
- ReductionHint.INNER,
- ReductionHint.DEFAULT,
- )
- for subnode in contiguous_node.get_nodes()
- if subnode.is_reduction()
- ):
- return False
- # rnumel so large that we will not generated persistent reduction
- # We don't see real use cases with dynamic ncol. But if we do,
- # we should call evaluete_expr here which adds guards.
- if not V.graph.sizevars.statically_known_leq(ncol, 1024 * 16):
- return False
- # Other reduction types like max/min is not supported yet.
- # There are no real use case as well.
- out = all(
- subnode.node.get_reduction_type() # type: ignore[union-attr]
- in {
- "sum",
- "prod",
- }
- for subnode in other_node.get_nodes()
- if subnode.is_reduction()
- )
- return out
- @classmethod
- def are_mix_order_reductions(
- cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode
- ) -> bool:
- return cls.can_fuse(node1, node2)
- @classmethod
- def is_contiguous_node(cls, node: BaseSchedulerNode) -> bool:
- if not all(
- cls.is_contiguous_load(dep.name, node) for dep in node.read_writes.reads
- ):
- return False
- return True
- @classmethod
- def is_contiguous_load(cls, buf: str, parent_node: BaseSchedulerNode) -> bool:
- from torch._inductor.loop_body import MemoryUsageType
- for node in parent_node.get_nodes():
- assert isinstance(node, SchedulerNode)
- loop_body = node._body
- entries = loop_body.memory_usage[MemoryUsageType.LOAD]
- index_names = [e.index_name for e in entries if e.buffer_name == buf]
- if len(index_names) == 0:
- continue
- # there can be multiple index_names some times
- for index_name in index_names:
- index_expr = loop_body.indexing_exprs[index_name]
- var_ranges = loop_body.var_ranges
- # assumes the final symbol is for reduction
- var_symbols = list(var_ranges.keys())
- stride_vars = V.graph.sizevars.stride_vars(
- index_expr,
- var_symbols,
- var_symbols,
- )
- # stride==0 means a broadcast
- if not (stride_vars[-1] == 0 or stride_vars[-1] == 1):
- return False
- return True
- @dataclasses.dataclass
- class SchedulerBuffer:
- scheduler: Scheduler
- node: ir.Buffer
- defining_op: Optional[BaseSchedulerNode]
- users: list[NodeUser] = dataclasses.field(default_factory=list)
- mpi_buffer: MemoryPlanningInfoForBuffer = dataclasses.field(
- default_factory=MemoryPlanningInfoForBuffer
- )
- def defining_op_name(self) -> str:
- op = self.defining_op
- assert op is not None
- return op.get_name()
- def __hash__(self) -> int:
- return hash(self.node.name)
- def debug_str(self) -> str:
- result = IndentedBuffer()
- name = self.get_name()
- result.writeline(f"{name}: {type(self.node).__name__}")
- result.writeline(f"{name}.layout = {self.node.layout}")
- if self.get_aliases():
- result.writeline(f"{name}.aliases = {pformat(self.get_aliases())}")
- if self.get_mutations():
- result.writeline(f"{name}.mutations = {pformat(self.get_mutations())}")
- if len(self.users) <= 1:
- result.writeline(f"{name}.users = {self.users}")
- else:
- result.writeline(f"{name}.users = [")
- with result.indent(1):
- for user in self.users:
- result.writeline(f"{user},")
- result.writeline("]")
- return result.getrawvalue()
- def get_name(self) -> str:
- return self.node.get_name()
- def allocate(self) -> None:
- assert self.node is not None
- if not self.node.should_allocate():
- return
- if (
- self.node.get_inputs_that_alias_output()
- or self.node.get_mutation_names()
- or isinstance(self.node.get_output_spec(), ir.CommBufferLayout)
- ):
- V.graph.wrapper_code.codegen_allocation(self.node)
- return
- # hacky check for if V.kernel is a real kernel or NullHandler
- if (
- hasattr(V.kernel, "args")
- and self.get_name() in V.kernel.inplace_update_buffers
- ):
- input_buffer: Union[ir.DonatedBuffer, ir.Buffer]
- input_buffer_name = V.kernel.inplace_update_buffers[self.get_name()]
- if input_buffer_name in self.scheduler.name_to_donated_buffer:
- input_buffer = self.scheduler.name_to_donated_buffer[
- input_buffer_name
- ].node
- else:
- input_buffer = self.scheduler.name_to_buf[input_buffer_name].node
- V.graph.wrapper_code.codegen_inplace_reuse(
- input_buffer,
- self.node,
- )
- else:
- V.graph.wrapper_code.codegen_allocation(self.node)
- def can_free(self) -> bool:
- # There's no real allocated buffer, no need to free it
- assert self.node is not None
- if isinstance(self.node.layout, ir.NoneLayout) or is_multi_outputs_template(
- self.node
- ):
- return False
- for use in self.users:
- if isinstance(use.node, OutputNode):
- return False
- return True
- def set_users(self, users: list[NodeUser]) -> None:
- # deduplicate
- result: dict[int, NodeUser] = {}
- for use in users:
- if id(use.node) in result:
- result[id(use.node)] = use.merge(result[id(use.node)])
- else:
- result[id(use.node)] = use
- self.users = list(result.values())
- def get_aliases(self) -> Sequence[str]:
- assert self.node is not None
- return self.node.get_inputs_that_alias_output()
- def get_mutations(self) -> Sequence[str]:
- assert self.node is not None
- return self.node.get_mutation_names()
- def get_device(self) -> Optional[torch.device]:
- return self.node.get_output_spec().get_device()
- @dataclasses.dataclass
- class SchedulerDonatedBuffer(SchedulerBuffer):
- defining_op: Optional[BaseSchedulerNode] = None
- class BaseSchedulerNode:
- ancestors: OrderedSet[str]
- group: tuple[torch.device, tuple[tuple[sympy.Expr, ...], ...]]
- last_usage: OrderedSet[str]
- # .min_order and .max_order are only relevant for "grouped" nodes such as FusedSchedulerNode.
- # e.g. if the FusedSchedulerNode includes nodes (op_1, op_2, op_3), and op_X is X-th node
- # in `self.scheduler.nodes`, then for this FusedSchedulerNode, .min_order is 1 and .max_order is 3.
- # For non-"grouped" nodes (i.e. regular SchedulerNode),
- # .min_order = .max_order = X if this node is X-th node in `self.scheduler.nodes`.
- min_order: int
- max_order: int
- mpi_node: MemoryPlanningInfoForNode
- mutation_renames: dict[str, str]
- node: Optional[ir.Operation] = None
- outputs: list[SchedulerBuffer]
- outputs_by_name: dict[str, SchedulerBuffer]
- override_estimated_runtime: Optional[float] = None
- read_writes: dependencies.ReadWrites
- unmet_dependencies: OrderedSet[Dep]
- written: bool = False
- def __init__(self, scheduler: Scheduler) -> None:
- self.scheduler: Scheduler = scheduler
- self.debug_device_str: Callable[[BaseSchedulerNode], list[str]] = (
- lambda *args, **kwargs: []
- )
- def _init_from_node(self, node: ir.Operation) -> None:
- self.node = node
- self.ancestors = OrderedSet()
- self.last_usage = OrderedSet[
- str
- ]() # buffers that won't be used after this kernel
- self.written = False
- self.outputs = [
- SchedulerBuffer(
- scheduler=self.scheduler,
- node=output,
- defining_op=self,
- )
- for output in node.get_outputs()
- ]
- self.outputs_by_name = {buf.get_name(): buf for buf in self.outputs}
- # mutation_renames for the current node. Due to potential
- # more mutations happening later, this can be different
- # to Scheduler.mutation_renames. Also this dict should be small
- # since only mutation information relevant to the deps for this
- # node is stored here.
- self.mutation_renames = {}
- def __repr__(self) -> str:
- return f"{type(self).__name__}(name={self.get_name()!r})"
- def debug_str(self) -> str:
- """Longer form printout for trace logs"""
- name = self.get_name()
- buf = IndentedBuffer()
- buf.splice(
- f"""\
- {name}: {type(self).__name__}({type(getattr(self, "node", None)).__name__})
- {name}.writes = {pformat(self.read_writes.writes)}
- {name}.unmet_dependencies = {pformat(self.unmet_dependencies)}
- {name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)}
- {name}.outputs = [
- """
- )
- with buf.indent():
- for out in self.get_outputs():
- buf.splice(out.debug_str())
- buf.writeline("]")
- try:
- buf.splice(self.debug_str_extra())
- except Exception:
- log.warning("Ignoring error in debug_str()", exc_info=True)
- return buf.getrawvalue().rstrip()
- def debug_str_extra(self) -> str:
- return ""
- def _debug_str_for_device(self) -> list[str]:
- return self.debug_device_str(self)
- def debug_str_short(self) -> str:
- maybe_data = getattr(self.node, "data", None)
- data_str = ""
- if isinstance(maybe_data, torch._inductor.ir.Pointwise):
- data_str = ", " + maybe_data.str_helper(
- [maybe_data.get_size()], shorten=False, multiline=False
- )
- elif isinstance(maybe_data, torch._inductor.ir.Reduction):
- data_str = ", " + maybe_data.str_helper(
- [maybe_data.get_reduction_size(), maybe_data.get_reduction_type()],
- shorten=False,
- multiline=False,
- )
- return f"{self}{data_str}"
- def log_details(self) -> None:
- log.info(
- "%s: unmet_dependencies = %s, writes = %s",
- self,
- self.unmet_dependencies,
- self.read_writes.writes,
- )
- def reorder_loops_by_dep_pair(
- self, self_dep: MemoryDep, other_dep: MemoryDep
- ) -> bool:
- return False
- def update_mutated_names(self, renames: dict[str, str]) -> None:
- self.mutation_renames = {
- name: renames[name]
- for name in (dep.name for dep in self.read_writes.reads_and_writes())
- if name in renames
- }
- self.set_read_writes(self.read_writes.rename(self.mutation_renames))
- def add_fake_dep(self, dep: Dep) -> None:
- self.set_read_writes(self.read_writes.with_read(dep))
- def has_aliasing_or_mutation(self) -> bool:
- return any(
- buf.get_aliases() or buf.get_mutations() for buf in self.get_outputs()
- )
- def set_read_writes(self, rw: dependencies.ReadWrites) -> None:
- self.read_writes = rw
- self.unmet_dependencies = self.read_writes.reads
- self.prune_deps()
- def set_last_usage(
- self, future_used_buffers: OrderedSet[str], mutation_real_name: dict[str, str]
- ) -> None:
- used_buffers = self.used_or_aliased_buffer_names()
- used_buffers = OrderedSet(mutation_real_name.get(k, k) for k in used_buffers)
- self.last_usage = used_buffers - future_used_buffers
- def mark_run(self) -> None:
- for buf in self.outputs:
- buf.allocate()
- def used_buffer_names(self) -> OrderedSet[str]:
- return OrderedSet(
- dep.name
- for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes)
- )
- def used_or_aliased_buffer_names(self) -> OrderedSet[str]:
- """
- Returns buffer names used by this node, including aliases.
- Note: is_fake WeakDeps are excluded since they are purely for ordering
- and should not affect buffer lifetime.
- """
- used_names: OrderedSet[str] = OrderedSet()
- deps = [
- dep.name
- for dep in itertools.chain(self.read_writes.reads, self.read_writes.writes)
- if not (isinstance(dep, WeakDep) and dep.is_fake)
- ]
- while len(deps) > 0:
- dep = deps.pop()
- used_names.add(dep)
- if V.graph.name_to_buffer.get(dep):
- deps.extend(
- alias
- for alias in V.graph.name_to_buffer[
- dep
- ].get_inputs_that_alias_output()
- if alias not in used_names
- )
- return used_names
- def prune_deps(self) -> None:
- self.unmet_dependencies = OrderedSet(
- dep
- for dep in self.unmet_dependencies
- if dep.name not in self.scheduler.available_buffer_names
- )
- def prune_weak_deps(self) -> None:
- # Prune weak dependencies on operations that have been removed
- def should_prune(dep: Dep) -> bool:
- if not isinstance(dep, WeakDep):
- return False
- if dep.name not in self.scheduler.name_to_buf:
- return False
- op_name = self.scheduler.name_to_buf[dep.name].defining_op_name()
- return op_name in V.graph.removed_operations
- to_remove = OrderedSet(
- dep for dep in self.read_writes.reads if should_prune(dep)
- )
- self.set_read_writes(self.read_writes.remove_reads(to_remove))
- def prune_redundant_deps(
- self, name_to_fused_node: dict[str, BaseSchedulerNode]
- ) -> None:
- _prune_redundant_deps(self, name_to_fused_node, self.scheduler.name_to_buf)
- def get_name(self) -> str:
- assert self.node is not None
- return self.node.get_operation_name()
- def get_first_name(self) -> str:
- return self.get_name()
- @cache_on_self
- def get_operation_names(self) -> OrderedSet[str]:
- return OrderedSet(node.get_name() for node in self.get_nodes())
- @cache_on_self
- def get_buffer_names(self) -> OrderedSet[str]:
- return OrderedSet(out.get_name() for out in self.outputs)
- @cache_on_self
- def can_codegen_in_low_precision(self) -> bool:
- return all(
- isinstance(n, SchedulerNode)
- and can_codegen_without_upcasts(n, disallow_fp32_ops=True)
- for n in self.get_nodes()
- )
- @cache_on_self
- def can_codegen_without_upcasts(self) -> bool:
- return all(
- isinstance(n, SchedulerNode) and can_codegen_without_upcasts(n)
- for n in self.get_nodes()
- )
- def get_nodes(self) -> Sequence[BaseSchedulerNode]:
- return [self]
- def get_outputs(self) -> Sequence[SchedulerBuffer]:
- return self.outputs
- def get_output(self, buf_name: str) -> SchedulerBuffer:
- return self.outputs_by_name[buf_name]
- def get_device(self) -> Optional[torch.device]:
- assert self.node is not None
- return self.node.get_device()
- def is_cpu(self) -> bool:
- device = self.get_device()
- return device is not None and device.type == "cpu"
- def is_gpu(self) -> bool:
- device = self.get_device()
- return device is not None and is_gpu(device.type)
- def is_reduction(self) -> bool:
- return False
- def is_native_matmul(self) -> bool:
- return False
- def is_split_scan(self) -> bool:
- return False
- def is_template(self) -> bool:
- return False
- def is_extern(self) -> bool:
- return False
- def is_foreach(self) -> bool:
- return False
- def can_inplace(self, read_dep: dependencies.Dep) -> bool:
- return False
- def has_side_effects(self) -> bool:
- return False
- def decide_inplace_update(self) -> None:
- """
- Decide if there should be inplace updates for the node
- and record the decision in the active kernel.
- """
- from .codegen.wrapper import can_match_buffer_size
- if not (
- isinstance(self, SchedulerNode)
- and config.inplace_buffers
- and V.graph.has_feature(self.get_device(), BackendFeature.INPLACE_BUFFERS)
- and (
- not isinstance(V.kernel, torch._inductor.codegen.simd.SIMDKernel)
- or getattr(V.kernel, "mutations", None) is not None
- )
- # hacky check for if V.kernel is a real kernel or NullHandler
- and hasattr(V.kernel, "args")
- ):
- return
- # NOTE remove V.graph.removed_operations once deps issue is fixed
- inconsequential_nodes = (
- self.ancestors
- | V.graph.removed_operations
- | self.scheduler.completed_operations
- )
- def single_index_in_fused_node(buf_to_be_inplaced: SchedulerBuffer) -> bool:
- # Inside of NodeUser, we track that the read and write are equivalent
- # before deciding if the use can be inplace.
- # But if that use is fused into a larger kernel, we need to check equivalence
- # of other accesses in fused scheduler node as well.
- fused_node = buf_to_be_inplaced.scheduler.get_fused_node(self)
- buf_name = buf_to_be_inplaced.get_name()
- # Dedup read/writes with equivalent indices
- # TODO - would be nice if we could just cache accesses on ReadWrites,
- # and enforce variant that this class & members are functional..
- deps: OrderedSet[Dep] = OrderedSet()
- for user in buf_to_be_inplaced.users:
- user_node = user.node
- if not isinstance(user_node, BaseSchedulerNode):
- continue
- if (
- user_node.get_first_name()
- not in buf_to_be_inplaced.scheduler.name_to_fused_node
- or buf_to_be_inplaced.scheduler.get_fused_node(user_node)
- is not fused_node
- ):
- continue
- deps |= (
- o
- for o in user_node.read_writes.reads_and_writes()
- if o.name == buf_name
- )
- if len(deps) > 1:
- return False
- return True
- for buf in self.get_outputs():
- buf_node = buf.node
- assert buf_node is not None
- if (
- not buf_node.should_allocate()
- or buf_node.get_inputs_that_alias_output()
- or buf_node.get_mutation_names()
- or buf.get_name() in V.graph.removed_buffers
- ):
- continue
- for read in self.read_writes.reads:
- input_buf: Optional[Union[SchedulerBuffer, SchedulerDonatedBuffer]]
- if read.name in self.scheduler.name_to_donated_buffer:
- input_buf = self.scheduler.name_to_donated_buffer[read.name]
- else:
- input_buf = self.scheduler.name_to_buf.get(read.name)
- if (
- input_buf
- and V.graph.wrapper_code.can_reuse(input_buf, self)
- and not isinstance(input_buf.defining_op, NopKernelSchedulerNode)
- ):
- assert input_buf.users is not None
- remaining_uses = [
- x
- for x in input_buf.users
- if x.node.get_name() not in inconsequential_nodes
- ]
- if (
- len(remaining_uses) == 1
- and remaining_uses[0].can_inplace
- and remaining_uses[0].node is self
- and input_buf.node is not None
- and not isinstance(
- input_buf.node.get_output_spec(),
- (
- ir.NoneLayout,
- ir.MultiOutputLayout,
- ir.MutationLayoutSHOULDREMOVE,
- ),
- )
- and not (
- input_buf.defining_op
- and isinstance(
- input_buf.defining_op.node,
- (ir.FallbackKernel, ir.MultiOutput),
- )
- and len(input_buf.node.get_inputs_that_alias_output()) > 0
- )
- and can_match_buffer_size(input_buf.node, buf.node)
- and single_index_in_fused_node(input_buf)
- ):
- # if there isn't a triton kernel, then we don't need to call triton-specific things.
- # but TODO this might be a convenient place to signal to the Collective kernels to inplace
- # (and, can we make "kernel" less generic of a name?)
- V.kernel.args.make_inplace(input_buf.get_name(), buf.get_name())
- # mutations not tracked in cpp kernels
- if isinstance(
- V.kernel, torch._inductor.codegen.simd.SIMDKernel
- ):
- V.kernel.mutations.add(input_buf.get_name())
- V.kernel.mutations.add(buf.get_name())
- V.kernel.inplace_update_buffers[buf.get_name()] = (
- input_buf.get_name()
- )
- break
- def codegen_originating_info(
- self, buffer: IndentedBuffer, only_once: bool = True
- ) -> None:
- if not config.comment_origin:
- return
- if only_once and self.written:
- return
- assert self.node is not None
- origins = self.node.get_origins()
- out_lines = []
- for o in origins:
- if o.op == "output":
- # These are boring and samey
- continue
- out_lines.append("")
- # TODO(voz): Should the pragma be constant somewhere?
- out_lines.append("#pragma CMT ORIGIN:")
- op_info_str = f"#pragma CMT {o.op} {o.target}"
- if "seq_nr" in o.meta:
- op_info_str = op_info_str + f" seq_nr:{o.meta['seq_nr']}"
- out_lines.append(op_info_str)
- if "stack_trace" in o.meta:
- stack_trace = f"{o.meta['stack_trace']}"
- stack_trace_last_line = stack_trace.rsplit("|", maxsplit=1)[-1]
- out_lines.append(
- "#pragma CMT "
- + stack_trace_last_line.replace("{", "{{")
- .replace("}", "}}")
- .replace("\n", "\\")
- .replace(
- "\\", "\\\\"
- ) # For windows safe path, avoid for example \x, \U.
- )
- out_lines.append("#pragma CMT END ORIGIN")
- out_lines.append("")
- if len(out_lines) == 0:
- return
- # TODO(voz): Ostensibly, we should not need this. But there are cases where C++ codegen does
- # not use BracesBuffer, so we have no good indicator of a C++ buffer atm.
- buffer.writelines(out_lines)
- self.written = True
- @cache_on_self
- def get_read_write_buffers_sizes(self) -> int:
- return self.get_read_write_buffers_sizes_impl(
- include_reads=True, include_writes=True
- )
- @cache_on_self
- def get_read_buffer_sizes(self) -> int:
- return self.get_read_write_buffers_sizes_impl(
- include_reads=True, include_writes=False
- )
- @cache_on_self
- def get_write_buffer_sizes(self) -> int:
- return self.get_read_write_buffers_sizes_impl(
- include_reads=False, include_writes=True
- )
- def get_read_write_buffers_sizes_impl(
- self, include_reads: bool, include_writes: bool
- ) -> int:
- return sum(
- self.get_read_write_buffer_accesses(
- include_reads=include_reads, include_writes=include_writes
- ).values(),
- start=0,
- )
- def get_read_write_buffer_accesses(
- self, include_reads: bool, include_writes: bool
- ) -> dict[str, int]:
- """
- Counting the number of bytes accessed for a kernel is
- surprisingly tricky. In particular, there is a differentiation
- between 'theoretical' memory accesses and practical memory
- accesses. For example, a layernorm kernel may actually access an
- input 3 times, but in theory, it only needs to access its input
- once (and may be optimized to do so through say, persistent
- reductions)
- Another example is that even though a buffer is passed in, we may
- not access the entire buffer. This may occur if we are accessing
- a slice of the buffer. Another tricky case is for indirect
- indexing, where the amount of bytes accessed depends on the
- values of the input.
- What this function aims to compute is the memory accesses for
- worst-case inputs, best-case optimization. What this means is
- that for each buffer we compute the amount of potential accesses in two ways and take the minimum.
- 1. Numel in ranges multiplied by number of deps the buffer has
- 2. The buffer size
- Returns memory accesses per buffer.
- """
- if isinstance(self, NopKernelSchedulerNode):
- return {}
- if isinstance(self, ExternKernelSchedulerNode) and isinstance(
- self.node, MultiOutput
- ):
- # todo: Calculate this - it's kinda annoying.
- return {}
- if (
- isinstance(self, ExternKernelSchedulerNode)
- and isinstance(self.node, ir.FallbackKernel)
- and self.node.op_overload
- is torch._prims.rng_prims.graphsafe_run_with_rng_state
- ):
- return {}
- def try_size_hint(s: sympy.Expr) -> int:
- return V.graph.sizevars.optimization_hint(s, fallback=0)
- if isinstance(self, SchedulerNode):
- node_numel = try_size_hint(
- sympy_product(self.get_ranges()[0])
- * sympy_product(self.get_ranges()[1]),
- )
- else:
- node_numel = int(1e9)
- buf_accesses = collections.defaultdict(list)
- if include_reads:
- for dep in self.read_writes.reads:
- buf_accesses[dep.name].append(dep)
- if include_writes:
- for dep in self.read_writes.writes:
- buf_accesses[dep.name].append(dep)
- reads = (
- OrderedSet(dep.name for dep in self.read_writes.reads)
- if include_reads
- else OrderedSet()
- )
- writes = (
- OrderedSet(dep.name for dep in self.read_writes.writes)
- if include_writes
- else OrderedSet()
- )
- def is_materialized(buf: str, snodes: Sequence[BaseSchedulerNode]) -> bool:
- users = self.scheduler.name_to_buf[buf].users
- buf_uses = OrderedSet(user.node for user in users)
- return len(buf_uses - OrderedSet(snodes)) > 0
- if isinstance(self, FusedSchedulerNode):
- removed_buffers = OrderedSet(
- dep for dep in writes if not is_materialized(dep, self.snodes)
- )
- writes = writes - removed_buffers
- reads = reads - removed_buffers
- buf_byte_accesses: dict[str, int] = {}
- for buf_name in reads | writes:
- buf_accessed_elems = sum(node_numel for dep in buf_accesses[buf_name])
- buf: Union[ir.Buffer, ir.TensorBox, ir.TorchBindObject]
- if buf_name in V.graph.name_to_buffer:
- buf = V.graph.name_to_buffer[buf_name]
- elif buf_name in V.graph.graph_inputs:
- buf = V.graph.graph_inputs[buf_name]
- else:
- continue
- def get_buf_bytes(
- buf: Optional[Union[ir.Buffer, ir.TensorBox, ir.TorchBindObject]],
- ) -> int:
- if not buf:
- return 0
- if isinstance(buf, ir.TorchBindObject):
- return buf.get_buf_bytes()
- elif isinstance(buf.layout, MultiOutputLayout):
- # Kind of a lazy way to get the MultiOutput nodes corresponding to
- # a MultiOutputLayout
- users = self.scheduler.name_to_buf[buf.get_name()].users
- tot = 0
- for user in users:
- if isinstance(user.node, OutputNode):
- continue
- assert isinstance(user.node, BaseSchedulerNode)
- if isinstance(user.node.node, MultiOutput):
- for sched_buf in user.node.get_outputs():
- tot += get_buf_bytes(sched_buf.node)
- else:
- # Buf is a MultiOutputLayout but not all of its
- # users are MultiOutputs...
- # TODO: Figure out what's going on
- return 0
- return tot
- elif isinstance(buf.layout, ir.NoneLayout):
- return sum(
- get_buf_bytes(V.graph.get_buffer(mut_name))
- for mut_name in buf.get_mutation_names()
- )
- else:
- buf_elems = try_size_hint(sympy_product(buf.get_size()))
- return get_dtype_size(buf.get_dtype()) * min(
- buf_accessed_elems, buf_elems
- )
- buf_bytes = get_buf_bytes(buf)
- if buf_name not in buf_byte_accesses:
- buf_byte_accesses[buf_name] = buf_bytes
- else:
- buf_byte_accesses[buf_name] += buf_bytes
- return buf_byte_accesses
- @cache_on_self
- def estimate_flops(self) -> int | None:
- if self.node is None:
- return None
- fx_node = self.node.get_origin_node()
- if fx_node is None:
- return None
- flops = count_flops_fx(fx_node)
- if flops is None:
- return None
- if isinstance(flops, torch.SymInt):
- flops = flops.node.expr
- resolved_flops = V.graph.sizevars.optimization_hint(flops, fallback=0)
- counters["inductor"]["flop_count"] += resolved_flops
- return resolved_flops
- def get_estimated_runtime(self) -> float:
- if self.override_estimated_runtime is not None:
- return self.override_estimated_runtime
- return self._get_estimated_runtime()
- @cache_on_self
- def _get_estimated_runtime(self) -> float:
- """
- Returns estimated op runtime in milliseconds (ms)
- """
- buf = self.get_nodes()[0].get_outputs()[0]
- layout = buf.node.get_output_spec()
- if not is_gpu(get_device_type(layout)):
- # default to no reordering based on runtime
- return 0
- # Collective kernels
- if is_collective(self.node):
- assert isinstance(self.node, ir.IRNode)
- try:
- if config_comms.runtime_estimations_use_nccl_lib_estimations:
- cache_key = get_estimate_runtime_cache_key_from_snode(self)
- cache = get_estimate_runtime_cache()
- cache_val = cache.lookup(cache_key)
- if cache_val is not None:
- assert isinstance(cache_val, float)
- return cache_val
- ms = estimate_nccl_collective_runtime_nccl_estimator(self)
- if ms is None:
- # NCCL estimations fail: fallback to in-tree algorithmic estimation.
- ms = estimate_nccl_collective_runtime(self.node)
- cache.set_value(cache_key, value=ms)
- return ms
- return estimate_nccl_collective_runtime(self.node)
- except ValueError as e:
- # We don't know how to estimate runtime for this collective,
- # falling back to 0
- log.info(e) # noqa: G200
- return 0
- except TypeError as e:
- # this happens when the collective is not of type ir._CollectiveKernel
- log.info(e) # noqa: G200
- return 0
- elif is_wait(self.node):
- # ir.Wait is only used for collective ops.
- # The time needed for the collective op is already estimated and considered
- # when we are processing the collective op IR node, so ir.Wait takes 0 time
- # since it doesn't take extra time to get the result after the collective is completed.
- return 0
- ret = maybe_estimate_runtime_benchmark(self)
- if ret is not None:
- return ret
- dtype = buf.node.maybe_get_dtype()
- try:
- gpu_memory_bandwidth = get_gpu_dram_gbps()
- gpu_flops = get_device_tflops(dtype) * 10**12
- # If cudaGetDeviceProperties returns 0 for gpu_memory_bandwidth or gpu_flops
- # there is a chance to continue execution successfully. Otherwise, it would fail with
- # ZeroDivisionError below.
- if gpu_memory_bandwidth <= 0:
- raise AssertionError(
- f"gpu_memory_bandwidth cannot be <= 0, but got {gpu_memory_bandwidth}"
- )
- if gpu_flops <= 0:
- raise AssertionError(f"gpu_flops cannot be <= 0, but got {gpu_flops}")
- except Exception:
- return 0
- flops_est = self.estimate_flops()
- if flops_est == 0 or flops_est is None:
- # no flops estimate, so fall back to memory estimate
- ns = self.get_read_write_buffers_sizes() / gpu_memory_bandwidth
- ms = ns / 1e6
- return ms
- # TODO(xmfan): find a better heuristic to model FLOPS/latency relationship
- factor = 1.0
- counted_bytes = self.get_read_write_buffers_sizes()
- counted_bytes = 0 if counted_bytes is None else counted_bytes
- compute_time = (factor * flops_est / gpu_flops) * 1e9
- transfer_time = counted_bytes / gpu_memory_bandwidth
- # Return estimated runtime in milliseconds
- ns = max(compute_time, transfer_time)
- ms = ns / 1e6
- return ms
- def get_template_node(self) -> Optional[ir.TemplateBuffer]:
- return None
- def get_template_node_or_throw(self) -> ir.TemplateBuffer:
- template = self.get_template_node()
- assert template is not None
- return template
- @staticmethod
- def get_prologue_template_epilogue(
- nodes: list[BaseSchedulerNode],
- ) -> tuple[list[BaseSchedulerNode], BaseSchedulerNode, list[BaseSchedulerNode]]:
- """
- For the list of nodes, get the prologue, template, and epilogue
- """
- template_index = next(i for i, n in enumerate(nodes) if n.is_template())
- prologue = nodes[:template_index]
- template_node = nodes[template_index]
- epilogue = nodes[template_index + 1 :]
- return prologue, template_node, epilogue
- @functools.cache
- def get_estimate_runtime_cache() -> torch._inductor.codecache.LocalCache:
- return torch._inductor.codecache.LocalCache()
- def get_estimate_runtime_cache_key_from_snode(snode: BaseSchedulerNode) -> str:
- python_kernel_name = getattr(snode.node, "python_kernel_name", "")
- args = snode.node.inputs # type: ignore[union-attr]
- args = snode.node.fill_non_provided_args( # type: ignore[union-attr]
- [*args, *snode.node.constant_args], # type: ignore[union-attr]
- snode.node.kwargs, # type: ignore[union-attr]
- )
- kwargs = snode.node.kwargs # type: ignore[union-attr]
- flat_args, flat_args_pytree_spec = pytree.tree_flatten((args, kwargs))
- def _is_tensor_ir(x) -> bool: # type: ignore[no-untyped-def]
- return isinstance(x, ir.IRNode) and not isinstance(x, ir.GeneratorState)
- cache_key = str(
- (python_kernel_name,)
- + tuple(tuple(a.get_size()) if _is_tensor_ir(a) else None for a in flat_args)
- )
- return cache_key
- def _get_mm_like_fn(snode: BaseSchedulerNode) -> Optional[Callable[[Any], Any]]:
- if not isinstance(snode, ExternKernelSchedulerNode):
- return None
- mms_fns = {
- "extern_kernels.mm": torch.ops.aten.mm,
- "extern_kernels.bmm": torch.ops.aten.bmm,
- "extern_kernels.addmm": torch.ops.aten.addmm,
- }
- python_kernel_name = getattr(snode.node, "python_kernel_name", "")
- if python_kernel_name not in mms_fns:
- return None
- if not isinstance(snode.node, ir.ExternKernel):
- return None
- return mms_fns[python_kernel_name]
- def maybe_estimate_runtime_benchmark(snode: BaseSchedulerNode) -> Optional[float]:
- bench_fn = None
- args_kwargs_fn = None
- if config.runtime_estimations_mms_benchmark:
- mm_fn = _get_mm_like_fn(snode)
- if mm_fn is None:
- return None
- bench_fn = mm_fn
- args_kwargs_fn = lambda: snode_args_kwargs(snode) # noqa: E731
- else:
- return None
- cache_key = get_estimate_runtime_cache_key_from_snode(snode)
- cache = get_estimate_runtime_cache()
- cache_val = cache.lookup(cache_key)
- if cache_val is not None:
- assert isinstance(cache_val, float)
- return cache_val
- from .utils import snode_args_kwargs
- args, kwargs = args_kwargs_fn()
- from torch._inductor.runtime.benchmarking import benchmarker
- ms = benchmarker.benchmark(
- bench_fn,
- args, # pyrefly: ignore[bad-argument-type]
- kwargs,
- memory_warmup_iters=5,
- benchmark_iters=10,
- max_benchmark_duration=10,
- ) # type: ignore[arg-type]
- cache.set_value(cache_key, value=ms)
- return ms
- @dataclasses.dataclass(slots=True)
- class WhyNoFuse:
- name1: str
- name2: str
- reason: str
- args: tuple[Any, ...]
- def __init__(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> None:
- self.name1 = node1.get_name()
- self.name2 = node2.get_name()
- def __call__(self, reason: str, *args: Any) -> None:
- self.reason = reason
- self.args = args
- fusion_log.debug(self)
- def __str__(self) -> str:
- return f"cannot fuse {self.name1} with {self.name2}: " + (
- self.reason % self.args
- )
- def pformat(obj: Any) -> str:
- if isinstance(obj, (OrderedSet, set)): # noqa: set_linter
- # pformat has trouble with sets of sympy exprs
- obj = sorted(obj, key=str)
- result = pprint.pformat(obj, indent=4)
- if "\n" in result:
- return f"\n{textwrap.indent(result, ' ' * 4)}"
- return result
- class OutputNode:
- def __init__(self, dep: StarDep) -> None:
- self.unmet_dependencies = OrderedSet([dep])
- def is_reduction(self) -> bool:
- return False
- def get_inputs_that_alias_output(self) -> Sequence[str]:
- return ()
- def get_name(self) -> str:
- return "OUTPUT"
- __repr__ = get_name
- def _prune_redundant_deps(
- node: BaseSchedulerNode,
- name_to_fused_node: dict[str, BaseSchedulerNode],
- name_to_buf: dict[str, SchedulerBuffer],
- ) -> None:
- """
- Prunes weakdeps intended for mutation ordering
- on an upstream fused node if after fusion there is another dependency
- on the fused upstream node, making the weakdep redundant
- In essence this enforces an ordering on fusions. As fusions occur, weakdeps will
- be incrementally removed, enabling other fusions, ensuring they are fused in order.
- """
- name_to_dep_count: Counter[str] = collections.Counter()
- for dep in node.unmet_dependencies:
- if not isinstance(dep, WeakDep):
- op_name = name_to_buf[dep.name].defining_op_name()
- name_to_dep_count[name_to_fused_node[op_name].get_name()] += 1
- def should_prune(dep: Dep) -> bool:
- if isinstance(dep, WeakDep):
- op_name = name_to_buf[dep.name].defining_op_name()
- is_redundant = name_to_dep_count[
- name_to_fused_node[op_name].get_name()
- ] > 0 and node.scheduler.fusable_weak_dep(
- dep, name_to_fused_node[op_name], node
- )
- # These can occur because fused nodes always gather deps from their snodes
- # If B has a weakdep on A
- # B gets fused with C, then any time BC is fused, the weakdep will reappear
- is_self_dep = name_to_fused_node[op_name] == node
- return is_redundant or is_self_dep
- else:
- return False
- deps_to_prune = OrderedSet(
- dep for dep in node.unmet_dependencies if should_prune(dep)
- )
- if deps_to_prune:
- node.unmet_dependencies = node.unmet_dependencies - deps_to_prune
- node.set_read_writes(node.read_writes.remove_reads(deps_to_prune))
- class ExternKernelSchedulerNode(BaseSchedulerNode):
- def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None:
- super().__init__(scheduler)
- self._init_from_node(node)
- self.set_read_writes(node.get_read_writes())
- def debug_str_extra(self) -> str:
- return f"{self.get_name()}.node.kernel = {getattr(self.node, 'python_kernel_name', None)}"
- def is_extern(self) -> bool:
- return True
- def has_side_effects(self) -> bool:
- assert self.node is not None
- return hasattr(self.node, "has_side_effects") and self.node.has_side_effects()
- class NopKernelSchedulerNode(BaseSchedulerNode):
- def __init__(self, scheduler: Scheduler, node: ir.Operation) -> None:
- super().__init__(scheduler)
- self._init_from_node(node)
- self.set_read_writes(node.get_read_writes())
- class SchedulerNode(BaseSchedulerNode):
- """
- A SchedulerNode is a node for scheduling that encapsulates either
- a ComputedBuffer or a TemplateBuffer.
- """
- _sizes: tuple[Sequence[sympy.Expr], ...]
- _body: LoopBody
- def __init__(
- self,
- scheduler: Scheduler,
- node: Union[ir.ComputedBuffer, ir.TemplateBuffer],
- ) -> None:
- super().__init__(scheduler)
- self._init_from_node(node)
- self._compute_attrs()
- def _compute_attrs(
- self,
- extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None,
- recompute_sizes_body_func: Optional[Callable[_P, _T]] = None,
- ) -> None:
- assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer))
- self._sizes, body = self.node.simplify_and_reorder(
- extra_indexing_constraints=extra_indexing_constraints,
- recompute_sizes_body_func=recompute_sizes_body_func,
- )
- self._body = body # type: ignore[assignment]
- device = self.node.get_device_or_error()
- group_fn = self.scheduler.get_backend(device).group_fn
- self.group = (device, group_fn(self._sizes))
- # Don't normalize since normalization will merge loops which
- # makes it hard to decide new loop orders.
- should_normalize = not config.loop_ordering_after_fusion or not is_gpu(
- device.type
- )
- if isinstance(self.node, ir.TemplateBuffer):
- self.set_read_writes(
- self.node.extract_read_writes(normalize=should_normalize)
- )
- else:
- self.set_read_writes(
- dependencies.extract_read_writes(
- self._body, *self._sizes, normalize=should_normalize
- )
- )
- def recompute_size_and_body(
- self,
- extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None,
- recompute_sizes_body_func: Optional[Callable[..., Any]] = None,
- ) -> None:
- self._compute_attrs(
- extra_indexing_constraints=extra_indexing_constraints,
- recompute_sizes_body_func=recompute_sizes_body_func,
- )
- def refresh_dependencies(
- self, normalize: bool, need_clear_tiling_cache: bool
- ) -> None:
- # Fake dependencies are added manually. They can not be analyzed from
- # extract_read_writes. Find them out and apply manually.
- fake_deps: OrderedSet[Dep] = OrderedSet(
- dep for dep in self.read_writes.reads if isinstance(dep, (WeakDep, StarDep))
- )
- # don't normalize since the loop order may need to be further changed
- # later
- self.set_read_writes(
- dependencies.extract_read_writes(
- self._body, *self._sizes, normalize=normalize
- )
- .with_read(fake_deps)
- .rename(self.mutation_renames)
- )
- self.pointwise_read_writes.clear_cache(self)
- if need_clear_tiling_cache:
- from .codegen.simd import SIMDScheduling
- # TODO(shunting) if this cause compilation time increase when
- # enabling LOAF by default, try just clearing the specific cache
- # entry by using a customized cache implementation rather than
- # lru_cache.
- SIMDScheduling.candidate_tilings.cache_clear()
- def apply_new_loop_order(self, new_order: Sequence[int]) -> None:
- self._body = self._body.reorder_iter_loops(
- new_order,
- )
- self._sizes = self._body.sizes
- self.refresh_dependencies(normalize=False, need_clear_tiling_cache=True)
- def swap_pw_red_dimension(self) -> None:
- num_rdims = self._body.get_original_num_rdims()
- num_pwdims = len(self._body.iter_vars) - num_rdims
- pwdims = tuple(range(num_pwdims))
- rdims = tuple(range(num_pwdims, num_pwdims + num_rdims))
- self.apply_new_loop_order(rdims + pwdims)
- assert len(self.group[1]) == 2
- self.group = self.group[0], (self.group[1][1], self.group[1][0])
- def extract_pw_from_reduction(self) -> BaseSchedulerNode:
- self._body = self._body.extract_pw_from_reduction()
- return self
- def cancel_reduction_split(self) -> None:
- if not MixOrderReduction.is_split_reduction(self):
- return
- assert isinstance(self.node, ir.ComputedBuffer)
- with self.node.with_original_inner_fn():
- self._compute_attrs()
- def expand_dimension_for_pointwise_node(
- self, dimension: int, new_range: int
- ) -> None:
- assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer))
- self._body = self._body.expand_dimension_for_pointwise_node(
- dimension, new_range
- )
- self._sizes = self._body.sizes
- device = self.node.get_device_or_error()
- group_fn = self.scheduler.get_backend(device).group_fn
- self.group = (device, group_fn(self._sizes))
- # Need normalize the prefix name to facilitate finding common dependencies
- self.refresh_dependencies(normalize=True, need_clear_tiling_cache=True)
- def merge_loops(self) -> None:
- self._body = self._body.merge_loops()
- self._sizes = self._body.sizes
- # merge_loops is called after loop reordering.
- # We still need retain fake dependencies since codegen the
- # estimated amount of memory access rely on them.
- #
- # Merge loops does not affect the tiling decision. So we
- # don't need clear the tiling cache.
- self.refresh_dependencies(normalize=True, need_clear_tiling_cache=False)
- def reorder_loops_by_dep_pair(
- self, self_dep: MemoryDep, other_dep: MemoryDep
- ) -> bool:
- new_order = None
- self_sizes = self._sizes[0]
- if len(self_sizes) == self_dep.num_vars == other_dep.num_vars:
- new_order = self_dep.decide_loop_order_to_match(other_dep)
- if new_order:
- # pyrefly: ignore [bad-assignment]
- metrics.num_loop_reordering += 1
- loop_ordering_log.debug(
- "Reorder loops for %s with order %s", self.get_name(), new_order
- )
- self.apply_new_loop_order(new_order)
- return True
- else:
- loop_ordering_log.debug(
- "Don't reordering %s because we can not decide the suitable loop order",
- self.get_name(),
- )
- return False
- def debug_str_extra(self) -> str:
- name = self.get_name()
- lines = [
- f"{name}.group.device = {self.group[0]}",
- f"{name}.group.iteration = {self.group[1]}",
- f"{name}.sizes = {self._sizes}",
- ]
- for dep in self.read_writes.reads_and_writes():
- if not isinstance(dep, WeakDep):
- buf_name = dep.name
- buf = V.graph.get_buffer(buf_name)
- if not isinstance(buf, ir.TorchBindObject):
- lines.append(f"{buf_name}_layout = {pformat(buf.layout)}")
- if isinstance(self._body, LoopBody):
- lines.append(f"class {name}_loop_body:")
- lines.append(textwrap.indent(self._body.debug_str(), " "))
- assert self.node is not None
- lines.extend(self._debug_str_for_device())
- return "\n".join(lines)
- def get_ranges(self) -> Sequence[Sequence[sympy.Expr]]:
- return self._sizes
- def is_reduction(self) -> bool:
- assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)), (
- f"{type(self.node)=}"
- )
- # self._body containing partial accumulate means the reduction is
- # converted to a pointwise node. Need this extra check since
- # we change self._body but didn't change self.node (IRNode)
- # when converting a reduction to a pointwise
- return bool(self.node.get_reduction_type()) and (
- self._body is None or not self._body.has_partial_accumulate
- )
- def is_native_matmul(self) -> bool:
- assert isinstance(self.node, ir.ComputedBuffer), f"{type(self.node)=}"
- return self.node.get_reduction_type() == "dot"
- def is_split_scan(self) -> bool:
- assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)), (
- f"{type(self.node)=}"
- )
- return isinstance(self.node, ir.ComputedBuffer) and isinstance(
- self.node.data, ir.SplitScan
- )
- def is_template(self) -> bool:
- return isinstance(self.node, ir.TemplateBuffer)
- def get_template_node(self) -> Optional[ir.TemplateBuffer]:
- return self.node if isinstance(self.node, ir.TemplateBuffer) else None
- def run(self, *index_vars: Sequence[sympy.Expr]) -> None:
- self.decide_inplace_update()
- self.mark_run()
- self.codegen(index_vars)
- def ranges_from_index_vars(
- self, index_vars: Sequence[Sequence[sympy.Expr]]
- ) -> dict[sympy.Expr, sympy.Expr]:
- sizes = self._sizes
- assert sum(map(len, sizes)) == sum(map(len, index_vars))
- var_ranges = dict(
- zip(
- itertools.chain.from_iterable(index_vars),
- itertools.chain.from_iterable(sizes),
- )
- )
- return var_ranges
- def codegen(self, index_vars: Sequence[Sequence[sympy.Expr]]) -> None:
- """
- Generate code for this node using the provided index variables.
- This method sets up the appropriate context for code generation, including
- simplifying indexing expressions based on the variable ranges, and then
- calls the node's body function with the index variables.
- Args:
- index_vars: A sequence of sequences of sympy expressions representing
- the index variables for each dimension of the computation.
- """
- var_ranges = self.ranges_from_index_vars(index_vars)
- try:
- with (
- V.set_ops_handler(SimplifyIndexing(V.get_ops_handler(), var_ranges)),
- V.kernel.set_current_node(self),
- ):
- self._body(*index_vars)
- except Exception:
- log.fatal("Error in codegen for %s", self.node)
- raise
- def pointwise_or_reduction_read_writes(
- self, pointwise: bool = True
- ) -> dependencies.ReadWrites:
- """
- Get the memory dependencies in either the pointwise or the reduction axes.
- """
- keep_sizes, ignore_sizes = self._sizes if pointwise else reversed(self._sizes)
- return dependencies.extract_read_writes(
- self._body, keep_sizes, hidden_args=[[sympy.S.Zero] * len(ignore_sizes)]
- )
- @cache_on_self
- def pointwise_read_writes(self) -> dependencies.ReadWrites:
- """
- Get the memory dependencies in the non-reduction axes.
- """
- return self.pointwise_or_reduction_read_writes(pointwise=True)
- @cache_on_self
- def reduction_read_writes(self) -> dependencies.ReadWrites:
- """
- Get the memory dependencies in the reduction axes.
- """
- return self.pointwise_or_reduction_read_writes(pointwise=False)
- def can_inplace(self, read_dep: dependencies.Dep) -> bool:
- if self.is_template():
- return False
- if any(out.get_aliases() for out in self.get_outputs()):
- return False
- if len(self.read_writes.writes) == 1 and isinstance(
- read_dep, dependencies.MemoryDep
- ):
- write_dep = next(iter(self.read_writes.writes))
- assert isinstance(write_dep, dependencies.MemoryDep), f"{type(write_dep)=}"
- return read_dep.index == write_dep.index and read_dep.size == write_dep.size
- return False
- @cache_on_self
- def _get_atomic_add_buffers(self) -> OrderedSet[str]:
- buffers_store_as_atomic_add: OrderedSet[str] = OrderedSet()
- if isinstance(self._body, LoopBody):
- for node in self._body.get_nodes():
- if (
- node.op == "call_method"
- and node.target == "store"
- and (
- ("mode" in node.kwargs and node.kwargs["mode"] == "atomic_add")
- or (len(node.args) == 5 and node.args[4] == "atomic_add")
- )
- ):
- buffers_store_as_atomic_add.add(
- node.kwargs["name"]
- if "name" in node.kwargs
- else (node.args[1] if len(node.args) >= 2 else "")
- )
- return buffers_store_as_atomic_add
- @cache_on_self
- def has_side_effects(self) -> bool:
- # self._body is None sometimes that's why this check was added
- if self._body is not None and self._body.has_op("device_assert_async"):
- return True
- return super().has_side_effects()
- def refresh_group_node_dependencies(
- group_snode: Union[FusedSchedulerNode, GroupedSchedulerNode],
- ) -> None:
- snodes = group_snode.snodes
- group_snode.set_read_writes(
- dependencies.ReadWrites.merge_list([x.read_writes for x in snodes])
- )
- group_snode.unmet_dependencies = (
- OrderedSet(
- dep
- for dep in OrderedSet.union(*[x.unmet_dependencies for x in snodes])
- if dep.name not in group_snode.get_buffer_names()
- )
- - group_snode.read_writes.writes
- )
- def init_group_node(
- group_snode: Union[FusedSchedulerNode, GroupedSchedulerNode],
- scheduler: Scheduler,
- snodes: list[BaseSchedulerNode],
- ) -> None:
- assert isinstance(group_snode, (FusedSchedulerNode, GroupedSchedulerNode))
- group_snode.snodes = snodes
- group_snode.scheduler = scheduler
- group_snode.node = None
- group_snode.ancestors = OrderedSet.union(
- *[x.ancestors for x in snodes if x.ancestors is not None]
- )
- refresh_group_node_dependencies(group_snode)
- group_snode.min_order = min(x.min_order for x in group_snode.snodes)
- group_snode.max_order = max(x.max_order for x in group_snode.snodes)
- group_snode.outputs_by_name = {
- buf.get_name(): buf for buf in group_snode.get_outputs()
- }
- class FusedSchedulerNode(BaseSchedulerNode):
- """
- This is a "fake" scheduler node that represents a group of scheduler nodes
- that are meant to be fused together. The way it does this is by maintaining
- its unmet dependencies as the union of its constituent nodes.
- """
- snodes: list[BaseSchedulerNode]
- @classmethod
- def fuse(
- cls, node1: BaseSchedulerNode, node2: BaseSchedulerNode
- ) -> FusedSchedulerNode:
- assert node1.scheduler is node2.scheduler
- assert isinstance(node1, (SchedulerNode, FusedSchedulerNode))
- if node1.is_template() and isinstance(node2, ExternKernelSchedulerNode):
- # Fuse multi outputs template and its outputs
- # * Node1 has memorydep of MultiOutput in reads
- # * Node2 has StarDep of MultiOutput in writes
- # Rewrite the Node2' StarDep to MemoryDep, because calculate score_fusion_memory
- # of the template node and its epilogue requires the same type of dependencies
- assert isinstance(node2.node, MultiOutput)
- assert len(node2.read_writes.writes) == 1
- assert isinstance(next(iter(node2.read_writes.writes)), StarDep)
- name = next(iter(node2.read_writes.writes)).name
- template_nodes = [node for node in node1.get_nodes() if node.is_template()]
- assert len(template_nodes) == 1
- template_node = template_nodes[0]
- assert len(template_node.read_writes.writes) == 1
- write = next(iter(template_node.read_writes.writes))
- assert isinstance(write, MemoryDep)
- node2.read_writes.writes = OrderedSet(
- [
- MemoryDep(
- name, write.index, write.var_names, write.size, write.mode
- ),
- ]
- )
- else:
- assert isinstance(node2, (SchedulerNode, FusedSchedulerNode))
- nodes = list(itertools.chain(node1.get_nodes(), node2.get_nodes()))
- return cls(node1.scheduler, nodes)
- def extract_pw_from_reduction(self) -> BaseSchedulerNode:
- for subnode in self.snodes:
- assert isinstance(subnode, SchedulerNode)
- assert subnode.is_reduction()
- subnode.extract_pw_from_reduction()
- return self
- def swap_pw_red_dimension(self) -> None:
- for subnode in self.snodes:
- assert isinstance(subnode, SchedulerNode)
- subnode.swap_pw_red_dimension()
- @cache_on_self
- def estimate_flops(self) -> int | None:
- # don't increment counters in fused methods so we don't double count
- fps = list(
- filter(
- None,
- (
- node.estimate_flops()
- for node in self.get_nodes()
- if node.is_template() or node.is_extern()
- ),
- )
- )
- if len(fps) == 0:
- return None
- ret = sum(fps)
- return ret
- def reorder_loops_by_dep_pair(
- self, self_dep: MemoryDep, other_dep: MemoryDep
- ) -> bool:
- """
- Return true if a loop reordering is performed.
- """
- if self.is_template():
- # We can not really reorder loops for a triton template
- return False
- self_sizes = None
- for snode in self.snodes:
- assert isinstance(snode, SchedulerNode)
- if self_sizes is not None and tuple(self_sizes) != tuple(snode._sizes[0]):
- loop_ordering_log.debug(
- "Can not reorder fused node due to different sizes"
- )
- return False
- self_sizes = snode._sizes[0]
- new_order = None
- assert self_sizes is not None
- if len(self_sizes) == self_dep.num_vars == other_dep.num_vars:
- new_order = self_dep.decide_loop_order_to_match(other_dep)
- if not new_order:
- loop_ordering_log.debug(
- "Dont reordering fused node %s because we can not decide the suitable loop order",
- self.get_name(),
- )
- return False
- # pyrefly: ignore [bad-assignment]
- metrics.num_loop_reordering += 1
- loop_ordering_log.debug(
- "Reorder loops for fused node %s with order %s", self.get_name(), new_order
- )
- for snode in self.snodes:
- assert isinstance(snode, SchedulerNode)
- snode.apply_new_loop_order(new_order)
- refresh_group_node_dependencies(self)
- return True
- def __init__(self, scheduler: Scheduler, snodes: list[BaseSchedulerNode]) -> None:
- super().__init__(scheduler)
- init_group_node(self, scheduler, snodes)
- self.users: list[NodeUser] = []
- self.group = max(snodes, key=lambda x: int(x.is_reduction())).group
- @cache_on_self
- def get_name(self) -> str:
- return "_".join([x.get_name() for x in self.snodes])
- def get_first_name(self) -> str:
- return self.snodes[0].get_name()
- @cache_on_self
- def get_buffer_names(self) -> OrderedSet[str]:
- return OrderedSet.union(*[x.get_buffer_names() for x in self.snodes])
- def get_outputs(self) -> list[SchedulerBuffer]:
- result: list[SchedulerBuffer] = []
- for node in self.snodes:
- result.extend(node.get_outputs())
- return result
- def debug_str_extra(self) -> str:
- lines = [
- f"{self.get_name()}.snodes[{i}] =\n{node.debug_str()}"
- for i, node in enumerate(self.snodes)
- ]
- node = self.snodes[0].node
- if node is not None:
- lines.extend(self._debug_str_for_device())
- return textwrap.indent("\n".join(lines).rstrip(), " ")
- def debug_str_short(self) -> str:
- snodes_str = [node.debug_str_short() for node in self.snodes]
- return f"{self}, snodes: {snodes_str}"
- def set_last_usage(
- self, future_used_buffers: OrderedSet[str], mutation_real_name: dict[str, str]
- ) -> None:
- # Set self.last_usage using the global information
- # This will be used for inter-kernel optimisations
- super().set_last_usage(future_used_buffers, mutation_real_name)
- # Set self.last_usage on the snodes
- # This will be used for optimisations within the kernel
- future_used_buffers: OrderedSet[str] = OrderedSet()
- for node in reversed(self.snodes):
- node.set_last_usage(future_used_buffers, mutation_real_name)
- future_used_buffers.update(node.last_usage)
- @cache_on_self
- def used_buffer_names(self) -> OrderedSet[str]:
- return OrderedSet.union(*[x.used_buffer_names() for x in self.snodes])
- @cache_on_self
- def used_or_aliased_buffer_names(self) -> OrderedSet[str]:
- return OrderedSet.union(
- *[x.used_or_aliased_buffer_names() for x in self.snodes]
- )
- def get_nodes(self) -> Sequence[BaseSchedulerNode]:
- return self.snodes
- def __repr__(self) -> str:
- return f"{type(self).__name__}(nodes={self.get_name()})"
- @cache_on_self
- def is_reduction(self) -> bool:
- return any(x.is_reduction() for x in self.snodes)
- @cache_on_self
- def is_native_matmul(self) -> bool:
- return any(x.is_native_matmul() for x in self.snodes)
- @cache_on_self
- def is_split_scan(self) -> bool:
- return any(x.is_split_scan() for x in self.snodes)
- @cache_on_self
- def is_template(self) -> bool:
- return any(x.is_template() for x in self.snodes)
- @cache_on_self
- def get_template_node(self) -> Optional[ir.TemplateBuffer]:
- for node in self.snodes:
- if node.is_template():
- return node.get_template_node()
- return None
- def get_device(self) -> torch.device:
- return self.group[0]
- @cache_on_self
- def has_aliasing_or_mutation(self) -> bool:
- return any(x.has_aliasing_or_mutation() for x in self.snodes)
- # None of these need to be implemented, as a FusedSchedulerNode is just an
- # abstraction for scheduling purposes
- def update_mutated_names(self, renames: dict[str, str]) -> None:
- raise NotImplementedError
- def add_fake_dep(self, name: Dep) -> None:
- raise NotImplementedError
- def can_inplace(self, read_dep: dependencies.Dep) -> bool:
- raise NotImplementedError
- def debug_str(self) -> str:
- """Longer form printout for trace logs"""
- name = self.get_name()
- node_typestr = ",".join(type(n).__name__ for n in self.snodes)
- buf = IndentedBuffer()
- buf.splice(
- f"""\
- {name}: {type(self).__name__}({node_typestr})
- {name}.writes = {pformat(self.read_writes.writes)}
- {name}.unmet_dependencies = {pformat(self.unmet_dependencies)}
- {name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)}
- {name}.outputs = [
- """
- )
- with buf.indent():
- for out in self.get_outputs():
- buf.splice(out.debug_str())
- buf.writeline("]")
- try:
- buf.splice(self.debug_str_extra())
- except Exception:
- log.warning("Ignoring error in debug_str()", exc_info=True)
- return buf.getrawvalue().rstrip()
- @cache_on_self
- def has_side_effects(self) -> bool:
- if self.snodes is not None:
- return any(node.has_side_effects() for node in self.snodes)
- return super().has_side_effects()
- class FusedMixOrderReductions(FusedSchedulerNode):
- def __init__(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> None:
- if not MixOrderReduction.is_contiguous_node(node1):
- assert MixOrderReduction.is_contiguous_node(node2)
- node1, node2 = node2, node1
- self.node1 = node1
- self.node2 = node2
- super().__init__(
- node1.scheduler, list(node1.get_nodes()) + list(node2.get_nodes())
- )
- self.numel = MixOrderReduction.get_numel(self.node1)
- def sub_node_can_fuse(
- self,
- node1: BaseSchedulerNode,
- node2: BaseSchedulerNode,
- other_nodes: tuple[BaseSchedulerNode, ...],
- ):
- """
- node1 is from the current mix order reduction; node2 is another node we want to fuse in.
- other_nodes are passed in to check if fusion will introduce producer/consumer relationship
- between the inner and outer reduction. If yes, we don't fuse.
- """
- assert not isinstance(node1, FusedMixOrderReductions)
- assert not isinstance(node2, FusedMixOrderReductions)
- # When we fuse extra nodes into a FusedMixOrderReductions node,
- # we should not allow recursive mix-order reduction being
- # created.
- if not self.scheduler.can_fuse(node1, node2, allow_mix_order_reduction=False):
- return False
- # Since node1 is from the current mix order reduction, if node1 is
- # contiguous, the fused node should also be contiguous.
- if MixOrderReduction.is_contiguous_node(
- node1
- ) and not MixOrderReduction.is_contiguous_node(node2):
- return False
- def _get_ancestors(nodes: tuple[BaseSchedulerNode, ...]) -> OrderedSet[str]:
- out = OrderedSet()
- return out.union(*(n.ancestors for n in nodes))
- def _get_operation_names(
- nodes: tuple[BaseSchedulerNode, ...],
- ) -> OrderedSet[str]:
- out = OrderedSet()
- return out.union(*(n.get_operation_names() for n in nodes))
- if other_nodes:
- if (_get_ancestors((node1, node2)) & _get_operation_names(other_nodes)) or (
- _get_ancestors(other_nodes) & _get_operation_names((node1, node2))
- ):
- return False
- return (
- not node2.is_reduction()
- or typing.cast(
- int, self.scheduler.score_fusion_memory(node1, node2, count_bytes=False)
- )
- >= self.numel
- )
- def can_fuse_with(self, other: BaseSchedulerNode):
- if not isinstance(other, FusedMixOrderReductions):
- return self.sub_node_can_fuse(
- self.node1, other, (self.node2,)
- ) or self.sub_node_can_fuse(self.node2, other, (self.node1,))
- else:
- # pass empty tuple for the second since the producer/consumer relationship has
- # already been checked in the first call
- return self.sub_node_can_fuse(
- self.node1, other.node1, (self.node2, other.node2)
- ) and self.sub_node_can_fuse(self.node2, other.node2, tuple())
- def fuse_with(self, other: BaseSchedulerNode):
- device = self.node1.get_device()
- backend = self.scheduler.get_backend(device)
- if isinstance(other, FusedMixOrderReductions):
- fused_node1 = backend.fuse(self.node1, other.node1)
- fused_node2 = backend.fuse(self.node2, other.node2)
- return FusedMixOrderReductions(fused_node1, fused_node2)
- else:
- if self.sub_node_can_fuse(self.node1, other, (self.node2,)):
- fused_node = backend.fuse(self.node1, other)
- return FusedMixOrderReductions(fused_node, self.node2)
- else:
- fused_node = backend.fuse(self.node2, other)
- return FusedMixOrderReductions(self.node1, fused_node)
- class ForeachKernelSchedulerNode(FusedSchedulerNode):
- """
- This is a schedular node that consists of a set of scheduler nodes that
- has no data dependencies among them and can be executed in parallel.
- """
- def get_consumer_subnode_for(
- self, producer: BaseSchedulerNode
- ) -> Optional[BaseSchedulerNode]:
- for buf in producer.get_outputs():
- if buf.get_name() in self.read_to_node:
- return self.read_to_node[buf.get_name()]
- return None
- def get_producer_subnode_for(
- self, consumer: BaseSchedulerNode
- ) -> Optional[BaseSchedulerNode]:
- producers = OrderedSet[BaseSchedulerNode]()
- for rd in consumer.read_writes.reads:
- if rd.name not in self.scheduler.name_to_buf:
- continue
- node_name = self.scheduler.name_to_buf[rd.name].defining_op_name()
- if node_name in self.name_to_node:
- producers.add(self.name_to_node[node_name])
- # Don't permit fusion if there are multiple subnodes
- # that this consumer reads from
- if len(producers) == 1:
- return next(iter(producers))
- else:
- return None
- @classmethod
- def can_fuse(cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode) -> bool:
- why = WhyNoFuse(producer, consumer)
- if producer.is_foreach() and consumer.is_foreach():
- producer = typing.cast(ForeachKernelSchedulerNode, producer)
- consumer = typing.cast(ForeachKernelSchedulerNode, consumer)
- foreach_match = len(producer.snodes) == len(consumer.snodes)
- if not foreach_match:
- why("foreach do not have same length")
- return foreach_match and all(
- producer.scheduler.can_fuse(l, r)
- for l, r in zip(producer.snodes, consumer.snodes)
- )
- elif consumer.is_foreach():
- if producer.is_reduction():
- why(
- "candidate producer is a reduction, foreach ops cannot be fused with reductions currently"
- )
- return False
- consumer = typing.cast(ForeachKernelSchedulerNode, consumer)
- consumer_subnode = consumer.get_consumer_subnode_for(producer)
- if consumer_subnode is not None:
- return consumer.scheduler.can_fuse(producer, consumer_subnode)
- why("candidate producer is not dep of any foreach consumer")
- return False
- elif producer.is_foreach():
- if consumer.is_reduction():
- why(
- "candidate consumer is a reduction, foreach ops cannot be fused with reductions currently"
- )
- return False
- producer = typing.cast(ForeachKernelSchedulerNode, producer)
- producer_subnode = producer.get_producer_subnode_for(consumer)
- if producer_subnode is not None:
- return producer.scheduler.can_fuse(producer_subnode, consumer)
- why("candidate consumer has no dep in any foreach producer")
- return False
- raise AssertionError(
- "At least one node passed to ForeachKernelSchedulerNode.can_fuse should be a foreach node"
- )
- @classmethod
- def fuse(
- cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode
- ) -> ForeachKernelSchedulerNode:
- assert producer.is_foreach() or consumer.is_foreach()
- if producer.is_foreach():
- producer = typing.cast(ForeachKernelSchedulerNode, producer)
- use_custom_partition_algo = producer.use_custom_partition_algo
- enable_autotune = producer.enable_autotune
- else:
- consumer = typing.cast(ForeachKernelSchedulerNode, consumer)
- use_custom_partition_algo = consumer.use_custom_partition_algo
- enable_autotune = consumer.enable_autotune
- prev_node_1 = None
- prev_node_2 = None
- fused_nodes: list[BaseSchedulerNode]
- if producer.is_foreach() and consumer.is_foreach():
- producer = typing.cast(ForeachKernelSchedulerNode, producer)
- consumer = typing.cast(ForeachKernelSchedulerNode, consumer)
- fused_nodes = [
- FusedSchedulerNode.fuse(l, r)
- for l, r in zip(producer.snodes, consumer.snodes)
- ]
- elif producer.is_foreach():
- producer = typing.cast(ForeachKernelSchedulerNode, producer)
- producer_subnode = producer.get_producer_subnode_for(consumer)
- fused_nodes = []
- prev_node_1 = producer
- prev_node_2 = None
- for node in producer.snodes:
- if node is producer_subnode:
- new_node = FusedSchedulerNode.fuse(node, consumer)
- prev_node_2 = new_node
- fused_nodes.append(new_node)
- else:
- fused_nodes.append(node)
- elif consumer.is_foreach():
- consumer = typing.cast(ForeachKernelSchedulerNode, consumer)
- consumer_subnode = consumer.get_consumer_subnode_for(producer)
- fused_nodes = []
- prev_node_1 = consumer
- prev_node_2 = None
- for node in consumer.snodes:
- if node is consumer_subnode:
- new_node = FusedSchedulerNode.fuse(producer, node)
- prev_node_2 = new_node
- fused_nodes.append(new_node)
- else:
- fused_nodes.append(node)
- else:
- raise AssertionError(
- "At least one node passed to ForeachKernelSchedulerNode.fuse should be a foreach node"
- )
- return cls(
- producer.scheduler,
- fused_nodes,
- use_custom_partition_algo=use_custom_partition_algo,
- prev_node_1=prev_node_1,
- prev_node_2=prev_node_2,
- enable_autotune=enable_autotune,
- )
- def __init__(
- self,
- scheduler: Scheduler,
- snodes: list[BaseSchedulerNode],
- use_custom_partition_algo: bool,
- prev_node_1: Optional[BaseSchedulerNode] = None,
- prev_node_2: Optional[BaseSchedulerNode] = None,
- enable_autotune: bool = False,
- ) -> None:
- self.read_to_node = {}
- self.name_to_node = {}
- if prev_node_1 is None or prev_node_2 is None:
- super().__init__(scheduler, snodes)
- for node in snodes:
- for read in node.read_writes.reads:
- self.read_to_node[read.name] = node
- for name in node.get_operation_names():
- self.name_to_node[name] = node
- else:
- self.scheduler = scheduler
- self.snodes = snodes
- self.node = None
- self.users: list[NodeUser] = []
- self.set_read_writes(
- dependencies.ReadWrites.merge_list(
- [prev_node_1.read_writes, prev_node_2.read_writes]
- )
- )
- self.unmet_dependencies = (
- OrderedSet(
- dep
- for dep in OrderedSet.union(
- prev_node_1.unmet_dependencies, prev_node_2.unmet_dependencies
- )
- if dep.name not in self.get_buffer_names()
- )
- - self.read_writes.writes
- )
- self.min_order = min([prev_node_1.min_order, prev_node_2.min_order])
- self.max_order = max([prev_node_1.max_order, prev_node_2.max_order])
- if prev_node_1.is_foreach():
- assert isinstance(prev_node_1, ForeachKernelSchedulerNode)
- foreach_node, other_node = prev_node_1, prev_node_2
- else:
- assert isinstance(prev_node_2, ForeachKernelSchedulerNode)
- foreach_node, other_node = prev_node_2, prev_node_1
- self.ancestors = foreach_node.ancestors
- self.ancestors.update(other_node.ancestors)
- self.name_to_node = foreach_node.name_to_node
- for name in other_node.get_operation_names():
- self.name_to_node[name] = other_node
- self.outputs_by_name: dict[str, SchedulerBuffer] = {
- k: v for snode in self.snodes for k, v in snode.outputs_by_name.items()
- }
- self.use_custom_partition_algo = use_custom_partition_algo
- device = snodes[0].get_device()
- assert device
- self.group = (device, ((sympy.Expr("combo_kernel"),),))
- self.origins = OrderedSet[torch.fx.Node]()
- self.enable_autotune = enable_autotune
- @classmethod
- def combinable_nodes(
- cls, nodes: list[BaseSchedulerNode]
- ) -> list[BaseSchedulerNode]:
- extern = [x for x in nodes if isinstance(x, ExternKernelSchedulerNode)]
- if extern:
- log.debug(
- "ComboKernels: %d external nodes are filtered %s",
- len(extern),
- [node.node.get_origins() for node in extern if node.node is not None],
- )
- grouped = [x for x in nodes if isinstance(x, GroupedSchedulerNode)]
- if grouped:
- log.debug(
- "ComboKernels: %d grouped nodes are filtered",
- len(grouped),
- )
- mix_order = [x for x in nodes if isinstance(x, FusedMixOrderReductions)]
- if mix_order:
- log.debug(
- "ComboKernels: %d FusedMixOrderReductions nodes are filtered",
- len(mix_order),
- )
- filtered_nodes = [
- x
- for x in nodes
- if not isinstance(
- x,
- (
- NopKernelSchedulerNode,
- ExternKernelSchedulerNode,
- GroupedSchedulerNode,
- FusedMixOrderReductions,
- ),
- )
- ]
- foreach_nodes = [
- x for x in filtered_nodes if isinstance(x, ForeachKernelSchedulerNode)
- ]
- if foreach_nodes:
- log.debug("ComboKernels: %d foreach nodes are filtered", len(foreach_nodes))
- filtered_nodes = [
- x for x in filtered_nodes if not isinstance(x, ForeachKernelSchedulerNode)
- ]
- template_nodes = [x for x in filtered_nodes if x.is_template()]
- if template_nodes:
- log.debug(
- "ComboKernels: %d template nodes are filtered: %s",
- len(template_nodes),
- template_nodes,
- )
- filtered_nodes = [x for x in filtered_nodes if x not in template_nodes]
- # Filter out reduction nodes if combo_kernels_pointwise_only is enabled
- if config.combo_kernels_pointwise_only:
- reduction_nodes = [x for x in filtered_nodes if x.is_reduction()]
- if reduction_nodes:
- log.debug(
- "ComboKernels: %d reduction nodes are filtered (pointwise_only mode)",
- len(reduction_nodes),
- )
- filtered_nodes = [x for x in filtered_nodes if not x.is_reduction()]
- return filtered_nodes
- @staticmethod
- def _default_group_nodes_for_combo_kernels(
- scheduler: Scheduler,
- ) -> list[list[BaseSchedulerNode]]:
- """
- Returns a list of lists of nodes that are to be grouped together.
- """
- sorted_nodes = scheduler._topological_sort_nodes()
- grouped_nodes = []
- max_num_nodes = 8
- excluded_buffer_names: OrderedSet[str] = OrderedSet(
- [
- buf_name
- for group in sorted_nodes
- for node in group
- if isinstance(node, FusedMixOrderReductions)
- for buf_name in node.get_buffer_names()
- ]
- )
- for nodes in sorted_nodes:
- # Group nodes by device first to avoid mixed-device fusion
- device_groups: dict[Optional[torch.device], list[BaseSchedulerNode]] = (
- defaultdict(list)
- )
- for node in nodes:
- device = node.get_device()
- if device and (device.type == "mps" or device.type == "cpu"):
- continue
- # exclude nodes that read from FusedMixOrderReductions output buffers'
- if node.used_buffer_names() & excluded_buffer_names:
- continue
- device_groups[device].append(node)
- # Chunk each device group separately
- for device_nodes in device_groups.values():
- grouped_nodes.extend(
- [
- device_nodes[i : i + max_num_nodes]
- for i in range(0, len(device_nodes), max_num_nodes)
- ]
- )
- return grouped_nodes
- group_algorithm_for_combo_kernels: Callable[
- [Scheduler], list[list[BaseSchedulerNode]]
- ] = _default_group_nodes_for_combo_kernels
- @staticmethod
- def set_group_algorithm_for_combo_kernels(
- custom_group_algorithm: Callable[[Scheduler], list[list[BaseSchedulerNode]]],
- ) -> None:
- ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels = (
- custom_group_algorithm
- )
- @staticmethod
- def group_nodes_for_combo_kernels(
- scheduler: Scheduler,
- ) -> list[list[BaseSchedulerNode]]:
- return ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels(scheduler)
- def mark_run(self) -> None:
- raise NotImplementedError
- def codegen(self) -> None:
- raise NotImplementedError
- def is_foreach(self) -> bool:
- return True
- def get_subkernel_nodes(self) -> list[BaseSchedulerNode]:
- """Returns a list of nodes which comprise the combo kernel.
- These nodes may be vertically fused."""
- return list(self.snodes)
- def get_nodes(self) -> Sequence[BaseSchedulerNode]:
- """Returns all nodes contained in this kernel, unpacking fused nodes
- into their constituent scheduler nodes."""
- return list(itertools.chain.from_iterable(x.get_nodes() for x in self.snodes))
- def get_first_name(self) -> str:
- return self.snodes[0].get_first_name()
- def prune_redundant_deps(
- self, name_to_fused_node: dict[str, BaseSchedulerNode]
- ) -> None:
- _prune_redundant_deps(self, name_to_fused_node, self.scheduler.name_to_buf)
- for node in self.snodes:
- node.prune_redundant_deps(name_to_fused_node)
- class GroupedSchedulerNode(BaseSchedulerNode):
- """
- This is a "fake" scheduler node that represents a group of scheduler nodes
- that are meant to be *grouped* together (it does not allow another node to be scheduled
- in between its constituent nodes, nor does it allow another node to fuse into any of its constituent nodes).
- The way it does this is by maintaining its unmet dependencies as the union of its constituent nodes.
- Fusion will still happen among the nodes within each GroupedSchedulerNode.
- At codegen time, this scheduler node will be unpacked and codegen is called on each constituent node.
- """
- snodes: list[BaseSchedulerNode]
- @classmethod
- def create(cls, snodes: list[BaseSchedulerNode]) -> GroupedSchedulerNode:
- scheduler = snodes[0].scheduler
- assert all(node.scheduler is scheduler for node in snodes)
- grouped_snode = cls(scheduler, snodes)
- for snode in snodes:
- scheduler.name_to_fused_node[snode.get_name()] = grouped_snode
- scheduler.name_to_fused_node[grouped_snode.get_name()] = grouped_snode
- return grouped_snode
- def __init__(
- self,
- scheduler: Scheduler,
- snodes: list[BaseSchedulerNode],
- temp_grouping: bool = False,
- ) -> None:
- super().__init__(scheduler)
- init_group_node(self, scheduler, snodes)
- # This flag is introduced for "temporary" grouping during some passes,
- # Where nodes are grouped and moved together.
- # After the pass those nodes are flattened.
- # Reusing calculation of grouped unmed_dependencies etc.
- # No fusion logic in this case.
- self.temp_grouping = temp_grouping
- def unpack(self) -> list[BaseSchedulerNode]:
- """
- Do fusion among nodes within this GroupedSchedulerNode,
- and then unpack this GroupedSchedulerNode into regular nodes.
- """
- if self.temp_grouping:
- return self.snodes
- for snode in self.snodes:
- self.scheduler.name_to_fused_node[snode.get_name()] = snode
- del self.scheduler.name_to_fused_node[self.get_name()]
- return self.scheduler.fuse_nodes(self.snodes)
- def add_fake_dep(self, fake_dep: Dep) -> None:
- self.set_read_writes(self.read_writes.with_read(fake_dep))
- self.unmet_dependencies.add(fake_dep)
- @cache_on_self
- def get_name(self) -> str:
- return "_".join([x.get_name() for x in self.snodes])
- def get_first_name(self) -> str:
- return self.snodes[0].get_name()
- @cache_on_self
- def get_buffer_names(self) -> OrderedSet[str]:
- return OrderedSet.union(*[x.get_buffer_names() for x in self.snodes])
- def get_outputs(self) -> list[SchedulerBuffer]:
- result: list[SchedulerBuffer] = []
- for node in self.snodes:
- result.extend(node.get_outputs())
- return result
- @cache_on_self
- def estimate_flops(self) -> int | None:
- # don't increment counters in fused methods so we don't double count
- fps = list(
- filter(
- None,
- (
- node.estimate_flops()
- for node in self.get_nodes()
- if node.is_template() or node.is_extern()
- ),
- )
- )
- if len(fps) == 0:
- return None
- ret = sum(fps)
- return ret
- def get_nodes(self) -> Sequence[BaseSchedulerNode]:
- return self.snodes
- def get_device(self) -> Optional[torch.device]:
- return self.snodes[0].get_device() if self.snodes else None
- @classmethod
- def can_fuse(cls, producer: BaseSchedulerNode, consumer: BaseSchedulerNode) -> bool:
- # GroupedSchedulerNode cannot be fused with another node
- return False
- def pick_loop_order(
- stride_lengths: list[list[int]],
- sizes: Sequence[sympy.Expr],
- priority_idx: Sequence[int] = (),
- ) -> list[int]:
- """
- A heuristic to decide loop iteration orders. This has not been well
- tuned and may be something we should autotune.
- """
- @functools.cmp_to_key
- def index_cmp(a: int, b: int) -> int:
- if sizes[a] == 1 or sizes[b] == 1:
- # 1-sizes don't matter, just move them to the end
- return cmp(sizes[a] == 1, sizes[b] == 1)
- # Take abs, otherwise flipped dimensions are treated as smaller
- # strides than contiguous dims
- stride_len_a = [abs(sl[a]) for sl in stride_lengths]
- stride_len_b = [abs(sl[b]) for sl in stride_lengths]
- # equivalent to
- # np.logical_or(stride_lengths[:, b] == 0, stride_lengths[:, a] < stride_lengths[:, b]).all()
- a_first = sum(
- sl_b == 0 or sl_a < sl_b for sl_a, sl_b in zip(stride_len_a, stride_len_b)
- )
- b_first = sum(
- sl_a == 0 or sl_b < sl_a for sl_a, sl_b in zip(stride_len_a, stride_len_b)
- )
- if a_first > b_first:
- return -1
- if b_first > a_first:
- return 1
- # otherwise contiguous
- return cmp(b, a)
- order = list(reversed(range(len(stride_lengths[0]))))
- if len(priority_idx) > 0:
- # if we have priority node, only use that node's order
- stride_lengths = [stride_lengths[pi] for pi in priority_idx]
- if config.pick_loop_orders:
- order.sort(key=index_cmp)
- return order
- def _replace_operation_buffer(
- orig_node: ir.MultiTemplateBuffer, new_node: ir.OperationBuffer
- ) -> None:
- replaced_buf_name = new_node.get_name()
- orig_buf_name = orig_node.get_name()
- assert isinstance(orig_buf_name, str) and isinstance(replaced_buf_name, str)
- replaced_op_name = new_node.get_operation_name()
- orig_op_name = orig_node.get_operation_name()
- assert isinstance(orig_op_name, str) and isinstance(replaced_op_name, str)
- del V.graph.name_to_buffer[replaced_buf_name]
- new_node.name = orig_buf_name
- del V.graph.name_to_op[replaced_op_name]
- new_node.operation_name = orig_op_name
- orig = V.graph.buffers.index(orig_node)
- V.graph.buffers.remove(new_node)
- V.graph.buffers[orig] = new_node
- V.graph.name_to_buffer[orig_buf_name] = new_node
- orig = V.graph.operations.index(orig_node)
- V.graph.operations.remove(new_node)
- V.graph.operations[orig] = new_node
- V.graph.name_to_op[orig_op_name] = new_node
- def _estimate_fused_epilogue_runtime(node1, node2, epilogue_runtime) -> float:
- # If no extra memory read by epilogue, assume epilogue is free
- # if extra memory is read by epilogue, add to minimum choice
- total_read_bytes = node2.get_read_buffer_sizes()
- template_write_bytes = node1.get_write_buffer_sizes()
- extra_bytes = total_read_bytes - template_write_bytes
- extra_bytes_ratio = extra_bytes / template_write_bytes
- # Smoothly approaches 1 as extra_bytes_ratio increases
- extra_memory_ratio = extra_bytes_ratio / (1 + extra_bytes_ratio)
- return extra_memory_ratio * epilogue_runtime
- @dataclasses.dataclass
- class NodeUser:
- node: Union[BaseSchedulerNode, OutputNode]
- can_inplace: bool = False
- # A weak user must be scheduled after a given node, but doesn't actually
- # use the result
- is_weak: bool = False
- def __hash__(self) -> int:
- return hash((self.node.get_name(), self.can_inplace, self.is_weak))
- def __eq__(self, other: object) -> bool:
- return (
- isinstance(other, NodeUser)
- and self.get_name() == other.get_name()
- and self.can_inplace == other.can_inplace
- and self.is_weak == other.is_weak
- )
- def get_name(self) -> str:
- return self.node.get_name()
- def merge(self, other: NodeUser) -> NodeUser:
- assert self.node is other.node
- return NodeUser(
- self.node,
- self.can_inplace and other.can_inplace,
- self.is_weak and other.is_weak,
- )
- _post_grad_graph_counter = itertools.count()
- def used_non_deterministic_runtime_estimations() -> bool:
- return config.runtime_estimations_mms_benchmark
- def get_layout_symints(node: ir.IRNode) -> OrderedSet[sympy.Symbol]:
- """Get free symbols from a node's layout (size, stride, offset)."""
- free_symbol_uses: OrderedSet[sympy.Symbol] = OrderedSet()
- layout = node.maybe_get_layout()
- if isinstance(layout, ir.Layout):
- free_symbol_uses.update(
- free_symbols(layout.size)
- | free_symbols(layout.stride)
- | free_symbols(layout.offset)
- )
- if isinstance(layout, ir.MutationLayoutSHOULDREMOVE):
- # symint may be used as index in layout.target
- free_symbol_uses.update(get_layout_symints(layout.target))
- else:
- assert layout is None, f"Expect layout to be None but found layout={layout}"
- return free_symbol_uses
- def get_scheduler_node_symbol_uses(
- node: BaseSchedulerNode,
- ) -> OrderedSet[sympy.Symbol]:
- """
- Gets symbols used in a scheduler node, including free symbols from
- the node's operations and layout symints from outputs.
- """
- if isinstance(node, FusedSchedulerNode):
- return OrderedSet().union(
- *(get_scheduler_node_symbol_uses(snode) for snode in node.snodes)
- )
- assert node.node is not None
- free_symbol_uses = node.node.get_free_symbol_uses()
- free_symbol_uses.update(
- *(get_layout_symints(ir_node) for ir_node in node.node.get_outputs())
- )
- return free_symbol_uses
- def is_epilogue_fusion(node1: BaseSchedulerNode, node2: BaseSchedulerNode):
- return node1.is_template() and config.epilogue_fusion and not node2.is_template()
- def is_prologue_fusion(node1: BaseSchedulerNode, node2: BaseSchedulerNode):
- return node2.is_template() and config.prologue_fusion and not node1.is_template()
- def is_template_fusion(node1: BaseSchedulerNode, node2: BaseSchedulerNode):
- return is_epilogue_fusion(node1, node2) or is_prologue_fusion(node1, node2)
- def template_fusion_pw_node(node1: BaseSchedulerNode, node2: BaseSchedulerNode):
- return node2 if is_epilogue_fusion(node1, node2) else node1
- class Scheduler:
- """
- A Scheduler is a graph of BaseSchedulerNodes. It is responsible for
- optimizations such as fusion, reorder, and graph partition.
- """
- def __init__(self, nodes: list[ir.Operation]) -> None:
- with dynamo_timed("Scheduler.__init__"):
- self._init(nodes)
- def _init(self, nodes: list[ir.Operation]) -> None:
- super().__init__()
- V.graph.scheduler = self
- self.backends: dict[torch.device, BaseScheduling] = {}
- self.post_grad_graph_id = next(_post_grad_graph_counter)
- self._graph_partition_counter = itertools.count()
- self.completed_operations: OrderedSet[str] = OrderedSet()
- self.available_buffer_names = OrderedSet(
- [
- *V.graph.graph_inputs.keys(),
- *V.graph.constants.keys(),
- *V.graph.torchbind_constants.keys(),
- ]
- )
- self.nodes = [self.create_scheduler_node(n) for n in nodes]
- self.previous_node: Optional[BaseSchedulerNode] = None
- self.current_node: Optional[BaseSchedulerNode] = None
- self.update_zero_dim_cpu_tensor()
- # some new constants could have been created above
- self.available_buffer_names.update(V.graph.constants.keys())
- for node in self.nodes:
- node.prune_deps()
- # See [Note: Graph Partition Device Contexts]
- self.default_device_context: Optional[torch.device] = None
- self.name_to_donated_buffer: dict[str, SchedulerDonatedBuffer] = (
- self.get_donated_buffers()
- )
- self.name_to_node: dict[str, BaseSchedulerNode] = {
- n.get_name(): n for n in self.nodes
- }
- self.name_to_buf: dict[str, SchedulerBuffer] = {
- buf.get_name(): buf for node in self.nodes for buf in node.get_outputs()
- }
- self.name_to_fused_node: dict[str, BaseSchedulerNode] = self.name_to_node.copy()
- # mutation_real_name: Maps back to the original name for codegen
- # Example:
- # If you mutate buf0 inside of buf1's kernel, then:
- # mutation_real_name = {"buf0" : "buf1"}
- # all subsequent uses of buf0 become buf1's usage in dependency graph
- self.mutation_real_name: dict[str, str] = {}
- # We handle mutation by renaming modified versions of the same
- # buffer in the dependency graph to prevent cycles.
- # mutation_renames: tracks the current name for a given buffer
- # (changed once per mutation)
- # Example:
- # If you mutate buf0 inside of buf1's kernel, then:
- # mutation_renames = {"buf1" : "buf0"}
- # in codegen we only use buf0, never buf1
- self.mutation_renames: dict[str, str] = {}
- self.seen_template_fusions: OrderedSet[
- tuple[BaseSchedulerNode, BaseSchedulerNode]
- ] = OrderedSet()
- # Must run first to correctly set dependencies, before all other passes that rely on
- # reading from .read_writes.reads or .unmet_dependencies
- self.nodes = comms.decide_global_ordering_of_comms(
- self.nodes,
- self.name_to_buf,
- self.name_to_fused_node,
- )
- self.compute_dependencies()
- self.nodes = self.topological_sort_schedule(self.nodes)
- self.dead_node_elimination()
- self.name_to_fused_node = {n.get_name(): n for n in self.nodes}
- self.compute_ancestors()
- # pyrefly: ignore [bad-assignment]
- metrics.ir_nodes_pre_fusion += len(self.nodes)
- from torch._inductor.debug import log_ir_post_fusion, log_ir_pre_fusion
- log_ir_pre_fusion(self.nodes)
- self.num_orig_nodes = len(self.nodes)
- self.create_foreach_nodes()
- self.nodes = self.topological_sort_schedule(self.nodes)
- self.logged_slow_fusion = OrderedSet[tuple[str, str]]()
- if config._pre_fusion_custom_pass is not None:
- self.nodes = config._pre_fusion_custom_pass(self.nodes)
- if config.distributed_max_autotune_gemm:
- from . import distributed_autotune
- distributed_autotune.schedule(self)
- self.compute_ancestors()
- self.nodes = self.fuse_nodes(self.nodes)
- if config._post_fusion_custom_pass is not None:
- self.nodes = config._post_fusion_custom_pass(self.nodes)
- self.merge_loops()
- self.finalize_multi_template_buffers()
- if (
- config.max_autotune_gemm or config.max_autotune
- ) and use_pipelined_autotuning():
- torch._inductor.select_algorithm.PrecompileThreadPool.shutdown_instance()
- if config.combo_kernels:
- with dynamo_timed(
- "Scheduler.create_combo_kernel_nodes",
- log_pt2_compile_event=True,
- log_waitcounter=True,
- ):
- self.create_combo_kernel_nodes(num_ck_nodes=None)
- # Peak memory pass and overlap pass must run last, otherwise
- # other reordering passes could undo their effects.
- if config.reorder_for_peak_memory:
- from .memory import reorder_for_peak_memory
- self.nodes = reorder_for_peak_memory(
- self.nodes,
- self.name_to_buf,
- self.name_to_fused_node,
- OrderedSet(V.graph.graph_inputs.keys()),
- OrderedSet(V.graph.get_output_names()),
- )
- # reorder_for_compute_comm_overlap may do benchmarking to estimate
- # op runtime. Disable it for now in deterministic mode.
- if not config.deterministic and config.reorder_for_compute_comm_overlap:
- if not config.reorder_for_peak_memory:
- from .memory import assign_memory_planning_info_for_scheduler_buffers
- assign_memory_planning_info_for_scheduler_buffers(
- self.nodes, self.name_to_buf
- )
- if (
- used_non_deterministic_runtime_estimations()
- and config_comms.runtime_estimations_align_across_all_distributed_ranks
- and (
- config.runtime_estimations_mms_benchmark
- or config_comms.runtime_estimations_use_nccl_lib_estimations
- )
- ):
- has_collectives = False
- for node in self.nodes:
- if is_collective(node.node):
- has_collectives = True
- break
- if has_collectives:
- from .comms import (
- align_runtime_estimations_across_all_distributed_ranks,
- )
- align_runtime_estimations_across_all_distributed_ranks(self.nodes)
- # pyrefly: ignore [unbound-name]
- if config_comms.reorder_sink_verbose_logging:
- from torch._logging import trace_structured
- trace_structured(
- "artifact",
- metadata_fn=lambda: {
- "name": "scheduler_nodes_before_comm_overlap",
- "encoding": "string",
- },
- payload_fn=lambda: "\n\n".join(
- [
- f"snode[{i}]"
- + n.debug_str()
- + f" buffer_names:{n.get_buffer_names()}"
- for i, n in enumerate(self.nodes)
- ]
- ),
- )
- self.nodes = comms.reorder_compute_and_comm_for_overlap(self.nodes)
- self.process_grouped_nodes()
- if (
- # pyrefly: ignore[unbound-name]
- config.graph_partition
- # pyrefly: ignore[unbound-name]
- and config.triton.cudagraphs
- # pyrefly: ignore[unbound-name]
- and config.triton.reorder_for_reducing_graph_partitions
- ):
- self.nodes = self.maybe_reorder_for_minimizing_partition(self.nodes)
- self.nodes = self.reorder_for_partition_with_simple_dependency(self.nodes)
- self.compute_last_usage()
- if torch._inductor.config.test_configs.track_memory_lifecycle:
- self.insert_memory_check_nodes()
- log_ir_post_fusion(self.nodes)
- # pyrefly: ignore[unbound-name]
- V.debug.graph_diagram(self.nodes)
- self.debug_draw_graph()
- # used during codegen:
- self.buffer_names_to_free: OrderedSet[str] = OrderedSet()
- # fx graph node to the position it appears in the graph
- # for debug attribution
- self.origin_to_index: dict[torch.fx.Node, int] = {}
- get_metric_table("graph_stats").add_row(
- lambda: {
- "graph_id": self.post_grad_graph_id,
- "num_nodes_before_fusion": self.num_orig_nodes,
- "num_nodes_after_fusion": len(self.nodes),
- }
- )
- # Unlike V.graph.removed_buffers, the op recorded here is removed but
- # we still need the buffer (generated in alternative ways)
- self.removed_ops: OrderedSet[str] = OrderedSet()
- def get_donated_buffers(self) -> dict[str, SchedulerDonatedBuffer]:
- name_to_donated_buf = {}
- for name in V.graph.graph_inputs_original:
- if isinstance(V.graph.graph_inputs_original[name], ir.DonatedBuffer):
- name_to_donated_buf[name] = SchedulerDonatedBuffer(
- self,
- V.graph.graph_inputs_original[name],
- defining_op=None,
- )
- return name_to_donated_buf
- @property
- def current_device(self) -> Optional[torch.device]:
- return V.graph.current_device
- @current_device.setter
- def current_device(self, device: Optional[torch.device]) -> None:
- V.graph.current_device = device
- def debug_draw_graph(self) -> None:
- """Generate an image of the graph for debugging"""
- if os.environ.get("INDUCTOR_WRITE_SCHEDULER_GRAPH", None) == "1":
- from .debug import draw_buffers
- draw_buffers(self.nodes, print_graph=True)
- def debug_print_nodes(self, label: str) -> None:
- if log.isEnabledFor(logging.INFO):
- log.info("%s:", label)
- for node in self.nodes:
- node.log_details()
- def create_scheduler_node(self, node: ir.Operation) -> BaseSchedulerNode:
- assert node.get_origins() is not None, (
- "All nodes passed to scheduling must have an origin"
- )
- if node.is_no_op():
- return NopKernelSchedulerNode(self, node)
- elif isinstance(node, (ir.ComputedBuffer, ir.TemplateBuffer)):
- return SchedulerNode(self, node)
- elif isinstance(node, ir.ExternKernel):
- return ExternKernelSchedulerNode(self, node)
- else:
- raise NotImplementedError(node)
- def create_foreach_nodes(self) -> None:
- removed_node_names: OrderedSet[str] = OrderedSet()
- fe_nodes = []
- kept_node_names = self.name_to_fused_node.keys()
- for names in V.graph.lists.values():
- names = [
- name
- for name in names
- if name in kept_node_names
- and not isinstance(self.name_to_node[name], NopKernelSchedulerNode)
- ]
- if not names:
- # All nodes eliminated
- continue
- removed_node_names.update(names)
- snodes = [self.name_to_node[name] for name in names]
- enable_autotune = config.combo_kernels_autotune > 1
- fe_node = ForeachKernelSchedulerNode(
- self,
- snodes,
- use_custom_partition_algo=False,
- enable_autotune=enable_autotune,
- )
- fe_nodes.append(fe_node)
- for name in names:
- self.name_to_fused_node[name] = fe_node
- self.nodes = [
- node for node in self.nodes if node.get_name() not in removed_node_names
- ] + list(fe_nodes)
- def compute_dependencies(self) -> None:
- """
- Create dependency edges between nodes, handling aliasing and
- mutation properly.
- """
- class DedupList(Generic[_T]):
- """
- This data structure behaves like a list except it makes sure the
- elements remain unique.
- Normally one could use a OrderedSet/dict for this purpose however
- the list in question gets elements appended as it is being
- iterated over which means that we need to keep the list
- semantics.
- """
- def __init__(
- self,
- items: Optional[list[_T]] = None,
- membership: Optional[OrderedSet[_T]] = None,
- ) -> None:
- self.items = items or []
- self.membership = membership or OrderedSet()
- def append(self, node_user: _T) -> None:
- if node_user in self.membership:
- return
- self.items.append(node_user)
- self.membership.add(node_user)
- def __add__(self, other: DedupList[_T]) -> DedupList[_T]:
- new_membership = OrderedSet.union(self.membership, other.membership)
- new_items = self.items + [
- x for x in other.items if x not in self.membership
- ]
- return DedupList(new_items, new_membership)
- # pyrefly: ignore [not-a-type]
- name_to_users: defaultdict[str, DedupList[NodeUser]] = collections.defaultdict(
- DedupList
- )
- # handle aliasing by using python aliasing in name_to_users
- # if foo aliases bar then we will make name_to_users["foo"] point
- # to the same python list as name_to_users["bar"]
- for node in self.nodes:
- for buf1 in node.get_outputs():
- buf1_name = buf1.get_name()
- # This is for handling auto functionized ops which return None
- # and mutate more than 1 inputs, we shouldn't let them all
- # point to the same user list since buffers in the aliases
- # list might not be alias to each other.
- if (
- isinstance(buf1.node.layout, ir.NoneLayout)
- and len(buf1.get_aliases()) > 1
- ):
- continue
- for buf2_name in buf1.get_aliases():
- if buf1_name in name_to_users and buf2_name in name_to_users:
- # merge the two
- list1 = name_to_users[buf1_name]
- list2 = name_to_users[buf2_name]
- combined = list1 + list2
- for key in name_to_users:
- if (
- name_to_users[key] is list1
- or name_to_users[key] is list2
- ):
- name_to_users[key] = combined
- elif buf1_name in name_to_users:
- name_to_users[buf2_name] = name_to_users[buf1_name]
- else:
- name_to_users[buf1_name] = name_to_users[buf2_name]
- # pyrefly: ignore [not-a-type]
- def rename(n: str) -> str:
- if n in self.mutation_renames:
- return rename(self.mutation_renames[n])
- return n
- def add_user(
- # pyrefly: ignore [not-a-type]
- used_by_name: str,
- user_node: Union[BaseSchedulerNode, OutputNode],
- can_inplace: bool = False,
- is_weak: bool = False,
- ) -> None:
- name_to_users[rename(used_by_name)].append(
- NodeUser(user_node, can_inplace, is_weak)
- )
- # pyrefly: ignore [not-a-type]
- unbacked_symbol_to_origin_node: dict[sympy.Symbol, Optional[str]] = {}
- # NB: None means that the dependency is on an input. Don't actually
- # generate a dependency because if we do, Inductor will start trying
- # to free the unbacked int but that's pointless
- for val in V.graph.graph_inputs.values():
- if isinstance(val, sympy.Expr):
- for fs in val.free_symbols:
- unbacked_symbol_to_origin_node[fs] = None
- elif isinstance(val, ir.TensorBox):
- # We also need to add symbols from input size as well because
- # AOTI doesn't lift the unbacked symints to inputs
- sym_size = [s for s in val.get_size() if isinstance(s, sympy.Expr)]
- for s in sym_size:
- for fs in s.free_symbols:
- unbacked_symbol_to_origin_node[fs] = None
- has_non_input_unbacked_defs = False
- for node in self.nodes:
- assert node.node is not None
- # unbacked symbols don't follow ordinary buffer dependencies, so
- # we track their def/uses separately
- unbacked_symbol_defs = sorted(
- node.node.get_unbacked_symbol_defs(), key=lambda x: x.name
- )
- for s in unbacked_symbol_defs:
- assert isinstance(s, sympy.Symbol)
- # Pick the first definer as canonical. There may be multiple
- # because if a MultiOutputLayout buffer propagates an unbacked
- # symint to multiple outputs, they will all claim to def it.
- has_non_input_unbacked_defs = True
- if s not in unbacked_symbol_to_origin_node:
- unbacked_symbol_to_origin_node[s] = node.get_name()
- for node in self.nodes:
- log.debug("scheduling %s", node.node)
- if has_non_input_unbacked_defs:
- assert node.node is not None
- unbacked_symbol_uses = sorted(
- node.node.get_free_symbol_uses(unbacked_only=True),
- key=lambda x: x.name,
- )
- # if a kernel takes unbacked symints, register dependencies
- for s in unbacked_symbol_uses:
- assert s in unbacked_symbol_to_origin_node, (
- f"{s} not in {unbacked_symbol_to_origin_node}"
- )
- if (r := unbacked_symbol_to_origin_node[s]) is not None:
- for buf in self.name_to_node[r].get_outputs():
- node.add_fake_dep(StarDep(buf.get_name()))
- if (
- len(node.read_writes.writes) == 1
- and (dep := next(iter(node.read_writes.writes)))
- and isinstance(dep, MemoryDep)
- ):
- node_mode = dep.mode
- else:
- node_mode = None
- # Handle output mutations
- for buf in node.get_outputs():
- # a node will mutate either 0 or 1 buffers
- assert len(buf.get_mutations()) <= 1
- for alt_name in buf.get_mutations():
- alt_name = rename(alt_name)
- # this node must run after the prior writer
- add_user(alt_name, node)
- node.add_fake_dep(StarDep(alt_name, mode=node_mode))
- for user in name_to_users[alt_name].items:
- if user.get_name() == node.get_name():
- continue
- assert isinstance(user.node, BaseSchedulerNode)
- for out_buf in user.node.get_outputs():
- other_name = out_buf.get_name()
- # this node must run after all prior readers
- other_name = rename(other_name)
- # Check if the prior reader is a true alias (view) vs a clone.
- # Views share underlying storage with the mutated buffer, so we
- # need a real dependency (is_fake=False) to keep the view's
- # buffer alive until after this mutation completes. Clones have
- # independent storage, so we only need an ordering dependency
- # (is_fake=True) that won't extend their buffer lifetime.
- is_alias = alt_name in out_buf.get_aliases()
- node.add_fake_dep(
- WeakDep(
- other_name,
- mutating_buf=buf.get_name(),
- is_fake=not is_alias,
- )
- )
- add_user(other_name, node, is_weak=True)
- for add_dep in V.graph.additional_buffer_deps[node.get_name()]:
- add_user(add_dep, node, is_weak=True)
- # is_fake=True because these are control dependencies for ordering only,
- # they should not extend buffer lifetimes
- node.add_fake_dep(WeakDep(add_dep, node.get_name(), is_fake=True))
- for add_dep in V.graph.additional_star_deps[node.get_name()]:
- add_user(add_dep, node, is_weak=False) # Strong dependency
- node.add_fake_dep(StarDep(add_dep))
- # add normal non-mutation dependencies
- for read in node.read_writes.reads:
- if not isinstance(read, WeakDep):
- add_user(read.name, node, node.can_inplace(read))
- node.update_mutated_names(self.mutation_renames)
- # update our renaming scheme for the next iteration
- for buf in node.get_outputs():
- for alt_name in buf.get_mutations():
- self.mutation_renames[rename(alt_name)] = buf.get_name()
- self.mutation_renames[alt_name] = buf.get_name()
- self.mutation_real_name[buf.get_name()] = (
- self.mutation_real_name.get(alt_name, alt_name)
- )
- # make sure outputs aren't dead-code-eliminated
- for buf_name in V.graph.get_output_names():
- log.debug("scheduling output %s", buf_name)
- add_user(buf_name, OutputNode(StarDep(buf_name)))
- # make sure unbacked symints aren't dead-code-eliminated
- if has_non_input_unbacked_defs:
- for out in V.graph.graph_outputs:
- for s in out.get_free_symbol_uses(unbacked_only=True):
- assert s in unbacked_symbol_to_origin_node, (
- f"{s} not in {unbacked_symbol_to_origin_node.keys()}"
- )
- if r := unbacked_symbol_to_origin_node[s]:
- for buf_name in self.name_to_node[r].get_buffer_names():
- log.debug(
- "scheduling output %s for unbacked symint %s",
- buf_name,
- s,
- )
- add_user(buf_name, OutputNode(StarDep(buf_name)))
- # make sure input mutation isn't dead-code-eliminated
- for name in self.mutation_renames:
- if name in V.graph.graph_inputs:
- add_user(name, OutputNode(StarDep(name)))
- V.graph.mutated_inputs.add(name)
- elif name in V.graph.constants:
- # In AOTI, module parameters and buffers are not lifted as graph inputs
- add_user(name, OutputNode(StarDep(name)))
- inp_names = {
- name: index for index, name in enumerate(V.graph.graph_inputs.keys())
- }
- V.graph.mutated_input_idxs = [
- inp_names[name] for name in V.graph.mutated_inputs
- ]
- # copy users information onto the nodes
- for node in self.nodes:
- for buf in node.get_outputs():
- buf.set_users(name_to_users[buf.get_name()].items)
- for name in self.name_to_donated_buffer:
- self.name_to_donated_buffer[name].set_users(name_to_users[name].items)
- # For debug logging
- logbuf = IndentedBuffer()
- logbuf.splice("{")
- for key, value in name_to_users.items():
- with logbuf.indent():
- users = [v.get_name() for v in value.items]
- logbuf.splice(f"'{key}': {users},")
- logbuf.splice("}")
- str = logbuf.getrawvalue().rstrip()
- compute_dependencies_log.debug("BUFFER USER LIST\n")
- compute_dependencies_log.debug("===== AFTER SCHEDULING =====\n%s", str)
- def insert_memory_check_nodes(self) -> None:
- from .memory import (
- assign_memory_planning_info_for_scheduler_buffers,
- compute_memory_timeline,
- FreeableInputBuffer,
- get_freeable_input_buf,
- )
- graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
- name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = (
- get_freeable_input_buf(self.nodes, graph_inputs)
- )
- if not torch._inductor.config.reorder_for_peak_memory:
- assign_memory_planning_info_for_scheduler_buffers(
- self.nodes, self.name_to_buf
- )
- graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
- buf_info_list, _, _ = compute_memory_timeline(
- self.nodes,
- name_to_freeable_input_buf,
- graph_outputs,
- )
- step_allocs_deallocs: list[tuple[list[str], list[str]]] = [
- ([], []) for _ in range(len(self.nodes))
- ]
- for buf_info in buf_info_list:
- # Skip zero-size buffers
- if buf_info.size_alloc == 0 and buf_info.size_free == 0:
- continue
- buf_name = buf_info.buffer.get_name()
- step_allocs_deallocs[buf_info.start_step][0].append(buf_name)
- step_allocs_deallocs[buf_info.end_step][1].append(buf_name)
- from torch._inductor.runtime.debug_utils import register_check_mem_op
- register_check_mem_op()
- def construct_mem_check_node(
- step_idx: int, is_final_step: bool
- ) -> ExternKernelSchedulerNode:
- expected_newly_alive = step_allocs_deallocs[step_idx][0]
- expected_newly_dead = step_allocs_deallocs[step_idx][1]
- nontensor_args = [expected_newly_alive, expected_newly_dead, is_final_step]
- node = ir.MemoryCheckKernel(
- layout=NoneLayout(device=torch.device("cpu")),
- kernel=torch.ops._inductor_debug.check_memory_step.default,
- tensor_args=[],
- nontensor_args=nontensor_args,
- unflatten_args=lambda tensor_args, constant_args: (
- tensor_args,
- {
- "alive": constant_args[0],
- "dead": constant_args[1],
- "is_final_step": constant_args[2],
- },
- ),
- )
- node.operation_name = f"mem_check_{self.nodes[step_idx].get_name()}"
- return ExternKernelSchedulerNode(self, node)
- new_nodes = []
- for i, node in enumerate(self.nodes):
- new_nodes.append(node)
- new_nodes.append(
- construct_mem_check_node(i, is_final_step=(i == len(self.nodes) - 1))
- )
- self.nodes = new_nodes
- def dead_node_elimination(self) -> None:
- """
- Remove any nodes without users
- """
- if not config.use_dce:
- return
- # self.nodes is in topological order, so by iterating in reverse order
- # we have visited (and potentially removed) all users before visiting a
- # given node.
- updated_nodes = []
- for node in reversed(self.nodes):
- def can_eliminate_user(user: NodeUser) -> bool:
- return user.is_weak or user.get_name() in V.graph.removed_operations
- active_buffers = False
- for buf in node.get_outputs():
- can_eliminate = all(can_eliminate_user(u) for u in buf.users)
- if can_eliminate:
- log.debug("removed dead buffer: %s", buf.get_name())
- V.graph.removed_buffers.add(buf.get_name())
- else:
- active_buffers = True
- can_eliminate = not node.has_side_effects() and not active_buffers
- if not can_eliminate:
- updated_nodes.append(node)
- else:
- # dead code
- log.debug("removed dead operation: %s", node.get_name())
- V.graph.removed_operations.add(node.get_name())
- for read in node.read_writes.reads:
- if read.name in self.name_to_buf:
- users = self.name_to_buf[read.name].users
- self.name_to_buf[read.name].users = [
- u for u in users if u.node.get_name() != node.get_name()
- ]
- self.nodes = list(reversed(updated_nodes))
- # Prune any WeakDeps no longer needed
- for node in self.nodes:
- node.prune_weak_deps()
- def mode_requires_synchronization(self, mode: Optional[str]) -> bool:
- """Check if store mode requires cross-thread synchronization."""
- return mode is not None # Currently all non-None modes need sync
- def topological_sort_schedule(
- self, nodes: list[BaseSchedulerNode]
- ) -> list[BaseSchedulerNode]:
- """
- Ensure nodes is in topologically sorted order
- """
- seen = OrderedSet[BaseSchedulerNode]()
- name_to_node: dict[str, BaseSchedulerNode] = dict()
- result: list[BaseSchedulerNode] = []
- def visit(n: BaseSchedulerNode) -> None:
- if n not in seen:
- seen.add(n)
- for dep in sorted(n.unmet_dependencies, key=lambda d: d.name):
- # We only care about doing toposort within `nodes`
- if dep.name not in name_to_node:
- continue
- visit(name_to_node[dep.name])
- result.append(n)
- for node in nodes:
- for name in node.get_buffer_names():
- name_to_node[name] = node
- for node in nodes:
- visit(node)
- return result
- def _get_unmet_dep_nodes(self, snode: BaseSchedulerNode) -> list[BaseSchedulerNode]:
- unmet_deps: OrderedSet[str] = OrderedSet()
- if isinstance(
- snode,
- (
- SchedulerNode,
- ExternKernelSchedulerNode,
- NopKernelSchedulerNode,
- FusedSchedulerNode,
- GroupedSchedulerNode,
- ),
- ):
- for dep in snode.unmet_dependencies:
- unmet_deps.add(dep.name)
- else:
- raise RuntimeError(
- f"get_unmet_dep_nodes is not implemented for {type(snode)}."
- )
- unmet_dep_ops = (self.name_to_buf[dep].defining_op_name() for dep in unmet_deps)
- return list(OrderedSet(self.name_to_fused_node[n] for n in unmet_dep_ops))
- def _topological_sort_nodes(self) -> list[list[BaseSchedulerNode]]:
- """
- Sort nodes by their topological order, return a list of node lists.
- """
- order = []
- nodes = dict.fromkeys(self.nodes, 0)
- children: dict[Any, Any] = {}
- for node in self.nodes:
- deps = self._get_unmet_dep_nodes(node)
- nodes[node] = len(deps)
- for dep in deps:
- c = children.get(dep, [])
- c.append(node)
- children[dep] = c
- zero_deg_nodes = [n for n, v in nodes.items() if v == 0]
- while zero_deg_nodes:
- order.append(zero_deg_nodes)
- for n in zero_deg_nodes:
- for user in children.get(n, []):
- nodes[user] -= 1
- nodes.pop(n)
- zero_deg_nodes = [n for n, v in nodes.items() if v == 0]
- assert not nodes, "Topological sort failed!"
- return order
- def compute_ancestors(self) -> None:
- """
- Populate each node.ancestors
- """
- # note self.nodes is topologically sorted
- name_to_ancestors: dict[str, OrderedSet[str]] = {}
- for node in self.nodes:
- ancestors: OrderedSet[str] = OrderedSet()
- for dep in node.unmet_dependencies:
- dep_node_name = self.name_to_buf[dep.name].defining_op_name()
- ancestors.add(dep_node_name)
- ancestors |= name_to_ancestors[dep_node_name]
- name_to_ancestors[node.get_name()] = ancestors
- node.ancestors = ancestors
- for order, node in enumerate(self.nodes):
- node.min_order = order
- node.max_order = order
- def merge_loops(self) -> None:
- if not config.loop_ordering_after_fusion:
- return
- for node in self.nodes:
- # Even for CPU, if we are using the halide backend, we still need
- # the merge loops steps below
- if not isinstance(node, (SchedulerNode, FusedSchedulerNode)) or (
- not node.is_gpu() and config.cpu_backend != "halide"
- ):
- continue
- for snode in node.get_nodes():
- # merge loops for the scheduler node
- if not isinstance(snode, SchedulerNode) or snode.is_template():
- continue
- snode.merge_loops()
- # Note that for CPU backend, merging loops will change
- # snode.group. It's fine for Triton backend.
- # But if we simplify update snode.group like this:
- # group_fn = self.get_backend(snode.node.get_device()).group_fn
- # snode.group = (snode.node.get_device(), group_fn(snode._sizes))
- # There is still an issue due to different snode in a
- # FusedSchedulerNode having different merged loops.
- # Skip CPU backend for now.
- def fuse_nodes(self, nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
- """
- Combine eligible nodes into FusedSchedulerNodes.
- """
- with dynamo_timed(
- "Scheduler.fused_nodes", log_pt2_compile_event=True, log_waitcounter=True
- ):
- for i in range(10):
- old_len = len(nodes)
- fusion_log.debug(
- "===== attempting fusion (%d/10): %d nodes =====",
- i + 1,
- old_len,
- )
- nodes = self.fuse_nodes_once(nodes, is_reorder_round=False)
- new_len = len(nodes)
- fusion_log.debug(
- "completed fusion round (%d/10): fused %d nodes into %d nodes\n",
- i + 1,
- old_len,
- new_len,
- )
- if new_len == old_len or new_len == 1:
- fusion_log.debug(
- "===== fusion complete (%d iterations) =====", i + 1
- )
- break
- if (
- config.loop_ordering_after_fusion
- or config.loop_index_inversion_in_fusion
- ):
- nodes = self.fuse_nodes_once(nodes, is_reorder_round=True)
- return nodes
- def process_grouped_nodes(self) -> None:
- """
- Unpack GroupedSchedulerNode into regular nodes.
- """
- new_nodes: list[BaseSchedulerNode] = []
- for node in self.nodes:
- new_nodes.extend(
- node.unpack() if isinstance(node, GroupedSchedulerNode) else [node]
- )
- self.nodes = new_nodes
- def benchmark_fused_nodes(
- self, nodes: Sequence[BaseSchedulerNode]
- ) -> tuple[float, str]:
- """
- Benchmark fused list of nodes and return the execution time
- in milliseconds on randomly generated inputs.
- """
- assert len(nodes) > 0
- device = nodes[0].get_device()
- self.current_device = device
- backend = self.get_backend(device)
- with dynamo_timed(
- "benchmark_fused_nodes",
- log_pt2_compile_event=True,
- dynamo_compile_column_us="compile_time_autotune_time_us",
- ):
- return backend.benchmark_fused_nodes(nodes)
- def generate_kernel_code_from_nodes(
- self,
- nodes: Sequence[BaseSchedulerNode],
- benchmark_kernel: bool,
- hint_override: Optional[int] = None,
- ) -> str:
- """
- Benchmark fused list of nodes and return the execution time
- in milliseconds on randomly generated inputs.
- """
- assert len(nodes) > 0
- device = nodes[0].get_device()
- self.current_device = device
- backend = self.get_backend(device)
- with dynamo_timed("generate_kernel_code_from_nodes"):
- return backend.generate_kernel_code_from_nodes(
- nodes, benchmark_kernel, hint_override=hint_override
- )
- def benchmark_codegened_module(
- self, module: ModuleType, device: torch.device
- ) -> tuple[float, str]:
- """
- Benchmark fused list of nodes and return the execution time
- in milliseconds on randomly generated inputs.
- """
- self.current_device = device
- backend = self.get_backend(device)
- with dynamo_timed("benchmark_codegened_module"):
- return backend.benchmark_codegened_module(module)
- def _has_layout_conflict_for_template(
- self, multi_node: ir.MultiTemplateBuffer
- ) -> bool:
- """
- Check if selecting a Triton template would cause layout conflicts.
- Returns True if there's a conflict and we should fall back to ATen.
- """
- constraints = V.graph.buffer_layout_constraints
- if not constraints:
- return False
- log.debug("Node %s has constraints %s", multi_node, constraints)
- for inp in multi_node.inputs:
- # pyrefly: ignore [missing-attribute]
- inp_name = inp.get_name()
- if not getattr(inp, "layout", None) or inp_name not in constraints:
- continue
- layout = inp.layout
- expected_layout = constraints[inp_name]
- if isinstance(layout, ir.FlexibleLayout):
- # Freeze to the expected layout to avoid conflicts
- # pyrefly: ignore [missing-attribute]
- inp.freeze_layout_with_exact_strides(expected_layout.stride)
- layout = inp.layout
- if isinstance(layout, ir.FixedLayout) and expected_layout != layout:
- # Layout already frozen to a different layout - conflict
- log.warning(
- "Layout conflict detected for %s: template expects %s but layout is frozen to %s",
- inp_name,
- expected_layout,
- layout,
- )
- return True
- return False
- def finalize_multi_template_buffers(self) -> None:
- """
- Finalize a backing choice for MultiTemplateBuffers which did not already have a
- choice finalized through fusion. In the case of an extern choice, this will result
- in replacing the SchedulerNode.
- If a MultiTemplateBuffer did not have any fusion opportunities, finalizing a choice
- will force completion of compilation and benchmarking.
- """
- for i, node in enumerate(self.nodes):
- if isinstance(node, SchedulerNode) and isinstance(
- node.node, ir.MultiTemplateBuffer
- ):
- multi_node = node.node
- if not config.test_configs.force_extern_kernel_in_multi_template:
- min_node_unfused, _ = multi_node.get_min_choice()
- else:
- min_node_unfused = next(
- (
- timing
- for timing in multi_node.choice_timings()
- if isinstance(
- timing,
- torch._inductor.select_algorithm.ExternKernelCaller,
- )
- ),
- )
- if isinstance(
- min_node_unfused,
- torch._inductor.ir.TritonTemplateCallerBase,
- ):
- # Check for layout conflicts before committing to Triton template
- if self._has_layout_conflict_for_template(multi_node):
- # Fall back to first ExternKernelCaller (ATen)
- for choice in multi_node.choice_timings():
- if isinstance(
- choice,
- torch._inductor.select_algorithm.ExternKernelCaller,
- ):
- min_node_unfused = choice
- break
- assert isinstance(
- choice, torch._inductor.select_algorithm.ExternKernelCaller
- ), (
- "No extern kernel detected to fallback to when layout constraints fail for Triton templates"
- )
- if isinstance(
- min_node_unfused,
- torch._inductor.ir.TritonTemplateCallerBase,
- ):
- # pyrefly: ignore [unbound-name]
- if config.multi_kernel_hints:
- callers: dict[Optional[int], TritonTemplateCallerBase] = {}
- callers[None] = min_node_unfused
- # pyrefly: ignore [unbound-name]
- for hint in config.multi_kernel_hints:
- timings = multi_node.choice_timings(hint_override=hint)
- triton_timings = {
- k: v
- for k, v in timings.items()
- if isinstance(k, TritonTemplateCallerBase)
- }
- choice = min(triton_timings.items(), key=lambda x: x[1])[0]
- callers[hint] = choice
- node.node.finalize_as_triton_callers(callers)
- else:
- node.node.finalize_as_triton_caller(min_node_unfused)
- continue
- with ir.IRNode.current_origins(multi_node.origins):
- out_tensorbox = min_node_unfused.output_node()
- out_storage = out_tensorbox.data # type: ignore[union-attr]
- assert isinstance(out_storage, ir.StorageBox)
- out_buffer = out_storage.data
- assert isinstance(out_buffer, ir.OperationBuffer)
- if multi_node.origin_node:
- assign_origin_node(out_tensorbox, multi_node.origin_node)
- out_buffer.layout = multi_node.layout
- self._replace_node(out_buffer, multi_node, i, node)
- def _replace_node(
- self,
- out_buffer: ir.OperationBuffer,
- multi_node: ir.MultiTemplateBuffer,
- i: int,
- node: SchedulerNode,
- ) -> None:
- _replace_operation_buffer(multi_node, out_buffer)
- new_scheduler_node = self.create_scheduler_node(out_buffer)
- self.nodes[i] = new_scheduler_node
- self.name_to_node[node.get_name()] = new_scheduler_node
- self.name_to_fused_node[node.get_name()] = new_scheduler_node
- # We need to reflect the mutation renames that were recorded in the original node
- mutation_renames = {}
- for dep in itertools.chain(node.read_writes.reads, node.unmet_dependencies):
- if real_name := self.mutation_real_name.get(dep.name, None):
- mutation_renames[real_name] = dep.name
- def rename_deps(deps: OrderedSet[Dep]) -> OrderedSet[Dep]:
- return OrderedSet(dep.rename(mutation_renames) for dep in deps)
- new_scheduler_node.unmet_dependencies = rename_deps(
- new_scheduler_node.unmet_dependencies
- )
- new_scheduler_node.read_writes.reads = rename_deps(
- new_scheduler_node.read_writes.reads
- )
- for new_out, old_out in zip(
- new_scheduler_node.get_outputs(), node.get_outputs()
- ):
- self.name_to_buf[old_out.get_name()] = new_out
- new_out.users = old_out.users
- new_scheduler_node.min_order = node.min_order
- new_scheduler_node.max_order = node.max_order
- new_scheduler_node.ancestors = node.ancestors
- new_scheduler_node.last_usage = node.last_usage
- def _any_atomic_add(self, node_list: Sequence[BaseSchedulerNode]) -> bool:
- return any(
- hasattr(n.node, "data")
- and n.node is not None
- and hasattr(n.node.data, "scatter_mode")
- and n.node.data.scatter_mode == "atomic_add"
- for n in node_list
- )
- def compile_kernel(
- self, nodes: Sequence[BaseSchedulerNode], hint_override: Optional[int] = None
- ) -> tuple[Optional[LambdaFuture], ModuleType]:
- src_code = self.generate_kernel_code_from_nodes(
- nodes, benchmark_kernel=True, hint_override=hint_override
- )
- mod = PyCodeCache.load(src_code)
- async_compile = torch._inductor.async_compile.AsyncCompile()
- if not async_compile.use_process_pool():
- fut = None
- else:
- fut = async_compile.triton(kernel_name="triton_", source_code=src_code)
- assert isinstance(fut, LambdaFuture)
- return (fut, mod)
- def speedup_by_fusion(
- self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
- ) -> FusionResult:
- """
- If config.benchmark_fusion is False, always return True.
- Otherwise, return True if fusion can brings speedup.
- """
- is_multi_template = any(
- n.is_template()
- and isinstance(n.get_template_node(), ir.MultiTemplateBuffer)
- for n in (node1, node2)
- )
- if not config.benchmark_fusion and not is_multi_template:
- return FusionResult.fuse(True)
- if (
- node1.is_template()
- and not isinstance(node1.get_template_node(), ir.TritonTemplateBuffer)
- or node1.is_foreach()
- or node2.is_foreach()
- ):
- # TODO support benchmarking epilogue fusion
- return FusionResult.fuse(True)
- node_list_1 = node1.get_nodes()
- device = node_list_1[0].get_device()
- assert device
- # don't support benchmark fusion for CPU C++ backend right now.
- if device.type == "cpu" and config.cpu_backend != "triton":
- return FusionResult.fuse(True)
- node_list_2 = node2.get_nodes()
- node_list_fused = list(itertools.chain(node_list_1, node_list_2))
- # We can not accurately benchmark kernel using atomic_add
- # due to how we generate random integer inputs.
- # Skip benchmarking them by allowing fusion.
- if self._any_atomic_add(node_list_fused):
- return FusionResult.fuse(True)
- from triton.compiler.errors import CompilationError
- why = WhyNoFuse(node1, node2)
- device = node_list_fused[0].get_device()
- assert device is not None
- def log_fusion(ms_fused: float, ms1: float, ms2: float) -> None:
- if fusion_log.isEnabledFor(logging.DEBUG):
- if ms_fused < ms1 + ms2:
- fusion_log.debug(
- "can fuse (benchmark): fusing %s with %s cause %sx speedup",
- node1.get_buffer_names(),
- node2.get_buffer_names(),
- green_text(f"{(ms1 + ms2) / ms_fused:.3f}"),
- )
- else:
- fusion_log.debug(
- "cannot fuse (benchmark): fusing %s with %s cause %sx slowdown",
- node1.get_buffer_names(),
- node2.get_buffer_names(),
- red_text(f"{ms_fused / (ms1 + ms2):.3f}"),
- )
- if is_multi_template and any(
- n.get_template_node() is not None for n in (node1, node2)
- ):
- epilogue_fusion = node1.get_template_node() is not None
- multi_node = (
- node1.get_template_node()
- if epilogue_fusion
- else node2.get_template_node()
- )
- assert isinstance(multi_node, ir.MultiTemplateBuffer)
- # Check for layout conflicts before committing to Triton template
- if self._has_layout_conflict_for_template(multi_node):
- return FusionResult.fuse(False)
- hint_override_best_fusion_choice: dict[
- Optional[int], TritonTemplateCallerBase
- ] = {}
- future_choices: list[tuple[Any, Optional[LambdaFuture], ModuleType]] = []
- for hint_override in config.multi_kernel_hints:
- choice_timings = multi_node.choice_timings(hint_override)
- for choice, _ in sorted(choice_timings.items(), key=lambda x: x[1]):
- if not isinstance(
- choice, torch._inductor.select_algorithm.TritonTemplateCaller
- ):
- continue
- with multi_node.swap_as_triton_caller(choice):
- future_choices.append(
- (
- choice,
- *self.compile_kernel(
- node_list_fused, hint_override=choice.hint_override
- ),
- )
- )
- min_ms_fused = float("inf")
- ms_fused_choice: Optional[TritonTemplateCallerBase] = None
- new_timings = {}
- for choice, future, mod_fused in future_choices:
- try:
- if future is not None:
- future.result()
- except Exception as e:
- if fusion_log.isEnabledFor(logging.DEBUG):
- fusion_log.debug( # noqa: G200
- "Exception in compiling %s: %s",
- "prologue" if not epilogue_fusion else "epilogue",
- str(e),
- )
- continue
- with multi_node.swap_as_triton_caller(choice):
- ms_fused, path = self.benchmark_codegened_module(
- mod_fused, device
- )
- new_timings[choice] = ms_fused
- if ms_fused < min_ms_fused:
- min_ms_fused = ms_fused
- ms_fused_choice = choice
- multi_node._choice_timings[hint_override] = new_timings
- assert isinstance(ms_fused_choice, TritonTemplateCallerBase)
- hint_override_best_fusion_choice[hint_override] = ms_fused_choice
- bench_epilogue = config.benchmark_epilogue_fusion
- num_triton_callers = sum(
- isinstance(c, TritonTemplateCallerBase) for c in multi_node.choices
- )
- # Track if the choice timings can be retrieved async after compilation
- get_choice_timings_async = (
- use_pipelined_autotuning()
- and not bench_epilogue
- and num_triton_callers <= config.max_epilogue_benchmarked_choices
- )
- ms1, ms2 = float("inf"), float("inf")
- min_choice: ir.ChoiceCaller | None = None
- if not get_choice_timings_async:
- # Eagerly compile and benchmark non-template nodes
- choice_timings = multi_node.choice_timings()
- min_choice, ms1 = multi_node.get_min_choice()
- choice_timings_iter = sorted(
- choice_timings.items(), key=operator.itemgetter(1)
- )
- else:
- # Use 0 for unfused time, won't be used as bench_epilogue
- # is guaranteed to be False here
- choice_timings_iter = [(c, 0) for c in multi_node.choices]
- if bench_epilogue:
- ms2, path2 = (
- self.benchmark_fused_nodes(node_list_2)
- if epilogue_fusion
- else self.benchmark_fused_nodes(node_list_1)
- )
- else:
- # By default, don't do prologue fusion. Generally slower
- if not epilogue_fusion:
- return FusionResult.fuse(False)
- ms2 = node2._get_estimated_runtime()
- ms2_fused = _estimate_fused_epilogue_runtime(node1, node2, ms2)
- # Start compiling choices in parallel
- future_choices: list[tuple[Any, Optional[LambdaFuture], ModuleType]] = []
- triton_choices = 0
- for choice, unfused_time in choice_timings_iter:
- if not isinstance(choice, TritonTemplateCallerBase):
- continue
- # For prologue fusion we check if the underlying template of the choice
- # supports all allowed prologue inputs. If not, we skip this choice in
- # the fusion benchmark.
- # TODO: Remove this check after all Triton templates support prologue fusion.
- # Currently, persistent+TMA Triton template does not due to the TMA-based loads.
- if (
- not epilogue_fusion
- and hasattr(choice, "allowed_prologue_inps")
- and choice.allowed_prologue_inps != multi_node.allowed_prologue_inps
- ):
- continue
- if bench_epilogue and unfused_time >= ms1 + ms2:
- break
- triton_choices += 1
- if triton_choices > config.max_epilogue_benchmarked_choices:
- break
- with multi_node.swap_as_triton_caller(choice):
- future_choices.append(
- (choice, *self.compile_kernel(node_list_fused))
- )
- if len(future_choices) == 0:
- return FusionResult.fuse(False)
- def benchmark_when_ready() -> bool:
- nonlocal choice_timings, future_choices, ms1, min_choice, multi_node
- min_ms_fused = float("inf")
- ms_fused_choice = None
- new_timings = {}
- if get_choice_timings_async:
- assert multi_node and isinstance(multi_node, ir.MultiTemplateBuffer)
- choice_timings = multi_node.choice_timings()
- min_choice, ms1 = multi_node.get_min_choice()
- future_choices = sorted(
- future_choices,
- key=lambda x: choice_timings[x[0]],
- )
- # Benchmark each choice after compilation completes
- for choice, future, mod_fused in future_choices:
- try:
- if future is not None:
- res = future.result()
- elif not bench_epilogue:
- res = mod_fused.triton_
- res.precompile()
- else:
- res = None
- # Ideally we would more narrowly catch Exceptions here but
- # triton will unpredictably error with valid prologue fusions
- except Exception as e:
- if fusion_log.isEnabledFor(logging.DEBUG):
- fusion_log.debug( # noqa: G200
- "Exception in compiling %s: %s",
- "prologue" if not epilogue_fusion else "epilogue",
- str(e),
- )
- continue
- if bench_epilogue:
- # pyrefly: ignore [missing-attribute]
- with multi_node.swap_as_triton_caller(choice):
- ms_fused, path = self.benchmark_codegened_module(
- mod_fused,
- # pyrefly: ignore [bad-argument-type]
- device,
- )
- new_timings[choice] = ms_fused
- if ms_fused < min_ms_fused:
- min_ms_fused = ms_fused
- ms_fused_choice = choice
- else:
- fusible_choice = (
- min_choice == choice
- or ms2 + ms1 > choice_timings[choice] + ms2_fused
- )
- if (
- res
- # pyrefly: ignore [missing-attribute]
- and len(res.launchers) == 1
- # pyrefly: ignore [bad-index]
- and res.launchers[0].n_spills <= 8
- and fusible_choice
- ):
- ms_fused_choice = choice
- break
- if bench_epilogue:
- log_fusion(min_ms_fused, ms1, ms2)
- if (
- not bench_epilogue or min_ms_fused < (ms1 + ms2)
- ) and ms_fused_choice is not None:
- if config.multi_kernel_hints:
- hint_override_best_fusion_choice[None] = ms_fused_choice
- # pyrefly: ignore [missing-attribute]
- multi_node.finalize_as_triton_callers(
- hint_override_best_fusion_choice
- )
- else:
- # pyrefly: ignore [missing-attribute]
- multi_node.finalize_as_triton_caller(ms_fused_choice)
- # pyrefly: ignore [missing-attribute]
- multi_node._choice_timings[None] = new_timings
- return True
- else:
- return False
- return FusionResult.from_callable(
- benchmark_when_ready, future_choices[0][1]
- )
- else:
- # Start parallel compilation for all three kernels
- future_and_mod_l1 = self.compile_kernel(node_list_1)
- future_and_mod_l2 = self.compile_kernel(node_list_2)
- future_and_mod_l1_fused = self.compile_kernel(node_list_fused)
- def benchmark_when_ready() -> bool:
- from torch._inductor.runtime.triton_heuristics import (
- NoTritonConfigsError,
- )
- try:
- # Wait for all compilations to complete
- for fut in (
- future_and_mod_l1[0],
- future_and_mod_l2[0],
- future_and_mod_l1_fused[0],
- ):
- if fut is not None:
- fut.result()
- ms1, path1 = self.benchmark_codegened_module(
- future_and_mod_l1[1],
- # pyrefly: ignore [bad-argument-type]
- device,
- )
- if math.isinf(ms1):
- why("register spilling of the first kernel")
- return False
- ms2, path2 = self.benchmark_codegened_module(
- future_and_mod_l2[1],
- # pyrefly: ignore [bad-argument-type]
- device,
- )
- if math.isinf(ms2):
- why("register spilling of the second kernel")
- return False
- ms_fused, path_fused = self.benchmark_codegened_module(
- future_and_mod_l1_fused[1],
- # pyrefly: ignore [bad-argument-type]
- device,
- )
- if math.isinf(ms_fused):
- why("register spilling of the fused kernel")
- return False
- log_fusion(ms_fused, ms1, ms2)
- if (
- is_metric_table_enabled("slow_fusion")
- and ms_fused >= ms1 + ms2
- and (path1, path2) not in self.logged_slow_fusion
- ):
- self.logged_slow_fusion.add((path1, path2))
- get_metric_table("slow_fusion").add_row(
- lambda: {
- "kernel1_path": path1,
- "kernel1_latency": ms1,
- "kernel2_path": path2,
- "kernel2_latency": ms2,
- "fused_kernel_path": path_fused,
- "fused_kernel_latency": ms_fused,
- "slow_down_ratio": ms_fused / (ms1 + ms2),
- }
- )
- return ms_fused < ms1 + ms2
- except NoTritonConfigsError:
- return False
- except CompilationError as e:
- if "Loop-carried variable" in str(e):
- return True
- raise
- return FusionResult.from_callable(
- callable_fn=benchmark_when_ready, future=future_and_mod_l1_fused[0]
- )
- def get_fused_node(self, node: BaseSchedulerNode) -> BaseSchedulerNode:
- "Look up the node in Scheduler name_to_fused_node"
- return self.name_to_fused_node[node.get_first_name()]
- def fuse_two_nodes(
- self,
- node1: BaseSchedulerNode,
- node2: BaseSchedulerNode,
- fused_nodes: OrderedSet[BaseSchedulerNode],
- ) -> BaseSchedulerNode:
- fusion_log.debug("fusing %s with %s", node1.get_name(), node2.get_name())
- device = node1.get_device()
- assert node2.get_device() == device
- node3 = self.get_backend(device).fuse(node1, node2)
- fused_nodes.remove(node1)
- fused_nodes.remove(node2)
- fused_nodes.add(node3)
- self.name_to_fused_node.update({n.get_name(): node3 for n in node3.get_nodes()})
- return node3
- def fuse_if_speedup(
- self,
- node1: BaseSchedulerNode,
- node2: BaseSchedulerNode,
- speedup_fn: Callable[[], bool],
- fused_nodes: OrderedSet[BaseSchedulerNode],
- ):
- if (
- self.can_fuse(node1, node2)
- and not self.will_fusion_create_cycle(node1, node2)
- and speedup_fn()
- ):
- self.fuse_two_nodes(node1, node2, fused_nodes)
- return True
- return False
- def _evaluate_pending_template_fusions(
- self,
- template_fusion_candidates: dict[BaseSchedulerNode, list[PendingFusion]],
- fused_nodes: OrderedSet[BaseSchedulerNode],
- ) -> None:
- """
- Evaluate pending template fusions for a set of fusion candidate nodes.
- The fusion candidate nodes are pointwise nodes as potential epilogue
- or prologue fusions
- """
- while template_fusion_candidates:
- template_futures: list[Future] = []
- future_to_pending_fusion: dict[
- Future, tuple[PendingFusion, BaseSchedulerNode]
- ] = {}
- fusions_to_remove: OrderedSet[BaseSchedulerNode] = OrderedSet()
- for candidate in template_fusion_candidates:
- assert (
- candidate in template_fusion_candidates
- and len(template_fusion_candidates[candidate]) >= 1
- )
- pending_fusion = template_fusion_candidates[candidate].pop(0)
- if len(template_fusion_candidates[candidate]) == 0:
- fusions_to_remove.add(candidate)
- node1, node2 = pending_fusion.get_fusion_nodes()
- if node2 == candidate:
- assert is_epilogue_fusion(node1, node2)
- template_node = node1
- else:
- assert node1 == candidate
- assert is_prologue_fusion(node1, node2)
- template_node = node2
- # template node fused with same class of pointwise (prologue/epilogue)
- # move onto next candidate as not fusible
- # TODO (PaulZhang12): Does not support fusions of templates with
- # multiple potential epilogues
- if self.get_fused_node(template_node) is not template_node:
- continue
- if pending_fusion.future:
- f = pending_fusion.future.future
- assert f is not None
- template_futures.append(f)
- future_to_pending_fusion[f] = (pending_fusion, candidate)
- else:
- # Non AsyncCompile path, perform fusion
- if self.fuse_if_speedup(
- node1, node2, pending_fusion.callable_fn, fused_nodes
- ):
- fusions_to_remove.add(candidate)
- # Evaluate fusion candidates as async_compile completes
- for f in as_completed(template_futures):
- pending_fusion, cand = future_to_pending_fusion[f]
- if self.fuse_if_speedup(
- self.get_fused_node(pending_fusion.node1),
- self.get_fused_node(pending_fusion.node2),
- pending_fusion.callable_fn,
- fused_nodes,
- ):
- fusions_to_remove.add(cand)
- for f in fusions_to_remove:
- template_fusion_candidates.pop(f)
- def _try_fusion_pairs(
- self,
- possible_fusion_pairs: list[tuple[BaseSchedulerNode, BaseSchedulerNode]],
- pending_fusions: dict[BaseSchedulerNode, PendingFusion],
- template_fusion_nodes: dict[BaseSchedulerNode, list[PendingFusion]],
- fused_nodes: OrderedSet[BaseSchedulerNode],
- is_reorder_round: bool,
- ):
- def resolve_pending_fusions(
- node1: BaseSchedulerNode,
- node2: BaseSchedulerNode,
- ) -> None:
- while (
- self.get_fused_node(node1) in pending_fusions
- or self.get_fused_node(node2) in pending_fusions
- ):
- pending_fusion = pending_fusions.get(
- self.get_fused_node(node1),
- pending_fusions.get(self.get_fused_node(node2)),
- )
- assert pending_fusion is not None
- node_key1, node_key2 = pending_fusion.get_fusion_nodes()
- is_speedup = pending_fusion.callable_fn
- pending_fusions.pop(node_key1, None)
- pending_fusions.pop(node_key2, None)
- assert self.get_fused_node(node_key1) is node_key1
- assert self.get_fused_node(node_key2) is node_key2
- if not is_speedup() or self.will_fusion_create_cycle(node1, node2):
- continue
- self.fuse_two_nodes(node_key1, node_key2, fused_nodes)
- for node1, node2 in possible_fusion_pairs:
- # if either node is in a pending fusion, resolve it.
- # since we iterate on potential fusions based on profitability
- # the first potential fusion should take precedence.
- resolve_pending_fusions(node1, node2)
- node1 = self.get_fused_node(node1)
- node2 = self.get_fused_node(node2)
- if (
- is_template_fusion(node1, node2)
- and (node1, node2) in self.seen_template_fusions
- ):
- continue
- if self.can_fuse(
- node1, node2, is_reorder_round
- ) and not self.will_fusion_create_cycle(node1, node2):
- fusion_res = self.speedup_by_fusion(node1, node2)
- if fusion_res.callable_fn is not None:
- pending_fusion = PendingFusion(
- callable_fn=fusion_res.callable_fn,
- node1=node1,
- node2=node2,
- future=fusion_res.future,
- )
- if is_template_fusion(node1, node2):
- assert (node1, node2) not in self.seen_template_fusions
- self.seen_template_fusions.add((node1, node2))
- template_pw_node = template_fusion_pw_node(node1, node2)
- if template_pw_node not in template_fusion_nodes:
- template_fusion_nodes[template_pw_node] = []
- template_fusion_nodes[template_pw_node].append(pending_fusion)
- else:
- pending_fusions[node1] = pending_fusion
- pending_fusions[node2] = pending_fusion
- continue
- if not fusion_res.should_fuse:
- continue
- self.fuse_two_nodes(node1, node2, fused_nodes)
- def _finish_pending_fusions(
- self,
- fused_nodes: OrderedSet[BaseSchedulerNode],
- pending_fusions: dict[BaseSchedulerNode, PendingFusion],
- ):
- seen_pair_speedup_fn: OrderedSet[Callable[[], bool]] = OrderedSet()
- # Resolve pending fusions for non templates in case of benchmark_kernel=True
- for pending_fusion in pending_fusions.values():
- node_key1, node_key2 = pending_fusion.get_fusion_nodes()
- is_speedup_fn = pending_fusion.callable_fn
- if is_speedup_fn in seen_pair_speedup_fn or is_template_fusion(
- node_key1, node_key2
- ):
- continue
- seen_pair_speedup_fn.add(is_speedup_fn)
- assert self.get_fused_node(node_key1) is node_key1
- assert self.get_fused_node(node_key2) is node_key2
- self.fuse_if_speedup(node_key1, node_key2, is_speedup_fn, fused_nodes)
- def _handle_template_overlap(
- self,
- possible_fusions: list[tuple[BaseSchedulerNode, BaseSchedulerNode]],
- deferred_prologue_fusions: list[tuple[BaseSchedulerNode, BaseSchedulerNode]],
- ):
- # Potentially a prologue fusion might have the same template as an epilogue
- # the prologue fusion therefore has to be evaluated on the potential
- # fused template + epilogue
- epilogue_template_nodes = OrderedSet(
- [n1 for n1, n2 in possible_fusions if is_epilogue_fusion(n1, n2)]
- )
- new_possible_fusions = []
- for n1, n2 in possible_fusions:
- if is_prologue_fusion(n1, n2) and n1 in epilogue_template_nodes:
- deferred_prologue_fusions.append((n1, n2))
- else:
- new_possible_fusions.append((n1, n2))
- possible_fusions = new_possible_fusions
- def fuse_nodes_once(
- self,
- nodes: list[BaseSchedulerNode],
- is_reorder_round: bool,
- ) -> list[BaseSchedulerNode]:
- """
- Combine eligible nodes into FusedSchedulerNodes.
- This relies on two key functions to control the logic:
- - self.can_fuse(): checks if a fusion is legal
- - self.score_fusion(): assigns priority to a given fusion
- """
- self.prune_redundant_deps(nodes)
- fused_nodes = OrderedSet(nodes)
- if fusion_log.isEnabledFor(logging.DEBUG):
- fusion_log.debug("fuse_nodes_once, candidates:")
- for node in fused_nodes:
- fusion_log.debug(" %s", node.debug_str_short())
- # These are potential fusions which we are async compiling,
- # and which we will benchmark profitability of.
- # Maps node -> (is_speedup_fn, LambdaFuture, node1, node2)
- # Only used in the case of benchmark_kernel=True
- pending_fusions: dict[
- BaseSchedulerNode,
- PendingFusion,
- ] = {}
- template_fusion_nodes: dict[BaseSchedulerNode, list[PendingFusion]] = {}
- deferred_prologue_fusions: list[
- tuple[BaseSchedulerNode, BaseSchedulerNode]
- ] = []
- possible_fusions = self.get_possible_fusions(
- nodes,
- is_reorder_round,
- )
- if (
- (config.max_autotune_gemm or config.max_autotune)
- and config.prologue_fusion
- and config.epilogue_fusion
- ):
- self._handle_template_overlap(possible_fusions, deferred_prologue_fusions)
- self._try_fusion_pairs(
- possible_fusions,
- pending_fusions,
- template_fusion_nodes,
- fused_nodes,
- is_reorder_round,
- )
- self._finish_pending_fusions(fused_nodes, pending_fusions)
- self._evaluate_pending_template_fusions(template_fusion_nodes, fused_nodes)
- template_fusion_nodes.clear()
- if deferred_prologue_fusions:
- self._try_fusion_pairs(
- deferred_prologue_fusions,
- pending_fusions,
- template_fusion_nodes,
- fused_nodes,
- is_reorder_round,
- )
- self._evaluate_pending_template_fusions(template_fusion_nodes, fused_nodes)
- nodes = sorted(fused_nodes, key=lambda x: x.min_order)
- nodes = self.topological_sort_schedule(nodes)
- return nodes
- def create_combo_kernel_nodes(self, num_ck_nodes: Optional[int] = None) -> None:
- """
- Groups parallel nodes
- """
- fused_nodes = OrderedSet(self.nodes)
- count = 0
- num_nodes_orig = len(self.nodes)
- log.debug("ComboKernels: Generating with num_ck_nodes = %s...", num_ck_nodes)
- for num, node_list in enumerate(
- ForeachKernelSchedulerNode.group_nodes_for_combo_kernels(self)
- ):
- node_list = ForeachKernelSchedulerNode.combinable_nodes(node_list)
- if len(node_list) < 2:
- continue
- if num_ck_nodes is not None and count > num_ck_nodes:
- break
- if not self.speedup_by_combo_kernel(node_list):
- log.debug("ComboKernels: Not speeding up %d-th group", num)
- continue
- count += 1
- enable_autotune = config.combo_kernels_autotune > 0
- group_snode = ForeachKernelSchedulerNode(
- node_list[0].scheduler,
- node_list,
- use_custom_partition_algo=True,
- enable_autotune=enable_autotune,
- )
- log.info(
- "ComboKernels: Combining %d nodes for %d-th group",
- len(node_list),
- num,
- )
- for node in node_list:
- fused_nodes.remove(node)
- fused_nodes.add(group_snode)
- self.name_to_fused_node.update(
- {n.get_name(): group_snode for n in group_snode.get_nodes()}
- )
- self.nodes = sorted(fused_nodes, key=lambda x: x.min_order)
- self.nodes = self.topological_sort_schedule(self.nodes)
- log.info(
- "Generated ComboKernel nodes: %d ComboKernels, totally %d -> %d nodes",
- count,
- num_nodes_orig,
- len(self.nodes),
- )
- self.prune_redundant_deps(self.nodes)
- def prune_redundant_deps(self, nodes: list[BaseSchedulerNode]) -> None:
- for node in nodes:
- node.prune_redundant_deps(self.name_to_fused_node)
- def get_possible_fusions(
- self,
- nodes: list[BaseSchedulerNode],
- is_reorder_round: bool,
- ) -> list[tuple[BaseSchedulerNode, BaseSchedulerNode]]:
- """
- Helper to find all legal fusion opportunities, sorted by self.score_fusion()
- """
- possible_fusions = []
- seen = OrderedSet[tuple[BaseSchedulerNode, BaseSchedulerNode]]()
- def check_all_pairs(nodes: list[BaseSchedulerNode]) -> None:
- for node1_index, node1 in enumerate(nodes):
- for node2 in nodes[
- node1_index + 1 : node1_index
- + 1
- + config.max_fusion_buffer_group_pairwise_attempts
- ]:
- key = (node1, node2)
- if key in seen:
- continue
- seen.add(key)
- if self.can_fuse(node1, node2, is_reorder_round):
- possible_fusions.append(key)
- elif (node2.is_template() or node2.is_foreach()) and self.can_fuse(
- node2, node1, is_reorder_round
- ):
- # foreach fusions and epilogue fusions are order dependent
- possible_fusions.append((node2, node1))
- buffer_names_grouping = collections.defaultdict(list)
- for node in nodes:
- if self.unfusable_node(node):
- continue
- for buf in node.used_buffer_names():
- buffer_names_grouping[buf].append(node)
- for node_grouping in buffer_names_grouping.values():
- check_all_pairs(node_grouping)
- if config.aggressive_fusion:
- group_grouping = collections.defaultdict(list)
- for node in nodes:
- group = getattr(node, "group", None)
- if group:
- group_grouping[group].append(node)
- for node_grouping in group_grouping.values():
- check_all_pairs(node_grouping)
- possible_fusions = self.get_possible_fusions_with_highest_priority(
- possible_fusions
- )
- possible_fusions.sort(key=self.score_fusion_key, reverse=True)
- fusion_log.debug("found %d possible fusions", len(possible_fusions))
- return possible_fusions
- def will_fusion_create_cycle(
- self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
- ) -> bool:
- """
- Finds whether there's a path from node1 to node2 (or vice-versa)
- caused indirectly by other fusions.
- """
- # since we are just returning boolean here, use slightly faster, unordered set
- visited = OrderedSet[FusedSchedulerNode]()
- def found_path(node: BaseSchedulerNode) -> bool:
- # only fused nodes can introduce new ancestors.
- if isinstance(node, FusedSchedulerNode) and node not in visited:
- visited.add(node)
- if node.get_operation_names().issubset(combined_ancestors):
- # All fusion outputs are in ancestors of node1 and node2, thus
- # cannot introduce new path:
- #
- # 1. if output is neither descendent of node1 or node2, the
- # output cannot introduce a path
- # 2. due to [can_fuse]: if WLOG output is descendent of node1, it cannot be
- # on path(node1->node2), hence it cannot be ancestor of node2
- # 3. due to [acyclic]: if WLOG output is descendent of node1, it cannot be
- # ancestor of node1
- return False
- else:
- # continue DFS of new ancestors introduced by the fusion
- return bool(combined_names & node.ancestors) or any(
- found_path(self.name_to_fused_node[n])
- for n in node.ancestors - combined_ancestors
- )
- return False
- # as above - use slightly faster, unordered set
- combined_names = (
- node1.get_operation_names()._dict.keys()
- | node2.get_operation_names()._dict.keys()
- )
- combined_ancestors = (
- node1.ancestors._dict.keys() | node2.ancestors._dict.keys()
- ) - combined_names
- cycle = any(found_path(self.name_to_fused_node[n]) for n in combined_ancestors)
- if cycle:
- WhyNoFuse(node1, node2)("will create cycle")
- return cycle
- def can_fusion_increase_peak_memory(
- self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
- ) -> bool:
- """
- Return true if fusing the two nodes can potentially increasing peak memory.
- The implementation is more like a heuristic since we don't really know if we are at peak
- or not when trying to fuse these two nodes. The order of nodes may change later which makes the
- peak memory estimation hard.
- Here is how we decide the LOWER BOUND of extra memory allocation if we fuse these 2 nodes:
- 1. find all buffers read by each node with a single user. These buffers are supposed to
- be reused if we don't fuses these 2 nodes
- 2. find the intersection of these buffers for the two node and sum the total buffer size.
- If we don't fuse these two nodes, we can at lease avoid this much memory allocation.
- Note that the extra memory allocation is not necessarily causing peak memory increase.
- This is just a heuristic.
- We return true only if the saving for fusion can not trade off the extra memory allocation.
- """
- from .codegen.wrapper import buffer_reuse_key
- def _find_single_user_inputs(
- node: BaseSchedulerNode,
- ) -> list[ir.Buffer]:
- output = []
- for rd in node.read_writes.reads:
- buf = self.name_to_buf.get(rd.name)
- if buf and len(buf.users) == 1 and buf.node.has_tensor_output():
- output.append(buf.node)
- return output
- # Check inputs that can be potentially reused
- lhs_dep_nodes = _find_single_user_inputs(node1)
- rhs_dep_nodes = _find_single_user_inputs(node2)
- lhs_reuse_keys = OrderedSet(buffer_reuse_key(buf) for buf in lhs_dep_nodes)
- rhs_reuse_keys = OrderedSet(buffer_reuse_key(buf) for buf in rhs_dep_nodes)
- common_reuse_keys = lhs_reuse_keys.intersection(rhs_reuse_keys)
- memory_overhead = 0
- for key in common_reuse_keys:
- try:
- memory_overhead += int(key[2])
- except ValueError:
- # not an integer. Fallback is to fuse
- return False
- bw_saving = self.score_fusion_memory(node1, node2)
- # The factor 32 here is quite arbitrary.
- if V.graph.sizevars.statically_known_gt(memory_overhead, 32 * bw_saving):
- return True
- return False
- def fusion_prevent_too_many_reads_and_writes(
- self, node1: BaseSchedulerNode, node2: BaseSchedulerNode, threshold: int
- ) -> bool:
- # After fusion, we need to calculate the unique I/O buffers
- # accounting for buffers that become internal (removed through fusion)
- # Get all nodes that will be in the fused node
- fused_node_names = OrderedSet(
- [node.get_name() for node in node1.get_nodes()]
- + [node.get_name() for node in node2.get_nodes()]
- )
- # Calculate node2 reads that can be removed through fusion,
- # i.e. node2 reads that are outputs of node1
- node1_write_names = OrderedSet(dep.name for dep in node1.read_writes.writes)
- node2_read_names = OrderedSet(dep.name for dep in node2.read_writes.reads)
- reads_removed_through_fusion = node2_read_names & node1_write_names
- # Calculate node1 writes that can be removed through fusion,
- # i.e. node1 writes that are only read by node2
- writes_removed_through_fusion: OrderedSet[str] = OrderedSet()
- for write_dep in node1.read_writes.writes:
- if self.can_buffer_be_removed_through_fusion(
- write_dep.name, fused_node_names
- ):
- writes_removed_through_fusion.add(write_dep.name)
- # Get all unique reads (union of both nodes' reads)
- all_read_names = OrderedSet(
- dep.name for dep in node1.read_writes.reads
- ) | OrderedSet(dep.name for dep in node2.read_writes.reads)
- # Get all unique writes (union of both nodes' writes)
- all_write_names = OrderedSet(
- dep.name for dep in node1.read_writes.writes
- ) | OrderedSet(dep.name for dep in node2.read_writes.writes)
- # Remove reads that become internal
- unique_reads = all_read_names - reads_removed_through_fusion
- # Remove writes that become internal
- unique_writes = all_write_names - writes_removed_through_fusion
- # Get all unique buffer names (reads and writes combined, but no double counting)
- unique_io_buffers = unique_reads | unique_writes
- return len(unique_io_buffers) > threshold
- def are_long_distant_nodes(
- self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
- ) -> bool:
- """
- This function prevents fusion for nodes that can increase memory
- footprint. This problem is more common in horizontal fusion, where nodes
- that are far apart in the original order get fused, lengthening the live
- intervals of tensors. This is very evident in models with activation
- checkpointing, where the recomputed nodes from different checkpointed
- regions get fused and significantly increase the memory footprint.
- The current attempt is a quick, possibly hacky, heuristic to prevent the
- fusion of nodes that are far away in the original order.
- A better but difficult to implement heuristic would be to use live
- intervals of the buffers, find region of peak pressure in the original
- program and prevent fusion that crosses that peak region. We might need
- special care or good approximation in this implementation, as fusion of
- node changes live intervals, and re-computing live intervals and peak
- memory after each fusion can introduce large compilation overhead.
- """
- proximity_score = max(
- abs(node1.min_order - node2.max_order),
- abs(node2.min_order - node1.max_order),
- )
- return proximity_score > 64
- def decide_fusion_fail_reason(
- self,
- node1: BaseSchedulerNode,
- node2: BaseSchedulerNode,
- common_buf_names: Union[tuple[str, ...], OrderedSet[str]],
- ) -> str:
- """
- Try to decide reasons why fusion fail due to no shared memory even though
- there are common buffers.
- """
- reasons = {}
- node1_name2dep = {dep.name: dep for dep in node1.read_writes.reads_and_writes()}
- node2_name2dep = {dep.name: dep for dep in node2.read_writes.reads_and_writes()}
- for buf_name in common_buf_names:
- buf = V.graph.get_buffer(buf_name)
- lhs_dep = node1_name2dep[buf_name]
- rhs_dep = node2_name2dep[buf_name]
- if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep):
- reasons[buf_name] = (
- f"not MemoryDep: {type(lhs_dep)} v.s. {type(rhs_dep)}"
- )
- continue
- if lhs_dep.get_numel() != rhs_dep.get_numel():
- reasons[buf_name] = (
- f"different numel: {lhs_dep.get_numel()} v.s. {rhs_dep.get_numel()}"
- )
- continue
- # same numel but different MemoryDep.size. Should be broadcasting
- if sympy_product(lhs_dep.size) != sympy_product(rhs_dep.size):
- reasons[buf_name] = "broadcast"
- continue
- lhs_off = lhs_dep.get_offset()
- rhs_off = rhs_dep.get_offset()
- if lhs_off != rhs_off:
- # One example is in transformer, we use a concatenated linear layer
- # to project Q/K/V and then split the result. The 3 splits will
- # point to the same buffer with different offsets.
- reasons[buf_name] = f"different offset: {lhs_off} v.s. {rhs_off}"
- continue
- if (
- lhs_dep.normalize_with_stride_order()
- == rhs_dep.normalize_with_stride_order()
- ):
- reasons[buf_name] = f"Mismatch loop orders: {lhs_dep} v.s. {rhs_dep}"
- continue
- # Add more rules here
- layout_str = ""
- if not isinstance(buf, ir.TorchBindObject):
- layout_str = f"Layout: {buf.layout}"
- reasons[buf_name] = (
- f"Unknown reason: {lhs_dep} v.s. {rhs_dep}. {layout_str}"
- )
- return str(reasons)
- def shared_data_after_inverting_indexing(
- self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
- ) -> int:
- """
- Attempts to enable fusion between two nodes by inverting indexing patterns.
- This optimization targets cases where node1 has a contiguous write and
- node2 has a contiguous write but discontiguous read. By inverting the
- indexing in node2's read and write operations, we can make them compatible
- with node1 for potential fusion.
- Args:
- node1: First scheduler node (source)
- node2: Second scheduler node (target for inversion)
- Returns:
- int: Fusion score if successful, 0 if optimization not applicable
- """
- if not config.loop_index_inversion_in_fusion:
- return -1
- if any(n.is_cpu() for n in [node1, node2]):
- return -1
- # Check for shared buffers between nodes
- node1_buffer_names = node1.read_writes.buffer_names()
- node2_buffer_names = node2.read_writes.buffer_names()
- common_buffer_names = node1_buffer_names & node2_buffer_names
- if not common_buffer_names:
- return -1
- # only invert if node1 is single unmet dep
- node2_unmet_dependencies = OrderedSet(
- dep.name for dep in node2.unmet_dependencies
- )
- if node2_unmet_dependencies - node1_buffer_names:
- return -1
- if len(node2_unmet_dependencies) > 1:
- return -1
- # Currently only handle single read/write operations
- if len(node2.read_writes.reads) > 1 or len(node2.read_writes.writes) > 1:
- return -1
- node2_read = next(iter(node2.read_writes.reads))
- node2_write = next(iter(node2.read_writes.writes))
- if not isinstance(node2_read, MemoryDep) or not isinstance(
- node2_write, MemoryDep
- ):
- return -1
- node1_writes = {dep.name: dep for dep in node1.read_writes.writes}
- if node2_read.name not in node1_writes:
- return -1
- node1_write = node1_writes[node2_read.name]
- if not isinstance(node1_write, MemoryDep):
- return -1
- # We are checking for compatibility with the normalized node1 write
- # then modifying node2 reads/writes. since the node1 write will be just used
- # for compatibility, while node2 will be used in actual modification, just
- # normalize node1 not node2.
- node1_write = node1_write.normalize()
- if (
- node1_write.index != node2_write.index
- and node1_write.size != node2_write.size
- ):
- return -1
- if node2_read.size != node2_write.size or len(node2_read.var_names) != 1:
- return -1
- # Verify we have exactly two indexing expressions (one read, one write)
- if len(node2._body.indexing_exprs) != 2: # type: ignore[attr-defined]
- return -1
- # No subblocks allowed for this optimization
- if node2._body.subblocks: # type: ignore[attr-defined]
- return -1
- assert (
- "index0" in node2._body.indexing_exprs # type: ignore[attr-defined]
- and "index1" in node2._body.indexing_exprs # type: ignore[attr-defined]
- )
- # Extract and verify single read expression
- node2_read_exprs = OrderedSet(expr for expr in node2._body.get_read_exprs()) # type: ignore[attr-defined]
- if len(node2_read_exprs) != 1:
- return -1
- read_expr = next(iter(node2_read_exprs))
- # Determine which index is for reading vs writing
- if read_expr == node2._body.indexing_exprs["index0"]: # type: ignore[attr-defined]
- read_expr_index = "index0"
- write_expr_index = "index1"
- else:
- assert read_expr == node2._body.indexing_exprs["index1"] # type: ignore[attr-defined]
- read_expr_index = "index1"
- write_expr_index = "index0"
- from torch._inductor.invert_expr_analysis import generate_inverse_formula
- index_vars = node2._body.vars[0] # type: ignore[attr-defined]
- if len(index_vars) != 1:
- return -1
- simplified_terms = []
- for term in sympy.Add.make_args(read_expr):
- simplified_terms.append(
- V.graph.sizevars.combine_modular_indexing_pairs(term)
- )
- simplified_read_expr = sum(simplified_terms)
- inverse_formula = generate_inverse_formula(simplified_read_expr, index_vars[0])
- # formula is not invertible
- if inverse_formula is None:
- return -1
- # === Apply Inversion ===
- # Swap the indexing expressions using the inverse formula
- node2._body.indexing_exprs[read_expr_index] = node2._body.indexing_exprs[ # type: ignore[attr-defined]
- write_expr_index
- ]
- node2._body.indexing_exprs[write_expr_index] = inverse_formula # type: ignore[attr-defined]
- # Refresh dependencies and calculate fusion score
- node2.refresh_dependencies(True, False) # type: ignore[attr-defined]
- score = self.score_fusion_memory(node1, node2)
- assert isinstance(score, int)
- fusion_log.info("Shared memory after inversion: %d", score)
- return score
- def shared_data_after_reordering_loop(
- self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
- ) -> int:
- """
- Right now just greedily reorder the loop of node1 to be compatible with node2,
- but ideally we should have some heuristics to reorder the loop for node2
- to be compatible with node1 if that's more efficient.
- Return the amount of shared data re-computed in this method.
- If no such recomputation happens, return -1 (not return 0 since 0 is a valid
- amount of shared data).
- """
- # TODO Don't do loop reordering for CPU for now.
- # Should debug more why it does not work for CPU codegen
- if not config.loop_ordering_after_fusion or any(
- n.is_cpu() for n in [node1, node2]
- ):
- return -1
- # in some rare case, a template can be passed in.
- # Check test_interaction_with_multi_template in test_loop_ordering.py
- # and https://github.com/pytorch/pytorch/issues/165579
- if node1.is_template() or node2.is_template():
- return -1
- node1_buffer_names = node1.read_writes.buffer_names()
- node2_buffer_names = node2.read_writes.buffer_names()
- # Fast path: no common buffers.
- common_buffer_names = node1_buffer_names & node2_buffer_names
- if not common_buffer_names:
- return -1
- node1_name2dep = {dep.name: dep for dep in node1.read_writes.reads_and_writes()}
- node2_name2dep = {dep.name: dep for dep in node2.read_writes.reads_and_writes()}
- # Find the commons buffers that has different loop orders
- candidates = []
- for buffer_name in common_buffer_names:
- lhs_dep = node1_name2dep[buffer_name]
- rhs_dep = node2_name2dep[buffer_name]
- if (
- lhs_dep.normalize_with_stride_order()
- == rhs_dep.normalize_with_stride_order()
- ):
- candidates.append(
- (
- V.graph.sizevars.size_hint(lhs_dep.get_numel(), fallback=0),
- lhs_dep,
- rhs_dep,
- )
- )
- if len(candidates) == 0:
- return -1
- # Pick the largest buffer to guide the loop reordering
- _numel, lhs_dep, rhs_dep = max(candidates, key=operator.itemgetter(0))
- if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep):
- return -1
- if lhs_dep.num_vars != rhs_dep.num_vars:
- # this can happen due to we don't merge loops.
- # We can not do loop reordering in this case right now
- # Simply returning true if the two Deps are the same after
- # normalization (merging loops)
- if lhs_dep.normalize() == rhs_dep.normalize():
- return self.dep_size_hint(lhs_dep)
- return -1
- reordered = False
- # Only reorder loops for pointwise for now
- if not node1.is_reduction():
- reordered = node1.reorder_loops_by_dep_pair(lhs_dep, rhs_dep)
- elif not node2.is_reduction():
- reordered = node2.reorder_loops_by_dep_pair(rhs_dep, lhs_dep)
- else:
- loop_ordering_log.debug(
- "Don't reorder loops since both nodes are reductions: %s v.s. %s",
- node1.get_name(),
- node2.get_name(),
- )
- return (
- typing.cast(int, self.score_fusion_memory(node1, node2))
- if reordered
- else -1
- )
- def unfusable_node(self, node: BaseSchedulerNode) -> bool:
- """
- Is this node unfusable under any conditions.
- """
- return (
- isinstance(node, (ExternKernelSchedulerNode, NopKernelSchedulerNode))
- and not node.is_template()
- and not is_output_of_multi_outputs_template(node.node)
- )
- def check_prologue_fusion_heuristics_fusable(
- self,
- prologue_node: BaseSchedulerNode,
- template_node: BaseSchedulerNode,
- why: WhyNoFuse,
- ) -> bool:
- """
- Heuristics to avoid benchmarking predictably slow prologue fusions
- """
- # user opt into more aggressive prologue fusion, dont use heuristics
- if prologue_node.get_operation_names() <= V.graph.invoke_quant_ops:
- return True
- read_bytes = prologue_node.get_read_buffer_sizes()
- write_bytes = prologue_node.get_write_buffer_sizes()
- # Initially, only do fusions which will result in fewer memory accesses inside of the template to avoid
- # potential bad cache behavior and shared memory use.
- # we also want to avoid benchmarking reliably unprofitable fusions like downcasts from fp32 -> fp16 inside kernel.
- # allowing gathers by allowing increasing write_bytes by small factor
- # TODO - make configurable per input, for instance, bias can fuse fp32 -> fp16 profitably
- BYTES_THRESHOLD_MULTIPLIER = 1.1
- if read_bytes > (write_bytes * BYTES_THRESHOLD_MULTIPLIER):
- why("prologue fusion will not increase amount of bytes read in kernel")
- return False
- # we want to avoid attempting to fuse predictably unprofitable prologues
- # such as increasing the unaligned reads or writes.
- # TODO - would be nice to generalize this, however, we would need more explicit
- # knowledge of memory access patterns in the TritonTemplate in order to know
- # the stride order to check alignment.
- origins = tuple(
- e.target
- for n in prologue_node.get_nodes()
- if n.node is not None
- for e in n.node.get_origins()
- if e.op == "call_function"
- )
- if origins == (torch.ops.aten.constant_pad_nd.default,):
- why(
- "prologue fusion will not increase attempt to fuse in padding bc it increases unaligned reads"
- )
- return False
- def low_prec_fp(dtype: torch.dtype) -> bool:
- return dtype.itemsize <= 2 and dtype.is_floating_point
- if (
- low_prec_fp(template_node.get_template_node_or_throw().dtype)
- and not prologue_node.can_codegen_in_low_precision()
- ):
- why(
- "prologue fusion that must be upcast to fp32 not profitable for low precision templates"
- )
- return False
- return True
- def get_expand_dim_for_pointwise_nodes(
- self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
- ) -> Optional[tuple[int, SchedulerNode, sympy.Expr]]:
- """
- Fusing two small pointwise nodes significantly reduces kernel overhead
- and launch overhead. However, slightly different sizes would prevent fusion.
- Here, we decide if expanding sizes of one node is profitible by allowing
- fusion, and returns the dimension to expand, node with smaller sizes,
- and new size after expand.
- """
- # only support scheduler node
- if not isinstance(node1, SchedulerNode) or not isinstance(node2, SchedulerNode):
- return None
- # only support computued buffer
- if not (
- isinstance(node1.node, ir.ComputedBuffer)
- and isinstance(node2.node, ir.ComputedBuffer)
- ):
- return None
- # does not support mutation yet since relying on index mod to handle
- # out-of-boundary access.
- if node1.has_aliasing_or_mutation() or node2.has_aliasing_or_mutation():
- return None
- # skip halide which does not support mod for index
- if config.cpu_backend == "halide":
- return None
- # only support pointwise nodes with the same reduction size
- n1_sizes, n2_sizes = node1._sizes, node2._sizes
- n1_iter_sizes, n1_reduce_sizes = n1_sizes
- n2_iter_sizes, n2_reduce_sizes = n2_sizes
- if (
- node1.is_reduction()
- or node2.is_reduction()
- or n1_reduce_sizes != n2_reduce_sizes
- or len(n1_iter_sizes) != len(n2_iter_sizes)
- ):
- return None
- # only support nodes with 1 write for simplification
- if len(node1.read_writes.writes) > 1 or len(node2.read_writes.writes) > 1:
- return None
- # When memory access is small, reducing gpu kernel overhead is profitable over
- # slightly larger memory access.
- node1_write_memory = self.dep_size_hint(next(iter(node1.read_writes.writes)))
- node2_write_memory = self.dep_size_hint(next(iter(node1.read_writes.writes)))
- if (
- max(node1_write_memory, node2_write_memory)
- > config.small_memory_access_threshold
- ):
- return None
- # does not support reinplace since `index % boundary` may lead to
- # race condition
- def has_reusable_buffer(node: BaseSchedulerNode) -> bool:
- for read in node.read_writes.reads:
- input_buf: Optional[Union[SchedulerBuffer, SchedulerDonatedBuffer]]
- if read.name in self.name_to_donated_buffer:
- input_buf = self.name_to_donated_buffer[read.name]
- else:
- input_buf = self.name_to_buf.get(read.name)
- if (
- input_buf
- and V.graph.wrapper_code.can_reuse(input_buf, node)
- and not isinstance(input_buf.defining_op, NopKernelSchedulerNode)
- ):
- return True
- return False
- if has_reusable_buffer(node1) or has_reusable_buffer(node2):
- return None
- # only support nodes with 1 mismatch dimension
- mismatch_dimensions = []
- for idx, (n1_size, n2_size) in enumerate(zip(n1_iter_sizes, n2_iter_sizes)):
- if n1_size != n2_size:
- mismatch_dimensions.append(idx)
- if len(mismatch_dimensions) != 1:
- return None
- mismatch_dim = mismatch_dimensions[0]
- mismatch_size1, mismatch_size2 = (
- n1_iter_sizes[mismatch_dim],
- n2_iter_sizes[mismatch_dim],
- )
- if V.graph.sizevars.statically_known_lt(mismatch_size1, mismatch_size2):
- return mismatch_dim, node1, mismatch_size2
- elif V.graph.sizevars.statically_known_lt(mismatch_size2, mismatch_size1):
- return mismatch_dim, node2, mismatch_size1
- else:
- return None
- def can_fuse(
- self,
- node1: BaseSchedulerNode,
- node2: BaseSchedulerNode,
- can_reorder: bool = False,
- allow_mix_order_reduction: bool = True,
- ) -> bool:
- """
- Determine if it is possible to combine node1 and node2 into a
- single fused node.
- """
- if node1 is node2:
- return False
- if isinstance(node1, FusedMixOrderReductions):
- return node1.can_fuse_with(node2)
- if isinstance(node2, FusedMixOrderReductions):
- # We don't fuse something before a FusedMixOrderReductions
- # right now
- return False
- why = WhyNoFuse(node1, node2)
- if node1.is_template() and self.get_backend(
- node1.get_device()
- ).can_fuse_multi_outputs_template(node1, node2):
- return True
- if isinstance(node1, GroupedSchedulerNode) or isinstance(
- node2, GroupedSchedulerNode
- ):
- why("grouped node must not be fused with other nodes")
- return False
- if (
- isinstance(node1, (ExternKernelSchedulerNode, NopKernelSchedulerNode))
- and not node1.is_template()
- ):
- why("node1 is extern or nop")
- return False
- if (
- isinstance(node2, (ExternKernelSchedulerNode, NopKernelSchedulerNode))
- and not node2.is_template()
- ):
- why("node2 is extern or nop")
- return False
- if node2.get_operation_names() & node1.ancestors:
- why("node1 must go before node2")
- return False
- if node2.is_template():
- if not config.prologue_fusion:
- why("prologue fusion turned off")
- return False
- if node1.is_reduction() or node1.is_template():
- why("prologue fusion only supported for pointwise nodes")
- return False
- template = node2.get_template_node_or_throw()
- if not isinstance(template, ir.TritonTemplateBuffer):
- why("prologue fusion only supported for TritonTemplates")
- return False
- allowed_prologue_inps = template.get_allowed_prologue_inps()
- unsupported_prologue_args = (
- OrderedSet(inp.get_name() for inp in template.inputs) # type: ignore[union-attr]
- - allowed_prologue_inps
- )
- if node1.get_buffer_names() & unsupported_prologue_args:
- why("prologue fusion not implemented for kernel for these inputs")
- return False
- if node1.has_aliasing_or_mutation() or node1.has_aliasing_or_mutation():
- why("template prologue can only fuse functional pointwise nodes")
- return False
- prologue_nodes = node1.get_nodes()
- for node in prologue_nodes[:-1]:
- node_outs = node.get_outputs()
- for out in node_outs:
- if not all(user.node in prologue_nodes for user in out.users):
- why("template prologue can only fuse nodes with a single use")
- return False
- template_snodes = (
- [node2]
- if not isinstance(node2, FusedSchedulerNode)
- else [n for n in node2.snodes if n.is_template()]
- )
- assert len(template_snodes) == 1
- template_snode = template_snodes[0]
- if not (
- len(prologue_nodes[-1].outputs) == 1
- and len(prologue_nodes[-1].outputs[0].users) == 1
- and prologue_nodes[-1].outputs[0].users[0].node is template_snode
- ):
- why(
- "template prologue can only fuse nodes with a single use into template"
- )
- return False
- if not self.check_prologue_fusion_heuristics_fusable(node1, node2, why):
- return False
- if node1.is_template() and (
- node2.has_aliasing_or_mutation()
- or node2.is_reduction()
- or not config.epilogue_fusion
- ):
- why("template epilogue not satisfied")
- return False
- if (node1.get_buffer_names() & V.graph.no_fuse_buffer_names) or (
- node2.get_buffer_names() & V.graph.no_fuse_buffer_names
- ):
- why("fusion for buffer explicit disabled")
- return False
- device = node1.get_device()
- device2 = node2.get_device()
- if device != device2:
- why("device mismatch (%s vs %s)", device, device2)
- return False
- del device2
- shared_data_score = self.score_fusion_memory(
- node1, node2, allow_mix_order_reduction=allow_mix_order_reduction
- )
- assert isinstance(shared_data_score, int)
- if (
- can_reorder
- and shared_data_score < config.score_fusion_memory_threshold
- and config.loop_ordering_after_fusion
- ):
- new_shared_data_score = self.shared_data_after_reordering_loop(node1, node2)
- if new_shared_data_score >= 0:
- shared_data_score = new_shared_data_score
- if config.expand_dimension_for_pointwise_nodes and (
- expand_analysis := self.get_expand_dim_for_pointwise_nodes(node1, node2)
- ):
- (expand_dim, smaller_node, expand_size) = expand_analysis
- smaller_node.expand_dimension_for_pointwise_node(expand_dim, expand_size)
- shared_data_score = self.score_fusion_memory(node1, node2)
- assert isinstance(shared_data_score, int)
- if (
- config.loop_index_inversion_in_fusion
- and shared_data_score < config.score_fusion_memory_threshold
- ):
- new_shared_data_score = self.shared_data_after_inverting_indexing(
- node1, node2
- )
- if new_shared_data_score >= 0:
- shared_data_score = new_shared_data_score
- if loop_ordering_log.isEnabledFor(logging.DEBUG):
- loop_ordering_log.debug(
- "%s and %s has %s shared data",
- node1.get_name(),
- node2.get_name(),
- shared_data_score,
- )
- if not V.choices.can_fuse(self, node1, node2, shared_data_score):
- return False
- if node1.get_operation_names() & node2.ancestors:
- # node2 depends on node1 outputs
- return (
- self.can_fuse_vertical(node1, node2)
- and V.choices.can_fuse_vertical(self, node1, node2, shared_data_score)
- and self.get_backend(device).can_fuse_vertical(node1, node2)
- )
- else: # nodes don't depend on each other, but may have common reads
- return V.choices.can_fuse_horizontal(
- self, node1, node2, shared_data_score
- ) and self.get_backend(device).can_fuse_horizontal(node1, node2)
- def can_fuse_vertical(
- self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
- ) -> bool:
- """
- Check if it is legal to fuse a consumer (node2) into a producer (node1).
- We can fuse them if all the reads of node2 either match
- corresponding writes in node1, or are written by nodes that can
- be scheduled before the fusion of node1 and node2.
- """
- node1_buf_names = node1.get_buffer_names()
- why = WhyNoFuse(node1, node2)
- remaining_deps_by_name: dict[str, list[Dep]] = defaultdict(list)
- for dep in node2.unmet_dependencies:
- name = self.mutation_renames.get(dep.name, dep.name)
- if isinstance(dep, WeakDep) and self.fusable_weak_dep(dep, node1, node2):
- continue
- remaining_deps_by_name[name].append(dep)
- for cd in node1.read_writes.writes:
- if not isinstance(cd, MemoryDep):
- continue
- remaining = remaining_deps_by_name.get(
- self.mutation_renames.get(cd.name, cd.name)
- )
- if remaining:
- for rd in remaining:
- if self.fusable_read_and_write(rd, cd):
- remaining.remove(rd) # noqa: B909
- remaining_deps = OrderedSet(
- dep.name
- for dep in itertools.chain.from_iterable(remaining_deps_by_name.values())
- )
- if remaining_deps & node1_buf_names:
- # MemoryDeps didn't match and read different locations of the same buffer.
- # Examples here include:
- # - MemoryDep("foo", x) != MemoryDep("foo", x + 1)
- # - MemoryDep("foo", x) != StarDep("foo")
- why("memory deps did not match")
- return False
- node1_op_names = node1.get_operation_names()
- for name in remaining_deps:
- op_name = self.name_to_buf[name].defining_op_name()
- if node1_op_names & self.name_to_fused_node[op_name].ancestors:
- why("intermediate nodes between node1 & node2")
- return False
- return True
- def fusable_weak_dep(
- self, weak_dep: WeakDep, node1: BaseSchedulerNode, node2: BaseSchedulerNode
- ) -> bool:
- if weak_dep.name not in node1.get_buffer_names():
- return False
- # A weak dep can be fused if and only if the fused operation acts inplace
- # on the buffer being mutated. i.e. the same index is being read then mutated
- mutating_writes = [
- write
- for write in node2.read_writes.writes
- if write.name == weak_dep.mutating_buf
- ]
- if len(mutating_writes) != 1:
- return False
- write = mutating_writes[0]
- if isinstance(write, StarDep):
- return False
- assert isinstance(write, MemoryDep)
- if free_symbol_is_type(write.index, SymT.TMP):
- return False
- real_name = self.mutation_real_name[weak_dep.mutating_buf]
- relevant_reading_nodes = [node1]
- if isinstance(node1, ForeachKernelSchedulerNode):
- relevant_reading_nodes = node1.snodes
- num_concurrent_reads = 0
- for reading_node in relevant_reading_nodes:
- relevant_reads = [
- read
- for read in reading_node.read_writes.reads
- if read.name == real_name
- ]
- if not relevant_reads:
- continue
- num_concurrent_reads += 1
- if not all(
- isinstance(read, MemoryDep)
- and not free_symbol_is_type(read.index, SymT.TMP)
- and read.index == write.index
- and read.size == write.size
- for read in relevant_reads
- ):
- return False
- return num_concurrent_reads <= 1
- # StarDep doesn't match MemoryDep, different indices don't match
- # However, broadcasting sometimes strips dimensions, and if that's the case
- # we still can match unmet dep
- # if there's indirect indexing, don't match it
- def fusable_read_and_write(self, read: Dep, write: MemoryDep) -> bool:
- if isinstance(read, MemoryDep):
- read_name = self.mutation_renames.get(read.name, read.name)
- if (
- read_name != write.name
- or free_symbol_is_type(read.index, SymT.TMP)
- or free_symbol_is_type(write.index, SymT.TMP)
- ):
- return False
- if config.loop_ordering_after_fusion and read.num_vars != write.num_vars:
- # Need merge loops if we do loop ordering after fusion since
- # we have not merged the loops yet when creating the scheduler
- # nodes.
- read = read.normalize()
- write = write.normalize()
- # Operations like index_add_, scatter_add_, etc. require global
- # synchronization - all threads must complete writes before any reads.
- # These cannot be safely fused into the same kernel. Atomic modes and TMA stores require synchronization barriers
- if self.mode_requires_synchronization(write.mode):
- return False
- return (
- read.index == write.index
- and len(read.size) >= len(write.size)
- and read.size[: len(write.size)] == write.size
- )
- elif isinstance(read, StarDep):
- read_name = self.mutation_renames.get(read.name, read.name)
- write_name = self.mutation_renames.get(write.name, write.name)
- if (
- read.mode == write.mode
- and write.mode is not None
- and read_name == write_name
- ):
- return True
- return False
- def dep_size_hint(self, dep: Dep, count_bytes: bool = True) -> int:
- return V.graph.get_dep_size_hint(dep, count_bytes)
- def score_fusion_memory(
- self,
- node1: BaseSchedulerNode,
- node2: BaseSchedulerNode,
- count_bytes: bool = True,
- return_is_mix_order_reduction: bool = False,
- allow_mix_order_reduction: bool = True,
- ) -> int | tuple[int, bool]:
- """
- The first term in our fusion score that estimates number of saved
- memory operations.
- """
- def _construct_return_value(score, is_mix_order_reduction):
- return (
- (score, is_mix_order_reduction)
- if return_is_mix_order_reduction
- else score
- )
- if allow_mix_order_reduction and MixOrderReduction.can_fuse(node1, node2):
- # The fusion score for mix order reduction only count
- # numel so far. It's actually fine. This makes other fusions
- # sharing the same amount of numels go first; but make
- # fusions only share weight/bias go later.
- score = MixOrderReduction.get_fusion_score(node1, node2)
- return _construct_return_value(score, True)
- node1_dep_len = len(node1.read_writes.reads) + len(node1.read_writes.writes)
- node2_dep_len = len(node2.read_writes.reads) + len(node2.read_writes.writes)
- # optimization: iter over smaller set
- if min(node1_dep_len, node2_dep_len) * 4 < max(node1_dep_len, node2_dep_len):
- if node1_dep_len > node2_dep_len:
- node1, node2 = node2, node1
- deps = [
- dep
- for dep in node1.read_writes.reads | node1.read_writes.writes
- if dep in node2.read_writes.reads or dep in node2.read_writes.writes
- ]
- return _construct_return_value(
- sum(self.dep_size_hint(dep, count_bytes) for dep in deps), False
- )
- common_memory_deps = (node1.read_writes.reads | node1.read_writes.writes) & (
- node2.read_writes.reads | node2.read_writes.writes
- )
- return _construct_return_value(
- sum(self.dep_size_hint(dep) for dep in common_memory_deps), False
- )
- def get_possible_fusions_with_highest_priority(
- self, possible_fusions: list[tuple[BaseSchedulerNode, BaseSchedulerNode]]
- ) -> list[tuple[BaseSchedulerNode, BaseSchedulerNode]]:
- # Group the possible fusions based on their priority from the backend.
- # Only return the group of possible fusions with highest priority.
- if len(possible_fusions) == 0:
- return possible_fusions
- possible_fusions_group_by_priority: dict[
- int, list[tuple[BaseSchedulerNode, BaseSchedulerNode]]
- ] = {}
- for node1, node2 in possible_fusions:
- assert node1.get_device() == node2.get_device()
- device = node1.get_device()
- fusion_pair_priority = int(
- self.get_backend(device).get_fusion_pair_priority(node1, node2)
- )
- if fusion_pair_priority not in possible_fusions_group_by_priority:
- possible_fusions_group_by_priority[fusion_pair_priority] = [
- (node1, node2),
- ]
- else:
- possible_fusions_group_by_priority[fusion_pair_priority].append(
- (node1, node2)
- )
- # return the possible fusions with highest priority
- possible_fusions_with_highest_priority = min(
- possible_fusions_group_by_priority.items(), key=operator.itemgetter(0)
- )[1]
- assert len(possible_fusions_with_highest_priority) > 0
- return possible_fusions_with_highest_priority
- def score_fusion_key(
- self, nodes: tuple[BaseSchedulerNode, BaseSchedulerNode]
- ) -> Any:
- """
- Shim for list.sort(key=...)
- """
- return V.choices.score_fusion(self, *nodes)
- def compute_last_usage(self) -> None:
- """
- Populate node.last_usage recursively (also for the nodes within a FusedSchedulerNode)
- """
- future_used_buffers = OrderedSet(V.graph.get_output_names())
- for node in reversed(self.nodes):
- node.set_last_usage(future_used_buffers, self.mutation_real_name)
- future_used_buffers.update(node.last_usage)
- def free_buffers(self) -> None:
- """Free any buffers that are no longer needed"""
- for name in sorted(
- self.buffer_names_to_free
- - V.graph.removed_buffers
- - V.graph.wrapper_code.freed # type: ignore[has-type]
- ):
- if name in self.name_to_buf:
- buf = self.name_to_buf[name]
- if buf.can_free():
- V.graph.wrapper_code.codegen_free(buf.node)
- elif name in V.graph.graph_inputs:
- inp = V.graph.graph_inputs[name]
- if isinstance(inp, ir.TorchBindObject):
- V.graph.wrapper_code.codegen_free(inp)
- elif isinstance(inp, ir.GeneratorState):
- continue
- else:
- storage = inp.data
- assert (
- isinstance(storage, ir.StorageBox) and storage.is_input_buffer()
- )
- V.graph.wrapper_code.codegen_free(storage.data)
- self.buffer_names_to_free.clear()
- def flush(self) -> None:
- for backend in self.backends.values():
- backend.flush()
- self.free_buffers()
- def codegen_extern_call(self, scheduler_node: ExternKernelSchedulerNode) -> None:
- assert isinstance(scheduler_node, ExternKernelSchedulerNode)
- # 'decide_inplace_update' stores the inplace update decisions in
- # the current kernel from where 'allocate' retrieve those decisions.
- # We have to make sure there is a non-NULL kernel handler to store
- # those inplace update decisions.
- counters["inductor"]["extern_calls"] += 1
- with V.set_kernel_handler(Kernel(increase_kernel_count=False)):
- scheduler_node.decide_inplace_update()
- scheduler_node.mark_run()
- node = scheduler_node.node
- assert isinstance(node, ir.ExternKernel), f"{type(node)=}"
- node.codegen(V.graph.wrapper_code)
- self.free_buffers()
- def create_backend(self, device: torch.device) -> BaseScheduling:
- assert not is_gpu(device.type) or device.index is not None, (
- f"{device} should have been normalized in lowering"
- )
- V.graph.add_device_info(device)
- device_scheduling = get_scheduling_for_device(device.type)
- if device_scheduling is None:
- raise RuntimeError(f"Unsupported device type: {device.type}")
- if not has_triton():
- if (
- device.type == "cuda"
- and (device_props := torch.cuda.get_device_properties(device)).major < 7
- ):
- raise GPUTooOldForTriton(device_props, inspect.currentframe())
- elif is_gpu(device.type) and not device.type == "mps":
- raise TritonMissing(inspect.currentframe())
- return device_scheduling(self)
- def get_backend(self, device: Optional[torch.device]) -> BaseScheduling:
- assert device is not None
- if device not in self.backends:
- self.backends[device] = self.create_backend(device)
- return self.backends[device]
- def enter_context(self, node: BaseSchedulerNode) -> None:
- def get_order(n: torch.fx.Node) -> int:
- if n not in self.origin_to_index:
- self.origin_to_index.update({n: i for i, n in enumerate(n.graph.nodes)})
- return self.origin_to_index[n]
- # Use a dict to have ordering
- origins = {
- (get_order(e), e): None
- for n in node.get_nodes()
- if n.node is not None
- for e in n.node.get_origins()
- }
- origins = list(origins.keys())
- if origins:
- _, last = max(origins, key=operator.itemgetter(0))
- V.graph.wrapper_code.enter_context(last)
- def can_buffer_be_removed_through_fusion(
- self, name: str, fused_node_names: OrderedSet[str]
- ) -> bool:
- try:
- users = self.name_to_buf[name].users
- except KeyError:
- return False
- return (
- all(user.is_weak or user.get_name() in fused_node_names for user in users)
- and name not in self.mutation_renames
- and name not in self.mutation_real_name
- )
- def should_partition(self, node: BaseSchedulerNode) -> Optional[str]:
- """
- Return the reason why we should partition the inductor graph on this node,
- or None if the node is cudagraphable.
- """
- # Allow users to manually specify if a node should be partitioned
- # Can only do this for FallbackKernels
- ir_node = node.node
- if isinstance(ir_node, torch._inductor.ir.FallbackKernel) and (
- op := ir_node.op_overload
- ):
- op_overload_packet_name, op_overload_name = get_op_names(op)
- if (
- op_overload_packet_name in config.custom_should_partition_ops
- or op_overload_name in config.custom_should_partition_ops
- ):
- assert isinstance(op, torch._ops.OpOverload)
- return f"custom partition op: {op_overload_name}"
- # When not using cudagraphs, keep all kernels in the `call` function
- # instead of graph partition functions, since graph partition only brings
- # benefit to cudagraph
- if (
- not torch._inductor.config.triton.cudagraphs
- and _unstable_customized_partition_wrapper.wrapper is None
- ):
- return "partition includes all ops when cudagraphs is disabled"
- if isinstance(node, FusedSchedulerNode):
- for snode in node.snodes:
- reason = self.should_partition(snode)
- if reason:
- return reason
- return None
- assert node.node is not None
- if not node.is_gpu():
- return f"{node.get_device()} ops"
- if isinstance(node.node, ir.DeviceCopy):
- return "DeviceCopy ops"
- if isinstance(node.node, ir.Conditional):
- return "Conditional ops"
- if getattr(node.node, "unbacked_bindings", None):
- return "unbacked binding ops"
- if is_cudagraph_unsafe_op(node.node):
- return "CUDAGraph-unsafe custom ops"
- if reason := self._uses_cudagraph_unsafe_unbacked_symint(node):
- return reason
- # Partition around nodes with dynamic shapes when cudagraph_skip_dynamic_graphs is enabled
- if config.triton.cudagraph_skip_dynamic_graphs:
- if get_scheduler_node_symbol_uses(node):
- return "dynamic shape ops"
- return None
- @cache_on_self
- def _get_cudagraph_unsafe_unbacked_symints(self) -> OrderedSet[sympy.Symbol]:
- """
- Collect output unbacked symints from ops in config.cudagraph_unsafe_unbacked_ops.
- """
- unsafe_symints: OrderedSet[sympy.Symbol] = OrderedSet()
- if not config.cudagraph_unsafe_unbacked_ops:
- return unsafe_symints
- for node in self.nodes:
- ir_node = node.node
- if ir_node is None:
- continue
- if not isinstance(ir_node, torch._inductor.ir.FallbackKernel):
- continue
- op = ir_node.op_overload
- if op is None:
- continue
- op_overload_packet_name, op_overload_name = get_op_names(op)
- if (
- op_overload_packet_name not in config.cudagraph_unsafe_unbacked_ops
- and op_overload_name not in config.cudagraph_unsafe_unbacked_ops
- ):
- continue
- for sym in ir_node.get_unbacked_symbol_defs():
- sym = V.graph.sizevars.simplify(sym)
- if symbol_is_type(sym, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT)):
- unsafe_symints.add(sym)
- return unsafe_symints
- def _uses_cudagraph_unsafe_unbacked_symint(
- self, node: BaseSchedulerNode
- ) -> Optional[str]:
- unsafe_symints = self._get_cudagraph_unsafe_unbacked_symints()
- if not unsafe_symints:
- return None
- node_symbols = get_scheduler_node_symbol_uses(node)
- for sym in node_symbols:
- simplified_sym = V.graph.sizevars.simplify(sym)
- for free_sym in simplified_sym.free_symbols:
- if free_sym in unsafe_symints:
- return f"uses cudagraph-unsafe unbacked symint: {free_sym}"
- return None
- def get_name_to_nodes(
- self,
- ) -> dict[str, Union[ir.IRNode, ir.TorchBindObject, sympy.Expr]]:
- """
- Return a mapping from name strings to the corresponding graph inputs or
- base scheduler node outputs.
- """
- name_to_node: dict[str, Union[ir.IRNode, ir.TorchBindObject, sympy.Expr]] = {}
- name_to_node.update(V.graph.graph_inputs)
- for node in self.nodes:
- for name, scheduler_buffer in node.outputs_by_name.items():
- name_to_node[name] = scheduler_buffer.node
- return name_to_node
- def compute_graph_partition_maps(
- self,
- signatures: list[GraphPartitionSignature],
- ) -> None:
- """
- computes a mapping from partition input/output indices to graph input/output
- indices for each partition.
- """
- name_to_graph_input_index = {
- name: idx for idx, name in enumerate(V.graph.graph_inputs)
- }
- name_to_graph_output_index = {
- name: idx for idx, name in enumerate(V.graph.get_output_names())
- }
- V.graph.partition_maps = []
- for partition_id, signature in enumerate(signatures):
- if signature.skip_cudagraph:
- # Note: [Graph Partition Map for CUDAGraph]
- # number of partition map should be the same as the number of generated
- # partition functions. This assumption will be used when cudagraphify
- # each partition function.
- continue
- input_mapping = []
- for name in signature.input_nodes:
- input_mapping.append(name_to_graph_input_index.get(name))
- output_mapping = []
- for node in signature.output_nodes:
- output_mapping.append(name_to_graph_output_index.get(node.get_name()))
- V.graph.partition_maps.append(
- GraphPartitionMap(
- partition_id,
- input_mapping,
- output_mapping,
- signature.constant_names,
- )
- )
- def get_graph_partition_symbol_inputs(
- self,
- partition: PartitionType,
- input_nodes: dict[str, Union[ir.IRNode, ir.TorchBindObject, sympy.Expr]],
- ) -> OrderedSet[sympy.Symbol]:
- """
- Returns all symbol inputs which are required to be in scope to successfully
- perform codegen for this graph partition, including:
- - free symbols used in partition nodes
- - free symbols in partition input/node shapes, strides, and offsets. This is needed
- for recording cudagraphs for tensors with dynamic shapes.
- """
- def get_input_node_symbols(
- node: Union[ir.IRNode, sympy.Expr, ir.TorchBindObject],
- ) -> OrderedSet[sympy.Symbol]:
- """
- Gets symbols used in input node shapes, strides, and offsets.
- """
- if isinstance(node, ir.TorchBindObject):
- # TorchBindObject does not involve dynamic shapes yet
- return OrderedSet()
- elif isinstance(node, ir.IRNode):
- return get_layout_symints(node)
- else:
- # node cannot be sympy.Expr since node comes from read_writes and
- # read_writes does not contain sympy.Expr
- raise NotImplementedError(f"Unsupported input node type: {type(node)}")
- def filter_symbols(
- symbols: OrderedSet[sympy.Symbol],
- ) -> OrderedSet[sympy.Symbol]:
- """
- Filters a set of symbols that are required for codegen. Skip symbols
- that are always internal to kernels, such as SymT.TMP, SymT.INDEX,
- and SymT.R0_INDEX.
- """
- return OrderedSet(
- s
- for s in symbols
- if symbol_is_type(
- s,
- (
- SymT.SIZE,
- SymT.FLOAT,
- SymT.UNBACKED_INT,
- SymT.UNBACKED_FLOAT,
- ),
- )
- )
- candidate_symbols: OrderedSet[sympy.Symbol] = OrderedSet().union(
- *(get_scheduler_node_symbol_uses(node) for node in partition)
- )
- candidate_symbols.update(
- *(get_input_node_symbols(node) for _, node in input_nodes.items())
- )
- candidate_symbols = filter_symbols(candidate_symbols)
- res: OrderedSet[sympy.Symbol] = OrderedSet()
- for s in candidate_symbols:
- symplified_s = V.graph.sizevars.simplify(s)
- # use free_symbols only when s is simplified to an Integer or expr
- res.update(symplified_s.free_symbols)
- return OrderedSet(sorted(res, key=operator.attrgetter("name")))
- def get_graph_partition_signature(
- self, partitions: list[PartitionType], skip_cudagraphs: list[bool]
- ) -> list[GraphPartitionSignature]:
- """
- Gets signature for each graph partition, including input nodes, output nodes, and
- whether deallocating an input within graph partition.
- """
- signatures = []
- unmet_output_names = OrderedSet(V.graph.get_output_names())
- name_to_node = self.get_name_to_nodes()
- def is_unallocated_buffer(buf_name: str) -> bool:
- """
- Checks if buf_name resolves to a NoneLayout buffer (following mutation_real_name).
- Buffers with NoneLayout are not allocated so graph partition should not
- take them as inputs or outputs.
- """
- buf = self.name_to_buf.get(buf_name, None)
- if buf is None:
- return False
- if isinstance(buf.node.layout, NoneLayout):
- # If there's a mutation real name, check the underlying buffer
- # This handles both MutationOutput and other mutation ops like
- # IndexPutFallback that have NoneLayout but mutate real buffers
- if real_name := self.mutation_real_name.get(buf_name, None):
- return is_unallocated_buffer(real_name)
- return True
- return False
- for partition, skip_cudagraph in zip(
- reversed(partitions), reversed(skip_cudagraphs)
- ):
- output_names: OrderedSet[str] = OrderedSet()
- for node in partition:
- output_names.update(node.outputs_by_name.keys())
- returned_output_names = output_names.intersection(unmet_output_names)
- # all reads/writes are partition inputs except those generated
- # within the partition and tensor constants
- read_writes = dependencies.ReadWrites.merge_list(
- [node.read_writes for node in partition]
- )
- # WeakDep is fake dependency on unused buffer. It should not appear
- # in partition_input_names for inputs that are actually read or written.
- partition_input_names = (
- OrderedSet(
- [
- x.name
- for x in read_writes.reads | read_writes.writes
- if not isinstance(x, WeakDep)
- ]
- )
- - output_names
- )
- partition_input_names = OrderedSet(
- self.mutation_real_name.get(name, name)
- for name in partition_input_names
- )
- buffer_names_to_free: OrderedSet[str] = OrderedSet()
- for node in partition:
- buffer_names_to_free.update(node.last_usage)
- # buffer_names_to_free may contain buffers allocated in previous
- # graph partitions. These buffers should also be a partition
- # input.
- extra_input_names = [
- name
- for name in (buffer_names_to_free - output_names)
- if name in name_to_node
- ]
- partition_input_names.update(extra_input_names)
- input_nodes = {
- name: name_to_node[name]
- for name in partition_input_names
- if name in name_to_node
- }
- input_deallocation = {
- name: name in buffer_names_to_free
- for name in partition_input_names
- if name in name_to_node
- }
- # if an input tensor is not freed in the partition function, it should
- # also be returned as an output. This brings benefits to cudagraph
- # since the returned output tensor is a cudagraph managed tensor with
- # a static tensor address.
- extra_output_names = [
- name
- for name in partition_input_names
- if name in name_to_node and name not in buffer_names_to_free
- ]
- returned_output_names.update(extra_output_names)
- returned_output_names = OrderedSet(
- self.mutation_real_name.get(name, name)
- for name in returned_output_names
- )
- output_nodes = [
- name_to_node[name]
- for name in returned_output_names
- if not is_unallocated_buffer(name)
- ]
- constant_names = [
- name for name in partition_input_names if name in V.graph.constants
- ]
- symbol_inputs = self.get_graph_partition_symbol_inputs(
- partition, input_nodes
- )
- partition_signature = GraphPartitionSignature(
- symbol_inputs,
- input_nodes,
- output_nodes,
- input_deallocation,
- skip_cudagraph,
- constant_names,
- )
- signatures.append(partition_signature)
- unmet_output_names = partition_input_names.union(
- unmet_output_names - returned_output_names
- )
- return signatures[::-1]
- def clean_removed_buffer_from_partition_signatures(
- self, signature: GraphPartitionSignature
- ) -> GraphPartitionSignature:
- """
- Updates the partition signature by removing buffers specified in
- V.graph.removed_buffers. See [Note: Removed Graph Partition Arguments]
- """
- input_nodes = {
- name: buffer
- for name, buffer in signature.input_nodes.items()
- if name not in V.graph.removed_buffers
- }
- input_deallocation = {
- name: val
- for name, val in signature.input_deallocation.items()
- if name not in V.graph.removed_buffers
- }
- output_nodes = [
- node
- for node in signature.output_nodes
- if node.maybe_get_name() not in V.graph.removed_buffers
- ]
- constant_names = [
- name
- for name in signature.constant_names
- if name not in V.graph.removed_buffers
- ]
- return GraphPartitionSignature(
- signature.symbol_inputs,
- input_nodes,
- output_nodes,
- input_deallocation,
- signature.skip_cudagraph,
- constant_names,
- )
- def reorder_for_minimizing_partition(
- self,
- nodes: list[BaseSchedulerNode],
- ) -> list[BaseSchedulerNode]:
- """
- Reorder nodes to minimize the number of partitions via a bfs
- topological sort. This is the optimal reordering such that the
- number of partitions cannot be reduced further. This may be
- sub-optimal for other metrics such as peak memory. This does not
- change relative orders of two cudagraphable nodes, nor the
- relative order of two non_cudagraphable nodes.
- """
- import heapq
- node_to_indegree: dict[BaseSchedulerNode, int] = dict()
- cudagraphable_nodes: list[tuple[int, BaseSchedulerNode]] = []
- non_cudagraphable_nodes: list[tuple[int, BaseSchedulerNode]] = []
- node_to_index = {node: idx for idx, node in enumerate(nodes)}
- def insert_pending_nodes(node: BaseSchedulerNode) -> None:
- node_with_index = (node_to_index[node], node)
- if self.should_partition(node):
- heapq.heappush(non_cudagraphable_nodes, node_with_index)
- else:
- heapq.heappush(cudagraphable_nodes, node_with_index)
- def update_indegree(node: BaseSchedulerNode) -> None:
- for succ_node in node.mpi_node.succ_nodes:
- assert node_to_indegree[succ_node] > 0
- node_to_indegree[succ_node] -= 1
- if node_to_indegree[succ_node] == 0:
- insert_pending_nodes(succ_node)
- for node in nodes:
- node_to_indegree[node] = len(node.mpi_node.pred_nodes)
- if node_to_indegree[node] == 0:
- insert_pending_nodes(node)
- schedule: list[BaseSchedulerNode] = []
- num_iters: int = 0
- while num_iters < len(nodes) and (
- non_cudagraphable_nodes or cudagraphable_nodes
- ):
- while non_cudagraphable_nodes:
- _, node = heapq.heappop(non_cudagraphable_nodes)
- schedule.append(node)
- update_indegree(node)
- while cudagraphable_nodes:
- _, node = heapq.heappop(cudagraphable_nodes)
- schedule.append(node)
- update_indegree(node)
- num_iters += 1
- if num_iters > len(nodes):
- raise RuntimeError(
- """
- Failed to schedule, while loop ran too long when
- reordering for minimizing the num of partitions
- """
- )
- return schedule
- def maybe_reorder_for_minimizing_partition(
- self,
- nodes: list[BaseSchedulerNode],
- ) -> list[BaseSchedulerNode]:
- """
- Reorder nodes to minimize the number of partitions if this only slightly
- increase peak memory.
- """
- from .memory import estimate_peak_memory, prepare_planning_info
- graph_outputs = OrderedSet(V.graph.get_output_names())
- default_peak_memory, name_to_freeable_input_buf = prepare_planning_info(
- nodes,
- self.name_to_buf,
- self.name_to_fused_node,
- OrderedSet(V.graph.graph_inputs.keys()),
- graph_outputs,
- )
- reordered_nodes = self.reorder_for_minimizing_partition(nodes)
- reorder_peak_memory, _ = estimate_peak_memory(
- reordered_nodes, name_to_freeable_input_buf, graph_outputs
- )
- # 1.1 here means 10% extra peak memory budget which is quite arbitrary
- if reorder_peak_memory < default_peak_memory * 1.1:
- return reordered_nodes
- return nodes
- def reorder_for_partition_with_simple_dependency(
- self, nodes: list[BaseSchedulerNode]
- ) -> list[BaseSchedulerNode]:
- """
- Reorder a node if it should be partitioned and has simple dependency:
- 1. move a partitioned node to the front if it has no dependency
- 2. move a partitioned node to the back if it is only used by OutputNode
- 3. otherwise do not reorder
- """
- front: list[BaseSchedulerNode] = []
- middle: list[BaseSchedulerNode] = []
- back: list[BaseSchedulerNode] = []
- def only_output_user(node: BaseSchedulerNode) -> bool:
- for buf in node.get_outputs():
- for use in buf.users:
- if not isinstance(use.node, OutputNode):
- return False
- return True
- for node in nodes:
- should_partition = self.should_partition(node) is not None
- if should_partition and len(node.unmet_dependencies) == 0:
- front.append(node)
- elif should_partition and only_output_user(node):
- back.append(node)
- else:
- middle.append(node)
- return front + middle + back
- def graph_partition(
- self,
- ) -> tuple[list[PartitionType], list[GraphPartitionSignature]]:
- """
- Given a list of BaseSchedulerNodes, split into a list of
- graph partitions and compute partition input/output signatures.
- """
- partitions: list[PartitionType] = []
- skip_cudagraph = True
- cur_partition: PartitionType = []
- skip_cudagraphs = []
- for node in self.nodes:
- node_should_partition = self.should_partition(node) is not None
- if cur_partition and skip_cudagraph != node_should_partition:
- partitions.append(cur_partition)
- skip_cudagraphs.append(skip_cudagraph)
- cur_partition = []
- skip_cudagraph = node_should_partition
- cur_partition.append(node)
- if cur_partition:
- partitions.append(cur_partition)
- skip_cudagraphs.append(skip_cudagraph)
- signatures = self.get_graph_partition_signature(
- partitions=partitions, skip_cudagraphs=skip_cudagraphs
- )
- self.compute_graph_partition_maps(signatures)
- self._log_graph_partitions(partitions, signatures)
- return partitions, signatures
- def _log_graph_partitions(
- self,
- partitions: list[PartitionType],
- signatures: list[GraphPartitionSignature],
- ) -> None:
- if not cudagraphs_log.isEnabledFor(logging.DEBUG):
- return
- # Don't log partition reasons for CPU-only graphs since cudagraph
- # partitioning is not relevant when there are no GPU devices
- has_gpu_device = any(is_gpu(device) for device in V.graph.device_types)
- if not has_gpu_device:
- return
- cudagraphable_count = sum(1 for s in signatures if not s.skip_cudagraph)
- non_cudagraphable_count = len(signatures) - cudagraphable_count
- cudagraphs_log.debug(
- "Created %d graph partitions: %d cudagraphable, %d non-cudagraphable",
- len(partitions),
- cudagraphable_count,
- non_cudagraphable_count,
- )
- for i, (partition, signature) in enumerate(zip(partitions, signatures)):
- cudagraphs_log.debug(
- " Partition %d: %d nodes, %s, inputs=%d, outputs=%d",
- i,
- len(partition),
- "non-cudagraphable" if signature.skip_cudagraph else "cudagraphable",
- len(signature.input_nodes),
- len(signature.output_nodes),
- )
- if signature.skip_cudagraph:
- # Log details for each non-cudagraphable node
- for node in partition:
- self._log_non_cudagraphable_node(node)
- def _log_non_cudagraphable_node(self, node: BaseSchedulerNode) -> None:
- """Log details for a non-cudagraphable node."""
- reason = self.should_partition(node)
- if not reason:
- return
- node_name = node.get_name()
- fx_node = node.node.get_origin_node() if node.node is not None else None
- parts = [f"reason={reason}"]
- ir_type = type(node.node).__name__
- parts.append(f"ir={ir_type}")
- if fx_node is not None:
- fx_str = f"{fx_node.target}({', '.join(str(a) for a in fx_node.args)})"
- parts.append(f"fx={fx_str}")
- cudagraphs_log.debug(" %s: %s", node_name, ", ".join(parts))
- # Log full stack trace if available
- if fx_node is not None:
- stack_trace = fx_node.meta.get("stack_trace", None)
- if stack_trace:
- for line in stack_trace.strip().split("\n"):
- cudagraphs_log.debug(" %s", line)
- def codegen(self) -> None:
- with dynamo_timed("Scheduler.codegen"):
- return (
- self._codegen_partitions()
- if torch._inductor.config.graph_partition
- else self._codegen(self.nodes)
- )
- def _codegen_partition_wrapper(
- self,
- partition: PartitionType,
- signature: GraphPartitionSignature,
- ) -> None:
- """Codegen a partition given its inputs/outputs"""
- from .codegen.wrapper import SubgraphPythonWrapperCodegen
- parent_wrapper_code = V.graph.wrapper_code
- graph_partition_id = next(self._graph_partition_counter)
- with V.graph.set_current_wrapper_code():
- V.graph.init_wrapper_code(
- is_subgraph=True,
- subgraph_name=f"partition_{graph_partition_id}",
- parent_wrapper_code=parent_wrapper_code,
- partition_signatures=signature,
- )
- self._codegen(partition)
- # Note: [Removed Graph Partition Arguments]
- # Graph partition relies on node.read_writes to analyze the partition
- # inputs and outputs. However, during codegen, we may decide some buffers
- # are internal to a kernel (e.g., triton kernel) such that these buffers
- # are never actually defined. This information is collected during codegen
- # and recorded in V.graph.removed_buffers. So we cleanup signature and write
- # prefix (i.e., generating call function and return outputs) after we have
- # codegen the partition.
- assert isinstance(V.graph.wrapper_code, SubgraphPythonWrapperCodegen)
- signature = self.clean_removed_buffer_from_partition_signatures(signature)
- V.graph.wrapper_code.partition_signatures = signature
- V.graph.wrapper_code.write_prefix()
- graph_name = V.graph.name
- partition_code, _ = V.graph.wrapper_code.generate(V.graph.is_inference)
- V.graph.wrapper_code.define_subgraph_launcher_fn(graph_name, partition_code)
- V.graph.wrapper_code.codegen_partition_call(graph_partition_id, signature)
- V.graph.wrapper_code.allocated.update( # type: ignore[has-type]
- [node.get_name() for node in signature.output_nodes]
- )
- def use_default_device_context(
- self, partitions: list[PartitionType], signatures: list[GraphPartitionSignature]
- ) -> contextlib.AbstractContextManager[None]:
- @contextlib.contextmanager
- def ctx() -> Iterator[None]:
- self.update_graph_partition_default_device(partitions, signatures)
- if self.default_device_context and device_need_guard(
- self.default_device_context.type
- ):
- assert self.default_device_context.index is not None, (
- "device should have an index"
- )
- V.graph.wrapper_code.codegen_device_guard_enter(
- self.default_device_context.index
- )
- try:
- yield
- finally:
- if self.default_device_context and device_need_guard(
- self.default_device_context.type
- ):
- V.graph.wrapper_code.codegen_device_guard_exit()
- self.default_device_context = None
- return ctx()
- def update_graph_partition_default_device(
- self, partitions: list[PartitionType], signatures: list[GraphPartitionSignature]
- ) -> None:
- # Note: [Graph Partition Device Contexts]
- # Entering a device context takes 60 microseconds and exiting a device
- # context takes 20 microseconds. If all graph partitions and
- # cudagraph-unsafe ops happen on the same device, we can share the
- # device context.
- if len(partitions) == 1 and not signatures[0].skip_cudagraph:
- # If there is only 1 cudagraph partition, the device context
- # should happen within the cudagraph partition, which
- # would be removed by cudagraph.
- return
- def get_cudagraph_partition_device(partition: PartitionType) -> torch.device:
- partition_device = partition[0].get_device()
- assert partition_device is not None
- return partition_device
- def all_on_target_device(
- partition: PartitionType, target_device: torch.device
- ) -> bool:
- for node in partition:
- device = node.get_device()
- if device != target_device:
- return False
- return True
- cudagraph_partition_device = None
- for partition, signature in zip(partitions, signatures):
- if not signature.skip_cudagraph:
- cudagraph_partition_device = get_cudagraph_partition_device(partition)
- break
- # all partitions skip cudagraph
- if cudagraph_partition_device is None:
- return
- for partition, signature in zip(partitions, signatures):
- if signature.skip_cudagraph and not all_on_target_device(
- partition, cudagraph_partition_device
- ):
- return
- self.default_device_context = cudagraph_partition_device
- def _codegen_partitions(self) -> None:
- """
- Split nodes into partitions and codegen each partition into separate functions.
- This allows further applying different optimizations (e.g., cudagraph) to
- each function.
- """
- partitions, signatures = self.graph_partition()
- if len(partitions) > 1:
- counters["inductor"]["cudagraph_partitions"] += len(partitions)
- with self.use_default_device_context(partitions, signatures):
- for partition, signature in zip(partitions, signatures):
- assert len(partition) >= 1, (
- f"Each partition must have at least one node but found {len(partition)}"
- )
- if signature.skip_cudagraph:
- self._codegen(partition)
- else:
- self._codegen_partition_wrapper(partition, signature)
- num_partitions = next(self._graph_partition_counter)
- V.graph.wrapper_code.set_all_partition_names(num_partitions)
- # See [Note: Graph Partition Map for CUDAGraph]
- if num_partitions > 0:
- assert V.graph.partition_maps is not None
- assert num_partitions == len(V.graph.partition_maps), (
- f"Expect {num_partitions} partition maps but got {len(V.graph.partition_maps)}"
- )
- def _codegen(self, nodes: list[BaseSchedulerNode]) -> None:
- if config.check_stack_no_cycles_TESTING_ONLY:
- import torch._dynamo.convert_frame
- stack = traceback.extract_stack()
- seen: OrderedSet[tuple[str, int | None]] = OrderedSet()
- for frame in reversed(stack):
- # This is where maybe_cprofile is
- if (
- frame.name == "_compile_inner"
- and frame.filename == torch._dynamo.convert_frame.__file__
- ):
- break
- key = (frame.filename, frame.lineno)
- assert key not in seen, (
- f"Duplicate stack frame {frame.filename}:{frame.lineno}; "
- "did you add a decorator to one of the functions in this stack "
- "trace? If so, try using a context manager instead."
- )
- seen.add(key)
- self.current_device = self.default_device_context
- assert self.previous_node is None
- # pyrefly: ignore [unbound-name]
- if self.default_device_context and config.triton.autotune_at_compile_time:
- V.graph.wrapper_code.write_get_raw_stream_header()
- for node in nodes:
- if log.isEnabledFor(logging.DEBUG):
- try:
- log.debug(
- "Generating code for node %s with estimated runtime %f",
- node.get_name(),
- node.get_estimated_runtime(),
- )
- except Exception:
- log.debug(
- "Generating code for node %s with estimated runtime 0.0",
- node.get_name(),
- )
- self.enter_context(node)
- if device := node.get_device():
- if (
- device != self.current_device
- or node.is_extern()
- or node.is_template()
- ):
- self.flush()
- if device != self.current_device:
- if self.current_device and device_need_guard(
- self.current_device.type
- ):
- V.graph.wrapper_code.codegen_device_guard_exit()
- self.current_device = device
- if device_need_guard(device.type):
- assert device.index is not None, "device should have an index"
- V.graph.wrapper_code.codegen_device_guard_enter(device.index)
- self.current_node = node
- self.buffer_names_to_free.update(node.last_usage)
- if node.is_template():
- prologue, template_node, epilogue = node.get_prologue_template_epilogue(
- list(node.get_nodes())
- )
- # pyrefly: ignore [unbound-name]
- self.get_backend(device).codegen_template(
- template_node, epilogue, prologue
- )
- elif node.is_extern():
- node = typing.cast(ExternKernelSchedulerNode, node)
- self.codegen_extern_call(node)
- elif node.is_foreach():
- node = typing.cast(ForeachKernelSchedulerNode, node)
- # pyrefly: ignore [unbound-name]
- backend_ = self.get_backend(device)
- from .codegen.cuda_combined_scheduling import CUDACombinedScheduling
- from .codegen.simd import SIMDScheduling
- if isinstance(backend_, (SIMDScheduling, CUDACombinedScheduling)):
- backend = backend_
- else:
- raise AssertionError(f"{type(self)=}")
- backend.codegen_combo_kernel(node)
- elif isinstance(node, FusedMixOrderReductions):
- # pyrefly: ignore [unbound-name]
- self.get_backend(device).codegen_mix_order_reduction(node)
- elif isinstance(node, (FusedSchedulerNode, SchedulerNode)):
- # pyrefly: ignore [unbound-name]
- self.get_backend(device).codegen_node(node)
- else:
- assert isinstance(node, NopKernelSchedulerNode)
- node.mark_run()
- # pyrefly: ignore [unbound-name]
- if config.triton.debug_sync_kernel:
- # pyrefly: ignore [unbound-name]
- self.get_backend(device).codegen_sync()
- self.available_buffer_names.update(node.get_buffer_names())
- self.completed_operations.update(node.get_operation_names())
- if not isinstance(node, NopKernelSchedulerNode):
- device = node.get_device()
- if (
- device is not None
- and device.type != "meta"
- and self.get_backend(device).ready_to_flush()
- ):
- self.flush()
- if all(isinstance(n, SchedulerNode) for n in node.get_nodes()):
- self.previous_node = node
- else:
- self.previous_node = None
- if self.current_device != self.default_device_context:
- # when default_device_context is not None, we are codegen
- # for graph partitions and all nodes must be on
- # the same default device.
- assert self.current_device is not None
- if device_need_guard(self.current_device.type):
- # exit the outermost CUDA device guard. this is
- # important for nested indentation codegen-ing.
- V.graph.wrapper_code.codegen_device_guard_exit()
- self.previous_node = None
- self.flush()
- def benchmark_combo_kernel(
- self, node_list: Sequence[BaseSchedulerNode], node_benchmark_results
- ) -> tuple[float, float, list[Optional[str]]]:
- """
- Benchmark fused list of nodes and return the execution time
- in milliseconds on randomly generated inputs.
- """
- device = node_list[0].get_device()
- V.graph.scheduler = self
- self.current_device = device
- assert device is not None
- backend = self.get_backend(device)
- return backend.benchmark_combo_kernel(node_list, node_benchmark_results)
- def speedup_by_combo_kernel(self, nodes: list[BaseSchedulerNode]) -> bool:
- """
- If config.benchmark_fusion is False, always return True.
- Otherwise, return True if fusion can brings speedup.
- """
- subkernel_nodes = nodes
- device = subkernel_nodes[0].get_device()
- assert all(node.get_device() == device for node in subkernel_nodes), (
- "All nodes in a combo kernel group must be on the same device"
- )
- if not config.benchmark_combo_kernel:
- return True
- from triton.compiler.errors import CompilationError
- ms1, path1_list = 0.0, []
- node_benchmark_results = {}
- for i, snode in enumerate(subkernel_nodes):
- node_list = snode.get_nodes()
- # We can not accurately benchmark kernel using atomic_add
- # due to how we generate random integer inputs.
- if self._any_atomic_add(node_list):
- fusion_log.debug(
- "ComboKernel: benchmarking may not accurate due to atomic_add"
- )
- try:
- ms, path = self.benchmark_fused_nodes(node_list)
- node_benchmark_results[snode] = (ms, path)
- if math.isinf(ms):
- fusion_log.debug(
- "ComboKernel benchmark: register spilling of %d-th subkernel",
- i,
- )
- return False
- except CompilationError as e:
- # workaround triton issue: https://github.com/triton-lang/triton/issues/2151
- if "Loop-carried variable" in str(e):
- fusion_log.debug(
- "ComboKernel benchmark: return True because of loop-carried variable"
- )
- return True # allow fusion
- else:
- raise
- ms1 += ms
- path1_list.append(path)
- try:
- ms2, ms2_clone, _path2_list = self.benchmark_combo_kernel(
- subkernel_nodes, node_benchmark_results
- )
- except CompilationError as e:
- # workaround triton issue: https://github.com/triton-lang/triton/issues/2151
- if "Loop-carried variable" in str(e):
- fusion_log.debug(
- "ComboKernel benchmark: return True because of loop-carried variable"
- )
- return True # allow fusion
- else:
- raise
- # small kernels are very likely to have speedup but hard to benchmark. So we skip benchmarking.
- small_kernel = ms2 - ms2_clone < 0.3 or ms1 < 0.3
- if fusion_log.isEnabledFor(logging.DEBUG):
- if ms1 > ms2 or small_kernel:
- fusion_log.debug(
- "can fuse (benchmark): fusing causes %sx speedup",
- green_text(f"{ms1 / ms2:.3f}"),
- )
- else:
- fusion_log.debug(
- "cannot fuse (benchmark): fusing causes %sx slowdown",
- red_text(f"{ms1 / ms2:.3f}"),
- )
- # ms1 returned by benchmark_fused_nodes discounted clone time
- return ms2 - ms2_clone < ms1 or small_kernel
- def get_buffer_layout(self, buf_name: str) -> ir.Layout:
- buf = self.name_to_buf[buf_name]
- assert buf.node is not None
- return buf.node.get_layout()
- def update_zero_dim_cpu_tensor(self) -> None:
- for node in self.nodes:
- if node.is_gpu():
- for read in node.read_writes.reads:
- buffer = V.graph.name_to_buffer.get(read.name)
- if (
- buffer
- and get_device_type(buffer) == "cpu"
- and not isinstance(
- buffer.layout, (NoneLayout, MultiOutputLayout)
- )
- and buffer.get_size() == []
- ):
- V.graph.zero_dim_cpu_tensor_list.add(read.name)
- class BaseScheduling: # noqa: docstring_linter
- def __init__(self, scheduler: Optional[Scheduler]):
- super().__init__()
- self.scheduler = scheduler
- def free_buffers_in_scheduler(self) -> None:
- if self.scheduler:
- self.scheduler.free_buffers()
- def get_backend_features(self, device: torch.device) -> OrderedSet[BackendFeature]:
- """Return a set of .codegen.common.BackendFeature()"""
- return OrderedSet()
- def can_fuse_vertical(
- self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
- ) -> bool:
- """
- Check whether node1 and node2 can be vertically fused or not.
- """
- raise NotImplementedError
- def can_fuse_horizontal(
- self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
- ) -> bool:
- """
- Check whether node1 and node2 can be horizontally fused or not.
- """
- raise NotImplementedError
- def can_fuse_multi_outputs_template(
- self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
- ) -> bool:
- """
- A Multi-Output Template (referenced in #144012) is a template node
- with MultiOutputLayout, and its output buffers are instances of MultiOutput.
- In this context, we verify whether node1 represents the Multi-Output Template
- and node2 corresponds to one of its outputs. If so, we further check if
- backend supports this fusion.
- """
- return False
- def fuse(
- self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
- ) -> FusedSchedulerNode:
- """
- Fuse two nodes
- """
- if node1.is_foreach() or node2.is_foreach():
- return ForeachKernelSchedulerNode.fuse(node1, node2)
- elif MixOrderReduction.are_mix_order_reductions(node1, node2):
- return FusedMixOrderReductions(node1, node2)
- elif isinstance(node1, FusedMixOrderReductions):
- return node1.fuse_with(node2)
- else:
- return FusedSchedulerNode.fuse(node1, node2)
- def group_fn(
- self, sizes: Sequence[Sequence[sympy.Expr]]
- ) -> tuple[tuple[sympy.Expr, ...], ...]:
- """
- Process the iteration sizes in case a transformation needs to be applied.
- """
- raise NotImplementedError
- def codegen_template(
- self,
- template_node: BaseSchedulerNode,
- epilogue_nodes: Sequence[BaseSchedulerNode],
- prologue_nodes: Sequence[BaseSchedulerNode],
- ) -> Optional[str]:
- """
- Given a template node, generate a kernel.
- This function is only available for triton now. If the third-party backend behaves as a sub-class
- of TritonScheduling, it can override it or reuse it.
- """
- raise NotImplementedError
- def generate_kernel_code_from_nodes(
- self,
- nodes: Sequence[BaseSchedulerNode],
- benchmark_kernel: bool,
- hint_override: Optional[int] = None,
- ) -> str:
- """
- Generate a kernel given a list of pre-fused nodes.
- """
- raise NotImplementedError
- def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]) -> None:
- """
- Generate a kernel given a list of pre-fused nodes.
- """
- raise NotImplementedError
- def codegen_mix_order_reduction(self, node: FusedMixOrderReductions) -> None:
- raise NotImplementedError
- def codegen_sync(self) -> None:
- """
- Generate synchronization code for the kernel. This method depends on the hardware characteristics.
- """
- raise NotImplementedError
- def ready_to_flush(self) -> bool:
- """
- Check whether the backend is requesting the scheduler to flush the generated kernel.
- If not supported, please return False.
- """
- return False
- def flush(self) -> None:
- """
- Flush the generated kernel and python wrapper code to the source code file.
- """
- raise NotImplementedError
- def benchmark_fused_nodes(
- self, nodes: Sequence[BaseSchedulerNode]
- ) -> tuple[float, str]:
- """
- Benchmark fused list of nodes and return the execution time
- in milliseconds on randomly generated inputs.
- """
- raise NotImplementedError
- def benchmark_codegened_module(self, module: ModuleType) -> tuple[float, str]:
- """
- Benchmark a compiled module and return the execution time
- in milliseconds on randomly generated inputs.
- """
- raise NotImplementedError
- def get_fusion_pair_priority(
- self, node1: BaseSchedulerNode, node2: BaseSchedulerNode
- ) -> int:
- """
- Return an unsigned integer which represents the priority of this fusion pair.
- The smaller is with higher priority.
- """
- return 0
- def benchmark_combo_kernel(
- self, node_list: Sequence[BaseSchedulerNode], node_benchmark_results
- ) -> tuple[float, float, list[Optional[str]]]:
- """
- Benchmark the list of nodes to combine and return the execution time
- and memory copy time in milliseconds on randomly generated inputs.
- """
- raise NotImplementedError
- def codegen_comment(
- self,
- node_schedule: Sequence[BaseSchedulerNode],
- kernel_name: Optional[str] = None,
- ) -> None:
- if kernel_name:
- from torch._inductor.debug import set_kernel_post_grad_provenance_tracing
- debug_handle = set_kernel_post_grad_provenance_tracing(
- node_schedule, # type: ignore[arg-type]
- kernel_name,
- )
- V.graph.wrapper_code.write_provenance_debug_handle(
- kernel_name, debug_handle
- )
|