lowering.py 255 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490549154925493549454955496549754985499550055015502550355045505550655075508550955105511551255135514551555165517551855195520552155225523552455255526552755285529553055315532553355345535553655375538553955405541554255435544554555465547554855495550555155525553555455555556555755585559556055615562556355645565556655675568556955705571557255735574557555765577557855795580558155825583558455855586558755885589559055915592559355945595559655975598559956005601560256035604560556065607560856095610561156125613561456155616561756185619562056215622562356245625562656275628562956305631563256335634563556365637563856395640564156425643564456455646564756485649565056515652565356545655565656575658565956605661566256635664566556665667566856695670567156725673567456755676567756785679568056815682568356845685568656875688568956905691569256935694569556965697569856995700570157025703570457055706570757085709571057115712571357145715571657175718571957205721572257235724572557265727572857295730573157325733573457355736573757385739574057415742574357445745574657475748574957505751575257535754575557565757575857595760576157625763576457655766576757685769577057715772577357745775577657775778577957805781578257835784578557865787578857895790579157925793579457955796579757985799580058015802580358045805580658075808580958105811581258135814581558165817581858195820582158225823582458255826582758285829583058315832583358345835583658375838583958405841584258435844584558465847584858495850585158525853585458555856585758585859586058615862586358645865586658675868586958705871587258735874587558765877587858795880588158825883588458855886588758885889589058915892589358945895589658975898589959005901590259035904590559065907590859095910591159125913591459155916591759185919592059215922592359245925592659275928592959305931593259335934593559365937593859395940594159425943594459455946594759485949595059515952595359545955595659575958595959605961596259635964596559665967596859695970597159725973597459755976597759785979598059815982598359845985598659875988598959905991599259935994599559965997599859996000600160026003600460056006600760086009601060116012601360146015601660176018601960206021602260236024602560266027602860296030603160326033603460356036603760386039604060416042604360446045604660476048604960506051605260536054605560566057605860596060606160626063606460656066606760686069607060716072607360746075607660776078607960806081608260836084608560866087608860896090609160926093609460956096609760986099610061016102610361046105610661076108610961106111611261136114611561166117611861196120612161226123612461256126612761286129613061316132613361346135613661376138613961406141614261436144614561466147614861496150615161526153615461556156615761586159616061616162616361646165616661676168616961706171617261736174617561766177617861796180618161826183618461856186618761886189619061916192619361946195619661976198619962006201620262036204620562066207620862096210621162126213621462156216621762186219622062216222622362246225622662276228622962306231623262336234623562366237623862396240624162426243624462456246624762486249625062516252625362546255625662576258625962606261626262636264626562666267626862696270627162726273627462756276627762786279628062816282628362846285628662876288628962906291629262936294629562966297629862996300630163026303630463056306630763086309631063116312631363146315631663176318631963206321632263236324632563266327632863296330633163326333633463356336633763386339634063416342634363446345634663476348634963506351635263536354635563566357635863596360636163626363636463656366636763686369637063716372637363746375637663776378637963806381638263836384638563866387638863896390639163926393639463956396639763986399640064016402640364046405640664076408640964106411641264136414641564166417641864196420642164226423642464256426642764286429643064316432643364346435643664376438643964406441644264436444644564466447644864496450645164526453645464556456645764586459646064616462646364646465646664676468646964706471647264736474647564766477647864796480648164826483648464856486648764886489649064916492649364946495649664976498649965006501650265036504650565066507650865096510651165126513651465156516651765186519652065216522652365246525652665276528652965306531653265336534653565366537653865396540654165426543654465456546654765486549655065516552655365546555655665576558655965606561656265636564656565666567656865696570657165726573657465756576657765786579658065816582658365846585658665876588658965906591659265936594659565966597659865996600660166026603660466056606660766086609661066116612661366146615661666176618661966206621662266236624662566266627662866296630663166326633663466356636663766386639664066416642664366446645664666476648664966506651665266536654665566566657665866596660666166626663666466656666666766686669667066716672667366746675667666776678667966806681668266836684668566866687668866896690669166926693669466956696669766986699670067016702670367046705670667076708670967106711671267136714671567166717671867196720672167226723672467256726672767286729673067316732673367346735673667376738673967406741674267436744674567466747674867496750675167526753675467556756675767586759676067616762676367646765676667676768676967706771677267736774677567766777677867796780678167826783678467856786678767886789679067916792679367946795679667976798679968006801680268036804680568066807680868096810681168126813681468156816681768186819682068216822682368246825682668276828682968306831683268336834683568366837683868396840684168426843684468456846684768486849685068516852685368546855685668576858685968606861686268636864686568666867686868696870687168726873687468756876687768786879688068816882688368846885688668876888688968906891689268936894689568966897689868996900690169026903690469056906690769086909691069116912691369146915691669176918691969206921692269236924692569266927692869296930693169326933693469356936693769386939694069416942694369446945694669476948694969506951695269536954695569566957695869596960696169626963696469656966696769686969697069716972697369746975697669776978697969806981698269836984698569866987698869896990699169926993699469956996699769986999700070017002700370047005700670077008700970107011701270137014701570167017701870197020702170227023702470257026702770287029703070317032703370347035703670377038703970407041704270437044704570467047704870497050705170527053705470557056705770587059706070617062706370647065706670677068706970707071707270737074707570767077707870797080708170827083708470857086708770887089709070917092709370947095709670977098709971007101710271037104710571067107710871097110711171127113711471157116711771187119712071217122712371247125712671277128712971307131713271337134713571367137713871397140714171427143714471457146714771487149715071517152715371547155715671577158715971607161716271637164716571667167716871697170717171727173717471757176717771787179718071817182718371847185718671877188718971907191719271937194719571967197719871997200720172027203720472057206720772087209721072117212721372147215721672177218721972207221722272237224722572267227722872297230723172327233723472357236723772387239724072417242724372447245724672477248724972507251725272537254725572567257725872597260726172627263726472657266726772687269727072717272727372747275727672777278727972807281728272837284728572867287728872897290729172927293729472957296729772987299730073017302730373047305730673077308730973107311731273137314731573167317731873197320732173227323732473257326732773287329733073317332733373347335733673377338733973407341734273437344734573467347734873497350735173527353735473557356735773587359736073617362736373647365736673677368736973707371737273737374737573767377737873797380738173827383738473857386738773887389739073917392739373947395739673977398739974007401740274037404740574067407740874097410741174127413741474157416741774187419742074217422742374247425742674277428742974307431743274337434743574367437743874397440744174427443744474457446744774487449745074517452745374547455745674577458745974607461746274637464746574667467746874697470747174727473747474757476747774787479748074817482748374847485748674877488748974907491749274937494749574967497749874997500750175027503750475057506750775087509751075117512751375147515751675177518751975207521752275237524752575267527752875297530753175327533753475357536753775387539754075417542754375447545754675477548754975507551755275537554755575567557755875597560756175627563756475657566756775687569757075717572757375747575757675777578757975807581758275837584758575867587758875897590759175927593759475957596759775987599760076017602760376047605760676077608760976107611761276137614761576167617761876197620762176227623762476257626762776287629763076317632763376347635763676377638763976407641764276437644764576467647764876497650765176527653765476557656765776587659766076617662766376647665766676677668766976707671767276737674767576767677767876797680768176827683768476857686768776887689769076917692769376947695769676977698769977007701770277037704770577067707770877097710771177127713771477157716771777187719772077217722772377247725772677277728772977307731773277337734773577367737773877397740774177427743774477457746774777487749775077517752775377547755775677577758775977607761776277637764776577667767776877697770777177727773777477757776777777787779778077817782778377847785778677877788778977907791779277937794779577967797779877997800780178027803780478057806780778087809781078117812781378147815781678177818781978207821782278237824782578267827782878297830783178327833783478357836783778387839784078417842784378447845784678477848784978507851785278537854785578567857785878597860786178627863786478657866786778687869787078717872787378747875787678777878787978807881788278837884788578867887788878897890789178927893789478957896789778987899790079017902790379047905790679077908790979107911791279137914791579167917791879197920792179227923792479257926792779287929793079317932793379347935
  1. # mypy: allow-untyped-defs
  2. from __future__ import annotations
  3. import contextlib
  4. import dataclasses
  5. import functools
  6. import itertools
  7. import logging
  8. import math
  9. import operator
  10. import os
  11. import textwrap
  12. import warnings
  13. from collections import defaultdict
  14. from collections.abc import Callable, Collection, Iterable, Sequence
  15. from typing import Any, cast, Optional, TYPE_CHECKING, TypeGuard, TypeVar, Union
  16. from typing_extensions import ParamSpec
  17. from unittest.mock import patch
  18. import sympy
  19. import torch
  20. import torch.ao.quantization.fx._decomposed
  21. import torch.fx
  22. import torch.utils._pytree as pytree
  23. from torch._dynamo.utils import counters
  24. from torch._higher_order_ops.associative_scan import associative_scan_op
  25. from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation
  26. from torch._library.fake_class_registry import FakeScriptObject
  27. from torch._library.utils import get_layout_constraint_tag
  28. from torch._prims_common import (
  29. canonicalize_dim,
  30. canonicalize_dims,
  31. check,
  32. dtype_to_type,
  33. elementwise_dtypes,
  34. ELEMENTWISE_TYPE_PROMOTION_KIND,
  35. get_computation_dtype,
  36. is_boolean_dtype,
  37. is_float_dtype,
  38. is_integer_dtype,
  39. Number,
  40. )
  41. from torch.fx.experimental.sym_node import magic_methods, method_to_operator
  42. from torch.fx.experimental.symbolic_shapes import (
  43. free_unbacked_symbols,
  44. has_free_unbacked_symbols,
  45. resolve_unbacked_bindings,
  46. )
  47. from torch.utils._ordered_set import OrderedSet
  48. from torch.utils._sympy.functions import (
  49. CeilDiv,
  50. FloorDiv,
  51. Identity,
  52. Mod,
  53. ModularIndexing,
  54. )
  55. from .._dynamo.utils import import_submodule
  56. from . import config, inductor_prims, ir, test_operators # NOQA: F401
  57. from .decomposition import decompositions, get_decompositions
  58. from .ir import (
  59. BaseView,
  60. DtypeView,
  61. ExpandView,
  62. IndexingConstant,
  63. IRNode,
  64. is_triton,
  65. MutableBox,
  66. OnlineSoftmaxReduction,
  67. ops_wrapper,
  68. PermuteView,
  69. Pointwise,
  70. Reduction,
  71. SqueezeView,
  72. TensorBox,
  73. validate_ir,
  74. View,
  75. )
  76. from .utils import (
  77. ceildiv,
  78. decode_device,
  79. is_dynamic,
  80. is_gpu,
  81. is_pointwise_use,
  82. is_view,
  83. needs_fallback_due_to_atomic_add_limitations,
  84. pad_listlike,
  85. register_op_dtype_propagation_rules,
  86. register_op_requires_libdevice_fp64,
  87. sympy_product,
  88. use_scatter_fallback,
  89. )
  90. from .virtualized import ops, V
  91. if TYPE_CHECKING:
  92. from .ops_handler import ReductionType
  93. _T = TypeVar("_T")
  94. _P = ParamSpec("_P")
  95. # TODO(jansel): we should implement decomps or lowerings for these
  96. # https://github.com/pytorch/torchdynamo/issues/327
  97. FALLBACK_ALLOW_LIST = OrderedSet(
  98. [
  99. "torchvision::roi_align",
  100. "aten::index_add",
  101. ]
  102. )
  103. log = logging.getLogger(__name__)
  104. lowerings: dict[Union[Callable[..., Any], str], Callable[..., Any]] = {}
  105. # Use maybe_layout_constraints to access this dict, we lazily register tag-based layout constraints
  106. _maybe_layout_constraints: dict[
  107. torch._ops.OpOverload, Optional[Callable[..., Any]]
  108. ] = {}
  109. fallbacks = OrderedSet[torch._ops.OpOverload]()
  110. aten = torch.ops.aten
  111. tr_c10d = torch.ops.tr_c10d
  112. prims = torch.ops.prims
  113. needs_realized_inputs = OrderedSet[torch._ops.OpOverload]()
  114. foreach_ops = OrderedSet[torch._ops.OpOverload](
  115. [torch._higher_order_ops._foreach_map] # type: ignore[list-item]
  116. )
  117. # TODO(rec): torch._higher_order_ops._foreach_map is not an OpOverload
  118. # so why is it in foreach_ops?
  119. inplace_foreach_ops = OrderedSet[torch._ops.OpOverload]()
  120. inplaceable_foreach_ops: dict[torch._ops.OpOverload, torch._ops.OpOverload] = {}
  121. quantized_decomposed = torch.ops.quantized_decomposed
  122. def cur_node_has_non_foreach_users() -> bool:
  123. for node in V.graph.current_node.users:
  124. for user in node.users:
  125. if not (user.op == "call_function" and (user.target in foreach_ops)):
  126. return True
  127. return False
  128. # group by device, whether any of the inputs are dynamic
  129. # note arg_pairs may or may not be a pair
  130. # foreach_map for example just passes output buffers here
  131. def group_foreach_args(
  132. arg_pairs: Iterable[Any],
  133. ) -> defaultdict[tuple[Any, bool], list[tuple[int, Any]]]:
  134. out = defaultdict(list)
  135. unpack_args = False
  136. for i, args in enumerate(arg_pairs):
  137. if not isinstance(args, Iterable):
  138. unpack_args = True
  139. args = (args,)
  140. use_foreach = (
  141. not is_dynamic(*args) or config.combo_kernel_foreach_dynamic_shapes
  142. )
  143. device = None
  144. for t in args:
  145. if isinstance(t, TensorBox):
  146. device = t.data.get_device()
  147. break
  148. assert device is not None, "foreach op should have at least one tensor arg"
  149. if unpack_args:
  150. (args,) = args
  151. out[(device, use_foreach)].append((i, args))
  152. return out
  153. def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., Any]]:
  154. """Get layout constraints. Returns None if there are no layout constraints."""
  155. if not isinstance(fn, torch._ops.OpOverload):
  156. # Only OpOverloads have layout constraints.
  157. return None
  158. if maybe_layout_tag := get_layout_constraint_tag(fn, with_default=False):
  159. return tag_to_layout_constraint(maybe_layout_tag)
  160. if fn in _maybe_layout_constraints:
  161. return _maybe_layout_constraints[fn]
  162. return None
  163. def tag_to_layout_constraint(
  164. tag: torch._C.Tag,
  165. ) -> Optional[Callable[..., tuple[Any, Any]]]:
  166. if tag == torch._C.Tag.needs_exact_strides:
  167. return constrain_to_fake_tensors
  168. if tag == torch._C.Tag.needs_contiguous_strides: # type: ignore[attr-defined]
  169. return require_contiguous_strides
  170. if tag == torch._C.Tag.needs_fixed_stride_order:
  171. return constrain_to_fx_strides
  172. if tag == torch._C.Tag.flexible_layout:
  173. return None
  174. raise AssertionError(f"Unknown layout constraint tag: {tag}")
  175. def assert_nyi(cond: bool, msg: str) -> None:
  176. if not cond:
  177. raise NotImplementedError(f"inductor does not support {msg}")
  178. def add_needs_realized_inputs(
  179. fn: Union[
  180. Collection[Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket]],
  181. torch._ops.OpOverload,
  182. torch._ops.OpOverloadPacket,
  183. ],
  184. ) -> Optional[list[Any]]:
  185. if isinstance(fn, (list, set, tuple, OrderedSet)): # noqa: set_linter
  186. # pyrefly: ignore [bad-argument-type]
  187. return [add_needs_realized_inputs(x) for x in fn]
  188. if isinstance(fn, torch._ops.OpOverload):
  189. needs_realized_inputs.add(fn)
  190. elif isinstance(fn, torch._ops.OpOverloadPacket):
  191. needs_realized_inputs.update(
  192. getattr(fn, overload) for overload in fn.overloads()
  193. )
  194. return None
  195. def add_layout_constraint(
  196. fn: Union[torch._ops.OpOverloadPacket, torch._ops.OpOverload],
  197. constraint: Callable[..., tuple[Any, Any]],
  198. ) -> None:
  199. if isinstance(fn, torch._ops.OpOverloadPacket):
  200. for overload in fn.overloads():
  201. _maybe_layout_constraints[getattr(fn, overload)] = constraint
  202. else:
  203. _maybe_layout_constraints[fn] = constraint
  204. add_needs_realized_inputs(
  205. [
  206. aten.as_strided,
  207. aten.as_strided_copy,
  208. aten.avg_pool2d,
  209. aten.avg_pool2d_backward,
  210. aten.bmm,
  211. aten.convolution,
  212. aten.convolution_backward,
  213. aten.max_pool2d_with_indices,
  214. aten.max_pool3d_with_indices,
  215. aten.max_pool2d_with_indices_backward,
  216. aten.mm,
  217. aten.upsample_nearest2d,
  218. aten._upsample_nearest_exact2d,
  219. aten._int_mm,
  220. ]
  221. )
  222. # TODO(jansel): ezyang says we won't need this in the future, try removing it
  223. # based on https://github.com/pytorch/pytorch/blob/9e3eb329df8f701/c10/core/ScalarType.h#L28
  224. DTYPE_ID_LOOKUP = {
  225. 0: torch.uint8,
  226. 1: torch.int8,
  227. 2: torch.int16,
  228. 3: torch.int32,
  229. 4: torch.int64,
  230. 5: torch.float16,
  231. 6: torch.float32,
  232. 7: torch.float64,
  233. 8: torch.complex32,
  234. 9: torch.complex64,
  235. 10: torch.complex32,
  236. 11: torch.bool,
  237. 15: torch.bfloat16,
  238. # TODO(jansel): add quantized types?
  239. # _(c10::qint8, QInt8) /* 12 */
  240. # _(c10::quint8, QUInt8) /* 13 */
  241. # _(c10::qint32, QInt32) /* 14 */
  242. # _(c10::quint4x2, QUInt4x2) /* 16 */
  243. # _(c10::quint2x4, QUInt2x4) /* 17 */
  244. }
  245. def decode_dtype(dtype: Union[int, torch.dtype]) -> torch.dtype:
  246. if not isinstance(dtype, int):
  247. return dtype
  248. assert dtype in DTYPE_ID_LOOKUP, f"id {dtype} missing from DTYPE_ID_LOOKUP"
  249. dtype = DTYPE_ID_LOOKUP[dtype]
  250. return dtype
  251. def is_integer_type(x: Any) -> TypeGuard[Union[TensorBox, sympy.Expr, int]]:
  252. if isinstance(x, TensorBox):
  253. return is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
  254. elif isinstance(x, sympy.Expr):
  255. return x.is_integer is True # type: ignore[attr-defined]
  256. else:
  257. return isinstance(x, int)
  258. def is_boolean_type(x: Any) -> TypeGuard[Union[TensorBox, bool]]:
  259. if isinstance(x, TensorBox):
  260. return is_boolean_dtype(x.get_dtype())
  261. else:
  262. return isinstance(x, bool)
  263. def get_promoted_dtype(
  264. *args: Any,
  265. type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
  266. return_compute_dtype: bool = False,
  267. ) -> torch.dtype:
  268. def construct_input(inp: Any) -> Any:
  269. if isinstance(inp, (Number, sympy.Basic)):
  270. return inp
  271. else:
  272. dim = len(inp.get_size())
  273. # construct a tmp tensor to feed into torch.result_type
  274. return torch.zeros([1] * dim, dtype=inp.get_dtype())
  275. inps = [construct_input(arg) for arg in args]
  276. compute_dtype, result_dtype = elementwise_dtypes(
  277. *inps, type_promotion_kind=type_promotion_kind
  278. )
  279. return compute_dtype if return_compute_dtype else result_dtype
  280. def get_overloads(aten_fn):
  281. if not isinstance(aten_fn, (list, tuple)):
  282. aten_fn = [aten_fn]
  283. else:
  284. aten_fn = list(aten_fn)
  285. for fn in list(aten_fn):
  286. if isinstance(fn, torch._ops.OpOverloadPacket):
  287. for overload in fn.overloads():
  288. other_fn = getattr(fn, overload)
  289. if other_fn not in lowerings:
  290. aten_fn.append(other_fn)
  291. return aten_fn
  292. def in_namespace(
  293. op: Union[Any, torch._ops.OpOverloadPacket, torch._ops.OpOverload], namespace: str
  294. ) -> bool:
  295. if isinstance(op, torch._ops.OpOverloadPacket):
  296. return namespace in op._qualified_op_name
  297. elif isinstance(op, torch._ops.OpOverload):
  298. return namespace in op.name()
  299. return False
  300. def maybe_copy_cpu_scalar(x: TensorBox, device: torch.device) -> TensorBox:
  301. """
  302. Copy cpu scalar if doesn't not match with given `device`
  303. """
  304. if not isinstance(x.data, ir.ReinterpretView) or has_free_unbacked_symbols(
  305. x.get_size()
  306. ):
  307. return x
  308. size = [V.graph.sizevars.size_hint_or_throw(s) for s in x.get_size()]
  309. cur_device = x.get_device()
  310. if (
  311. cur_device is not None
  312. and cur_device.type == "cpu"
  313. and cur_device != device
  314. and (len(size) == 0 or (len(size) == 1 and size[0] == 1))
  315. ):
  316. return TensorBox(ir.StorageBox(ir.DeviceCopy.create(x, cur_device, False)))
  317. return x
  318. def transform_args(
  319. args: list[Any],
  320. kwargs: dict[str, Any],
  321. broadcast: bool,
  322. type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND],
  323. convert_input_to_bool: bool,
  324. ) -> tuple[list[Any], dict[str, Any]]:
  325. """
  326. Transforms arguments for broadcasting and type promotion
  327. """
  328. args_indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)]
  329. kwargs_indices = [k for k, v in kwargs.items() if isinstance(v, TensorBox)]
  330. # check that there's something to transform
  331. if not args_indices and not kwargs_indices:
  332. return args, kwargs
  333. if type_promotion_kind or convert_input_to_bool:
  334. if convert_input_to_bool:
  335. dtype = torch.bool
  336. else:
  337. # FIXME this is a crude approximation for promoting args
  338. promoting_args = [
  339. a
  340. for a in args
  341. if isinstance(a, (Number, sympy.Basic)) or hasattr(a, "dtype")
  342. ]
  343. # only consider tensor kwargs for promotion, for now
  344. promoting_args.extend(a for a in kwargs.values() if hasattr(a, "dtype"))
  345. dtype = get_promoted_dtype(
  346. *promoting_args,
  347. type_promotion_kind=type_promotion_kind, # type: ignore[arg-type]
  348. )
  349. device = (
  350. args[args_indices[0]] if args_indices else kwargs[kwargs_indices[0]]
  351. ).get_device()
  352. for i in args_indices:
  353. args[i] = maybe_copy_cpu_scalar(args[i], device)
  354. for k in kwargs_indices:
  355. kwargs[k] = maybe_copy_cpu_scalar(kwargs[k], device)
  356. # sometimes args are an immutable list so we can't mutate them
  357. def promote(arg: Any) -> Any:
  358. if isinstance(arg, TensorBox):
  359. return to_dtype(arg, dtype)
  360. elif isinstance(arg, ir.Constant):
  361. return ir.Constant(value=arg.value, dtype=dtype, device=device)
  362. else:
  363. return arg
  364. args = [promote(a) for a in args]
  365. kwargs = {k: promote(v) for k, v in kwargs.items()}
  366. if broadcast:
  367. broadcasted = broadcast_tensors(
  368. *list(
  369. itertools.chain(
  370. (args[i] for i in args_indices),
  371. (kwargs[k] for k in kwargs_indices),
  372. )
  373. )
  374. )
  375. size = list(broadcasted[0].get_size())
  376. for i, x in zip(args_indices, broadcasted[: len(args_indices)]):
  377. args[i] = x
  378. for k, x in zip(kwargs_indices, broadcasted[len(args_indices) :]):
  379. kwargs[k] = x
  380. for i in range(len(args)):
  381. if isinstance(args[i], ir.Constant):
  382. args[i] = ExpandView.create(args[i], size)
  383. for k in kwargs:
  384. if isinstance(kwargs[k], ir.Constant):
  385. kwargs[k] = ExpandView.create(kwargs[k], size)
  386. return args, kwargs
  387. def _register_foreach_lowering(
  388. aten_fn: torch._ops.OpOverload, decomp_fn: Callable[..., Any]
  389. ) -> Callable[..., Any]:
  390. """
  391. Add a foreach lowering to lowerings dict.
  392. Arguments:
  393. aten_fn: torch.ops.aten.* fn we are lowering
  394. decomp_fn: alternate implementation on our IR
  395. broadcast: True to apply broadcasting to tensor inputs
  396. type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion
  397. convert_input_to_bool: some logical ops require inputs are converted to bool
  398. """
  399. @functools.wraps(decomp_fn)
  400. def wrapped(*args: Any, **kwargs: Any) -> Any:
  401. out = decomp_fn(*args, **kwargs)
  402. validate_ir(out)
  403. return out
  404. aten_fns = get_overloads(aten_fn)
  405. foreach_ops.update(aten_fns)
  406. lowerings.update(dict.fromkeys(aten_fns, wrapped))
  407. return wrapped
  408. def _register_lowering(
  409. aten_fn,
  410. decomp_fn: Callable[..., Any],
  411. broadcast: bool,
  412. type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND],
  413. convert_input_to_bool: bool,
  414. lowering_dict: dict[Union[Callable[..., Any], str], Callable[..., Any]],
  415. ):
  416. """
  417. Add a lowering to lowerings dict
  418. Arguments:
  419. aten_fn: torch.ops.aten.* fn we are lowering
  420. decomp_fn: alternate implementation on our IR
  421. broadcast: True to apply broadcasting to tensor inputs
  422. type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion
  423. convert_input_to_bool: some logical ops require inputs are converted to bool
  424. """
  425. @functools.wraps(decomp_fn)
  426. def wrapped(*args, **kwargs):
  427. args: list[Any] = list(args)
  428. kwargs: dict[str, Any] = dict(kwargs)
  429. unpacked = False
  430. # TODO maybe we need to use pytrees here
  431. if len(args) == 1 and isinstance(args[0], (list, tuple)):
  432. unpacked = True
  433. args = list(args[0])
  434. if not all(
  435. (fn in fallbacks or in_namespace(fn, "_c10d_functional")) for fn in aten_fn
  436. ):
  437. # explicitly assert for "out=" ops for better error messages
  438. assert not any(x == "out" for x in kwargs), "out= ops aren't yet supported"
  439. args, kwargs = transform_args(
  440. args, kwargs, broadcast, type_promotion_kind, convert_input_to_bool
  441. )
  442. if unpacked:
  443. args = [args]
  444. out = decomp_fn(*args, **kwargs)
  445. validate_ir(out)
  446. return out
  447. aten_fn = get_overloads(aten_fn)
  448. lowering_dict.update(dict.fromkeys(aten_fn, wrapped))
  449. return wrapped
  450. def register_lowering(
  451. aten_fn,
  452. broadcast=False,
  453. type_promotion_kind: Optional[
  454. ELEMENTWISE_TYPE_PROMOTION_KIND
  455. ] = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  456. convert_input_to_bool=False,
  457. lowering_dict=lowerings,
  458. ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
  459. """
  460. Shim to support decorator syntax.
  461. """
  462. return functools.partial(
  463. _register_lowering,
  464. aten_fn,
  465. broadcast=broadcast,
  466. type_promotion_kind=type_promotion_kind,
  467. convert_input_to_bool=convert_input_to_bool,
  468. lowering_dict=lowering_dict,
  469. )
  470. def broadcast_symbolic_shapes(a, b):
  471. """
  472. Broadcasting logic based on symbolic shapes.
  473. We give the shapes 0 and 1 concrete values, while all other shapes
  474. are symbolic sympy formulas.
  475. """
  476. b = tuple(b)
  477. if not a or a == b:
  478. return b
  479. output = []
  480. for x, y in itertools.zip_longest(reversed(a), reversed(b), fillvalue=sympy.S.One):
  481. if V.graph.sizevars.is_size_one_or_false(y):
  482. output.append(x)
  483. elif V.graph.sizevars.is_size_one_or_false(x):
  484. output.append(y)
  485. else:
  486. V.graph.sizevars.check_equals(x, y)
  487. if len(sympy.expand(y).free_symbols) < len(sympy.expand(x).free_symbols):
  488. output.append(y) # prefer shorter formula
  489. else:
  490. output.append(x)
  491. return tuple(reversed(output))
  492. def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=None):
  493. assert override_return_dtype is None or type_promotion_kind is None, (
  494. "only one of override_return_dtype or type_promotion_kind may be given"
  495. )
  496. if override_return_dtype is None and type_promotion_kind is None:
  497. type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  498. if not any(isinstance(x, (sympy.Basic, int, float)) for x in inputs):
  499. return inputs
  500. if all(isinstance(x, (int, float, sympy.Basic)) for x in inputs):
  501. dtype = override_return_dtype or get_promoted_dtype(
  502. *inputs,
  503. # pyrefly: ignore [bad-argument-type]
  504. type_promotion_kind=type_promotion_kind,
  505. )
  506. def const_func(x):
  507. if isinstance(x, sympy.Basic):
  508. return ir.IndexingConstant(
  509. index=x, dtype=dtype, device=decode_device(None)
  510. )
  511. else:
  512. return ir.Constant(value=x, dtype=dtype, device=decode_device(None))
  513. return [const_func(x) for x in inputs]
  514. ex = next(x for x in inputs if isinstance(x, (TensorBox, ExpandView, ir.Constant)))
  515. out = []
  516. for x in inputs:
  517. if isinstance(x, (int, float)):
  518. out.append(
  519. ExpandView.create(
  520. ir.Constant(
  521. value=x, dtype=ex.get_dtype(), device=ex.get_device_or_error()
  522. ),
  523. list(ex.get_size()),
  524. )
  525. )
  526. elif isinstance(x, sympy.Basic):
  527. out.append(
  528. ExpandView.create(
  529. IndexingConstant(
  530. index=x, dtype=ex.get_dtype(), device=ex.get_device_or_error()
  531. ),
  532. list(ex.get_size()),
  533. )
  534. )
  535. else:
  536. out.append(x)
  537. return out
  538. def make_pointwise(
  539. fn,
  540. override_return_dtype=None,
  541. override_device=None,
  542. override_fn_when_input_bool=None,
  543. allow_alpha=False,
  544. triton_fallback=None,
  545. ):
  546. def inner(*inputs: TensorBox, alpha=None):
  547. if triton_fallback is not None and any(
  548. isinstance(inp, IRNode) and is_triton(inp) for inp in inputs
  549. ):
  550. assert not allow_alpha # not implemented
  551. return triton_fallback(*inputs)
  552. inputs = promote_constants(inputs, override_return_dtype)
  553. if allow_alpha:
  554. if alpha is not None and alpha != 1:
  555. # pyrefly: ignore [bad-assignment]
  556. inputs = list(inputs)
  557. # pyrefly: ignore [unsupported-operation]
  558. inputs[-1] = mul(inputs[-1], alpha)
  559. else:
  560. assert alpha is None
  561. loaders = [x.make_loader() for x in inputs]
  562. ranges = inputs[0].get_size()
  563. dtype = override_return_dtype or inputs[0].get_dtype()
  564. for other in inputs[1:]:
  565. assert isinstance(other, ir.BaseConstant) or len(ranges) == len(
  566. other.get_size()
  567. ), f"ndim mismatch {fn} {ranges} {other.get_size()}"
  568. # in tracing, we will annotate pointwise nodes that correspond to the output of
  569. # a pointwise node that would have been run in eager. intermediary pointwise nodes
  570. # during decompositions are not annotated.
  571. low_pr_fp = (torch.bfloat16, torch.float16)
  572. emulate_precision_casts = (
  573. V.graph is not None
  574. and getattr(V.graph, "current_node", None) is not None
  575. and V.graph.current_node.meta is not None
  576. and V.graph.current_node.meta.get("low_precision_pointwise_barrier", False)
  577. )
  578. emulate_output_cast = emulate_precision_casts and dtype in low_pr_fp
  579. def inner_fn(index):
  580. assert len(index) == len(ranges), f"wrong ndim {index} {ranges}"
  581. if dtype == torch.bool and override_fn_when_input_bool is not None:
  582. return override_fn_when_input_bool(*[load(index) for load in loaders])
  583. else:
  584. inputs_loaded = []
  585. for inp_index, load in enumerate(loaders):
  586. out = load(index)
  587. inp_dtype = inputs[inp_index].get_dtype()
  588. if emulate_precision_casts and inp_dtype in low_pr_fp:
  589. downcast = ops.to_dtype(out, inp_dtype, use_compute_types=False)
  590. out = ops.to_dtype(downcast, inp_dtype)
  591. inputs_loaded.append(out)
  592. out = fn(*inputs_loaded)
  593. if emulate_output_cast:
  594. # fp16/bf16 kernels are computed in fp32. Casting down to fp16/bf16 here,
  595. # then upcasting again, to emulate casts that eager would do.
  596. downcast = ops.to_dtype(out, dtype, use_compute_types=False)
  597. return ops.to_dtype(downcast, dtype)
  598. return out
  599. if not override_device:
  600. device = None
  601. for i in inputs:
  602. if is_gpu(i.get_device().type):
  603. device = i.get_device()
  604. break
  605. if not device:
  606. device = inputs[0].get_device()
  607. # pyrefly: ignore [unbound-name]
  608. device = override_device or device
  609. return Pointwise.create(
  610. device=device, # type: ignore[arg-type]
  611. dtype=dtype,
  612. inner_fn=inner_fn,
  613. ranges=ranges,
  614. )
  615. return inner
  616. def make_foreach_pointwise(pw_fn, allow_alpha=False):
  617. def inner(*inputs: list[list[TensorBox]], alpha=1):
  618. realize_outputs = (
  619. len(V.graph.current_node.users) == 0
  620. or V.graph.current_node.target in inplace_foreach_ops
  621. or cur_node_has_non_foreach_users()
  622. )
  623. a_list_input = None
  624. for input in inputs:
  625. if isinstance(input, (list, tuple)):
  626. a_list_input = input
  627. break
  628. assert a_list_input is not None, (
  629. "at least one input must be a list to a foreach op"
  630. )
  631. # broadcast scalar inputs to match length of list inputs
  632. broadcast_inputs = []
  633. for input in inputs:
  634. if not isinstance(input, (list, tuple)):
  635. broadcast_inputs.append([input] * len(a_list_input))
  636. else:
  637. # pyrefly: ignore [bad-argument-type]
  638. broadcast_inputs.append(input)
  639. groups = group_foreach_args(zip(*broadcast_inputs))
  640. def apply_fn(args):
  641. if allow_alpha:
  642. return pw_fn(*args, alpha=alpha)
  643. else:
  644. return pw_fn(*args)
  645. return foreach_group_loop(groups, len(a_list_input), apply_fn, realize_outputs)
  646. return inner
  647. def foreach_group_loop(groups, num_outputs, apply_fn, realize_outputs):
  648. """
  649. Common loop over grouped foreach arguments.
  650. Args:
  651. groups: Result of group_foreach_args - dict mapping (device, use_foreach) to groups
  652. num_outputs: Number of outputs to produce
  653. apply_fn: Function to apply to each set of args, returns the output
  654. realize_outputs: Whether to realize outputs for foreach fusion
  655. """
  656. outputs = [None] * num_outputs
  657. for (device, use_foreach), group in groups.items():
  658. operation_list: list[str] = []
  659. for output_ind, args in group:
  660. output = apply_fn(args)
  661. outputs[output_ind] = output
  662. if (
  663. V.graph.has_feature(device, BackendFeature.FOREACH)
  664. and use_foreach
  665. and realize_outputs
  666. ):
  667. output.realize()
  668. operation_list.append(output.get_operation_name())
  669. if operation_list:
  670. V.graph.register_operation_list(operation_list)
  671. assert all(x is not None for x in outputs)
  672. return outputs
  673. def to_dtype(x: TensorBox, dtype: torch.dtype, copy: bool = False):
  674. src_dtype = x.get_dtype()
  675. if src_dtype == dtype:
  676. return clone(x) if copy else x
  677. def _to_dtype(x):
  678. return ops.to_dtype(x, dtype, src_dtype=src_dtype)
  679. return make_pointwise(_to_dtype, override_return_dtype=dtype)(x)
  680. @register_lowering(torch._higher_order_ops._foreach_map, type_promotion_kind=None)
  681. def _foreach_map(subgraph, *args, **kwargs):
  682. """
  683. This lowers an invocation of foreach_map
  684. The way this works is that an arbitrary N-arg func is provided by the user, looped over by the
  685. polyfill with the same semantics as a foreach op (a loop applying an n-ary function to n args)
  686. and then traced into a subgraph by dynamo.
  687. This code allows us to inline the subgraph into the main graph lowering using the PontwiseSubgraphLowering.
  688. The graph outputs represent the vertically fused sequence of ops, and then register_operation_list
  689. below registers the buffers as horizontally fuseable in the scheduler.
  690. """
  691. from .subgraph_lowering import PointwiseSubgraphLowering
  692. inputs = args
  693. gm = subgraph.graph_module
  694. pw_subgraph = PointwiseSubgraphLowering(gm, root_graph_lowering=V.graph)
  695. with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type]
  696. pw_subgraph.run(*inputs)
  697. sub_outputs = pw_subgraph.graph_outputs
  698. # group outputs by device and register as foreach
  699. assert sub_outputs # mypy lol
  700. groups = group_foreach_args(sub_outputs)
  701. outputs = [None] * len(sub_outputs)
  702. for (device, use_foreach), group in groups.items():
  703. operation_list: list[str] = []
  704. for (
  705. output_ind,
  706. output,
  707. ) in group:
  708. outputs[output_ind] = output
  709. if V.graph.has_feature(device, BackendFeature.FOREACH) and use_foreach:
  710. output.realize()
  711. operation_list.append(output.get_operation_name())
  712. if operation_list:
  713. V.graph.register_operation_list(operation_list)
  714. assert all(x is not None for x in outputs)
  715. return outputs
  716. @register_lowering(prims.convert_element_type, type_promotion_kind=None)
  717. def _convert_element_type(x: TensorBox, dtype: torch.dtype):
  718. if dtype.is_complex or x.get_dtype().is_complex:
  719. if x.get_size():
  720. # Decompose since aa aten fallback is more friendly for c++ codegen.
  721. # This decomposition doesn't work for empty tensor, which needs more investigation.
  722. dst = empty_like(x, dtype=dtype)
  723. ir.InplaceCopyFallback.create(dst, x)
  724. return dst
  725. else:
  726. return fallback_handler(
  727. prims.convert_element_type.default, add_to_fallback_set=False
  728. )(x, dtype)
  729. return to_dtype(x, dtype, copy=True)
  730. def to_dtype_bitcast(x: TensorBox, dtype: torch.dtype, *, copy=False):
  731. x_dtype = x.get_dtype()
  732. if x_dtype == dtype:
  733. return clone(x) if copy else x
  734. def _get_primitive_bitwidth(dtype):
  735. if dtype.is_floating_point:
  736. return torch.finfo(dtype).bits
  737. else:
  738. return torch.iinfo(dtype).bits
  739. src_bits = _get_primitive_bitwidth(x_dtype)
  740. dst_bits = _get_primitive_bitwidth(dtype)
  741. if src_bits != dst_bits:
  742. # fallback to aten eager implementation for differing bitwidths
  743. return fallback_handler(aten.view.dtype)(x, dtype)
  744. else:
  745. return TensorBox(DtypeView.create(x, dtype))
  746. @register_lowering(aten.view.dtype, type_promotion_kind=None)
  747. def _view_dtype(x: TensorBox, dtype: torch.dtype):
  748. if dtype.is_complex or x.get_dtype().is_complex:
  749. return TensorBox.create(
  750. ir.ComplexView.create(torch.ops.aten.view.dtype, x, dtype)
  751. )
  752. return to_dtype_bitcast(x, dtype)
  753. def to_device(x: TensorBox, device: torch.device, *, copy=False, non_blocking=False):
  754. device = decode_device(device)
  755. if x.get_device() == device:
  756. return clone(x) if copy else x
  757. return TensorBox.create(ir.DeviceCopy.create(x, device, non_blocking))
  758. @register_lowering(prims.device_put, type_promotion_kind=None)
  759. def _device_put(x: TensorBox, device: torch.device, non_blocking=False):
  760. return to_device(x, device, copy=True, non_blocking=non_blocking)
  761. def register_pointwise(
  762. aten_fn,
  763. name=None,
  764. broadcast=True,
  765. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  766. convert_input_to_bool=False,
  767. override_return_dtype=None,
  768. override_fn_when_input_bool=None,
  769. allow_alpha=False,
  770. triton_fallback=None,
  771. ):
  772. """A pointwise function that maps ops.{name} to inputs"""
  773. name = name or aten_fn.__name__
  774. fn = ops_wrapper(name)
  775. register_op_dtype_propagation_rules(
  776. name, type_promotion_kind, override_return_dtype
  777. )
  778. if override_fn_when_input_bool is not None:
  779. override_fn_when_input_bool = ops_wrapper(override_fn_when_input_bool)
  780. fn = make_pointwise(
  781. fn,
  782. override_return_dtype=override_return_dtype,
  783. override_fn_when_input_bool=override_fn_when_input_bool,
  784. allow_alpha=allow_alpha,
  785. triton_fallback=triton_fallback,
  786. )
  787. fn = register_lowering(
  788. aten_fn,
  789. broadcast=broadcast,
  790. type_promotion_kind=type_promotion_kind,
  791. convert_input_to_bool=convert_input_to_bool,
  792. )(fn)
  793. if hasattr(prims, name):
  794. register_lowering(
  795. getattr(prims, name),
  796. type_promotion_kind=None,
  797. convert_input_to_bool=convert_input_to_bool,
  798. )(fn)
  799. return fn
  800. register_op_dtype_propagation_rules(
  801. "ldexp",
  802. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  803. override_return_dtype=None,
  804. )
  805. @register_lowering(aten.ldexp, broadcast=True, type_promotion_kind=None)
  806. def ldexp_lowering(x: TensorBox, n: TensorBox):
  807. ldexp_fn = ops_wrapper("ldexp")
  808. x_dtype = x.get_dtype()
  809. n_dtype = n.get_dtype()
  810. x_is_float = x_dtype.is_floating_point
  811. n_is_int = not n_dtype.is_floating_point and n_dtype != torch.bool
  812. if x_is_float and n_is_int:
  813. # Use native ldexp
  814. def compute_ldexp(x, n):
  815. return ldexp_fn(x, n)
  816. return make_pointwise(compute_ldexp)(x, n)
  817. else:
  818. # Fall back to decomposition: x * pow(2, n)
  819. out_dtype = torch.float32 if is_integer_type(x) else x_dtype
  820. def compute_fallback(x, n):
  821. n_out_type = ops.to_dtype(n, out_dtype)
  822. two = ops.constant(2.0, out_dtype)
  823. pow_result = ops.pow(two, n_out_type)
  824. return ops.mul(x, pow_result)
  825. return make_pointwise(
  826. compute_fallback,
  827. override_return_dtype=out_dtype,
  828. )(x, n)
  829. def register_frexp():
  830. """A pointwise function that maps ops.frexp to inputs"""
  831. name = "frexp"
  832. frexp = ops_wrapper("frexp")
  833. def frexp0(*args, **kwargs):
  834. return frexp(*args, **kwargs)[0] # type: ignore[index]
  835. def frexp1(*args, **kwargs):
  836. return frexp(*args, **kwargs)[1] # type: ignore[index]
  837. pw_fns = [
  838. make_pointwise(frexp0),
  839. make_pointwise(frexp1, override_return_dtype=torch.int32),
  840. ]
  841. def fn(*args, **kwargs):
  842. return pw_fns[0](*args, **kwargs), pw_fns[1](*args, **kwargs)
  843. fn = register_lowering(
  844. aten.frexp,
  845. )(fn)
  846. if hasattr(prims, name):
  847. register_lowering(
  848. getattr(prims, name),
  849. type_promotion_kind=None,
  850. )(fn)
  851. return fn
  852. register_frexp()
  853. def register_foreach_pointwise(
  854. aten_fn,
  855. pointwise_lowering_fn,
  856. allow_alpha=False,
  857. ):
  858. fn = make_foreach_pointwise(pointwise_lowering_fn, allow_alpha=allow_alpha)
  859. fn = _register_foreach_lowering(aten_fn, fn)
  860. return fn
  861. @register_lowering(aten.where, broadcast=False, type_promotion_kind=None)
  862. def where(cond, a, b):
  863. def fn(*args):
  864. return ops.where(*args)
  865. if isinstance(a, (float, int)):
  866. a = constant_like(a)(b)
  867. if isinstance(b, (float, int)):
  868. b = constant_like(b)(a)
  869. args = [cond, a, b]
  870. dtype = get_promoted_dtype(
  871. args[1], args[2], type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  872. )
  873. indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)]
  874. for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])):
  875. args[i] = x
  876. for i in range(len(args)):
  877. if isinstance(args[i], ir.Constant):
  878. args[i] = ExpandView.create(args[i], list(args[indices[0]].get_size()))
  879. return make_pointwise(fn, override_return_dtype=dtype)(
  880. args[0], to_dtype(args[1], dtype), to_dtype(args[2], dtype)
  881. )
  882. @register_lowering(aten.broadcast_tensors, broadcast=False, type_promotion_kind=None)
  883. def broadcast_tensors(*inputs):
  884. if len(inputs) == 1:
  885. if isinstance(inputs[0], (list, tuple)):
  886. return broadcast_tensors(*inputs[0])
  887. return inputs
  888. target: list[sympy.Expr] = functools.reduce(
  889. broadcast_symbolic_shapes, (x.get_size() for x in inputs), ()
  890. )
  891. outputs = []
  892. for x in inputs:
  893. if (sizes := tuple(x.get_size())) == target:
  894. pass
  895. elif len(sizes) != len(target) or any(
  896. V.graph.sizevars.is_size_one_or_false(a)
  897. != V.graph.sizevars.is_size_one_or_false(b)
  898. for a, b in zip(sizes, target)
  899. ):
  900. x = expand(x, target)
  901. outputs.append(x)
  902. return outputs
  903. @register_lowering([aten.alias, aten.detach, aten.detach_, aten.lift, prims.view_of])
  904. def nop(x):
  905. return x # AOT autograd handles this for us
  906. if hasattr(aten, "lift_fresh"):
  907. register_lowering(aten.lift_fresh)(nop)
  908. @register_lowering(aten.squeeze, type_promotion_kind=None)
  909. def squeeze(x, dim=None):
  910. assert isinstance(x, TensorBox)
  911. if dim is None:
  912. return TensorBox(SqueezeView.create(x.data))
  913. dim = (
  914. V.graph.sizevars.guard_int(dim)
  915. if isinstance(dim, (int, sympy.Expr))
  916. else tuple(V.graph.sizevars.guard_int(d) for d in dim)
  917. )
  918. dim = canonicalize_dims(len(x.get_size()), dim) # type: ignore[call-overload]
  919. dims = OrderedSet((dim,) if not isinstance(dim, tuple) else dim)
  920. new_shape = []
  921. for d, s in enumerate(x.get_size()):
  922. if not (d in dims and V.graph.sizevars.guard_or_false(sympy.Eq(s, 1))):
  923. new_shape.append(s)
  924. # squeeze does nothing if the size isn't 1
  925. return view(x, new_shape) if new_shape != x.get_size() else x
  926. @register_lowering(aten.squeeze_copy, type_promotion_kind=None)
  927. def squeeze_copy(x, dim=None):
  928. return clone(squeeze(x, dim))
  929. @register_lowering([aten.squeeze_])
  930. def squeeze_(x, dim=None):
  931. val = squeeze(x, dim)
  932. assert isinstance(x, TensorBox)
  933. assert isinstance(val, TensorBox)
  934. x.data = val.data
  935. return x
  936. @register_lowering(aten.isinf)
  937. def isinf(x):
  938. if is_integer_type(x):
  939. return full_like(x, False, dtype=torch.bool)
  940. fn = ops_wrapper("isinf")
  941. return make_pointwise(fn, override_return_dtype=torch.bool)(x)
  942. @register_lowering(aten.isnan)
  943. def isnan(x):
  944. if is_integer_type(x):
  945. return full_like(x, False, dtype=torch.bool)
  946. fn = ops_wrapper("isnan")
  947. return make_pointwise(fn, override_return_dtype=torch.bool)(x)
  948. @register_lowering(aten.ceil)
  949. def ceil(x):
  950. if is_integer_type(x):
  951. return clone(x)
  952. fn = ops_wrapper("ceil")
  953. return make_pointwise(fn)(x)
  954. @register_lowering(aten.floor)
  955. def floor(x):
  956. if is_integer_type(x):
  957. return clone(x)
  958. fn = ops_wrapper("floor")
  959. return make_pointwise(fn)(x)
  960. @register_lowering(aten.round.default)
  961. def round(x):
  962. if is_integer_type(x):
  963. return clone(x)
  964. else:
  965. fn = ops_wrapper("round")
  966. return make_pointwise(fn)(x)
  967. @register_lowering(aten.trunc)
  968. def trunc(x):
  969. if is_integer_type(x):
  970. return clone(x)
  971. fn = ops_wrapper("trunc")
  972. return make_pointwise(fn)(x)
  973. @register_lowering(aten.expand, type_promotion_kind=None)
  974. def expand(x, sizes):
  975. (x,) = promote_constants([x])
  976. if isinstance(x, ir.BaseConstant):
  977. return ExpandView.create(x, tuple(sizes))
  978. assert isinstance(x, TensorBox)
  979. assert isinstance(sizes, (list, tuple))
  980. if tuple(x.get_size()) == tuple(sizes):
  981. return x
  982. if not free_unbacked_symbols(x.get_size()):
  983. x_size_product = V.graph.sizevars.size_hint_or_throw(
  984. sympy_product(x.get_size())
  985. )
  986. # TODO: It would be better to realize the input if any of its sizes
  987. # are unbacked, because typically the size will be non-zero. However,
  988. # this cannot be done directly as below as we'll choke on the size_hint
  989. # here
  990. if x_size_product > 0 and not free_unbacked_symbols(sizes):
  991. # maybe realize input before broadcasting it
  992. x.mark_reuse(
  993. V.graph.sizevars.size_hint_or_throw(sympy_product(sizes))
  994. // x_size_product
  995. )
  996. return TensorBox(ExpandView.create(x.data, tuple(sizes)))
  997. @register_lowering(prims.broadcast_in_dim, type_promotion_kind=None)
  998. def broadcast_in_dim(a, shape, broadcast_dimensions):
  999. s = list(shape)
  1000. for broadcast_dimension in broadcast_dimensions:
  1001. s[broadcast_dimension] = -1
  1002. v = a
  1003. for idx, x in enumerate(s):
  1004. if x != -1:
  1005. v = unsqueeze(v, idx)
  1006. return expand(v, shape)
  1007. @register_lowering(aten.expand_as, type_promotion_kind=None)
  1008. def expand_as(x, y):
  1009. return expand(x, y.get_size())
  1010. @register_lowering(aten.repeat)
  1011. def repeat(x, repeats):
  1012. old_size = list(x.get_size())
  1013. if len(repeats) > len(old_size):
  1014. old_size = [sympy.S.One] * (len(repeats) - len(old_size)) + old_size
  1015. x = view(x, list(old_size))
  1016. assert len(repeats) == len(x.get_size())
  1017. new_size = list(x.get_size())
  1018. zero_tensor = False
  1019. for i in range(len(repeats)):
  1020. if repeats[i] == 0:
  1021. zero_tensor = True
  1022. new_size[i] = new_size[i] * repeats[i]
  1023. if zero_tensor:
  1024. return empty(new_size, dtype=x.get_dtype(), device=x.get_device())
  1025. if all((a == 1 or b == 1) for a, b in zip(repeats, old_size)):
  1026. return clone(expand(x, new_size))
  1027. x_loader: Callable[[Any], Any]
  1028. def inner_fn(index):
  1029. assert len(index) == len(repeats)
  1030. index = list(index)
  1031. for i in range(len(repeats)):
  1032. if repeats[i] != 1:
  1033. if old_size[i] == 1:
  1034. index[i] = sympy.S.Zero
  1035. else:
  1036. index[i] = ModularIndexing(index[i], 1, old_size[i])
  1037. return x_loader(index)
  1038. if not free_unbacked_symbols(old_size) and not free_unbacked_symbols(new_size):
  1039. old_size_product = V.graph.sizevars.size_hint_or_throw(sympy_product(old_size))
  1040. if old_size_product > 0:
  1041. # maybe realize the input but skip for unbacked symints since it'll
  1042. # choke on the size hint.
  1043. x.mark_reuse(
  1044. V.graph.sizevars.size_hint_or_throw(sympy_product(new_size))
  1045. // old_size_product
  1046. )
  1047. x_loader = x.make_loader()
  1048. return Pointwise.create(
  1049. device=x.get_device(),
  1050. dtype=x.get_dtype(),
  1051. inner_fn=inner_fn,
  1052. ranges=list(new_size),
  1053. )
  1054. @register_lowering(aten._unsafe_view, type_promotion_kind=None)
  1055. @register_lowering(aten.view, type_promotion_kind=None)
  1056. @register_lowering(aten.reshape, type_promotion_kind=None)
  1057. def view(x: TensorBox, sizes: Sequence[sympy.Expr]) -> TensorBox:
  1058. return TensorBox(View.create(x.data, sizes))
  1059. @register_lowering(aten.permute, type_promotion_kind=None)
  1060. def permute(x, dims):
  1061. assert isinstance(x, TensorBox)
  1062. assert isinstance(dims, (list, tuple))
  1063. return TensorBox(PermuteView.create(x.data, tuple(dims)))
  1064. @register_lowering(aten.slice, type_promotion_kind=None)
  1065. def slice_(x, dim=0, start=0, end=2**63, step=1, clamp=True):
  1066. """
  1067. Lowers a slice call, creating ExternKernels for the output size & storage offset symbols,
  1068. if the indices are unbacked and appropriate semantics aren't known.
  1069. If they are known (indices are static/backed/unbacked with info), a SliceView is created.
  1070. """
  1071. from torch.fx.experimental.symbolic_shapes import (
  1072. CallMethodKey,
  1073. resolve_unbacked_bindings,
  1074. )
  1075. assert isinstance(x, TensorBox)
  1076. dim = _validate_dim(x, dim, 0)
  1077. size = x.get_size()[dim]
  1078. step = sympy.expand(step)
  1079. assert isinstance(step, sympy.Expr) or step > 0, step
  1080. # maybe apply slice optimization
  1081. try:
  1082. if (
  1083. start == 0
  1084. and V.graph.sizevars.statically_known_leq(size, end)
  1085. and step == 1
  1086. ):
  1087. return x
  1088. except TypeError:
  1089. pass
  1090. # try to avoid dynamic (unbacked) slice
  1091. def compute_slice_index(index, size, default=None):
  1092. if index is None:
  1093. return default
  1094. fn = lambda x: V.graph.sizevars.guard_or_false(x) # noqa: E731
  1095. index = sympy.expand(index)
  1096. size = sympy.expand(size)
  1097. if fn(sympy.Ge(index, 0)) and fn(sympy.Le(index, size)):
  1098. return index
  1099. elif fn(sympy.Lt(index, 0)) and fn(sympy.Ge(index, -size)):
  1100. return index + size
  1101. elif fn(sympy.Gt(index, size)):
  1102. return size
  1103. elif fn(sympy.Lt(index, -size)):
  1104. return 0
  1105. return None
  1106. start_index, end_index = None, None
  1107. ambiguous_slice = clamp
  1108. if ambiguous_slice:
  1109. start_index = compute_slice_index(start, size, 0)
  1110. end_index = compute_slice_index(end, size, size)
  1111. if start_index is not None and end_index is not None:
  1112. start, end = start_index, end_index
  1113. ambiguous_slice = False
  1114. # ambiguous_slice=False means we know what semantics this slice call follows,
  1115. # and don't need to generate an extern kernel to represent the output size.
  1116. # This is assumed True for clamp=False
  1117. # (meant to follow standard indexing semantics: 0 <= index < size)
  1118. if not ambiguous_slice:
  1119. return TensorBox(
  1120. ir.SliceView.create(x.data, dim, start, end, step, clamp=clamp)
  1121. ) # go to SliceView/ReinterpretView
  1122. # unbacked territory: create DynamicSlice ExternKernel
  1123. # clamp is True, unbacked start / end
  1124. assert clamp
  1125. unbacked_bindings = resolve_unbacked_bindings(
  1126. V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"]
  1127. )
  1128. assert unbacked_bindings is not None
  1129. assert len(unbacked_bindings) <= 2, unbacked_bindings
  1130. sym_size, sym_storage = None, None
  1131. for sym, keypath in unbacked_bindings.items():
  1132. if keypath == (CallMethodKey("size"), pytree.SequenceKey(dim)):
  1133. sym_size = sym
  1134. elif keypath == (CallMethodKey("storage_offset"),):
  1135. sym_storage = sym
  1136. assert start_index is None or end_index is None
  1137. b_size = ir.DynamicSliceSize(
  1138. sym_size,
  1139. start,
  1140. end,
  1141. step,
  1142. x.get_size()[dim],
  1143. )
  1144. b_size.name = V.graph.register_buffer(b_size)
  1145. V.graph.register_operation(b_size)
  1146. new_size = sym_size
  1147. if x.maybe_get_layout() is None:
  1148. # realize tensor before accessing layout
  1149. x.realize()
  1150. if start_index is not None:
  1151. # we shouldn't have allocated storage offset symbol if start index was determinable
  1152. assert sym_storage is None
  1153. new_storage_offset = x.get_layout().offset + start_index * x.get_stride()[dim]
  1154. else:
  1155. b_storage = ir.DynamicSelectStorageOffset(
  1156. sym_storage,
  1157. start,
  1158. x.get_layout().offset,
  1159. x.get_stride()[dim],
  1160. x.get_size()[dim],
  1161. clamp=True,
  1162. )
  1163. b_storage.name = V.graph.register_buffer(b_storage)
  1164. V.graph.register_operation(b_storage)
  1165. new_storage_offset = sym_storage
  1166. new_sizes = list(x.get_size())
  1167. new_strides = list(x.get_stride())
  1168. new_sizes[dim] = new_size
  1169. new_strides[dim] *= step
  1170. return as_strided(x, new_sizes, new_strides, new_storage_offset)
  1171. @register_lowering(aten.as_strided, type_promotion_kind=None)
  1172. def as_strided(x, size, stride, storage_offset=None):
  1173. new_device = None
  1174. new_dtype = None
  1175. if isinstance(x, TensorBox) and isinstance(x.data, ir.BaseView):
  1176. # Note: Merging views
  1177. # When we use as_strided, we can rewrite the size/stride/offset
  1178. # of the incoming buffer x. If x is a view, we would overwrite
  1179. # its metadata. Except for dtype, which we need to propagate.
  1180. # Technically device is not needed because it is not possible
  1181. # to have a cross-device view today.
  1182. new_device = x.get_device()
  1183. new_dtype = x.dtype
  1184. x = x.data.unwrap_view()
  1185. x.realize()
  1186. if not ir.is_storage_and_layout(x):
  1187. raise NotImplementedError(f"unrealized as_strided({x}, ...)")
  1188. storage, old_layout = ir.as_storage_and_layout(x)
  1189. new_layout = ir.FixedLayout(
  1190. new_device if new_device else old_layout.device,
  1191. new_dtype if new_dtype else old_layout.dtype,
  1192. [sympy.expand(s) for s in size],
  1193. [sympy.expand(s) for s in stride],
  1194. sympy.expand(storage_offset or 0),
  1195. )
  1196. return TensorBox(ir.ReinterpretView(data=storage, layout=new_layout))
  1197. @register_lowering(aten.as_strided_, type_promotion_kind=None)
  1198. def as_strided_(x, size, stride, storage_offset=None):
  1199. assert isinstance(x, TensorBox)
  1200. x.data = as_strided(x, size, stride, storage_offset).data
  1201. return x
  1202. @register_lowering(aten.as_strided_copy, type_promotion_kind=None)
  1203. def as_strided_copy(x, size, stride, storage_offset=None):
  1204. result = as_strided(x, size, stride, storage_offset)
  1205. return clone(result)
  1206. def pointwise_cat(inputs, dim=0):
  1207. # (inclusive, exclusive)
  1208. inputs_ranges: list[tuple[sympy.Expr, sympy.Expr]] = []
  1209. prev_end = 0
  1210. for inp in inputs:
  1211. inputs_ranges.append((prev_end, prev_end + inp.get_size()[dim])) # type: ignore[arg-type]
  1212. prev_end = inputs_ranges[-1][-1] # type: ignore[assignment]
  1213. inputs_loaders = [inp.make_loader() for inp in inputs]
  1214. def inner_fn(idx):
  1215. idx_dim = ops.index_expr(idx[dim], torch.int64)
  1216. masks = []
  1217. masked_loads = []
  1218. for i in range(len(inputs)):
  1219. start = (
  1220. ops.constant(0, torch.int64)
  1221. if i == 0
  1222. else ops.index_expr(inputs_ranges[i][0], torch.int64)
  1223. )
  1224. end = ops.index_expr(inputs_ranges[i][1], torch.int64)
  1225. start_cond = ops.ge(idx_dim, start)
  1226. end_cond = ops.lt(idx_dim, end)
  1227. if i == 0:
  1228. mask = end_cond
  1229. elif i == len(inputs) - 1:
  1230. mask = start_cond
  1231. else:
  1232. mask = ops.and_(start_cond, end_cond)
  1233. masks.append(mask)
  1234. idx_load = list(idx)
  1235. # if we're concatting [4], [2]
  1236. # when we index the second tensor for 5 we want to index 5 - 4
  1237. # Use Identity to prevent expansion of index * stride to keep expression
  1238. # in same int bitwidth as shape
  1239. idx_load[dim] = Identity(idx_load[dim] - inputs_ranges[i][0])
  1240. masked_loads.append(
  1241. ops.masked(
  1242. mask,
  1243. lambda: inputs_loaders[i](idx_load),
  1244. 0.0, # this value should be unused
  1245. ),
  1246. )
  1247. next_val = masked_loads[-1]
  1248. for i in range((len(inputs)) - 2, -1, -1):
  1249. next_val = ops.where(
  1250. masks[i],
  1251. masked_loads[i],
  1252. next_val,
  1253. )
  1254. return next_val
  1255. new_size = list(inputs[0].get_size())
  1256. new_size[dim] = inputs_ranges[-1][-1]
  1257. return Pointwise.create(
  1258. device=inputs[0].get_device(),
  1259. dtype=inputs[0].get_dtype(),
  1260. inner_fn=inner_fn,
  1261. ranges=new_size,
  1262. )
  1263. @register_lowering(quantized_decomposed.quantize_per_channel, type_promotion_kind=None)
  1264. def quantized_decomposed_quantize_per_channel(
  1265. input: TensorBox,
  1266. scales: TensorBox,
  1267. zero_points: TensorBox,
  1268. axis: int,
  1269. quant_min: int,
  1270. quant_max: int,
  1271. dtype: torch.dtype,
  1272. ) -> TensorBox:
  1273. assert len(scales.get_size()) == 1, "expect scales 1 dim"
  1274. assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim"
  1275. if input.get_dtype() == torch.bfloat16:
  1276. input = to_dtype(input, torch.float32)
  1277. assert input.get_dtype() == torch.float32, (
  1278. f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
  1279. )
  1280. assert axis < len(input.get_size()), (
  1281. f"Expecting axis to be < {len(input.get_size())}"
  1282. )
  1283. input_loader = input.make_loader()
  1284. scales_loader = scales.make_loader()
  1285. zero_points_loader = zero_points.make_loader()
  1286. def inner_fn(idx):
  1287. channel_idx = (idx[axis],)
  1288. input = input_loader(idx)
  1289. scale = scales_loader(channel_idx)
  1290. zero_point = zero_points_loader(channel_idx)
  1291. qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32)
  1292. if scales.dtype != torch.float32:
  1293. scale = ops.to_dtype(scale, torch.float32)
  1294. if zero_points.dtype != torch.int32:
  1295. zero_point = ops.to_dtype(zero_point, torch.int32)
  1296. inv_scale = ops.reciprocal(scale)
  1297. val = ops.round(input * inv_scale) + zero_point
  1298. clamped = ops.maximum(qmin, ops.minimum(qmax, val))
  1299. return ops.to_dtype(clamped, dtype)
  1300. return Pointwise.create(
  1301. device=input.get_device(),
  1302. dtype=dtype,
  1303. inner_fn=inner_fn,
  1304. ranges=input.get_size(),
  1305. )
  1306. def _assert_async(cond, msg):
  1307. cond.realize()
  1308. cond = to_dtype(cond, torch.bool)
  1309. def inner_fn(index):
  1310. with ir.ComputedBuffer.force_realize():
  1311. return ops.device_assert_async(cond.make_loader()(index), msg)
  1312. assertion_op = Pointwise.create(
  1313. device=cond.get_device(),
  1314. dtype=cond.get_dtype(),
  1315. inner_fn=inner_fn,
  1316. ranges=list(cond.get_size()),
  1317. )
  1318. assertion_op.realize()
  1319. return assertion_op
  1320. @register_lowering(aten._assert_async.msg)
  1321. def lower_assert_async(cond, msg):
  1322. return _assert_async(cond, msg)
  1323. @register_lowering(aten._functional_assert_async.msg)
  1324. def lower_assert_functional_async(cond, msg):
  1325. return _assert_async(cond, msg)
  1326. @register_lowering(
  1327. quantized_decomposed.dequantize_per_channel, type_promotion_kind=None
  1328. )
  1329. def quantized_decomposed_dequantize_per_channel(
  1330. input: TensorBox,
  1331. scales: TensorBox,
  1332. zero_points: TensorBox,
  1333. axis: int,
  1334. quant_min: int,
  1335. quant_max: int,
  1336. dtype: torch.dtype,
  1337. *,
  1338. out_dtype: Optional[torch.dtype] = None,
  1339. ) -> TensorBox:
  1340. assert len(scales.get_size()) == 1, "expect scales 1 dim"
  1341. assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim"
  1342. assert input.get_dtype() == dtype, (
  1343. f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
  1344. )
  1345. assert axis < len(input.get_size()), (
  1346. f"Expecting axis to be < {len(input.get_size())}"
  1347. )
  1348. if out_dtype is None:
  1349. out_dtype = torch.float32
  1350. input_loader = input.make_loader()
  1351. scales_loader = scales.make_loader()
  1352. zero_points_loader = zero_points.make_loader()
  1353. def inner_fn(idx):
  1354. channel_idx = (idx[axis],)
  1355. input = input_loader(idx)
  1356. scale = scales_loader(channel_idx)
  1357. zero_point = zero_points_loader(channel_idx)
  1358. if scales.dtype != torch.float32:
  1359. scale = ops.to_dtype(scale, torch.float32)
  1360. if zero_points.dtype != torch.float32:
  1361. zero_point = ops.to_dtype(zero_point, torch.float32)
  1362. val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale
  1363. val = ops.to_dtype(val, out_dtype)
  1364. return val
  1365. return Pointwise.create(
  1366. device=input.get_device(),
  1367. dtype=out_dtype,
  1368. inner_fn=inner_fn,
  1369. ranges=input.get_size(),
  1370. )
  1371. @register_lowering(
  1372. quantized_decomposed.quantize_per_tensor.default, type_promotion_kind=None
  1373. )
  1374. def quantized_decomposed_quantize_per_tensor_default(
  1375. input: TensorBox,
  1376. scale: float,
  1377. zero_point: int,
  1378. quant_min: int,
  1379. quant_max: int,
  1380. dtype: torch.dtype,
  1381. ) -> TensorBox:
  1382. if input.get_dtype() == torch.bfloat16:
  1383. input = to_dtype(input, torch.float32)
  1384. assert input.get_dtype() == torch.float32, (
  1385. f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
  1386. )
  1387. input_loader = input.make_loader()
  1388. def inner_fn(idx, scale, zero_point):
  1389. input = input_loader(idx)
  1390. inv_scale, zero_point = _create_constants(
  1391. 1.0 / scale, zero_point, dtype=torch.float32
  1392. )
  1393. val = ops.round(input * inv_scale) + zero_point
  1394. qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32)
  1395. clamped = ops.minimum(ops.maximum(val, qmin), qmax)
  1396. return ops.to_dtype(clamped, dtype)
  1397. return Pointwise.create(
  1398. device=input.get_device(),
  1399. dtype=dtype,
  1400. inner_fn=functools.partial(
  1401. inner_fn, scale=float(scale), zero_point=int(zero_point)
  1402. ),
  1403. ranges=input.get_size(),
  1404. )
  1405. @register_lowering(
  1406. quantized_decomposed.dequantize_per_tensor.default, type_promotion_kind=None
  1407. )
  1408. def quantized_decomposed_dequantize_per_tensor_default(
  1409. input: TensorBox,
  1410. scale: float,
  1411. zero_point: int,
  1412. quant_min: int,
  1413. quant_max: int,
  1414. dtype: torch.dtype,
  1415. *,
  1416. out_dtype: Optional[torch.dtype] = None,
  1417. ) -> TensorBox:
  1418. assert input.get_dtype() == dtype, (
  1419. f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
  1420. )
  1421. if out_dtype is None:
  1422. out_dtype = torch.float32
  1423. input_loader = input.make_loader()
  1424. def inner_fn(idx, scale, zero_point):
  1425. input = input_loader(idx)
  1426. scale, zero_point = _create_constants(scale, zero_point, dtype=torch.float32)
  1427. val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale
  1428. val = ops.to_dtype(val, out_dtype)
  1429. return val
  1430. return Pointwise.create(
  1431. device=input.get_device(),
  1432. dtype=out_dtype,
  1433. inner_fn=functools.partial(
  1434. inner_fn, scale=float(scale), zero_point=int(zero_point)
  1435. ),
  1436. ranges=input.get_size(),
  1437. )
  1438. @register_lowering(
  1439. quantized_decomposed.quantize_per_tensor.tensor, type_promotion_kind=None
  1440. )
  1441. def quantized_decomposed_quantize_per_tensor_tensor(
  1442. input: TensorBox,
  1443. scale: TensorBox,
  1444. zero_point: TensorBox,
  1445. quant_min: int,
  1446. quant_max: int,
  1447. dtype: torch.dtype,
  1448. ) -> TensorBox:
  1449. if input.get_dtype() == torch.bfloat16:
  1450. input = to_dtype(input, torch.float32)
  1451. assert input.get_dtype() == torch.float32, (
  1452. f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
  1453. )
  1454. assert len(scale.get_size()) == 0 or (
  1455. len(scale.get_size()) == 1 and scale.get_size()[0] == 1
  1456. ), "expect scale as scalar tensor"
  1457. assert len(zero_point.get_size()) == 0 or (
  1458. len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1
  1459. ), "expect zero_point as scalar tensor"
  1460. input_loader = input.make_loader()
  1461. scale_loader = scale.make_loader()
  1462. zero_point_loader = zero_point.make_loader()
  1463. def inner_fn(idx):
  1464. input = input_loader(idx)
  1465. _scale = scale_loader((0,) if len(scale.get_size()) == 1 else ())
  1466. _zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ())
  1467. if scale.dtype != torch.float32:
  1468. _scale = ops.to_dtype(_scale, torch.float32)
  1469. if zero_point.dtype != torch.float32:
  1470. _zero_point = ops.to_dtype(_zero_point, torch.float32)
  1471. val = ops.round(input * ops.reciprocal(_scale)) + _zero_point
  1472. qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32)
  1473. clamped = ops.minimum(ops.maximum(val, qmin), qmax)
  1474. return ops.to_dtype(clamped, dtype)
  1475. return Pointwise.create(
  1476. device=input.get_device(),
  1477. dtype=dtype,
  1478. inner_fn=inner_fn,
  1479. ranges=input.get_size(),
  1480. )
  1481. @register_lowering(
  1482. quantized_decomposed.dequantize_per_tensor.tensor, type_promotion_kind=None
  1483. )
  1484. def quantized_decomposed_dequantize_per_tensor_tensor(
  1485. input: TensorBox,
  1486. scale: TensorBox,
  1487. zero_point: TensorBox,
  1488. quant_min: int,
  1489. quant_max: int,
  1490. dtype: torch.dtype,
  1491. *,
  1492. out_dtype: Optional[torch.dtype] = None,
  1493. ) -> TensorBox:
  1494. assert len(scale.get_size()) == 0 or (
  1495. len(scale.get_size()) == 1 and scale.get_size()[0] == 1
  1496. ), "expect scale as scalar tensor"
  1497. assert len(zero_point.get_size()) == 0 or (
  1498. len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1
  1499. ), "expect zero_point as scalar tensor"
  1500. assert input.get_dtype() == dtype, (
  1501. f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
  1502. )
  1503. if out_dtype is None:
  1504. out_dtype = torch.float32
  1505. input_loader = input.make_loader()
  1506. scale_loader = scale.make_loader()
  1507. zero_point_loader = zero_point.make_loader()
  1508. def inner_fn(idx):
  1509. input = input_loader(idx)
  1510. _scale = scale_loader((0,) if len(scale.get_size()) == 1 else ())
  1511. _zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ())
  1512. if scale.dtype != torch.float32:
  1513. _scale = ops.to_dtype(_scale, torch.float32)
  1514. if zero_point.dtype != torch.float32:
  1515. _zero_point = ops.to_dtype(_zero_point, torch.float32)
  1516. val = ops.sub(ops.to_dtype(input, torch.float32), _zero_point) * _scale
  1517. val = ops.to_dtype(val, out_dtype)
  1518. return val
  1519. return Pointwise.create(
  1520. device=input.get_device(),
  1521. dtype=out_dtype,
  1522. inner_fn=inner_fn,
  1523. ranges=input.get_size(),
  1524. )
  1525. @register_lowering(aten.cat)
  1526. def cat(inputs, dim=0):
  1527. cpu_device = inputs[0].get_device().type == "cpu"
  1528. if cpu_device and all(
  1529. input.get_dtype() in [torch.int8, torch.uint8] for input in inputs
  1530. ):
  1531. # TODO <leslie> Remove this fallback when we support vectorization
  1532. # code gen with uint8 data type directly.
  1533. for input in inputs:
  1534. input.realize()
  1535. if all(len(input.get_size()) == 4 for input in inputs):
  1536. inputs, _ = require_channels_last(aten.cat, *inputs)
  1537. return fallback_handler(aten.cat.default)(inputs, dim)
  1538. if len(inputs) == 1:
  1539. return clone(inputs[0])
  1540. dim = _validate_dim(inputs[0], dim, 0)
  1541. dtype = get_promoted_dtype(
  1542. *inputs, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  1543. )
  1544. inputs = [to_dtype(inp, dtype) for inp in inputs]
  1545. def unwrap_tensor(x: Union[TensorBox, ir.StorageBox]) -> ir.IRNode:
  1546. if isinstance(x, TensorBox):
  1547. if isinstance(x.data, ir.BaseView):
  1548. return x.data.unwrap_view()
  1549. else:
  1550. return x.data
  1551. if isinstance(x, ir.StorageBox):
  1552. return x.data
  1553. return x
  1554. def is_reduction(t):
  1555. return isinstance(t, ir.ComputedBuffer) and isinstance(t.data, ir.Reduction)
  1556. def can_fuse_reduction(t):
  1557. if isinstance(t, (TensorBox, ir.StorageBox)):
  1558. return can_fuse_reduction(unwrap_tensor(t))
  1559. return (
  1560. is_reduction(t)
  1561. or isinstance(t, ir.Pointwise)
  1562. and any(
  1563. can_fuse_reduction(V.graph.get_buffer(read))
  1564. for read in t.get_read_names()
  1565. )
  1566. )
  1567. # fusing reducutions into computed concat buffer can cause regressions.
  1568. fusable_reduction = any(can_fuse_reduction(t) for t in inputs)
  1569. def should_lower_cat_input(x) -> bool:
  1570. # Unrealized inputs will not be storage and layouts, and we dont want to realize
  1571. # them in case we want to fuse
  1572. if ir.is_storage_and_layout(x):
  1573. storage, _ = ir.as_storage_and_layout(x, freeze=False)
  1574. return not ir.ConcatKernel.can_realize_into_without_copy(storage)
  1575. if isinstance(x, (TensorBox, ir.StorageBox)):
  1576. return should_lower_cat_input(unwrap_tensor(x))
  1577. if isinstance(x, ir.Pointwise):
  1578. return True
  1579. return False
  1580. if config.force_pointwise_cat:
  1581. return pointwise_cat(inputs, dim)
  1582. # TODO: We observed negative performance impact of pointwise_cat optimization on CPU so disabled it.
  1583. # We will revisit this later after enabling vectorization on index_expr.
  1584. if cpu_device:
  1585. return TensorBox(ir.ConcatKernel.create(inputs, dim))
  1586. def op_count(x):
  1587. if isinstance(x, (TensorBox, ir.StorageBox)):
  1588. return op_count(unwrap_tensor(x))
  1589. # this will correspond to a direct memory read
  1590. if not isinstance(x, ir.Pointwise):
  1591. return 0
  1592. count = x.inner_fn_opcount().num_ops
  1593. for read in x.get_read_names():
  1594. count += op_count(V.graph.get_buffer(read))
  1595. return count
  1596. # as of inputs increase, possibility for register spilling also increases
  1597. # past a certain threshold of inputs we only fuse if the if the input kernels
  1598. # are simple
  1599. # not sure if we want to expose to users via config since logic may change in future
  1600. MAX_COMPLEX_POINTWISE_CAT = 8
  1601. MAX_SIMPLE_OP_COUNT = 2
  1602. def additional_pointwise_ops(op: torch._ops.OpOverload):
  1603. return op in (aten.cat.default, aten.constant_pad_nd.default)
  1604. if len(inputs) <= MAX_COMPLEX_POINTWISE_CAT or (
  1605. (len(inputs) <= config.max_pointwise_cat_inputs)
  1606. and all(op_count(t) <= MAX_SIMPLE_OP_COUNT for t in inputs)
  1607. ):
  1608. pointwise_uses = all(
  1609. is_pointwise_use(use, additional_pointwise_ops)
  1610. for use in V.current_node.users
  1611. )
  1612. # fuse in case we will be used in a pointwise node, and there are any inputs we
  1613. # we can prevent materialization of.
  1614. fuse_pointwise_use = (
  1615. any(should_lower_cat_input(inp) for inp in inputs) and pointwise_uses
  1616. )
  1617. # horizontal fuse in case all inputs will require a copy kernel anyway.
  1618. # only horizontally fuse pointwise kernels
  1619. horizontal_fuse_cat = all(
  1620. should_lower_cat_input(inp) for inp in inputs
  1621. ) and not any(can_fuse_reduction(t) for t in inputs)
  1622. if fuse_pointwise_use or (horizontal_fuse_cat and not fusable_reduction):
  1623. return pointwise_cat(inputs, dim)
  1624. return TensorBox(ir.ConcatKernel.create(inputs, dim))
  1625. @register_lowering(aten.diagonal, type_promotion_kind=None)
  1626. def diagonal(input, offset: int = 0, dim1: int = 0, dim2: int = 1):
  1627. original_shape = input.get_size()
  1628. num_dims = len(original_shape)
  1629. dim1 = canonicalize_dim(idx=dim1, rank=num_dims)
  1630. dim2 = canonicalize_dim(idx=dim2, rank=num_dims)
  1631. check(
  1632. dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
  1633. )
  1634. offset_negative = V.graph.sizevars.evaluate_expr(sympy.Lt(offset, 0))
  1635. if offset_negative:
  1636. diag_size = V.graph.sizevars.evaluate_max(
  1637. V.graph.sizevars.evaluate_min(
  1638. original_shape[dim1] + offset, original_shape[dim2]
  1639. ),
  1640. 0, # type: ignore[arg-type]
  1641. )
  1642. else:
  1643. diag_size = V.graph.sizevars.evaluate_max(
  1644. V.graph.sizevars.evaluate_min(
  1645. original_shape[dim1], original_shape[dim2] - offset
  1646. ),
  1647. 0, # type: ignore[arg-type]
  1648. )
  1649. base_idx = (0, 0)
  1650. if offset_negative:
  1651. base_idx = (-offset, 0)
  1652. else:
  1653. base_idx = (0, offset)
  1654. sizes = [s for i, s in enumerate(original_shape) if i not in (dim1, dim2)]
  1655. sizes.append(diag_size)
  1656. def reindexer(idx):
  1657. diag_idx = idx[-1]
  1658. original_idx = [0] * len(original_shape)
  1659. cur_dim = 0
  1660. for d in range(num_dims):
  1661. if d == dim1:
  1662. original_idx[d] = diag_idx + base_idx[0]
  1663. elif d == dim2:
  1664. original_idx[d] = diag_idx + base_idx[1]
  1665. else:
  1666. original_idx[d] = idx[cur_dim]
  1667. cur_dim += 1
  1668. assert cur_dim == len(original_shape) - 2
  1669. return original_idx
  1670. return TensorBox(ir.GenericView.create(input, sizes, reindexer))
  1671. @register_lowering(aten.diagonal_copy, type_promotion_kind=None)
  1672. def diagonal_copy(input, offset: int = 0, dim1: int = 0, dim2: int = 1):
  1673. return clone(diagonal(input, offset, dim1, dim2))
  1674. @register_lowering(aten.diagonal_scatter, type_promotion_kind=None)
  1675. def diagonal_scatter(input, src, offset: int = 0, dim1: int = 0, dim2: int = 1):
  1676. output = clone(input)
  1677. target = diagonal(output, offset, dim1, dim2)
  1678. mutate_to(target, src)
  1679. return output
  1680. @register_lowering(aten.select, type_promotion_kind=None)
  1681. def select(x, dim, idx):
  1682. idx = sympy.expand(idx)
  1683. size = sympy.expand(x.get_size()[dim])
  1684. actual_index = None
  1685. if V.graph.sizevars.guard_or_false(sympy.Lt(idx, 0)):
  1686. actual_index = idx + size
  1687. elif V.graph.sizevars.guard_or_false(sympy.Ge(idx, 0)):
  1688. actual_index = idx
  1689. if actual_index is not None:
  1690. if has_free_unbacked_symbols(idx):
  1691. # Inductor could generate incorrect views for tensors with unbacked symbols here;
  1692. # Squeeze operations are translated to views, resulting in incorrect strides.
  1693. # Additionally, we want to avoid accidental unbacked unsqueeze semantics. To resolve this,
  1694. # we use as_strided instead.
  1695. # Removing this branch will cause test_unbacked_select_index_with_check to fail.
  1696. # before accessing size, stride, and offset we need to realize.
  1697. x.realize()
  1698. new_size = x.get_size()
  1699. new_stride = x.get_stride()
  1700. new_storage_offset = x.get_layout().offset + new_stride[dim] * actual_index
  1701. del new_size[dim]
  1702. del new_stride[dim]
  1703. return as_strided(x, new_size, new_stride, new_storage_offset)
  1704. else:
  1705. # no need to clamp, this function handles negative indexing itself
  1706. slice_result = slice_(x, dim, actual_index, actual_index + 1, clamp=False)
  1707. return squeeze(slice_result, dim)
  1708. # Unbacked Semantics:
  1709. # When the index idx is unbacked (e.g., u0), we compute the index dynamically
  1710. # during the lowering of the select operation using DynamicSelectStorageOffset.
  1711. unbacked_bindings = resolve_unbacked_bindings(
  1712. V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"]
  1713. )
  1714. assert unbacked_bindings is not None
  1715. assert len(unbacked_bindings) == 1, unbacked_bindings
  1716. unbacked_offset_sym, _ = next(iter(unbacked_bindings.items()))
  1717. # before accessing size, stride, and offset we need to realize.
  1718. x.realize()
  1719. new_size = x.get_size()
  1720. new_stride = x.get_stride()
  1721. new_storage_offset = unbacked_offset_sym
  1722. buffer = ir.DynamicSelectStorageOffset(
  1723. unbacked_offset_sym,
  1724. idx,
  1725. x.get_layout().offset,
  1726. new_stride[dim],
  1727. x.get_size()[dim],
  1728. clamp=False,
  1729. )
  1730. buffer.name = V.graph.register_buffer(buffer)
  1731. V.graph.register_operation(buffer)
  1732. del new_size[dim]
  1733. del new_stride[dim]
  1734. return as_strided(x, new_size, new_stride, new_storage_offset)
  1735. @register_lowering(aten.split, type_promotion_kind=None)
  1736. def split(x, sizes, dim=0):
  1737. dim = _validate_dim(x, dim, 0)
  1738. sizes_ = sizes
  1739. # If sizes is an integer (or a SymInt), we turn it into a list of sizes
  1740. # by computing what the actual size of each chunk should be.
  1741. if not isinstance(sizes, (list, tuple)):
  1742. x_size = x.get_size()[dim]
  1743. chunks = V.graph.sizevars.guard_int(FloorDiv(x_size + sizes - 1, sizes))
  1744. sizes_ = [sizes] * chunks
  1745. # The last chunk might have a smaller size than the rest.
  1746. sizes_[-1] = x_size - (chunks - 1) * sizes
  1747. # From this point, we assume that the sum of the sizes of all chunks
  1748. # equals the size of the base tensor.
  1749. result = []
  1750. start = 0
  1751. for size in sizes_:
  1752. end = start + size
  1753. # No need for clamping here, since we compute the exact
  1754. # start and end values.
  1755. result.append(slice_(x, dim, start, end, clamp=False))
  1756. start = end
  1757. return result
  1758. @register_lowering(aten.split_with_sizes, type_promotion_kind=None)
  1759. def split_with_sizes(x, sizes, dim=0):
  1760. return split(x, sizes, dim)
  1761. @register_lowering(aten.unbind, type_promotion_kind=None)
  1762. def unbind(x, dim=0):
  1763. dim = _validate_dim(x, dim, 0)
  1764. x_size = V.graph.sizevars.guard_int(x.get_size()[dim])
  1765. result = [select(x, dim, i) for i in range(x_size)]
  1766. return result
  1767. @register_lowering(aten.unfold, type_promotion_kind=None)
  1768. def unfold(x, dimension, size, step):
  1769. sizes = x.get_size()
  1770. ndim = len(sizes)
  1771. dim = canonicalize_dim(ndim, dimension)
  1772. if ndim == 0:
  1773. return slice_(unsqueeze(x, 0), end=size, clamp=False)
  1774. dim_size = sizes[dim]
  1775. sizevars = V.graph.sizevars
  1776. sizevars.check_leq(size, dim_size)
  1777. sizevars.check_lt(0, step) # type: ignore[arg-type]
  1778. new_dim_size = FloorDiv(dim_size - size, step) + 1
  1779. if sizevars.size_hint_or_throw(dim_size) > 0:
  1780. x.mark_reuse(
  1781. sizevars.size_hint_or_throw(CeilDiv(new_dim_size * size, dim_size))
  1782. )
  1783. out_size = [*sizes[:dim], new_dim_size, *sizes[dim + 1 :], size]
  1784. def reindexer(idx):
  1785. dim_idx = idx[-1] + idx[dim] * step
  1786. return (*idx[:dim], dim_idx, *idx[dim + 1 : -1])
  1787. return TensorBox(ir.GenericView.create(x, out_size, reindexer))
  1788. @register_lowering(aten.unsqueeze, type_promotion_kind=None)
  1789. def unsqueeze(x, dim):
  1790. dim = _validate_dim(x, dim, 1)
  1791. new_shape = list(x.get_size())
  1792. new_shape.insert(dim, sympy.S.One)
  1793. return view(x, new_shape)
  1794. @register_lowering(aten.unsqueeze_, type_promotion_kind=None)
  1795. def unsqueeze_(x, dim):
  1796. val = unsqueeze(x, dim)
  1797. assert isinstance(x, TensorBox)
  1798. assert isinstance(val, TensorBox)
  1799. x.data = val.data
  1800. return x
  1801. def _validate_dim(x, dim, offset=0):
  1802. dim = V.graph.sizevars.shape_env.evaluate_expr(sympy.sympify(dim))
  1803. ndim = len(x.get_size())
  1804. if dim < 0:
  1805. dim += ndim + offset
  1806. assert 0 <= dim < ndim + offset
  1807. return dim
  1808. @register_lowering(aten.glu)
  1809. def glu(x, dim=-1):
  1810. dim = _validate_dim(x, dim, 0)
  1811. # TODO: don't guard on static shape here
  1812. new_len = V.graph.sizevars.guard_int(x.get_size()[dim]) // 2
  1813. # no need to clamp, index is int based on input size
  1814. a = slice_(x, dim, 0, new_len, clamp=False)
  1815. b = slice_(x, dim, new_len, new_len * 2, clamp=False)
  1816. return mul(a, sigmoid(b))
  1817. def fallback_handler(kernel, add_to_fallback_set=True):
  1818. if add_to_fallback_set:
  1819. fallbacks.add(kernel)
  1820. def handler(*args, **kwargs):
  1821. def wrap_tensors(x):
  1822. return TensorBox.create(x) if isinstance(x, ir.IRNode) else x
  1823. return pytree.tree_map(
  1824. wrap_tensors, ir.FallbackKernel.create(kernel, *args, **kwargs)
  1825. )
  1826. # This lets us detect that a lowering is a fallback handler.
  1827. handler._is_fallback_handler = True # type: ignore[attr-defined]
  1828. return handler
  1829. @functools.cache
  1830. def _warn_complex_not_supported():
  1831. warnings.warn(
  1832. "Torchinductor does not support code generation for complex operators. Performance may be worse than eager."
  1833. )
  1834. # There are some types (CPU) which we accept as input but not as
  1835. # output.
  1836. def unsupported_input_tensor(t: torch.Tensor, node=None):
  1837. "Do not support reading or writing to this tensor"
  1838. if t.is_complex():
  1839. # Complex views are supported with IR ComplexView
  1840. _warn_complex_not_supported()
  1841. return True
  1842. if t.is_meta:
  1843. return True
  1844. if t.is_sparse:
  1845. return True
  1846. if t.dtype == torch.float8_e8m0fnu:
  1847. if not node:
  1848. return True
  1849. # allow bitcast, views, memory movement, but not arithmetic
  1850. # TODO: delete once triton adds native support
  1851. return not (
  1852. isinstance(node.target, torch._ops.OpOverload)
  1853. and node.target
  1854. in (
  1855. aten.view.dtype,
  1856. aten.cat.default,
  1857. aten.clone.default,
  1858. aten._scaled_mm.default,
  1859. )
  1860. or (isinstance(node.target, torch._ops.OpOverload) and is_view(node.target))
  1861. )
  1862. return False
  1863. def unsupported_output_tensor(t: torch.Tensor, node=None):
  1864. "Do not support writing tensor but can read from it"
  1865. supported_complex_views = (
  1866. aten.view.dtype,
  1867. torch.ops.prims.convert_element_type.default,
  1868. )
  1869. if node is not None and node.target in supported_complex_views and t.is_complex():
  1870. return False
  1871. if unsupported_input_tensor(t, node):
  1872. return True
  1873. return t.is_cpu and config.disable_cpp_codegen
  1874. def fallback_node_due_to_unsupported_type(node: torch.fx.Node, allow_cpu_inputs=True):
  1875. # Custom fallback lowering
  1876. if node.target is aten.view_as_complex.default:
  1877. return False
  1878. if node.op == "placeholder":
  1879. return False
  1880. # We should be able to remove this special case once `disable_cpp_codegen` is killed.
  1881. if node.target is aten.lift_fresh_copy.default:
  1882. return False
  1883. def check_skip_condition(inp_out_node, is_output):
  1884. if not isinstance(inp_out_node, torch.fx.Node):
  1885. return False
  1886. if "val" not in inp_out_node.meta:
  1887. return False
  1888. for meta in pytree.tree_leaves(inp_out_node.meta["val"]):
  1889. if not isinstance(meta, torch._subclasses.FakeTensor):
  1890. continue
  1891. if is_output:
  1892. if unsupported_output_tensor(meta, node):
  1893. return True
  1894. else:
  1895. if unsupported_input_tensor(meta, node):
  1896. return True
  1897. return False
  1898. # only skip codegen if there is a cpu output, not input
  1899. for arg in pytree.arg_tree_leaves(*node.args, **node.kwargs):
  1900. if check_skip_condition(arg, is_output=False):
  1901. return True
  1902. return check_skip_condition(node, is_output=True)
  1903. def make_fallback(op, layout_constraint=None, warn=True, override_decomp=False):
  1904. # When emulate_precision_casts is enabled, we skip decomposing addcmul ops
  1905. # to use the inductor lowering which preserves FMA semantics.
  1906. # For _foreach_addcdiv, we use the native CUDA kernel.
  1907. skip_decomp_for_precision = config.emulate_precision_casts and op in {
  1908. aten.addcmul,
  1909. aten._foreach_addcmul.Scalar,
  1910. aten._foreach_addcdiv.Scalar,
  1911. }
  1912. assert op not in decompositions or override_decomp or skip_decomp_for_precision, (
  1913. f"both a fallback and a decomp for same op: {op}"
  1914. )
  1915. if (
  1916. warn
  1917. and bool(os.getenv("CI"))
  1918. and get_decompositions([op])
  1919. # if fallback_random, we allow not decomposing random
  1920. and not (
  1921. config.fallback_random
  1922. and op in torch._decomp.decompositions_for_rng.extra_random_decomps
  1923. )
  1924. and not override_decomp
  1925. ):
  1926. # Note: 'warn' is holdover from when this was a warning, but for ops that previously
  1927. # set warn=False we do not want a CI error.
  1928. # Ignore the 'suppress errors' configs in CI, as this particular warning happens on startup anyway and is not
  1929. # likely to be triggered preferentially on one CI config over another.
  1930. if torch._dynamo.config.suppress_errors:
  1931. torch._dynamo.config.suppress_errors = False
  1932. log.warning(
  1933. "A make_fallback error occurred in suppress_errors config,"
  1934. " and suppress_errors is being disabled to surface it."
  1935. )
  1936. raise AssertionError(
  1937. f"make_fallback({op}): a decomposition exists, we should switch to it."
  1938. " To fix this error, either add a decomposition to core_aten_decompositions (preferred)"
  1939. " or inductor_decompositions, and delete the corresponding `make_fallback` line."
  1940. " Get help from the inductor team if unsure, don't pick arbitrarily to unblock yourself.",
  1941. )
  1942. def register_fallback(op_overload):
  1943. add_needs_realized_inputs(op_overload)
  1944. if layout_constraint is not None:
  1945. add_layout_constraint(op_overload, layout_constraint)
  1946. return register_lowering(op_overload, type_promotion_kind=None)(
  1947. fallback_handler(op_overload)
  1948. )
  1949. if isinstance(op, torch._ops.OpOverloadPacket):
  1950. for ol in op.overloads():
  1951. op_overload = getattr(op, ol)
  1952. register_fallback(op_overload)
  1953. elif isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)):
  1954. register_fallback(op)
  1955. else:
  1956. raise RuntimeError(f"Unsupported fallback {op} with type {type(op)}")
  1957. def philox_rand_offset(shape):
  1958. """
  1959. TorchInductor offset calculation differs from PyTorch eager offset
  1960. calculation for random ops (tl.rand vs torch.rand). In future, we should
  1961. strive for same impl for tl.rand and torch.rand.
  1962. """
  1963. numel = 1
  1964. for s in shape:
  1965. numel = numel * s
  1966. return tensor(numel, dtype=torch.int64)
  1967. @register_lowering(torch.ops.rngprims.philox_rand, type_promotion_kind=None)
  1968. def philox_rand(size, seed, offset, stride, device, dtype):
  1969. # stride arg is optional and will be used in future for distributed random
  1970. # ops. Currently, its unused.
  1971. random_pos = ir.FixedLayout(
  1972. device,
  1973. dtype,
  1974. size,
  1975. ir.FlexibleLayout.contiguous_strides(size),
  1976. ).make_indexer()
  1977. seed_loader = seed.make_loader()
  1978. offset_loader = offset.make_loader()
  1979. def inner_fn(index):
  1980. # Both seed and offset in the philox_rand op are tensors.
  1981. # torch seed and offsets are of type int64, but tl.rand accepts int32
  1982. seed_index_expr = ops.to_dtype(seed_loader([]), torch.int32)
  1983. offset_index_expr = ops.to_dtype(offset_loader([]), torch.int32)
  1984. # Get the offset'd position
  1985. rand_index_expr = ops.add(
  1986. ops.index_expr(random_pos(index), torch.int32), offset_index_expr
  1987. )
  1988. result = ops.rand(
  1989. seed_index_expr,
  1990. rand_index_expr,
  1991. )
  1992. return ops.to_dtype(result, dtype)
  1993. random_values_node = Pointwise.create(
  1994. device=device,
  1995. dtype=dtype,
  1996. inner_fn=inner_fn,
  1997. ranges=list(size),
  1998. )
  1999. offset_node = philox_rand_offset(size)
  2000. return random_values_node, offset_node
  2001. @register_lowering(aten.native_dropout, type_promotion_kind=None)
  2002. def native_dropout(x, p, train):
  2003. if config.fallback_random:
  2004. return pytree.tree_map(
  2005. TensorBox.create,
  2006. ir.FallbackKernel.create(aten.native_dropout.default, x, p, train),
  2007. )
  2008. else:
  2009. raise AssertionError("should be handled in replace_random.py")
  2010. @register_lowering(aten.bernoulli_, type_promotion_kind=None)
  2011. def bernoulli_(x, *args):
  2012. assert config.fallback_random or x.get_device() == torch.device("cpu"), (
  2013. "this should be handled in decomps unless config.fallback_random or the device is CPU"
  2014. )
  2015. x.realize()
  2016. op_overload = (
  2017. aten.bernoulli_.float
  2018. if len(args) == 0 or isinstance(args[0], float)
  2019. else aten.bernoulli_.Tensor
  2020. )
  2021. ir.InplaceBernoulliFallback(op_overload, x, *args)
  2022. return x
  2023. @register_lowering(aten.bernoulli.p, type_promotion_kind=None)
  2024. def bernoulli_p(x, *args):
  2025. assert config.fallback_random or x.get_device() == torch.device("cpu"), (
  2026. "this should be handled in decomps unless config.fallback_random or the device is CPU"
  2027. )
  2028. return bernoulli_(clone(x), *args)
  2029. # This shouldn't be called in general
  2030. @register_lowering(aten._foobar)
  2031. def _foobar(_):
  2032. raise AssertionError
  2033. @functools.lru_cache(1)
  2034. def _warn_triton_random(salt):
  2035. log.info("using triton random, expect difference from eager")
  2036. def warn_triton_random():
  2037. # only warn once per graph
  2038. _warn_triton_random(V.graph.creation_time)
  2039. fallback_rand_default = fallback_handler(aten.rand.default)
  2040. fallback_rand_generator = fallback_handler(aten.rand.generator)
  2041. fallback_randn_default = fallback_handler(aten.randn.default)
  2042. fallback_randn_generator = fallback_handler(aten.randn.generator)
  2043. make_fallback(aten.randint)
  2044. # TODO: mlazos reevaluate if we want to codegen something different
  2045. make_fallback(torch.ops.streams.record_event.default)
  2046. make_fallback(torch.ops.streams.wait_event.default)
  2047. @register_lowering(aten.rand)
  2048. def rand(*args, **kwargs):
  2049. if kwargs.get("generator") is not None:
  2050. return fallback_rand_generator(*args, **kwargs)
  2051. elif config.fallback_random:
  2052. kwargs.pop("generator", None)
  2053. return fallback_rand_default(*args, **kwargs)
  2054. raise AssertionError("should have been handled in replace_random.py")
  2055. @register_lowering(aten.randn)
  2056. def randn(*args, **kwargs):
  2057. if kwargs.get("generator") is not None:
  2058. return fallback_randn_generator(*args, **kwargs)
  2059. elif config.fallback_random:
  2060. kwargs.pop("generator", None)
  2061. return fallback_randn_default(*args, **kwargs)
  2062. raise AssertionError("should have been handled in replace_random.py")
  2063. @register_lowering(inductor_prims.force_stride_order, type_promotion_kind=None)
  2064. def inductor_force_stride_order(input_tensor, stride):
  2065. stride_order = ir.get_stride_order(stride)
  2066. return ir.ExternKernel.require_stride_order(input_tensor, stride_order)
  2067. @register_lowering(inductor_prims.seed, type_promotion_kind=None)
  2068. def inductor_seed(device: torch.device):
  2069. raise AssertionError("should be handled in fuse_seed_creation_pass()")
  2070. @register_lowering(inductor_prims.seeds, type_promotion_kind=None)
  2071. def inductor_seeds(count, device):
  2072. warn_triton_random()
  2073. return TensorBox.create(ir.RandomSeeds(count, decode_device(device)))
  2074. @register_lowering(inductor_prims.lookup_seed, type_promotion_kind=None)
  2075. def inductor_lookup_seed(seeds, index):
  2076. def inner_fn(_):
  2077. return ops.load_seed(seeds.get_name(), index)
  2078. return Pointwise.create(
  2079. device=seeds.get_device(),
  2080. dtype=seeds.get_dtype(),
  2081. inner_fn=inner_fn,
  2082. ranges=[],
  2083. )
  2084. @register_lowering(inductor_prims.random, type_promotion_kind=None)
  2085. def inductor_random(size: list[int], seed: TensorBox, mode: str, *, offset: int = 0):
  2086. assert not config.fallback_random
  2087. assert mode in ("rand", "randn")
  2088. size = [*size]
  2089. dtype = torch.float32
  2090. device = seed.get_device_or_error()
  2091. random_pos = ir.FixedLayout(
  2092. device, dtype, size, ir.FlexibleLayout.contiguous_strides(size), offset=offset
  2093. ).make_indexer()
  2094. seed_loader = seed.make_loader()
  2095. def inner_fn(index):
  2096. return getattr(ops, mode)(
  2097. seed_loader([]),
  2098. ops.index_expr(random_pos(index), torch.int32),
  2099. )
  2100. result = Pointwise.create(
  2101. device=device,
  2102. dtype=dtype,
  2103. inner_fn=inner_fn,
  2104. ranges=[*size],
  2105. )
  2106. result.realize()
  2107. return result
  2108. @register_lowering(inductor_prims.randint, type_promotion_kind=None)
  2109. def inductor_randint(
  2110. low: int, high: int, size: list[int], seed: TensorBox, *, offset: int = 0
  2111. ):
  2112. assert not config.fallback_random
  2113. size = [*size]
  2114. dtype = torch.int64
  2115. device = seed.get_device_or_error()
  2116. random_pos = ir.FixedLayout(
  2117. device, dtype, size, ir.FlexibleLayout.contiguous_strides(size), offset=offset
  2118. ).make_indexer()
  2119. seed_loader = seed.make_loader()
  2120. def inner_fn(index):
  2121. return ops.randint64(
  2122. seed_loader([]),
  2123. ops.index_expr(random_pos(index), torch.int32),
  2124. ops.index_expr(low, torch.int64),
  2125. ops.index_expr(high, torch.int64),
  2126. )
  2127. return Pointwise.create(
  2128. device=device,
  2129. dtype=dtype,
  2130. inner_fn=inner_fn,
  2131. ranges=[*size],
  2132. )
  2133. def _boundaries_helper(tb: TensorBox) -> tuple[str, sympy.Expr, sympy.Expr, sympy.Expr]:
  2134. # Calculate the maximum offset for the boundaries tensor
  2135. # For a strided tensor, this is sum((size[i] - 1) * stride[i]) + stride[-1]
  2136. # This ensures the mask check in bucketize_binary_search works correctly
  2137. # for both contiguous and non-contiguous tensors.
  2138. size = tb.get_size()
  2139. stride = tb.get_stride()
  2140. max_offset = sum((s - 1) * st for s, st in zip(size, stride)) + stride[-1]
  2141. return (
  2142. tb.get_name(),
  2143. size[-1],
  2144. max_offset,
  2145. stride[-1],
  2146. )
  2147. def _sorter_helper(tb: TensorBox) -> tuple[str, sympy.Expr]:
  2148. return tb.get_name(), tb.get_stride()[-1]
  2149. @register_lowering(aten.searchsorted.Tensor, type_promotion_kind=None)
  2150. def searchsorted(
  2151. sorted_sequence: TensorBox,
  2152. self: TensorBox,
  2153. *,
  2154. out_int32: bool = False,
  2155. right: bool = False,
  2156. side: Optional[str] = None,
  2157. sorter: Optional[TensorBox] = None,
  2158. ) -> TensorBox:
  2159. validate_bucketize = lambda tb: V.graph.has_feature( # noqa: E731
  2160. tb, BackendFeature.BUCKETIZE
  2161. )
  2162. if (
  2163. not validate_bucketize(sorted_sequence)
  2164. or not validate_bucketize(self)
  2165. or (sorter is not None and not validate_bucketize(sorter))
  2166. ):
  2167. return fallback_handler(aten.searchsorted.Tensor, add_to_fallback_set=False)(
  2168. sorted_sequence,
  2169. self,
  2170. out_int32=out_int32,
  2171. right=right,
  2172. side=side,
  2173. sorter=sorter,
  2174. )
  2175. # If side is present, override the value of right if needed. This assumes that
  2176. # validation of the two options being non-contradictory is already done by the
  2177. # searchsorted meta-function.
  2178. if side is not None and side == "right":
  2179. right = True
  2180. index_dtype = torch.int32 if out_int32 else torch.int64
  2181. values_loader = self.make_loader()
  2182. # The entire sorted_sequence tensor needs to be used by ops.bucketize, so we need to
  2183. # realize it into global memory; or in other words, we can't guarantee that
  2184. # sorted_sequence.get_name() (used below) will exist unless we call
  2185. # sorted_sequence.realize().
  2186. sorted_sequence.realize()
  2187. if sorter is not None:
  2188. sorter.realize()
  2189. if len(sorted_sequence.get_size()) == 1:
  2190. def inner_fn(idx):
  2191. val = values_loader(idx)
  2192. return ops.bucketize(
  2193. val,
  2194. _boundaries_helper(sorted_sequence),
  2195. 0,
  2196. index_dtype,
  2197. right,
  2198. sorter=None if sorter is None else _sorter_helper(sorter),
  2199. sorter_indices=None if sorter is None else 0,
  2200. )
  2201. else:
  2202. def inner_fn(idx):
  2203. val = values_loader(idx)
  2204. # Get index to the beginning of the sorted sequence within a flattened
  2205. # version of the array.
  2206. def get_flattened_index(tb: TensorBox):
  2207. strides = tb.get_stride()
  2208. return ops.index_expr(
  2209. functools.reduce(
  2210. operator.add, (s * i for s, i in zip(strides[:-1], idx[:-1]))
  2211. ),
  2212. index_dtype,
  2213. )
  2214. return ops.bucketize(
  2215. val,
  2216. _boundaries_helper(sorted_sequence),
  2217. get_flattened_index(sorted_sequence),
  2218. index_dtype,
  2219. right,
  2220. sorter=None if sorter is None else _sorter_helper(sorter),
  2221. sorter_indices=None if sorter is None else get_flattened_index(sorter),
  2222. )
  2223. device = self.get_device()
  2224. result = Pointwise.create(
  2225. device=device,
  2226. dtype=index_dtype,
  2227. inner_fn=inner_fn,
  2228. ranges=self.shape,
  2229. )
  2230. # see [NOTE: inductor bucketize realize]
  2231. result.realize()
  2232. return result
  2233. @register_lowering(
  2234. aten.bucketize.Tensor, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH
  2235. )
  2236. def bucketize(
  2237. input: TensorBox,
  2238. boundaries: TensorBox,
  2239. *,
  2240. out_int32: bool = False,
  2241. right: bool = False,
  2242. ):
  2243. assert len(boundaries.get_size()) == 1
  2244. if not (
  2245. V.graph.has_feature(input, BackendFeature.BUCKETIZE)
  2246. and V.graph.has_feature(boundaries, BackendFeature.BUCKETIZE)
  2247. ):
  2248. return fallback_handler(aten.bucketize.Tensor, add_to_fallback_set=False)(
  2249. input, boundaries, out_int32=out_int32, right=right
  2250. )
  2251. # The entire boundaries tensor needs to be used by ops.bucketize, so we
  2252. # need to realize it into global memory; or in other words, we can't
  2253. # guarantee that boundaries.get_name() (used below) will exist unless
  2254. # we call boundaries.realize().
  2255. boundaries.realize()
  2256. device = input.get_device()
  2257. input_loader = input.make_loader()
  2258. index_dtype = torch.int32 if out_int32 else torch.int64
  2259. def inner_fn(index):
  2260. val = input_loader(index)
  2261. indices = ops.bucketize(
  2262. val,
  2263. _boundaries_helper(boundaries),
  2264. 0,
  2265. index_dtype,
  2266. right,
  2267. )
  2268. return indices
  2269. result = Pointwise.create(
  2270. device=device,
  2271. dtype=index_dtype,
  2272. inner_fn=inner_fn,
  2273. ranges=input.get_size(),
  2274. )
  2275. # [NOTE: inductor bucketize realize]
  2276. # bucketize_binary_search is relatively expensive, so we don't want to re-compute
  2277. # it unnecessarily. If we run bucketize() and then broadcast the result, we don't
  2278. # want this to be fused into a large number of duplicate bucketize() computations
  2279. # for each of the elements in the result.
  2280. #
  2281. # If no broadcasting occurs, fusions can still occur in scheduler.py
  2282. result.realize()
  2283. return result
  2284. def require_dense(_, *args, **kwargs):
  2285. args, kwargs = pytree.tree_map_only(
  2286. ir.IRNode, ir.ExternKernel.require_stride1, (args, kwargs)
  2287. )
  2288. return args, kwargs
  2289. def require_contiguous(_, *args, **kwargs):
  2290. args, kwargs = pytree.tree_map_only(
  2291. ir.IRNode, ir.ExternKernel.require_contiguous, (args, kwargs)
  2292. )
  2293. return args, kwargs
  2294. def require_contiguous_strides(_, *args, **kwargs):
  2295. # TODO: combine this with require_contiguous after
  2296. # https://github.com/pytorch/pytorch/pull/148235 lands.
  2297. args, kwargs = pytree.tree_map_only(
  2298. ir.IRNode, ir.ExternKernel.require_contiguous_strides, (args, kwargs)
  2299. )
  2300. return args, kwargs
  2301. def require_channels_last(_, *args, **kwargs):
  2302. args, kwargs = pytree.tree_map_only(
  2303. ir.IRNode, ir.ExternKernel.require_channels_last, (args, kwargs)
  2304. )
  2305. return args, kwargs
  2306. def constrain_to_fake_tensor(arg, fake_arg):
  2307. if fake_arg is None:
  2308. return arg
  2309. if isinstance(fake_arg, FakeScriptObject):
  2310. return arg
  2311. if isinstance(arg, ir.IRNode):
  2312. return ir.ExternKernel.require_exact_strides(arg, fake_arg.stride())
  2313. if isinstance(arg, dict):
  2314. return {key: constrain_to_fake_tensor(arg[key], fake_arg[key]) for key in arg}
  2315. elif isinstance(arg, (tuple, list)):
  2316. return type(arg)(
  2317. constrain_to_fake_tensor(a, f_a) for (a, f_a) in zip(arg, fake_arg)
  2318. )
  2319. return arg
  2320. def constrain_to_fake_tensors(args, kwargs, fake_args, fake_kwargs):
  2321. args = tuple(
  2322. constrain_to_fake_tensor(arg, fake_arg)
  2323. for arg, fake_arg in zip(args, fake_args)
  2324. )
  2325. kwargs = {k: constrain_to_fake_tensor(v, fake_kwargs[k]) for k, v in kwargs.items()}
  2326. return args, kwargs
  2327. def constrain_to_fx_strides(fx_node, *args, **kwargs):
  2328. def apply_constraint(arg, fx_arg):
  2329. if isinstance(arg, ir.IRNode):
  2330. stride_order = ir.get_stride_order(
  2331. fx_arg.meta["val"].stride(), V.graph.sizevars.shape_env
  2332. )
  2333. return ir.ExternKernel.require_stride_order(arg, stride_order)
  2334. if isinstance(arg, dict):
  2335. return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg}
  2336. return arg
  2337. args = tuple(
  2338. apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)
  2339. )
  2340. kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
  2341. return args, kwargs
  2342. def sdpa_constraint(fx_node, *args, **kwargs):
  2343. # sdpa requires dense last dimension]
  2344. def apply_constraint(idx, arg, fx_arg):
  2345. if not isinstance(arg, ir.IRNode):
  2346. return arg
  2347. meta_val = fx_arg.meta["val"]
  2348. meta_stride_expr = [
  2349. s.node.expr if isinstance(s, torch.SymInt) else s for s in meta_val.stride()
  2350. ]
  2351. shape_env = V.graph.sizevars.shape_env
  2352. stride_order = ir.get_stride_order(meta_val.stride(), shape_env)
  2353. if stride_order and stride_order[-1] != 0:
  2354. # contiguous stride order
  2355. stride_order = list(reversed(range(len(arg.get_size()))))
  2356. if (
  2357. fx_node.target
  2358. == aten._scaled_dot_product_efficient_attention_backward.default
  2359. and idx in (0, 5)
  2360. ):
  2361. assert len(stride_order) == 4
  2362. # The 0 and 5th arguments for aten._scaled_dot_product_efficient_attention_backward.default
  2363. # are for out and gradient_out. They have to be in
  2364. # (3, 1, 2, 0) stride order. Otherwise the kernel will crash.
  2365. # Check https://github.com/pytorch/pytorch/issues/138772
  2366. stride_order = (3, 1, 2, 0)
  2367. if not meta_val.is_cuda:
  2368. return ir.ExternKernel.require_stride_order(arg, stride_order)
  2369. # This is the minimum alignment required by SDPA kernels for attention_bias.
  2370. # This value can be found in pytorch/aten/src/ATen/native/transformers/attention.cpp preprocess_mask
  2371. ALIGNMENT = 8
  2372. # effn_attn_fwd does requires dense last dim, not just alignment
  2373. effn_attn_fwd_bias = (
  2374. fx_node.target
  2375. == torch.ops.aten._scaled_dot_product_efficient_attention.default
  2376. and idx == 3
  2377. )
  2378. assert isinstance(arg, TensorBox)
  2379. if len(arg.get_size()) not in (3, 4):
  2380. return arg
  2381. is_aligned_tensor = ir.is_aligned_realized_tensor(arg, ALIGNMENT)
  2382. if is_aligned_tensor:
  2383. return ir.try_match_insignificant_strides(
  2384. ir.ExternKernel.realize_input(arg), meta_stride_expr
  2385. )
  2386. if (
  2387. isinstance(arg, IRNode)
  2388. and arg.maybe_get_stride() is not None
  2389. and is_aligned_tensor
  2390. ):
  2391. return ir.try_match_insignificant_strides(
  2392. ir.ExternKernel.realize_input(arg), meta_stride_expr
  2393. )
  2394. if effn_attn_fwd_bias:
  2395. out_size = list(arg.get_size())
  2396. expanded_dims = []
  2397. # We require a dense last dimension, but the other strides
  2398. # can be expanded, which results in a smaller tensor
  2399. maybe_stride = arg.maybe_get_stride()
  2400. for i in range(len(arg.get_size()) - 1):
  2401. if V.graph.sizevars.statically_known_equals(meta_stride_expr[i], 0) or (
  2402. maybe_stride is not None
  2403. and V.graph.sizevars.statically_known_equals(maybe_stride[i], 0)
  2404. ):
  2405. expanded_dims.append(i)
  2406. # Now, pad strides to alignment
  2407. out_strides = [-1] * len(out_size)
  2408. out_strides[-1] = 1
  2409. stride = 1
  2410. for i in range(len(out_size) - 2, -1, -1):
  2411. if out_strides[i + 1] != 0:
  2412. stride = stride * out_size[i + 1]
  2413. # the expanded dims still need to be aligned, if they are,
  2414. # we can make them expanded by setting the stride equal to 0
  2415. if i in expanded_dims:
  2416. if V.graph.sizevars.statically_known_equals(
  2417. out_strides[i + 1] % ALIGNMENT, 0
  2418. ):
  2419. out_strides[i] = 0
  2420. continue
  2421. if not V.graph.sizevars.statically_known_equals(stride % ALIGNMENT, 0):
  2422. stride = ceildiv(stride, ALIGNMENT) * ALIGNMENT
  2423. out_strides[i] = stride
  2424. return ir.ExternKernel.require_exact_strides(arg, out_strides)
  2425. if is_aligned_tensor:
  2426. return ir.try_match_insignificant_strides(
  2427. ir.ExternKernel.realize_input(arg), meta_stride_expr
  2428. )
  2429. if (
  2430. isinstance(arg, IRNode)
  2431. and arg.maybe_get_stride() is not None
  2432. and is_aligned_tensor
  2433. ):
  2434. return ir.try_match_insignificant_strides(
  2435. ir.ExternKernel.realize_input(arg), meta_stride_expr
  2436. )
  2437. def is_aligned(x):
  2438. return V.graph.sizevars.guard_or_false(
  2439. sympy.Eq(Mod(x.get_size()[-1], ALIGNMENT), 0)
  2440. )
  2441. if isinstance(arg.data, ir.BaseView):
  2442. if not is_aligned(arg):
  2443. if is_aligned(arg.unwrap_view()):
  2444. return ir.try_match_insignificant_strides(
  2445. ir.ExternKernel.realize_input(arg), meta_stride_expr
  2446. )
  2447. return ir.ExternKernel.require_stride_order(arg, stride_order)
  2448. args = tuple(
  2449. apply_constraint(idx, arg, fx_arg)
  2450. for idx, (arg, fx_arg) in enumerate(zip(args, fx_node.args))
  2451. )
  2452. kwargs = {k: apply_constraint(-1, v, fx_node.kwargs[k]) for k, v in kwargs.items()}
  2453. return args, kwargs
  2454. # WIP
  2455. make_fallback(aten._adaptive_avg_pool3d) # @isuruf
  2456. make_fallback(aten.adaptive_max_pool3d) # @isuruf
  2457. make_fallback(aten._scaled_dot_product_attention_math_for_mps) # @malfet
  2458. # 1) Easy
  2459. make_fallback(aten.uniform, warn=False)
  2460. make_fallback(aten.exponential.default, warn=False) # (fails accuracy on test_torch.py)
  2461. make_fallback(aten._pdist_forward, require_contiguous) # Has decomp. Needs benchmarks
  2462. make_fallback(aten.soft_margin_loss_backward, warn=False) # py_impl?
  2463. make_fallback(aten._fused_rms_norm, warn=False) # (MPS-only and faster than decomp)
  2464. if torch.xpu.is_available():
  2465. make_fallback(
  2466. aten.embedding_dense_backward, warn=False
  2467. ) # (XPU-only and faster than decomp)
  2468. if torch.mtia._is_compiled():
  2469. make_fallback(
  2470. aten.native_layer_norm, warn=False
  2471. ) # (MTIA-only and faster than decomp)
  2472. # 1.5) Easy or Impossible
  2473. make_fallback(aten._cdist_forward) # p=2 should be feasible
  2474. make_fallback(aten._cdist_backward)
  2475. # 2) Medium
  2476. make_fallback(aten._trilinear)
  2477. # 3) Difficult
  2478. # Scans
  2479. # See the discussion at
  2480. # https://dev-discuss.pytorch.org/t/pytorch-sparse-gnn-compiler-rfc/1644/19
  2481. make_fallback(aten.segment_reduce.default)
  2482. make_fallback(aten._segment_reduce_backward.default)
  2483. # Histogram (need to implement Histogram IR)
  2484. make_fallback(aten.histc)
  2485. make_fallback(aten.histogram.bin_ct)
  2486. make_fallback(aten._histogramdd_bin_edges.default)
  2487. make_fallback(aten._histogramdd_from_bin_cts.default)
  2488. # Need templated kernel
  2489. make_fallback(aten.addbmm)
  2490. make_fallback(aten._addmm_activation, warn=False)
  2491. make_fallback(aten._grouped_mm, require_dense)
  2492. # Need templated kernel. Probably impossible to write efficiently
  2493. make_fallback(aten.convolution_backward, constrain_to_fx_strides)
  2494. make_fallback(aten._cudnn_rnn, require_dense)
  2495. make_fallback(aten._cudnn_rnn_backward, require_contiguous)
  2496. make_fallback(aten.miopen_rnn, require_dense)
  2497. make_fallback(aten.miopen_rnn_backward, require_contiguous)
  2498. # Haven't checked but sound difficult / impossible
  2499. make_fallback(aten._embedding_bag, require_contiguous)
  2500. make_fallback(aten._embedding_bag_forward_only, require_contiguous)
  2501. make_fallback(aten._embedding_bag_backward)
  2502. make_fallback(aten._embedding_bag_per_sample_weights_backward)
  2503. make_fallback(aten._embedding_bag_per_sample_weights_backward)
  2504. make_fallback(aten._fused_moving_avg_obs_fq_helper)
  2505. make_fallback(aten._fused_moving_avg_obs_fq_helper_functional)
  2506. # 4) Backwards (try py_impl'ing them) when fwd is written as a decomp
  2507. make_fallback(aten.max_pool3d_with_indices_backward)
  2508. make_fallback(aten._adaptive_avg_pool2d_backward, require_dense)
  2509. make_fallback(aten._adaptive_avg_pool3d_backward)
  2510. make_fallback(aten.adaptive_max_pool2d_backward)
  2511. make_fallback(aten.adaptive_max_pool3d_backward)
  2512. make_fallback(aten.fractional_max_pool2d_backward)
  2513. make_fallback(aten.fractional_max_pool3d_backward)
  2514. make_fallback(aten.replication_pad1d_backward)
  2515. make_fallback(aten.replication_pad2d_backward)
  2516. make_fallback(aten.upsample_linear1d_backward)
  2517. make_fallback(aten.upsample_bicubic2d_backward, require_contiguous)
  2518. make_fallback(aten.upsample_trilinear3d_backward)
  2519. make_fallback(aten.grid_sampler_2d_backward)
  2520. make_fallback(aten._pdist_backward, require_contiguous)
  2521. # 5) Impossible (missing triton/CPU features)
  2522. # Sorting / Sorting-like
  2523. make_fallback(aten.sort)
  2524. make_fallback(aten.sort.stable)
  2525. make_fallback(aten.kthvalue)
  2526. make_fallback(aten.topk)
  2527. make_fallback(aten.mode)
  2528. make_fallback(aten.median)
  2529. make_fallback(aten.nanmedian)
  2530. make_fallback(aten.randperm)
  2531. # see: https://github.com/pytorch/pytorch/pull/121354
  2532. make_fallback(aten.resize_)
  2533. make_fallback(aten.resize_as_)
  2534. # Linalg
  2535. make_fallback(aten._linalg_det)
  2536. make_fallback(aten.linalg_householder_product)
  2537. make_fallback(aten.linalg_inv_ex)
  2538. make_fallback(aten.linalg_ldl_factor_ex)
  2539. make_fallback(aten.linalg_ldl_solve)
  2540. make_fallback(aten.linalg_lu)
  2541. make_fallback(aten.linalg_lu_factor_ex)
  2542. make_fallback(aten.linalg_lu_solve)
  2543. make_fallback(aten.linalg_matrix_exp)
  2544. make_fallback(aten.linalg_qr)
  2545. make_fallback(aten._linalg_slogdet)
  2546. make_fallback(aten._linalg_solve_ex)
  2547. make_fallback(aten.linalg_solve_triangular)
  2548. make_fallback(aten._linalg_svd)
  2549. make_fallback(aten.lu_unpack)
  2550. make_fallback(aten.ormqr)
  2551. make_fallback(aten._linalg_check_errors)
  2552. make_fallback(aten.linalg_pinv.atol_rtol_tensor)
  2553. make_fallback(aten._linalg_eigh)
  2554. make_fallback(aten.triangular_solve)
  2555. make_fallback(aten.linalg_cholesky_ex)
  2556. make_fallback(aten.cholesky_inverse)
  2557. make_fallback(aten.cholesky_solve)
  2558. make_fallback(aten.geqrf)
  2559. make_fallback(aten._fft_r2c) # needs complex as well
  2560. # Data dependent (are these necessary?)
  2561. make_fallback(aten.nonzero.default)
  2562. # Misc
  2563. make_fallback(aten.gcd.default, warn=False)
  2564. make_fallback(aten._thnn_fused_lstm_cell, require_dense)
  2565. make_fallback(torch._prims.rng_prims.run_and_save_rng_state)
  2566. make_fallback(torch._prims.rng_prims.run_with_rng_state)
  2567. make_fallback(torch._prims.rng_prims.graphsafe_run_with_rng_state)
  2568. # Implemented / Half implemented
  2569. # Scans. Implemented for CUDA, missing CPU
  2570. make_fallback(aten.masked_scatter)
  2571. make_fallback(aten.masked_scatter_backward)
  2572. # Complex number support
  2573. make_fallback(aten.view_as_complex, require_contiguous)
  2574. make_fallback(aten.angle) # needs complex
  2575. # Needs efficentzerotensor
  2576. make_fallback(aten._efficientzerotensor)
  2577. # Needs Sparse
  2578. make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors)
  2579. make_fallback(aten.to_sparse)
  2580. make_fallback(aten._to_sparse)
  2581. # Needs dimname support
  2582. make_fallback(aten.zeros.names)
  2583. # 6) Pattern-matched
  2584. make_fallback(
  2585. aten._scaled_dot_product_efficient_attention.default,
  2586. sdpa_constraint,
  2587. warn=False,
  2588. )
  2589. make_fallback(
  2590. aten._scaled_dot_product_efficient_attention_backward.default,
  2591. sdpa_constraint,
  2592. warn=False,
  2593. )
  2594. make_fallback(
  2595. aten._scaled_dot_product_flash_attention.default,
  2596. sdpa_constraint,
  2597. warn=False,
  2598. )
  2599. make_fallback(
  2600. aten._scaled_dot_product_flash_attention.quantized,
  2601. warn=False,
  2602. )
  2603. make_fallback(
  2604. aten._scaled_dot_product_flash_attention_backward.default,
  2605. sdpa_constraint,
  2606. warn=False,
  2607. )
  2608. make_fallback(
  2609. aten._scaled_dot_product_cudnn_attention.default,
  2610. sdpa_constraint,
  2611. warn=False,
  2612. )
  2613. make_fallback(
  2614. aten._scaled_dot_product_cudnn_attention_backward.default,
  2615. sdpa_constraint,
  2616. warn=False,
  2617. )
  2618. make_fallback(
  2619. aten._scaled_dot_product_flash_attention_for_cpu.default,
  2620. sdpa_constraint,
  2621. warn=False,
  2622. )
  2623. make_fallback(
  2624. aten._scaled_dot_product_flash_attention_for_cpu_backward.default,
  2625. sdpa_constraint,
  2626. warn=False,
  2627. )
  2628. make_fallback(
  2629. aten._scaled_dot_product_fused_attention_overrideable.default,
  2630. sdpa_constraint,
  2631. warn=False,
  2632. )
  2633. make_fallback(
  2634. aten._scaled_dot_product_fused_attention_overrideable_backward.default,
  2635. sdpa_constraint,
  2636. warn=False,
  2637. )
  2638. make_fallback(aten._flash_attention_forward.default, sdpa_constraint)
  2639. make_fallback(aten._flash_attention_forward.quantized)
  2640. make_fallback(aten._flash_attention_backward.default, sdpa_constraint)
  2641. make_fallback(aten._efficient_attention_forward.default, sdpa_constraint)
  2642. make_fallback(aten._efficient_attention_backward.default, sdpa_constraint)
  2643. # index_reduce requires fallback when use_scatter_fallback(...) returns True
  2644. make_fallback(aten.index_reduce)
  2645. make_fallback(aten.repeat_interleave.Tensor, override_decomp=True)
  2646. make_fallback(aten._weight_norm_interface_backward.default, require_contiguous)
  2647. # Register with type_promotion_kind None.
  2648. # For example, fp16.copy_(fp32) should **not** promote the first input's dtype.
  2649. @register_lowering(aten.copy, type_promotion_kind=None)
  2650. def copy(self, src, non_blocking=False):
  2651. if not isinstance(src, ir.IRNode):
  2652. src = tensor(src, dtype=self.get_dtype(), device=self.get_device())
  2653. x = src
  2654. if self.get_device() != src.get_device():
  2655. x = to_device(x, self.get_device())
  2656. if self.get_dtype() != src.get_dtype():
  2657. x = to_dtype(x, self.get_dtype())
  2658. if self.get_size() != src.get_size():
  2659. out = expand(x, self.get_size())
  2660. return clone(out)
  2661. return clone(x)
  2662. @register_lowering(aten.clone)
  2663. def clone(x, *, memory_format=None):
  2664. # TODO(jansel): memory format
  2665. return Pointwise.create(
  2666. device=x.get_device(),
  2667. dtype=x.get_dtype(),
  2668. inner_fn=x.make_loader(),
  2669. ranges=list(x.get_size()),
  2670. )
  2671. def clone_preserve_reinterpret_view(x):
  2672. reinterpret_view_layouts = []
  2673. if isinstance(x, TensorBox) and isinstance(x.data, ir.ReinterpretView):
  2674. x = x.data # unwrap TensorBox
  2675. # pyrefly: ignore [bad-assignment]
  2676. while isinstance(x, ir.ReinterpretView):
  2677. reinterpret_view_layouts.append(x.get_layout())
  2678. x = x.data
  2679. x = TensorBox(x)
  2680. x = clone(x)
  2681. if reinterpret_view_layouts:
  2682. x = x.data # unwrap TensorBox
  2683. for layout in reinterpret_view_layouts[::-1]:
  2684. x = ir.ReinterpretView(data=x, layout=layout)
  2685. x = TensorBox(x)
  2686. return x
  2687. if hasattr(aten, "lift_fresh_copy"):
  2688. register_lowering(aten.lift_fresh_copy)(clone)
  2689. @register_lowering(prims.iota)
  2690. def iota(
  2691. length,
  2692. *,
  2693. start,
  2694. step,
  2695. dtype,
  2696. device,
  2697. requires_grad,
  2698. ):
  2699. def fn(index):
  2700. return ops.index_expr(step * index[0] + start, dtype=dtype)
  2701. return Pointwise.create(
  2702. device=decode_device(device),
  2703. dtype=dtype,
  2704. inner_fn=fn,
  2705. ranges=[length],
  2706. )
  2707. @register_lowering(aten.select_scatter, type_promotion_kind=None)
  2708. def select_scatter(x, src, dim: int, index: int):
  2709. src = to_dtype(src, x.get_dtype())
  2710. x_loader = x.make_loader()
  2711. dim = _validate_dim(x, dim, 0)
  2712. if V.graph.sizevars.guard_or_false(sympy.Lt(index, 0)):
  2713. index = index + x.get_size()[dim]
  2714. elif V.graph.sizevars.guard_or_false(sympy.Ge(index, 0)):
  2715. pass
  2716. else:
  2717. # unbacked index
  2718. return fallback_handler(aten.select_scatter.default)(x, src, dim, index)
  2719. V.graph.sizevars.check_leq(0, index) # type: ignore[arg-type]
  2720. V.graph.sizevars.check_lt(index, x.get_size()[dim]) # type: ignore[arg-type]
  2721. src = expand(unsqueeze(src, dim), x.get_size())
  2722. src_loader = src.make_loader()
  2723. def inner_fn(idx):
  2724. return ops.where(
  2725. ops.eq(
  2726. ops.index_expr(idx[dim], torch.int32),
  2727. ops.index_expr(index, torch.int32),
  2728. ),
  2729. src_loader(idx),
  2730. x_loader(idx),
  2731. )
  2732. return Pointwise.create(
  2733. device=x.get_device(),
  2734. dtype=x.get_dtype(),
  2735. inner_fn=inner_fn,
  2736. ranges=list(x.get_size()),
  2737. )
  2738. @register_lowering(aten.slice_scatter, type_promotion_kind=None)
  2739. def slice_scatter(x, src, dim=0, start=None, end=None, step=1):
  2740. src = to_dtype(src, x.get_dtype())
  2741. x_loader = x.make_loader()
  2742. dim = _validate_dim(x, dim, 0)
  2743. dim_size = x.get_size()[dim]
  2744. # pyrefly: ignore [bad-argument-type]
  2745. start, end = ir.SliceView.normalize_start_end(x, dim, start, end)
  2746. src_size = list(x.get_size())
  2747. src_size[dim] = FloorDiv(end - start + (step - 1), step)
  2748. src = expand(src, src_size)
  2749. src_loader = src.make_loader()
  2750. def inner_fn(idx):
  2751. if start == 0 and end == dim_size and step == 1:
  2752. # selecting every element is the same as just src.clone()
  2753. return src_loader(idx)
  2754. idx_dim = ops.index_expr(idx[dim], torch.int64)
  2755. src_idx = list(idx)
  2756. src_idx[dim] = FloorDiv(idx[dim] - start, step)
  2757. mask = []
  2758. if start != 0:
  2759. mask.append(
  2760. ops.ge(
  2761. idx_dim,
  2762. ops.index_expr(sympy.expand(start), torch.int64),
  2763. )
  2764. )
  2765. if end != dim_size:
  2766. mask.append(
  2767. ops.lt(
  2768. idx_dim,
  2769. ops.index_expr(sympy.expand(end), torch.int64),
  2770. )
  2771. )
  2772. if step != 1:
  2773. mask.append(
  2774. ops.eq(
  2775. ops.index_expr(
  2776. ModularIndexing(idx[dim] - start, 1, step), torch.int64
  2777. ),
  2778. ops.constant(0, torch.int64),
  2779. )
  2780. )
  2781. assert mask
  2782. mask = functools.reduce(ops.and_, mask)
  2783. src_val = ops.masked(
  2784. mask,
  2785. lambda: src_loader(src_idx),
  2786. 0 if is_integer_type(x) else 0.0,
  2787. )
  2788. return ops.where(
  2789. mask,
  2790. src_val,
  2791. x_loader(idx),
  2792. )
  2793. return Pointwise.create(
  2794. device=x.get_device(),
  2795. dtype=x.get_dtype(),
  2796. inner_fn=inner_fn,
  2797. ranges=list(x.get_size()),
  2798. )
  2799. def _unwrap(x):
  2800. if isinstance(x, (list, tuple)) and len(x) > 0:
  2801. return _unwrap(x[0])
  2802. return x
  2803. @register_lowering([torch.tensor, aten.scalar_tensor])
  2804. def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False):
  2805. assert_nyi(layout in (None, torch.strided), f"layout={layout}")
  2806. assert_nyi(not pin_memory, "pin_memory")
  2807. if isinstance(_unwrap(data), int):
  2808. dtype = dtype or torch.int64
  2809. else:
  2810. dtype = dtype or torch.get_default_dtype()
  2811. ranges: list[sympy.Expr] = []
  2812. if isinstance(data, sympy.Basic):
  2813. def inner_fn(index):
  2814. return ops.index_expr(data, dtype)
  2815. elif isinstance(data, (float, int)):
  2816. def inner_fn(index):
  2817. return ops.constant(data, dtype)
  2818. elif len(data) == 0 or isinstance(data[0], (float, int)) and len(data) <= 8:
  2819. # inline small tensors
  2820. ranges.append(sympy.Integer(len(data)))
  2821. def inner_fn(index):
  2822. def binary_search(start, end):
  2823. assert start < end
  2824. if end - start == 1:
  2825. return ops.constant(data[start], dtype)
  2826. mid = (end - start) // 2 + start
  2827. return ops.where(
  2828. ops.lt(
  2829. ops.index_expr(index[0], torch.int64),
  2830. ops.constant(mid, torch.int64),
  2831. ),
  2832. binary_search(start, mid),
  2833. binary_search(mid, end),
  2834. )
  2835. if len(data) == 0:
  2836. return ops.constant(0, dtype)
  2837. return binary_search(0, len(data))
  2838. else:
  2839. return V.graph.add_tensor_constant(
  2840. torch.tensor(data, dtype=dtype, device=device)
  2841. )
  2842. return Pointwise.create(
  2843. device=decode_device(device),
  2844. dtype=dtype,
  2845. inner_fn=inner_fn,
  2846. ranges=ranges,
  2847. )
  2848. @register_lowering(torch.as_tensor)
  2849. def as_tensor(data, dtype=None, device=None):
  2850. if isinstance(data, TensorBox):
  2851. if dtype is not None:
  2852. data = to_dtype(data, dtype)
  2853. if device is not None:
  2854. data = to_device(data, device)
  2855. return data
  2856. return tensor(data, dtype=dtype, device=device)
  2857. @register_lowering(torch.LongTensor)
  2858. def long_tensor(data):
  2859. return tensor(data, dtype=torch.int64)
  2860. @register_lowering(aten._local_scalar_dense)
  2861. def _local_scalar_dense(data):
  2862. # This is interesting! Most lowerings return tensors, so you can just
  2863. # return the buffer you allocated and it will get used (or not used, if
  2864. # it's dead.) But _local_scalar_dense (aka item) returns an int,
  2865. # not a Tensor, so you would have a type mismatch if you return a buffer;
  2866. # we are obligated to return a sympy expression instead. However,
  2867. # we need to actually codegen the .item() call somehow. We do this
  2868. # by registering a faux buffer for the DynamicScalar IR node, which is
  2869. # solely responsible for generating this .item(). The buffer is
  2870. # not used for anything (notice we discard it); at codegen time,
  2871. # the "buffer" just gets assigned None.
  2872. unbacked_bindings = resolve_unbacked_bindings(
  2873. V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"]
  2874. )
  2875. assert unbacked_bindings is not None
  2876. assert len(unbacked_bindings) == 1, unbacked_bindings
  2877. # NB: Have to be very careful here. V.graph.current_node.meta["val"]
  2878. # seemingly also contains a symbol which you want to do binding for,
  2879. # but it actually isn't. In particular, if we have later performed
  2880. # a deferred runtime assert saying that u0 == s0, you will actually
  2881. # see s0 from expr! This is bad because we need to actually generate
  2882. # the assert that says u0 == s0, so we need to know where to get u0
  2883. # from (this call). In particular, we must use unbacked_bindings, which
  2884. # is guaranteed to have the original, unreplaced symbol in question.
  2885. #
  2886. # NB2: Another thing we have to be very careful about are symbol bindings
  2887. # that require nontrivial refinement, e.g., when you have a binding site
  2888. # x: Sym(u0 * 4) = y.item(). Here, the code generation must do a division
  2889. # in order to appropriately bind u0. This is communicated via the keypath
  2890. # in unbacked_bindings, and we need to hold onto it in order to generate
  2891. # code appropriately for this case.
  2892. binding_sym, keypath = next(iter(unbacked_bindings.items()))
  2893. buffer = ir.DynamicScalar(binding_sym, keypath, data)
  2894. buffer.name = V.graph.register_buffer(buffer)
  2895. V.graph.register_operation(buffer)
  2896. # NB: the replaced expr is OK to use directly downstream, we want
  2897. # simplifications in this case!
  2898. val = V.graph.current_node.meta["val"]
  2899. if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)):
  2900. return val.node.expr
  2901. else:
  2902. return sympy.sympify(val)
  2903. @register_lowering(aten._assert_scalar)
  2904. def _assert_scalar(data, msg):
  2905. # NB: These will be handled at codegen time
  2906. # Not sure if we are guaranteed to be able to serve out truth from the
  2907. # deferred_runtime_asserts, TODO: try this assert out
  2908. # See [NOTE] Codegen runtime asserts in Inductor
  2909. # assert bool(data.scalar), data
  2910. return None
  2911. @register_lowering(aten._assert_tensor_metadata)
  2912. def _assert_tensor_metadata(
  2913. a, size=None, stride=None, dtype=None, *, device=None, layout=None
  2914. ):
  2915. return None
  2916. def _full(fill_value, device, dtype, size):
  2917. value = fill_value
  2918. if not isinstance(fill_value, (int, float)) and hasattr(value, "value"):
  2919. value = value.value
  2920. if isinstance(value, (int, float)):
  2921. def inner_fn(index):
  2922. return ops.constant(value, dtype)
  2923. elif isinstance(value, sympy.Basic):
  2924. def inner_fn(index):
  2925. return ops.index_expr(value, dtype)
  2926. else:
  2927. assert len(value.get_size()) == 0
  2928. value_loader = value.make_loader()
  2929. def inner_fn(index):
  2930. return value_loader([])
  2931. return Pointwise.create(
  2932. device=device,
  2933. dtype=dtype,
  2934. inner_fn=inner_fn,
  2935. ranges=list(size),
  2936. )
  2937. def full_like(x, fill_value, **kwargs):
  2938. return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs)
  2939. def tensor_constructor(fill_value):
  2940. # torch.zeros, torch.ones, etc
  2941. def inner(
  2942. *size,
  2943. names=None,
  2944. dtype=None,
  2945. device=None,
  2946. layout=None,
  2947. pin_memory=False,
  2948. memory_format=None,
  2949. ):
  2950. assert_nyi(names is None, "named tensors")
  2951. assert_nyi(layout in (None, torch.strided), f"layout={layout}")
  2952. assert_nyi(not pin_memory, "pin_memory")
  2953. device = decode_device(device)
  2954. dtype = dtype or torch.get_default_dtype()
  2955. if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)):
  2956. size = tuple(size[0])
  2957. # See https://github.com/pytorch/pytorch/issues/118102
  2958. # All sizes at lowering time should be sympy.Symbol, not SymInt!
  2959. for s in size:
  2960. assert not isinstance(s, torch.SymInt)
  2961. size = [sympy.expand(s) for s in size]
  2962. return _full(fill_value, device, dtype, size)
  2963. return inner
  2964. @register_lowering([torch.empty, aten.empty])
  2965. def empty(
  2966. *size,
  2967. names=None,
  2968. dtype=None,
  2969. layout=None,
  2970. device=None,
  2971. pin_memory=None,
  2972. memory_format=None,
  2973. ):
  2974. assert_nyi(names is None, "named tensors")
  2975. device = decode_device(device)
  2976. if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)):
  2977. size = tuple(size[0])
  2978. return empty_strided(
  2979. size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  2980. )
  2981. def create_tensor_like(creation_fn):
  2982. """
  2983. Shim to convert X_like(...) into X(...). For example zeros_like() into zeros().
  2984. """
  2985. def _constant_like(
  2986. x, *, dtype=None, device=None, layout=None, pin_memory=False, memory_format=None
  2987. ):
  2988. assert_nyi(not pin_memory, "pin_memory")
  2989. assert_nyi(layout in (None, torch.strided), f"layout={layout}")
  2990. if dtype is None:
  2991. dtype = x.get_dtype()
  2992. else:
  2993. dtype = decode_dtype(dtype)
  2994. device = device or x.get_device()
  2995. size = list(x.get_size())
  2996. return creation_fn(
  2997. size, dtype=dtype, device=device, layout=layout, pin_memory=pin_memory
  2998. )
  2999. return _constant_like
  3000. def constant_like(fill_value):
  3001. return create_tensor_like(tensor_constructor(fill_value))
  3002. empty_like = register_lowering(aten.empty_like)(create_tensor_like(empty))
  3003. ones_like = create_tensor_like(tensor_constructor(1))
  3004. zeros_like = create_tensor_like(tensor_constructor(0))
  3005. def new_constant(fill_value):
  3006. def _new_constant(
  3007. x, size, *, dtype=None, layout=None, device=None, pin_memory=None
  3008. ):
  3009. assert isinstance(size, (list, tuple))
  3010. assert_nyi(not pin_memory, "pin_memory")
  3011. assert_nyi(layout in (None, torch.strided), f"layout={layout}")
  3012. # pyrefly: ignore [bad-argument-type]
  3013. dtype = decode_dtype(dtype) or x.get_dtype()
  3014. device = device or x.get_device()
  3015. size = [sympy.Integer(s) for s in size]
  3016. return _full(fill_value, decode_device(device), dtype, size)
  3017. return _new_constant
  3018. @register_lowering(aten.new_empty)
  3019. def new_empty(x, size, *, dtype=None, layout=None, device=None, pin_memory=None):
  3020. if dtype is None:
  3021. dtype = x.get_dtype()
  3022. if device is None:
  3023. device = x.get_device()
  3024. return empty_strided(
  3025. size,
  3026. None,
  3027. dtype=dtype,
  3028. layout=layout,
  3029. device=decode_device(device),
  3030. pin_memory=pin_memory,
  3031. )
  3032. @register_lowering(aten.empty_strided)
  3033. def empty_strided(
  3034. size, stride, *, dtype=None, layout=None, device=None, pin_memory=None
  3035. ):
  3036. assert isinstance(size, (list, tuple))
  3037. assert isinstance(stride, (list, tuple, type(None)))
  3038. assert_nyi(layout in (None, torch.strided), f"layout={layout}")
  3039. # pyrefly: ignore [bad-argument-type]
  3040. dtype = decode_dtype(dtype) or torch.get_default_dtype()
  3041. device = device or torch.tensor(0.0).device
  3042. device = decode_device(device)
  3043. pointwise = _full(fill_value=0, device=device, dtype=dtype, size=size)
  3044. pointwise.realize()
  3045. buffer = pointwise.data.data
  3046. # explicitly set ranges to zeros in order to make a NopKernelSchedulerNode
  3047. buffer.data = dataclasses.replace(buffer.data, ranges=[0] * len(size))
  3048. assert isinstance(buffer, ir.ComputedBuffer)
  3049. size = [sympy.expand(s) for s in size]
  3050. stride = (
  3051. [sympy.expand(s) for s in stride]
  3052. if stride
  3053. else ir.FlexibleLayout.contiguous_strides(size)
  3054. )
  3055. buffer.layout = ir.FixedLayout(
  3056. device=device,
  3057. dtype=dtype,
  3058. size=size,
  3059. stride=stride,
  3060. is_pinned=pin_memory or False,
  3061. )
  3062. return pointwise
  3063. @register_lowering(aten.new_empty_strided)
  3064. def new_empty_strided(
  3065. x, size, stride, *, dtype=None, layout=None, device=None, pin_memory=None
  3066. ):
  3067. if dtype is None:
  3068. dtype = x.get_dtype()
  3069. if device is None:
  3070. device = x.get_device()
  3071. return empty_strided(
  3072. size,
  3073. stride,
  3074. dtype=dtype,
  3075. layout=layout,
  3076. device=decode_device(device),
  3077. pin_memory=pin_memory,
  3078. )
  3079. @register_lowering(prims.copy_strided.default)
  3080. def copy_strided(x, stride):
  3081. stride = [V.graph.sizevars.size_hint_or_throw(s) for s in stride]
  3082. stride_order = sorted(range(len(stride)), key=stride.__getitem__)
  3083. return ir.ExternKernel.require_stride_order(x, stride_order)
  3084. @register_lowering([torch.full, aten.full])
  3085. def full(size, fill_value, **kwargs):
  3086. assert kwargs.get("dtype") is not None, "dtype should be handled by decomposition"
  3087. return tensor_constructor(fill_value)(size, **kwargs)
  3088. @register_lowering(aten.gather, type_promotion_kind=None)
  3089. def gather(x, dim, index, sparse_grad=False):
  3090. # sparse_grad doesn't affect forward computation,
  3091. # and backward tracing is taken care of by AOT Autograd
  3092. assert isinstance(x, TensorBox)
  3093. if index.get_numel() == 0:
  3094. # Empty index case. Return an empty array with the same shape
  3095. return new_empty(x, index.get_size())
  3096. size = x.get_size()
  3097. offset = len(size) == 0
  3098. dim = _validate_dim(x, dim, offset)
  3099. if offset:
  3100. x = expand(x, [1])
  3101. size = [1]
  3102. x_loader = x.make_loader()
  3103. index_loader = index.make_loader()
  3104. def fn(idx):
  3105. idx = list(idx)
  3106. gather_idx = ops.indirect_indexing(index_loader(idx), size[dim])
  3107. if len(idx) == 0:
  3108. idx = [gather_idx]
  3109. else:
  3110. idx[dim] = gather_idx
  3111. return x_loader(idx)
  3112. return Pointwise.create(
  3113. device=x.get_device(),
  3114. dtype=x.get_dtype(),
  3115. inner_fn=fn,
  3116. ranges=index.get_size(),
  3117. )
  3118. @register_lowering(aten.embedding, type_promotion_kind=None)
  3119. def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False):
  3120. if sparse:
  3121. return fallback_handler(aten.embedding.default)(
  3122. weight, indices, padding_idx, scale_grad_by_freq, sparse
  3123. )
  3124. assert not sparse
  3125. assert isinstance(weight, TensorBox)
  3126. assert isinstance(indices, TensorBox)
  3127. assert "int" in str(indices.get_dtype())
  3128. weight_loader = weight.make_loader()
  3129. indices_loader = indices.make_loader()
  3130. indices_ndim = len(indices.get_size())
  3131. weight_size = weight.get_size()
  3132. new_size = [*indices.get_size(), *weight_size[1:]]
  3133. def fn(idx):
  3134. assert len(idx) == len(new_size), f"{idx} != {new_size}"
  3135. var_index = indices_loader(idx[:indices_ndim])
  3136. weight_idx = [ops.indirect_indexing(var_index, weight_size[0])] + [
  3137. *idx[indices_ndim:]
  3138. ]
  3139. return weight_loader(weight_idx)
  3140. return Pointwise.create(
  3141. device=weight.get_device(),
  3142. dtype=weight.get_dtype(),
  3143. inner_fn=fn,
  3144. ranges=new_size,
  3145. )
  3146. def check_and_broadcast_indices(indices, device):
  3147. assert all(
  3148. i.get_dtype() in (torch.int64, torch.int32, torch.bool, torch.uint8)
  3149. for i in indices
  3150. if i is not None
  3151. ), (
  3152. f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}"
  3153. )
  3154. if any(
  3155. i.get_dtype() in (torch.bool, torch.uint8) for i in indices if i is not None
  3156. ):
  3157. raise NotImplementedError("Fallback for bool indices")
  3158. valid_idxs = [i for i, x in enumerate(indices) if isinstance(x, TensorBox)]
  3159. assert len(valid_idxs) > 0, "requires at least 1 non-None index"
  3160. new_indices = [None] * len(indices)
  3161. for i, x in zip(valid_idxs, broadcast_tensors(*[indices[i] for i in valid_idxs])):
  3162. # Eager allows indices to be CPU tensor when running on CUDA
  3163. # FIXME: Calling to_device(x, device) should work but
  3164. # test_advancedindex_mixed_cpu_devices still fails
  3165. if x.get_device() != device:
  3166. raise NotImplementedError("Fallback when indices is on a different device")
  3167. new_indices[i] = x
  3168. return new_indices, valid_idxs
  3169. def index_output_size_and_inner_fn(
  3170. x_size,
  3171. indices,
  3172. tensor_indices,
  3173. tensor_size,
  3174. indices_loaders,
  3175. indexed_size,
  3176. x_loader,
  3177. check,
  3178. wrap_neg=True,
  3179. ):
  3180. # Note that behavior of indexing differs when there are non consecutive
  3181. # tensors. In this case, the tensor index is pulled to the beginning.
  3182. #
  3183. # Suppose a = torch.arange(3 * 4 * 5 * 6 * 7).view(3, 4, 5, 6, 7)
  3184. # x = torch.tensor[1,2]
  3185. # Then, a[:,x,:,x,:] will have shape 2,3,5,7 as due to x,:,x then 2 will
  3186. # be pulled to the front.
  3187. non_consecutive_tensors = False
  3188. for previous, current in itertools.pairwise(tensor_indices):
  3189. if current - previous != 1:
  3190. non_consecutive_tensors = True
  3191. output_size = [x_size[i] for i, val in enumerate(indices) if val is None]
  3192. output_size = [*output_size, *x_size[len(output_size) + len(tensor_indices) :]]
  3193. first_tensor_index = tensor_indices[0]
  3194. if non_consecutive_tensors:
  3195. output_size = tensor_size + output_size
  3196. else:
  3197. output_size = (
  3198. output_size[:first_tensor_index]
  3199. + tensor_size
  3200. + output_size[first_tensor_index:]
  3201. )
  3202. def fn(idx):
  3203. assert len(idx) == len(output_size)
  3204. assert len(indices_loaders) == len(indexed_size)
  3205. rank = len(tensor_size)
  3206. new_index = []
  3207. first_tensor_index = tensor_indices[0]
  3208. start_offset = 0 if non_consecutive_tensors else first_tensor_index
  3209. next_idx = 0
  3210. for i in range(tensor_indices[-1] + 1):
  3211. if i == start_offset:
  3212. next_idx += rank
  3213. if indices[i] is None:
  3214. assert next_idx < len(idx)
  3215. new_index.append(idx[next_idx])
  3216. next_idx += 1
  3217. else:
  3218. loader = indices_loaders[i]
  3219. assert loader is not None
  3220. size = indexed_size[i]
  3221. new_index.append(
  3222. ops.indirect_indexing(
  3223. loader(idx[start_offset : start_offset + rank]),
  3224. size,
  3225. check=check,
  3226. wrap_neg=wrap_neg,
  3227. )
  3228. )
  3229. new_index = [
  3230. *new_index,
  3231. *idx[next_idx:],
  3232. ]
  3233. return new_index if x_loader is None else x_loader(new_index)
  3234. return output_size, fn
  3235. def index_impl(x, indices, check):
  3236. output_size, inner_fn, _ = index_impl_helper(x, indices, check)
  3237. return Pointwise.create(
  3238. device=x.get_device(),
  3239. dtype=x.get_dtype(),
  3240. inner_fn=inner_fn,
  3241. ranges=output_size,
  3242. )
  3243. def index_impl_helper(x, indices, check, wrap_neg=True):
  3244. assert isinstance(indices, (list, tuple))
  3245. x_loader = x.make_loader()
  3246. indices, tensor_indices = check_and_broadcast_indices(indices, x.get_device())
  3247. assert len(tensor_indices) > 0, "Must have at least one valid idx"
  3248. indices_loaders = [i.make_loader() if i is not None else None for i in indices]
  3249. # no guards on output size, all the guards are set in broadcast_tensors
  3250. # We can use the first one since they are all required to be the same size
  3251. tensor_size = list(indices[tensor_indices[0]].get_size())
  3252. x_size = x.get_size()
  3253. indexed_size = [x_size[i] for i in range(len(indices)) if indices[i] is not None]
  3254. if check and 0 in indexed_size and 0 not in tensor_size:
  3255. raise IndexError("index is out of bounds for dimension with size 0")
  3256. indexed_size = [x_size[i] for i in range(len(indices))]
  3257. output_size, index_inner_fn = index_output_size_and_inner_fn(
  3258. x_size,
  3259. indices,
  3260. tensor_indices,
  3261. tensor_size,
  3262. indices_loaders,
  3263. indexed_size,
  3264. None,
  3265. check=check,
  3266. wrap_neg=wrap_neg,
  3267. )
  3268. def inner_fn(idx):
  3269. return x_loader(index_inner_fn(idx))
  3270. return output_size, inner_fn, index_inner_fn
  3271. @register_lowering(aten.index, type_promotion_kind=None)
  3272. def index(x, indices):
  3273. try:
  3274. return index_impl(x, indices, check=True)
  3275. except NotImplementedError:
  3276. # Fallback to ATen for boolean indexing
  3277. x.realize()
  3278. return fallback_handler(aten.index.Tensor, add_to_fallback_set=False)(
  3279. x, indices
  3280. )
  3281. @register_lowering(aten._unsafe_index, type_promotion_kind=None)
  3282. def _unsafe_index(x, indices):
  3283. return index_impl(x, indices, check=False)
  3284. # All the indexing decompositions are written in terms of index, index_put, and index_put_
  3285. # We cannot have this lowering as a decomposition as it introduces
  3286. # mutation in the graph, which is bad for Aot Autograd. Aot Autograd runs dead
  3287. # code elimination and common subexpression elimination optimizations, which
  3288. # assume graphs to be side-effect free. More details at
  3289. # https://github.com/pytorch/torchdynamo/issues/1235
  3290. # and
  3291. # https://github.com/pytorch/torchdynamo/issues/1863
  3292. @register_lowering(aten.index_put, type_promotion_kind=None)
  3293. def index_put(x, indices, values, accumulate=False):
  3294. return index_put_impl_(
  3295. clone(x), indices, values, accumulate, check=True, may_realize=False
  3296. )
  3297. @register_lowering(aten._unsafe_index_put)
  3298. def _unsafe_index_put(x, indices, values, accumulate=False):
  3299. return index_put_impl_(
  3300. clone(x), indices, values, accumulate, check=False, may_realize=False
  3301. )
  3302. def index_put_as_masked_fill(self, indices, value, accumulate):
  3303. if value.get_device() != self.get_device():
  3304. value = to_device(value, self.get_device())
  3305. if accumulate:
  3306. value = add(self, value)
  3307. return mutate_to(self, where(indices[0], value, self))
  3308. def index_put_fallback(self, indices, values, accumulate):
  3309. from .utils import _fx_node_is_input_dependent_cudagraph_unsafe
  3310. op_overload = getattr(aten.index_put_, V.graph.current_node.target._overloadname) # type: ignore[union-attr]
  3311. # Check if any index is a boolean tensor - if so, mark as cudagraph-unsafe
  3312. # because boolean indices trigger .nonzero() during CUDA graph capture
  3313. # When graph_partition is enabled, skip - partitioning handles this
  3314. fx_node = V.graph.current_node
  3315. if (
  3316. not config.graph_partition
  3317. and fx_node is not None
  3318. and _fx_node_is_input_dependent_cudagraph_unsafe(fx_node)
  3319. ):
  3320. msg = "index_put_ fallback with boolean indexing is not compatible with CUDA graphs"
  3321. if stack_trace := fx_node.meta.get("stack_trace", None):
  3322. msg = f"{msg} Found from : \n {stack_trace}"
  3323. V.graph.disable_cudagraphs_reason = msg
  3324. ir.IndexPutFallback(op_overload, self, indices, values, accumulate)
  3325. return self
  3326. @register_lowering(aten.index_put_, type_promotion_kind=None)
  3327. def index_put_(self, indices, values, accumulate=False):
  3328. return index_put_impl_(
  3329. self, indices, values, accumulate, check=True, may_realize=True
  3330. )
  3331. @register_lowering(inductor_prims._unsafe_index_put_, type_promotion_kind=None)
  3332. def _unsafe_index_put_(self, indices, values, accumulate=False):
  3333. return index_put_impl_(
  3334. self, indices, values, accumulate, check=False, may_realize=True
  3335. )
  3336. def index_put_impl_(self, indices, values, accumulate, check, may_realize=False):
  3337. if may_realize:
  3338. def indice_slice_from_randperm(indice):
  3339. # Refer to: https://github.com/pytorch/pytorch/pull/139366#discussion_r1825424660
  3340. # For this specific pattern, indices is unique as coming from torch.randperm.
  3341. # However, as the content of the indices is unknown, we have to check this specific pattern.
  3342. if isinstance(indice, TensorBox) and isinstance(indice.data, ir.BaseView):
  3343. indice = indice.data.unwrap_view()
  3344. return (
  3345. isinstance(indice, ir.StorageBox)
  3346. and isinstance(indice.data, ir.ExternKernel)
  3347. and getattr(indice.data, "fx_node", None)
  3348. and indice.data.fx_node.target is torch.ops.aten.randperm.default
  3349. )
  3350. return False
  3351. if ir.try_get_name(self) in values.get_read_names() and not all(
  3352. indice_slice_from_randperm(indice) for indice in indices
  3353. ):
  3354. # Fix issue: https://github.com/pytorch/pytorch/issues/138908
  3355. # When self and values have memory overlapping, indices may
  3356. # contain duplicate values, potentially causing incorrect results since
  3357. # the load of `values` might contain modified value from the store of `self`.
  3358. # To address this, store values in a temporary buffer in such cases.
  3359. values.realize()
  3360. # Dispatch to masked fill for single boolean index with single value
  3361. if (
  3362. values.get_numel() == 1
  3363. and len(indices) == 1
  3364. and indices[0].get_dtype() in (torch.bool, torch.uint8)
  3365. ):
  3366. mask = indices[0]
  3367. for _ in range(len(mask.get_size()), len(self.get_size())):
  3368. mask = unsqueeze(mask, -1)
  3369. return index_put_as_masked_fill(self, [mask], values, accumulate)
  3370. # Fallback in torch deterministic mode
  3371. if torch.are_deterministic_algorithms_enabled():
  3372. return index_put_fallback(self, indices, values, accumulate)
  3373. # Fallback if there is a boolean index
  3374. for index in indices:
  3375. if index is not None and index.get_dtype() in (torch.bool, torch.uint8):
  3376. return index_put_fallback(self, indices, values, accumulate)
  3377. x_size = self.get_size()
  3378. x_ndim = len(x_size)
  3379. if accumulate and needs_fallback_due_to_atomic_add_limitations(self.get_dtype()):
  3380. # self is an scalar Tensor
  3381. if x_ndim == 0:
  3382. self = view(self, [1])
  3383. self = index_put_fallback(self, indices, values, accumulate)
  3384. if x_ndim == 0:
  3385. self = view(self, [])
  3386. return self
  3387. values = to_dtype(values, self.get_dtype())
  3388. try:
  3389. # Note that code will only get here when dtype is uint32
  3390. indices, tensor_indices = check_and_broadcast_indices(
  3391. indices, self.get_device()
  3392. )
  3393. except NotImplementedError:
  3394. return index_put_fallback(self, indices, values, accumulate)
  3395. indices_loaders = [i.make_loader() if i is not None else None for i in indices]
  3396. assert isinstance(self, TensorBox)
  3397. self.realize()
  3398. # self is an scalar Tensor
  3399. if x_ndim == 0:
  3400. self = view(self, [1])
  3401. # We can use the first one since they are all required to be the same size
  3402. tensor_size = list(indices[tensor_indices[0]].get_size())
  3403. indexed_size = [x_size[i] for i in range(len(indices))]
  3404. expected_vals_size, inner_fn = index_output_size_and_inner_fn(
  3405. x_size,
  3406. indices,
  3407. tensor_indices,
  3408. tensor_size,
  3409. indices_loaders,
  3410. indexed_size,
  3411. None,
  3412. check=check,
  3413. )
  3414. values = expand(values, expected_vals_size)
  3415. # all guards are set above during broadcast_tensors and expand
  3416. device = self.get_device()
  3417. assert device is not None
  3418. scatter = ir.Scatter(
  3419. device=device,
  3420. dtype=self.get_dtype(),
  3421. inner_fn=values.make_loader(),
  3422. ranges=expected_vals_size, # iter_ranges,
  3423. output_indexer=inner_fn,
  3424. scatter_mode="atomic_add" if accumulate else None,
  3425. )
  3426. buffer = ir.ComputedBuffer(
  3427. name=None,
  3428. layout=ir.MutationLayoutSHOULDREMOVE(self),
  3429. data=scatter,
  3430. )
  3431. buffer.name = V.graph.register_buffer(buffer)
  3432. V.graph.register_operation(buffer)
  3433. if x_ndim == 0:
  3434. self = view(self, [])
  3435. return self
  3436. fallback__unsafe_masked_index = fallback_handler(
  3437. aten._unsafe_masked_index.default, add_to_fallback_set=False
  3438. )
  3439. fallback__unsafe_masked_index_put_accumulate = fallback_handler(
  3440. aten._unsafe_masked_index_put_accumulate.default, add_to_fallback_set=False
  3441. )
  3442. @register_lowering(aten._unsafe_masked_index, type_promotion_kind=None)
  3443. def _unsafe_masked_index(self, mask, indices, fill):
  3444. ranges, _, _unsafe_index_fn = index_impl_helper(
  3445. self, indices, check=False, wrap_neg=False
  3446. )
  3447. mask_loader = mask.make_loader()
  3448. self_loader = self.make_loader()
  3449. def inner_fn(idx):
  3450. if mask.dtype != torch.bool:
  3451. mask_val = ops.to_dtype(mask_loader(idx), torch.bool)
  3452. else:
  3453. mask_val = mask_loader(idx)
  3454. return ops.masked(mask_val, lambda: self_loader(_unsafe_index_fn(idx)), fill)
  3455. return Pointwise.create(
  3456. device=self.get_device(),
  3457. dtype=self.get_dtype(),
  3458. inner_fn=inner_fn,
  3459. ranges=ranges,
  3460. )
  3461. @register_lowering(aten._unsafe_masked_index_put_accumulate, type_promotion_kind=None)
  3462. def _unsafe_masked_index_put_accumulate(x, mask, indices, values):
  3463. masked_value = where(mask, values, 0)
  3464. shape = x.get_size()
  3465. clamped_indices = [
  3466. clamp(indices[i], -shape[i], shape[i] - 1) if indices[i] else None
  3467. for i in range(len(indices))
  3468. ]
  3469. # TODO: use a masked store for this. currently only triton
  3470. # supports masked stores and cpp backend does not.
  3471. return _unsafe_index_put(x, clamped_indices, masked_value, accumulate=True)
  3472. @make_pointwise
  3473. def clamp(a, min, max):
  3474. return ops.maximum(min, ops.minimum(max, a))
  3475. @register_lowering(aten.as_strided_scatter, type_promotion_kind=None)
  3476. def as_strided_scatter(self, src, size, stride, storage_offset=None):
  3477. output = clone(self)
  3478. output_view = as_strided(output, size, stride, storage_offset)
  3479. copy_(output_view, src)
  3480. return output
  3481. @register_lowering(aten.scatter, type_promotion_kind=None)
  3482. def scatter(x, dim: int, index, src, **kwargs):
  3483. return scatter_(clone(x), dim, index, src, **kwargs)
  3484. def scatter_fallback(
  3485. op_overload: torch._ops.OpOverload,
  3486. self,
  3487. dim: int,
  3488. index,
  3489. src,
  3490. *,
  3491. reduce: Optional[str] = None,
  3492. include_self: bool = True,
  3493. ):
  3494. src_is_tensor = isinstance(src, TensorBox)
  3495. if use_scatter_fallback(
  3496. op_overload,
  3497. reduce,
  3498. self.get_dtype(),
  3499. cast(torch.dtype, src.get_dtype() if src_is_tensor else type(src)),
  3500. # pyrefly: ignore [missing-attribute]
  3501. src.get_device().type if src_is_tensor else "not impl",
  3502. src_is_tensor,
  3503. ):
  3504. ir.ScatterFallback(
  3505. op_overload,
  3506. self,
  3507. dim,
  3508. index,
  3509. src,
  3510. reduce=reduce,
  3511. include_self=include_self,
  3512. )
  3513. return self
  3514. return None
  3515. @register_lowering(aten.scatter_, type_promotion_kind=None)
  3516. def scatter_(self, dim: int, index, src, *, reduce: Optional[str] = None):
  3517. assert reduce in (None, "add", "multiply")
  3518. if reduce is None:
  3519. op_overload = getattr(aten.scatter_, V.graph.current_node.target._overloadname) # type: ignore[union-attr]
  3520. fallback_result = scatter_fallback(
  3521. op_overload, self, dim, index, src, reduce=reduce
  3522. )
  3523. if fallback_result is not None:
  3524. return fallback_result
  3525. if reduce == "add":
  3526. reduce = "sum"
  3527. elif reduce == "multiply":
  3528. reduce = "prod"
  3529. return scatter_reduce_(self, dim, index, src, reduce)
  3530. @register_lowering(aten.scatter_add, type_promotion_kind=None)
  3531. def scatter_add(x, dim: int, index, src):
  3532. return scatter_add_(clone(x), dim, index, src)
  3533. @register_lowering(aten.scatter_add_, type_promotion_kind=None)
  3534. def scatter_add_(x, dim: int, index, src):
  3535. return scatter_reduce_(x, dim, index, src, "sum")
  3536. @register_lowering(aten.scatter_reduce, type_promotion_kind=None)
  3537. def scatter_reduce(x, dim: int, index, src, reduction_type, **kwargs):
  3538. return scatter_reduce_(clone(x), dim, index, src, reduction_type, **kwargs)
  3539. @register_lowering(aten.scatter_reduce_, type_promotion_kind=None)
  3540. def scatter_reduce_(self, dim: int, index, src, reduce, *, include_self: bool = True):
  3541. assert reduce in (None, "sum", "prod", "mean", "amax", "amin")
  3542. assert (
  3543. len(aten.scatter_reduce_.overloads()) == 1
  3544. and "two" in aten.scatter_reduce_.overloads()
  3545. ), "aten.scatter_reduce_.two is not the unique overload of aten.scatter_reduce_"
  3546. if isinstance(src, Number):
  3547. src = full_like(self, src)
  3548. fallback_result = scatter_fallback(
  3549. aten.scatter_reduce_.two,
  3550. self,
  3551. dim,
  3552. index,
  3553. src,
  3554. reduce=reduce,
  3555. include_self=include_self,
  3556. )
  3557. if fallback_result:
  3558. return fallback_result
  3559. assert isinstance(self, TensorBox)
  3560. assert "int" in str(index.get_dtype())
  3561. ndim = len(self.get_size())
  3562. if ndim == 0:
  3563. self = view(self, [1])
  3564. if isinstance(src, TensorBox) and len(src.get_size()) == 0:
  3565. src = view(src, [1])
  3566. if isinstance(index, TensorBox) and len(index.get_size()) == 0:
  3567. index = view(index, [1])
  3568. if index.get_numel() == 0:
  3569. return self
  3570. dim = _validate_dim(self, dim)
  3571. self.realize()
  3572. index_loader = index.make_loader()
  3573. src_loader = src.make_loader() if isinstance(src, TensorBox) else None
  3574. def output_indexer(idx):
  3575. # self is captured from the end of the function, so it may have 0 dim
  3576. shape = self.get_size()
  3577. ndim = len(shape)
  3578. indirect_idx = list(idx)
  3579. indirect_idx[dim] = ops.indirect_indexing(
  3580. index_loader(idx), 1 if ndim == 0 else shape[dim], wrap_neg=False
  3581. )
  3582. return indirect_idx
  3583. def fn(idx):
  3584. if src_loader:
  3585. return src_loader(idx)
  3586. else:
  3587. # src is a scalar
  3588. # pyrefly: ignore [bad-argument-type]
  3589. return ops.constant(src, self.get_dtype())
  3590. def backend_reduce_str(reduce):
  3591. if reduce == "sum":
  3592. return "atomic_add"
  3593. else:
  3594. # TODO: Need to support more reduction type
  3595. assert reduce is None
  3596. return None
  3597. device = self.get_device()
  3598. assert device is not None
  3599. if not include_self:
  3600. # zero out the corresponding elements first
  3601. zero_out = ir.Scatter(
  3602. device=device,
  3603. dtype=self.get_dtype(),
  3604. inner_fn=lambda index: ops.constant(0, self.get_dtype()),
  3605. ranges=index.get_size(),
  3606. output_indexer=output_indexer,
  3607. scatter_mode=None,
  3608. )
  3609. buffer = ir.ComputedBuffer(
  3610. name=None,
  3611. layout=ir.MutationLayoutSHOULDREMOVE(self),
  3612. data=zero_out,
  3613. )
  3614. buffer.name = V.graph.register_buffer(buffer)
  3615. V.graph.register_operation(buffer)
  3616. # self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0
  3617. # self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1
  3618. # self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2
  3619. scatter = ir.Scatter(
  3620. device=device,
  3621. dtype=self.get_dtype(),
  3622. inner_fn=fn,
  3623. ranges=index.get_size(),
  3624. output_indexer=output_indexer,
  3625. scatter_mode=backend_reduce_str(reduce),
  3626. )
  3627. buffer = ir.ComputedBuffer(
  3628. name=None,
  3629. layout=ir.MutationLayoutSHOULDREMOVE(self),
  3630. data=scatter,
  3631. )
  3632. buffer.name = V.graph.register_buffer(buffer)
  3633. V.graph.register_operation(buffer)
  3634. if ndim == 0:
  3635. self = view(self, [])
  3636. return self
  3637. def upsample_nearestnd(
  3638. x,
  3639. output_size,
  3640. scales_x: tuple[Optional[float], ...],
  3641. n: int = 2,
  3642. exact: bool = False,
  3643. ):
  3644. x.realize_hint() # elements are reused
  3645. x_loader = x.make_loader()
  3646. i_sizes = x.get_size()[-n:]
  3647. batch = x.get_size()[:-n]
  3648. i_sizes = [V.graph.sizevars.guard_int(i) for i in i_sizes]
  3649. assert len(scales_x) == n
  3650. o_sizes = output_size
  3651. inv_scales = [i / o for i, o in zip(i_sizes, o_sizes)]
  3652. for i, scale in enumerate(scales_x):
  3653. if scale is not None:
  3654. inv_scales[i] = 1.0 / scale
  3655. def scale_fn(x, scale, size):
  3656. # Nearest Exact: input_index = round(scale * (output_index + 0.5) - 0.5)
  3657. # = floor(scale * (output_index + 0.5))
  3658. # Nearest: input_index = floor(scale * output_index)
  3659. x = ops.index_expr(x, torch.float32)
  3660. if exact:
  3661. x = ops.add(x, ops.constant(0.5, torch.float32))
  3662. x = ops.mul(x, ops.constant(scale, torch.float32))
  3663. x = ops.to_dtype(x, torch.int32)
  3664. return ops.indirect_indexing(x, size, check=False)
  3665. def fn(idx):
  3666. x = idx[-n:]
  3667. b = idx[:-n]
  3668. return x_loader(
  3669. [*b, *[scale_fn(i, s, size) for i, s, size in zip(x, inv_scales, i_sizes)]]
  3670. )
  3671. return Pointwise.create(
  3672. device=x.get_device(),
  3673. dtype=x.get_dtype(),
  3674. inner_fn=fn,
  3675. ranges=[*batch, *o_sizes],
  3676. )
  3677. @register_lowering(aten.upsample_nearest1d.default)
  3678. def upsample_nearest1d(x, output_size, scales: Optional[float] = None):
  3679. return upsample_nearestnd(x, output_size, (scales,), n=1)
  3680. @register_lowering(aten._upsample_nearest_exact1d.default)
  3681. def _upsample_nearest_exact1d(x, output_size, scales: Optional[float] = None):
  3682. return upsample_nearestnd(x, output_size, (scales,), n=1, exact=True)
  3683. @register_lowering(aten.upsample_nearest2d.default)
  3684. def upsample_nearest2d(
  3685. x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None
  3686. ):
  3687. return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2)
  3688. @register_lowering(aten._upsample_nearest_exact2d.default)
  3689. def _upsample_nearest_exact2d(
  3690. x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None
  3691. ):
  3692. return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2, exact=True)
  3693. @register_lowering(aten.upsample_nearest3d.default)
  3694. def upsample_nearest3d(
  3695. x,
  3696. output_size,
  3697. scales_d: Optional[float] = None,
  3698. scales_h: Optional[float] = None,
  3699. scales_w: Optional[float] = None,
  3700. ):
  3701. return upsample_nearestnd(x, output_size, (scales_d, scales_h, scales_w), n=3)
  3702. @register_lowering(aten._upsample_nearest_exact3d.default)
  3703. def _upsample_nearest_exact3d(
  3704. x,
  3705. output_size,
  3706. scales_d: Optional[float] = None,
  3707. scales_h: Optional[float] = None,
  3708. scales_w: Optional[float] = None,
  3709. ):
  3710. return upsample_nearestnd(
  3711. x, output_size, (scales_d, scales_h, scales_w), n=3, exact=True
  3712. )
  3713. def _create_constants(*args, dtype):
  3714. return tuple(ops.constant(a, dtype) for a in args)
  3715. @register_lowering(prims.rev.default)
  3716. def rev(x, dims):
  3717. # note - dims pre-canonicalized
  3718. x_loader = x.make_loader()
  3719. sizes = x.get_size()
  3720. def loader(idx):
  3721. idx = list(idx)
  3722. assert len(idx) == len(sizes)
  3723. for dim in dims:
  3724. idx[dim] = (sizes[dim] - 1) - idx[dim]
  3725. return x_loader(idx)
  3726. return Pointwise.create(
  3727. device=x.get_device(),
  3728. dtype=x.get_dtype(),
  3729. inner_fn=loader,
  3730. ranges=sizes,
  3731. )
  3732. def inplace_constant_pad_nd(
  3733. x: TensorBox, padding: Sequence[int], fill_value: float
  3734. ) -> Optional[TensorBox]:
  3735. """
  3736. This optimization changes the semantics of padding from 'clone'
  3737. style to 'view' style.
  3738. Thanks to functionalization, this change can still maintain numerical
  3739. correctness.
  3740. """
  3741. def _padding_can_be_fused():
  3742. """
  3743. Conservatively check if padding can be fused with downstream op.
  3744. 1. if the downstream op is a sum, then there is little benefit to
  3745. do inplace padding
  3746. 2. if the downstream op is a matmul, doing inplace padding can
  3747. save membw.
  3748. """
  3749. current_node = V.graph.current_node
  3750. if current_node is None:
  3751. return True # be conservative
  3752. users = tuple(current_node.users)
  3753. if len(users) == 1 and users[0].target in (
  3754. aten.mm.default,
  3755. aten.addmm.default,
  3756. ):
  3757. return False
  3758. return True # be conservative
  3759. if _padding_can_be_fused():
  3760. return None
  3761. # Only handle 2D case for now
  3762. if len(padding) != 4 or len(x.get_size()) != 2:
  3763. return None
  3764. # No harm to realize since we already know that
  3765. # the op can not be fused into the single user.
  3766. # It need to be realized later anyways.
  3767. x.realize()
  3768. # If x is a view (e.g. a SliceView), realizing it just realizing the
  3769. # underlying storage. x itself is still a view.
  3770. if (
  3771. not isinstance(x, ir.TensorBox)
  3772. or not isinstance(x.data, ir.StorageBox)
  3773. or not (
  3774. isinstance(x.data.data, ir.ComputedBuffer)
  3775. or (
  3776. config.can_inplace_pad_graph_input
  3777. and isinstance(x.data.data, ir.InputBuffer)
  3778. )
  3779. )
  3780. or not x.data.data.name
  3781. ):
  3782. return None
  3783. x.freeze_layout()
  3784. _, layout = ir.as_storage_and_layout(x)
  3785. strides = layout.stride
  3786. if strides[1] != 1:
  3787. return None
  3788. if padding[0] != 0 or padding[2] != 0 or padding[3] != 0:
  3789. return None
  3790. npad = padding[1]
  3791. if npad == 0:
  3792. return None
  3793. stride0 = strides[0]
  3794. rowsize = layout.size[1]
  3795. if stride0 < rowsize + npad:
  3796. return None
  3797. bufname = x.data.data.name
  3798. padded_size = [layout.size[0], layout.size[1] + npad]
  3799. V.graph.buffer_to_padded_size[bufname] = padded_size
  3800. resized_x = as_strided(
  3801. x,
  3802. padded_size,
  3803. layout.stride,
  3804. layout.offset,
  3805. )
  3806. sliced_x = slice_(resized_x, dim=1, start=rowsize, end=rowsize + npad, clamp=False)
  3807. fill_(sliced_x, fill_value)
  3808. counters["inductor"]["inplace_padding"] += 1
  3809. return resized_x
  3810. @register_lowering(aten.constant_pad_nd, type_promotion_kind=None)
  3811. def constant_pad_nd(x, padding, fill_value=0):
  3812. assert (len(padding) % 2) == 0
  3813. if all(p == 0 for p in padding):
  3814. return clone(x)
  3815. if config.inplace_padding:
  3816. out = inplace_constant_pad_nd(x, padding, fill_value)
  3817. if out:
  3818. return out
  3819. # fall through if can not inplace the padding
  3820. sizes = x.get_size()
  3821. bounds = list(reversed(list(zip(padding[::2], padding[1::2]))))
  3822. n = len(sizes) - len(bounds)
  3823. # if padding is a complicated expression, hoist it
  3824. bounds_precomp: list[tuple[sympy.Symbol, Any]] = []
  3825. for l, h in bounds:
  3826. bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h)) # type: ignore[arg-type]
  3827. output_size = list(sizes[:n])
  3828. mask_sizes = []
  3829. for (low, high), size in zip(bounds, sizes[n:]):
  3830. mask_sizes.append(size)
  3831. output_size.append(sympy.expand(size + low + high))
  3832. assert len(output_size) == len(sizes)
  3833. fill_value = dtype_to_type(x.get_dtype())(fill_value)
  3834. def mask(index):
  3835. mask = []
  3836. for idx, (low, high), length in zip(index[n:], bounds, mask_sizes):
  3837. if low != 0:
  3838. mask.append(range_mask_low(idx, 0))
  3839. if high != 0:
  3840. mask.append(range_mask_high(idx, length))
  3841. mask = functools.reduce(ops.and_, mask)
  3842. return ops.masked(mask, lambda: x_loader(index), fill_value)
  3843. def offset_fn(index):
  3844. new_index = list(index[:n])
  3845. for idx, (low, _high) in zip(index[n:], bounds_precomp):
  3846. new_index.append(idx - low)
  3847. assert len(new_index) == len(index)
  3848. return mask(new_index)
  3849. x_loader = x.make_loader()
  3850. return Pointwise.create(
  3851. device=x.get_device(),
  3852. dtype=x.get_dtype(),
  3853. inner_fn=offset_fn,
  3854. ranges=output_size,
  3855. )
  3856. def range_mask_low(i: sympy.Expr, low: Union[sympy.Expr, int]):
  3857. return ops.ge(
  3858. ops.index_expr(i, torch.int64),
  3859. ops.index_expr(sympy.Integer(low), torch.int64),
  3860. )
  3861. def range_mask_high(i: sympy.Expr, high: sympy.Expr):
  3862. return ops.lt(
  3863. ops.index_expr(i, torch.int64),
  3864. ops.index_expr(high, torch.int64),
  3865. )
  3866. def range_mask(i: sympy.Expr, high: sympy.Expr, low: sympy.Expr):
  3867. return ops.and_(
  3868. range_mask_low(i, low),
  3869. range_mask_high(i, high),
  3870. )
  3871. def constant_boundary_condition(
  3872. x, fill_value, padding=None, pad_fill_value=1.0, dim=None
  3873. ):
  3874. h = x.get_size()[-dim:]
  3875. x_loader = x.make_loader()
  3876. # pyrefly: ignore [unsupported-operation]
  3877. padding_h = padding or [0] * dim
  3878. def load(index):
  3879. prefix = index[:-dim]
  3880. ih = index[-dim:]
  3881. mask = functools.reduce(
  3882. ops.and_,
  3883. # pyrefly: ignore [no-matching-overload]
  3884. [range_mask(ih[i], h[i] + padding_h[i], -padding_h[i]) for i in range(dim)],
  3885. )
  3886. return (
  3887. ops.masked(
  3888. mask,
  3889. lambda: constant_boundary_condition(x, pad_fill_value, dim=dim)(
  3890. [*prefix, *ih]
  3891. ),
  3892. fill_value,
  3893. )
  3894. if padding
  3895. else ops.masked(mask, lambda: x_loader([*prefix, *ih]), fill_value)
  3896. )
  3897. return load
  3898. def pooling_size(x, i, kernel_size, stride, padding, ceil_mode, *, dilation=None):
  3899. if dilation is None:
  3900. dilation = [1] * len(padding)
  3901. x_out = FloorDiv(
  3902. x + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) + (stride[i] - 1),
  3903. stride[i],
  3904. )
  3905. if ceil_mode:
  3906. x_alt = FloorDiv(
  3907. x
  3908. + 2 * padding[i]
  3909. - dilation[i] * (kernel_size[i] - 1)
  3910. + 2 * (stride[i] - 1),
  3911. stride[i],
  3912. )
  3913. if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0:
  3914. # Sliding windows must start within the input or left padding
  3915. x_alt -= 1 # type: ignore[assignment]
  3916. V.graph.sizevars.check_leq(0, x_alt * stride[i] - x - padding[i]) # type: ignore[arg-type]
  3917. if V.graph.sizevars.size_hint(x_out - x_alt) == 0:
  3918. # ceil mode is actually a no-op, lets guard on that
  3919. V.graph.sizevars.check_equals(x_out, x_alt)
  3920. ceil_mode = False
  3921. else:
  3922. x_out = x_alt
  3923. return x_out, ceil_mode
  3924. def should_fallback_max_pool_with_indices(kernel_size, *, n_dim):
  3925. kernel_size = pad_listlike(kernel_size, n_dim)
  3926. window_size = functools.reduce(operator.mul, kernel_size)
  3927. return window_size > 25
  3928. def max_pool_checks(
  3929. x, kernel_size, stride, padding, dilation, n_dim, *, assert_fallback=None
  3930. ):
  3931. if padding == 0:
  3932. padding = [0] * n_dim
  3933. if dilation == 1:
  3934. dilation = [1] * n_dim
  3935. if not stride:
  3936. stride = kernel_size
  3937. kernel_size = pad_listlike(kernel_size, n_dim)
  3938. stride = pad_listlike(stride, n_dim)
  3939. padding = pad_listlike(padding, n_dim)
  3940. dilation = pad_listlike(dilation, n_dim)
  3941. assert isinstance(x, TensorBox)
  3942. assert len(kernel_size) == n_dim
  3943. assert len(stride) == n_dim
  3944. assert len(padding) == n_dim
  3945. assert len(dilation) == n_dim
  3946. assert len(x.get_size()) in (n_dim + 1, n_dim + 2)
  3947. use_fallback = should_fallback_max_pool_with_indices(kernel_size, n_dim=n_dim)
  3948. if assert_fallback is not None:
  3949. assert use_fallback == assert_fallback
  3950. return kernel_size, stride, padding, dilation, use_fallback
  3951. def _max_pool_with_offsets(
  3952. x,
  3953. kernel_size,
  3954. stride,
  3955. padding,
  3956. dilation,
  3957. ceil_mode,
  3958. *,
  3959. n_dim,
  3960. ):
  3961. x.realize_hint()
  3962. batch = x.shape[:-n_dim]
  3963. dhw = x.shape[-n_dim:]
  3964. dhw_out, ceil_mode = zip(
  3965. *[
  3966. pooling_size(
  3967. dhw[d], d, kernel_size, stride, padding, ceil_mode, dilation=dilation
  3968. )
  3969. for d in range(n_dim)
  3970. ]
  3971. )
  3972. dtype = x.dtype
  3973. min_value = (
  3974. False
  3975. if dtype is torch.bool
  3976. else (float("-inf") if dtype.is_floating_point else torch.iinfo(dtype).min)
  3977. )
  3978. new_size = list(batch) + list(dhw_out)
  3979. if any(padding) or any(ceil_mode) or any(d > 1 for d in dilation):
  3980. x_loader = constant_boundary_condition(x, min_value, dim=n_dim)
  3981. else:
  3982. x_loader = x.make_loader()
  3983. def fn_inner(idx, reduction_idx):
  3984. prefix = idx[:-n_dim]
  3985. bh = idx[-n_dim:]
  3986. ih = [
  3987. (bh[i] * stride[i]) + (reduction_idx[i] * dilation[i]) - padding[i]
  3988. for i in range(n_dim)
  3989. ]
  3990. return x_loader([*prefix, *ih])
  3991. result = Reduction.create(
  3992. reduction_type="max",
  3993. input_node=x,
  3994. device=x.get_device(),
  3995. dst_dtype=dtype,
  3996. src_dtype=dtype,
  3997. inner_fn=fn_inner,
  3998. ranges=new_size,
  3999. reduction_ranges=kernel_size,
  4000. )
  4001. offsets = Reduction.create(
  4002. reduction_type="argmax",
  4003. input_node=x,
  4004. device=x.get_device(),
  4005. dst_dtype=torch.int64,
  4006. src_dtype=dtype,
  4007. inner_fn=fn_inner,
  4008. ranges=new_size,
  4009. reduction_ranges=kernel_size,
  4010. )
  4011. if isinstance(result.data.data, Reduction): # type: ignore[attr-defined, union-attr]
  4012. # Only realize if reduction isn't unrolled
  4013. result.realize()
  4014. if isinstance(offsets.data.data, Reduction): # type: ignore[attr-defined, union-attr]
  4015. # Only realize if reduction isn't unrolled
  4016. offsets.realize()
  4017. return result, offsets
  4018. @register_lowering(prims._low_memory_max_pool_with_offsets, type_promotion_kind=None)
  4019. def _low_memory_max_pool_with_offsets(
  4020. x,
  4021. kernel_size,
  4022. stride,
  4023. padding,
  4024. dilation,
  4025. ceil_mode=False,
  4026. ):
  4027. n_dim = len(kernel_size)
  4028. # assert we are not on a fallback path, the inductor decomp should have guaranteed this
  4029. kernel_size, stride, padding, dilation, _ = max_pool_checks(
  4030. x,
  4031. kernel_size,
  4032. stride,
  4033. padding,
  4034. dilation,
  4035. n_dim,
  4036. assert_fallback=False,
  4037. )
  4038. with config.patch(unroll_reductions_threshold=25):
  4039. result, offsets = _max_pool_with_offsets(
  4040. x,
  4041. kernel_size,
  4042. stride,
  4043. padding,
  4044. dilation,
  4045. ceil_mode,
  4046. n_dim=n_dim,
  4047. )
  4048. return result, to_dtype(offsets, torch.int8)
  4049. def _pool_offsets_to_indices(
  4050. offsets: TensorBox,
  4051. kernel_size: Sequence[Union[int, torch.SymInt]],
  4052. input_size: Sequence[Union[int, torch.SymInt]],
  4053. increments_to_index: Callable[
  4054. [Sequence[Union[int, torch.SymInt]], Sequence[Union[int, torch.SymInt]]],
  4055. torch._inductor.virtualized.OpsValue,
  4056. ],
  4057. ) -> TensorBox:
  4058. n_dim = len(kernel_size)
  4059. offsets_loader = offsets.make_loader()
  4060. window_size = sympy.sympify(functools.reduce(operator.mul, kernel_size))
  4061. def offsets_to_indices(idx):
  4062. offset = offsets_loader(idx)
  4063. offset_sympy = ops.indirect_indexing(offset, window_size)
  4064. reduction_idx = inductor_prims._flattened_index_to_nd(offset_sympy, kernel_size)
  4065. idhw = increments_to_index(idx, reduction_idx)
  4066. return ops.index_expr(
  4067. inductor_prims._flatten_index(idhw, input_size[-n_dim:]), torch.int64
  4068. )
  4069. indices = Pointwise.create(
  4070. device=offsets.get_device(),
  4071. dtype=torch.int64,
  4072. inner_fn=offsets_to_indices,
  4073. ranges=offsets.get_size(),
  4074. )
  4075. return indices
  4076. @register_lowering(
  4077. prims._low_memory_max_pool_offsets_to_indices, type_promotion_kind=None
  4078. )
  4079. def _low_memory_max_pool_offsets_to_indices(
  4080. offsets, kernel_size, input_size, stride, padding, dilation
  4081. ):
  4082. # TODO: Generalize to other max pooling flavors
  4083. n_dim = len(kernel_size)
  4084. def increments_to_index(idx, reduction_idx):
  4085. bh = idx[-n_dim:]
  4086. return [
  4087. (bh[i] * stride[i]) + (reduction_idx[i] * dilation[i]) - padding[i]
  4088. for i in range(n_dim)
  4089. ]
  4090. return _pool_offsets_to_indices(
  4091. offsets, kernel_size, input_size, increments_to_index
  4092. )
  4093. def _max_pool_with_indices(
  4094. x,
  4095. kernel_size,
  4096. stride,
  4097. padding,
  4098. dilation,
  4099. ceil_mode,
  4100. n_dim,
  4101. ):
  4102. kernel_size, stride, padding, dilation, _ = max_pool_checks(
  4103. x, kernel_size, stride, padding, dilation, n_dim=n_dim
  4104. )
  4105. out, offsets = _max_pool_with_offsets(
  4106. x, kernel_size, stride, padding, dilation, ceil_mode, n_dim=n_dim
  4107. )
  4108. indices = _low_memory_max_pool_offsets_to_indices(
  4109. offsets,
  4110. kernel_size,
  4111. x.shape[-n_dim:],
  4112. stride,
  4113. padding,
  4114. dilation,
  4115. )
  4116. return out, indices
  4117. # Fallback when we do not decompose to the low-memory path.
  4118. @register_lowering(aten.max_pool2d_with_indices, type_promotion_kind=None)
  4119. def max_pool2d_with_indices(
  4120. x,
  4121. kernel_size,
  4122. stride=None,
  4123. padding=0,
  4124. dilation=1,
  4125. ceil_mode=False,
  4126. ):
  4127. return _max_pool_with_indices(
  4128. x, kernel_size, stride, padding, dilation, ceil_mode, n_dim=2
  4129. )
  4130. # Fallback when we do not decompose to the low-memory path.
  4131. @register_lowering(aten.max_pool3d_with_indices, type_promotion_kind=None)
  4132. def max_pool3d_with_indices(
  4133. x,
  4134. kernel_size,
  4135. stride=None,
  4136. padding=0,
  4137. dilation=1,
  4138. ceil_mode=False,
  4139. ):
  4140. return _max_pool_with_indices(
  4141. x, kernel_size, stride, padding, dilation, ceil_mode, n_dim=3
  4142. )
  4143. fallback_max_pool2d_with_indices_backward = fallback_handler(
  4144. aten.max_pool2d_with_indices_backward.default,
  4145. add_to_fallback_set=False,
  4146. )
  4147. @register_lowering(aten.max_pool2d_with_indices_backward, type_promotion_kind=None)
  4148. def max_pool2d_with_indices_backward(
  4149. grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices
  4150. ):
  4151. if padding == 0:
  4152. padding = [0, 0]
  4153. if dilation == 1:
  4154. dilation = [1, 1]
  4155. if not stride:
  4156. stride = kernel_size
  4157. assert isinstance(x, TensorBox)
  4158. assert len(kernel_size) == 2
  4159. assert len(stride) == 2
  4160. assert len(padding) == 2
  4161. assert len(dilation) == 2
  4162. assert len(x.get_size()) in (3, 4)
  4163. # we will read this many times, so make sure it is computed
  4164. grad_output.realize_hint()
  4165. gO_stride = grad_output.maybe_get_stride()
  4166. x_stride: Optional[Sequence[Any]]
  4167. if isinstance(x, TensorBox) and isinstance(x.data.data, Pointwise): # type: ignore[attr-defined]
  4168. data = x.data.data # type: ignore[attr-defined]
  4169. device = data.get_device()
  4170. assert device is not None
  4171. x_buffer = ir.ComputedBuffer(
  4172. name=None,
  4173. layout=ir.FlexibleLayout(
  4174. device=device,
  4175. dtype=data.get_dtype(),
  4176. size=data.get_size(),
  4177. ),
  4178. data=data,
  4179. )
  4180. x_buffer.decide_layout()
  4181. x_stride = x_buffer.get_stride()
  4182. else:
  4183. x_stride = x.maybe_get_stride()
  4184. is_channels_last = (x_stride is not None and x_stride[1] == 1) or (
  4185. gO_stride is not None and gO_stride[1] == 1
  4186. )
  4187. if any(d != 1 for d in dilation):
  4188. # dilation NYI
  4189. return fallback_max_pool2d_with_indices_backward(
  4190. grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices
  4191. )
  4192. *_batch, _height, width = x.get_size()
  4193. *_, pooled_height, pooled_width = grad_output.get_size()
  4194. indices_loader = indices.make_loader()
  4195. grad_loader = grad_output.make_loader()
  4196. new_size = list(x.get_size())
  4197. h_window_size = max(
  4198. max(FloorDiv(h, stride[0]) - max(0, FloorDiv(h - kernel_size[0], stride[0])), 1)
  4199. for h in range(kernel_size[0] * 2)
  4200. )
  4201. w_window_size = max(
  4202. max(FloorDiv(w, stride[1]) - max(0, FloorDiv(w - kernel_size[1], stride[1])), 1)
  4203. for w in range(kernel_size[1] * 2)
  4204. )
  4205. window_size = h_window_size * w_window_size
  4206. if window_size > 25:
  4207. # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
  4208. return fallback_max_pool2d_with_indices_backward(
  4209. grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices
  4210. )
  4211. indices_size = indices.get_size()
  4212. def fn(idx):
  4213. *prefix, h, w = idx
  4214. index_test = ops.index_expr(h * width + w, torch.int32)
  4215. h = h + padding[0]
  4216. w = w + padding[1]
  4217. phstart = ops.index_expr(
  4218. FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32
  4219. )
  4220. pwstart = ops.index_expr(
  4221. FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32
  4222. )
  4223. phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32)
  4224. pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32)
  4225. phstart = ops.maximum(phstart, ops.constant(0, torch.int32))
  4226. pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32))
  4227. phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32))
  4228. pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32))
  4229. gradient = None
  4230. for ph_ in range(h_window_size):
  4231. for pw_ in range(w_window_size):
  4232. ph = ops.add(phstart, ops.constant(ph_, torch.int32))
  4233. pw = ops.add(pwstart, ops.constant(pw_, torch.int32))
  4234. grad_index = [
  4235. *prefix,
  4236. ops.indirect_indexing(
  4237. ops.minimum(ph, ops.sub(phend, ops.constant(1, torch.int32))),
  4238. indices_size[-2],
  4239. check=False,
  4240. ),
  4241. ops.indirect_indexing(
  4242. ops.minimum(pw, ops.sub(pwend, ops.constant(1, torch.int32))),
  4243. indices_size[-1],
  4244. check=False,
  4245. ),
  4246. ]
  4247. index_actual = indices_loader(grad_index)
  4248. grad_part = grad_loader(grad_index)
  4249. check = ops.eq(index_actual, index_test)
  4250. if gradient is None:
  4251. # don't need mask for 0, 0
  4252. gradient = ops.where(
  4253. check, grad_part, ops.constant(0.0, torch.float32)
  4254. )
  4255. else:
  4256. mask = ops.and_(
  4257. ops.and_(
  4258. ops.lt(ph, phend),
  4259. ops.lt(pw, pwend),
  4260. ),
  4261. check,
  4262. )
  4263. gradient = ops.where(mask, ops.add(gradient, grad_part), gradient)
  4264. assert gradient is not None
  4265. return gradient
  4266. out = Pointwise.create(
  4267. device=grad_output.get_device(),
  4268. dtype=grad_output.get_dtype(),
  4269. inner_fn=fn,
  4270. ranges=new_size,
  4271. )
  4272. if is_channels_last:
  4273. return ir.ExternKernel.require_channels_last(out)
  4274. else:
  4275. return out
  4276. def pad_adaptive_loader(x, pad_val=0.0):
  4277. x_loader = x.make_loader()
  4278. def load(prefix, increments, start_indices, end_indices):
  4279. ih, iw = increments
  4280. h_start_index, w_start_index = start_indices
  4281. h_end_index, w_end_index = end_indices
  4282. mask = ops.and_(
  4283. ops.lt(
  4284. ops.index_expr(h_start_index + ih, torch.int64),
  4285. ops.index_expr(h_end_index, torch.int64),
  4286. ),
  4287. ops.lt(
  4288. ops.index_expr(w_start_index + iw, torch.int64),
  4289. ops.index_expr(w_end_index, torch.int64),
  4290. ),
  4291. )
  4292. return ops.masked(
  4293. mask,
  4294. lambda: x_loader([*prefix, h_start_index + ih, w_start_index + iw]),
  4295. pad_val,
  4296. )
  4297. return load
  4298. def compute_indices_adaptive_pooling(start_index, end_index, h_in, w_in, h_out, w_out):
  4299. h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in)
  4300. h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in)
  4301. w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in)
  4302. w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in)
  4303. return h_start_index, h_end_index, w_start_index, w_end_index
  4304. def _adaptive_pooling_fn(
  4305. start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn
  4306. ):
  4307. h_in, w_in = in_sizes
  4308. h_out, w_out = out_sizes
  4309. (
  4310. h_start_index_fn,
  4311. h_end_index_fn,
  4312. w_start_index_fn,
  4313. w_end_index_fn,
  4314. ) = compute_indices_adaptive_pooling(
  4315. start_index, end_index, h_in, w_in, h_out, w_out
  4316. )
  4317. def fn(idx, loader):
  4318. *prefix, bh, bw = idx
  4319. h_start_index = h_start_index_fn(bh)
  4320. h_end_index = h_end_index_fn(bh)
  4321. w_start_index = w_start_index_fn(bw)
  4322. w_end_index = w_end_index_fn(bw)
  4323. result = None
  4324. for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])):
  4325. val = loader(
  4326. prefix,
  4327. [ih, iw],
  4328. [h_start_index, w_start_index],
  4329. [h_end_index, w_end_index],
  4330. )
  4331. if result is None:
  4332. result = val
  4333. else:
  4334. result = pooling_fn(val, result)
  4335. return result
  4336. return fn
  4337. def _adaptive_pooling_fn_with_idx(
  4338. start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn
  4339. ):
  4340. h_in, w_in = in_sizes
  4341. h_out, w_out = out_sizes
  4342. (
  4343. h_start_index_fn,
  4344. h_end_index_fn,
  4345. w_start_index_fn,
  4346. w_end_index_fn,
  4347. ) = compute_indices_adaptive_pooling(
  4348. start_index, end_index, h_in, w_in, h_out, w_out
  4349. )
  4350. def fn(idx, loader):
  4351. *prefix, bh, bw = idx
  4352. h_start_index = h_start_index_fn(bh)
  4353. h_end_index = h_end_index_fn(bh)
  4354. w_start_index = w_start_index_fn(bw)
  4355. w_end_index = w_end_index_fn(bw)
  4356. maxval = None
  4357. maxindex = None
  4358. for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])):
  4359. val = loader(
  4360. prefix,
  4361. [ih, iw],
  4362. [h_start_index, w_start_index],
  4363. [h_end_index, w_end_index],
  4364. )
  4365. index = ops.index_expr(
  4366. (h_start_index + ih) * w_in + w_start_index + iw, torch.int64
  4367. )
  4368. if maxindex is None:
  4369. maxindex = index
  4370. else:
  4371. maxindex = ops.where(ops.gt(val, maxval), index, maxindex)
  4372. if maxval is None:
  4373. maxval = val
  4374. else:
  4375. maxval = pooling_fn(val, maxval)
  4376. return maxindex
  4377. return fn
  4378. fallback_adaptive_avg_pool2d = fallback_handler(
  4379. aten._adaptive_avg_pool2d.default, add_to_fallback_set=False
  4380. )
  4381. @register_lowering(aten._adaptive_avg_pool2d)
  4382. def _adaptive_avg_pool2d(x, output_size):
  4383. if x.get_dtype() == torch.int64:
  4384. # not supported in eager
  4385. raise RuntimeError("'adaptive_avg_pool2d' not implemented for 'Long'")
  4386. assert isinstance(x, TensorBox)
  4387. assert len(output_size) == 2
  4388. x.realize_hint()
  4389. *batch, h_in, w_in = x.get_size()
  4390. h_in = V.graph.sizevars.guard_int(h_in)
  4391. w_in = V.graph.sizevars.guard_int(w_in)
  4392. h_out, w_out = output_size
  4393. # no-op if the same input and output
  4394. if h_in == h_out and w_in == w_out:
  4395. return clone(x)
  4396. if h_out == 0 or w_out == 0:
  4397. o_size = [*batch, h_out, w_out]
  4398. return empty(o_size, dtype=x.get_dtype(), device=x.get_device())
  4399. if h_in % h_out == 0 and w_in % w_out == 0:
  4400. kernel_size = [FloorDiv(h_in, h_out), FloorDiv(w_in, w_out)]
  4401. return avg_pool2d(x, kernel_size)
  4402. h_kernel_max = ceildiv((h_in + h_out - 1), h_out)
  4403. w_kernel_max = ceildiv((w_in + w_out - 1), w_out)
  4404. new_size = list(batch) + [h_out, w_out]
  4405. dtype = x.get_dtype()
  4406. window_size = h_kernel_max * w_kernel_max
  4407. if window_size > 25:
  4408. # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
  4409. return fallback_adaptive_avg_pool2d(x, output_size)
  4410. def start_index(index, out_dim, inp_dim):
  4411. return FloorDiv((index * inp_dim), out_dim)
  4412. def end_index(index, out_dim, inp_dim):
  4413. return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim)
  4414. fn_sum = _adaptive_pooling_fn(
  4415. start_index=start_index,
  4416. end_index=end_index,
  4417. kernel_maxes=[h_kernel_max, w_kernel_max],
  4418. in_sizes=[h_in, w_in],
  4419. out_sizes=[h_out, w_out],
  4420. pooling_fn=ops.add,
  4421. )
  4422. ones_loader = pad_adaptive_loader(ones_like(x))
  4423. def fn(idx):
  4424. return ops.truediv(
  4425. fn_sum(idx, pad_adaptive_loader(x)), fn_sum(idx, ones_loader)
  4426. )
  4427. rv = Pointwise.create(
  4428. device=x.get_device(),
  4429. dtype=dtype,
  4430. inner_fn=fn,
  4431. ranges=new_size,
  4432. )
  4433. # TODO: should we force these to be realized?
  4434. return rv
  4435. fallback_adaptive_max_pool2d = fallback_handler(
  4436. aten.adaptive_max_pool2d.default, add_to_fallback_set=False
  4437. )
  4438. @register_lowering(aten.adaptive_max_pool2d)
  4439. def adaptive_max_pool2d(x, output_size):
  4440. if x.get_dtype() == torch.int64:
  4441. # not supported in eager
  4442. raise RuntimeError("adaptive_max_pool2d not implemented for Long")
  4443. assert isinstance(x, TensorBox)
  4444. assert len(output_size) == 2
  4445. x.realize_hint()
  4446. *batch, h_in, w_in = x.get_size()
  4447. h_in = V.graph.sizevars.guard_int(h_in)
  4448. w_in = V.graph.sizevars.guard_int(w_in)
  4449. h_out, w_out = output_size
  4450. if h_out == 0 or w_out == 0:
  4451. o_size = [*batch, h_out, w_out]
  4452. return empty(o_size, dtype=x.get_dtype(), device=x.get_device()), empty(
  4453. o_size, dtype=torch.int64, device=x.get_device()
  4454. )
  4455. if h_in % h_out == 0 and w_in % w_out == 0:
  4456. # This is handled by a decomposition
  4457. raise ValueError
  4458. h_kernel_max = ceildiv((h_in + h_out - 1), h_out)
  4459. w_kernel_max = ceildiv((w_in + w_out - 1), w_out)
  4460. new_size = list(batch) + [h_out, w_out]
  4461. dtype = x.get_dtype()
  4462. window_size = h_kernel_max * w_kernel_max
  4463. if window_size > 25:
  4464. # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
  4465. return fallback_adaptive_max_pool2d(x, output_size)
  4466. def start_index(index, out_dim, inp_dim):
  4467. return FloorDiv((index * inp_dim), out_dim)
  4468. def end_index(index, out_dim, inp_dim):
  4469. return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim)
  4470. inner_func_max_val = _adaptive_pooling_fn(
  4471. start_index=start_index,
  4472. end_index=end_index,
  4473. kernel_maxes=[h_kernel_max, w_kernel_max],
  4474. in_sizes=[h_in, w_in],
  4475. out_sizes=[h_out, w_out],
  4476. pooling_fn=ops.maximum,
  4477. )
  4478. inner_func_max_idx = _adaptive_pooling_fn_with_idx(
  4479. start_index=start_index,
  4480. end_index=end_index,
  4481. kernel_maxes=[h_kernel_max, w_kernel_max],
  4482. in_sizes=[h_in, w_in],
  4483. out_sizes=[h_out, w_out],
  4484. pooling_fn=ops.maximum,
  4485. )
  4486. def inner_fn_max_val(idx):
  4487. return inner_func_max_val(idx, pad_adaptive_loader(x, float("-inf")))
  4488. def inner_fn_max_idx(idx):
  4489. return inner_func_max_idx(idx, pad_adaptive_loader(x, float("-inf")))
  4490. rv = Pointwise.create(
  4491. device=x.get_device(),
  4492. dtype=dtype,
  4493. inner_fn=inner_fn_max_val,
  4494. ranges=new_size,
  4495. )
  4496. ri = Pointwise.create(
  4497. device=x.get_device(),
  4498. dtype=torch.int64,
  4499. inner_fn=inner_fn_max_idx,
  4500. ranges=new_size,
  4501. )
  4502. return rv, ri
  4503. def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim, ndims):
  4504. out_sz = out_sz[dim]
  4505. in_sz = in_sz[dim]
  4506. kernel_sz = kernel_sz[dim]
  4507. samples_loader = samples.make_loader()
  4508. def load(prefix, i):
  4509. # Handle indexing for samples tensor correctly for different input dimensions
  4510. # samples tensor always has shape (N, C, 2) for fractional_max_pool2d where:
  4511. # - N=1 for 3D inputs (C,H,W), N=batch_size for 4D inputs (N,C,H,W)
  4512. # - C=num_channels
  4513. # - 2 for the two spatial dimensions (height, width)
  4514. samples_shape = samples.get_size()
  4515. if len(samples_shape) == 3: # Expected: (N, C, 2)
  4516. if len(prefix) == 1:
  4517. # 3D input case: prefix=(channel,), samples=(1, C, 2)
  4518. # Access: samples[0, channel, dim]
  4519. sample = samples_loader([0, prefix[0], ndims - 1 - dim])
  4520. elif len(prefix) >= 2:
  4521. # 4D+ input case: prefix=(batch, channel, ...), samples=(batch, C, 2)
  4522. # Access: samples[batch, channel, dim]
  4523. sample = samples_loader([prefix[0], prefix[1], ndims - 1 - dim])
  4524. else:
  4525. # Edge case - shouldn't happen for valid fractional pooling
  4526. sample = samples_loader([0, 0, ndims - 1 - dim])
  4527. else:
  4528. # Fallback for unexpected tensor shapes
  4529. sample = samples_loader([*prefix, ndims - 1 - dim])
  4530. i_expr = ops.index_expr(i, samples.get_dtype())
  4531. diff = ops.index_expr(in_sz - kernel_sz, torch.int64)
  4532. out_sz_expr = ops.index_expr(out_sz - 1, torch.int64)
  4533. alpha = ops.truediv(
  4534. ops.to_dtype(diff, torch.float64), ops.to_dtype(out_sz_expr, torch.float64)
  4535. )
  4536. alpha = ops.where(ops.eq(out_sz_expr, 0), 0, alpha)
  4537. seq_i = ops.trunc((i_expr + sample) * alpha) - ops.trunc(sample * alpha)
  4538. seq_i = ops.to_dtype(seq_i, torch.int64)
  4539. mask = ops.lt(i_expr, out_sz_expr)
  4540. return ops.indirect_indexing(ops.where(mask, seq_i, diff), sympy.sympify(in_sz))
  4541. return load
  4542. @register_lowering(aten.fractional_max_pool2d)
  4543. def fractional_max_pool2d(x, kernel_size, output_size, random_samples):
  4544. return _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim=2)
  4545. @register_lowering(aten.fractional_max_pool3d)
  4546. def fractional_max_pool3d(x, kernel_size, output_size, random_samples):
  4547. return _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim=3)
  4548. def _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim):
  4549. x.realize_hint()
  4550. batch, inp_dhw = x.shape[:-n_dim], x.shape[-n_dim:]
  4551. with config.patch(unroll_reductions_threshold=25):
  4552. dhw_index_fn = [
  4553. _fractional_pooling_offsets(
  4554. samples=random_samples,
  4555. in_sz=inp_dhw,
  4556. out_sz=output_size,
  4557. kernel_sz=kernel_size,
  4558. ndims=n_dim,
  4559. dim=d,
  4560. )
  4561. for d in range(n_dim)
  4562. ]
  4563. x_loader = x.make_loader()
  4564. def fn_inner(idx, reduction_idx):
  4565. prefix = idx[:-n_dim]
  4566. return x_loader([*prefix, *increments_to_index(idx, reduction_idx)])
  4567. def increments_to_index(idx, reduction_idx):
  4568. prefix = idx[:-n_dim]
  4569. bdhw = idx[-n_dim:]
  4570. return [
  4571. dhw_index_fn[d](prefix, bdhw[d]) + reduction_idx[d]
  4572. for d in range(n_dim)
  4573. ]
  4574. new_size = list(batch) + list(output_size)
  4575. dtype = x.get_dtype()
  4576. result = Reduction.create(
  4577. reduction_type="max",
  4578. input_node=x,
  4579. device=x.get_device(),
  4580. dst_dtype=dtype,
  4581. src_dtype=dtype,
  4582. inner_fn=fn_inner,
  4583. ranges=new_size,
  4584. reduction_ranges=kernel_size,
  4585. )
  4586. offsets = Reduction.create(
  4587. reduction_type="argmax",
  4588. input_node=x,
  4589. device=x.get_device(),
  4590. dst_dtype=torch.int64,
  4591. src_dtype=dtype,
  4592. inner_fn=fn_inner,
  4593. ranges=new_size,
  4594. reduction_ranges=kernel_size,
  4595. )
  4596. assert isinstance(result, TensorBox), result
  4597. if isinstance(result.data.data, Reduction): # type: ignore[attr-defined]
  4598. # Only realize if reduction isn't unrolled
  4599. result.realize()
  4600. assert isinstance(offsets, TensorBox), offsets
  4601. if isinstance(offsets.data.data, Reduction): # type: ignore[attr-defined]
  4602. # Only realize if reduction isn't unrolled
  4603. offsets.realize()
  4604. indices = _pool_offsets_to_indices(
  4605. offsets, kernel_size, x.shape, increments_to_index
  4606. )
  4607. return result, indices
  4608. @register_lowering(aten.upsample_nearest2d_backward.default)
  4609. def upsample_nearest2d_backward(
  4610. x, output_size=None, input_size=None, scales_h=None, scales_w=None
  4611. ):
  4612. x.realize_hint()
  4613. *_batch, inp_h, inp_w = x.get_size()
  4614. inp_h = V.graph.sizevars.guard_int(inp_h)
  4615. inp_w = V.graph.sizevars.guard_int(inp_w)
  4616. # pyrefly: ignore [not-iterable]
  4617. *_batch, out_h, out_w = input_size
  4618. if inp_h % out_h == 0 and inp_w % out_w == 0:
  4619. return avg_pool2d(
  4620. x, [FloorDiv(inp_h, out_h), FloorDiv(inp_w, out_w)], divisor_override=1
  4621. )
  4622. h_kernel_max = ceildiv(inp_h, out_h)
  4623. w_kernel_max = ceildiv(inp_w, out_w)
  4624. def start_index(index, out_dim, inp_dim):
  4625. return CeilDiv(index * inp_dim, sympy.sympify(out_dim))
  4626. def end_index(index, out_dim, inp_dim):
  4627. return start_index((index + 1), out_dim, inp_dim)
  4628. fn_sum = _adaptive_pooling_fn(
  4629. start_index=start_index,
  4630. end_index=end_index,
  4631. kernel_maxes=[h_kernel_max, w_kernel_max],
  4632. in_sizes=[inp_h, inp_w],
  4633. out_sizes=[out_h, out_w],
  4634. pooling_fn=ops.add,
  4635. )
  4636. def fn(idx):
  4637. return fn_sum(idx, pad_adaptive_loader(x))
  4638. rv = Pointwise.create(
  4639. device=x.get_device(),
  4640. dtype=x.get_dtype(),
  4641. inner_fn=fn,
  4642. # pyrefly: ignore [no-matching-overload]
  4643. ranges=list(input_size),
  4644. )
  4645. return rv
  4646. @register_lowering(aten.avg_pool2d, type_promotion_kind=None)
  4647. def avg_pool2d(
  4648. x,
  4649. kernel_size,
  4650. stride=(),
  4651. padding=0,
  4652. ceil_mode=False,
  4653. count_include_pad=True,
  4654. divisor_override=None,
  4655. ):
  4656. return _avg_poolnd(
  4657. x,
  4658. kernel_size,
  4659. stride,
  4660. padding,
  4661. ceil_mode,
  4662. count_include_pad,
  4663. divisor_override,
  4664. dim=2,
  4665. )
  4666. @register_lowering(aten.avg_pool3d, type_promotion_kind=None)
  4667. def avg_pool3d(
  4668. x,
  4669. kernel_size,
  4670. stride=(),
  4671. padding=0,
  4672. ceil_mode=False,
  4673. count_include_pad=True,
  4674. divisor_override=None,
  4675. ):
  4676. return _avg_poolnd(
  4677. x,
  4678. kernel_size,
  4679. stride,
  4680. padding,
  4681. ceil_mode,
  4682. count_include_pad,
  4683. divisor_override,
  4684. dim=3,
  4685. )
  4686. fallbacks_avg_poolnd = [
  4687. fallback_handler(aten.avg_pool1d.default, add_to_fallback_set=False),
  4688. fallback_handler(aten.avg_pool2d.default, add_to_fallback_set=False),
  4689. fallback_handler(aten.avg_pool3d.default, add_to_fallback_set=False),
  4690. ]
  4691. def _avg_poolnd(
  4692. x,
  4693. kernel_size,
  4694. stride,
  4695. padding,
  4696. ceil_mode,
  4697. count_include_pad,
  4698. divisor_override,
  4699. dim,
  4700. ):
  4701. if not stride:
  4702. stride = kernel_size
  4703. if not padding:
  4704. padding = [0] * dim
  4705. kernel_size = pad_listlike(kernel_size, dim)
  4706. stride = pad_listlike(stride, dim)
  4707. padding = pad_listlike(padding, dim)
  4708. assert isinstance(x, TensorBox)
  4709. assert len(kernel_size) == dim
  4710. assert len(stride) == dim
  4711. assert len(padding) == dim
  4712. assert len(x.get_size()) in (dim + 1, dim + 2)
  4713. x.realize_hint()
  4714. batch = x.get_size()[:-dim]
  4715. h = x.get_size()[-dim:]
  4716. h_out, ceil_modes = zip(
  4717. *[
  4718. pooling_size(h[i], i, kernel_size, stride, padding, ceil_mode)
  4719. for i in range(dim)
  4720. ]
  4721. )
  4722. if any(padding) or any(ceil_modes):
  4723. x_loader = constant_boundary_condition(x, 0.0, dim=dim)
  4724. had_padding = True
  4725. else:
  4726. x_loader = x.make_loader()
  4727. had_padding = False
  4728. new_size = list(batch) + list(h_out)
  4729. dtype = x.get_dtype()
  4730. # compute in higher-precision until scaling
  4731. output_dtype = get_promoted_dtype(
  4732. x,
  4733. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  4734. return_compute_dtype=True,
  4735. )
  4736. def fn_inner(idx, reduction_idx):
  4737. prefix = idx[:-dim]
  4738. bh = idx[-dim:]
  4739. ih = reduction_idx
  4740. ih = [bh[i] * stride[i] + ih[i] - padding[i] for i in range(dim)]
  4741. return x_loader([*prefix, *ih])
  4742. window_size = functools.reduce(operator.mul, kernel_size)
  4743. if window_size > 25 and any(
  4744. V.graph.sizevars.statically_known_true(sympy.Ne(k, s))
  4745. for k, s in zip(kernel_size, stride)
  4746. ):
  4747. fallback = fallbacks_avg_poolnd[dim - 1]
  4748. return fallback(
  4749. x,
  4750. kernel_size,
  4751. stride,
  4752. padding,
  4753. ceil_mode,
  4754. count_include_pad,
  4755. divisor_override,
  4756. )
  4757. # TODO: remove this when #100331 is merged. We only do this
  4758. # for window_size <=25 to avoid performance regressions compared
  4759. # to the previous algorithm which unrolled manually for <=25
  4760. context = (
  4761. config.patch(unroll_reductions_threshold=25)
  4762. if window_size <= 25
  4763. else contextlib.nullcontext()
  4764. )
  4765. device = x.get_device()
  4766. assert device is not None
  4767. with context:
  4768. rv = Reduction.create(
  4769. reduction_type="sum",
  4770. input_node=x,
  4771. device=device,
  4772. dst_dtype=output_dtype,
  4773. src_dtype=dtype,
  4774. inner_fn=fn_inner,
  4775. ranges=new_size,
  4776. reduction_ranges=kernel_size,
  4777. )
  4778. if hasattr(rv.data, "data") and isinstance(rv.data.data, Reduction):
  4779. # Only realize if reduction isn't unrolled
  4780. rv.realize()
  4781. if not had_padding or divisor_override:
  4782. divisor = divisor_override if divisor_override else window_size
  4783. result = div_prim(rv, divisor)
  4784. else:
  4785. def fn_count(idx):
  4786. bh = idx[-dim:]
  4787. divide_factors = []
  4788. for i in range(dim):
  4789. hstart = bh[i] * stride[i] - padding[i]
  4790. hend = sympy.Min(hstart + kernel_size[i], h[i] + padding[i])
  4791. if not count_include_pad:
  4792. hstart = sympy.Max(hstart, 0)
  4793. hend = sympy.Min(hend, h[i])
  4794. factor = ops.index_expr(hend - hstart, torch.int32)
  4795. divide_factors.append(factor)
  4796. return functools.reduce(ops.mul, divide_factors)
  4797. divide_factor = Pointwise.create(
  4798. device=x.get_device(),
  4799. dtype=dtype,
  4800. inner_fn=fn_count,
  4801. ranges=new_size,
  4802. )
  4803. result = div_prim(rv, divide_factor)
  4804. return to_dtype(result, dtype)
  4805. fallback_avg_pool2d_backward = fallback_handler(
  4806. aten.avg_pool2d_backward.default, add_to_fallback_set=False
  4807. )
  4808. @register_lowering(aten.avg_pool2d_backward, type_promotion_kind=None)
  4809. def avg_pool2d_backward(
  4810. grad_output,
  4811. x,
  4812. kernel_size,
  4813. stride,
  4814. padding,
  4815. ceil_mode,
  4816. count_include_pad,
  4817. divisor_override=None,
  4818. ):
  4819. assert divisor_override is None or divisor_override != 0, "divisor must be not zero"
  4820. if not stride:
  4821. stride = kernel_size
  4822. if not padding:
  4823. padding = [0, 0]
  4824. assert isinstance(grad_output, TensorBox)
  4825. assert isinstance(x, TensorBox)
  4826. assert len(kernel_size) == 2
  4827. assert len(stride) == 2
  4828. assert len(padding) == 2
  4829. assert len(x.get_size()) in (3, 4)
  4830. grad_output.realize_hint() # we will read this many times, so make sure it is computed
  4831. *_, height, width = x.get_size()
  4832. _h_out, ceil_mode1 = pooling_size(
  4833. height, 0, kernel_size, stride, padding, ceil_mode
  4834. )
  4835. _w_out, ceil_mode2 = pooling_size(width, 1, kernel_size, stride, padding, ceil_mode)
  4836. grad_loader = grad_output.make_loader()
  4837. had_padding = padding[0] or padding[1] or ceil_mode1 or ceil_mode2
  4838. *_, pooled_height, pooled_width = grad_output.get_size()
  4839. new_size = list(x.get_size())
  4840. dtype = x.get_dtype()
  4841. h_window_size = max(
  4842. max(FloorDiv(h, stride[0]) - max(0, FloorDiv(h - kernel_size[0], stride[0])), 1)
  4843. for h in range(kernel_size[0] * 2)
  4844. )
  4845. w_window_size = max(
  4846. max(FloorDiv(w, stride[1]) - max(0, FloorDiv(w - kernel_size[1], stride[1])), 1)
  4847. for w in range(kernel_size[1] * 2)
  4848. )
  4849. window_size = h_window_size * w_window_size
  4850. if window_size > 25:
  4851. # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
  4852. return fallback_avg_pool2d_backward(
  4853. grad_output,
  4854. x,
  4855. kernel_size,
  4856. stride,
  4857. padding,
  4858. ceil_mode,
  4859. count_include_pad,
  4860. divisor_override,
  4861. )
  4862. def compute_pool_size_without_padding(ph, pw):
  4863. """
  4864. This computes the scaling factor that we will divide an element
  4865. by when `count_include_pad=False`
  4866. """
  4867. stride_h = ops.constant(stride[0], torch.int32)
  4868. stride_w = ops.constant(stride[1], torch.int32)
  4869. pad_h = ops.constant(padding[0], torch.int32)
  4870. pad_w = ops.constant(padding[1], torch.int32)
  4871. kernel_h = ops.constant(kernel_size[0], torch.int32)
  4872. kernel_w = ops.constant(kernel_size[1], torch.int32)
  4873. hstart = ops.sub(ops.mul(ph, stride_h), pad_h)
  4874. wstart = ops.sub(ops.mul(pw, stride_w), pad_w)
  4875. hend = ops.minimum(
  4876. ops.add(hstart, kernel_h),
  4877. ops.add(ops.index_expr(height, torch.int32), pad_h),
  4878. )
  4879. wend = ops.minimum(
  4880. ops.add(wstart, kernel_w),
  4881. ops.add(ops.index_expr(width, torch.int32), pad_w),
  4882. )
  4883. hstart = ops.maximum(hstart, ops.constant(0, torch.int32))
  4884. wstart = ops.maximum(wstart, ops.constant(0, torch.int32))
  4885. hend = ops.minimum(hend, ops.index_expr(height, torch.int32))
  4886. wend = ops.minimum(wend, ops.index_expr(width, torch.int32))
  4887. divide_factor = ops.mul(ops.sub(hend, hstart), ops.sub(wend, wstart))
  4888. return divide_factor
  4889. def fn(idx):
  4890. *prefix, h, w = idx
  4891. h = h + padding[0]
  4892. w = w + padding[1]
  4893. phstart = ops.index_expr(
  4894. FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32
  4895. )
  4896. pwstart = ops.index_expr(
  4897. FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32
  4898. )
  4899. phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32)
  4900. pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32)
  4901. phstart = ops.maximum(phstart, ops.constant(0, torch.int32))
  4902. pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32))
  4903. phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32))
  4904. pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32))
  4905. gradient = None
  4906. for ph_ in range(h_window_size):
  4907. for pw_ in range(w_window_size):
  4908. ph = ops.add(phstart, ops.constant(ph_, torch.int32))
  4909. pw = ops.add(pwstart, ops.constant(pw_, torch.int32))
  4910. if divisor_override is not None:
  4911. scale = divisor_override
  4912. elif count_include_pad or not had_padding:
  4913. scale = kernel_size[0] * kernel_size[1]
  4914. else:
  4915. scale = compute_pool_size_without_padding(ph, pw)
  4916. part = ops.truediv(
  4917. grad_loader(
  4918. [
  4919. *prefix,
  4920. ops.indirect_indexing(
  4921. ops.minimum(
  4922. ph, ops.sub(phend, ops.constant(1, torch.int32))
  4923. ),
  4924. pooled_height,
  4925. check=False,
  4926. ),
  4927. ops.indirect_indexing(
  4928. ops.minimum(
  4929. pw, ops.sub(pwend, ops.constant(1, torch.int32))
  4930. ),
  4931. pooled_width,
  4932. check=False,
  4933. ),
  4934. ]
  4935. ),
  4936. scale,
  4937. )
  4938. mask = ops.and_(
  4939. ops.lt(ph, phend),
  4940. ops.lt(pw, pwend),
  4941. )
  4942. if gradient is None:
  4943. gradient = ops.where(mask, part, ops.constant(0.0, torch.float32))
  4944. else:
  4945. gradient = ops.where(mask, ops.add(gradient, part), gradient)
  4946. assert gradient is not None
  4947. return gradient
  4948. rv = Pointwise.create(
  4949. device=grad_output.get_device(),
  4950. dtype=dtype,
  4951. inner_fn=fn,
  4952. ranges=new_size,
  4953. )
  4954. return rv
  4955. fallback_avg_pool3d_backward = fallback_handler(
  4956. aten.avg_pool3d_backward.default, add_to_fallback_set=False
  4957. )
  4958. @register_lowering(aten.avg_pool3d_backward, type_promotion_kind=None)
  4959. def avg_pool3d_backward(
  4960. grad_output,
  4961. x,
  4962. kernel_size,
  4963. stride,
  4964. padding,
  4965. ceil_mode,
  4966. count_include_pad,
  4967. divisor_override=None,
  4968. ):
  4969. assert divisor_override is None or divisor_override != 0, "divisor must be not zero"
  4970. if not stride:
  4971. stride = kernel_size
  4972. if not padding:
  4973. padding = [0, 0, 0]
  4974. assert isinstance(grad_output, TensorBox)
  4975. assert isinstance(x, TensorBox)
  4976. assert len(kernel_size) == 3
  4977. assert len(stride) == 3
  4978. assert len(padding) == 3
  4979. assert len(x.get_size()) in (4, 5)
  4980. grad_output.realize_hint()
  4981. *_batch, depth, height, width = x.get_size()
  4982. _d_out, ceil_mode_d = pooling_size(
  4983. depth, 0, kernel_size, stride, padding, ceil_mode
  4984. )
  4985. _h_out, ceil_mode_h = pooling_size(
  4986. height, 1, kernel_size, stride, padding, ceil_mode
  4987. )
  4988. _w_out, ceil_mode_w = pooling_size(
  4989. width, 2, kernel_size, stride, padding, ceil_mode
  4990. )
  4991. grad_loader = grad_output.make_loader()
  4992. had_padding = any(padding) or ceil_mode_d or ceil_mode_h or ceil_mode_w
  4993. *_, pooled_depth, pooled_height, pooled_width = grad_output.get_size()
  4994. new_size = list(x.get_size())
  4995. dtype = x.get_dtype()
  4996. d_window_size, h_window_size, w_window_size = (
  4997. max(
  4998. max(d // stride[i] - max(0, (d - kernel_size[i]) // stride[i]), 1)
  4999. for d in range(kernel_size[i] * 2)
  5000. )
  5001. for i in range(3)
  5002. )
  5003. window_size = d_window_size * h_window_size * w_window_size
  5004. if window_size > 125:
  5005. # Kernel size too big. Results in hard-to-optimize Triton code.
  5006. return fallback_avg_pool3d_backward(
  5007. grad_output,
  5008. x,
  5009. kernel_size,
  5010. stride,
  5011. padding,
  5012. ceil_mode,
  5013. count_include_pad,
  5014. divisor_override,
  5015. )
  5016. def compute_pool_size_without_padding(pd, ph, pw):
  5017. stride_d, stride_h, stride_w = (ops.constant(s, torch.int32) for s in stride)
  5018. pad_d, pad_h, pad_w = (ops.constant(p, torch.int32) for p in padding)
  5019. kernel_d, kernel_h, kernel_w = (
  5020. ops.constant(k, torch.int32) for k in kernel_size
  5021. )
  5022. dstart, hstart, wstart = (
  5023. ops.sub(ops.mul(p, s), pad)
  5024. for p, s, pad in zip(
  5025. [pd, ph, pw], [stride_d, stride_h, stride_w], [pad_d, pad_h, pad_w]
  5026. )
  5027. )
  5028. dend, hend, wend = (
  5029. ops.minimum(
  5030. ops.add(start, k), ops.add(ops.index_expr(dim, torch.int32), pad)
  5031. )
  5032. for start, k, dim, pad in zip(
  5033. [dstart, hstart, wstart],
  5034. [kernel_d, kernel_h, kernel_w],
  5035. [depth, height, width],
  5036. [pad_d, pad_h, pad_w],
  5037. )
  5038. )
  5039. dstart, hstart, wstart = (
  5040. ops.maximum(start, ops.constant(0, torch.int32))
  5041. for start in [dstart, hstart, wstart]
  5042. )
  5043. dend, hend, wend = (
  5044. ops.minimum(end, ops.index_expr(dim, torch.int32))
  5045. for end, dim in zip([dend, hend, wend], [depth, height, width])
  5046. )
  5047. divide_factor = ops.mul(
  5048. ops.mul(ops.sub(dend, dstart), ops.sub(hend, hstart)), ops.sub(wend, wstart)
  5049. )
  5050. return divide_factor
  5051. def fn(idx):
  5052. *prefix, d, h, w = idx
  5053. d, h, w = (v + pad for v, pad in zip([d, h, w], padding))
  5054. pdstart, phstart, pwstart = (
  5055. ops.index_expr(FloorDiv(v - k + s, s), torch.int32)
  5056. for v, k, s in zip([d, h, w], kernel_size, stride)
  5057. )
  5058. pdend, phend, pwend = (
  5059. ops.index_expr(FloorDiv(v, s) + 1, torch.int32)
  5060. for v, s in zip([d, h, w], stride)
  5061. )
  5062. pdstart, phstart, pwstart = (
  5063. ops.maximum(pstart, ops.constant(0, torch.int32))
  5064. for pstart in [pdstart, phstart, pwstart]
  5065. )
  5066. pdend, phend, pwend = (
  5067. ops.minimum(pend, ops.index_expr(pooled_dim, torch.int32))
  5068. for pend, pooled_dim in zip(
  5069. [pdend, phend, pwend], [pooled_depth, pooled_height, pooled_width]
  5070. )
  5071. )
  5072. gradient = None
  5073. # Iterate over the 3D region to accumulate gradients
  5074. for pd_ in range(d_window_size):
  5075. for ph_ in range(h_window_size):
  5076. for pw_ in range(w_window_size):
  5077. pd, ph, pw = (
  5078. ops.add(pstart, ops.constant(p_, torch.int32))
  5079. for pstart, p_ in zip(
  5080. [pdstart, phstart, pwstart], [pd_, ph_, pw_]
  5081. )
  5082. )
  5083. if divisor_override is not None:
  5084. scale = divisor_override
  5085. elif count_include_pad or not had_padding:
  5086. scale = kernel_size[0] * kernel_size[1] * kernel_size[2]
  5087. else:
  5088. scale = compute_pool_size_without_padding(pd, ph, pw)
  5089. part = ops.truediv(
  5090. grad_loader(
  5091. [
  5092. *prefix,
  5093. ops.indirect_indexing(
  5094. ops.minimum(
  5095. pd, ops.sub(pdend, ops.constant(1, torch.int32))
  5096. ),
  5097. pooled_depth,
  5098. check=False,
  5099. ),
  5100. ops.indirect_indexing(
  5101. ops.minimum(
  5102. ph, ops.sub(phend, ops.constant(1, torch.int32))
  5103. ),
  5104. pooled_height,
  5105. check=False,
  5106. ),
  5107. ops.indirect_indexing(
  5108. ops.minimum(
  5109. pw, ops.sub(pwend, ops.constant(1, torch.int32))
  5110. ),
  5111. pooled_width,
  5112. check=False,
  5113. ),
  5114. ]
  5115. ),
  5116. scale,
  5117. )
  5118. mask = ops.and_(
  5119. ops.and_(ops.lt(pd, pdend), ops.lt(ph, phend)),
  5120. ops.lt(pw, pwend),
  5121. )
  5122. if gradient is None:
  5123. gradient = ops.where(
  5124. mask, part, ops.constant(0.0, torch.float32)
  5125. )
  5126. else:
  5127. gradient = ops.where(mask, ops.add(gradient, part), gradient)
  5128. assert gradient is not None
  5129. return gradient
  5130. rv = Pointwise.create(
  5131. device=grad_output.get_device(),
  5132. dtype=dtype,
  5133. inner_fn=fn,
  5134. ranges=new_size,
  5135. )
  5136. return rv
  5137. def _validate_reduction_axis(x, axis):
  5138. size = x.get_size()
  5139. if isinstance(axis, int):
  5140. axis = [axis]
  5141. elif not axis:
  5142. axis = range(len(size))
  5143. if len(size) == 0:
  5144. assert tuple(axis) in [(), (0,), (-1,)], f"invalid axis: {axis}"
  5145. return []
  5146. axis = list(axis)
  5147. for i in range(len(axis)):
  5148. if axis[i] < 0:
  5149. axis[i] += len(size) if len(size) else 1
  5150. assert 0 <= axis[i] < len(size) or (len(size) == 0 and axis[i] == 0)
  5151. assert len(OrderedSet(axis)) == len(axis), "reduction axis not unique"
  5152. return axis
  5153. def _make_reduction_inner(
  5154. x, *, axis, keepdims, dtype, override_return_dtype, reduction_type=None
  5155. ):
  5156. if dtype is not None:
  5157. x = to_dtype(x, dtype)
  5158. size = x.get_size()
  5159. axis = OrderedSet[int](_validate_reduction_axis(x, axis))
  5160. kept_sizes = []
  5161. kept_idx = []
  5162. reduced_sizes = []
  5163. reduced_idx = []
  5164. for i in range(len(size)):
  5165. if i in axis:
  5166. reduced_idx.append(i)
  5167. reduced_sizes.append(size[i])
  5168. else:
  5169. kept_idx.append(i)
  5170. kept_sizes.append(size[i])
  5171. # For argmax/argmin compute logical indices when the tensor has non-contiguous layout.
  5172. should_compute_logical_index = False
  5173. if (
  5174. reduction_type in ("argmax", "argmin")
  5175. and len(reduced_sizes) > 1
  5176. and is_triton(x)
  5177. ):
  5178. if isinstance(x.data, PermuteView):
  5179. should_compute_logical_index = True
  5180. elif isinstance(x.data, ir.ReinterpretView) or (
  5181. isinstance(x.data, ir.StorageBox) and isinstance(x.data.data, ir.Buffer)
  5182. ):
  5183. layout = x.get_layout()
  5184. should_compute_logical_index = (
  5185. layout.is_transposed() or not layout.is_contiguous()
  5186. )
  5187. def loader(index, reduction_index):
  5188. assert len(reduction_index) == len(reduced_idx)
  5189. if keepdims:
  5190. assert len(index) == len(size)
  5191. index = [index[i] for i in kept_idx]
  5192. assert len(index) == len(kept_idx)
  5193. new_index = [None] * (len(index) + len(reduction_index))
  5194. for idx, var in itertools.chain(
  5195. zip(kept_idx, index), zip(reduced_idx, reduction_index)
  5196. ):
  5197. new_index[idx] = var
  5198. value = inner_loader(new_index)
  5199. # For argmax/argmin, return tuple with logical linear index if needed
  5200. if should_compute_logical_index:
  5201. rindex = [sympy.expand(i) for i in reduction_index]
  5202. # Compute linear index in row-major order
  5203. # For reduction_ranges = [4, 6]: linear_index = r0 * 6 + r1
  5204. linear_idx = rindex[0]
  5205. for i in range(1, len(rindex)):
  5206. linear_idx = linear_idx * reduced_sizes[i] + rindex[i]
  5207. return (value, ops.index_expr(linear_idx, torch.int64))
  5208. return value
  5209. if keepdims:
  5210. new_size = list(size)
  5211. for i in reduced_idx:
  5212. new_size[i] = sympy.S.One
  5213. else:
  5214. new_size = kept_sizes
  5215. inner_loader = x.make_loader()
  5216. return dict(
  5217. device=x.get_device(),
  5218. dst_dtype=override_return_dtype or x.get_dtype(),
  5219. src_dtype=x.get_dtype(),
  5220. inner_fn=loader,
  5221. ranges=new_size,
  5222. reduction_ranges=reduced_sizes,
  5223. )
  5224. def make_reduction(reduction_type: ReductionType, override_return_dtype=None):
  5225. def inner(x, axis=None, keepdims=False, *, dtype=None):
  5226. kwargs = _make_reduction_inner(
  5227. x,
  5228. axis=axis,
  5229. keepdims=keepdims,
  5230. dtype=dtype,
  5231. override_return_dtype=override_return_dtype,
  5232. reduction_type=reduction_type,
  5233. )
  5234. result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs)
  5235. if isinstance(
  5236. result.data.data, # type: ignore[attr-defined, attr-type, union-attr]
  5237. Reduction,
  5238. ): # Only realize if reduction isn't unrolled
  5239. result.realize()
  5240. return result
  5241. return inner
  5242. def _make_scan_inner(x, *, axis, dtype):
  5243. if dtype is not None:
  5244. x = to_dtype(x, dtype)
  5245. axis = _validate_dim(x, axis)
  5246. return dict(
  5247. device=x.get_device(),
  5248. dtypes=(x.get_dtype(),),
  5249. inner_fns=(x.make_loader(),),
  5250. size=x.get_size(),
  5251. axis=axis,
  5252. )
  5253. @register_lowering(aten.mean)
  5254. def mean(x, axis=None, keepdim=False, *, dtype=None):
  5255. if dtype is not None:
  5256. x = to_dtype(x, dtype)
  5257. size = x.get_size()
  5258. axis = _validate_reduction_axis(x, axis)
  5259. # compute in higher-precision until end of mean lowering
  5260. output_dtype = x.get_dtype()
  5261. if output_dtype in (torch.float16, torch.bfloat16):
  5262. x = to_dtype(x, torch.float)
  5263. sum_result = sum_(x, axis, keepdim)
  5264. denom = sympy_product(size[i] for i in axis)
  5265. denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device())
  5266. denom = ExpandView.create(denom, list(sum_result.get_size()))
  5267. return to_dtype(div(sum_result, denom), output_dtype)
  5268. def var_mean_sum_(x, axis, correction, keepdim, return_mean):
  5269. if correction is None:
  5270. correction = 1
  5271. size = x.get_size()
  5272. axis = _validate_reduction_axis(x, axis)
  5273. x_mean = mean(x, axis, keepdim=True)
  5274. if return_mean:
  5275. x_mean.realize()
  5276. diffs = square(sub(x, x_mean))
  5277. sum_result = sum_(diffs, axis, keepdim)
  5278. denom = sympy_product(size[i] for i in axis)
  5279. if correction:
  5280. denom = sympy.Max(denom - correction, 0)
  5281. denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device())
  5282. denom = ExpandView.create(denom, list(sum_result.get_size()))
  5283. x_var = div(sum_result, denom)
  5284. if not return_mean:
  5285. return (x_var,)
  5286. x_mean = x_mean if keepdim else squeeze(x_mean, axis)
  5287. return x_var, x_mean
  5288. def use_two_step_variance(x, axis, keepdim):
  5289. # Instead of unrolling welford, just unroll the simpler two-step var
  5290. axis = _validate_reduction_axis(x, axis)
  5291. kwargs = _make_reduction_inner(
  5292. x, axis=axis, keepdims=keepdim, dtype=None, override_return_dtype=None
  5293. )
  5294. ranges = kwargs["ranges"]
  5295. reduction_numel = sympy_product(kwargs["reduction_ranges"])
  5296. return (
  5297. isinstance(reduction_numel, sympy.Integer)
  5298. and int(reduction_numel) < config.unroll_reductions_threshold
  5299. and sympy_product(ranges) != 1
  5300. )
  5301. def var_mean_welford_(x, axis, *, correction, keepdim, return_mean):
  5302. if correction is None:
  5303. correction = 1
  5304. kwargs = _make_reduction_inner(
  5305. x, axis=axis, keepdims=keepdim, dtype=None, override_return_dtype=None
  5306. )
  5307. loader = kwargs.pop("inner_fn")
  5308. kwargs.pop("dst_dtype")
  5309. kwargs.pop("src_dtype")
  5310. mean, m2, _ = ir.WelfordReduction.create(
  5311. inner_fns=(loader,),
  5312. reduction_type="welford_reduce",
  5313. dtype=x.get_dtype(),
  5314. **kwargs,
  5315. )
  5316. m2.realize()
  5317. dtype = x.get_dtype()
  5318. size = x.get_size()
  5319. axis = _validate_reduction_axis(x, axis)
  5320. rnumel = sympy_product(size[i] for i in axis)
  5321. def get_constant_or_index_expr(x, dtype):
  5322. if isinstance(x, sympy.Expr) and not x.is_number:
  5323. return ops.to_dtype(ops.index_expr(x, torch.int64), dtype)
  5324. return ops.constant(x, dtype)
  5325. def scale_fn(data):
  5326. c = get_constant_or_index_expr(correction, dtype)
  5327. N = get_constant_or_index_expr(rnumel, dtype)
  5328. zero = ops.constant(0, dtype)
  5329. return data / ops.maximum(zero, N - c)
  5330. var = make_pointwise(scale_fn)(m2)
  5331. if return_mean:
  5332. mean.realize()
  5333. return var, mean
  5334. return (var,)
  5335. def var_mean_helper_(x, *, axis, correction, keepdim, return_mean):
  5336. out_dtype = x.get_dtype()
  5337. compute_dtype = get_computation_dtype(out_dtype)
  5338. x = to_dtype(x, compute_dtype, copy=False)
  5339. kwargs = dict(
  5340. x=x,
  5341. axis=axis,
  5342. correction=correction,
  5343. keepdim=keepdim,
  5344. return_mean=return_mean,
  5345. )
  5346. output = (
  5347. var_mean_sum_(**kwargs)
  5348. if use_two_step_variance(x, axis=axis, keepdim=keepdim)
  5349. else var_mean_welford_(**kwargs)
  5350. )
  5351. output = tuple(to_dtype(x, out_dtype, copy=False) for x in output)
  5352. return output[0] if not return_mean else output
  5353. @register_lowering([aten.var, prims.var])
  5354. def var_(x, axis=None, *, correction=None, keepdim=False):
  5355. return var_mean_helper_(
  5356. x, axis=axis, correction=correction, keepdim=keepdim, return_mean=False
  5357. )
  5358. @register_lowering(aten.var_mean)
  5359. def var_mean(x, axis=None, *, correction=None, keepdim=False):
  5360. return var_mean_helper_(
  5361. x, axis=axis, correction=correction, keepdim=keepdim, return_mean=True
  5362. )
  5363. def pow_recursive(x, y, dtype):
  5364. if y < 0:
  5365. return pow_recursive(ops.reciprocal(x), -y, dtype)
  5366. if y == 0:
  5367. return ops.constant(1, dtype)
  5368. if y == 1:
  5369. return x
  5370. result = pow_recursive(x, y // 2, dtype)
  5371. result = ops.mul(result, result)
  5372. if (y % 2) == 1:
  5373. result = ops.mul(result, x)
  5374. return result
  5375. @make_pointwise
  5376. def pow_native(a, b):
  5377. return ops.pow(a, b)
  5378. fallback_pow_tensor_tensor = fallback_handler(
  5379. aten.pow.Tensor_Tensor, add_to_fallback_set=False
  5380. )
  5381. fallback_pow_scalar = fallback_handler(aten.pow.Scalar, add_to_fallback_set=False)
  5382. fallback_pow_tensor_scalar = fallback_handler(
  5383. aten.pow.Tensor_Scalar, add_to_fallback_set=False
  5384. )
  5385. @register_lowering(aten.pow, broadcast=True)
  5386. def pow(a, b):
  5387. if isinstance(b, float) and b.is_integer():
  5388. return pow(a, int(b))
  5389. elif isinstance(b, float) and b == 0.5:
  5390. return sqrt(a)
  5391. elif isinstance(b, int) and b == 1:
  5392. return clone(a)
  5393. # Type promotion ensures all tensor arguments have the same type
  5394. dtype = next(x.get_dtype() for x in (a, b) if isinstance(x, ir.TensorBox))
  5395. is_integer_pow = is_integer_dtype(dtype)
  5396. # Optimize away small fixed powers, or for integers avoid falling back to ATen
  5397. embed_exponent = isinstance(b, int) and (
  5398. -32 < b < 32 or (is_integer_pow and b >= 0)
  5399. )
  5400. if embed_exponent:
  5401. loader = a.make_loader()
  5402. def fn(idx):
  5403. return pow_recursive(loader(idx), b, a.get_dtype())
  5404. return Pointwise.create(
  5405. device=a.get_device(),
  5406. dtype=a.get_dtype(),
  5407. inner_fn=fn,
  5408. ranges=a.get_size(),
  5409. )
  5410. if isinstance(a, Number):
  5411. if a == 1:
  5412. return full_like(b, 1)
  5413. if a == 2 and is_float_dtype(b.get_dtype()):
  5414. return exp2(b)
  5415. if is_integer_pow:
  5416. # ops.pow doesn't work for integers
  5417. if isinstance(a, Number):
  5418. return fallback_pow_scalar(a, b)
  5419. elif isinstance(b, Number):
  5420. return fallback_pow_tensor_scalar(a, b)
  5421. else:
  5422. return fallback_pow_tensor_tensor(a, b)
  5423. return pow_native(a, b)
  5424. def mutate_to(changed, val, unsafe_alias=False):
  5425. if isinstance(changed, TensorBox):
  5426. changed_data = changed.data
  5427. else:
  5428. changed_data = changed
  5429. if isinstance(val, TensorBox):
  5430. val = val.data
  5431. if not isinstance(val, ir.StorageBox):
  5432. # introduce a copy to handle views
  5433. node = Pointwise.create(
  5434. device=changed.get_device(),
  5435. dtype=changed.get_dtype(),
  5436. inner_fn=val.make_loader(),
  5437. ranges=changed.get_size(),
  5438. )
  5439. assert isinstance(node, (BaseView, MutableBox))
  5440. val = node.data
  5441. assert isinstance(val, ir.StorageBox)
  5442. if isinstance(changed_data, ir.StorageBox) and not (
  5443. changed_data.is_input_buffer()
  5444. # In AOTI, module parameters and buffers are not lifted as graph inputs
  5445. or changed_data.is_module_buffer()
  5446. or isinstance(changed_data.data, ir.NopKernel)
  5447. ):
  5448. # Fast path, just swing the data pointer
  5449. val.realize()
  5450. changed_data.data = val.data
  5451. return changed
  5452. ir.MutationLayoutSHOULDREMOVE.realize_into(
  5453. val, changed_data, unsafe_alias=unsafe_alias
  5454. )
  5455. return changed
  5456. @register_lowering(aten.fill_)
  5457. def fill_(x, fill_value):
  5458. return mutate_to(x, full_like(x, fill_value))
  5459. @register_lowering(aten.copy_, type_promotion_kind=None)
  5460. def copy_(dst, src, non_blocking=False):
  5461. if dst is src:
  5462. # dst.copy_(dst) can happen from the reinplacing pass
  5463. return dst
  5464. src = to_device(src, dst.get_device())
  5465. src = to_dtype(src, dst.get_dtype())
  5466. src = expand(src, dst.get_size())
  5467. return mutate_to(dst, src)
  5468. @make_pointwise
  5469. def floordiv(a, b):
  5470. return ops.floordiv(a, b)
  5471. @make_pointwise
  5472. def truncdiv(a, b):
  5473. return ops.truncdiv(a, b)
  5474. @register_lowering(aten.div, broadcast=True)
  5475. def div_mode(a, b, rounding_mode=None):
  5476. both_integer = is_integer_type(a) and is_integer_type(b)
  5477. both_boolean = is_boolean_type(a) and is_boolean_type(b)
  5478. # floordiv and truncdiv need special handling for integer tensors on Triton,
  5479. # see the discussion at https://github.com/triton-lang/triton/issues/605
  5480. if rounding_mode == "floor":
  5481. assert not both_boolean, "floordiv operands can not be boolean at the same time"
  5482. return floordiv(a, b) if both_integer else floor(div(a, b))
  5483. if rounding_mode == "trunc":
  5484. assert not both_boolean, "truncdiv operands can not be boolean at the same time"
  5485. return truncdiv(a, b) if both_integer else trunc(div(a, b))
  5486. return div(a, b)
  5487. @register_lowering([aten.mul], broadcast=True)
  5488. def mul(a, b):
  5489. both_bool = is_boolean_type(a) and is_boolean_type(b)
  5490. if both_bool:
  5491. return logical_and(a, b)
  5492. else:
  5493. fn = ops_wrapper(aten.mul.__name__)
  5494. return make_pointwise(fn)(a, b)
  5495. def get_constant_value(x: ir.IRNode) -> Optional[ir.Constant]:
  5496. """Try convert an arbitrary IR node into an ir.Constant value"""
  5497. # First try unwrapping the IRNode to see if it is already an ir.Constant
  5498. # Optional step, but avoids unnecessary inner_fn evaluation.
  5499. if isinstance(x, ir.MutableBox):
  5500. return get_constant_value(x.data)
  5501. if isinstance(x, ir.BaseView):
  5502. return get_constant_value(x.unwrap_view())
  5503. if isinstance(x, ir.Constant):
  5504. return x
  5505. # If the unwrapped node is not an ir.Constant, try evaluating inner_fn
  5506. # to see if the returned value is from an `ops.constant` call
  5507. if not isinstance(x, ir.Loops):
  5508. return None
  5509. handler = torch._inductor.ops_handler.ExtractConstantsHandler(x.get_device())
  5510. with (
  5511. V.set_ops_handler(handler),
  5512. patch.object(ir.FlexibleLayout, "allow_indexing", True),
  5513. ):
  5514. out = x.inner_fn(*x.inner_fn_args())
  5515. assert isinstance(out, torch._inductor.virtualized.OpsValue)
  5516. if isinstance(out.value, ir.Constant):
  5517. return out.value
  5518. return None
  5519. # NOTE: prims.div maps to a / b in C, so performs truncation division on
  5520. # integer inputs and true division for floating and complex inputs.
  5521. @register_lowering([prims.div], broadcast=True)
  5522. def div_prim(a, b):
  5523. is_integral = all(is_boolean_type(x) or is_integer_type(x) for x in [a, b])
  5524. if is_integral:
  5525. return truncdiv(a, b)
  5526. # Disable CPU optimization to avoid precision issues.
  5527. # see https://github.com/pytorch/pytorch/issues/157959
  5528. if (divisor := get_constant_value(b)) is not None and a.get_device().type != "cpu":
  5529. # Replace divide by constant with multiply by reciprocal
  5530. if divisor.value == 0:
  5531. reciprocal = math.copysign(float("inf"), divisor.value)
  5532. else:
  5533. reciprocal = 1.0 / divisor.value
  5534. return mul(a, reciprocal)
  5535. def fn(*args):
  5536. return ops.truediv(*args)
  5537. return make_pointwise(fn)(a, b)
  5538. @register_lowering(
  5539. [aten.true_divide, aten.div.Tensor],
  5540. broadcast=True,
  5541. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  5542. )
  5543. def div(a, b):
  5544. a, b = promote_constants(
  5545. (a, b), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
  5546. )
  5547. return div_prim(a, b)
  5548. @register_lowering([aten.fmod, prims.fmod], broadcast=True)
  5549. def fmod(a, b):
  5550. is_integral = is_boolean_type(a) or is_integer_type(a)
  5551. if is_integral:
  5552. def fn(a, b):
  5553. return ops.mod(a, b)
  5554. else:
  5555. def fn(a, b):
  5556. return ops.fmod(a, b)
  5557. return make_pointwise(fn)(a, b)
  5558. @register_lowering([aten.sum, prims.sum])
  5559. def sum_(x, axis=None, keepdims=False, *, dtype=None):
  5560. if (
  5561. is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
  5562. ) and dtype is None:
  5563. dtype = torch.int64
  5564. fn = make_reduction("sum", override_return_dtype=dtype)
  5565. return fn(x, axis, keepdims, dtype=dtype)
  5566. fallback_cumsum = fallback_handler(aten.cumsum.default)
  5567. fallback_cumprod = fallback_handler(aten.cumprod.default)
  5568. fallback_logcumsumexp = fallback_handler(aten.logcumsumexp.default)
  5569. fallback_cummax = fallback_handler(aten.cummax.default)
  5570. fallback_cummin = fallback_handler(aten.cummin.default)
  5571. @register_lowering(aten.cumsum)
  5572. def cumsum(x, axis=None, dtype=None):
  5573. if (
  5574. is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
  5575. ) and dtype is None:
  5576. dtype = torch.int64
  5577. if len(x.get_size()) == 0:
  5578. assert axis in [0, -1]
  5579. dtype = dtype or x.get_dtype()
  5580. return to_dtype(x, dtype, copy=True)
  5581. def combine_fn(a_tuple, b_tuple):
  5582. (a,) = a_tuple
  5583. (b,) = b_tuple
  5584. return (ops.add(a, b),)
  5585. kwargs = _make_scan_inner(x, axis=axis, dtype=dtype)
  5586. (result,) = ir.Scan.create(**kwargs, combine_fn=combine_fn)
  5587. if result is None:
  5588. return fallback_cumsum(x, dim=axis, dtype=dtype)
  5589. return result
  5590. @register_lowering(aten.cumprod)
  5591. def cumprod(x, axis=None, dtype=None):
  5592. if (
  5593. is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
  5594. ) and dtype is None:
  5595. dtype = torch.int64
  5596. if len(x.get_size()) == 0:
  5597. assert axis in [0, -1]
  5598. dtype = dtype or x.get_dtype()
  5599. return to_dtype(x, dtype, copy=True)
  5600. def combine_fn(a_tuple, b_tuple):
  5601. (a,) = a_tuple
  5602. (b,) = b_tuple
  5603. return (ops.mul(a, b),)
  5604. kwargs = _make_scan_inner(x, axis=axis, dtype=dtype)
  5605. (result,) = ir.Scan.create(**kwargs, combine_fn=combine_fn)
  5606. if result is None:
  5607. return fallback_cumprod(x, dim=axis, dtype=dtype)
  5608. return result
  5609. @register_lowering(aten.logcumsumexp)
  5610. def logcumsumexp(x, dim):
  5611. def log_add_exp_helper(a_tuple, b_tuple):
  5612. (a,) = a_tuple
  5613. (b,) = b_tuple
  5614. min_v = ops.minimum(a, b)
  5615. max_v = ops.maximum(a, b)
  5616. mask = (min_v != max_v) | (~ops.isinf(min_v))
  5617. return (ops.where(mask, ops.log1p(ops.exp(min_v - max_v)) + max_v, a),)
  5618. dtype = x.get_dtype()
  5619. if len(x.get_size()) == 0:
  5620. assert dim in [0, -1]
  5621. return clone(x)
  5622. kwargs = _make_scan_inner(x, axis=dim, dtype=dtype)
  5623. (result,) = ir.Scan.create(**kwargs, combine_fn=log_add_exp_helper)
  5624. if result is None:
  5625. return fallback_logcumsumexp(x, dim=dim)
  5626. return result
  5627. @register_lowering(aten.cummax, type_promotion_kind=None)
  5628. def cummax(x, axis=None):
  5629. if len(x.get_size()) == 0:
  5630. assert axis in [0, -1]
  5631. return clone(x), empty_like(x, dtype=torch.int64)
  5632. dtype = x.get_dtype()
  5633. combine_fn = ir.get_reduction_combine_fn(
  5634. "argmax", dtype=dtype, arg_break_ties_left=False
  5635. )
  5636. kwargs = _make_scan_inner(x, axis=axis, dtype=dtype)
  5637. kwargs["dtypes"] = (dtype, torch.int64)
  5638. kwargs["inner_fns"] = (
  5639. x.make_loader(),
  5640. lambda idx: ops.index_expr(idx[axis], torch.int64),
  5641. )
  5642. values, indices = ir.Scan.create(**kwargs, combine_fn=combine_fn) # type: ignore[arg-type]
  5643. if values is None:
  5644. return fallback_cummax(x, dim=axis)
  5645. return values, indices
  5646. @register_lowering(aten.cummin, type_promotion_kind=None)
  5647. def cummin(x, axis=None):
  5648. if len(x.get_size()) == 0:
  5649. assert axis in [0, -1]
  5650. return clone(x), empty_like(x, dtype=torch.int64)
  5651. dtype = x.get_dtype()
  5652. combine_fn = ir.get_reduction_combine_fn(
  5653. "argmin", dtype=dtype, arg_break_ties_left=False
  5654. )
  5655. kwargs = _make_scan_inner(x, axis=axis, dtype=dtype)
  5656. kwargs["dtypes"] = (dtype, torch.int64)
  5657. kwargs["inner_fns"] = (
  5658. x.make_loader(),
  5659. lambda idx: ops.index_expr(idx[axis], torch.int64),
  5660. )
  5661. values, indices = ir.Scan.create(**kwargs, combine_fn=combine_fn) # type: ignore[arg-type]
  5662. if values is None:
  5663. return fallback_cummin(x, dim=axis)
  5664. return values, indices
  5665. @register_lowering(aten.prod)
  5666. def prod(x, axis=None, keepdims=False, *, dtype=None):
  5667. if (
  5668. is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
  5669. ) and dtype is None:
  5670. dtype = torch.int64
  5671. fn = make_reduction("prod", override_return_dtype=dtype)
  5672. return fn(x, axis, keepdims, dtype=dtype)
  5673. @register_lowering(aten.any)
  5674. def reduce_any(x, dim=None, keepdim=False):
  5675. x = to_dtype(x, torch.bool)
  5676. return make_reduction("any")(x, axis=dim, keepdims=keepdim)
  5677. @register_lowering(aten.max, type_promotion_kind=None)
  5678. def reduce_max(x, dim=None, keepdim=False):
  5679. if dim is not None:
  5680. return (
  5681. reduce_amax(x, axis=dim, keepdims=keepdim),
  5682. reduce_argmax(x, axis=dim, keepdims=keepdim),
  5683. )
  5684. return reduce_amax(x, axis=None, keepdims=keepdim)
  5685. @register_lowering(aten.min, type_promotion_kind=None)
  5686. def reduce_min(x, dim=None, keepdim=False):
  5687. if dim is not None:
  5688. return (
  5689. reduce_amin(x, axis=dim, keepdims=keepdim),
  5690. reduce_argmin(x, axis=dim, keepdims=keepdim),
  5691. )
  5692. return reduce_amin(x, axis=None, keepdims=keepdim)
  5693. register_lowering(prims.xor_sum)(make_reduction("xor_sum"))
  5694. reduce_amax = register_lowering(aten.amax)(make_reduction("max"))
  5695. reduce_amin = register_lowering(aten.amin)(make_reduction("min"))
  5696. reduce_argmax = register_lowering(aten.argmax)(
  5697. make_reduction("argmax", override_return_dtype=torch.int64)
  5698. )
  5699. reduce_argmin = register_lowering(aten.argmin)(
  5700. make_reduction("argmin", override_return_dtype=torch.int64)
  5701. )
  5702. add = register_pointwise(
  5703. aten.add, allow_alpha=True, override_fn_when_input_bool="logical_or"
  5704. )
  5705. sort_fallback = fallback_handler(aten.sort.stable, add_to_fallback_set=False)
  5706. @register_lowering(aten.sort.stable, type_promotion_kind=None)
  5707. def sort_stable(x, *, stable=None, dim=-1, descending=False):
  5708. if stable is None:
  5709. stable = False
  5710. shape = x.get_size()
  5711. device = x.get_device()
  5712. dim = canonicalize_dim(len(shape), dim)
  5713. if len(shape) == 0:
  5714. return clone(x), _full(0, device, torch.int64, shape)
  5715. dim_size = shape[dim] if len(shape) else 1
  5716. if not V.graph.sizevars.statically_known_lt(dim_size, torch.iinfo(torch.int16).max):
  5717. return sort_fallback(x, stable=stable, dim=dim, descending=descending)
  5718. indices = iota(
  5719. dim_size, start=0, step=1, dtype=torch.int16, device=device, requires_grad=False
  5720. )
  5721. view_shape = [1] * len(shape)
  5722. if len(shape):
  5723. view_shape[dim] = dim_size
  5724. indices = view(indices, view_shape)
  5725. indices = expand(indices, shape)
  5726. values, indices = ir.Sort.create(
  5727. device=device,
  5728. dtypes=(x.dtype, indices.dtype),
  5729. inner_fns=(x.make_loader(), indices.make_loader()),
  5730. size=shape,
  5731. axis=dim,
  5732. stable=stable,
  5733. descending=descending,
  5734. )
  5735. if values is None:
  5736. return sort_fallback(x, stable=stable, dim=dim, descending=descending)
  5737. assert indices is not None
  5738. return values, to_dtype(indices, torch.int64)
  5739. @register_lowering(aten.sort.default, type_promotion_kind=None)
  5740. def sort(x, dim=-1, descending=False):
  5741. return sort_stable(x, stable=False, dim=dim, descending=descending)
  5742. def register_pointwise_numeric(op, name=None, triton_fallback=None):
  5743. return register_pointwise(
  5744. op,
  5745. name=name,
  5746. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  5747. triton_fallback=triton_fallback,
  5748. )
  5749. def register_pointwise_numeric_ldf64(op: torch._ops.OpOverloadPacket):
  5750. register_op_requires_libdevice_fp64(op.__name__)
  5751. return register_pointwise(
  5752. op,
  5753. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  5754. )
  5755. rsqrt = register_pointwise_numeric(aten.rsqrt)
  5756. exp = register_pointwise_numeric_ldf64(aten.exp)
  5757. exp2 = register_pointwise_numeric(aten.exp2)
  5758. expm1 = register_pointwise_numeric(aten.expm1)
  5759. relu = register_pointwise(aten.relu)
  5760. sigmoid = register_pointwise_numeric_ldf64(aten.sigmoid)
  5761. sqrt = register_pointwise_numeric_ldf64(aten.sqrt)
  5762. square = register_pointwise(aten.square)
  5763. sub = register_pointwise(aten.sub, allow_alpha=True)
  5764. @register_lowering(aten.addcmul, broadcast=True)
  5765. def addcmul(self, tensor1, tensor2, *, value=1):
  5766. """
  5767. Computes self + value * tensor1 * tensor2 using FMA for better precision.
  5768. Matches eager CUDA kernel order: self + value * (tensor1 * tensor2)
  5769. This is computed as: fma(value, tensor1 * tensor2, self)
  5770. Note: FMA is only used for floating-point types on non-AMD GPUs. For integer types,
  5771. we fall back to regular arithmetic since FMA doesn't support integers.
  5772. For floating-point types, we use mul_rn (round-to-nearest multiplication)
  5773. to force rounding of the product before the FMA. This prevents Triton's
  5774. compiler from fusing the multiplication with the FMA, matching eager's
  5775. rounding behavior.
  5776. When emulate_precision_casts is False, we return NotImplemented to use the
  5777. decomposition instead.
  5778. """
  5779. if not config.emulate_precision_casts:
  5780. return NotImplemented
  5781. dtype = get_promoted_dtype(
  5782. self,
  5783. tensor1,
  5784. tensor2,
  5785. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  5786. )
  5787. self_loader = self.make_loader()
  5788. t1_loader = tensor1.make_loader()
  5789. t2_loader = tensor2.make_loader()
  5790. # FMA is only available for floating-point types on non-AMD GPUs
  5791. use_fma = dtype.is_floating_point and not torch.version.hip
  5792. def inner_fn(idx):
  5793. self_val = self_loader(idx)
  5794. t1_val = t1_loader(idx)
  5795. t2_val = t2_loader(idx)
  5796. if value == 1 and use_fma:
  5797. return ops.fma(t1_val, t2_val, self_val)
  5798. # Match eager order: self + value * (tensor1 * tensor2)
  5799. # Compute tensor1 * tensor2 first
  5800. if use_fma:
  5801. # Use mul_rn to force rounding of the product, preventing Triton
  5802. # from fusing t1*t2 with the subsequent FMA
  5803. t1_times_t2 = ops.mul_rn(t1_val, t2_val)
  5804. else:
  5805. t1_times_t2 = ops.mul(t1_val, t2_val)
  5806. # Use index_expr for sympy expressions (e.g., from .item()), constant otherwise
  5807. if isinstance(value, sympy.Basic):
  5808. value_expr = ops.index_expr(value, dtype)
  5809. else:
  5810. value_expr = ops.constant(value, dtype)
  5811. if use_fma:
  5812. # Use FMA for floating-point types for better precision
  5813. return ops.fma(value_expr, t1_times_t2, self_val)
  5814. else:
  5815. # Fall back to regular arithmetic for integer types
  5816. return ops.add(self_val, ops.mul(value_expr, t1_times_t2))
  5817. return Pointwise.create(
  5818. device=self.get_device(),
  5819. dtype=dtype,
  5820. inner_fn=inner_fn,
  5821. ranges=self.get_size(),
  5822. )
  5823. def _foreach_addcmul_scalar(self, tensor1, tensor2, value=1):
  5824. """
  5825. Foreach version of addcmul with scalar value parameter.
  5826. Uses foreach_group_loop for consistent grouping behavior.
  5827. When emulate_precision_casts is False, we return NotImplemented to use the
  5828. decomposition instead.
  5829. """
  5830. if not config.emulate_precision_casts:
  5831. return NotImplemented
  5832. realize_outputs = (
  5833. len(V.graph.current_node.users) == 0
  5834. or V.graph.current_node.target in inplace_foreach_ops
  5835. or cur_node_has_non_foreach_users()
  5836. )
  5837. groups = group_foreach_args(zip(self, tensor1, tensor2))
  5838. def apply_fn(args):
  5839. return addcmul(*args, value=value)
  5840. return foreach_group_loop(groups, len(self), apply_fn, realize_outputs)
  5841. _register_foreach_lowering(aten._foreach_addcmul.Scalar, _foreach_addcmul_scalar)
  5842. register_pointwise_numeric_ldf64(aten.cos)
  5843. register_pointwise_numeric_ldf64(aten.sin)
  5844. abs = register_pointwise(aten.abs)
  5845. bitwise_and = register_pointwise(aten.bitwise_and)
  5846. bitwise_left_shift = register_pointwise(aten.bitwise_left_shift)
  5847. bitwise_not = register_pointwise(
  5848. aten.bitwise_not, override_fn_when_input_bool="logical_not"
  5849. )
  5850. bitwise_or = register_pointwise(aten.bitwise_or)
  5851. bitwise_right_shift = register_pointwise(aten.bitwise_right_shift)
  5852. bitwise_xor = register_pointwise(aten.bitwise_xor)
  5853. register_pointwise_numeric(aten.lgamma)
  5854. erf = register_pointwise_numeric(aten.erf)
  5855. register_lowering(
  5856. aten.special_erf, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
  5857. )(erf)
  5858. register_pointwise_numeric(aten.log1p)
  5859. register_pointwise_numeric(aten.tan)
  5860. register_pointwise_numeric(aten.tanh)
  5861. register_pointwise_numeric_ldf64(aten.log)
  5862. logical_and = register_pointwise(
  5863. aten.logical_and,
  5864. type_promotion_kind=None,
  5865. convert_input_to_bool=True,
  5866. override_return_dtype=torch.bool,
  5867. )
  5868. logical_not = register_pointwise(
  5869. aten.logical_not,
  5870. type_promotion_kind=None,
  5871. convert_input_to_bool=True,
  5872. override_return_dtype=torch.bool,
  5873. )
  5874. logical_or = register_pointwise(
  5875. aten.logical_or,
  5876. type_promotion_kind=None,
  5877. convert_input_to_bool=True,
  5878. override_return_dtype=torch.bool,
  5879. )
  5880. logical_xor = register_pointwise(
  5881. aten.logical_xor,
  5882. type_promotion_kind=None,
  5883. convert_input_to_bool=True,
  5884. override_return_dtype=torch.bool,
  5885. )
  5886. maximum = register_pointwise(aten.maximum)
  5887. minimum = register_pointwise(aten.minimum)
  5888. register_lowering(aten.clamp_min)(maximum)
  5889. register_lowering(aten.clamp_max)(minimum)
  5890. neg = register_pointwise(aten.neg)
  5891. abs = register_pointwise(aten.abs)
  5892. reciprocal = register_pointwise_numeric(aten.reciprocal)
  5893. register_pointwise(aten.remainder)
  5894. sign = register_pointwise(aten.sign, override_fn_when_input_bool="identity")
  5895. register_pointwise(aten.ceil)
  5896. register_pointwise(aten.signbit, override_return_dtype=torch.bool)
  5897. register_lowering(aten._neg_view)(neg)
  5898. register_pointwise(aten.le, override_return_dtype=torch.bool)
  5899. register_pointwise(aten.lt, override_return_dtype=torch.bool)
  5900. register_pointwise(aten.ge, override_return_dtype=torch.bool)
  5901. gt = register_pointwise(aten.gt, override_return_dtype=torch.bool)
  5902. register_pointwise(aten.eq, override_return_dtype=torch.bool)
  5903. register_pointwise(aten.ne, override_return_dtype=torch.bool)
  5904. register_pointwise_numeric(aten.cosh)
  5905. register_pointwise_numeric(aten.sinh)
  5906. register_pointwise_numeric(aten.acos)
  5907. register_pointwise_numeric(aten.acosh)
  5908. register_pointwise_numeric(aten.asin)
  5909. register_pointwise_numeric(aten.asinh)
  5910. register_pointwise_numeric(aten.atan2)
  5911. register_pointwise_numeric(aten.atan)
  5912. register_pointwise_numeric(aten.atanh)
  5913. register_pointwise_numeric(aten.copysign)
  5914. register_pointwise_numeric(aten.erfc)
  5915. register_pointwise_numeric(aten.erfinv)
  5916. register_pointwise_numeric(aten.hypot)
  5917. register_pointwise_numeric(aten.log10)
  5918. register_pointwise_numeric(aten.log2)
  5919. register_pointwise_numeric(aten.nextafter)
  5920. from .codegen.common import BackendFeature, pointwise_overrides_data
  5921. def _get_pointwise_overrides(ns, name):
  5922. data = pointwise_overrides_data[name]
  5923. op = getattr(ns, data.name, None)
  5924. if op is None:
  5925. return
  5926. def make_triton_fallback(op):
  5927. if data.triton is None:
  5928. return fallback_handler(op)
  5929. if isinstance(op, torch._ops.OpOverloadPacket):
  5930. for olname in op.overloads():
  5931. ol = getattr(op, olname)
  5932. yield ol, data.type_promotion_kind, make_triton_fallback(ol)
  5933. else:
  5934. yield op, data.type_promotion_kind, make_triton_fallback(op)
  5935. for name in pointwise_overrides_data:
  5936. for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides(
  5937. aten, name
  5938. ):
  5939. register_pointwise(
  5940. op,
  5941. name=name,
  5942. type_promotion_kind=type_promotion_kind,
  5943. triton_fallback=triton_fallback,
  5944. )
  5945. for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides(
  5946. prims, name
  5947. ):
  5948. register_pointwise(
  5949. op,
  5950. name=name,
  5951. type_promotion_kind=type_promotion_kind,
  5952. triton_fallback=triton_fallback,
  5953. )
  5954. foreach_add_list = register_foreach_pointwise(
  5955. aten._foreach_add.List, add, allow_alpha=True
  5956. )
  5957. foreach_add_scalar = register_foreach_pointwise(
  5958. aten._foreach_add.Scalar, add, allow_alpha=True
  5959. )
  5960. register_foreach_pointwise(aten._foreach_add.Tensor, add, allow_alpha=True)
  5961. foreach_mul_list = register_foreach_pointwise(aten._foreach_mul.List, mul)
  5962. register_foreach_pointwise(aten._foreach_mul.Tensor, mul)
  5963. foreach_mul_scalar = register_foreach_pointwise(aten._foreach_mul.Scalar, mul)
  5964. register_foreach_pointwise(aten._foreach_sub.List, sub)
  5965. register_foreach_pointwise(aten._foreach_sub.Scalar, sub)
  5966. register_foreach_pointwise(aten._foreach_neg.default, neg)
  5967. register_foreach_pointwise(aten._foreach_abs.default, abs)
  5968. register_foreach_pointwise(aten._foreach_pow.Scalar, pow)
  5969. register_foreach_pointwise(aten._foreach_pow.List, pow)
  5970. register_foreach_pointwise(aten._foreach_pow.ScalarAndTensor, pow)
  5971. foreach_div_list = register_foreach_pointwise(aten._foreach_div.List, div)
  5972. register_foreach_pointwise(aten._foreach_div.Tensor, div)
  5973. foreach_div_scalar = register_foreach_pointwise(aten._foreach_div.Scalar, div)
  5974. register_foreach_pointwise(aten._foreach_sqrt, sqrt)
  5975. register_foreach_pointwise(aten._foreach_rsqrt, rsqrt)
  5976. register_foreach_pointwise(aten._foreach_maximum.List, maximum)
  5977. register_foreach_pointwise(aten._foreach_maximum.Scalar, maximum)
  5978. register_foreach_pointwise(aten._foreach_minimum.List, minimum)
  5979. register_foreach_pointwise(aten._foreach_minimum.Scalar, minimum)
  5980. register_foreach_pointwise(aten._foreach_clamp_min.List, maximum)
  5981. register_foreach_pointwise(aten._foreach_clamp_min.Scalar, maximum)
  5982. register_foreach_pointwise(aten._foreach_clamp_max.List, minimum)
  5983. register_foreach_pointwise(aten._foreach_clamp_max.Scalar, minimum)
  5984. register_foreach_pointwise(aten._foreach_reciprocal, reciprocal)
  5985. register_foreach_pointwise(aten._foreach_sign, sign)
  5986. foreach_copy = register_foreach_pointwise(aten._foreach_copy, copy)
  5987. # these are only encountered as outputs of the graph
  5988. # reinplacing epilogue copies improves compile time
  5989. # by removing extra buffers sent to the scheduler.
  5990. def register_foreach_inplace(aten_op, outplace_aten_op, outplace_op):
  5991. inplaceable_foreach_ops[outplace_aten_op] = aten_op
  5992. inplace_foreach_ops.add(aten_op)
  5993. def fn(*args, **kwargs):
  5994. results = outplace_op(*args, **kwargs)
  5995. mut_results = []
  5996. for arg, result in zip(args[0], results):
  5997. mut_results.append(mutate_to(arg, result, unsafe_alias=True))
  5998. return mut_results
  5999. _register_foreach_lowering(aten_op, fn)
  6000. register_foreach_inplace(
  6001. aten._foreach_add_.List, aten._foreach_add.List, foreach_add_list
  6002. )
  6003. register_foreach_inplace(
  6004. aten._foreach_add_.Scalar, aten._foreach_add.Scalar, foreach_add_scalar
  6005. )
  6006. register_foreach_inplace(
  6007. aten._foreach_mul_.List, aten._foreach_mul.List, foreach_mul_list
  6008. )
  6009. register_foreach_inplace(
  6010. aten._foreach_mul_.Scalar, aten._foreach_mul.Scalar, foreach_mul_scalar
  6011. )
  6012. register_foreach_inplace(
  6013. aten._foreach_div_.List, aten._foreach_div.List, foreach_div_list
  6014. )
  6015. register_foreach_inplace(
  6016. aten._foreach_div_.Scalar, aten._foreach_div.Scalar, foreach_div_scalar
  6017. )
  6018. register_foreach_inplace(
  6019. aten._foreach_copy_.default, aten._foreach_copy.default, foreach_copy
  6020. )
  6021. def register_inplace(aten_op, outplace_op):
  6022. @register_lowering(aten_op, type_promotion_kind=None)
  6023. def fn(*args, **kwargs):
  6024. result = outplace_op(*args, **kwargs)
  6025. result = to_dtype(result, args[0].get_dtype())
  6026. return mutate_to(args[0], result)
  6027. return fn
  6028. register_inplace(aten.add_, add)
  6029. register_inplace(aten.bitwise_and_, bitwise_and)
  6030. register_inplace(aten.bitwise_left_shift_, bitwise_left_shift)
  6031. register_inplace(aten.bitwise_not_, bitwise_not)
  6032. register_inplace(aten.bitwise_or_, bitwise_or)
  6033. register_inplace(aten.bitwise_right_shift_, bitwise_right_shift)
  6034. register_inplace(aten.bitwise_xor_, bitwise_xor)
  6035. register_inplace(aten.mul_, mul)
  6036. register_inplace(aten.div_.Tensor, div)
  6037. register_inplace(aten.div_.Tensor_mode, div_mode)
  6038. register_inplace(aten.logical_and_, logical_and)
  6039. register_inplace(aten.logical_not_, logical_not)
  6040. register_inplace(aten.logical_or_, logical_or)
  6041. register_inplace(aten.logical_xor_, logical_xor)
  6042. register_inplace(aten.sub_, sub)
  6043. register_inplace(aten.relu_, relu)
  6044. register_inplace(aten.sigmoid_, sigmoid)
  6045. register_lowering(aten.__and__)(bitwise_and)
  6046. register_lowering(aten.__lshift__)(bitwise_left_shift)
  6047. register_lowering(aten.__or__)(bitwise_or)
  6048. register_lowering(aten.__rshift__)(bitwise_right_shift)
  6049. register_lowering(aten.__xor__)(bitwise_xor)
  6050. register_inplace(aten.__iand__, aten.__and__)
  6051. register_inplace(aten.__ilshift__, aten.__lshift__)
  6052. register_inplace(aten.__ior__, aten.__or__)
  6053. register_inplace(aten.__irshift__, aten.__rshift__)
  6054. register_inplace(aten.__ixor__, aten.__xor__)
  6055. @register_lowering(aten.sym_constrain_range)
  6056. def sym_constrain_range(a, min=None, max=None):
  6057. return None
  6058. @register_lowering(aten.sym_size.int)
  6059. def sym_size(a, dim):
  6060. val = V.graph.current_node.meta["val"]
  6061. if isinstance(val, torch.SymInt):
  6062. return val.node.expr
  6063. else:
  6064. return int(val)
  6065. @register_lowering(aten.sym_stride.int)
  6066. def sym_stride(a, dim):
  6067. val = V.graph.current_node.meta["val"]
  6068. if isinstance(val, torch.SymInt):
  6069. return val.node.expr
  6070. else:
  6071. return int(val)
  6072. @register_lowering(aten.sym_numel)
  6073. def sym_numel(a):
  6074. return a.get_numel()
  6075. for method, func in magic_methods.items():
  6076. register_lowering(method_to_operator(method))(func) # type: ignore[arg-type]
  6077. @register_lowering(torch.sym_sum)
  6078. def sym_sum(args):
  6079. return sympy.Add(*args)
  6080. @register_lowering(aten._foobar)
  6081. def foobar(self, *args, **kwargs):
  6082. raise NotImplementedError("Helpful for debugging")
  6083. @register_lowering(torch.ops._inductor_test.realize)
  6084. def _realize(x):
  6085. x.realize()
  6086. return clone(x)
  6087. @register_lowering(torch.ops.inductor.resize_storage_bytes_)
  6088. def resize_storage_bytes_(variable, new_size):
  6089. variable.realize()
  6090. ir.ResizeStorageBytes(variable, new_size)
  6091. return variable
  6092. @register_lowering(torch.ops.aten.set_.source_Tensor)
  6093. def set__source_tensor(self, source_tensor):
  6094. self.realize()
  6095. source_tensor.realize()
  6096. return TensorBox.create(ir.SetSourceTensorKernel(self, source_tensor))
  6097. if hasattr(torch.ops.fsdp, "copy_"):
  6098. @register_lowering(torch.ops.fsdp.copy_.default)
  6099. def fsdp_copy_(dst, src):
  6100. if dst is src:
  6101. # dst.copy_(dst) can happen from the reinplacing pass
  6102. return dst
  6103. src = to_device(src, dst.get_device())
  6104. src = to_dtype(src, dst.get_dtype())
  6105. src = expand(src, dst.get_size())
  6106. return mutate_to(dst, src)
  6107. @register_lowering(torch.ops.aten.resize)
  6108. def resize(x, size, *, memory_format=None):
  6109. assert isinstance(x, TensorBox)
  6110. assert isinstance(size, (list, tuple))
  6111. if memory_format is None:
  6112. memory_format = torch.contiguous_format
  6113. if memory_format == torch.preserve_format:
  6114. raise RuntimeError(f"unsupported memory format: {memory_format}")
  6115. if memory_format == torch.channels_last:
  6116. assert len(size) == 4
  6117. if memory_format == torch.channels_last_3d:
  6118. assert len(size) == 5
  6119. old_numel = x.get_numel()
  6120. dtype = x.get_dtype()
  6121. device = x.get_device_or_error()
  6122. if isinstance(x.data, ir.BaseView):
  6123. x.data = x.data.unwrap_view()
  6124. if (
  6125. torch.are_deterministic_algorithms_enabled()
  6126. and torch.utils.deterministic.fill_uninitialized_memory # type: ignore[attr-defined]
  6127. ):
  6128. if is_float_dtype(dtype):
  6129. uninitialized_val = float("nan")
  6130. elif is_integer_dtype(dtype):
  6131. uninitialized_val = torch.iinfo(dtype).max
  6132. else:
  6133. uninitialized_val = True
  6134. else:
  6135. # using zero as that is what empty does
  6136. uninitialized_val = 0.0
  6137. if V.graph.sizevars.statically_known_equals(old_numel, 0): # type: ignore[arg-type]
  6138. return full(size, uninitialized_val, dtype=dtype, device=device)
  6139. x_flat = as_strided(
  6140. x,
  6141. [
  6142. old_numel,
  6143. ],
  6144. [
  6145. 1,
  6146. ],
  6147. )
  6148. flat_loader = x_flat.make_loader()
  6149. out_stride = ir.FlexibleLayout.stride_ordered_for_memory_format(size, memory_format)
  6150. out_indexer = ir.FixedLayout(device, dtype, size, out_stride).make_indexer()
  6151. def inner_fn(idx):
  6152. flat_index = out_indexer(idx)
  6153. flat_index_expr = ops.index_expr(flat_index, torch.int64)
  6154. limit = ops.index_expr(old_numel, torch.int64)
  6155. mask = ops.lt(flat_index_expr, limit)
  6156. return ops.masked(mask, lambda: flat_loader([flat_index]), uninitialized_val)
  6157. out = Pointwise.create(
  6158. device=device, dtype=dtype, inner_fn=inner_fn, ranges=list(size)
  6159. )
  6160. return out
  6161. from torch._higher_order_ops.auto_functionalize import auto_functionalized
  6162. make_fallback(auto_functionalized)
  6163. @register_lowering(triton_kernel_wrapper_mutation)
  6164. def triton_kernel_wrap_(
  6165. *,
  6166. kernel_idx,
  6167. constant_args_idx,
  6168. grid,
  6169. tma_descriptor_metadata,
  6170. kwargs,
  6171. ):
  6172. from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
  6173. constant_args = kernel_side_table.get_constant_args(constant_args_idx)
  6174. ir.UserDefinedTritonKernel(
  6175. kernel_idx=kernel_idx,
  6176. grid=grid,
  6177. tma_descriptor_metadata=tma_descriptor_metadata,
  6178. kernel_args={**kwargs, **constant_args},
  6179. )
  6180. return {key: val for key, val in kwargs.items() if isinstance(val, TensorBox)}
  6181. @register_lowering(torch.ops.higher_order.cond, type_promotion_kind=None)
  6182. def cond(
  6183. pred, true_fn, false_fn, operands
  6184. ) -> list[Union[ir.TensorBox, ir.ShapeAsConstantBuffer]]:
  6185. # TODO: when graph_partition is enabled, skip - partitioning handles control flow
  6186. # we run into memory cleanup issue
  6187. if any(isinstance(x, IRNode) and is_triton(x) for x in [pred, *operands]):
  6188. msg = "control flow operator: torch.cond."
  6189. if stack_trace := V.graph.current_node.meta.get("stack_trace", None):
  6190. msg = f"{msg} Found from : \n {stack_trace}"
  6191. V.graph.disable_cudagraphs_reason = msg
  6192. result = ir.Conditional.create(pred, true_fn, false_fn, operands)
  6193. return list(map(TensorBox.create, result)) # pyrefly: ignore no-matching-overload
  6194. @register_lowering(torch.ops.higher_order.while_loop, type_promotion_kind=None)
  6195. def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs, stack_output=False):
  6196. # TODO: when graph_partition is enabled, skip - partitioning handles control flow
  6197. # we run into memory cleanup issue
  6198. if not config.graph_partition and any(
  6199. isinstance(x, IRNode) and is_triton(x)
  6200. for x in carried_inputs + additional_inputs
  6201. ):
  6202. msg = "control flow operator: torch.while_loop."
  6203. if stack_trace := V.graph.current_node.meta.get("stack_trace", None):
  6204. msg = f"{msg} Found from : \n {stack_trace}"
  6205. V.graph.disable_cudagraphs_reason = msg
  6206. result = ir.WhileLoop.create(
  6207. cond_fn, body_fn, carried_inputs, additional_inputs, stack_output
  6208. )
  6209. assert isinstance(result, Sequence)
  6210. return list(map(ir.WhileLoop._maybe_wrap_as_tensor_box, result))
  6211. register_lowering(
  6212. torch.ops.higher_order.while_loop_stack_output, type_promotion_kind=None
  6213. )(functools.partial(while_loop, stack_output=True))
  6214. @register_lowering(torch.ops.higher_order.invoke_subgraph, type_promotion_kind=None)
  6215. def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, *operands):
  6216. result = ir.InvokeSubgraph.create(subgraph_fn, *operands)
  6217. return list(map(TensorBox.create, result)) # type: ignore[call-overload]
  6218. def process_subgraph_nodes(graph_module: torch.fx.GraphModule, args: list[Any]):
  6219. """Process nodes from a FX graph by executing them through V.graph.
  6220. This is a common pattern for executing a subgraph's nodes:
  6221. - Placeholder nodes are mapped to the provided args
  6222. - Output nodes return their result
  6223. - Other nodes are executed via V.graph.run_node
  6224. """
  6225. output = None
  6226. for i, node in enumerate(graph_module.graph.nodes):
  6227. if node.op == "placeholder":
  6228. assert node not in V.graph.env
  6229. V.graph.env[node] = args[i]
  6230. continue
  6231. elif node.op == "output":
  6232. output_args, kwargs = V.graph.fetch_args_kwargs_from_env(node)
  6233. output = torch.fx.Interpreter.output(V.graph, node, output_args, kwargs)
  6234. else:
  6235. assert node not in V.graph.env
  6236. # Track current node for error diagnostics; restore after run_node to handle nested calls correctly
  6237. saved_current_node = V.graph.current_node
  6238. try:
  6239. V.graph.current_node = node
  6240. V.graph.env[node] = V.graph.run_node(node)
  6241. finally:
  6242. V.graph.current_node = saved_current_node
  6243. if output is None:
  6244. raise RuntimeError("No output node found in graph")
  6245. return output
  6246. # Import the control_deps_op HOP for lowering
  6247. from torch._inductor.fx_passes.control_dependencies import control_deps
  6248. @register_lowering(control_deps, type_promotion_kind=None)
  6249. def control_deps_op_lowering(additional_deps, subgraph_fn, *args):
  6250. """
  6251. Lower control_deps_op by ensuring dependencies are realized and tracking them.
  6252. The control_deps_op HOP makes dependencies explicit in the graph. During lowering:
  6253. 1. Realize all additional dependencies to ensure they're computed
  6254. 2. Execute the target operation normally
  6255. 3. Track the dependencies for the scheduler
  6256. """
  6257. # Realize all additional dependencies
  6258. dep_names = []
  6259. for dep in additional_deps:
  6260. if not isinstance(dep, IRNode):
  6261. continue
  6262. dep.realize()
  6263. dep_names.append(dep.get_name())
  6264. original_args = V.graph.current_node.args
  6265. arg_offset = 2 # first two args (additional_deps, subgraph)
  6266. assert len(args) + arg_offset == len(original_args)
  6267. operation_len = len(V.graph.operations)
  6268. assert len(subgraph_fn.graph_module.graph.find_nodes(op="placeholder")) == len(args)
  6269. # Process subgraph nodes using the shared helper
  6270. output = process_subgraph_nodes(subgraph_fn.graph_module, list(args))
  6271. assert output is not None and additional_deps
  6272. # some operators, like wait_tensor, just return their input,
  6273. # so its more robust to add dep to the operation itself,
  6274. # otherwise you can have a cycle of
  6275. # a = coll
  6276. # b = control_deps(a, mm, ...)
  6277. # c = control_deps(b, wait, ...)
  6278. # if c == a, then you have a cycle.
  6279. for op in V.graph.operations[operation_len:]:
  6280. for dep_name in dep_names:
  6281. op_name = op.operation_name
  6282. assert op_name is not None
  6283. V.graph.additional_buffer_deps[op_name].add(dep_name)
  6284. return output
  6285. @register_lowering(torch._higher_order_ops.invoke_quant, type_promotion_kind=None)
  6286. def invoke_quant_tracer(subgraph_fn: ir.Subgraph, *operands, scheme=None):
  6287. output = None
  6288. quant_options = V.graph.current_node.meta.get("quant_options", None)
  6289. assert quant_options is not None
  6290. for i, node in enumerate(subgraph_fn.graph_module.graph.nodes):
  6291. if node.op == "placeholder":
  6292. V.graph.env[node] = operands[i]
  6293. continue
  6294. # todo getattr
  6295. elif node.op == "output":
  6296. args, kwargs = V.graph.fetch_args_kwargs_from_env(node)
  6297. for v in itertools.chain(args, kwargs.values()):
  6298. v.realize()
  6299. if quant_options.codegen_low_precision:
  6300. V.graph.low_precision_codegen_ops.add(v.get_operation_name())
  6301. V.graph.invoke_quant_ops.add(v.get_operation_name())
  6302. output = torch.fx.Interpreter.output(V.graph, node, args, kwargs)
  6303. else:
  6304. V.graph.env[node] = V.graph.run_node(node)
  6305. return output
  6306. @register_lowering(associative_scan_op, type_promotion_kind=None)
  6307. def associative_scan(
  6308. combine_fn: ir.Subgraph, xs, additional_inputs: tuple[torch.Tensor]
  6309. ):
  6310. from .subgraph_lowering import InputDescriptor, lower_pointwise_subgraph
  6311. if len(additional_inputs) > 0:
  6312. raise RuntimeError(
  6313. "Unable to generate code for associative_scan op, because there are lifted arguments"
  6314. )
  6315. subgraph_inputs = [
  6316. InputDescriptor(dtype=x.get_dtype(), device=x.get_device())
  6317. for x in itertools.chain(xs, xs)
  6318. ]
  6319. lowered_combine_fn = lower_pointwise_subgraph(combine_fn, subgraph_inputs) # type: ignore[var-annotated]
  6320. def wrapped_combine_fn(lhs, rhs):
  6321. return lowered_combine_fn(
  6322. *pytree.tree_leaves(lhs),
  6323. *pytree.tree_leaves(rhs),
  6324. )
  6325. kwargs = _make_scan_inner(xs[0], axis=0, dtype=None)
  6326. kwargs["dtypes"] = tuple(x.get_dtype() for x in xs)
  6327. kwargs["inner_fns"] = tuple(x.make_loader() for x in xs)
  6328. result = ir.Scan.create(
  6329. combine_fn=wrapped_combine_fn,
  6330. can_fallback_to_aten=False,
  6331. **kwargs,
  6332. )
  6333. if result[0] is None:
  6334. raise RuntimeError("Unable to generate code for associative_scan op")
  6335. return result
  6336. @register_lowering(torch.ops.prims._sink_tokens.default)
  6337. def _sink_tokens(tokens):
  6338. return None
  6339. @register_lowering(torch.ops.prims._make_token.default)
  6340. def _make_token():
  6341. return None
  6342. @register_lowering(torch.ops.higher_order.with_effects, type_promotion_kind=None)
  6343. def with_effects(token, op, *args, **kwargs):
  6344. """
  6345. We lower the operator directly, and then we add StarDep dependencies to all
  6346. the newly created nodes in the graph.
  6347. """
  6348. from torch._higher_order_ops.effects import _get_effect, _get_schema
  6349. # Get effect type
  6350. effect_type = _get_effect(op)
  6351. if effect_type is None and op is torch.ops.higher_order.invoke_subgraph:
  6352. from torch._guards import InvokeSubgraphCache, TracingContext
  6353. tracing_ctx = TracingContext.try_get()
  6354. if tracing_ctx:
  6355. invoke_subgraph_cache = tracing_ctx.hop_dispatch_set_cache.get_cache(
  6356. torch.ops.higher_order.invoke_subgraph
  6357. )
  6358. if invoke_subgraph_cache:
  6359. assert isinstance(invoke_subgraph_cache, InvokeSubgraphCache)
  6360. # args[1] is identifier
  6361. effects = invoke_subgraph_cache.get_effects(args[1])
  6362. if effects:
  6363. assert len(effects) == 1, "Multiple effects NYI"
  6364. effect_type = next(iter(effects))
  6365. # Track operations before
  6366. operation_len = len(V.graph.operations)
  6367. # Lower the op
  6368. if op in lowerings:
  6369. result = lowerings[op](*args, **kwargs)
  6370. # Realize so that we can get the ops to show up in V.graph.operations
  6371. pytree.tree_map_only(TensorBox, lambda a: a.realize(), result)
  6372. else:
  6373. def wrap_tensors(x):
  6374. return TensorBox.create(x) if isinstance(x, ir.IRNode) else x
  6375. result = pytree.tree_map(
  6376. wrap_tensors, ir.FallbackKernel.create(op, *args, **kwargs)
  6377. )
  6378. # Get all the operations created during the lowering above, and add StarDeps
  6379. # to the previous node with the same effect
  6380. assert len(V.graph.operations[operation_len:]) > 0, (
  6381. f"No operation nodes were generated when lowering effectful operator {op}."
  6382. )
  6383. if effect_type:
  6384. prev_effect_buffer = V.graph.effectful_ops.get(effect_type)
  6385. for new_op in V.graph.operations[operation_len:]:
  6386. # Patch has_side_effects to return True
  6387. new_op.has_side_effects = lambda: True # pyrefly: ignore[missing-attribute]
  6388. if prev_effect_buffer:
  6389. op_name = new_op.get_name() # pyrefly: ignore[missing-attribute]
  6390. V.graph.additional_star_deps[op_name].add(prev_effect_buffer.get_name())
  6391. # Update the effectful ops chain to point to the latest operation
  6392. V.graph.effectful_ops[effect_type] = (
  6393. new_op # pyrefly: ignore[unsupported-operation]
  6394. )
  6395. try:
  6396. def convert_ir_to_value(a):
  6397. if isinstance(a, ir.TorchBindObject):
  6398. return a.get_value()
  6399. elif isinstance(a, TensorBox):
  6400. # TensorBox wraps StorageBox, which wraps the actual buffer
  6401. # We need to get the example tensor from the inner buffer
  6402. try:
  6403. storage = a.data
  6404. if hasattr(storage, "data") and hasattr(
  6405. storage.data, "get_example"
  6406. ):
  6407. return storage.data.get_example()
  6408. except (AttributeError, NotImplementedError):
  6409. pass
  6410. # Fall back to returning the TensorBox itself if get_example fails
  6411. return a
  6412. return a
  6413. schema_args, schema_kwargs = pytree.tree_map(
  6414. convert_ir_to_value, (args, kwargs)
  6415. )
  6416. schema = _get_schema(op, schema_args, schema_kwargs)
  6417. except RuntimeError as e:
  6418. error_msg = str(e)
  6419. log.warning(
  6420. "Failed to get schema for %s: %s. Assuming list output", op, error_msg
  6421. )
  6422. if isinstance(result, (tuple, list)):
  6423. return (token, *result)
  6424. else:
  6425. return (token, result)
  6426. if len(schema.returns) == 0:
  6427. return (token, result)
  6428. elif len(schema.returns) == 1:
  6429. return (token, result)
  6430. else:
  6431. return (token, *result)
  6432. from .comm_lowering import register_comm_lowerings, register_symm_mem_lowerings
  6433. register_comm_lowerings()
  6434. register_symm_mem_lowerings()
  6435. @register_lowering(inductor_prims.prepare_softmax_online, type_promotion_kind=None)
  6436. def prepare_softmax_online(x, dim):
  6437. """
  6438. Lowering inductor_prims.prepare_softmax_online to compute max/sum in one pass if no split is needed.
  6439. """
  6440. kwargs = _make_reduction_inner(
  6441. x, axis=dim, keepdims=True, dtype=None, override_return_dtype=None
  6442. )
  6443. reduction_ranges = kwargs["reduction_ranges"]
  6444. rnumel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
  6445. hint, num_split = ir.Reduction.num_splits(
  6446. **kwargs,
  6447. reduction_type="online_softmax_reduce", # type: ignore[arg-type]
  6448. reduction_numel=rnumel,
  6449. )
  6450. if num_split == 1 and V.graph.sizevars.statically_known_geq(
  6451. rnumel, config.unroll_reductions_threshold
  6452. ):
  6453. max_tensor, sum_tensor = OnlineSoftmaxReduction.create(
  6454. input_node=x, num_output=2, reduction_hint=hint, **kwargs
  6455. )
  6456. return max_tensor, sum_tensor
  6457. else:
  6458. # Note: [Split online_softmax_reduce]
  6459. # We don't split reduction for online_softmax_reduce for now.
  6460. # On one hand, supporting split reduction makes things complex since
  6461. # the split out reuctions requires 2 inputs rather than one.
  6462. # On the other hand, during training the online_softmax_reduce should
  6463. # usually don't requires a split due to large batch size
  6464. # (more specifically batch size times sequence length).
  6465. # We should support split reduction if we find legit use cases to
  6466. # motivate the work.
  6467. #
  6468. # TODO: does inference need split online_softmax_reduce?
  6469. warnings.warn(
  6470. textwrap.dedent(
  6471. """
  6472. Online softmax is disabled on the fly since Inductor decides to
  6473. split the reduction. Cut an issue to PyTorch if this is an
  6474. important use case and you want to speed it up with online
  6475. softmax.
  6476. """
  6477. )
  6478. )
  6479. amax = reduce_amax(x, dim, keepdims=True)
  6480. exp = lowerings[aten.exp](sub(x, amax))
  6481. xsum = sum_(exp, dim, keepdims=True)
  6482. return amax, xsum
  6483. def _is_sm100_or_later():
  6484. """Check if we're on SM100+ hardware (Blackwell)."""
  6485. return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (10, 0)
  6486. @register_lowering(inductor_prims.cvt_e8m0_rceil, type_promotion_kind=None)
  6487. def cvt_e8m0_rceil_lowering(inp):
  6488. """
  6489. Lowering for cvt_e8m0_rceil. Uses PTX cvt.rp.satfinite.ue8m0x2.f32 on SM100+.
  6490. The PTX instruction takes 2 float32 and outputs 2 e8m0 packed in uint16.
  6491. Currently we pass 0.0 as the second input and only use the low byte result.
  6492. """
  6493. # TODO: Optimize to process pairs (pack=2) by creating a custom Pointwise
  6494. # that loads adjacent elements, applies PTX to both, and uses a follow-up
  6495. # kernel to extract the packed uint16 results as uint8.
  6496. if not _is_sm100_or_later():
  6497. raise NotImplementedError(
  6498. "cvt_e8m0_rceil requires SM100+ (Blackwell) for PTX instruction support"
  6499. )
  6500. dtype = inp.get_dtype()
  6501. if dtype not in (torch.float32, torch.float16, torch.bfloat16):
  6502. raise ValueError(
  6503. f"cvt_e8m0_rceil requires float32, float16, or bfloat16 input, got {dtype}"
  6504. )
  6505. # Upcast bf16/fp16 to float32 for PTX instruction
  6506. if dtype != torch.float32:
  6507. inp = to_dtype(inp, torch.float32)
  6508. fn = functools.partial(
  6509. ops.inline_asm_elementwise,
  6510. asm="cvt.rp.satfinite.ue8m0x2.f32 $0, 0.0, $1;",
  6511. constraints="=h,r",
  6512. dtype=torch.uint16,
  6513. is_pure=True,
  6514. pack=1,
  6515. )
  6516. result = make_pointwise(fn)(inp)
  6517. return to_dtype(result, torch.uint8)
  6518. # populate lowerings defined in kernel/*
  6519. from . import kernel
  6520. import_submodule(kernel)
  6521. from . import quantized_lowerings
  6522. quantized_lowerings.register_quantized_ops()
  6523. quantized_lowerings.register_woq_mm_ops()
  6524. from . import mkldnn_lowerings
  6525. mkldnn_lowerings.register_onednn_fusion_ops()
  6526. from . import jagged_lowerings
  6527. jagged_lowerings.register_jagged_ops()
  6528. @contextlib.contextmanager
  6529. def force_fallback(op: torch._ops.OpOverload):
  6530. """
  6531. A context manager to force fallback an op. Used in unit test
  6532. for FallbackKernel.
  6533. """
  6534. assert isinstance(op, torch._ops.OpOverload), (
  6535. "Only OpOverload to make the clean up easier"
  6536. )
  6537. old_handler = lowerings.get(op)
  6538. try:
  6539. register_lowering(op)(fallback_handler(op))
  6540. yield
  6541. finally:
  6542. if old_handler:
  6543. lowerings[op] = old_handler
  6544. else:
  6545. lowerings.pop(op)