| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490549154925493549454955496549754985499550055015502550355045505550655075508550955105511551255135514551555165517551855195520552155225523552455255526552755285529553055315532553355345535553655375538553955405541554255435544554555465547554855495550555155525553555455555556555755585559556055615562556355645565556655675568556955705571557255735574557555765577557855795580558155825583558455855586558755885589559055915592559355945595559655975598559956005601560256035604560556065607560856095610561156125613561456155616561756185619562056215622562356245625562656275628562956305631563256335634563556365637563856395640564156425643564456455646564756485649565056515652565356545655565656575658565956605661566256635664566556665667566856695670567156725673567456755676567756785679568056815682568356845685568656875688568956905691569256935694569556965697569856995700570157025703570457055706570757085709571057115712571357145715571657175718571957205721572257235724572557265727572857295730573157325733573457355736573757385739574057415742574357445745574657475748574957505751575257535754575557565757575857595760576157625763576457655766576757685769577057715772577357745775577657775778577957805781578257835784578557865787578857895790579157925793579457955796579757985799580058015802580358045805580658075808580958105811581258135814581558165817581858195820582158225823582458255826582758285829583058315832583358345835583658375838583958405841584258435844584558465847584858495850585158525853585458555856585758585859586058615862586358645865586658675868586958705871587258735874587558765877587858795880588158825883588458855886588758885889589058915892589358945895589658975898589959005901590259035904590559065907590859095910591159125913591459155916591759185919592059215922592359245925592659275928592959305931593259335934593559365937593859395940594159425943594459455946594759485949595059515952595359545955595659575958595959605961596259635964596559665967596859695970597159725973597459755976597759785979598059815982598359845985598659875988598959905991599259935994599559965997599859996000600160026003600460056006600760086009601060116012601360146015601660176018601960206021602260236024602560266027602860296030603160326033603460356036603760386039604060416042604360446045604660476048604960506051605260536054605560566057605860596060606160626063606460656066606760686069607060716072607360746075607660776078607960806081608260836084608560866087608860896090609160926093609460956096609760986099610061016102610361046105610661076108610961106111611261136114611561166117611861196120612161226123612461256126612761286129613061316132613361346135613661376138613961406141614261436144614561466147614861496150615161526153615461556156615761586159616061616162616361646165616661676168616961706171617261736174617561766177617861796180618161826183618461856186618761886189619061916192619361946195619661976198619962006201620262036204620562066207620862096210621162126213621462156216621762186219622062216222622362246225622662276228622962306231623262336234623562366237623862396240624162426243624462456246624762486249625062516252625362546255625662576258625962606261626262636264626562666267626862696270627162726273627462756276627762786279628062816282628362846285628662876288628962906291629262936294629562966297629862996300630163026303630463056306630763086309631063116312631363146315631663176318631963206321632263236324632563266327632863296330633163326333633463356336633763386339634063416342634363446345634663476348634963506351635263536354635563566357635863596360636163626363636463656366636763686369637063716372637363746375637663776378637963806381638263836384638563866387638863896390639163926393639463956396639763986399640064016402640364046405640664076408640964106411641264136414641564166417641864196420642164226423642464256426642764286429643064316432643364346435643664376438643964406441644264436444644564466447644864496450645164526453645464556456645764586459646064616462646364646465646664676468646964706471647264736474647564766477647864796480648164826483648464856486648764886489649064916492649364946495649664976498649965006501650265036504650565066507650865096510651165126513651465156516651765186519652065216522652365246525652665276528652965306531653265336534653565366537653865396540654165426543654465456546654765486549655065516552655365546555655665576558655965606561656265636564656565666567656865696570657165726573657465756576657765786579658065816582658365846585658665876588658965906591659265936594659565966597659865996600660166026603660466056606660766086609661066116612661366146615661666176618661966206621662266236624662566266627662866296630663166326633663466356636663766386639664066416642664366446645664666476648664966506651665266536654665566566657665866596660666166626663666466656666666766686669667066716672667366746675667666776678667966806681668266836684668566866687668866896690669166926693669466956696669766986699670067016702670367046705670667076708670967106711671267136714671567166717671867196720672167226723672467256726672767286729673067316732673367346735673667376738673967406741674267436744674567466747674867496750675167526753675467556756675767586759676067616762676367646765676667676768676967706771677267736774677567766777677867796780678167826783678467856786678767886789679067916792679367946795679667976798679968006801680268036804680568066807680868096810681168126813681468156816681768186819682068216822682368246825682668276828682968306831683268336834683568366837683868396840684168426843684468456846684768486849685068516852685368546855685668576858685968606861686268636864686568666867686868696870687168726873687468756876687768786879688068816882688368846885688668876888688968906891689268936894689568966897689868996900690169026903690469056906690769086909691069116912691369146915691669176918691969206921692269236924692569266927692869296930693169326933693469356936693769386939694069416942694369446945694669476948694969506951695269536954695569566957695869596960696169626963696469656966696769686969697069716972697369746975697669776978697969806981698269836984698569866987698869896990699169926993699469956996699769986999700070017002700370047005700670077008700970107011701270137014701570167017701870197020702170227023702470257026702770287029703070317032703370347035703670377038703970407041704270437044704570467047704870497050705170527053705470557056705770587059706070617062706370647065706670677068706970707071707270737074707570767077707870797080708170827083708470857086708770887089709070917092709370947095709670977098709971007101710271037104710571067107710871097110711171127113711471157116711771187119712071217122712371247125712671277128712971307131713271337134713571367137713871397140714171427143714471457146714771487149715071517152715371547155715671577158715971607161716271637164716571667167716871697170717171727173717471757176717771787179718071817182718371847185718671877188718971907191719271937194719571967197719871997200720172027203720472057206720772087209721072117212721372147215721672177218721972207221722272237224722572267227722872297230723172327233723472357236723772387239724072417242724372447245724672477248724972507251725272537254725572567257725872597260726172627263726472657266726772687269727072717272727372747275727672777278727972807281728272837284728572867287728872897290729172927293729472957296729772987299730073017302730373047305730673077308730973107311731273137314731573167317731873197320732173227323732473257326732773287329733073317332733373347335733673377338733973407341734273437344734573467347734873497350735173527353735473557356735773587359736073617362736373647365736673677368736973707371737273737374737573767377737873797380738173827383738473857386738773887389739073917392739373947395739673977398739974007401740274037404740574067407740874097410741174127413741474157416741774187419742074217422742374247425742674277428742974307431743274337434743574367437743874397440744174427443744474457446744774487449745074517452745374547455745674577458745974607461746274637464746574667467746874697470747174727473747474757476747774787479748074817482748374847485748674877488748974907491749274937494749574967497749874997500750175027503750475057506750775087509751075117512751375147515751675177518751975207521752275237524752575267527752875297530753175327533753475357536753775387539754075417542754375447545754675477548754975507551755275537554755575567557755875597560756175627563756475657566756775687569757075717572757375747575757675777578757975807581758275837584758575867587758875897590759175927593759475957596759775987599760076017602760376047605760676077608760976107611761276137614761576167617761876197620762176227623762476257626762776287629763076317632763376347635763676377638763976407641764276437644764576467647764876497650765176527653765476557656765776587659766076617662766376647665766676677668766976707671767276737674767576767677767876797680768176827683768476857686768776887689769076917692769376947695769676977698769977007701770277037704770577067707770877097710771177127713771477157716771777187719772077217722772377247725772677277728772977307731773277337734773577367737773877397740774177427743774477457746774777487749775077517752775377547755775677577758775977607761776277637764776577667767776877697770777177727773777477757776777777787779778077817782778377847785778677877788778977907791779277937794779577967797779877997800780178027803780478057806780778087809781078117812781378147815781678177818781978207821782278237824782578267827782878297830783178327833783478357836783778387839784078417842784378447845784678477848784978507851785278537854785578567857785878597860786178627863786478657866786778687869787078717872787378747875787678777878787978807881788278837884788578867887788878897890789178927893789478957896789778987899790079017902790379047905790679077908790979107911791279137914791579167917791879197920792179227923792479257926792779287929793079317932793379347935 |
- # mypy: allow-untyped-defs
- from __future__ import annotations
- import contextlib
- import dataclasses
- import functools
- import itertools
- import logging
- import math
- import operator
- import os
- import textwrap
- import warnings
- from collections import defaultdict
- from collections.abc import Callable, Collection, Iterable, Sequence
- from typing import Any, cast, Optional, TYPE_CHECKING, TypeGuard, TypeVar, Union
- from typing_extensions import ParamSpec
- from unittest.mock import patch
- import sympy
- import torch
- import torch.ao.quantization.fx._decomposed
- import torch.fx
- import torch.utils._pytree as pytree
- from torch._dynamo.utils import counters
- from torch._higher_order_ops.associative_scan import associative_scan_op
- from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation
- from torch._library.fake_class_registry import FakeScriptObject
- from torch._library.utils import get_layout_constraint_tag
- from torch._prims_common import (
- canonicalize_dim,
- canonicalize_dims,
- check,
- dtype_to_type,
- elementwise_dtypes,
- ELEMENTWISE_TYPE_PROMOTION_KIND,
- get_computation_dtype,
- is_boolean_dtype,
- is_float_dtype,
- is_integer_dtype,
- Number,
- )
- from torch.fx.experimental.sym_node import magic_methods, method_to_operator
- from torch.fx.experimental.symbolic_shapes import (
- free_unbacked_symbols,
- has_free_unbacked_symbols,
- resolve_unbacked_bindings,
- )
- from torch.utils._ordered_set import OrderedSet
- from torch.utils._sympy.functions import (
- CeilDiv,
- FloorDiv,
- Identity,
- Mod,
- ModularIndexing,
- )
- from .._dynamo.utils import import_submodule
- from . import config, inductor_prims, ir, test_operators # NOQA: F401
- from .decomposition import decompositions, get_decompositions
- from .ir import (
- BaseView,
- DtypeView,
- ExpandView,
- IndexingConstant,
- IRNode,
- is_triton,
- MutableBox,
- OnlineSoftmaxReduction,
- ops_wrapper,
- PermuteView,
- Pointwise,
- Reduction,
- SqueezeView,
- TensorBox,
- validate_ir,
- View,
- )
- from .utils import (
- ceildiv,
- decode_device,
- is_dynamic,
- is_gpu,
- is_pointwise_use,
- is_view,
- needs_fallback_due_to_atomic_add_limitations,
- pad_listlike,
- register_op_dtype_propagation_rules,
- register_op_requires_libdevice_fp64,
- sympy_product,
- use_scatter_fallback,
- )
- from .virtualized import ops, V
- if TYPE_CHECKING:
- from .ops_handler import ReductionType
- _T = TypeVar("_T")
- _P = ParamSpec("_P")
- # TODO(jansel): we should implement decomps or lowerings for these
- # https://github.com/pytorch/torchdynamo/issues/327
- FALLBACK_ALLOW_LIST = OrderedSet(
- [
- "torchvision::roi_align",
- "aten::index_add",
- ]
- )
- log = logging.getLogger(__name__)
- lowerings: dict[Union[Callable[..., Any], str], Callable[..., Any]] = {}
- # Use maybe_layout_constraints to access this dict, we lazily register tag-based layout constraints
- _maybe_layout_constraints: dict[
- torch._ops.OpOverload, Optional[Callable[..., Any]]
- ] = {}
- fallbacks = OrderedSet[torch._ops.OpOverload]()
- aten = torch.ops.aten
- tr_c10d = torch.ops.tr_c10d
- prims = torch.ops.prims
- needs_realized_inputs = OrderedSet[torch._ops.OpOverload]()
- foreach_ops = OrderedSet[torch._ops.OpOverload](
- [torch._higher_order_ops._foreach_map] # type: ignore[list-item]
- )
- # TODO(rec): torch._higher_order_ops._foreach_map is not an OpOverload
- # so why is it in foreach_ops?
- inplace_foreach_ops = OrderedSet[torch._ops.OpOverload]()
- inplaceable_foreach_ops: dict[torch._ops.OpOverload, torch._ops.OpOverload] = {}
- quantized_decomposed = torch.ops.quantized_decomposed
- def cur_node_has_non_foreach_users() -> bool:
- for node in V.graph.current_node.users:
- for user in node.users:
- if not (user.op == "call_function" and (user.target in foreach_ops)):
- return True
- return False
- # group by device, whether any of the inputs are dynamic
- # note arg_pairs may or may not be a pair
- # foreach_map for example just passes output buffers here
- def group_foreach_args(
- arg_pairs: Iterable[Any],
- ) -> defaultdict[tuple[Any, bool], list[tuple[int, Any]]]:
- out = defaultdict(list)
- unpack_args = False
- for i, args in enumerate(arg_pairs):
- if not isinstance(args, Iterable):
- unpack_args = True
- args = (args,)
- use_foreach = (
- not is_dynamic(*args) or config.combo_kernel_foreach_dynamic_shapes
- )
- device = None
- for t in args:
- if isinstance(t, TensorBox):
- device = t.data.get_device()
- break
- assert device is not None, "foreach op should have at least one tensor arg"
- if unpack_args:
- (args,) = args
- out[(device, use_foreach)].append((i, args))
- return out
- def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., Any]]:
- """Get layout constraints. Returns None if there are no layout constraints."""
- if not isinstance(fn, torch._ops.OpOverload):
- # Only OpOverloads have layout constraints.
- return None
- if maybe_layout_tag := get_layout_constraint_tag(fn, with_default=False):
- return tag_to_layout_constraint(maybe_layout_tag)
- if fn in _maybe_layout_constraints:
- return _maybe_layout_constraints[fn]
- return None
- def tag_to_layout_constraint(
- tag: torch._C.Tag,
- ) -> Optional[Callable[..., tuple[Any, Any]]]:
- if tag == torch._C.Tag.needs_exact_strides:
- return constrain_to_fake_tensors
- if tag == torch._C.Tag.needs_contiguous_strides: # type: ignore[attr-defined]
- return require_contiguous_strides
- if tag == torch._C.Tag.needs_fixed_stride_order:
- return constrain_to_fx_strides
- if tag == torch._C.Tag.flexible_layout:
- return None
- raise AssertionError(f"Unknown layout constraint tag: {tag}")
- def assert_nyi(cond: bool, msg: str) -> None:
- if not cond:
- raise NotImplementedError(f"inductor does not support {msg}")
- def add_needs_realized_inputs(
- fn: Union[
- Collection[Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket]],
- torch._ops.OpOverload,
- torch._ops.OpOverloadPacket,
- ],
- ) -> Optional[list[Any]]:
- if isinstance(fn, (list, set, tuple, OrderedSet)): # noqa: set_linter
- # pyrefly: ignore [bad-argument-type]
- return [add_needs_realized_inputs(x) for x in fn]
- if isinstance(fn, torch._ops.OpOverload):
- needs_realized_inputs.add(fn)
- elif isinstance(fn, torch._ops.OpOverloadPacket):
- needs_realized_inputs.update(
- getattr(fn, overload) for overload in fn.overloads()
- )
- return None
- def add_layout_constraint(
- fn: Union[torch._ops.OpOverloadPacket, torch._ops.OpOverload],
- constraint: Callable[..., tuple[Any, Any]],
- ) -> None:
- if isinstance(fn, torch._ops.OpOverloadPacket):
- for overload in fn.overloads():
- _maybe_layout_constraints[getattr(fn, overload)] = constraint
- else:
- _maybe_layout_constraints[fn] = constraint
- add_needs_realized_inputs(
- [
- aten.as_strided,
- aten.as_strided_copy,
- aten.avg_pool2d,
- aten.avg_pool2d_backward,
- aten.bmm,
- aten.convolution,
- aten.convolution_backward,
- aten.max_pool2d_with_indices,
- aten.max_pool3d_with_indices,
- aten.max_pool2d_with_indices_backward,
- aten.mm,
- aten.upsample_nearest2d,
- aten._upsample_nearest_exact2d,
- aten._int_mm,
- ]
- )
- # TODO(jansel): ezyang says we won't need this in the future, try removing it
- # based on https://github.com/pytorch/pytorch/blob/9e3eb329df8f701/c10/core/ScalarType.h#L28
- DTYPE_ID_LOOKUP = {
- 0: torch.uint8,
- 1: torch.int8,
- 2: torch.int16,
- 3: torch.int32,
- 4: torch.int64,
- 5: torch.float16,
- 6: torch.float32,
- 7: torch.float64,
- 8: torch.complex32,
- 9: torch.complex64,
- 10: torch.complex32,
- 11: torch.bool,
- 15: torch.bfloat16,
- # TODO(jansel): add quantized types?
- # _(c10::qint8, QInt8) /* 12 */
- # _(c10::quint8, QUInt8) /* 13 */
- # _(c10::qint32, QInt32) /* 14 */
- # _(c10::quint4x2, QUInt4x2) /* 16 */
- # _(c10::quint2x4, QUInt2x4) /* 17 */
- }
- def decode_dtype(dtype: Union[int, torch.dtype]) -> torch.dtype:
- if not isinstance(dtype, int):
- return dtype
- assert dtype in DTYPE_ID_LOOKUP, f"id {dtype} missing from DTYPE_ID_LOOKUP"
- dtype = DTYPE_ID_LOOKUP[dtype]
- return dtype
- def is_integer_type(x: Any) -> TypeGuard[Union[TensorBox, sympy.Expr, int]]:
- if isinstance(x, TensorBox):
- return is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
- elif isinstance(x, sympy.Expr):
- return x.is_integer is True # type: ignore[attr-defined]
- else:
- return isinstance(x, int)
- def is_boolean_type(x: Any) -> TypeGuard[Union[TensorBox, bool]]:
- if isinstance(x, TensorBox):
- return is_boolean_dtype(x.get_dtype())
- else:
- return isinstance(x, bool)
- def get_promoted_dtype(
- *args: Any,
- type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
- return_compute_dtype: bool = False,
- ) -> torch.dtype:
- def construct_input(inp: Any) -> Any:
- if isinstance(inp, (Number, sympy.Basic)):
- return inp
- else:
- dim = len(inp.get_size())
- # construct a tmp tensor to feed into torch.result_type
- return torch.zeros([1] * dim, dtype=inp.get_dtype())
- inps = [construct_input(arg) for arg in args]
- compute_dtype, result_dtype = elementwise_dtypes(
- *inps, type_promotion_kind=type_promotion_kind
- )
- return compute_dtype if return_compute_dtype else result_dtype
- def get_overloads(aten_fn):
- if not isinstance(aten_fn, (list, tuple)):
- aten_fn = [aten_fn]
- else:
- aten_fn = list(aten_fn)
- for fn in list(aten_fn):
- if isinstance(fn, torch._ops.OpOverloadPacket):
- for overload in fn.overloads():
- other_fn = getattr(fn, overload)
- if other_fn not in lowerings:
- aten_fn.append(other_fn)
- return aten_fn
- def in_namespace(
- op: Union[Any, torch._ops.OpOverloadPacket, torch._ops.OpOverload], namespace: str
- ) -> bool:
- if isinstance(op, torch._ops.OpOverloadPacket):
- return namespace in op._qualified_op_name
- elif isinstance(op, torch._ops.OpOverload):
- return namespace in op.name()
- return False
- def maybe_copy_cpu_scalar(x: TensorBox, device: torch.device) -> TensorBox:
- """
- Copy cpu scalar if doesn't not match with given `device`
- """
- if not isinstance(x.data, ir.ReinterpretView) or has_free_unbacked_symbols(
- x.get_size()
- ):
- return x
- size = [V.graph.sizevars.size_hint_or_throw(s) for s in x.get_size()]
- cur_device = x.get_device()
- if (
- cur_device is not None
- and cur_device.type == "cpu"
- and cur_device != device
- and (len(size) == 0 or (len(size) == 1 and size[0] == 1))
- ):
- return TensorBox(ir.StorageBox(ir.DeviceCopy.create(x, cur_device, False)))
- return x
- def transform_args(
- args: list[Any],
- kwargs: dict[str, Any],
- broadcast: bool,
- type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND],
- convert_input_to_bool: bool,
- ) -> tuple[list[Any], dict[str, Any]]:
- """
- Transforms arguments for broadcasting and type promotion
- """
- args_indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)]
- kwargs_indices = [k for k, v in kwargs.items() if isinstance(v, TensorBox)]
- # check that there's something to transform
- if not args_indices and not kwargs_indices:
- return args, kwargs
- if type_promotion_kind or convert_input_to_bool:
- if convert_input_to_bool:
- dtype = torch.bool
- else:
- # FIXME this is a crude approximation for promoting args
- promoting_args = [
- a
- for a in args
- if isinstance(a, (Number, sympy.Basic)) or hasattr(a, "dtype")
- ]
- # only consider tensor kwargs for promotion, for now
- promoting_args.extend(a for a in kwargs.values() if hasattr(a, "dtype"))
- dtype = get_promoted_dtype(
- *promoting_args,
- type_promotion_kind=type_promotion_kind, # type: ignore[arg-type]
- )
- device = (
- args[args_indices[0]] if args_indices else kwargs[kwargs_indices[0]]
- ).get_device()
- for i in args_indices:
- args[i] = maybe_copy_cpu_scalar(args[i], device)
- for k in kwargs_indices:
- kwargs[k] = maybe_copy_cpu_scalar(kwargs[k], device)
- # sometimes args are an immutable list so we can't mutate them
- def promote(arg: Any) -> Any:
- if isinstance(arg, TensorBox):
- return to_dtype(arg, dtype)
- elif isinstance(arg, ir.Constant):
- return ir.Constant(value=arg.value, dtype=dtype, device=device)
- else:
- return arg
- args = [promote(a) for a in args]
- kwargs = {k: promote(v) for k, v in kwargs.items()}
- if broadcast:
- broadcasted = broadcast_tensors(
- *list(
- itertools.chain(
- (args[i] for i in args_indices),
- (kwargs[k] for k in kwargs_indices),
- )
- )
- )
- size = list(broadcasted[0].get_size())
- for i, x in zip(args_indices, broadcasted[: len(args_indices)]):
- args[i] = x
- for k, x in zip(kwargs_indices, broadcasted[len(args_indices) :]):
- kwargs[k] = x
- for i in range(len(args)):
- if isinstance(args[i], ir.Constant):
- args[i] = ExpandView.create(args[i], size)
- for k in kwargs:
- if isinstance(kwargs[k], ir.Constant):
- kwargs[k] = ExpandView.create(kwargs[k], size)
- return args, kwargs
- def _register_foreach_lowering(
- aten_fn: torch._ops.OpOverload, decomp_fn: Callable[..., Any]
- ) -> Callable[..., Any]:
- """
- Add a foreach lowering to lowerings dict.
- Arguments:
- aten_fn: torch.ops.aten.* fn we are lowering
- decomp_fn: alternate implementation on our IR
- broadcast: True to apply broadcasting to tensor inputs
- type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion
- convert_input_to_bool: some logical ops require inputs are converted to bool
- """
- @functools.wraps(decomp_fn)
- def wrapped(*args: Any, **kwargs: Any) -> Any:
- out = decomp_fn(*args, **kwargs)
- validate_ir(out)
- return out
- aten_fns = get_overloads(aten_fn)
- foreach_ops.update(aten_fns)
- lowerings.update(dict.fromkeys(aten_fns, wrapped))
- return wrapped
- def _register_lowering(
- aten_fn,
- decomp_fn: Callable[..., Any],
- broadcast: bool,
- type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND],
- convert_input_to_bool: bool,
- lowering_dict: dict[Union[Callable[..., Any], str], Callable[..., Any]],
- ):
- """
- Add a lowering to lowerings dict
- Arguments:
- aten_fn: torch.ops.aten.* fn we are lowering
- decomp_fn: alternate implementation on our IR
- broadcast: True to apply broadcasting to tensor inputs
- type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion
- convert_input_to_bool: some logical ops require inputs are converted to bool
- """
- @functools.wraps(decomp_fn)
- def wrapped(*args, **kwargs):
- args: list[Any] = list(args)
- kwargs: dict[str, Any] = dict(kwargs)
- unpacked = False
- # TODO maybe we need to use pytrees here
- if len(args) == 1 and isinstance(args[0], (list, tuple)):
- unpacked = True
- args = list(args[0])
- if not all(
- (fn in fallbacks or in_namespace(fn, "_c10d_functional")) for fn in aten_fn
- ):
- # explicitly assert for "out=" ops for better error messages
- assert not any(x == "out" for x in kwargs), "out= ops aren't yet supported"
- args, kwargs = transform_args(
- args, kwargs, broadcast, type_promotion_kind, convert_input_to_bool
- )
- if unpacked:
- args = [args]
- out = decomp_fn(*args, **kwargs)
- validate_ir(out)
- return out
- aten_fn = get_overloads(aten_fn)
- lowering_dict.update(dict.fromkeys(aten_fn, wrapped))
- return wrapped
- def register_lowering(
- aten_fn,
- broadcast=False,
- type_promotion_kind: Optional[
- ELEMENTWISE_TYPE_PROMOTION_KIND
- ] = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- convert_input_to_bool=False,
- lowering_dict=lowerings,
- ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
- """
- Shim to support decorator syntax.
- """
- return functools.partial(
- _register_lowering,
- aten_fn,
- broadcast=broadcast,
- type_promotion_kind=type_promotion_kind,
- convert_input_to_bool=convert_input_to_bool,
- lowering_dict=lowering_dict,
- )
- def broadcast_symbolic_shapes(a, b):
- """
- Broadcasting logic based on symbolic shapes.
- We give the shapes 0 and 1 concrete values, while all other shapes
- are symbolic sympy formulas.
- """
- b = tuple(b)
- if not a or a == b:
- return b
- output = []
- for x, y in itertools.zip_longest(reversed(a), reversed(b), fillvalue=sympy.S.One):
- if V.graph.sizevars.is_size_one_or_false(y):
- output.append(x)
- elif V.graph.sizevars.is_size_one_or_false(x):
- output.append(y)
- else:
- V.graph.sizevars.check_equals(x, y)
- if len(sympy.expand(y).free_symbols) < len(sympy.expand(x).free_symbols):
- output.append(y) # prefer shorter formula
- else:
- output.append(x)
- return tuple(reversed(output))
- def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=None):
- assert override_return_dtype is None or type_promotion_kind is None, (
- "only one of override_return_dtype or type_promotion_kind may be given"
- )
- if override_return_dtype is None and type_promotion_kind is None:
- type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
- if not any(isinstance(x, (sympy.Basic, int, float)) for x in inputs):
- return inputs
- if all(isinstance(x, (int, float, sympy.Basic)) for x in inputs):
- dtype = override_return_dtype or get_promoted_dtype(
- *inputs,
- # pyrefly: ignore [bad-argument-type]
- type_promotion_kind=type_promotion_kind,
- )
- def const_func(x):
- if isinstance(x, sympy.Basic):
- return ir.IndexingConstant(
- index=x, dtype=dtype, device=decode_device(None)
- )
- else:
- return ir.Constant(value=x, dtype=dtype, device=decode_device(None))
- return [const_func(x) for x in inputs]
- ex = next(x for x in inputs if isinstance(x, (TensorBox, ExpandView, ir.Constant)))
- out = []
- for x in inputs:
- if isinstance(x, (int, float)):
- out.append(
- ExpandView.create(
- ir.Constant(
- value=x, dtype=ex.get_dtype(), device=ex.get_device_or_error()
- ),
- list(ex.get_size()),
- )
- )
- elif isinstance(x, sympy.Basic):
- out.append(
- ExpandView.create(
- IndexingConstant(
- index=x, dtype=ex.get_dtype(), device=ex.get_device_or_error()
- ),
- list(ex.get_size()),
- )
- )
- else:
- out.append(x)
- return out
- def make_pointwise(
- fn,
- override_return_dtype=None,
- override_device=None,
- override_fn_when_input_bool=None,
- allow_alpha=False,
- triton_fallback=None,
- ):
- def inner(*inputs: TensorBox, alpha=None):
- if triton_fallback is not None and any(
- isinstance(inp, IRNode) and is_triton(inp) for inp in inputs
- ):
- assert not allow_alpha # not implemented
- return triton_fallback(*inputs)
- inputs = promote_constants(inputs, override_return_dtype)
- if allow_alpha:
- if alpha is not None and alpha != 1:
- # pyrefly: ignore [bad-assignment]
- inputs = list(inputs)
- # pyrefly: ignore [unsupported-operation]
- inputs[-1] = mul(inputs[-1], alpha)
- else:
- assert alpha is None
- loaders = [x.make_loader() for x in inputs]
- ranges = inputs[0].get_size()
- dtype = override_return_dtype or inputs[0].get_dtype()
- for other in inputs[1:]:
- assert isinstance(other, ir.BaseConstant) or len(ranges) == len(
- other.get_size()
- ), f"ndim mismatch {fn} {ranges} {other.get_size()}"
- # in tracing, we will annotate pointwise nodes that correspond to the output of
- # a pointwise node that would have been run in eager. intermediary pointwise nodes
- # during decompositions are not annotated.
- low_pr_fp = (torch.bfloat16, torch.float16)
- emulate_precision_casts = (
- V.graph is not None
- and getattr(V.graph, "current_node", None) is not None
- and V.graph.current_node.meta is not None
- and V.graph.current_node.meta.get("low_precision_pointwise_barrier", False)
- )
- emulate_output_cast = emulate_precision_casts and dtype in low_pr_fp
- def inner_fn(index):
- assert len(index) == len(ranges), f"wrong ndim {index} {ranges}"
- if dtype == torch.bool and override_fn_when_input_bool is not None:
- return override_fn_when_input_bool(*[load(index) for load in loaders])
- else:
- inputs_loaded = []
- for inp_index, load in enumerate(loaders):
- out = load(index)
- inp_dtype = inputs[inp_index].get_dtype()
- if emulate_precision_casts and inp_dtype in low_pr_fp:
- downcast = ops.to_dtype(out, inp_dtype, use_compute_types=False)
- out = ops.to_dtype(downcast, inp_dtype)
- inputs_loaded.append(out)
- out = fn(*inputs_loaded)
- if emulate_output_cast:
- # fp16/bf16 kernels are computed in fp32. Casting down to fp16/bf16 here,
- # then upcasting again, to emulate casts that eager would do.
- downcast = ops.to_dtype(out, dtype, use_compute_types=False)
- return ops.to_dtype(downcast, dtype)
- return out
- if not override_device:
- device = None
- for i in inputs:
- if is_gpu(i.get_device().type):
- device = i.get_device()
- break
- if not device:
- device = inputs[0].get_device()
- # pyrefly: ignore [unbound-name]
- device = override_device or device
- return Pointwise.create(
- device=device, # type: ignore[arg-type]
- dtype=dtype,
- inner_fn=inner_fn,
- ranges=ranges,
- )
- return inner
- def make_foreach_pointwise(pw_fn, allow_alpha=False):
- def inner(*inputs: list[list[TensorBox]], alpha=1):
- realize_outputs = (
- len(V.graph.current_node.users) == 0
- or V.graph.current_node.target in inplace_foreach_ops
- or cur_node_has_non_foreach_users()
- )
- a_list_input = None
- for input in inputs:
- if isinstance(input, (list, tuple)):
- a_list_input = input
- break
- assert a_list_input is not None, (
- "at least one input must be a list to a foreach op"
- )
- # broadcast scalar inputs to match length of list inputs
- broadcast_inputs = []
- for input in inputs:
- if not isinstance(input, (list, tuple)):
- broadcast_inputs.append([input] * len(a_list_input))
- else:
- # pyrefly: ignore [bad-argument-type]
- broadcast_inputs.append(input)
- groups = group_foreach_args(zip(*broadcast_inputs))
- def apply_fn(args):
- if allow_alpha:
- return pw_fn(*args, alpha=alpha)
- else:
- return pw_fn(*args)
- return foreach_group_loop(groups, len(a_list_input), apply_fn, realize_outputs)
- return inner
- def foreach_group_loop(groups, num_outputs, apply_fn, realize_outputs):
- """
- Common loop over grouped foreach arguments.
- Args:
- groups: Result of group_foreach_args - dict mapping (device, use_foreach) to groups
- num_outputs: Number of outputs to produce
- apply_fn: Function to apply to each set of args, returns the output
- realize_outputs: Whether to realize outputs for foreach fusion
- """
- outputs = [None] * num_outputs
- for (device, use_foreach), group in groups.items():
- operation_list: list[str] = []
- for output_ind, args in group:
- output = apply_fn(args)
- outputs[output_ind] = output
- if (
- V.graph.has_feature(device, BackendFeature.FOREACH)
- and use_foreach
- and realize_outputs
- ):
- output.realize()
- operation_list.append(output.get_operation_name())
- if operation_list:
- V.graph.register_operation_list(operation_list)
- assert all(x is not None for x in outputs)
- return outputs
- def to_dtype(x: TensorBox, dtype: torch.dtype, copy: bool = False):
- src_dtype = x.get_dtype()
- if src_dtype == dtype:
- return clone(x) if copy else x
- def _to_dtype(x):
- return ops.to_dtype(x, dtype, src_dtype=src_dtype)
- return make_pointwise(_to_dtype, override_return_dtype=dtype)(x)
- @register_lowering(torch._higher_order_ops._foreach_map, type_promotion_kind=None)
- def _foreach_map(subgraph, *args, **kwargs):
- """
- This lowers an invocation of foreach_map
- The way this works is that an arbitrary N-arg func is provided by the user, looped over by the
- polyfill with the same semantics as a foreach op (a loop applying an n-ary function to n args)
- and then traced into a subgraph by dynamo.
- This code allows us to inline the subgraph into the main graph lowering using the PontwiseSubgraphLowering.
- The graph outputs represent the vertically fused sequence of ops, and then register_operation_list
- below registers the buffers as horizontally fuseable in the scheduler.
- """
- from .subgraph_lowering import PointwiseSubgraphLowering
- inputs = args
- gm = subgraph.graph_module
- pw_subgraph = PointwiseSubgraphLowering(gm, root_graph_lowering=V.graph)
- with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type]
- pw_subgraph.run(*inputs)
- sub_outputs = pw_subgraph.graph_outputs
- # group outputs by device and register as foreach
- assert sub_outputs # mypy lol
- groups = group_foreach_args(sub_outputs)
- outputs = [None] * len(sub_outputs)
- for (device, use_foreach), group in groups.items():
- operation_list: list[str] = []
- for (
- output_ind,
- output,
- ) in group:
- outputs[output_ind] = output
- if V.graph.has_feature(device, BackendFeature.FOREACH) and use_foreach:
- output.realize()
- operation_list.append(output.get_operation_name())
- if operation_list:
- V.graph.register_operation_list(operation_list)
- assert all(x is not None for x in outputs)
- return outputs
- @register_lowering(prims.convert_element_type, type_promotion_kind=None)
- def _convert_element_type(x: TensorBox, dtype: torch.dtype):
- if dtype.is_complex or x.get_dtype().is_complex:
- if x.get_size():
- # Decompose since aa aten fallback is more friendly for c++ codegen.
- # This decomposition doesn't work for empty tensor, which needs more investigation.
- dst = empty_like(x, dtype=dtype)
- ir.InplaceCopyFallback.create(dst, x)
- return dst
- else:
- return fallback_handler(
- prims.convert_element_type.default, add_to_fallback_set=False
- )(x, dtype)
- return to_dtype(x, dtype, copy=True)
- def to_dtype_bitcast(x: TensorBox, dtype: torch.dtype, *, copy=False):
- x_dtype = x.get_dtype()
- if x_dtype == dtype:
- return clone(x) if copy else x
- def _get_primitive_bitwidth(dtype):
- if dtype.is_floating_point:
- return torch.finfo(dtype).bits
- else:
- return torch.iinfo(dtype).bits
- src_bits = _get_primitive_bitwidth(x_dtype)
- dst_bits = _get_primitive_bitwidth(dtype)
- if src_bits != dst_bits:
- # fallback to aten eager implementation for differing bitwidths
- return fallback_handler(aten.view.dtype)(x, dtype)
- else:
- return TensorBox(DtypeView.create(x, dtype))
- @register_lowering(aten.view.dtype, type_promotion_kind=None)
- def _view_dtype(x: TensorBox, dtype: torch.dtype):
- if dtype.is_complex or x.get_dtype().is_complex:
- return TensorBox.create(
- ir.ComplexView.create(torch.ops.aten.view.dtype, x, dtype)
- )
- return to_dtype_bitcast(x, dtype)
- def to_device(x: TensorBox, device: torch.device, *, copy=False, non_blocking=False):
- device = decode_device(device)
- if x.get_device() == device:
- return clone(x) if copy else x
- return TensorBox.create(ir.DeviceCopy.create(x, device, non_blocking))
- @register_lowering(prims.device_put, type_promotion_kind=None)
- def _device_put(x: TensorBox, device: torch.device, non_blocking=False):
- return to_device(x, device, copy=True, non_blocking=non_blocking)
- def register_pointwise(
- aten_fn,
- name=None,
- broadcast=True,
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- convert_input_to_bool=False,
- override_return_dtype=None,
- override_fn_when_input_bool=None,
- allow_alpha=False,
- triton_fallback=None,
- ):
- """A pointwise function that maps ops.{name} to inputs"""
- name = name or aten_fn.__name__
- fn = ops_wrapper(name)
- register_op_dtype_propagation_rules(
- name, type_promotion_kind, override_return_dtype
- )
- if override_fn_when_input_bool is not None:
- override_fn_when_input_bool = ops_wrapper(override_fn_when_input_bool)
- fn = make_pointwise(
- fn,
- override_return_dtype=override_return_dtype,
- override_fn_when_input_bool=override_fn_when_input_bool,
- allow_alpha=allow_alpha,
- triton_fallback=triton_fallback,
- )
- fn = register_lowering(
- aten_fn,
- broadcast=broadcast,
- type_promotion_kind=type_promotion_kind,
- convert_input_to_bool=convert_input_to_bool,
- )(fn)
- if hasattr(prims, name):
- register_lowering(
- getattr(prims, name),
- type_promotion_kind=None,
- convert_input_to_bool=convert_input_to_bool,
- )(fn)
- return fn
- register_op_dtype_propagation_rules(
- "ldexp",
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- override_return_dtype=None,
- )
- @register_lowering(aten.ldexp, broadcast=True, type_promotion_kind=None)
- def ldexp_lowering(x: TensorBox, n: TensorBox):
- ldexp_fn = ops_wrapper("ldexp")
- x_dtype = x.get_dtype()
- n_dtype = n.get_dtype()
- x_is_float = x_dtype.is_floating_point
- n_is_int = not n_dtype.is_floating_point and n_dtype != torch.bool
- if x_is_float and n_is_int:
- # Use native ldexp
- def compute_ldexp(x, n):
- return ldexp_fn(x, n)
- return make_pointwise(compute_ldexp)(x, n)
- else:
- # Fall back to decomposition: x * pow(2, n)
- out_dtype = torch.float32 if is_integer_type(x) else x_dtype
- def compute_fallback(x, n):
- n_out_type = ops.to_dtype(n, out_dtype)
- two = ops.constant(2.0, out_dtype)
- pow_result = ops.pow(two, n_out_type)
- return ops.mul(x, pow_result)
- return make_pointwise(
- compute_fallback,
- override_return_dtype=out_dtype,
- )(x, n)
- def register_frexp():
- """A pointwise function that maps ops.frexp to inputs"""
- name = "frexp"
- frexp = ops_wrapper("frexp")
- def frexp0(*args, **kwargs):
- return frexp(*args, **kwargs)[0] # type: ignore[index]
- def frexp1(*args, **kwargs):
- return frexp(*args, **kwargs)[1] # type: ignore[index]
- pw_fns = [
- make_pointwise(frexp0),
- make_pointwise(frexp1, override_return_dtype=torch.int32),
- ]
- def fn(*args, **kwargs):
- return pw_fns[0](*args, **kwargs), pw_fns[1](*args, **kwargs)
- fn = register_lowering(
- aten.frexp,
- )(fn)
- if hasattr(prims, name):
- register_lowering(
- getattr(prims, name),
- type_promotion_kind=None,
- )(fn)
- return fn
- register_frexp()
- def register_foreach_pointwise(
- aten_fn,
- pointwise_lowering_fn,
- allow_alpha=False,
- ):
- fn = make_foreach_pointwise(pointwise_lowering_fn, allow_alpha=allow_alpha)
- fn = _register_foreach_lowering(aten_fn, fn)
- return fn
- @register_lowering(aten.where, broadcast=False, type_promotion_kind=None)
- def where(cond, a, b):
- def fn(*args):
- return ops.where(*args)
- if isinstance(a, (float, int)):
- a = constant_like(a)(b)
- if isinstance(b, (float, int)):
- b = constant_like(b)(a)
- args = [cond, a, b]
- dtype = get_promoted_dtype(
- args[1], args[2], type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
- )
- indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)]
- for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])):
- args[i] = x
- for i in range(len(args)):
- if isinstance(args[i], ir.Constant):
- args[i] = ExpandView.create(args[i], list(args[indices[0]].get_size()))
- return make_pointwise(fn, override_return_dtype=dtype)(
- args[0], to_dtype(args[1], dtype), to_dtype(args[2], dtype)
- )
- @register_lowering(aten.broadcast_tensors, broadcast=False, type_promotion_kind=None)
- def broadcast_tensors(*inputs):
- if len(inputs) == 1:
- if isinstance(inputs[0], (list, tuple)):
- return broadcast_tensors(*inputs[0])
- return inputs
- target: list[sympy.Expr] = functools.reduce(
- broadcast_symbolic_shapes, (x.get_size() for x in inputs), ()
- )
- outputs = []
- for x in inputs:
- if (sizes := tuple(x.get_size())) == target:
- pass
- elif len(sizes) != len(target) or any(
- V.graph.sizevars.is_size_one_or_false(a)
- != V.graph.sizevars.is_size_one_or_false(b)
- for a, b in zip(sizes, target)
- ):
- x = expand(x, target)
- outputs.append(x)
- return outputs
- @register_lowering([aten.alias, aten.detach, aten.detach_, aten.lift, prims.view_of])
- def nop(x):
- return x # AOT autograd handles this for us
- if hasattr(aten, "lift_fresh"):
- register_lowering(aten.lift_fresh)(nop)
- @register_lowering(aten.squeeze, type_promotion_kind=None)
- def squeeze(x, dim=None):
- assert isinstance(x, TensorBox)
- if dim is None:
- return TensorBox(SqueezeView.create(x.data))
- dim = (
- V.graph.sizevars.guard_int(dim)
- if isinstance(dim, (int, sympy.Expr))
- else tuple(V.graph.sizevars.guard_int(d) for d in dim)
- )
- dim = canonicalize_dims(len(x.get_size()), dim) # type: ignore[call-overload]
- dims = OrderedSet((dim,) if not isinstance(dim, tuple) else dim)
- new_shape = []
- for d, s in enumerate(x.get_size()):
- if not (d in dims and V.graph.sizevars.guard_or_false(sympy.Eq(s, 1))):
- new_shape.append(s)
- # squeeze does nothing if the size isn't 1
- return view(x, new_shape) if new_shape != x.get_size() else x
- @register_lowering(aten.squeeze_copy, type_promotion_kind=None)
- def squeeze_copy(x, dim=None):
- return clone(squeeze(x, dim))
- @register_lowering([aten.squeeze_])
- def squeeze_(x, dim=None):
- val = squeeze(x, dim)
- assert isinstance(x, TensorBox)
- assert isinstance(val, TensorBox)
- x.data = val.data
- return x
- @register_lowering(aten.isinf)
- def isinf(x):
- if is_integer_type(x):
- return full_like(x, False, dtype=torch.bool)
- fn = ops_wrapper("isinf")
- return make_pointwise(fn, override_return_dtype=torch.bool)(x)
- @register_lowering(aten.isnan)
- def isnan(x):
- if is_integer_type(x):
- return full_like(x, False, dtype=torch.bool)
- fn = ops_wrapper("isnan")
- return make_pointwise(fn, override_return_dtype=torch.bool)(x)
- @register_lowering(aten.ceil)
- def ceil(x):
- if is_integer_type(x):
- return clone(x)
- fn = ops_wrapper("ceil")
- return make_pointwise(fn)(x)
- @register_lowering(aten.floor)
- def floor(x):
- if is_integer_type(x):
- return clone(x)
- fn = ops_wrapper("floor")
- return make_pointwise(fn)(x)
- @register_lowering(aten.round.default)
- def round(x):
- if is_integer_type(x):
- return clone(x)
- else:
- fn = ops_wrapper("round")
- return make_pointwise(fn)(x)
- @register_lowering(aten.trunc)
- def trunc(x):
- if is_integer_type(x):
- return clone(x)
- fn = ops_wrapper("trunc")
- return make_pointwise(fn)(x)
- @register_lowering(aten.expand, type_promotion_kind=None)
- def expand(x, sizes):
- (x,) = promote_constants([x])
- if isinstance(x, ir.BaseConstant):
- return ExpandView.create(x, tuple(sizes))
- assert isinstance(x, TensorBox)
- assert isinstance(sizes, (list, tuple))
- if tuple(x.get_size()) == tuple(sizes):
- return x
- if not free_unbacked_symbols(x.get_size()):
- x_size_product = V.graph.sizevars.size_hint_or_throw(
- sympy_product(x.get_size())
- )
- # TODO: It would be better to realize the input if any of its sizes
- # are unbacked, because typically the size will be non-zero. However,
- # this cannot be done directly as below as we'll choke on the size_hint
- # here
- if x_size_product > 0 and not free_unbacked_symbols(sizes):
- # maybe realize input before broadcasting it
- x.mark_reuse(
- V.graph.sizevars.size_hint_or_throw(sympy_product(sizes))
- // x_size_product
- )
- return TensorBox(ExpandView.create(x.data, tuple(sizes)))
- @register_lowering(prims.broadcast_in_dim, type_promotion_kind=None)
- def broadcast_in_dim(a, shape, broadcast_dimensions):
- s = list(shape)
- for broadcast_dimension in broadcast_dimensions:
- s[broadcast_dimension] = -1
- v = a
- for idx, x in enumerate(s):
- if x != -1:
- v = unsqueeze(v, idx)
- return expand(v, shape)
- @register_lowering(aten.expand_as, type_promotion_kind=None)
- def expand_as(x, y):
- return expand(x, y.get_size())
- @register_lowering(aten.repeat)
- def repeat(x, repeats):
- old_size = list(x.get_size())
- if len(repeats) > len(old_size):
- old_size = [sympy.S.One] * (len(repeats) - len(old_size)) + old_size
- x = view(x, list(old_size))
- assert len(repeats) == len(x.get_size())
- new_size = list(x.get_size())
- zero_tensor = False
- for i in range(len(repeats)):
- if repeats[i] == 0:
- zero_tensor = True
- new_size[i] = new_size[i] * repeats[i]
- if zero_tensor:
- return empty(new_size, dtype=x.get_dtype(), device=x.get_device())
- if all((a == 1 or b == 1) for a, b in zip(repeats, old_size)):
- return clone(expand(x, new_size))
- x_loader: Callable[[Any], Any]
- def inner_fn(index):
- assert len(index) == len(repeats)
- index = list(index)
- for i in range(len(repeats)):
- if repeats[i] != 1:
- if old_size[i] == 1:
- index[i] = sympy.S.Zero
- else:
- index[i] = ModularIndexing(index[i], 1, old_size[i])
- return x_loader(index)
- if not free_unbacked_symbols(old_size) and not free_unbacked_symbols(new_size):
- old_size_product = V.graph.sizevars.size_hint_or_throw(sympy_product(old_size))
- if old_size_product > 0:
- # maybe realize the input but skip for unbacked symints since it'll
- # choke on the size hint.
- x.mark_reuse(
- V.graph.sizevars.size_hint_or_throw(sympy_product(new_size))
- // old_size_product
- )
- x_loader = x.make_loader()
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=inner_fn,
- ranges=list(new_size),
- )
- @register_lowering(aten._unsafe_view, type_promotion_kind=None)
- @register_lowering(aten.view, type_promotion_kind=None)
- @register_lowering(aten.reshape, type_promotion_kind=None)
- def view(x: TensorBox, sizes: Sequence[sympy.Expr]) -> TensorBox:
- return TensorBox(View.create(x.data, sizes))
- @register_lowering(aten.permute, type_promotion_kind=None)
- def permute(x, dims):
- assert isinstance(x, TensorBox)
- assert isinstance(dims, (list, tuple))
- return TensorBox(PermuteView.create(x.data, tuple(dims)))
- @register_lowering(aten.slice, type_promotion_kind=None)
- def slice_(x, dim=0, start=0, end=2**63, step=1, clamp=True):
- """
- Lowers a slice call, creating ExternKernels for the output size & storage offset symbols,
- if the indices are unbacked and appropriate semantics aren't known.
- If they are known (indices are static/backed/unbacked with info), a SliceView is created.
- """
- from torch.fx.experimental.symbolic_shapes import (
- CallMethodKey,
- resolve_unbacked_bindings,
- )
- assert isinstance(x, TensorBox)
- dim = _validate_dim(x, dim, 0)
- size = x.get_size()[dim]
- step = sympy.expand(step)
- assert isinstance(step, sympy.Expr) or step > 0, step
- # maybe apply slice optimization
- try:
- if (
- start == 0
- and V.graph.sizevars.statically_known_leq(size, end)
- and step == 1
- ):
- return x
- except TypeError:
- pass
- # try to avoid dynamic (unbacked) slice
- def compute_slice_index(index, size, default=None):
- if index is None:
- return default
- fn = lambda x: V.graph.sizevars.guard_or_false(x) # noqa: E731
- index = sympy.expand(index)
- size = sympy.expand(size)
- if fn(sympy.Ge(index, 0)) and fn(sympy.Le(index, size)):
- return index
- elif fn(sympy.Lt(index, 0)) and fn(sympy.Ge(index, -size)):
- return index + size
- elif fn(sympy.Gt(index, size)):
- return size
- elif fn(sympy.Lt(index, -size)):
- return 0
- return None
- start_index, end_index = None, None
- ambiguous_slice = clamp
- if ambiguous_slice:
- start_index = compute_slice_index(start, size, 0)
- end_index = compute_slice_index(end, size, size)
- if start_index is not None and end_index is not None:
- start, end = start_index, end_index
- ambiguous_slice = False
- # ambiguous_slice=False means we know what semantics this slice call follows,
- # and don't need to generate an extern kernel to represent the output size.
- # This is assumed True for clamp=False
- # (meant to follow standard indexing semantics: 0 <= index < size)
- if not ambiguous_slice:
- return TensorBox(
- ir.SliceView.create(x.data, dim, start, end, step, clamp=clamp)
- ) # go to SliceView/ReinterpretView
- # unbacked territory: create DynamicSlice ExternKernel
- # clamp is True, unbacked start / end
- assert clamp
- unbacked_bindings = resolve_unbacked_bindings(
- V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"]
- )
- assert unbacked_bindings is not None
- assert len(unbacked_bindings) <= 2, unbacked_bindings
- sym_size, sym_storage = None, None
- for sym, keypath in unbacked_bindings.items():
- if keypath == (CallMethodKey("size"), pytree.SequenceKey(dim)):
- sym_size = sym
- elif keypath == (CallMethodKey("storage_offset"),):
- sym_storage = sym
- assert start_index is None or end_index is None
- b_size = ir.DynamicSliceSize(
- sym_size,
- start,
- end,
- step,
- x.get_size()[dim],
- )
- b_size.name = V.graph.register_buffer(b_size)
- V.graph.register_operation(b_size)
- new_size = sym_size
- if x.maybe_get_layout() is None:
- # realize tensor before accessing layout
- x.realize()
- if start_index is not None:
- # we shouldn't have allocated storage offset symbol if start index was determinable
- assert sym_storage is None
- new_storage_offset = x.get_layout().offset + start_index * x.get_stride()[dim]
- else:
- b_storage = ir.DynamicSelectStorageOffset(
- sym_storage,
- start,
- x.get_layout().offset,
- x.get_stride()[dim],
- x.get_size()[dim],
- clamp=True,
- )
- b_storage.name = V.graph.register_buffer(b_storage)
- V.graph.register_operation(b_storage)
- new_storage_offset = sym_storage
- new_sizes = list(x.get_size())
- new_strides = list(x.get_stride())
- new_sizes[dim] = new_size
- new_strides[dim] *= step
- return as_strided(x, new_sizes, new_strides, new_storage_offset)
- @register_lowering(aten.as_strided, type_promotion_kind=None)
- def as_strided(x, size, stride, storage_offset=None):
- new_device = None
- new_dtype = None
- if isinstance(x, TensorBox) and isinstance(x.data, ir.BaseView):
- # Note: Merging views
- # When we use as_strided, we can rewrite the size/stride/offset
- # of the incoming buffer x. If x is a view, we would overwrite
- # its metadata. Except for dtype, which we need to propagate.
- # Technically device is not needed because it is not possible
- # to have a cross-device view today.
- new_device = x.get_device()
- new_dtype = x.dtype
- x = x.data.unwrap_view()
- x.realize()
- if not ir.is_storage_and_layout(x):
- raise NotImplementedError(f"unrealized as_strided({x}, ...)")
- storage, old_layout = ir.as_storage_and_layout(x)
- new_layout = ir.FixedLayout(
- new_device if new_device else old_layout.device,
- new_dtype if new_dtype else old_layout.dtype,
- [sympy.expand(s) for s in size],
- [sympy.expand(s) for s in stride],
- sympy.expand(storage_offset or 0),
- )
- return TensorBox(ir.ReinterpretView(data=storage, layout=new_layout))
- @register_lowering(aten.as_strided_, type_promotion_kind=None)
- def as_strided_(x, size, stride, storage_offset=None):
- assert isinstance(x, TensorBox)
- x.data = as_strided(x, size, stride, storage_offset).data
- return x
- @register_lowering(aten.as_strided_copy, type_promotion_kind=None)
- def as_strided_copy(x, size, stride, storage_offset=None):
- result = as_strided(x, size, stride, storage_offset)
- return clone(result)
- def pointwise_cat(inputs, dim=0):
- # (inclusive, exclusive)
- inputs_ranges: list[tuple[sympy.Expr, sympy.Expr]] = []
- prev_end = 0
- for inp in inputs:
- inputs_ranges.append((prev_end, prev_end + inp.get_size()[dim])) # type: ignore[arg-type]
- prev_end = inputs_ranges[-1][-1] # type: ignore[assignment]
- inputs_loaders = [inp.make_loader() for inp in inputs]
- def inner_fn(idx):
- idx_dim = ops.index_expr(idx[dim], torch.int64)
- masks = []
- masked_loads = []
- for i in range(len(inputs)):
- start = (
- ops.constant(0, torch.int64)
- if i == 0
- else ops.index_expr(inputs_ranges[i][0], torch.int64)
- )
- end = ops.index_expr(inputs_ranges[i][1], torch.int64)
- start_cond = ops.ge(idx_dim, start)
- end_cond = ops.lt(idx_dim, end)
- if i == 0:
- mask = end_cond
- elif i == len(inputs) - 1:
- mask = start_cond
- else:
- mask = ops.and_(start_cond, end_cond)
- masks.append(mask)
- idx_load = list(idx)
- # if we're concatting [4], [2]
- # when we index the second tensor for 5 we want to index 5 - 4
- # Use Identity to prevent expansion of index * stride to keep expression
- # in same int bitwidth as shape
- idx_load[dim] = Identity(idx_load[dim] - inputs_ranges[i][0])
- masked_loads.append(
- ops.masked(
- mask,
- lambda: inputs_loaders[i](idx_load),
- 0.0, # this value should be unused
- ),
- )
- next_val = masked_loads[-1]
- for i in range((len(inputs)) - 2, -1, -1):
- next_val = ops.where(
- masks[i],
- masked_loads[i],
- next_val,
- )
- return next_val
- new_size = list(inputs[0].get_size())
- new_size[dim] = inputs_ranges[-1][-1]
- return Pointwise.create(
- device=inputs[0].get_device(),
- dtype=inputs[0].get_dtype(),
- inner_fn=inner_fn,
- ranges=new_size,
- )
- @register_lowering(quantized_decomposed.quantize_per_channel, type_promotion_kind=None)
- def quantized_decomposed_quantize_per_channel(
- input: TensorBox,
- scales: TensorBox,
- zero_points: TensorBox,
- axis: int,
- quant_min: int,
- quant_max: int,
- dtype: torch.dtype,
- ) -> TensorBox:
- assert len(scales.get_size()) == 1, "expect scales 1 dim"
- assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim"
- if input.get_dtype() == torch.bfloat16:
- input = to_dtype(input, torch.float32)
- assert input.get_dtype() == torch.float32, (
- f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
- )
- assert axis < len(input.get_size()), (
- f"Expecting axis to be < {len(input.get_size())}"
- )
- input_loader = input.make_loader()
- scales_loader = scales.make_loader()
- zero_points_loader = zero_points.make_loader()
- def inner_fn(idx):
- channel_idx = (idx[axis],)
- input = input_loader(idx)
- scale = scales_loader(channel_idx)
- zero_point = zero_points_loader(channel_idx)
- qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32)
- if scales.dtype != torch.float32:
- scale = ops.to_dtype(scale, torch.float32)
- if zero_points.dtype != torch.int32:
- zero_point = ops.to_dtype(zero_point, torch.int32)
- inv_scale = ops.reciprocal(scale)
- val = ops.round(input * inv_scale) + zero_point
- clamped = ops.maximum(qmin, ops.minimum(qmax, val))
- return ops.to_dtype(clamped, dtype)
- return Pointwise.create(
- device=input.get_device(),
- dtype=dtype,
- inner_fn=inner_fn,
- ranges=input.get_size(),
- )
- def _assert_async(cond, msg):
- cond.realize()
- cond = to_dtype(cond, torch.bool)
- def inner_fn(index):
- with ir.ComputedBuffer.force_realize():
- return ops.device_assert_async(cond.make_loader()(index), msg)
- assertion_op = Pointwise.create(
- device=cond.get_device(),
- dtype=cond.get_dtype(),
- inner_fn=inner_fn,
- ranges=list(cond.get_size()),
- )
- assertion_op.realize()
- return assertion_op
- @register_lowering(aten._assert_async.msg)
- def lower_assert_async(cond, msg):
- return _assert_async(cond, msg)
- @register_lowering(aten._functional_assert_async.msg)
- def lower_assert_functional_async(cond, msg):
- return _assert_async(cond, msg)
- @register_lowering(
- quantized_decomposed.dequantize_per_channel, type_promotion_kind=None
- )
- def quantized_decomposed_dequantize_per_channel(
- input: TensorBox,
- scales: TensorBox,
- zero_points: TensorBox,
- axis: int,
- quant_min: int,
- quant_max: int,
- dtype: torch.dtype,
- *,
- out_dtype: Optional[torch.dtype] = None,
- ) -> TensorBox:
- assert len(scales.get_size()) == 1, "expect scales 1 dim"
- assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim"
- assert input.get_dtype() == dtype, (
- f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
- )
- assert axis < len(input.get_size()), (
- f"Expecting axis to be < {len(input.get_size())}"
- )
- if out_dtype is None:
- out_dtype = torch.float32
- input_loader = input.make_loader()
- scales_loader = scales.make_loader()
- zero_points_loader = zero_points.make_loader()
- def inner_fn(idx):
- channel_idx = (idx[axis],)
- input = input_loader(idx)
- scale = scales_loader(channel_idx)
- zero_point = zero_points_loader(channel_idx)
- if scales.dtype != torch.float32:
- scale = ops.to_dtype(scale, torch.float32)
- if zero_points.dtype != torch.float32:
- zero_point = ops.to_dtype(zero_point, torch.float32)
- val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale
- val = ops.to_dtype(val, out_dtype)
- return val
- return Pointwise.create(
- device=input.get_device(),
- dtype=out_dtype,
- inner_fn=inner_fn,
- ranges=input.get_size(),
- )
- @register_lowering(
- quantized_decomposed.quantize_per_tensor.default, type_promotion_kind=None
- )
- def quantized_decomposed_quantize_per_tensor_default(
- input: TensorBox,
- scale: float,
- zero_point: int,
- quant_min: int,
- quant_max: int,
- dtype: torch.dtype,
- ) -> TensorBox:
- if input.get_dtype() == torch.bfloat16:
- input = to_dtype(input, torch.float32)
- assert input.get_dtype() == torch.float32, (
- f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
- )
- input_loader = input.make_loader()
- def inner_fn(idx, scale, zero_point):
- input = input_loader(idx)
- inv_scale, zero_point = _create_constants(
- 1.0 / scale, zero_point, dtype=torch.float32
- )
- val = ops.round(input * inv_scale) + zero_point
- qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32)
- clamped = ops.minimum(ops.maximum(val, qmin), qmax)
- return ops.to_dtype(clamped, dtype)
- return Pointwise.create(
- device=input.get_device(),
- dtype=dtype,
- inner_fn=functools.partial(
- inner_fn, scale=float(scale), zero_point=int(zero_point)
- ),
- ranges=input.get_size(),
- )
- @register_lowering(
- quantized_decomposed.dequantize_per_tensor.default, type_promotion_kind=None
- )
- def quantized_decomposed_dequantize_per_tensor_default(
- input: TensorBox,
- scale: float,
- zero_point: int,
- quant_min: int,
- quant_max: int,
- dtype: torch.dtype,
- *,
- out_dtype: Optional[torch.dtype] = None,
- ) -> TensorBox:
- assert input.get_dtype() == dtype, (
- f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
- )
- if out_dtype is None:
- out_dtype = torch.float32
- input_loader = input.make_loader()
- def inner_fn(idx, scale, zero_point):
- input = input_loader(idx)
- scale, zero_point = _create_constants(scale, zero_point, dtype=torch.float32)
- val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale
- val = ops.to_dtype(val, out_dtype)
- return val
- return Pointwise.create(
- device=input.get_device(),
- dtype=out_dtype,
- inner_fn=functools.partial(
- inner_fn, scale=float(scale), zero_point=int(zero_point)
- ),
- ranges=input.get_size(),
- )
- @register_lowering(
- quantized_decomposed.quantize_per_tensor.tensor, type_promotion_kind=None
- )
- def quantized_decomposed_quantize_per_tensor_tensor(
- input: TensorBox,
- scale: TensorBox,
- zero_point: TensorBox,
- quant_min: int,
- quant_max: int,
- dtype: torch.dtype,
- ) -> TensorBox:
- if input.get_dtype() == torch.bfloat16:
- input = to_dtype(input, torch.float32)
- assert input.get_dtype() == torch.float32, (
- f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
- )
- assert len(scale.get_size()) == 0 or (
- len(scale.get_size()) == 1 and scale.get_size()[0] == 1
- ), "expect scale as scalar tensor"
- assert len(zero_point.get_size()) == 0 or (
- len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1
- ), "expect zero_point as scalar tensor"
- input_loader = input.make_loader()
- scale_loader = scale.make_loader()
- zero_point_loader = zero_point.make_loader()
- def inner_fn(idx):
- input = input_loader(idx)
- _scale = scale_loader((0,) if len(scale.get_size()) == 1 else ())
- _zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ())
- if scale.dtype != torch.float32:
- _scale = ops.to_dtype(_scale, torch.float32)
- if zero_point.dtype != torch.float32:
- _zero_point = ops.to_dtype(_zero_point, torch.float32)
- val = ops.round(input * ops.reciprocal(_scale)) + _zero_point
- qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32)
- clamped = ops.minimum(ops.maximum(val, qmin), qmax)
- return ops.to_dtype(clamped, dtype)
- return Pointwise.create(
- device=input.get_device(),
- dtype=dtype,
- inner_fn=inner_fn,
- ranges=input.get_size(),
- )
- @register_lowering(
- quantized_decomposed.dequantize_per_tensor.tensor, type_promotion_kind=None
- )
- def quantized_decomposed_dequantize_per_tensor_tensor(
- input: TensorBox,
- scale: TensorBox,
- zero_point: TensorBox,
- quant_min: int,
- quant_max: int,
- dtype: torch.dtype,
- *,
- out_dtype: Optional[torch.dtype] = None,
- ) -> TensorBox:
- assert len(scale.get_size()) == 0 or (
- len(scale.get_size()) == 1 and scale.get_size()[0] == 1
- ), "expect scale as scalar tensor"
- assert len(zero_point.get_size()) == 0 or (
- len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1
- ), "expect zero_point as scalar tensor"
- assert input.get_dtype() == dtype, (
- f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
- )
- if out_dtype is None:
- out_dtype = torch.float32
- input_loader = input.make_loader()
- scale_loader = scale.make_loader()
- zero_point_loader = zero_point.make_loader()
- def inner_fn(idx):
- input = input_loader(idx)
- _scale = scale_loader((0,) if len(scale.get_size()) == 1 else ())
- _zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ())
- if scale.dtype != torch.float32:
- _scale = ops.to_dtype(_scale, torch.float32)
- if zero_point.dtype != torch.float32:
- _zero_point = ops.to_dtype(_zero_point, torch.float32)
- val = ops.sub(ops.to_dtype(input, torch.float32), _zero_point) * _scale
- val = ops.to_dtype(val, out_dtype)
- return val
- return Pointwise.create(
- device=input.get_device(),
- dtype=out_dtype,
- inner_fn=inner_fn,
- ranges=input.get_size(),
- )
- @register_lowering(aten.cat)
- def cat(inputs, dim=0):
- cpu_device = inputs[0].get_device().type == "cpu"
- if cpu_device and all(
- input.get_dtype() in [torch.int8, torch.uint8] for input in inputs
- ):
- # TODO <leslie> Remove this fallback when we support vectorization
- # code gen with uint8 data type directly.
- for input in inputs:
- input.realize()
- if all(len(input.get_size()) == 4 for input in inputs):
- inputs, _ = require_channels_last(aten.cat, *inputs)
- return fallback_handler(aten.cat.default)(inputs, dim)
- if len(inputs) == 1:
- return clone(inputs[0])
- dim = _validate_dim(inputs[0], dim, 0)
- dtype = get_promoted_dtype(
- *inputs, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
- )
- inputs = [to_dtype(inp, dtype) for inp in inputs]
- def unwrap_tensor(x: Union[TensorBox, ir.StorageBox]) -> ir.IRNode:
- if isinstance(x, TensorBox):
- if isinstance(x.data, ir.BaseView):
- return x.data.unwrap_view()
- else:
- return x.data
- if isinstance(x, ir.StorageBox):
- return x.data
- return x
- def is_reduction(t):
- return isinstance(t, ir.ComputedBuffer) and isinstance(t.data, ir.Reduction)
- def can_fuse_reduction(t):
- if isinstance(t, (TensorBox, ir.StorageBox)):
- return can_fuse_reduction(unwrap_tensor(t))
- return (
- is_reduction(t)
- or isinstance(t, ir.Pointwise)
- and any(
- can_fuse_reduction(V.graph.get_buffer(read))
- for read in t.get_read_names()
- )
- )
- # fusing reducutions into computed concat buffer can cause regressions.
- fusable_reduction = any(can_fuse_reduction(t) for t in inputs)
- def should_lower_cat_input(x) -> bool:
- # Unrealized inputs will not be storage and layouts, and we dont want to realize
- # them in case we want to fuse
- if ir.is_storage_and_layout(x):
- storage, _ = ir.as_storage_and_layout(x, freeze=False)
- return not ir.ConcatKernel.can_realize_into_without_copy(storage)
- if isinstance(x, (TensorBox, ir.StorageBox)):
- return should_lower_cat_input(unwrap_tensor(x))
- if isinstance(x, ir.Pointwise):
- return True
- return False
- if config.force_pointwise_cat:
- return pointwise_cat(inputs, dim)
- # TODO: We observed negative performance impact of pointwise_cat optimization on CPU so disabled it.
- # We will revisit this later after enabling vectorization on index_expr.
- if cpu_device:
- return TensorBox(ir.ConcatKernel.create(inputs, dim))
- def op_count(x):
- if isinstance(x, (TensorBox, ir.StorageBox)):
- return op_count(unwrap_tensor(x))
- # this will correspond to a direct memory read
- if not isinstance(x, ir.Pointwise):
- return 0
- count = x.inner_fn_opcount().num_ops
- for read in x.get_read_names():
- count += op_count(V.graph.get_buffer(read))
- return count
- # as of inputs increase, possibility for register spilling also increases
- # past a certain threshold of inputs we only fuse if the if the input kernels
- # are simple
- # not sure if we want to expose to users via config since logic may change in future
- MAX_COMPLEX_POINTWISE_CAT = 8
- MAX_SIMPLE_OP_COUNT = 2
- def additional_pointwise_ops(op: torch._ops.OpOverload):
- return op in (aten.cat.default, aten.constant_pad_nd.default)
- if len(inputs) <= MAX_COMPLEX_POINTWISE_CAT or (
- (len(inputs) <= config.max_pointwise_cat_inputs)
- and all(op_count(t) <= MAX_SIMPLE_OP_COUNT for t in inputs)
- ):
- pointwise_uses = all(
- is_pointwise_use(use, additional_pointwise_ops)
- for use in V.current_node.users
- )
- # fuse in case we will be used in a pointwise node, and there are any inputs we
- # we can prevent materialization of.
- fuse_pointwise_use = (
- any(should_lower_cat_input(inp) for inp in inputs) and pointwise_uses
- )
- # horizontal fuse in case all inputs will require a copy kernel anyway.
- # only horizontally fuse pointwise kernels
- horizontal_fuse_cat = all(
- should_lower_cat_input(inp) for inp in inputs
- ) and not any(can_fuse_reduction(t) for t in inputs)
- if fuse_pointwise_use or (horizontal_fuse_cat and not fusable_reduction):
- return pointwise_cat(inputs, dim)
- return TensorBox(ir.ConcatKernel.create(inputs, dim))
- @register_lowering(aten.diagonal, type_promotion_kind=None)
- def diagonal(input, offset: int = 0, dim1: int = 0, dim2: int = 1):
- original_shape = input.get_size()
- num_dims = len(original_shape)
- dim1 = canonicalize_dim(idx=dim1, rank=num_dims)
- dim2 = canonicalize_dim(idx=dim2, rank=num_dims)
- check(
- dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
- )
- offset_negative = V.graph.sizevars.evaluate_expr(sympy.Lt(offset, 0))
- if offset_negative:
- diag_size = V.graph.sizevars.evaluate_max(
- V.graph.sizevars.evaluate_min(
- original_shape[dim1] + offset, original_shape[dim2]
- ),
- 0, # type: ignore[arg-type]
- )
- else:
- diag_size = V.graph.sizevars.evaluate_max(
- V.graph.sizevars.evaluate_min(
- original_shape[dim1], original_shape[dim2] - offset
- ),
- 0, # type: ignore[arg-type]
- )
- base_idx = (0, 0)
- if offset_negative:
- base_idx = (-offset, 0)
- else:
- base_idx = (0, offset)
- sizes = [s for i, s in enumerate(original_shape) if i not in (dim1, dim2)]
- sizes.append(diag_size)
- def reindexer(idx):
- diag_idx = idx[-1]
- original_idx = [0] * len(original_shape)
- cur_dim = 0
- for d in range(num_dims):
- if d == dim1:
- original_idx[d] = diag_idx + base_idx[0]
- elif d == dim2:
- original_idx[d] = diag_idx + base_idx[1]
- else:
- original_idx[d] = idx[cur_dim]
- cur_dim += 1
- assert cur_dim == len(original_shape) - 2
- return original_idx
- return TensorBox(ir.GenericView.create(input, sizes, reindexer))
- @register_lowering(aten.diagonal_copy, type_promotion_kind=None)
- def diagonal_copy(input, offset: int = 0, dim1: int = 0, dim2: int = 1):
- return clone(diagonal(input, offset, dim1, dim2))
- @register_lowering(aten.diagonal_scatter, type_promotion_kind=None)
- def diagonal_scatter(input, src, offset: int = 0, dim1: int = 0, dim2: int = 1):
- output = clone(input)
- target = diagonal(output, offset, dim1, dim2)
- mutate_to(target, src)
- return output
- @register_lowering(aten.select, type_promotion_kind=None)
- def select(x, dim, idx):
- idx = sympy.expand(idx)
- size = sympy.expand(x.get_size()[dim])
- actual_index = None
- if V.graph.sizevars.guard_or_false(sympy.Lt(idx, 0)):
- actual_index = idx + size
- elif V.graph.sizevars.guard_or_false(sympy.Ge(idx, 0)):
- actual_index = idx
- if actual_index is not None:
- if has_free_unbacked_symbols(idx):
- # Inductor could generate incorrect views for tensors with unbacked symbols here;
- # Squeeze operations are translated to views, resulting in incorrect strides.
- # Additionally, we want to avoid accidental unbacked unsqueeze semantics. To resolve this,
- # we use as_strided instead.
- # Removing this branch will cause test_unbacked_select_index_with_check to fail.
- # before accessing size, stride, and offset we need to realize.
- x.realize()
- new_size = x.get_size()
- new_stride = x.get_stride()
- new_storage_offset = x.get_layout().offset + new_stride[dim] * actual_index
- del new_size[dim]
- del new_stride[dim]
- return as_strided(x, new_size, new_stride, new_storage_offset)
- else:
- # no need to clamp, this function handles negative indexing itself
- slice_result = slice_(x, dim, actual_index, actual_index + 1, clamp=False)
- return squeeze(slice_result, dim)
- # Unbacked Semantics:
- # When the index idx is unbacked (e.g., u0), we compute the index dynamically
- # during the lowering of the select operation using DynamicSelectStorageOffset.
- unbacked_bindings = resolve_unbacked_bindings(
- V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"]
- )
- assert unbacked_bindings is not None
- assert len(unbacked_bindings) == 1, unbacked_bindings
- unbacked_offset_sym, _ = next(iter(unbacked_bindings.items()))
- # before accessing size, stride, and offset we need to realize.
- x.realize()
- new_size = x.get_size()
- new_stride = x.get_stride()
- new_storage_offset = unbacked_offset_sym
- buffer = ir.DynamicSelectStorageOffset(
- unbacked_offset_sym,
- idx,
- x.get_layout().offset,
- new_stride[dim],
- x.get_size()[dim],
- clamp=False,
- )
- buffer.name = V.graph.register_buffer(buffer)
- V.graph.register_operation(buffer)
- del new_size[dim]
- del new_stride[dim]
- return as_strided(x, new_size, new_stride, new_storage_offset)
- @register_lowering(aten.split, type_promotion_kind=None)
- def split(x, sizes, dim=0):
- dim = _validate_dim(x, dim, 0)
- sizes_ = sizes
- # If sizes is an integer (or a SymInt), we turn it into a list of sizes
- # by computing what the actual size of each chunk should be.
- if not isinstance(sizes, (list, tuple)):
- x_size = x.get_size()[dim]
- chunks = V.graph.sizevars.guard_int(FloorDiv(x_size + sizes - 1, sizes))
- sizes_ = [sizes] * chunks
- # The last chunk might have a smaller size than the rest.
- sizes_[-1] = x_size - (chunks - 1) * sizes
- # From this point, we assume that the sum of the sizes of all chunks
- # equals the size of the base tensor.
- result = []
- start = 0
- for size in sizes_:
- end = start + size
- # No need for clamping here, since we compute the exact
- # start and end values.
- result.append(slice_(x, dim, start, end, clamp=False))
- start = end
- return result
- @register_lowering(aten.split_with_sizes, type_promotion_kind=None)
- def split_with_sizes(x, sizes, dim=0):
- return split(x, sizes, dim)
- @register_lowering(aten.unbind, type_promotion_kind=None)
- def unbind(x, dim=0):
- dim = _validate_dim(x, dim, 0)
- x_size = V.graph.sizevars.guard_int(x.get_size()[dim])
- result = [select(x, dim, i) for i in range(x_size)]
- return result
- @register_lowering(aten.unfold, type_promotion_kind=None)
- def unfold(x, dimension, size, step):
- sizes = x.get_size()
- ndim = len(sizes)
- dim = canonicalize_dim(ndim, dimension)
- if ndim == 0:
- return slice_(unsqueeze(x, 0), end=size, clamp=False)
- dim_size = sizes[dim]
- sizevars = V.graph.sizevars
- sizevars.check_leq(size, dim_size)
- sizevars.check_lt(0, step) # type: ignore[arg-type]
- new_dim_size = FloorDiv(dim_size - size, step) + 1
- if sizevars.size_hint_or_throw(dim_size) > 0:
- x.mark_reuse(
- sizevars.size_hint_or_throw(CeilDiv(new_dim_size * size, dim_size))
- )
- out_size = [*sizes[:dim], new_dim_size, *sizes[dim + 1 :], size]
- def reindexer(idx):
- dim_idx = idx[-1] + idx[dim] * step
- return (*idx[:dim], dim_idx, *idx[dim + 1 : -1])
- return TensorBox(ir.GenericView.create(x, out_size, reindexer))
- @register_lowering(aten.unsqueeze, type_promotion_kind=None)
- def unsqueeze(x, dim):
- dim = _validate_dim(x, dim, 1)
- new_shape = list(x.get_size())
- new_shape.insert(dim, sympy.S.One)
- return view(x, new_shape)
- @register_lowering(aten.unsqueeze_, type_promotion_kind=None)
- def unsqueeze_(x, dim):
- val = unsqueeze(x, dim)
- assert isinstance(x, TensorBox)
- assert isinstance(val, TensorBox)
- x.data = val.data
- return x
- def _validate_dim(x, dim, offset=0):
- dim = V.graph.sizevars.shape_env.evaluate_expr(sympy.sympify(dim))
- ndim = len(x.get_size())
- if dim < 0:
- dim += ndim + offset
- assert 0 <= dim < ndim + offset
- return dim
- @register_lowering(aten.glu)
- def glu(x, dim=-1):
- dim = _validate_dim(x, dim, 0)
- # TODO: don't guard on static shape here
- new_len = V.graph.sizevars.guard_int(x.get_size()[dim]) // 2
- # no need to clamp, index is int based on input size
- a = slice_(x, dim, 0, new_len, clamp=False)
- b = slice_(x, dim, new_len, new_len * 2, clamp=False)
- return mul(a, sigmoid(b))
- def fallback_handler(kernel, add_to_fallback_set=True):
- if add_to_fallback_set:
- fallbacks.add(kernel)
- def handler(*args, **kwargs):
- def wrap_tensors(x):
- return TensorBox.create(x) if isinstance(x, ir.IRNode) else x
- return pytree.tree_map(
- wrap_tensors, ir.FallbackKernel.create(kernel, *args, **kwargs)
- )
- # This lets us detect that a lowering is a fallback handler.
- handler._is_fallback_handler = True # type: ignore[attr-defined]
- return handler
- @functools.cache
- def _warn_complex_not_supported():
- warnings.warn(
- "Torchinductor does not support code generation for complex operators. Performance may be worse than eager."
- )
- # There are some types (CPU) which we accept as input but not as
- # output.
- def unsupported_input_tensor(t: torch.Tensor, node=None):
- "Do not support reading or writing to this tensor"
- if t.is_complex():
- # Complex views are supported with IR ComplexView
- _warn_complex_not_supported()
- return True
- if t.is_meta:
- return True
- if t.is_sparse:
- return True
- if t.dtype == torch.float8_e8m0fnu:
- if not node:
- return True
- # allow bitcast, views, memory movement, but not arithmetic
- # TODO: delete once triton adds native support
- return not (
- isinstance(node.target, torch._ops.OpOverload)
- and node.target
- in (
- aten.view.dtype,
- aten.cat.default,
- aten.clone.default,
- aten._scaled_mm.default,
- )
- or (isinstance(node.target, torch._ops.OpOverload) and is_view(node.target))
- )
- return False
- def unsupported_output_tensor(t: torch.Tensor, node=None):
- "Do not support writing tensor but can read from it"
- supported_complex_views = (
- aten.view.dtype,
- torch.ops.prims.convert_element_type.default,
- )
- if node is not None and node.target in supported_complex_views and t.is_complex():
- return False
- if unsupported_input_tensor(t, node):
- return True
- return t.is_cpu and config.disable_cpp_codegen
- def fallback_node_due_to_unsupported_type(node: torch.fx.Node, allow_cpu_inputs=True):
- # Custom fallback lowering
- if node.target is aten.view_as_complex.default:
- return False
- if node.op == "placeholder":
- return False
- # We should be able to remove this special case once `disable_cpp_codegen` is killed.
- if node.target is aten.lift_fresh_copy.default:
- return False
- def check_skip_condition(inp_out_node, is_output):
- if not isinstance(inp_out_node, torch.fx.Node):
- return False
- if "val" not in inp_out_node.meta:
- return False
- for meta in pytree.tree_leaves(inp_out_node.meta["val"]):
- if not isinstance(meta, torch._subclasses.FakeTensor):
- continue
- if is_output:
- if unsupported_output_tensor(meta, node):
- return True
- else:
- if unsupported_input_tensor(meta, node):
- return True
- return False
- # only skip codegen if there is a cpu output, not input
- for arg in pytree.arg_tree_leaves(*node.args, **node.kwargs):
- if check_skip_condition(arg, is_output=False):
- return True
- return check_skip_condition(node, is_output=True)
- def make_fallback(op, layout_constraint=None, warn=True, override_decomp=False):
- # When emulate_precision_casts is enabled, we skip decomposing addcmul ops
- # to use the inductor lowering which preserves FMA semantics.
- # For _foreach_addcdiv, we use the native CUDA kernel.
- skip_decomp_for_precision = config.emulate_precision_casts and op in {
- aten.addcmul,
- aten._foreach_addcmul.Scalar,
- aten._foreach_addcdiv.Scalar,
- }
- assert op not in decompositions or override_decomp or skip_decomp_for_precision, (
- f"both a fallback and a decomp for same op: {op}"
- )
- if (
- warn
- and bool(os.getenv("CI"))
- and get_decompositions([op])
- # if fallback_random, we allow not decomposing random
- and not (
- config.fallback_random
- and op in torch._decomp.decompositions_for_rng.extra_random_decomps
- )
- and not override_decomp
- ):
- # Note: 'warn' is holdover from when this was a warning, but for ops that previously
- # set warn=False we do not want a CI error.
- # Ignore the 'suppress errors' configs in CI, as this particular warning happens on startup anyway and is not
- # likely to be triggered preferentially on one CI config over another.
- if torch._dynamo.config.suppress_errors:
- torch._dynamo.config.suppress_errors = False
- log.warning(
- "A make_fallback error occurred in suppress_errors config,"
- " and suppress_errors is being disabled to surface it."
- )
- raise AssertionError(
- f"make_fallback({op}): a decomposition exists, we should switch to it."
- " To fix this error, either add a decomposition to core_aten_decompositions (preferred)"
- " or inductor_decompositions, and delete the corresponding `make_fallback` line."
- " Get help from the inductor team if unsure, don't pick arbitrarily to unblock yourself.",
- )
- def register_fallback(op_overload):
- add_needs_realized_inputs(op_overload)
- if layout_constraint is not None:
- add_layout_constraint(op_overload, layout_constraint)
- return register_lowering(op_overload, type_promotion_kind=None)(
- fallback_handler(op_overload)
- )
- if isinstance(op, torch._ops.OpOverloadPacket):
- for ol in op.overloads():
- op_overload = getattr(op, ol)
- register_fallback(op_overload)
- elif isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)):
- register_fallback(op)
- else:
- raise RuntimeError(f"Unsupported fallback {op} with type {type(op)}")
- def philox_rand_offset(shape):
- """
- TorchInductor offset calculation differs from PyTorch eager offset
- calculation for random ops (tl.rand vs torch.rand). In future, we should
- strive for same impl for tl.rand and torch.rand.
- """
- numel = 1
- for s in shape:
- numel = numel * s
- return tensor(numel, dtype=torch.int64)
- @register_lowering(torch.ops.rngprims.philox_rand, type_promotion_kind=None)
- def philox_rand(size, seed, offset, stride, device, dtype):
- # stride arg is optional and will be used in future for distributed random
- # ops. Currently, its unused.
- random_pos = ir.FixedLayout(
- device,
- dtype,
- size,
- ir.FlexibleLayout.contiguous_strides(size),
- ).make_indexer()
- seed_loader = seed.make_loader()
- offset_loader = offset.make_loader()
- def inner_fn(index):
- # Both seed and offset in the philox_rand op are tensors.
- # torch seed and offsets are of type int64, but tl.rand accepts int32
- seed_index_expr = ops.to_dtype(seed_loader([]), torch.int32)
- offset_index_expr = ops.to_dtype(offset_loader([]), torch.int32)
- # Get the offset'd position
- rand_index_expr = ops.add(
- ops.index_expr(random_pos(index), torch.int32), offset_index_expr
- )
- result = ops.rand(
- seed_index_expr,
- rand_index_expr,
- )
- return ops.to_dtype(result, dtype)
- random_values_node = Pointwise.create(
- device=device,
- dtype=dtype,
- inner_fn=inner_fn,
- ranges=list(size),
- )
- offset_node = philox_rand_offset(size)
- return random_values_node, offset_node
- @register_lowering(aten.native_dropout, type_promotion_kind=None)
- def native_dropout(x, p, train):
- if config.fallback_random:
- return pytree.tree_map(
- TensorBox.create,
- ir.FallbackKernel.create(aten.native_dropout.default, x, p, train),
- )
- else:
- raise AssertionError("should be handled in replace_random.py")
- @register_lowering(aten.bernoulli_, type_promotion_kind=None)
- def bernoulli_(x, *args):
- assert config.fallback_random or x.get_device() == torch.device("cpu"), (
- "this should be handled in decomps unless config.fallback_random or the device is CPU"
- )
- x.realize()
- op_overload = (
- aten.bernoulli_.float
- if len(args) == 0 or isinstance(args[0], float)
- else aten.bernoulli_.Tensor
- )
- ir.InplaceBernoulliFallback(op_overload, x, *args)
- return x
- @register_lowering(aten.bernoulli.p, type_promotion_kind=None)
- def bernoulli_p(x, *args):
- assert config.fallback_random or x.get_device() == torch.device("cpu"), (
- "this should be handled in decomps unless config.fallback_random or the device is CPU"
- )
- return bernoulli_(clone(x), *args)
- # This shouldn't be called in general
- @register_lowering(aten._foobar)
- def _foobar(_):
- raise AssertionError
- @functools.lru_cache(1)
- def _warn_triton_random(salt):
- log.info("using triton random, expect difference from eager")
- def warn_triton_random():
- # only warn once per graph
- _warn_triton_random(V.graph.creation_time)
- fallback_rand_default = fallback_handler(aten.rand.default)
- fallback_rand_generator = fallback_handler(aten.rand.generator)
- fallback_randn_default = fallback_handler(aten.randn.default)
- fallback_randn_generator = fallback_handler(aten.randn.generator)
- make_fallback(aten.randint)
- # TODO: mlazos reevaluate if we want to codegen something different
- make_fallback(torch.ops.streams.record_event.default)
- make_fallback(torch.ops.streams.wait_event.default)
- @register_lowering(aten.rand)
- def rand(*args, **kwargs):
- if kwargs.get("generator") is not None:
- return fallback_rand_generator(*args, **kwargs)
- elif config.fallback_random:
- kwargs.pop("generator", None)
- return fallback_rand_default(*args, **kwargs)
- raise AssertionError("should have been handled in replace_random.py")
- @register_lowering(aten.randn)
- def randn(*args, **kwargs):
- if kwargs.get("generator") is not None:
- return fallback_randn_generator(*args, **kwargs)
- elif config.fallback_random:
- kwargs.pop("generator", None)
- return fallback_randn_default(*args, **kwargs)
- raise AssertionError("should have been handled in replace_random.py")
- @register_lowering(inductor_prims.force_stride_order, type_promotion_kind=None)
- def inductor_force_stride_order(input_tensor, stride):
- stride_order = ir.get_stride_order(stride)
- return ir.ExternKernel.require_stride_order(input_tensor, stride_order)
- @register_lowering(inductor_prims.seed, type_promotion_kind=None)
- def inductor_seed(device: torch.device):
- raise AssertionError("should be handled in fuse_seed_creation_pass()")
- @register_lowering(inductor_prims.seeds, type_promotion_kind=None)
- def inductor_seeds(count, device):
- warn_triton_random()
- return TensorBox.create(ir.RandomSeeds(count, decode_device(device)))
- @register_lowering(inductor_prims.lookup_seed, type_promotion_kind=None)
- def inductor_lookup_seed(seeds, index):
- def inner_fn(_):
- return ops.load_seed(seeds.get_name(), index)
- return Pointwise.create(
- device=seeds.get_device(),
- dtype=seeds.get_dtype(),
- inner_fn=inner_fn,
- ranges=[],
- )
- @register_lowering(inductor_prims.random, type_promotion_kind=None)
- def inductor_random(size: list[int], seed: TensorBox, mode: str, *, offset: int = 0):
- assert not config.fallback_random
- assert mode in ("rand", "randn")
- size = [*size]
- dtype = torch.float32
- device = seed.get_device_or_error()
- random_pos = ir.FixedLayout(
- device, dtype, size, ir.FlexibleLayout.contiguous_strides(size), offset=offset
- ).make_indexer()
- seed_loader = seed.make_loader()
- def inner_fn(index):
- return getattr(ops, mode)(
- seed_loader([]),
- ops.index_expr(random_pos(index), torch.int32),
- )
- result = Pointwise.create(
- device=device,
- dtype=dtype,
- inner_fn=inner_fn,
- ranges=[*size],
- )
- result.realize()
- return result
- @register_lowering(inductor_prims.randint, type_promotion_kind=None)
- def inductor_randint(
- low: int, high: int, size: list[int], seed: TensorBox, *, offset: int = 0
- ):
- assert not config.fallback_random
- size = [*size]
- dtype = torch.int64
- device = seed.get_device_or_error()
- random_pos = ir.FixedLayout(
- device, dtype, size, ir.FlexibleLayout.contiguous_strides(size), offset=offset
- ).make_indexer()
- seed_loader = seed.make_loader()
- def inner_fn(index):
- return ops.randint64(
- seed_loader([]),
- ops.index_expr(random_pos(index), torch.int32),
- ops.index_expr(low, torch.int64),
- ops.index_expr(high, torch.int64),
- )
- return Pointwise.create(
- device=device,
- dtype=dtype,
- inner_fn=inner_fn,
- ranges=[*size],
- )
- def _boundaries_helper(tb: TensorBox) -> tuple[str, sympy.Expr, sympy.Expr, sympy.Expr]:
- # Calculate the maximum offset for the boundaries tensor
- # For a strided tensor, this is sum((size[i] - 1) * stride[i]) + stride[-1]
- # This ensures the mask check in bucketize_binary_search works correctly
- # for both contiguous and non-contiguous tensors.
- size = tb.get_size()
- stride = tb.get_stride()
- max_offset = sum((s - 1) * st for s, st in zip(size, stride)) + stride[-1]
- return (
- tb.get_name(),
- size[-1],
- max_offset,
- stride[-1],
- )
- def _sorter_helper(tb: TensorBox) -> tuple[str, sympy.Expr]:
- return tb.get_name(), tb.get_stride()[-1]
- @register_lowering(aten.searchsorted.Tensor, type_promotion_kind=None)
- def searchsorted(
- sorted_sequence: TensorBox,
- self: TensorBox,
- *,
- out_int32: bool = False,
- right: bool = False,
- side: Optional[str] = None,
- sorter: Optional[TensorBox] = None,
- ) -> TensorBox:
- validate_bucketize = lambda tb: V.graph.has_feature( # noqa: E731
- tb, BackendFeature.BUCKETIZE
- )
- if (
- not validate_bucketize(sorted_sequence)
- or not validate_bucketize(self)
- or (sorter is not None and not validate_bucketize(sorter))
- ):
- return fallback_handler(aten.searchsorted.Tensor, add_to_fallback_set=False)(
- sorted_sequence,
- self,
- out_int32=out_int32,
- right=right,
- side=side,
- sorter=sorter,
- )
- # If side is present, override the value of right if needed. This assumes that
- # validation of the two options being non-contradictory is already done by the
- # searchsorted meta-function.
- if side is not None and side == "right":
- right = True
- index_dtype = torch.int32 if out_int32 else torch.int64
- values_loader = self.make_loader()
- # The entire sorted_sequence tensor needs to be used by ops.bucketize, so we need to
- # realize it into global memory; or in other words, we can't guarantee that
- # sorted_sequence.get_name() (used below) will exist unless we call
- # sorted_sequence.realize().
- sorted_sequence.realize()
- if sorter is not None:
- sorter.realize()
- if len(sorted_sequence.get_size()) == 1:
- def inner_fn(idx):
- val = values_loader(idx)
- return ops.bucketize(
- val,
- _boundaries_helper(sorted_sequence),
- 0,
- index_dtype,
- right,
- sorter=None if sorter is None else _sorter_helper(sorter),
- sorter_indices=None if sorter is None else 0,
- )
- else:
- def inner_fn(idx):
- val = values_loader(idx)
- # Get index to the beginning of the sorted sequence within a flattened
- # version of the array.
- def get_flattened_index(tb: TensorBox):
- strides = tb.get_stride()
- return ops.index_expr(
- functools.reduce(
- operator.add, (s * i for s, i in zip(strides[:-1], idx[:-1]))
- ),
- index_dtype,
- )
- return ops.bucketize(
- val,
- _boundaries_helper(sorted_sequence),
- get_flattened_index(sorted_sequence),
- index_dtype,
- right,
- sorter=None if sorter is None else _sorter_helper(sorter),
- sorter_indices=None if sorter is None else get_flattened_index(sorter),
- )
- device = self.get_device()
- result = Pointwise.create(
- device=device,
- dtype=index_dtype,
- inner_fn=inner_fn,
- ranges=self.shape,
- )
- # see [NOTE: inductor bucketize realize]
- result.realize()
- return result
- @register_lowering(
- aten.bucketize.Tensor, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH
- )
- def bucketize(
- input: TensorBox,
- boundaries: TensorBox,
- *,
- out_int32: bool = False,
- right: bool = False,
- ):
- assert len(boundaries.get_size()) == 1
- if not (
- V.graph.has_feature(input, BackendFeature.BUCKETIZE)
- and V.graph.has_feature(boundaries, BackendFeature.BUCKETIZE)
- ):
- return fallback_handler(aten.bucketize.Tensor, add_to_fallback_set=False)(
- input, boundaries, out_int32=out_int32, right=right
- )
- # The entire boundaries tensor needs to be used by ops.bucketize, so we
- # need to realize it into global memory; or in other words, we can't
- # guarantee that boundaries.get_name() (used below) will exist unless
- # we call boundaries.realize().
- boundaries.realize()
- device = input.get_device()
- input_loader = input.make_loader()
- index_dtype = torch.int32 if out_int32 else torch.int64
- def inner_fn(index):
- val = input_loader(index)
- indices = ops.bucketize(
- val,
- _boundaries_helper(boundaries),
- 0,
- index_dtype,
- right,
- )
- return indices
- result = Pointwise.create(
- device=device,
- dtype=index_dtype,
- inner_fn=inner_fn,
- ranges=input.get_size(),
- )
- # [NOTE: inductor bucketize realize]
- # bucketize_binary_search is relatively expensive, so we don't want to re-compute
- # it unnecessarily. If we run bucketize() and then broadcast the result, we don't
- # want this to be fused into a large number of duplicate bucketize() computations
- # for each of the elements in the result.
- #
- # If no broadcasting occurs, fusions can still occur in scheduler.py
- result.realize()
- return result
- def require_dense(_, *args, **kwargs):
- args, kwargs = pytree.tree_map_only(
- ir.IRNode, ir.ExternKernel.require_stride1, (args, kwargs)
- )
- return args, kwargs
- def require_contiguous(_, *args, **kwargs):
- args, kwargs = pytree.tree_map_only(
- ir.IRNode, ir.ExternKernel.require_contiguous, (args, kwargs)
- )
- return args, kwargs
- def require_contiguous_strides(_, *args, **kwargs):
- # TODO: combine this with require_contiguous after
- # https://github.com/pytorch/pytorch/pull/148235 lands.
- args, kwargs = pytree.tree_map_only(
- ir.IRNode, ir.ExternKernel.require_contiguous_strides, (args, kwargs)
- )
- return args, kwargs
- def require_channels_last(_, *args, **kwargs):
- args, kwargs = pytree.tree_map_only(
- ir.IRNode, ir.ExternKernel.require_channels_last, (args, kwargs)
- )
- return args, kwargs
- def constrain_to_fake_tensor(arg, fake_arg):
- if fake_arg is None:
- return arg
- if isinstance(fake_arg, FakeScriptObject):
- return arg
- if isinstance(arg, ir.IRNode):
- return ir.ExternKernel.require_exact_strides(arg, fake_arg.stride())
- if isinstance(arg, dict):
- return {key: constrain_to_fake_tensor(arg[key], fake_arg[key]) for key in arg}
- elif isinstance(arg, (tuple, list)):
- return type(arg)(
- constrain_to_fake_tensor(a, f_a) for (a, f_a) in zip(arg, fake_arg)
- )
- return arg
- def constrain_to_fake_tensors(args, kwargs, fake_args, fake_kwargs):
- args = tuple(
- constrain_to_fake_tensor(arg, fake_arg)
- for arg, fake_arg in zip(args, fake_args)
- )
- kwargs = {k: constrain_to_fake_tensor(v, fake_kwargs[k]) for k, v in kwargs.items()}
- return args, kwargs
- def constrain_to_fx_strides(fx_node, *args, **kwargs):
- def apply_constraint(arg, fx_arg):
- if isinstance(arg, ir.IRNode):
- stride_order = ir.get_stride_order(
- fx_arg.meta["val"].stride(), V.graph.sizevars.shape_env
- )
- return ir.ExternKernel.require_stride_order(arg, stride_order)
- if isinstance(arg, dict):
- return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg}
- return arg
- args = tuple(
- apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)
- )
- kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
- return args, kwargs
- def sdpa_constraint(fx_node, *args, **kwargs):
- # sdpa requires dense last dimension]
- def apply_constraint(idx, arg, fx_arg):
- if not isinstance(arg, ir.IRNode):
- return arg
- meta_val = fx_arg.meta["val"]
- meta_stride_expr = [
- s.node.expr if isinstance(s, torch.SymInt) else s for s in meta_val.stride()
- ]
- shape_env = V.graph.sizevars.shape_env
- stride_order = ir.get_stride_order(meta_val.stride(), shape_env)
- if stride_order and stride_order[-1] != 0:
- # contiguous stride order
- stride_order = list(reversed(range(len(arg.get_size()))))
- if (
- fx_node.target
- == aten._scaled_dot_product_efficient_attention_backward.default
- and idx in (0, 5)
- ):
- assert len(stride_order) == 4
- # The 0 and 5th arguments for aten._scaled_dot_product_efficient_attention_backward.default
- # are for out and gradient_out. They have to be in
- # (3, 1, 2, 0) stride order. Otherwise the kernel will crash.
- # Check https://github.com/pytorch/pytorch/issues/138772
- stride_order = (3, 1, 2, 0)
- if not meta_val.is_cuda:
- return ir.ExternKernel.require_stride_order(arg, stride_order)
- # This is the minimum alignment required by SDPA kernels for attention_bias.
- # This value can be found in pytorch/aten/src/ATen/native/transformers/attention.cpp preprocess_mask
- ALIGNMENT = 8
- # effn_attn_fwd does requires dense last dim, not just alignment
- effn_attn_fwd_bias = (
- fx_node.target
- == torch.ops.aten._scaled_dot_product_efficient_attention.default
- and idx == 3
- )
- assert isinstance(arg, TensorBox)
- if len(arg.get_size()) not in (3, 4):
- return arg
- is_aligned_tensor = ir.is_aligned_realized_tensor(arg, ALIGNMENT)
- if is_aligned_tensor:
- return ir.try_match_insignificant_strides(
- ir.ExternKernel.realize_input(arg), meta_stride_expr
- )
- if (
- isinstance(arg, IRNode)
- and arg.maybe_get_stride() is not None
- and is_aligned_tensor
- ):
- return ir.try_match_insignificant_strides(
- ir.ExternKernel.realize_input(arg), meta_stride_expr
- )
- if effn_attn_fwd_bias:
- out_size = list(arg.get_size())
- expanded_dims = []
- # We require a dense last dimension, but the other strides
- # can be expanded, which results in a smaller tensor
- maybe_stride = arg.maybe_get_stride()
- for i in range(len(arg.get_size()) - 1):
- if V.graph.sizevars.statically_known_equals(meta_stride_expr[i], 0) or (
- maybe_stride is not None
- and V.graph.sizevars.statically_known_equals(maybe_stride[i], 0)
- ):
- expanded_dims.append(i)
- # Now, pad strides to alignment
- out_strides = [-1] * len(out_size)
- out_strides[-1] = 1
- stride = 1
- for i in range(len(out_size) - 2, -1, -1):
- if out_strides[i + 1] != 0:
- stride = stride * out_size[i + 1]
- # the expanded dims still need to be aligned, if they are,
- # we can make them expanded by setting the stride equal to 0
- if i in expanded_dims:
- if V.graph.sizevars.statically_known_equals(
- out_strides[i + 1] % ALIGNMENT, 0
- ):
- out_strides[i] = 0
- continue
- if not V.graph.sizevars.statically_known_equals(stride % ALIGNMENT, 0):
- stride = ceildiv(stride, ALIGNMENT) * ALIGNMENT
- out_strides[i] = stride
- return ir.ExternKernel.require_exact_strides(arg, out_strides)
- if is_aligned_tensor:
- return ir.try_match_insignificant_strides(
- ir.ExternKernel.realize_input(arg), meta_stride_expr
- )
- if (
- isinstance(arg, IRNode)
- and arg.maybe_get_stride() is not None
- and is_aligned_tensor
- ):
- return ir.try_match_insignificant_strides(
- ir.ExternKernel.realize_input(arg), meta_stride_expr
- )
- def is_aligned(x):
- return V.graph.sizevars.guard_or_false(
- sympy.Eq(Mod(x.get_size()[-1], ALIGNMENT), 0)
- )
- if isinstance(arg.data, ir.BaseView):
- if not is_aligned(arg):
- if is_aligned(arg.unwrap_view()):
- return ir.try_match_insignificant_strides(
- ir.ExternKernel.realize_input(arg), meta_stride_expr
- )
- return ir.ExternKernel.require_stride_order(arg, stride_order)
- args = tuple(
- apply_constraint(idx, arg, fx_arg)
- for idx, (arg, fx_arg) in enumerate(zip(args, fx_node.args))
- )
- kwargs = {k: apply_constraint(-1, v, fx_node.kwargs[k]) for k, v in kwargs.items()}
- return args, kwargs
- # WIP
- make_fallback(aten._adaptive_avg_pool3d) # @isuruf
- make_fallback(aten.adaptive_max_pool3d) # @isuruf
- make_fallback(aten._scaled_dot_product_attention_math_for_mps) # @malfet
- # 1) Easy
- make_fallback(aten.uniform, warn=False)
- make_fallback(aten.exponential.default, warn=False) # (fails accuracy on test_torch.py)
- make_fallback(aten._pdist_forward, require_contiguous) # Has decomp. Needs benchmarks
- make_fallback(aten.soft_margin_loss_backward, warn=False) # py_impl?
- make_fallback(aten._fused_rms_norm, warn=False) # (MPS-only and faster than decomp)
- if torch.xpu.is_available():
- make_fallback(
- aten.embedding_dense_backward, warn=False
- ) # (XPU-only and faster than decomp)
- if torch.mtia._is_compiled():
- make_fallback(
- aten.native_layer_norm, warn=False
- ) # (MTIA-only and faster than decomp)
- # 1.5) Easy or Impossible
- make_fallback(aten._cdist_forward) # p=2 should be feasible
- make_fallback(aten._cdist_backward)
- # 2) Medium
- make_fallback(aten._trilinear)
- # 3) Difficult
- # Scans
- # See the discussion at
- # https://dev-discuss.pytorch.org/t/pytorch-sparse-gnn-compiler-rfc/1644/19
- make_fallback(aten.segment_reduce.default)
- make_fallback(aten._segment_reduce_backward.default)
- # Histogram (need to implement Histogram IR)
- make_fallback(aten.histc)
- make_fallback(aten.histogram.bin_ct)
- make_fallback(aten._histogramdd_bin_edges.default)
- make_fallback(aten._histogramdd_from_bin_cts.default)
- # Need templated kernel
- make_fallback(aten.addbmm)
- make_fallback(aten._addmm_activation, warn=False)
- make_fallback(aten._grouped_mm, require_dense)
- # Need templated kernel. Probably impossible to write efficiently
- make_fallback(aten.convolution_backward, constrain_to_fx_strides)
- make_fallback(aten._cudnn_rnn, require_dense)
- make_fallback(aten._cudnn_rnn_backward, require_contiguous)
- make_fallback(aten.miopen_rnn, require_dense)
- make_fallback(aten.miopen_rnn_backward, require_contiguous)
- # Haven't checked but sound difficult / impossible
- make_fallback(aten._embedding_bag, require_contiguous)
- make_fallback(aten._embedding_bag_forward_only, require_contiguous)
- make_fallback(aten._embedding_bag_backward)
- make_fallback(aten._embedding_bag_per_sample_weights_backward)
- make_fallback(aten._embedding_bag_per_sample_weights_backward)
- make_fallback(aten._fused_moving_avg_obs_fq_helper)
- make_fallback(aten._fused_moving_avg_obs_fq_helper_functional)
- # 4) Backwards (try py_impl'ing them) when fwd is written as a decomp
- make_fallback(aten.max_pool3d_with_indices_backward)
- make_fallback(aten._adaptive_avg_pool2d_backward, require_dense)
- make_fallback(aten._adaptive_avg_pool3d_backward)
- make_fallback(aten.adaptive_max_pool2d_backward)
- make_fallback(aten.adaptive_max_pool3d_backward)
- make_fallback(aten.fractional_max_pool2d_backward)
- make_fallback(aten.fractional_max_pool3d_backward)
- make_fallback(aten.replication_pad1d_backward)
- make_fallback(aten.replication_pad2d_backward)
- make_fallback(aten.upsample_linear1d_backward)
- make_fallback(aten.upsample_bicubic2d_backward, require_contiguous)
- make_fallback(aten.upsample_trilinear3d_backward)
- make_fallback(aten.grid_sampler_2d_backward)
- make_fallback(aten._pdist_backward, require_contiguous)
- # 5) Impossible (missing triton/CPU features)
- # Sorting / Sorting-like
- make_fallback(aten.sort)
- make_fallback(aten.sort.stable)
- make_fallback(aten.kthvalue)
- make_fallback(aten.topk)
- make_fallback(aten.mode)
- make_fallback(aten.median)
- make_fallback(aten.nanmedian)
- make_fallback(aten.randperm)
- # see: https://github.com/pytorch/pytorch/pull/121354
- make_fallback(aten.resize_)
- make_fallback(aten.resize_as_)
- # Linalg
- make_fallback(aten._linalg_det)
- make_fallback(aten.linalg_householder_product)
- make_fallback(aten.linalg_inv_ex)
- make_fallback(aten.linalg_ldl_factor_ex)
- make_fallback(aten.linalg_ldl_solve)
- make_fallback(aten.linalg_lu)
- make_fallback(aten.linalg_lu_factor_ex)
- make_fallback(aten.linalg_lu_solve)
- make_fallback(aten.linalg_matrix_exp)
- make_fallback(aten.linalg_qr)
- make_fallback(aten._linalg_slogdet)
- make_fallback(aten._linalg_solve_ex)
- make_fallback(aten.linalg_solve_triangular)
- make_fallback(aten._linalg_svd)
- make_fallback(aten.lu_unpack)
- make_fallback(aten.ormqr)
- make_fallback(aten._linalg_check_errors)
- make_fallback(aten.linalg_pinv.atol_rtol_tensor)
- make_fallback(aten._linalg_eigh)
- make_fallback(aten.triangular_solve)
- make_fallback(aten.linalg_cholesky_ex)
- make_fallback(aten.cholesky_inverse)
- make_fallback(aten.cholesky_solve)
- make_fallback(aten.geqrf)
- make_fallback(aten._fft_r2c) # needs complex as well
- # Data dependent (are these necessary?)
- make_fallback(aten.nonzero.default)
- # Misc
- make_fallback(aten.gcd.default, warn=False)
- make_fallback(aten._thnn_fused_lstm_cell, require_dense)
- make_fallback(torch._prims.rng_prims.run_and_save_rng_state)
- make_fallback(torch._prims.rng_prims.run_with_rng_state)
- make_fallback(torch._prims.rng_prims.graphsafe_run_with_rng_state)
- # Implemented / Half implemented
- # Scans. Implemented for CUDA, missing CPU
- make_fallback(aten.masked_scatter)
- make_fallback(aten.masked_scatter_backward)
- # Complex number support
- make_fallback(aten.view_as_complex, require_contiguous)
- make_fallback(aten.angle) # needs complex
- # Needs efficentzerotensor
- make_fallback(aten._efficientzerotensor)
- # Needs Sparse
- make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors)
- make_fallback(aten.to_sparse)
- make_fallback(aten._to_sparse)
- # Needs dimname support
- make_fallback(aten.zeros.names)
- # 6) Pattern-matched
- make_fallback(
- aten._scaled_dot_product_efficient_attention.default,
- sdpa_constraint,
- warn=False,
- )
- make_fallback(
- aten._scaled_dot_product_efficient_attention_backward.default,
- sdpa_constraint,
- warn=False,
- )
- make_fallback(
- aten._scaled_dot_product_flash_attention.default,
- sdpa_constraint,
- warn=False,
- )
- make_fallback(
- aten._scaled_dot_product_flash_attention.quantized,
- warn=False,
- )
- make_fallback(
- aten._scaled_dot_product_flash_attention_backward.default,
- sdpa_constraint,
- warn=False,
- )
- make_fallback(
- aten._scaled_dot_product_cudnn_attention.default,
- sdpa_constraint,
- warn=False,
- )
- make_fallback(
- aten._scaled_dot_product_cudnn_attention_backward.default,
- sdpa_constraint,
- warn=False,
- )
- make_fallback(
- aten._scaled_dot_product_flash_attention_for_cpu.default,
- sdpa_constraint,
- warn=False,
- )
- make_fallback(
- aten._scaled_dot_product_flash_attention_for_cpu_backward.default,
- sdpa_constraint,
- warn=False,
- )
- make_fallback(
- aten._scaled_dot_product_fused_attention_overrideable.default,
- sdpa_constraint,
- warn=False,
- )
- make_fallback(
- aten._scaled_dot_product_fused_attention_overrideable_backward.default,
- sdpa_constraint,
- warn=False,
- )
- make_fallback(aten._flash_attention_forward.default, sdpa_constraint)
- make_fallback(aten._flash_attention_forward.quantized)
- make_fallback(aten._flash_attention_backward.default, sdpa_constraint)
- make_fallback(aten._efficient_attention_forward.default, sdpa_constraint)
- make_fallback(aten._efficient_attention_backward.default, sdpa_constraint)
- # index_reduce requires fallback when use_scatter_fallback(...) returns True
- make_fallback(aten.index_reduce)
- make_fallback(aten.repeat_interleave.Tensor, override_decomp=True)
- make_fallback(aten._weight_norm_interface_backward.default, require_contiguous)
- # Register with type_promotion_kind None.
- # For example, fp16.copy_(fp32) should **not** promote the first input's dtype.
- @register_lowering(aten.copy, type_promotion_kind=None)
- def copy(self, src, non_blocking=False):
- if not isinstance(src, ir.IRNode):
- src = tensor(src, dtype=self.get_dtype(), device=self.get_device())
- x = src
- if self.get_device() != src.get_device():
- x = to_device(x, self.get_device())
- if self.get_dtype() != src.get_dtype():
- x = to_dtype(x, self.get_dtype())
- if self.get_size() != src.get_size():
- out = expand(x, self.get_size())
- return clone(out)
- return clone(x)
- @register_lowering(aten.clone)
- def clone(x, *, memory_format=None):
- # TODO(jansel): memory format
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=x.make_loader(),
- ranges=list(x.get_size()),
- )
- def clone_preserve_reinterpret_view(x):
- reinterpret_view_layouts = []
- if isinstance(x, TensorBox) and isinstance(x.data, ir.ReinterpretView):
- x = x.data # unwrap TensorBox
- # pyrefly: ignore [bad-assignment]
- while isinstance(x, ir.ReinterpretView):
- reinterpret_view_layouts.append(x.get_layout())
- x = x.data
- x = TensorBox(x)
- x = clone(x)
- if reinterpret_view_layouts:
- x = x.data # unwrap TensorBox
- for layout in reinterpret_view_layouts[::-1]:
- x = ir.ReinterpretView(data=x, layout=layout)
- x = TensorBox(x)
- return x
- if hasattr(aten, "lift_fresh_copy"):
- register_lowering(aten.lift_fresh_copy)(clone)
- @register_lowering(prims.iota)
- def iota(
- length,
- *,
- start,
- step,
- dtype,
- device,
- requires_grad,
- ):
- def fn(index):
- return ops.index_expr(step * index[0] + start, dtype=dtype)
- return Pointwise.create(
- device=decode_device(device),
- dtype=dtype,
- inner_fn=fn,
- ranges=[length],
- )
- @register_lowering(aten.select_scatter, type_promotion_kind=None)
- def select_scatter(x, src, dim: int, index: int):
- src = to_dtype(src, x.get_dtype())
- x_loader = x.make_loader()
- dim = _validate_dim(x, dim, 0)
- if V.graph.sizevars.guard_or_false(sympy.Lt(index, 0)):
- index = index + x.get_size()[dim]
- elif V.graph.sizevars.guard_or_false(sympy.Ge(index, 0)):
- pass
- else:
- # unbacked index
- return fallback_handler(aten.select_scatter.default)(x, src, dim, index)
- V.graph.sizevars.check_leq(0, index) # type: ignore[arg-type]
- V.graph.sizevars.check_lt(index, x.get_size()[dim]) # type: ignore[arg-type]
- src = expand(unsqueeze(src, dim), x.get_size())
- src_loader = src.make_loader()
- def inner_fn(idx):
- return ops.where(
- ops.eq(
- ops.index_expr(idx[dim], torch.int32),
- ops.index_expr(index, torch.int32),
- ),
- src_loader(idx),
- x_loader(idx),
- )
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=inner_fn,
- ranges=list(x.get_size()),
- )
- @register_lowering(aten.slice_scatter, type_promotion_kind=None)
- def slice_scatter(x, src, dim=0, start=None, end=None, step=1):
- src = to_dtype(src, x.get_dtype())
- x_loader = x.make_loader()
- dim = _validate_dim(x, dim, 0)
- dim_size = x.get_size()[dim]
- # pyrefly: ignore [bad-argument-type]
- start, end = ir.SliceView.normalize_start_end(x, dim, start, end)
- src_size = list(x.get_size())
- src_size[dim] = FloorDiv(end - start + (step - 1), step)
- src = expand(src, src_size)
- src_loader = src.make_loader()
- def inner_fn(idx):
- if start == 0 and end == dim_size and step == 1:
- # selecting every element is the same as just src.clone()
- return src_loader(idx)
- idx_dim = ops.index_expr(idx[dim], torch.int64)
- src_idx = list(idx)
- src_idx[dim] = FloorDiv(idx[dim] - start, step)
- mask = []
- if start != 0:
- mask.append(
- ops.ge(
- idx_dim,
- ops.index_expr(sympy.expand(start), torch.int64),
- )
- )
- if end != dim_size:
- mask.append(
- ops.lt(
- idx_dim,
- ops.index_expr(sympy.expand(end), torch.int64),
- )
- )
- if step != 1:
- mask.append(
- ops.eq(
- ops.index_expr(
- ModularIndexing(idx[dim] - start, 1, step), torch.int64
- ),
- ops.constant(0, torch.int64),
- )
- )
- assert mask
- mask = functools.reduce(ops.and_, mask)
- src_val = ops.masked(
- mask,
- lambda: src_loader(src_idx),
- 0 if is_integer_type(x) else 0.0,
- )
- return ops.where(
- mask,
- src_val,
- x_loader(idx),
- )
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=inner_fn,
- ranges=list(x.get_size()),
- )
- def _unwrap(x):
- if isinstance(x, (list, tuple)) and len(x) > 0:
- return _unwrap(x[0])
- return x
- @register_lowering([torch.tensor, aten.scalar_tensor])
- def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False):
- assert_nyi(layout in (None, torch.strided), f"layout={layout}")
- assert_nyi(not pin_memory, "pin_memory")
- if isinstance(_unwrap(data), int):
- dtype = dtype or torch.int64
- else:
- dtype = dtype or torch.get_default_dtype()
- ranges: list[sympy.Expr] = []
- if isinstance(data, sympy.Basic):
- def inner_fn(index):
- return ops.index_expr(data, dtype)
- elif isinstance(data, (float, int)):
- def inner_fn(index):
- return ops.constant(data, dtype)
- elif len(data) == 0 or isinstance(data[0], (float, int)) and len(data) <= 8:
- # inline small tensors
- ranges.append(sympy.Integer(len(data)))
- def inner_fn(index):
- def binary_search(start, end):
- assert start < end
- if end - start == 1:
- return ops.constant(data[start], dtype)
- mid = (end - start) // 2 + start
- return ops.where(
- ops.lt(
- ops.index_expr(index[0], torch.int64),
- ops.constant(mid, torch.int64),
- ),
- binary_search(start, mid),
- binary_search(mid, end),
- )
- if len(data) == 0:
- return ops.constant(0, dtype)
- return binary_search(0, len(data))
- else:
- return V.graph.add_tensor_constant(
- torch.tensor(data, dtype=dtype, device=device)
- )
- return Pointwise.create(
- device=decode_device(device),
- dtype=dtype,
- inner_fn=inner_fn,
- ranges=ranges,
- )
- @register_lowering(torch.as_tensor)
- def as_tensor(data, dtype=None, device=None):
- if isinstance(data, TensorBox):
- if dtype is not None:
- data = to_dtype(data, dtype)
- if device is not None:
- data = to_device(data, device)
- return data
- return tensor(data, dtype=dtype, device=device)
- @register_lowering(torch.LongTensor)
- def long_tensor(data):
- return tensor(data, dtype=torch.int64)
- @register_lowering(aten._local_scalar_dense)
- def _local_scalar_dense(data):
- # This is interesting! Most lowerings return tensors, so you can just
- # return the buffer you allocated and it will get used (or not used, if
- # it's dead.) But _local_scalar_dense (aka item) returns an int,
- # not a Tensor, so you would have a type mismatch if you return a buffer;
- # we are obligated to return a sympy expression instead. However,
- # we need to actually codegen the .item() call somehow. We do this
- # by registering a faux buffer for the DynamicScalar IR node, which is
- # solely responsible for generating this .item(). The buffer is
- # not used for anything (notice we discard it); at codegen time,
- # the "buffer" just gets assigned None.
- unbacked_bindings = resolve_unbacked_bindings(
- V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"]
- )
- assert unbacked_bindings is not None
- assert len(unbacked_bindings) == 1, unbacked_bindings
- # NB: Have to be very careful here. V.graph.current_node.meta["val"]
- # seemingly also contains a symbol which you want to do binding for,
- # but it actually isn't. In particular, if we have later performed
- # a deferred runtime assert saying that u0 == s0, you will actually
- # see s0 from expr! This is bad because we need to actually generate
- # the assert that says u0 == s0, so we need to know where to get u0
- # from (this call). In particular, we must use unbacked_bindings, which
- # is guaranteed to have the original, unreplaced symbol in question.
- #
- # NB2: Another thing we have to be very careful about are symbol bindings
- # that require nontrivial refinement, e.g., when you have a binding site
- # x: Sym(u0 * 4) = y.item(). Here, the code generation must do a division
- # in order to appropriately bind u0. This is communicated via the keypath
- # in unbacked_bindings, and we need to hold onto it in order to generate
- # code appropriately for this case.
- binding_sym, keypath = next(iter(unbacked_bindings.items()))
- buffer = ir.DynamicScalar(binding_sym, keypath, data)
- buffer.name = V.graph.register_buffer(buffer)
- V.graph.register_operation(buffer)
- # NB: the replaced expr is OK to use directly downstream, we want
- # simplifications in this case!
- val = V.graph.current_node.meta["val"]
- if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)):
- return val.node.expr
- else:
- return sympy.sympify(val)
- @register_lowering(aten._assert_scalar)
- def _assert_scalar(data, msg):
- # NB: These will be handled at codegen time
- # Not sure if we are guaranteed to be able to serve out truth from the
- # deferred_runtime_asserts, TODO: try this assert out
- # See [NOTE] Codegen runtime asserts in Inductor
- # assert bool(data.scalar), data
- return None
- @register_lowering(aten._assert_tensor_metadata)
- def _assert_tensor_metadata(
- a, size=None, stride=None, dtype=None, *, device=None, layout=None
- ):
- return None
- def _full(fill_value, device, dtype, size):
- value = fill_value
- if not isinstance(fill_value, (int, float)) and hasattr(value, "value"):
- value = value.value
- if isinstance(value, (int, float)):
- def inner_fn(index):
- return ops.constant(value, dtype)
- elif isinstance(value, sympy.Basic):
- def inner_fn(index):
- return ops.index_expr(value, dtype)
- else:
- assert len(value.get_size()) == 0
- value_loader = value.make_loader()
- def inner_fn(index):
- return value_loader([])
- return Pointwise.create(
- device=device,
- dtype=dtype,
- inner_fn=inner_fn,
- ranges=list(size),
- )
- def full_like(x, fill_value, **kwargs):
- return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs)
- def tensor_constructor(fill_value):
- # torch.zeros, torch.ones, etc
- def inner(
- *size,
- names=None,
- dtype=None,
- device=None,
- layout=None,
- pin_memory=False,
- memory_format=None,
- ):
- assert_nyi(names is None, "named tensors")
- assert_nyi(layout in (None, torch.strided), f"layout={layout}")
- assert_nyi(not pin_memory, "pin_memory")
- device = decode_device(device)
- dtype = dtype or torch.get_default_dtype()
- if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)):
- size = tuple(size[0])
- # See https://github.com/pytorch/pytorch/issues/118102
- # All sizes at lowering time should be sympy.Symbol, not SymInt!
- for s in size:
- assert not isinstance(s, torch.SymInt)
- size = [sympy.expand(s) for s in size]
- return _full(fill_value, device, dtype, size)
- return inner
- @register_lowering([torch.empty, aten.empty])
- def empty(
- *size,
- names=None,
- dtype=None,
- layout=None,
- device=None,
- pin_memory=None,
- memory_format=None,
- ):
- assert_nyi(names is None, "named tensors")
- device = decode_device(device)
- if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)):
- size = tuple(size[0])
- return empty_strided(
- size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
- )
- def create_tensor_like(creation_fn):
- """
- Shim to convert X_like(...) into X(...). For example zeros_like() into zeros().
- """
- def _constant_like(
- x, *, dtype=None, device=None, layout=None, pin_memory=False, memory_format=None
- ):
- assert_nyi(not pin_memory, "pin_memory")
- assert_nyi(layout in (None, torch.strided), f"layout={layout}")
- if dtype is None:
- dtype = x.get_dtype()
- else:
- dtype = decode_dtype(dtype)
- device = device or x.get_device()
- size = list(x.get_size())
- return creation_fn(
- size, dtype=dtype, device=device, layout=layout, pin_memory=pin_memory
- )
- return _constant_like
- def constant_like(fill_value):
- return create_tensor_like(tensor_constructor(fill_value))
- empty_like = register_lowering(aten.empty_like)(create_tensor_like(empty))
- ones_like = create_tensor_like(tensor_constructor(1))
- zeros_like = create_tensor_like(tensor_constructor(0))
- def new_constant(fill_value):
- def _new_constant(
- x, size, *, dtype=None, layout=None, device=None, pin_memory=None
- ):
- assert isinstance(size, (list, tuple))
- assert_nyi(not pin_memory, "pin_memory")
- assert_nyi(layout in (None, torch.strided), f"layout={layout}")
- # pyrefly: ignore [bad-argument-type]
- dtype = decode_dtype(dtype) or x.get_dtype()
- device = device or x.get_device()
- size = [sympy.Integer(s) for s in size]
- return _full(fill_value, decode_device(device), dtype, size)
- return _new_constant
- @register_lowering(aten.new_empty)
- def new_empty(x, size, *, dtype=None, layout=None, device=None, pin_memory=None):
- if dtype is None:
- dtype = x.get_dtype()
- if device is None:
- device = x.get_device()
- return empty_strided(
- size,
- None,
- dtype=dtype,
- layout=layout,
- device=decode_device(device),
- pin_memory=pin_memory,
- )
- @register_lowering(aten.empty_strided)
- def empty_strided(
- size, stride, *, dtype=None, layout=None, device=None, pin_memory=None
- ):
- assert isinstance(size, (list, tuple))
- assert isinstance(stride, (list, tuple, type(None)))
- assert_nyi(layout in (None, torch.strided), f"layout={layout}")
- # pyrefly: ignore [bad-argument-type]
- dtype = decode_dtype(dtype) or torch.get_default_dtype()
- device = device or torch.tensor(0.0).device
- device = decode_device(device)
- pointwise = _full(fill_value=0, device=device, dtype=dtype, size=size)
- pointwise.realize()
- buffer = pointwise.data.data
- # explicitly set ranges to zeros in order to make a NopKernelSchedulerNode
- buffer.data = dataclasses.replace(buffer.data, ranges=[0] * len(size))
- assert isinstance(buffer, ir.ComputedBuffer)
- size = [sympy.expand(s) for s in size]
- stride = (
- [sympy.expand(s) for s in stride]
- if stride
- else ir.FlexibleLayout.contiguous_strides(size)
- )
- buffer.layout = ir.FixedLayout(
- device=device,
- dtype=dtype,
- size=size,
- stride=stride,
- is_pinned=pin_memory or False,
- )
- return pointwise
- @register_lowering(aten.new_empty_strided)
- def new_empty_strided(
- x, size, stride, *, dtype=None, layout=None, device=None, pin_memory=None
- ):
- if dtype is None:
- dtype = x.get_dtype()
- if device is None:
- device = x.get_device()
- return empty_strided(
- size,
- stride,
- dtype=dtype,
- layout=layout,
- device=decode_device(device),
- pin_memory=pin_memory,
- )
- @register_lowering(prims.copy_strided.default)
- def copy_strided(x, stride):
- stride = [V.graph.sizevars.size_hint_or_throw(s) for s in stride]
- stride_order = sorted(range(len(stride)), key=stride.__getitem__)
- return ir.ExternKernel.require_stride_order(x, stride_order)
- @register_lowering([torch.full, aten.full])
- def full(size, fill_value, **kwargs):
- assert kwargs.get("dtype") is not None, "dtype should be handled by decomposition"
- return tensor_constructor(fill_value)(size, **kwargs)
- @register_lowering(aten.gather, type_promotion_kind=None)
- def gather(x, dim, index, sparse_grad=False):
- # sparse_grad doesn't affect forward computation,
- # and backward tracing is taken care of by AOT Autograd
- assert isinstance(x, TensorBox)
- if index.get_numel() == 0:
- # Empty index case. Return an empty array with the same shape
- return new_empty(x, index.get_size())
- size = x.get_size()
- offset = len(size) == 0
- dim = _validate_dim(x, dim, offset)
- if offset:
- x = expand(x, [1])
- size = [1]
- x_loader = x.make_loader()
- index_loader = index.make_loader()
- def fn(idx):
- idx = list(idx)
- gather_idx = ops.indirect_indexing(index_loader(idx), size[dim])
- if len(idx) == 0:
- idx = [gather_idx]
- else:
- idx[dim] = gather_idx
- return x_loader(idx)
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=fn,
- ranges=index.get_size(),
- )
- @register_lowering(aten.embedding, type_promotion_kind=None)
- def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False):
- if sparse:
- return fallback_handler(aten.embedding.default)(
- weight, indices, padding_idx, scale_grad_by_freq, sparse
- )
- assert not sparse
- assert isinstance(weight, TensorBox)
- assert isinstance(indices, TensorBox)
- assert "int" in str(indices.get_dtype())
- weight_loader = weight.make_loader()
- indices_loader = indices.make_loader()
- indices_ndim = len(indices.get_size())
- weight_size = weight.get_size()
- new_size = [*indices.get_size(), *weight_size[1:]]
- def fn(idx):
- assert len(idx) == len(new_size), f"{idx} != {new_size}"
- var_index = indices_loader(idx[:indices_ndim])
- weight_idx = [ops.indirect_indexing(var_index, weight_size[0])] + [
- *idx[indices_ndim:]
- ]
- return weight_loader(weight_idx)
- return Pointwise.create(
- device=weight.get_device(),
- dtype=weight.get_dtype(),
- inner_fn=fn,
- ranges=new_size,
- )
- def check_and_broadcast_indices(indices, device):
- assert all(
- i.get_dtype() in (torch.int64, torch.int32, torch.bool, torch.uint8)
- for i in indices
- if i is not None
- ), (
- f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}"
- )
- if any(
- i.get_dtype() in (torch.bool, torch.uint8) for i in indices if i is not None
- ):
- raise NotImplementedError("Fallback for bool indices")
- valid_idxs = [i for i, x in enumerate(indices) if isinstance(x, TensorBox)]
- assert len(valid_idxs) > 0, "requires at least 1 non-None index"
- new_indices = [None] * len(indices)
- for i, x in zip(valid_idxs, broadcast_tensors(*[indices[i] for i in valid_idxs])):
- # Eager allows indices to be CPU tensor when running on CUDA
- # FIXME: Calling to_device(x, device) should work but
- # test_advancedindex_mixed_cpu_devices still fails
- if x.get_device() != device:
- raise NotImplementedError("Fallback when indices is on a different device")
- new_indices[i] = x
- return new_indices, valid_idxs
- def index_output_size_and_inner_fn(
- x_size,
- indices,
- tensor_indices,
- tensor_size,
- indices_loaders,
- indexed_size,
- x_loader,
- check,
- wrap_neg=True,
- ):
- # Note that behavior of indexing differs when there are non consecutive
- # tensors. In this case, the tensor index is pulled to the beginning.
- #
- # Suppose a = torch.arange(3 * 4 * 5 * 6 * 7).view(3, 4, 5, 6, 7)
- # x = torch.tensor[1,2]
- # Then, a[:,x,:,x,:] will have shape 2,3,5,7 as due to x,:,x then 2 will
- # be pulled to the front.
- non_consecutive_tensors = False
- for previous, current in itertools.pairwise(tensor_indices):
- if current - previous != 1:
- non_consecutive_tensors = True
- output_size = [x_size[i] for i, val in enumerate(indices) if val is None]
- output_size = [*output_size, *x_size[len(output_size) + len(tensor_indices) :]]
- first_tensor_index = tensor_indices[0]
- if non_consecutive_tensors:
- output_size = tensor_size + output_size
- else:
- output_size = (
- output_size[:first_tensor_index]
- + tensor_size
- + output_size[first_tensor_index:]
- )
- def fn(idx):
- assert len(idx) == len(output_size)
- assert len(indices_loaders) == len(indexed_size)
- rank = len(tensor_size)
- new_index = []
- first_tensor_index = tensor_indices[0]
- start_offset = 0 if non_consecutive_tensors else first_tensor_index
- next_idx = 0
- for i in range(tensor_indices[-1] + 1):
- if i == start_offset:
- next_idx += rank
- if indices[i] is None:
- assert next_idx < len(idx)
- new_index.append(idx[next_idx])
- next_idx += 1
- else:
- loader = indices_loaders[i]
- assert loader is not None
- size = indexed_size[i]
- new_index.append(
- ops.indirect_indexing(
- loader(idx[start_offset : start_offset + rank]),
- size,
- check=check,
- wrap_neg=wrap_neg,
- )
- )
- new_index = [
- *new_index,
- *idx[next_idx:],
- ]
- return new_index if x_loader is None else x_loader(new_index)
- return output_size, fn
- def index_impl(x, indices, check):
- output_size, inner_fn, _ = index_impl_helper(x, indices, check)
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=inner_fn,
- ranges=output_size,
- )
- def index_impl_helper(x, indices, check, wrap_neg=True):
- assert isinstance(indices, (list, tuple))
- x_loader = x.make_loader()
- indices, tensor_indices = check_and_broadcast_indices(indices, x.get_device())
- assert len(tensor_indices) > 0, "Must have at least one valid idx"
- indices_loaders = [i.make_loader() if i is not None else None for i in indices]
- # no guards on output size, all the guards are set in broadcast_tensors
- # We can use the first one since they are all required to be the same size
- tensor_size = list(indices[tensor_indices[0]].get_size())
- x_size = x.get_size()
- indexed_size = [x_size[i] for i in range(len(indices)) if indices[i] is not None]
- if check and 0 in indexed_size and 0 not in tensor_size:
- raise IndexError("index is out of bounds for dimension with size 0")
- indexed_size = [x_size[i] for i in range(len(indices))]
- output_size, index_inner_fn = index_output_size_and_inner_fn(
- x_size,
- indices,
- tensor_indices,
- tensor_size,
- indices_loaders,
- indexed_size,
- None,
- check=check,
- wrap_neg=wrap_neg,
- )
- def inner_fn(idx):
- return x_loader(index_inner_fn(idx))
- return output_size, inner_fn, index_inner_fn
- @register_lowering(aten.index, type_promotion_kind=None)
- def index(x, indices):
- try:
- return index_impl(x, indices, check=True)
- except NotImplementedError:
- # Fallback to ATen for boolean indexing
- x.realize()
- return fallback_handler(aten.index.Tensor, add_to_fallback_set=False)(
- x, indices
- )
- @register_lowering(aten._unsafe_index, type_promotion_kind=None)
- def _unsafe_index(x, indices):
- return index_impl(x, indices, check=False)
- # All the indexing decompositions are written in terms of index, index_put, and index_put_
- # We cannot have this lowering as a decomposition as it introduces
- # mutation in the graph, which is bad for Aot Autograd. Aot Autograd runs dead
- # code elimination and common subexpression elimination optimizations, which
- # assume graphs to be side-effect free. More details at
- # https://github.com/pytorch/torchdynamo/issues/1235
- # and
- # https://github.com/pytorch/torchdynamo/issues/1863
- @register_lowering(aten.index_put, type_promotion_kind=None)
- def index_put(x, indices, values, accumulate=False):
- return index_put_impl_(
- clone(x), indices, values, accumulate, check=True, may_realize=False
- )
- @register_lowering(aten._unsafe_index_put)
- def _unsafe_index_put(x, indices, values, accumulate=False):
- return index_put_impl_(
- clone(x), indices, values, accumulate, check=False, may_realize=False
- )
- def index_put_as_masked_fill(self, indices, value, accumulate):
- if value.get_device() != self.get_device():
- value = to_device(value, self.get_device())
- if accumulate:
- value = add(self, value)
- return mutate_to(self, where(indices[0], value, self))
- def index_put_fallback(self, indices, values, accumulate):
- from .utils import _fx_node_is_input_dependent_cudagraph_unsafe
- op_overload = getattr(aten.index_put_, V.graph.current_node.target._overloadname) # type: ignore[union-attr]
- # Check if any index is a boolean tensor - if so, mark as cudagraph-unsafe
- # because boolean indices trigger .nonzero() during CUDA graph capture
- # When graph_partition is enabled, skip - partitioning handles this
- fx_node = V.graph.current_node
- if (
- not config.graph_partition
- and fx_node is not None
- and _fx_node_is_input_dependent_cudagraph_unsafe(fx_node)
- ):
- msg = "index_put_ fallback with boolean indexing is not compatible with CUDA graphs"
- if stack_trace := fx_node.meta.get("stack_trace", None):
- msg = f"{msg} Found from : \n {stack_trace}"
- V.graph.disable_cudagraphs_reason = msg
- ir.IndexPutFallback(op_overload, self, indices, values, accumulate)
- return self
- @register_lowering(aten.index_put_, type_promotion_kind=None)
- def index_put_(self, indices, values, accumulate=False):
- return index_put_impl_(
- self, indices, values, accumulate, check=True, may_realize=True
- )
- @register_lowering(inductor_prims._unsafe_index_put_, type_promotion_kind=None)
- def _unsafe_index_put_(self, indices, values, accumulate=False):
- return index_put_impl_(
- self, indices, values, accumulate, check=False, may_realize=True
- )
- def index_put_impl_(self, indices, values, accumulate, check, may_realize=False):
- if may_realize:
- def indice_slice_from_randperm(indice):
- # Refer to: https://github.com/pytorch/pytorch/pull/139366#discussion_r1825424660
- # For this specific pattern, indices is unique as coming from torch.randperm.
- # However, as the content of the indices is unknown, we have to check this specific pattern.
- if isinstance(indice, TensorBox) and isinstance(indice.data, ir.BaseView):
- indice = indice.data.unwrap_view()
- return (
- isinstance(indice, ir.StorageBox)
- and isinstance(indice.data, ir.ExternKernel)
- and getattr(indice.data, "fx_node", None)
- and indice.data.fx_node.target is torch.ops.aten.randperm.default
- )
- return False
- if ir.try_get_name(self) in values.get_read_names() and not all(
- indice_slice_from_randperm(indice) for indice in indices
- ):
- # Fix issue: https://github.com/pytorch/pytorch/issues/138908
- # When self and values have memory overlapping, indices may
- # contain duplicate values, potentially causing incorrect results since
- # the load of `values` might contain modified value from the store of `self`.
- # To address this, store values in a temporary buffer in such cases.
- values.realize()
- # Dispatch to masked fill for single boolean index with single value
- if (
- values.get_numel() == 1
- and len(indices) == 1
- and indices[0].get_dtype() in (torch.bool, torch.uint8)
- ):
- mask = indices[0]
- for _ in range(len(mask.get_size()), len(self.get_size())):
- mask = unsqueeze(mask, -1)
- return index_put_as_masked_fill(self, [mask], values, accumulate)
- # Fallback in torch deterministic mode
- if torch.are_deterministic_algorithms_enabled():
- return index_put_fallback(self, indices, values, accumulate)
- # Fallback if there is a boolean index
- for index in indices:
- if index is not None and index.get_dtype() in (torch.bool, torch.uint8):
- return index_put_fallback(self, indices, values, accumulate)
- x_size = self.get_size()
- x_ndim = len(x_size)
- if accumulate and needs_fallback_due_to_atomic_add_limitations(self.get_dtype()):
- # self is an scalar Tensor
- if x_ndim == 0:
- self = view(self, [1])
- self = index_put_fallback(self, indices, values, accumulate)
- if x_ndim == 0:
- self = view(self, [])
- return self
- values = to_dtype(values, self.get_dtype())
- try:
- # Note that code will only get here when dtype is uint32
- indices, tensor_indices = check_and_broadcast_indices(
- indices, self.get_device()
- )
- except NotImplementedError:
- return index_put_fallback(self, indices, values, accumulate)
- indices_loaders = [i.make_loader() if i is not None else None for i in indices]
- assert isinstance(self, TensorBox)
- self.realize()
- # self is an scalar Tensor
- if x_ndim == 0:
- self = view(self, [1])
- # We can use the first one since they are all required to be the same size
- tensor_size = list(indices[tensor_indices[0]].get_size())
- indexed_size = [x_size[i] for i in range(len(indices))]
- expected_vals_size, inner_fn = index_output_size_and_inner_fn(
- x_size,
- indices,
- tensor_indices,
- tensor_size,
- indices_loaders,
- indexed_size,
- None,
- check=check,
- )
- values = expand(values, expected_vals_size)
- # all guards are set above during broadcast_tensors and expand
- device = self.get_device()
- assert device is not None
- scatter = ir.Scatter(
- device=device,
- dtype=self.get_dtype(),
- inner_fn=values.make_loader(),
- ranges=expected_vals_size, # iter_ranges,
- output_indexer=inner_fn,
- scatter_mode="atomic_add" if accumulate else None,
- )
- buffer = ir.ComputedBuffer(
- name=None,
- layout=ir.MutationLayoutSHOULDREMOVE(self),
- data=scatter,
- )
- buffer.name = V.graph.register_buffer(buffer)
- V.graph.register_operation(buffer)
- if x_ndim == 0:
- self = view(self, [])
- return self
- fallback__unsafe_masked_index = fallback_handler(
- aten._unsafe_masked_index.default, add_to_fallback_set=False
- )
- fallback__unsafe_masked_index_put_accumulate = fallback_handler(
- aten._unsafe_masked_index_put_accumulate.default, add_to_fallback_set=False
- )
- @register_lowering(aten._unsafe_masked_index, type_promotion_kind=None)
- def _unsafe_masked_index(self, mask, indices, fill):
- ranges, _, _unsafe_index_fn = index_impl_helper(
- self, indices, check=False, wrap_neg=False
- )
- mask_loader = mask.make_loader()
- self_loader = self.make_loader()
- def inner_fn(idx):
- if mask.dtype != torch.bool:
- mask_val = ops.to_dtype(mask_loader(idx), torch.bool)
- else:
- mask_val = mask_loader(idx)
- return ops.masked(mask_val, lambda: self_loader(_unsafe_index_fn(idx)), fill)
- return Pointwise.create(
- device=self.get_device(),
- dtype=self.get_dtype(),
- inner_fn=inner_fn,
- ranges=ranges,
- )
- @register_lowering(aten._unsafe_masked_index_put_accumulate, type_promotion_kind=None)
- def _unsafe_masked_index_put_accumulate(x, mask, indices, values):
- masked_value = where(mask, values, 0)
- shape = x.get_size()
- clamped_indices = [
- clamp(indices[i], -shape[i], shape[i] - 1) if indices[i] else None
- for i in range(len(indices))
- ]
- # TODO: use a masked store for this. currently only triton
- # supports masked stores and cpp backend does not.
- return _unsafe_index_put(x, clamped_indices, masked_value, accumulate=True)
- @make_pointwise
- def clamp(a, min, max):
- return ops.maximum(min, ops.minimum(max, a))
- @register_lowering(aten.as_strided_scatter, type_promotion_kind=None)
- def as_strided_scatter(self, src, size, stride, storage_offset=None):
- output = clone(self)
- output_view = as_strided(output, size, stride, storage_offset)
- copy_(output_view, src)
- return output
- @register_lowering(aten.scatter, type_promotion_kind=None)
- def scatter(x, dim: int, index, src, **kwargs):
- return scatter_(clone(x), dim, index, src, **kwargs)
- def scatter_fallback(
- op_overload: torch._ops.OpOverload,
- self,
- dim: int,
- index,
- src,
- *,
- reduce: Optional[str] = None,
- include_self: bool = True,
- ):
- src_is_tensor = isinstance(src, TensorBox)
- if use_scatter_fallback(
- op_overload,
- reduce,
- self.get_dtype(),
- cast(torch.dtype, src.get_dtype() if src_is_tensor else type(src)),
- # pyrefly: ignore [missing-attribute]
- src.get_device().type if src_is_tensor else "not impl",
- src_is_tensor,
- ):
- ir.ScatterFallback(
- op_overload,
- self,
- dim,
- index,
- src,
- reduce=reduce,
- include_self=include_self,
- )
- return self
- return None
- @register_lowering(aten.scatter_, type_promotion_kind=None)
- def scatter_(self, dim: int, index, src, *, reduce: Optional[str] = None):
- assert reduce in (None, "add", "multiply")
- if reduce is None:
- op_overload = getattr(aten.scatter_, V.graph.current_node.target._overloadname) # type: ignore[union-attr]
- fallback_result = scatter_fallback(
- op_overload, self, dim, index, src, reduce=reduce
- )
- if fallback_result is not None:
- return fallback_result
- if reduce == "add":
- reduce = "sum"
- elif reduce == "multiply":
- reduce = "prod"
- return scatter_reduce_(self, dim, index, src, reduce)
- @register_lowering(aten.scatter_add, type_promotion_kind=None)
- def scatter_add(x, dim: int, index, src):
- return scatter_add_(clone(x), dim, index, src)
- @register_lowering(aten.scatter_add_, type_promotion_kind=None)
- def scatter_add_(x, dim: int, index, src):
- return scatter_reduce_(x, dim, index, src, "sum")
- @register_lowering(aten.scatter_reduce, type_promotion_kind=None)
- def scatter_reduce(x, dim: int, index, src, reduction_type, **kwargs):
- return scatter_reduce_(clone(x), dim, index, src, reduction_type, **kwargs)
- @register_lowering(aten.scatter_reduce_, type_promotion_kind=None)
- def scatter_reduce_(self, dim: int, index, src, reduce, *, include_self: bool = True):
- assert reduce in (None, "sum", "prod", "mean", "amax", "amin")
- assert (
- len(aten.scatter_reduce_.overloads()) == 1
- and "two" in aten.scatter_reduce_.overloads()
- ), "aten.scatter_reduce_.two is not the unique overload of aten.scatter_reduce_"
- if isinstance(src, Number):
- src = full_like(self, src)
- fallback_result = scatter_fallback(
- aten.scatter_reduce_.two,
- self,
- dim,
- index,
- src,
- reduce=reduce,
- include_self=include_self,
- )
- if fallback_result:
- return fallback_result
- assert isinstance(self, TensorBox)
- assert "int" in str(index.get_dtype())
- ndim = len(self.get_size())
- if ndim == 0:
- self = view(self, [1])
- if isinstance(src, TensorBox) and len(src.get_size()) == 0:
- src = view(src, [1])
- if isinstance(index, TensorBox) and len(index.get_size()) == 0:
- index = view(index, [1])
- if index.get_numel() == 0:
- return self
- dim = _validate_dim(self, dim)
- self.realize()
- index_loader = index.make_loader()
- src_loader = src.make_loader() if isinstance(src, TensorBox) else None
- def output_indexer(idx):
- # self is captured from the end of the function, so it may have 0 dim
- shape = self.get_size()
- ndim = len(shape)
- indirect_idx = list(idx)
- indirect_idx[dim] = ops.indirect_indexing(
- index_loader(idx), 1 if ndim == 0 else shape[dim], wrap_neg=False
- )
- return indirect_idx
- def fn(idx):
- if src_loader:
- return src_loader(idx)
- else:
- # src is a scalar
- # pyrefly: ignore [bad-argument-type]
- return ops.constant(src, self.get_dtype())
- def backend_reduce_str(reduce):
- if reduce == "sum":
- return "atomic_add"
- else:
- # TODO: Need to support more reduction type
- assert reduce is None
- return None
- device = self.get_device()
- assert device is not None
- if not include_self:
- # zero out the corresponding elements first
- zero_out = ir.Scatter(
- device=device,
- dtype=self.get_dtype(),
- inner_fn=lambda index: ops.constant(0, self.get_dtype()),
- ranges=index.get_size(),
- output_indexer=output_indexer,
- scatter_mode=None,
- )
- buffer = ir.ComputedBuffer(
- name=None,
- layout=ir.MutationLayoutSHOULDREMOVE(self),
- data=zero_out,
- )
- buffer.name = V.graph.register_buffer(buffer)
- V.graph.register_operation(buffer)
- # self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0
- # self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1
- # self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2
- scatter = ir.Scatter(
- device=device,
- dtype=self.get_dtype(),
- inner_fn=fn,
- ranges=index.get_size(),
- output_indexer=output_indexer,
- scatter_mode=backend_reduce_str(reduce),
- )
- buffer = ir.ComputedBuffer(
- name=None,
- layout=ir.MutationLayoutSHOULDREMOVE(self),
- data=scatter,
- )
- buffer.name = V.graph.register_buffer(buffer)
- V.graph.register_operation(buffer)
- if ndim == 0:
- self = view(self, [])
- return self
- def upsample_nearestnd(
- x,
- output_size,
- scales_x: tuple[Optional[float], ...],
- n: int = 2,
- exact: bool = False,
- ):
- x.realize_hint() # elements are reused
- x_loader = x.make_loader()
- i_sizes = x.get_size()[-n:]
- batch = x.get_size()[:-n]
- i_sizes = [V.graph.sizevars.guard_int(i) for i in i_sizes]
- assert len(scales_x) == n
- o_sizes = output_size
- inv_scales = [i / o for i, o in zip(i_sizes, o_sizes)]
- for i, scale in enumerate(scales_x):
- if scale is not None:
- inv_scales[i] = 1.0 / scale
- def scale_fn(x, scale, size):
- # Nearest Exact: input_index = round(scale * (output_index + 0.5) - 0.5)
- # = floor(scale * (output_index + 0.5))
- # Nearest: input_index = floor(scale * output_index)
- x = ops.index_expr(x, torch.float32)
- if exact:
- x = ops.add(x, ops.constant(0.5, torch.float32))
- x = ops.mul(x, ops.constant(scale, torch.float32))
- x = ops.to_dtype(x, torch.int32)
- return ops.indirect_indexing(x, size, check=False)
- def fn(idx):
- x = idx[-n:]
- b = idx[:-n]
- return x_loader(
- [*b, *[scale_fn(i, s, size) for i, s, size in zip(x, inv_scales, i_sizes)]]
- )
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=fn,
- ranges=[*batch, *o_sizes],
- )
- @register_lowering(aten.upsample_nearest1d.default)
- def upsample_nearest1d(x, output_size, scales: Optional[float] = None):
- return upsample_nearestnd(x, output_size, (scales,), n=1)
- @register_lowering(aten._upsample_nearest_exact1d.default)
- def _upsample_nearest_exact1d(x, output_size, scales: Optional[float] = None):
- return upsample_nearestnd(x, output_size, (scales,), n=1, exact=True)
- @register_lowering(aten.upsample_nearest2d.default)
- def upsample_nearest2d(
- x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None
- ):
- return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2)
- @register_lowering(aten._upsample_nearest_exact2d.default)
- def _upsample_nearest_exact2d(
- x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None
- ):
- return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2, exact=True)
- @register_lowering(aten.upsample_nearest3d.default)
- def upsample_nearest3d(
- x,
- output_size,
- scales_d: Optional[float] = None,
- scales_h: Optional[float] = None,
- scales_w: Optional[float] = None,
- ):
- return upsample_nearestnd(x, output_size, (scales_d, scales_h, scales_w), n=3)
- @register_lowering(aten._upsample_nearest_exact3d.default)
- def _upsample_nearest_exact3d(
- x,
- output_size,
- scales_d: Optional[float] = None,
- scales_h: Optional[float] = None,
- scales_w: Optional[float] = None,
- ):
- return upsample_nearestnd(
- x, output_size, (scales_d, scales_h, scales_w), n=3, exact=True
- )
- def _create_constants(*args, dtype):
- return tuple(ops.constant(a, dtype) for a in args)
- @register_lowering(prims.rev.default)
- def rev(x, dims):
- # note - dims pre-canonicalized
- x_loader = x.make_loader()
- sizes = x.get_size()
- def loader(idx):
- idx = list(idx)
- assert len(idx) == len(sizes)
- for dim in dims:
- idx[dim] = (sizes[dim] - 1) - idx[dim]
- return x_loader(idx)
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=loader,
- ranges=sizes,
- )
- def inplace_constant_pad_nd(
- x: TensorBox, padding: Sequence[int], fill_value: float
- ) -> Optional[TensorBox]:
- """
- This optimization changes the semantics of padding from 'clone'
- style to 'view' style.
- Thanks to functionalization, this change can still maintain numerical
- correctness.
- """
- def _padding_can_be_fused():
- """
- Conservatively check if padding can be fused with downstream op.
- 1. if the downstream op is a sum, then there is little benefit to
- do inplace padding
- 2. if the downstream op is a matmul, doing inplace padding can
- save membw.
- """
- current_node = V.graph.current_node
- if current_node is None:
- return True # be conservative
- users = tuple(current_node.users)
- if len(users) == 1 and users[0].target in (
- aten.mm.default,
- aten.addmm.default,
- ):
- return False
- return True # be conservative
- if _padding_can_be_fused():
- return None
- # Only handle 2D case for now
- if len(padding) != 4 or len(x.get_size()) != 2:
- return None
- # No harm to realize since we already know that
- # the op can not be fused into the single user.
- # It need to be realized later anyways.
- x.realize()
- # If x is a view (e.g. a SliceView), realizing it just realizing the
- # underlying storage. x itself is still a view.
- if (
- not isinstance(x, ir.TensorBox)
- or not isinstance(x.data, ir.StorageBox)
- or not (
- isinstance(x.data.data, ir.ComputedBuffer)
- or (
- config.can_inplace_pad_graph_input
- and isinstance(x.data.data, ir.InputBuffer)
- )
- )
- or not x.data.data.name
- ):
- return None
- x.freeze_layout()
- _, layout = ir.as_storage_and_layout(x)
- strides = layout.stride
- if strides[1] != 1:
- return None
- if padding[0] != 0 or padding[2] != 0 or padding[3] != 0:
- return None
- npad = padding[1]
- if npad == 0:
- return None
- stride0 = strides[0]
- rowsize = layout.size[1]
- if stride0 < rowsize + npad:
- return None
- bufname = x.data.data.name
- padded_size = [layout.size[0], layout.size[1] + npad]
- V.graph.buffer_to_padded_size[bufname] = padded_size
- resized_x = as_strided(
- x,
- padded_size,
- layout.stride,
- layout.offset,
- )
- sliced_x = slice_(resized_x, dim=1, start=rowsize, end=rowsize + npad, clamp=False)
- fill_(sliced_x, fill_value)
- counters["inductor"]["inplace_padding"] += 1
- return resized_x
- @register_lowering(aten.constant_pad_nd, type_promotion_kind=None)
- def constant_pad_nd(x, padding, fill_value=0):
- assert (len(padding) % 2) == 0
- if all(p == 0 for p in padding):
- return clone(x)
- if config.inplace_padding:
- out = inplace_constant_pad_nd(x, padding, fill_value)
- if out:
- return out
- # fall through if can not inplace the padding
- sizes = x.get_size()
- bounds = list(reversed(list(zip(padding[::2], padding[1::2]))))
- n = len(sizes) - len(bounds)
- # if padding is a complicated expression, hoist it
- bounds_precomp: list[tuple[sympy.Symbol, Any]] = []
- for l, h in bounds:
- bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h)) # type: ignore[arg-type]
- output_size = list(sizes[:n])
- mask_sizes = []
- for (low, high), size in zip(bounds, sizes[n:]):
- mask_sizes.append(size)
- output_size.append(sympy.expand(size + low + high))
- assert len(output_size) == len(sizes)
- fill_value = dtype_to_type(x.get_dtype())(fill_value)
- def mask(index):
- mask = []
- for idx, (low, high), length in zip(index[n:], bounds, mask_sizes):
- if low != 0:
- mask.append(range_mask_low(idx, 0))
- if high != 0:
- mask.append(range_mask_high(idx, length))
- mask = functools.reduce(ops.and_, mask)
- return ops.masked(mask, lambda: x_loader(index), fill_value)
- def offset_fn(index):
- new_index = list(index[:n])
- for idx, (low, _high) in zip(index[n:], bounds_precomp):
- new_index.append(idx - low)
- assert len(new_index) == len(index)
- return mask(new_index)
- x_loader = x.make_loader()
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=offset_fn,
- ranges=output_size,
- )
- def range_mask_low(i: sympy.Expr, low: Union[sympy.Expr, int]):
- return ops.ge(
- ops.index_expr(i, torch.int64),
- ops.index_expr(sympy.Integer(low), torch.int64),
- )
- def range_mask_high(i: sympy.Expr, high: sympy.Expr):
- return ops.lt(
- ops.index_expr(i, torch.int64),
- ops.index_expr(high, torch.int64),
- )
- def range_mask(i: sympy.Expr, high: sympy.Expr, low: sympy.Expr):
- return ops.and_(
- range_mask_low(i, low),
- range_mask_high(i, high),
- )
- def constant_boundary_condition(
- x, fill_value, padding=None, pad_fill_value=1.0, dim=None
- ):
- h = x.get_size()[-dim:]
- x_loader = x.make_loader()
- # pyrefly: ignore [unsupported-operation]
- padding_h = padding or [0] * dim
- def load(index):
- prefix = index[:-dim]
- ih = index[-dim:]
- mask = functools.reduce(
- ops.and_,
- # pyrefly: ignore [no-matching-overload]
- [range_mask(ih[i], h[i] + padding_h[i], -padding_h[i]) for i in range(dim)],
- )
- return (
- ops.masked(
- mask,
- lambda: constant_boundary_condition(x, pad_fill_value, dim=dim)(
- [*prefix, *ih]
- ),
- fill_value,
- )
- if padding
- else ops.masked(mask, lambda: x_loader([*prefix, *ih]), fill_value)
- )
- return load
- def pooling_size(x, i, kernel_size, stride, padding, ceil_mode, *, dilation=None):
- if dilation is None:
- dilation = [1] * len(padding)
- x_out = FloorDiv(
- x + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) + (stride[i] - 1),
- stride[i],
- )
- if ceil_mode:
- x_alt = FloorDiv(
- x
- + 2 * padding[i]
- - dilation[i] * (kernel_size[i] - 1)
- + 2 * (stride[i] - 1),
- stride[i],
- )
- if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0:
- # Sliding windows must start within the input or left padding
- x_alt -= 1 # type: ignore[assignment]
- V.graph.sizevars.check_leq(0, x_alt * stride[i] - x - padding[i]) # type: ignore[arg-type]
- if V.graph.sizevars.size_hint(x_out - x_alt) == 0:
- # ceil mode is actually a no-op, lets guard on that
- V.graph.sizevars.check_equals(x_out, x_alt)
- ceil_mode = False
- else:
- x_out = x_alt
- return x_out, ceil_mode
- def should_fallback_max_pool_with_indices(kernel_size, *, n_dim):
- kernel_size = pad_listlike(kernel_size, n_dim)
- window_size = functools.reduce(operator.mul, kernel_size)
- return window_size > 25
- def max_pool_checks(
- x, kernel_size, stride, padding, dilation, n_dim, *, assert_fallback=None
- ):
- if padding == 0:
- padding = [0] * n_dim
- if dilation == 1:
- dilation = [1] * n_dim
- if not stride:
- stride = kernel_size
- kernel_size = pad_listlike(kernel_size, n_dim)
- stride = pad_listlike(stride, n_dim)
- padding = pad_listlike(padding, n_dim)
- dilation = pad_listlike(dilation, n_dim)
- assert isinstance(x, TensorBox)
- assert len(kernel_size) == n_dim
- assert len(stride) == n_dim
- assert len(padding) == n_dim
- assert len(dilation) == n_dim
- assert len(x.get_size()) in (n_dim + 1, n_dim + 2)
- use_fallback = should_fallback_max_pool_with_indices(kernel_size, n_dim=n_dim)
- if assert_fallback is not None:
- assert use_fallback == assert_fallback
- return kernel_size, stride, padding, dilation, use_fallback
- def _max_pool_with_offsets(
- x,
- kernel_size,
- stride,
- padding,
- dilation,
- ceil_mode,
- *,
- n_dim,
- ):
- x.realize_hint()
- batch = x.shape[:-n_dim]
- dhw = x.shape[-n_dim:]
- dhw_out, ceil_mode = zip(
- *[
- pooling_size(
- dhw[d], d, kernel_size, stride, padding, ceil_mode, dilation=dilation
- )
- for d in range(n_dim)
- ]
- )
- dtype = x.dtype
- min_value = (
- False
- if dtype is torch.bool
- else (float("-inf") if dtype.is_floating_point else torch.iinfo(dtype).min)
- )
- new_size = list(batch) + list(dhw_out)
- if any(padding) or any(ceil_mode) or any(d > 1 for d in dilation):
- x_loader = constant_boundary_condition(x, min_value, dim=n_dim)
- else:
- x_loader = x.make_loader()
- def fn_inner(idx, reduction_idx):
- prefix = idx[:-n_dim]
- bh = idx[-n_dim:]
- ih = [
- (bh[i] * stride[i]) + (reduction_idx[i] * dilation[i]) - padding[i]
- for i in range(n_dim)
- ]
- return x_loader([*prefix, *ih])
- result = Reduction.create(
- reduction_type="max",
- input_node=x,
- device=x.get_device(),
- dst_dtype=dtype,
- src_dtype=dtype,
- inner_fn=fn_inner,
- ranges=new_size,
- reduction_ranges=kernel_size,
- )
- offsets = Reduction.create(
- reduction_type="argmax",
- input_node=x,
- device=x.get_device(),
- dst_dtype=torch.int64,
- src_dtype=dtype,
- inner_fn=fn_inner,
- ranges=new_size,
- reduction_ranges=kernel_size,
- )
- if isinstance(result.data.data, Reduction): # type: ignore[attr-defined, union-attr]
- # Only realize if reduction isn't unrolled
- result.realize()
- if isinstance(offsets.data.data, Reduction): # type: ignore[attr-defined, union-attr]
- # Only realize if reduction isn't unrolled
- offsets.realize()
- return result, offsets
- @register_lowering(prims._low_memory_max_pool_with_offsets, type_promotion_kind=None)
- def _low_memory_max_pool_with_offsets(
- x,
- kernel_size,
- stride,
- padding,
- dilation,
- ceil_mode=False,
- ):
- n_dim = len(kernel_size)
- # assert we are not on a fallback path, the inductor decomp should have guaranteed this
- kernel_size, stride, padding, dilation, _ = max_pool_checks(
- x,
- kernel_size,
- stride,
- padding,
- dilation,
- n_dim,
- assert_fallback=False,
- )
- with config.patch(unroll_reductions_threshold=25):
- result, offsets = _max_pool_with_offsets(
- x,
- kernel_size,
- stride,
- padding,
- dilation,
- ceil_mode,
- n_dim=n_dim,
- )
- return result, to_dtype(offsets, torch.int8)
- def _pool_offsets_to_indices(
- offsets: TensorBox,
- kernel_size: Sequence[Union[int, torch.SymInt]],
- input_size: Sequence[Union[int, torch.SymInt]],
- increments_to_index: Callable[
- [Sequence[Union[int, torch.SymInt]], Sequence[Union[int, torch.SymInt]]],
- torch._inductor.virtualized.OpsValue,
- ],
- ) -> TensorBox:
- n_dim = len(kernel_size)
- offsets_loader = offsets.make_loader()
- window_size = sympy.sympify(functools.reduce(operator.mul, kernel_size))
- def offsets_to_indices(idx):
- offset = offsets_loader(idx)
- offset_sympy = ops.indirect_indexing(offset, window_size)
- reduction_idx = inductor_prims._flattened_index_to_nd(offset_sympy, kernel_size)
- idhw = increments_to_index(idx, reduction_idx)
- return ops.index_expr(
- inductor_prims._flatten_index(idhw, input_size[-n_dim:]), torch.int64
- )
- indices = Pointwise.create(
- device=offsets.get_device(),
- dtype=torch.int64,
- inner_fn=offsets_to_indices,
- ranges=offsets.get_size(),
- )
- return indices
- @register_lowering(
- prims._low_memory_max_pool_offsets_to_indices, type_promotion_kind=None
- )
- def _low_memory_max_pool_offsets_to_indices(
- offsets, kernel_size, input_size, stride, padding, dilation
- ):
- # TODO: Generalize to other max pooling flavors
- n_dim = len(kernel_size)
- def increments_to_index(idx, reduction_idx):
- bh = idx[-n_dim:]
- return [
- (bh[i] * stride[i]) + (reduction_idx[i] * dilation[i]) - padding[i]
- for i in range(n_dim)
- ]
- return _pool_offsets_to_indices(
- offsets, kernel_size, input_size, increments_to_index
- )
- def _max_pool_with_indices(
- x,
- kernel_size,
- stride,
- padding,
- dilation,
- ceil_mode,
- n_dim,
- ):
- kernel_size, stride, padding, dilation, _ = max_pool_checks(
- x, kernel_size, stride, padding, dilation, n_dim=n_dim
- )
- out, offsets = _max_pool_with_offsets(
- x, kernel_size, stride, padding, dilation, ceil_mode, n_dim=n_dim
- )
- indices = _low_memory_max_pool_offsets_to_indices(
- offsets,
- kernel_size,
- x.shape[-n_dim:],
- stride,
- padding,
- dilation,
- )
- return out, indices
- # Fallback when we do not decompose to the low-memory path.
- @register_lowering(aten.max_pool2d_with_indices, type_promotion_kind=None)
- def max_pool2d_with_indices(
- x,
- kernel_size,
- stride=None,
- padding=0,
- dilation=1,
- ceil_mode=False,
- ):
- return _max_pool_with_indices(
- x, kernel_size, stride, padding, dilation, ceil_mode, n_dim=2
- )
- # Fallback when we do not decompose to the low-memory path.
- @register_lowering(aten.max_pool3d_with_indices, type_promotion_kind=None)
- def max_pool3d_with_indices(
- x,
- kernel_size,
- stride=None,
- padding=0,
- dilation=1,
- ceil_mode=False,
- ):
- return _max_pool_with_indices(
- x, kernel_size, stride, padding, dilation, ceil_mode, n_dim=3
- )
- fallback_max_pool2d_with_indices_backward = fallback_handler(
- aten.max_pool2d_with_indices_backward.default,
- add_to_fallback_set=False,
- )
- @register_lowering(aten.max_pool2d_with_indices_backward, type_promotion_kind=None)
- def max_pool2d_with_indices_backward(
- grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices
- ):
- if padding == 0:
- padding = [0, 0]
- if dilation == 1:
- dilation = [1, 1]
- if not stride:
- stride = kernel_size
- assert isinstance(x, TensorBox)
- assert len(kernel_size) == 2
- assert len(stride) == 2
- assert len(padding) == 2
- assert len(dilation) == 2
- assert len(x.get_size()) in (3, 4)
- # we will read this many times, so make sure it is computed
- grad_output.realize_hint()
- gO_stride = grad_output.maybe_get_stride()
- x_stride: Optional[Sequence[Any]]
- if isinstance(x, TensorBox) and isinstance(x.data.data, Pointwise): # type: ignore[attr-defined]
- data = x.data.data # type: ignore[attr-defined]
- device = data.get_device()
- assert device is not None
- x_buffer = ir.ComputedBuffer(
- name=None,
- layout=ir.FlexibleLayout(
- device=device,
- dtype=data.get_dtype(),
- size=data.get_size(),
- ),
- data=data,
- )
- x_buffer.decide_layout()
- x_stride = x_buffer.get_stride()
- else:
- x_stride = x.maybe_get_stride()
- is_channels_last = (x_stride is not None and x_stride[1] == 1) or (
- gO_stride is not None and gO_stride[1] == 1
- )
- if any(d != 1 for d in dilation):
- # dilation NYI
- return fallback_max_pool2d_with_indices_backward(
- grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices
- )
- *_batch, _height, width = x.get_size()
- *_, pooled_height, pooled_width = grad_output.get_size()
- indices_loader = indices.make_loader()
- grad_loader = grad_output.make_loader()
- new_size = list(x.get_size())
- h_window_size = max(
- max(FloorDiv(h, stride[0]) - max(0, FloorDiv(h - kernel_size[0], stride[0])), 1)
- for h in range(kernel_size[0] * 2)
- )
- w_window_size = max(
- max(FloorDiv(w, stride[1]) - max(0, FloorDiv(w - kernel_size[1], stride[1])), 1)
- for w in range(kernel_size[1] * 2)
- )
- window_size = h_window_size * w_window_size
- if window_size > 25:
- # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
- return fallback_max_pool2d_with_indices_backward(
- grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices
- )
- indices_size = indices.get_size()
- def fn(idx):
- *prefix, h, w = idx
- index_test = ops.index_expr(h * width + w, torch.int32)
- h = h + padding[0]
- w = w + padding[1]
- phstart = ops.index_expr(
- FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32
- )
- pwstart = ops.index_expr(
- FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32
- )
- phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32)
- pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32)
- phstart = ops.maximum(phstart, ops.constant(0, torch.int32))
- pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32))
- phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32))
- pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32))
- gradient = None
- for ph_ in range(h_window_size):
- for pw_ in range(w_window_size):
- ph = ops.add(phstart, ops.constant(ph_, torch.int32))
- pw = ops.add(pwstart, ops.constant(pw_, torch.int32))
- grad_index = [
- *prefix,
- ops.indirect_indexing(
- ops.minimum(ph, ops.sub(phend, ops.constant(1, torch.int32))),
- indices_size[-2],
- check=False,
- ),
- ops.indirect_indexing(
- ops.minimum(pw, ops.sub(pwend, ops.constant(1, torch.int32))),
- indices_size[-1],
- check=False,
- ),
- ]
- index_actual = indices_loader(grad_index)
- grad_part = grad_loader(grad_index)
- check = ops.eq(index_actual, index_test)
- if gradient is None:
- # don't need mask for 0, 0
- gradient = ops.where(
- check, grad_part, ops.constant(0.0, torch.float32)
- )
- else:
- mask = ops.and_(
- ops.and_(
- ops.lt(ph, phend),
- ops.lt(pw, pwend),
- ),
- check,
- )
- gradient = ops.where(mask, ops.add(gradient, grad_part), gradient)
- assert gradient is not None
- return gradient
- out = Pointwise.create(
- device=grad_output.get_device(),
- dtype=grad_output.get_dtype(),
- inner_fn=fn,
- ranges=new_size,
- )
- if is_channels_last:
- return ir.ExternKernel.require_channels_last(out)
- else:
- return out
- def pad_adaptive_loader(x, pad_val=0.0):
- x_loader = x.make_loader()
- def load(prefix, increments, start_indices, end_indices):
- ih, iw = increments
- h_start_index, w_start_index = start_indices
- h_end_index, w_end_index = end_indices
- mask = ops.and_(
- ops.lt(
- ops.index_expr(h_start_index + ih, torch.int64),
- ops.index_expr(h_end_index, torch.int64),
- ),
- ops.lt(
- ops.index_expr(w_start_index + iw, torch.int64),
- ops.index_expr(w_end_index, torch.int64),
- ),
- )
- return ops.masked(
- mask,
- lambda: x_loader([*prefix, h_start_index + ih, w_start_index + iw]),
- pad_val,
- )
- return load
- def compute_indices_adaptive_pooling(start_index, end_index, h_in, w_in, h_out, w_out):
- h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in)
- h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in)
- w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in)
- w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in)
- return h_start_index, h_end_index, w_start_index, w_end_index
- def _adaptive_pooling_fn(
- start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn
- ):
- h_in, w_in = in_sizes
- h_out, w_out = out_sizes
- (
- h_start_index_fn,
- h_end_index_fn,
- w_start_index_fn,
- w_end_index_fn,
- ) = compute_indices_adaptive_pooling(
- start_index, end_index, h_in, w_in, h_out, w_out
- )
- def fn(idx, loader):
- *prefix, bh, bw = idx
- h_start_index = h_start_index_fn(bh)
- h_end_index = h_end_index_fn(bh)
- w_start_index = w_start_index_fn(bw)
- w_end_index = w_end_index_fn(bw)
- result = None
- for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])):
- val = loader(
- prefix,
- [ih, iw],
- [h_start_index, w_start_index],
- [h_end_index, w_end_index],
- )
- if result is None:
- result = val
- else:
- result = pooling_fn(val, result)
- return result
- return fn
- def _adaptive_pooling_fn_with_idx(
- start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn
- ):
- h_in, w_in = in_sizes
- h_out, w_out = out_sizes
- (
- h_start_index_fn,
- h_end_index_fn,
- w_start_index_fn,
- w_end_index_fn,
- ) = compute_indices_adaptive_pooling(
- start_index, end_index, h_in, w_in, h_out, w_out
- )
- def fn(idx, loader):
- *prefix, bh, bw = idx
- h_start_index = h_start_index_fn(bh)
- h_end_index = h_end_index_fn(bh)
- w_start_index = w_start_index_fn(bw)
- w_end_index = w_end_index_fn(bw)
- maxval = None
- maxindex = None
- for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])):
- val = loader(
- prefix,
- [ih, iw],
- [h_start_index, w_start_index],
- [h_end_index, w_end_index],
- )
- index = ops.index_expr(
- (h_start_index + ih) * w_in + w_start_index + iw, torch.int64
- )
- if maxindex is None:
- maxindex = index
- else:
- maxindex = ops.where(ops.gt(val, maxval), index, maxindex)
- if maxval is None:
- maxval = val
- else:
- maxval = pooling_fn(val, maxval)
- return maxindex
- return fn
- fallback_adaptive_avg_pool2d = fallback_handler(
- aten._adaptive_avg_pool2d.default, add_to_fallback_set=False
- )
- @register_lowering(aten._adaptive_avg_pool2d)
- def _adaptive_avg_pool2d(x, output_size):
- if x.get_dtype() == torch.int64:
- # not supported in eager
- raise RuntimeError("'adaptive_avg_pool2d' not implemented for 'Long'")
- assert isinstance(x, TensorBox)
- assert len(output_size) == 2
- x.realize_hint()
- *batch, h_in, w_in = x.get_size()
- h_in = V.graph.sizevars.guard_int(h_in)
- w_in = V.graph.sizevars.guard_int(w_in)
- h_out, w_out = output_size
- # no-op if the same input and output
- if h_in == h_out and w_in == w_out:
- return clone(x)
- if h_out == 0 or w_out == 0:
- o_size = [*batch, h_out, w_out]
- return empty(o_size, dtype=x.get_dtype(), device=x.get_device())
- if h_in % h_out == 0 and w_in % w_out == 0:
- kernel_size = [FloorDiv(h_in, h_out), FloorDiv(w_in, w_out)]
- return avg_pool2d(x, kernel_size)
- h_kernel_max = ceildiv((h_in + h_out - 1), h_out)
- w_kernel_max = ceildiv((w_in + w_out - 1), w_out)
- new_size = list(batch) + [h_out, w_out]
- dtype = x.get_dtype()
- window_size = h_kernel_max * w_kernel_max
- if window_size > 25:
- # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
- return fallback_adaptive_avg_pool2d(x, output_size)
- def start_index(index, out_dim, inp_dim):
- return FloorDiv((index * inp_dim), out_dim)
- def end_index(index, out_dim, inp_dim):
- return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim)
- fn_sum = _adaptive_pooling_fn(
- start_index=start_index,
- end_index=end_index,
- kernel_maxes=[h_kernel_max, w_kernel_max],
- in_sizes=[h_in, w_in],
- out_sizes=[h_out, w_out],
- pooling_fn=ops.add,
- )
- ones_loader = pad_adaptive_loader(ones_like(x))
- def fn(idx):
- return ops.truediv(
- fn_sum(idx, pad_adaptive_loader(x)), fn_sum(idx, ones_loader)
- )
- rv = Pointwise.create(
- device=x.get_device(),
- dtype=dtype,
- inner_fn=fn,
- ranges=new_size,
- )
- # TODO: should we force these to be realized?
- return rv
- fallback_adaptive_max_pool2d = fallback_handler(
- aten.adaptive_max_pool2d.default, add_to_fallback_set=False
- )
- @register_lowering(aten.adaptive_max_pool2d)
- def adaptive_max_pool2d(x, output_size):
- if x.get_dtype() == torch.int64:
- # not supported in eager
- raise RuntimeError("adaptive_max_pool2d not implemented for Long")
- assert isinstance(x, TensorBox)
- assert len(output_size) == 2
- x.realize_hint()
- *batch, h_in, w_in = x.get_size()
- h_in = V.graph.sizevars.guard_int(h_in)
- w_in = V.graph.sizevars.guard_int(w_in)
- h_out, w_out = output_size
- if h_out == 0 or w_out == 0:
- o_size = [*batch, h_out, w_out]
- return empty(o_size, dtype=x.get_dtype(), device=x.get_device()), empty(
- o_size, dtype=torch.int64, device=x.get_device()
- )
- if h_in % h_out == 0 and w_in % w_out == 0:
- # This is handled by a decomposition
- raise ValueError
- h_kernel_max = ceildiv((h_in + h_out - 1), h_out)
- w_kernel_max = ceildiv((w_in + w_out - 1), w_out)
- new_size = list(batch) + [h_out, w_out]
- dtype = x.get_dtype()
- window_size = h_kernel_max * w_kernel_max
- if window_size > 25:
- # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
- return fallback_adaptive_max_pool2d(x, output_size)
- def start_index(index, out_dim, inp_dim):
- return FloorDiv((index * inp_dim), out_dim)
- def end_index(index, out_dim, inp_dim):
- return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim)
- inner_func_max_val = _adaptive_pooling_fn(
- start_index=start_index,
- end_index=end_index,
- kernel_maxes=[h_kernel_max, w_kernel_max],
- in_sizes=[h_in, w_in],
- out_sizes=[h_out, w_out],
- pooling_fn=ops.maximum,
- )
- inner_func_max_idx = _adaptive_pooling_fn_with_idx(
- start_index=start_index,
- end_index=end_index,
- kernel_maxes=[h_kernel_max, w_kernel_max],
- in_sizes=[h_in, w_in],
- out_sizes=[h_out, w_out],
- pooling_fn=ops.maximum,
- )
- def inner_fn_max_val(idx):
- return inner_func_max_val(idx, pad_adaptive_loader(x, float("-inf")))
- def inner_fn_max_idx(idx):
- return inner_func_max_idx(idx, pad_adaptive_loader(x, float("-inf")))
- rv = Pointwise.create(
- device=x.get_device(),
- dtype=dtype,
- inner_fn=inner_fn_max_val,
- ranges=new_size,
- )
- ri = Pointwise.create(
- device=x.get_device(),
- dtype=torch.int64,
- inner_fn=inner_fn_max_idx,
- ranges=new_size,
- )
- return rv, ri
- def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim, ndims):
- out_sz = out_sz[dim]
- in_sz = in_sz[dim]
- kernel_sz = kernel_sz[dim]
- samples_loader = samples.make_loader()
- def load(prefix, i):
- # Handle indexing for samples tensor correctly for different input dimensions
- # samples tensor always has shape (N, C, 2) for fractional_max_pool2d where:
- # - N=1 for 3D inputs (C,H,W), N=batch_size for 4D inputs (N,C,H,W)
- # - C=num_channels
- # - 2 for the two spatial dimensions (height, width)
- samples_shape = samples.get_size()
- if len(samples_shape) == 3: # Expected: (N, C, 2)
- if len(prefix) == 1:
- # 3D input case: prefix=(channel,), samples=(1, C, 2)
- # Access: samples[0, channel, dim]
- sample = samples_loader([0, prefix[0], ndims - 1 - dim])
- elif len(prefix) >= 2:
- # 4D+ input case: prefix=(batch, channel, ...), samples=(batch, C, 2)
- # Access: samples[batch, channel, dim]
- sample = samples_loader([prefix[0], prefix[1], ndims - 1 - dim])
- else:
- # Edge case - shouldn't happen for valid fractional pooling
- sample = samples_loader([0, 0, ndims - 1 - dim])
- else:
- # Fallback for unexpected tensor shapes
- sample = samples_loader([*prefix, ndims - 1 - dim])
- i_expr = ops.index_expr(i, samples.get_dtype())
- diff = ops.index_expr(in_sz - kernel_sz, torch.int64)
- out_sz_expr = ops.index_expr(out_sz - 1, torch.int64)
- alpha = ops.truediv(
- ops.to_dtype(diff, torch.float64), ops.to_dtype(out_sz_expr, torch.float64)
- )
- alpha = ops.where(ops.eq(out_sz_expr, 0), 0, alpha)
- seq_i = ops.trunc((i_expr + sample) * alpha) - ops.trunc(sample * alpha)
- seq_i = ops.to_dtype(seq_i, torch.int64)
- mask = ops.lt(i_expr, out_sz_expr)
- return ops.indirect_indexing(ops.where(mask, seq_i, diff), sympy.sympify(in_sz))
- return load
- @register_lowering(aten.fractional_max_pool2d)
- def fractional_max_pool2d(x, kernel_size, output_size, random_samples):
- return _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim=2)
- @register_lowering(aten.fractional_max_pool3d)
- def fractional_max_pool3d(x, kernel_size, output_size, random_samples):
- return _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim=3)
- def _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim):
- x.realize_hint()
- batch, inp_dhw = x.shape[:-n_dim], x.shape[-n_dim:]
- with config.patch(unroll_reductions_threshold=25):
- dhw_index_fn = [
- _fractional_pooling_offsets(
- samples=random_samples,
- in_sz=inp_dhw,
- out_sz=output_size,
- kernel_sz=kernel_size,
- ndims=n_dim,
- dim=d,
- )
- for d in range(n_dim)
- ]
- x_loader = x.make_loader()
- def fn_inner(idx, reduction_idx):
- prefix = idx[:-n_dim]
- return x_loader([*prefix, *increments_to_index(idx, reduction_idx)])
- def increments_to_index(idx, reduction_idx):
- prefix = idx[:-n_dim]
- bdhw = idx[-n_dim:]
- return [
- dhw_index_fn[d](prefix, bdhw[d]) + reduction_idx[d]
- for d in range(n_dim)
- ]
- new_size = list(batch) + list(output_size)
- dtype = x.get_dtype()
- result = Reduction.create(
- reduction_type="max",
- input_node=x,
- device=x.get_device(),
- dst_dtype=dtype,
- src_dtype=dtype,
- inner_fn=fn_inner,
- ranges=new_size,
- reduction_ranges=kernel_size,
- )
- offsets = Reduction.create(
- reduction_type="argmax",
- input_node=x,
- device=x.get_device(),
- dst_dtype=torch.int64,
- src_dtype=dtype,
- inner_fn=fn_inner,
- ranges=new_size,
- reduction_ranges=kernel_size,
- )
- assert isinstance(result, TensorBox), result
- if isinstance(result.data.data, Reduction): # type: ignore[attr-defined]
- # Only realize if reduction isn't unrolled
- result.realize()
- assert isinstance(offsets, TensorBox), offsets
- if isinstance(offsets.data.data, Reduction): # type: ignore[attr-defined]
- # Only realize if reduction isn't unrolled
- offsets.realize()
- indices = _pool_offsets_to_indices(
- offsets, kernel_size, x.shape, increments_to_index
- )
- return result, indices
- @register_lowering(aten.upsample_nearest2d_backward.default)
- def upsample_nearest2d_backward(
- x, output_size=None, input_size=None, scales_h=None, scales_w=None
- ):
- x.realize_hint()
- *_batch, inp_h, inp_w = x.get_size()
- inp_h = V.graph.sizevars.guard_int(inp_h)
- inp_w = V.graph.sizevars.guard_int(inp_w)
- # pyrefly: ignore [not-iterable]
- *_batch, out_h, out_w = input_size
- if inp_h % out_h == 0 and inp_w % out_w == 0:
- return avg_pool2d(
- x, [FloorDiv(inp_h, out_h), FloorDiv(inp_w, out_w)], divisor_override=1
- )
- h_kernel_max = ceildiv(inp_h, out_h)
- w_kernel_max = ceildiv(inp_w, out_w)
- def start_index(index, out_dim, inp_dim):
- return CeilDiv(index * inp_dim, sympy.sympify(out_dim))
- def end_index(index, out_dim, inp_dim):
- return start_index((index + 1), out_dim, inp_dim)
- fn_sum = _adaptive_pooling_fn(
- start_index=start_index,
- end_index=end_index,
- kernel_maxes=[h_kernel_max, w_kernel_max],
- in_sizes=[inp_h, inp_w],
- out_sizes=[out_h, out_w],
- pooling_fn=ops.add,
- )
- def fn(idx):
- return fn_sum(idx, pad_adaptive_loader(x))
- rv = Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=fn,
- # pyrefly: ignore [no-matching-overload]
- ranges=list(input_size),
- )
- return rv
- @register_lowering(aten.avg_pool2d, type_promotion_kind=None)
- def avg_pool2d(
- x,
- kernel_size,
- stride=(),
- padding=0,
- ceil_mode=False,
- count_include_pad=True,
- divisor_override=None,
- ):
- return _avg_poolnd(
- x,
- kernel_size,
- stride,
- padding,
- ceil_mode,
- count_include_pad,
- divisor_override,
- dim=2,
- )
- @register_lowering(aten.avg_pool3d, type_promotion_kind=None)
- def avg_pool3d(
- x,
- kernel_size,
- stride=(),
- padding=0,
- ceil_mode=False,
- count_include_pad=True,
- divisor_override=None,
- ):
- return _avg_poolnd(
- x,
- kernel_size,
- stride,
- padding,
- ceil_mode,
- count_include_pad,
- divisor_override,
- dim=3,
- )
- fallbacks_avg_poolnd = [
- fallback_handler(aten.avg_pool1d.default, add_to_fallback_set=False),
- fallback_handler(aten.avg_pool2d.default, add_to_fallback_set=False),
- fallback_handler(aten.avg_pool3d.default, add_to_fallback_set=False),
- ]
- def _avg_poolnd(
- x,
- kernel_size,
- stride,
- padding,
- ceil_mode,
- count_include_pad,
- divisor_override,
- dim,
- ):
- if not stride:
- stride = kernel_size
- if not padding:
- padding = [0] * dim
- kernel_size = pad_listlike(kernel_size, dim)
- stride = pad_listlike(stride, dim)
- padding = pad_listlike(padding, dim)
- assert isinstance(x, TensorBox)
- assert len(kernel_size) == dim
- assert len(stride) == dim
- assert len(padding) == dim
- assert len(x.get_size()) in (dim + 1, dim + 2)
- x.realize_hint()
- batch = x.get_size()[:-dim]
- h = x.get_size()[-dim:]
- h_out, ceil_modes = zip(
- *[
- pooling_size(h[i], i, kernel_size, stride, padding, ceil_mode)
- for i in range(dim)
- ]
- )
- if any(padding) or any(ceil_modes):
- x_loader = constant_boundary_condition(x, 0.0, dim=dim)
- had_padding = True
- else:
- x_loader = x.make_loader()
- had_padding = False
- new_size = list(batch) + list(h_out)
- dtype = x.get_dtype()
- # compute in higher-precision until scaling
- output_dtype = get_promoted_dtype(
- x,
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- return_compute_dtype=True,
- )
- def fn_inner(idx, reduction_idx):
- prefix = idx[:-dim]
- bh = idx[-dim:]
- ih = reduction_idx
- ih = [bh[i] * stride[i] + ih[i] - padding[i] for i in range(dim)]
- return x_loader([*prefix, *ih])
- window_size = functools.reduce(operator.mul, kernel_size)
- if window_size > 25 and any(
- V.graph.sizevars.statically_known_true(sympy.Ne(k, s))
- for k, s in zip(kernel_size, stride)
- ):
- fallback = fallbacks_avg_poolnd[dim - 1]
- return fallback(
- x,
- kernel_size,
- stride,
- padding,
- ceil_mode,
- count_include_pad,
- divisor_override,
- )
- # TODO: remove this when #100331 is merged. We only do this
- # for window_size <=25 to avoid performance regressions compared
- # to the previous algorithm which unrolled manually for <=25
- context = (
- config.patch(unroll_reductions_threshold=25)
- if window_size <= 25
- else contextlib.nullcontext()
- )
- device = x.get_device()
- assert device is not None
- with context:
- rv = Reduction.create(
- reduction_type="sum",
- input_node=x,
- device=device,
- dst_dtype=output_dtype,
- src_dtype=dtype,
- inner_fn=fn_inner,
- ranges=new_size,
- reduction_ranges=kernel_size,
- )
- if hasattr(rv.data, "data") and isinstance(rv.data.data, Reduction):
- # Only realize if reduction isn't unrolled
- rv.realize()
- if not had_padding or divisor_override:
- divisor = divisor_override if divisor_override else window_size
- result = div_prim(rv, divisor)
- else:
- def fn_count(idx):
- bh = idx[-dim:]
- divide_factors = []
- for i in range(dim):
- hstart = bh[i] * stride[i] - padding[i]
- hend = sympy.Min(hstart + kernel_size[i], h[i] + padding[i])
- if not count_include_pad:
- hstart = sympy.Max(hstart, 0)
- hend = sympy.Min(hend, h[i])
- factor = ops.index_expr(hend - hstart, torch.int32)
- divide_factors.append(factor)
- return functools.reduce(ops.mul, divide_factors)
- divide_factor = Pointwise.create(
- device=x.get_device(),
- dtype=dtype,
- inner_fn=fn_count,
- ranges=new_size,
- )
- result = div_prim(rv, divide_factor)
- return to_dtype(result, dtype)
- fallback_avg_pool2d_backward = fallback_handler(
- aten.avg_pool2d_backward.default, add_to_fallback_set=False
- )
- @register_lowering(aten.avg_pool2d_backward, type_promotion_kind=None)
- def avg_pool2d_backward(
- grad_output,
- x,
- kernel_size,
- stride,
- padding,
- ceil_mode,
- count_include_pad,
- divisor_override=None,
- ):
- assert divisor_override is None or divisor_override != 0, "divisor must be not zero"
- if not stride:
- stride = kernel_size
- if not padding:
- padding = [0, 0]
- assert isinstance(grad_output, TensorBox)
- assert isinstance(x, TensorBox)
- assert len(kernel_size) == 2
- assert len(stride) == 2
- assert len(padding) == 2
- assert len(x.get_size()) in (3, 4)
- grad_output.realize_hint() # we will read this many times, so make sure it is computed
- *_, height, width = x.get_size()
- _h_out, ceil_mode1 = pooling_size(
- height, 0, kernel_size, stride, padding, ceil_mode
- )
- _w_out, ceil_mode2 = pooling_size(width, 1, kernel_size, stride, padding, ceil_mode)
- grad_loader = grad_output.make_loader()
- had_padding = padding[0] or padding[1] or ceil_mode1 or ceil_mode2
- *_, pooled_height, pooled_width = grad_output.get_size()
- new_size = list(x.get_size())
- dtype = x.get_dtype()
- h_window_size = max(
- max(FloorDiv(h, stride[0]) - max(0, FloorDiv(h - kernel_size[0], stride[0])), 1)
- for h in range(kernel_size[0] * 2)
- )
- w_window_size = max(
- max(FloorDiv(w, stride[1]) - max(0, FloorDiv(w - kernel_size[1], stride[1])), 1)
- for w in range(kernel_size[1] * 2)
- )
- window_size = h_window_size * w_window_size
- if window_size > 25:
- # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
- return fallback_avg_pool2d_backward(
- grad_output,
- x,
- kernel_size,
- stride,
- padding,
- ceil_mode,
- count_include_pad,
- divisor_override,
- )
- def compute_pool_size_without_padding(ph, pw):
- """
- This computes the scaling factor that we will divide an element
- by when `count_include_pad=False`
- """
- stride_h = ops.constant(stride[0], torch.int32)
- stride_w = ops.constant(stride[1], torch.int32)
- pad_h = ops.constant(padding[0], torch.int32)
- pad_w = ops.constant(padding[1], torch.int32)
- kernel_h = ops.constant(kernel_size[0], torch.int32)
- kernel_w = ops.constant(kernel_size[1], torch.int32)
- hstart = ops.sub(ops.mul(ph, stride_h), pad_h)
- wstart = ops.sub(ops.mul(pw, stride_w), pad_w)
- hend = ops.minimum(
- ops.add(hstart, kernel_h),
- ops.add(ops.index_expr(height, torch.int32), pad_h),
- )
- wend = ops.minimum(
- ops.add(wstart, kernel_w),
- ops.add(ops.index_expr(width, torch.int32), pad_w),
- )
- hstart = ops.maximum(hstart, ops.constant(0, torch.int32))
- wstart = ops.maximum(wstart, ops.constant(0, torch.int32))
- hend = ops.minimum(hend, ops.index_expr(height, torch.int32))
- wend = ops.minimum(wend, ops.index_expr(width, torch.int32))
- divide_factor = ops.mul(ops.sub(hend, hstart), ops.sub(wend, wstart))
- return divide_factor
- def fn(idx):
- *prefix, h, w = idx
- h = h + padding[0]
- w = w + padding[1]
- phstart = ops.index_expr(
- FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32
- )
- pwstart = ops.index_expr(
- FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32
- )
- phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32)
- pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32)
- phstart = ops.maximum(phstart, ops.constant(0, torch.int32))
- pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32))
- phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32))
- pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32))
- gradient = None
- for ph_ in range(h_window_size):
- for pw_ in range(w_window_size):
- ph = ops.add(phstart, ops.constant(ph_, torch.int32))
- pw = ops.add(pwstart, ops.constant(pw_, torch.int32))
- if divisor_override is not None:
- scale = divisor_override
- elif count_include_pad or not had_padding:
- scale = kernel_size[0] * kernel_size[1]
- else:
- scale = compute_pool_size_without_padding(ph, pw)
- part = ops.truediv(
- grad_loader(
- [
- *prefix,
- ops.indirect_indexing(
- ops.minimum(
- ph, ops.sub(phend, ops.constant(1, torch.int32))
- ),
- pooled_height,
- check=False,
- ),
- ops.indirect_indexing(
- ops.minimum(
- pw, ops.sub(pwend, ops.constant(1, torch.int32))
- ),
- pooled_width,
- check=False,
- ),
- ]
- ),
- scale,
- )
- mask = ops.and_(
- ops.lt(ph, phend),
- ops.lt(pw, pwend),
- )
- if gradient is None:
- gradient = ops.where(mask, part, ops.constant(0.0, torch.float32))
- else:
- gradient = ops.where(mask, ops.add(gradient, part), gradient)
- assert gradient is not None
- return gradient
- rv = Pointwise.create(
- device=grad_output.get_device(),
- dtype=dtype,
- inner_fn=fn,
- ranges=new_size,
- )
- return rv
- fallback_avg_pool3d_backward = fallback_handler(
- aten.avg_pool3d_backward.default, add_to_fallback_set=False
- )
- @register_lowering(aten.avg_pool3d_backward, type_promotion_kind=None)
- def avg_pool3d_backward(
- grad_output,
- x,
- kernel_size,
- stride,
- padding,
- ceil_mode,
- count_include_pad,
- divisor_override=None,
- ):
- assert divisor_override is None or divisor_override != 0, "divisor must be not zero"
- if not stride:
- stride = kernel_size
- if not padding:
- padding = [0, 0, 0]
- assert isinstance(grad_output, TensorBox)
- assert isinstance(x, TensorBox)
- assert len(kernel_size) == 3
- assert len(stride) == 3
- assert len(padding) == 3
- assert len(x.get_size()) in (4, 5)
- grad_output.realize_hint()
- *_batch, depth, height, width = x.get_size()
- _d_out, ceil_mode_d = pooling_size(
- depth, 0, kernel_size, stride, padding, ceil_mode
- )
- _h_out, ceil_mode_h = pooling_size(
- height, 1, kernel_size, stride, padding, ceil_mode
- )
- _w_out, ceil_mode_w = pooling_size(
- width, 2, kernel_size, stride, padding, ceil_mode
- )
- grad_loader = grad_output.make_loader()
- had_padding = any(padding) or ceil_mode_d or ceil_mode_h or ceil_mode_w
- *_, pooled_depth, pooled_height, pooled_width = grad_output.get_size()
- new_size = list(x.get_size())
- dtype = x.get_dtype()
- d_window_size, h_window_size, w_window_size = (
- max(
- max(d // stride[i] - max(0, (d - kernel_size[i]) // stride[i]), 1)
- for d in range(kernel_size[i] * 2)
- )
- for i in range(3)
- )
- window_size = d_window_size * h_window_size * w_window_size
- if window_size > 125:
- # Kernel size too big. Results in hard-to-optimize Triton code.
- return fallback_avg_pool3d_backward(
- grad_output,
- x,
- kernel_size,
- stride,
- padding,
- ceil_mode,
- count_include_pad,
- divisor_override,
- )
- def compute_pool_size_without_padding(pd, ph, pw):
- stride_d, stride_h, stride_w = (ops.constant(s, torch.int32) for s in stride)
- pad_d, pad_h, pad_w = (ops.constant(p, torch.int32) for p in padding)
- kernel_d, kernel_h, kernel_w = (
- ops.constant(k, torch.int32) for k in kernel_size
- )
- dstart, hstart, wstart = (
- ops.sub(ops.mul(p, s), pad)
- for p, s, pad in zip(
- [pd, ph, pw], [stride_d, stride_h, stride_w], [pad_d, pad_h, pad_w]
- )
- )
- dend, hend, wend = (
- ops.minimum(
- ops.add(start, k), ops.add(ops.index_expr(dim, torch.int32), pad)
- )
- for start, k, dim, pad in zip(
- [dstart, hstart, wstart],
- [kernel_d, kernel_h, kernel_w],
- [depth, height, width],
- [pad_d, pad_h, pad_w],
- )
- )
- dstart, hstart, wstart = (
- ops.maximum(start, ops.constant(0, torch.int32))
- for start in [dstart, hstart, wstart]
- )
- dend, hend, wend = (
- ops.minimum(end, ops.index_expr(dim, torch.int32))
- for end, dim in zip([dend, hend, wend], [depth, height, width])
- )
- divide_factor = ops.mul(
- ops.mul(ops.sub(dend, dstart), ops.sub(hend, hstart)), ops.sub(wend, wstart)
- )
- return divide_factor
- def fn(idx):
- *prefix, d, h, w = idx
- d, h, w = (v + pad for v, pad in zip([d, h, w], padding))
- pdstart, phstart, pwstart = (
- ops.index_expr(FloorDiv(v - k + s, s), torch.int32)
- for v, k, s in zip([d, h, w], kernel_size, stride)
- )
- pdend, phend, pwend = (
- ops.index_expr(FloorDiv(v, s) + 1, torch.int32)
- for v, s in zip([d, h, w], stride)
- )
- pdstart, phstart, pwstart = (
- ops.maximum(pstart, ops.constant(0, torch.int32))
- for pstart in [pdstart, phstart, pwstart]
- )
- pdend, phend, pwend = (
- ops.minimum(pend, ops.index_expr(pooled_dim, torch.int32))
- for pend, pooled_dim in zip(
- [pdend, phend, pwend], [pooled_depth, pooled_height, pooled_width]
- )
- )
- gradient = None
- # Iterate over the 3D region to accumulate gradients
- for pd_ in range(d_window_size):
- for ph_ in range(h_window_size):
- for pw_ in range(w_window_size):
- pd, ph, pw = (
- ops.add(pstart, ops.constant(p_, torch.int32))
- for pstart, p_ in zip(
- [pdstart, phstart, pwstart], [pd_, ph_, pw_]
- )
- )
- if divisor_override is not None:
- scale = divisor_override
- elif count_include_pad or not had_padding:
- scale = kernel_size[0] * kernel_size[1] * kernel_size[2]
- else:
- scale = compute_pool_size_without_padding(pd, ph, pw)
- part = ops.truediv(
- grad_loader(
- [
- *prefix,
- ops.indirect_indexing(
- ops.minimum(
- pd, ops.sub(pdend, ops.constant(1, torch.int32))
- ),
- pooled_depth,
- check=False,
- ),
- ops.indirect_indexing(
- ops.minimum(
- ph, ops.sub(phend, ops.constant(1, torch.int32))
- ),
- pooled_height,
- check=False,
- ),
- ops.indirect_indexing(
- ops.minimum(
- pw, ops.sub(pwend, ops.constant(1, torch.int32))
- ),
- pooled_width,
- check=False,
- ),
- ]
- ),
- scale,
- )
- mask = ops.and_(
- ops.and_(ops.lt(pd, pdend), ops.lt(ph, phend)),
- ops.lt(pw, pwend),
- )
- if gradient is None:
- gradient = ops.where(
- mask, part, ops.constant(0.0, torch.float32)
- )
- else:
- gradient = ops.where(mask, ops.add(gradient, part), gradient)
- assert gradient is not None
- return gradient
- rv = Pointwise.create(
- device=grad_output.get_device(),
- dtype=dtype,
- inner_fn=fn,
- ranges=new_size,
- )
- return rv
- def _validate_reduction_axis(x, axis):
- size = x.get_size()
- if isinstance(axis, int):
- axis = [axis]
- elif not axis:
- axis = range(len(size))
- if len(size) == 0:
- assert tuple(axis) in [(), (0,), (-1,)], f"invalid axis: {axis}"
- return []
- axis = list(axis)
- for i in range(len(axis)):
- if axis[i] < 0:
- axis[i] += len(size) if len(size) else 1
- assert 0 <= axis[i] < len(size) or (len(size) == 0 and axis[i] == 0)
- assert len(OrderedSet(axis)) == len(axis), "reduction axis not unique"
- return axis
- def _make_reduction_inner(
- x, *, axis, keepdims, dtype, override_return_dtype, reduction_type=None
- ):
- if dtype is not None:
- x = to_dtype(x, dtype)
- size = x.get_size()
- axis = OrderedSet[int](_validate_reduction_axis(x, axis))
- kept_sizes = []
- kept_idx = []
- reduced_sizes = []
- reduced_idx = []
- for i in range(len(size)):
- if i in axis:
- reduced_idx.append(i)
- reduced_sizes.append(size[i])
- else:
- kept_idx.append(i)
- kept_sizes.append(size[i])
- # For argmax/argmin compute logical indices when the tensor has non-contiguous layout.
- should_compute_logical_index = False
- if (
- reduction_type in ("argmax", "argmin")
- and len(reduced_sizes) > 1
- and is_triton(x)
- ):
- if isinstance(x.data, PermuteView):
- should_compute_logical_index = True
- elif isinstance(x.data, ir.ReinterpretView) or (
- isinstance(x.data, ir.StorageBox) and isinstance(x.data.data, ir.Buffer)
- ):
- layout = x.get_layout()
- should_compute_logical_index = (
- layout.is_transposed() or not layout.is_contiguous()
- )
- def loader(index, reduction_index):
- assert len(reduction_index) == len(reduced_idx)
- if keepdims:
- assert len(index) == len(size)
- index = [index[i] for i in kept_idx]
- assert len(index) == len(kept_idx)
- new_index = [None] * (len(index) + len(reduction_index))
- for idx, var in itertools.chain(
- zip(kept_idx, index), zip(reduced_idx, reduction_index)
- ):
- new_index[idx] = var
- value = inner_loader(new_index)
- # For argmax/argmin, return tuple with logical linear index if needed
- if should_compute_logical_index:
- rindex = [sympy.expand(i) for i in reduction_index]
- # Compute linear index in row-major order
- # For reduction_ranges = [4, 6]: linear_index = r0 * 6 + r1
- linear_idx = rindex[0]
- for i in range(1, len(rindex)):
- linear_idx = linear_idx * reduced_sizes[i] + rindex[i]
- return (value, ops.index_expr(linear_idx, torch.int64))
- return value
- if keepdims:
- new_size = list(size)
- for i in reduced_idx:
- new_size[i] = sympy.S.One
- else:
- new_size = kept_sizes
- inner_loader = x.make_loader()
- return dict(
- device=x.get_device(),
- dst_dtype=override_return_dtype or x.get_dtype(),
- src_dtype=x.get_dtype(),
- inner_fn=loader,
- ranges=new_size,
- reduction_ranges=reduced_sizes,
- )
- def make_reduction(reduction_type: ReductionType, override_return_dtype=None):
- def inner(x, axis=None, keepdims=False, *, dtype=None):
- kwargs = _make_reduction_inner(
- x,
- axis=axis,
- keepdims=keepdims,
- dtype=dtype,
- override_return_dtype=override_return_dtype,
- reduction_type=reduction_type,
- )
- result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs)
- if isinstance(
- result.data.data, # type: ignore[attr-defined, attr-type, union-attr]
- Reduction,
- ): # Only realize if reduction isn't unrolled
- result.realize()
- return result
- return inner
- def _make_scan_inner(x, *, axis, dtype):
- if dtype is not None:
- x = to_dtype(x, dtype)
- axis = _validate_dim(x, axis)
- return dict(
- device=x.get_device(),
- dtypes=(x.get_dtype(),),
- inner_fns=(x.make_loader(),),
- size=x.get_size(),
- axis=axis,
- )
- @register_lowering(aten.mean)
- def mean(x, axis=None, keepdim=False, *, dtype=None):
- if dtype is not None:
- x = to_dtype(x, dtype)
- size = x.get_size()
- axis = _validate_reduction_axis(x, axis)
- # compute in higher-precision until end of mean lowering
- output_dtype = x.get_dtype()
- if output_dtype in (torch.float16, torch.bfloat16):
- x = to_dtype(x, torch.float)
- sum_result = sum_(x, axis, keepdim)
- denom = sympy_product(size[i] for i in axis)
- denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device())
- denom = ExpandView.create(denom, list(sum_result.get_size()))
- return to_dtype(div(sum_result, denom), output_dtype)
- def var_mean_sum_(x, axis, correction, keepdim, return_mean):
- if correction is None:
- correction = 1
- size = x.get_size()
- axis = _validate_reduction_axis(x, axis)
- x_mean = mean(x, axis, keepdim=True)
- if return_mean:
- x_mean.realize()
- diffs = square(sub(x, x_mean))
- sum_result = sum_(diffs, axis, keepdim)
- denom = sympy_product(size[i] for i in axis)
- if correction:
- denom = sympy.Max(denom - correction, 0)
- denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device())
- denom = ExpandView.create(denom, list(sum_result.get_size()))
- x_var = div(sum_result, denom)
- if not return_mean:
- return (x_var,)
- x_mean = x_mean if keepdim else squeeze(x_mean, axis)
- return x_var, x_mean
- def use_two_step_variance(x, axis, keepdim):
- # Instead of unrolling welford, just unroll the simpler two-step var
- axis = _validate_reduction_axis(x, axis)
- kwargs = _make_reduction_inner(
- x, axis=axis, keepdims=keepdim, dtype=None, override_return_dtype=None
- )
- ranges = kwargs["ranges"]
- reduction_numel = sympy_product(kwargs["reduction_ranges"])
- return (
- isinstance(reduction_numel, sympy.Integer)
- and int(reduction_numel) < config.unroll_reductions_threshold
- and sympy_product(ranges) != 1
- )
- def var_mean_welford_(x, axis, *, correction, keepdim, return_mean):
- if correction is None:
- correction = 1
- kwargs = _make_reduction_inner(
- x, axis=axis, keepdims=keepdim, dtype=None, override_return_dtype=None
- )
- loader = kwargs.pop("inner_fn")
- kwargs.pop("dst_dtype")
- kwargs.pop("src_dtype")
- mean, m2, _ = ir.WelfordReduction.create(
- inner_fns=(loader,),
- reduction_type="welford_reduce",
- dtype=x.get_dtype(),
- **kwargs,
- )
- m2.realize()
- dtype = x.get_dtype()
- size = x.get_size()
- axis = _validate_reduction_axis(x, axis)
- rnumel = sympy_product(size[i] for i in axis)
- def get_constant_or_index_expr(x, dtype):
- if isinstance(x, sympy.Expr) and not x.is_number:
- return ops.to_dtype(ops.index_expr(x, torch.int64), dtype)
- return ops.constant(x, dtype)
- def scale_fn(data):
- c = get_constant_or_index_expr(correction, dtype)
- N = get_constant_or_index_expr(rnumel, dtype)
- zero = ops.constant(0, dtype)
- return data / ops.maximum(zero, N - c)
- var = make_pointwise(scale_fn)(m2)
- if return_mean:
- mean.realize()
- return var, mean
- return (var,)
- def var_mean_helper_(x, *, axis, correction, keepdim, return_mean):
- out_dtype = x.get_dtype()
- compute_dtype = get_computation_dtype(out_dtype)
- x = to_dtype(x, compute_dtype, copy=False)
- kwargs = dict(
- x=x,
- axis=axis,
- correction=correction,
- keepdim=keepdim,
- return_mean=return_mean,
- )
- output = (
- var_mean_sum_(**kwargs)
- if use_two_step_variance(x, axis=axis, keepdim=keepdim)
- else var_mean_welford_(**kwargs)
- )
- output = tuple(to_dtype(x, out_dtype, copy=False) for x in output)
- return output[0] if not return_mean else output
- @register_lowering([aten.var, prims.var])
- def var_(x, axis=None, *, correction=None, keepdim=False):
- return var_mean_helper_(
- x, axis=axis, correction=correction, keepdim=keepdim, return_mean=False
- )
- @register_lowering(aten.var_mean)
- def var_mean(x, axis=None, *, correction=None, keepdim=False):
- return var_mean_helper_(
- x, axis=axis, correction=correction, keepdim=keepdim, return_mean=True
- )
- def pow_recursive(x, y, dtype):
- if y < 0:
- return pow_recursive(ops.reciprocal(x), -y, dtype)
- if y == 0:
- return ops.constant(1, dtype)
- if y == 1:
- return x
- result = pow_recursive(x, y // 2, dtype)
- result = ops.mul(result, result)
- if (y % 2) == 1:
- result = ops.mul(result, x)
- return result
- @make_pointwise
- def pow_native(a, b):
- return ops.pow(a, b)
- fallback_pow_tensor_tensor = fallback_handler(
- aten.pow.Tensor_Tensor, add_to_fallback_set=False
- )
- fallback_pow_scalar = fallback_handler(aten.pow.Scalar, add_to_fallback_set=False)
- fallback_pow_tensor_scalar = fallback_handler(
- aten.pow.Tensor_Scalar, add_to_fallback_set=False
- )
- @register_lowering(aten.pow, broadcast=True)
- def pow(a, b):
- if isinstance(b, float) and b.is_integer():
- return pow(a, int(b))
- elif isinstance(b, float) and b == 0.5:
- return sqrt(a)
- elif isinstance(b, int) and b == 1:
- return clone(a)
- # Type promotion ensures all tensor arguments have the same type
- dtype = next(x.get_dtype() for x in (a, b) if isinstance(x, ir.TensorBox))
- is_integer_pow = is_integer_dtype(dtype)
- # Optimize away small fixed powers, or for integers avoid falling back to ATen
- embed_exponent = isinstance(b, int) and (
- -32 < b < 32 or (is_integer_pow and b >= 0)
- )
- if embed_exponent:
- loader = a.make_loader()
- def fn(idx):
- return pow_recursive(loader(idx), b, a.get_dtype())
- return Pointwise.create(
- device=a.get_device(),
- dtype=a.get_dtype(),
- inner_fn=fn,
- ranges=a.get_size(),
- )
- if isinstance(a, Number):
- if a == 1:
- return full_like(b, 1)
- if a == 2 and is_float_dtype(b.get_dtype()):
- return exp2(b)
- if is_integer_pow:
- # ops.pow doesn't work for integers
- if isinstance(a, Number):
- return fallback_pow_scalar(a, b)
- elif isinstance(b, Number):
- return fallback_pow_tensor_scalar(a, b)
- else:
- return fallback_pow_tensor_tensor(a, b)
- return pow_native(a, b)
- def mutate_to(changed, val, unsafe_alias=False):
- if isinstance(changed, TensorBox):
- changed_data = changed.data
- else:
- changed_data = changed
- if isinstance(val, TensorBox):
- val = val.data
- if not isinstance(val, ir.StorageBox):
- # introduce a copy to handle views
- node = Pointwise.create(
- device=changed.get_device(),
- dtype=changed.get_dtype(),
- inner_fn=val.make_loader(),
- ranges=changed.get_size(),
- )
- assert isinstance(node, (BaseView, MutableBox))
- val = node.data
- assert isinstance(val, ir.StorageBox)
- if isinstance(changed_data, ir.StorageBox) and not (
- changed_data.is_input_buffer()
- # In AOTI, module parameters and buffers are not lifted as graph inputs
- or changed_data.is_module_buffer()
- or isinstance(changed_data.data, ir.NopKernel)
- ):
- # Fast path, just swing the data pointer
- val.realize()
- changed_data.data = val.data
- return changed
- ir.MutationLayoutSHOULDREMOVE.realize_into(
- val, changed_data, unsafe_alias=unsafe_alias
- )
- return changed
- @register_lowering(aten.fill_)
- def fill_(x, fill_value):
- return mutate_to(x, full_like(x, fill_value))
- @register_lowering(aten.copy_, type_promotion_kind=None)
- def copy_(dst, src, non_blocking=False):
- if dst is src:
- # dst.copy_(dst) can happen from the reinplacing pass
- return dst
- src = to_device(src, dst.get_device())
- src = to_dtype(src, dst.get_dtype())
- src = expand(src, dst.get_size())
- return mutate_to(dst, src)
- @make_pointwise
- def floordiv(a, b):
- return ops.floordiv(a, b)
- @make_pointwise
- def truncdiv(a, b):
- return ops.truncdiv(a, b)
- @register_lowering(aten.div, broadcast=True)
- def div_mode(a, b, rounding_mode=None):
- both_integer = is_integer_type(a) and is_integer_type(b)
- both_boolean = is_boolean_type(a) and is_boolean_type(b)
- # floordiv and truncdiv need special handling for integer tensors on Triton,
- # see the discussion at https://github.com/triton-lang/triton/issues/605
- if rounding_mode == "floor":
- assert not both_boolean, "floordiv operands can not be boolean at the same time"
- return floordiv(a, b) if both_integer else floor(div(a, b))
- if rounding_mode == "trunc":
- assert not both_boolean, "truncdiv operands can not be boolean at the same time"
- return truncdiv(a, b) if both_integer else trunc(div(a, b))
- return div(a, b)
- @register_lowering([aten.mul], broadcast=True)
- def mul(a, b):
- both_bool = is_boolean_type(a) and is_boolean_type(b)
- if both_bool:
- return logical_and(a, b)
- else:
- fn = ops_wrapper(aten.mul.__name__)
- return make_pointwise(fn)(a, b)
- def get_constant_value(x: ir.IRNode) -> Optional[ir.Constant]:
- """Try convert an arbitrary IR node into an ir.Constant value"""
- # First try unwrapping the IRNode to see if it is already an ir.Constant
- # Optional step, but avoids unnecessary inner_fn evaluation.
- if isinstance(x, ir.MutableBox):
- return get_constant_value(x.data)
- if isinstance(x, ir.BaseView):
- return get_constant_value(x.unwrap_view())
- if isinstance(x, ir.Constant):
- return x
- # If the unwrapped node is not an ir.Constant, try evaluating inner_fn
- # to see if the returned value is from an `ops.constant` call
- if not isinstance(x, ir.Loops):
- return None
- handler = torch._inductor.ops_handler.ExtractConstantsHandler(x.get_device())
- with (
- V.set_ops_handler(handler),
- patch.object(ir.FlexibleLayout, "allow_indexing", True),
- ):
- out = x.inner_fn(*x.inner_fn_args())
- assert isinstance(out, torch._inductor.virtualized.OpsValue)
- if isinstance(out.value, ir.Constant):
- return out.value
- return None
- # NOTE: prims.div maps to a / b in C, so performs truncation division on
- # integer inputs and true division for floating and complex inputs.
- @register_lowering([prims.div], broadcast=True)
- def div_prim(a, b):
- is_integral = all(is_boolean_type(x) or is_integer_type(x) for x in [a, b])
- if is_integral:
- return truncdiv(a, b)
- # Disable CPU optimization to avoid precision issues.
- # see https://github.com/pytorch/pytorch/issues/157959
- if (divisor := get_constant_value(b)) is not None and a.get_device().type != "cpu":
- # Replace divide by constant with multiply by reciprocal
- if divisor.value == 0:
- reciprocal = math.copysign(float("inf"), divisor.value)
- else:
- reciprocal = 1.0 / divisor.value
- return mul(a, reciprocal)
- def fn(*args):
- return ops.truediv(*args)
- return make_pointwise(fn)(a, b)
- @register_lowering(
- [aten.true_divide, aten.div.Tensor],
- broadcast=True,
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )
- def div(a, b):
- a, b = promote_constants(
- (a, b), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
- )
- return div_prim(a, b)
- @register_lowering([aten.fmod, prims.fmod], broadcast=True)
- def fmod(a, b):
- is_integral = is_boolean_type(a) or is_integer_type(a)
- if is_integral:
- def fn(a, b):
- return ops.mod(a, b)
- else:
- def fn(a, b):
- return ops.fmod(a, b)
- return make_pointwise(fn)(a, b)
- @register_lowering([aten.sum, prims.sum])
- def sum_(x, axis=None, keepdims=False, *, dtype=None):
- if (
- is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
- ) and dtype is None:
- dtype = torch.int64
- fn = make_reduction("sum", override_return_dtype=dtype)
- return fn(x, axis, keepdims, dtype=dtype)
- fallback_cumsum = fallback_handler(aten.cumsum.default)
- fallback_cumprod = fallback_handler(aten.cumprod.default)
- fallback_logcumsumexp = fallback_handler(aten.logcumsumexp.default)
- fallback_cummax = fallback_handler(aten.cummax.default)
- fallback_cummin = fallback_handler(aten.cummin.default)
- @register_lowering(aten.cumsum)
- def cumsum(x, axis=None, dtype=None):
- if (
- is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
- ) and dtype is None:
- dtype = torch.int64
- if len(x.get_size()) == 0:
- assert axis in [0, -1]
- dtype = dtype or x.get_dtype()
- return to_dtype(x, dtype, copy=True)
- def combine_fn(a_tuple, b_tuple):
- (a,) = a_tuple
- (b,) = b_tuple
- return (ops.add(a, b),)
- kwargs = _make_scan_inner(x, axis=axis, dtype=dtype)
- (result,) = ir.Scan.create(**kwargs, combine_fn=combine_fn)
- if result is None:
- return fallback_cumsum(x, dim=axis, dtype=dtype)
- return result
- @register_lowering(aten.cumprod)
- def cumprod(x, axis=None, dtype=None):
- if (
- is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
- ) and dtype is None:
- dtype = torch.int64
- if len(x.get_size()) == 0:
- assert axis in [0, -1]
- dtype = dtype or x.get_dtype()
- return to_dtype(x, dtype, copy=True)
- def combine_fn(a_tuple, b_tuple):
- (a,) = a_tuple
- (b,) = b_tuple
- return (ops.mul(a, b),)
- kwargs = _make_scan_inner(x, axis=axis, dtype=dtype)
- (result,) = ir.Scan.create(**kwargs, combine_fn=combine_fn)
- if result is None:
- return fallback_cumprod(x, dim=axis, dtype=dtype)
- return result
- @register_lowering(aten.logcumsumexp)
- def logcumsumexp(x, dim):
- def log_add_exp_helper(a_tuple, b_tuple):
- (a,) = a_tuple
- (b,) = b_tuple
- min_v = ops.minimum(a, b)
- max_v = ops.maximum(a, b)
- mask = (min_v != max_v) | (~ops.isinf(min_v))
- return (ops.where(mask, ops.log1p(ops.exp(min_v - max_v)) + max_v, a),)
- dtype = x.get_dtype()
- if len(x.get_size()) == 0:
- assert dim in [0, -1]
- return clone(x)
- kwargs = _make_scan_inner(x, axis=dim, dtype=dtype)
- (result,) = ir.Scan.create(**kwargs, combine_fn=log_add_exp_helper)
- if result is None:
- return fallback_logcumsumexp(x, dim=dim)
- return result
- @register_lowering(aten.cummax, type_promotion_kind=None)
- def cummax(x, axis=None):
- if len(x.get_size()) == 0:
- assert axis in [0, -1]
- return clone(x), empty_like(x, dtype=torch.int64)
- dtype = x.get_dtype()
- combine_fn = ir.get_reduction_combine_fn(
- "argmax", dtype=dtype, arg_break_ties_left=False
- )
- kwargs = _make_scan_inner(x, axis=axis, dtype=dtype)
- kwargs["dtypes"] = (dtype, torch.int64)
- kwargs["inner_fns"] = (
- x.make_loader(),
- lambda idx: ops.index_expr(idx[axis], torch.int64),
- )
- values, indices = ir.Scan.create(**kwargs, combine_fn=combine_fn) # type: ignore[arg-type]
- if values is None:
- return fallback_cummax(x, dim=axis)
- return values, indices
- @register_lowering(aten.cummin, type_promotion_kind=None)
- def cummin(x, axis=None):
- if len(x.get_size()) == 0:
- assert axis in [0, -1]
- return clone(x), empty_like(x, dtype=torch.int64)
- dtype = x.get_dtype()
- combine_fn = ir.get_reduction_combine_fn(
- "argmin", dtype=dtype, arg_break_ties_left=False
- )
- kwargs = _make_scan_inner(x, axis=axis, dtype=dtype)
- kwargs["dtypes"] = (dtype, torch.int64)
- kwargs["inner_fns"] = (
- x.make_loader(),
- lambda idx: ops.index_expr(idx[axis], torch.int64),
- )
- values, indices = ir.Scan.create(**kwargs, combine_fn=combine_fn) # type: ignore[arg-type]
- if values is None:
- return fallback_cummin(x, dim=axis)
- return values, indices
- @register_lowering(aten.prod)
- def prod(x, axis=None, keepdims=False, *, dtype=None):
- if (
- is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
- ) and dtype is None:
- dtype = torch.int64
- fn = make_reduction("prod", override_return_dtype=dtype)
- return fn(x, axis, keepdims, dtype=dtype)
- @register_lowering(aten.any)
- def reduce_any(x, dim=None, keepdim=False):
- x = to_dtype(x, torch.bool)
- return make_reduction("any")(x, axis=dim, keepdims=keepdim)
- @register_lowering(aten.max, type_promotion_kind=None)
- def reduce_max(x, dim=None, keepdim=False):
- if dim is not None:
- return (
- reduce_amax(x, axis=dim, keepdims=keepdim),
- reduce_argmax(x, axis=dim, keepdims=keepdim),
- )
- return reduce_amax(x, axis=None, keepdims=keepdim)
- @register_lowering(aten.min, type_promotion_kind=None)
- def reduce_min(x, dim=None, keepdim=False):
- if dim is not None:
- return (
- reduce_amin(x, axis=dim, keepdims=keepdim),
- reduce_argmin(x, axis=dim, keepdims=keepdim),
- )
- return reduce_amin(x, axis=None, keepdims=keepdim)
- register_lowering(prims.xor_sum)(make_reduction("xor_sum"))
- reduce_amax = register_lowering(aten.amax)(make_reduction("max"))
- reduce_amin = register_lowering(aten.amin)(make_reduction("min"))
- reduce_argmax = register_lowering(aten.argmax)(
- make_reduction("argmax", override_return_dtype=torch.int64)
- )
- reduce_argmin = register_lowering(aten.argmin)(
- make_reduction("argmin", override_return_dtype=torch.int64)
- )
- add = register_pointwise(
- aten.add, allow_alpha=True, override_fn_when_input_bool="logical_or"
- )
- sort_fallback = fallback_handler(aten.sort.stable, add_to_fallback_set=False)
- @register_lowering(aten.sort.stable, type_promotion_kind=None)
- def sort_stable(x, *, stable=None, dim=-1, descending=False):
- if stable is None:
- stable = False
- shape = x.get_size()
- device = x.get_device()
- dim = canonicalize_dim(len(shape), dim)
- if len(shape) == 0:
- return clone(x), _full(0, device, torch.int64, shape)
- dim_size = shape[dim] if len(shape) else 1
- if not V.graph.sizevars.statically_known_lt(dim_size, torch.iinfo(torch.int16).max):
- return sort_fallback(x, stable=stable, dim=dim, descending=descending)
- indices = iota(
- dim_size, start=0, step=1, dtype=torch.int16, device=device, requires_grad=False
- )
- view_shape = [1] * len(shape)
- if len(shape):
- view_shape[dim] = dim_size
- indices = view(indices, view_shape)
- indices = expand(indices, shape)
- values, indices = ir.Sort.create(
- device=device,
- dtypes=(x.dtype, indices.dtype),
- inner_fns=(x.make_loader(), indices.make_loader()),
- size=shape,
- axis=dim,
- stable=stable,
- descending=descending,
- )
- if values is None:
- return sort_fallback(x, stable=stable, dim=dim, descending=descending)
- assert indices is not None
- return values, to_dtype(indices, torch.int64)
- @register_lowering(aten.sort.default, type_promotion_kind=None)
- def sort(x, dim=-1, descending=False):
- return sort_stable(x, stable=False, dim=dim, descending=descending)
- def register_pointwise_numeric(op, name=None, triton_fallback=None):
- return register_pointwise(
- op,
- name=name,
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- triton_fallback=triton_fallback,
- )
- def register_pointwise_numeric_ldf64(op: torch._ops.OpOverloadPacket):
- register_op_requires_libdevice_fp64(op.__name__)
- return register_pointwise(
- op,
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )
- rsqrt = register_pointwise_numeric(aten.rsqrt)
- exp = register_pointwise_numeric_ldf64(aten.exp)
- exp2 = register_pointwise_numeric(aten.exp2)
- expm1 = register_pointwise_numeric(aten.expm1)
- relu = register_pointwise(aten.relu)
- sigmoid = register_pointwise_numeric_ldf64(aten.sigmoid)
- sqrt = register_pointwise_numeric_ldf64(aten.sqrt)
- square = register_pointwise(aten.square)
- sub = register_pointwise(aten.sub, allow_alpha=True)
- @register_lowering(aten.addcmul, broadcast=True)
- def addcmul(self, tensor1, tensor2, *, value=1):
- """
- Computes self + value * tensor1 * tensor2 using FMA for better precision.
- Matches eager CUDA kernel order: self + value * (tensor1 * tensor2)
- This is computed as: fma(value, tensor1 * tensor2, self)
- Note: FMA is only used for floating-point types on non-AMD GPUs. For integer types,
- we fall back to regular arithmetic since FMA doesn't support integers.
- For floating-point types, we use mul_rn (round-to-nearest multiplication)
- to force rounding of the product before the FMA. This prevents Triton's
- compiler from fusing the multiplication with the FMA, matching eager's
- rounding behavior.
- When emulate_precision_casts is False, we return NotImplemented to use the
- decomposition instead.
- """
- if not config.emulate_precision_casts:
- return NotImplemented
- dtype = get_promoted_dtype(
- self,
- tensor1,
- tensor2,
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- )
- self_loader = self.make_loader()
- t1_loader = tensor1.make_loader()
- t2_loader = tensor2.make_loader()
- # FMA is only available for floating-point types on non-AMD GPUs
- use_fma = dtype.is_floating_point and not torch.version.hip
- def inner_fn(idx):
- self_val = self_loader(idx)
- t1_val = t1_loader(idx)
- t2_val = t2_loader(idx)
- if value == 1 and use_fma:
- return ops.fma(t1_val, t2_val, self_val)
- # Match eager order: self + value * (tensor1 * tensor2)
- # Compute tensor1 * tensor2 first
- if use_fma:
- # Use mul_rn to force rounding of the product, preventing Triton
- # from fusing t1*t2 with the subsequent FMA
- t1_times_t2 = ops.mul_rn(t1_val, t2_val)
- else:
- t1_times_t2 = ops.mul(t1_val, t2_val)
- # Use index_expr for sympy expressions (e.g., from .item()), constant otherwise
- if isinstance(value, sympy.Basic):
- value_expr = ops.index_expr(value, dtype)
- else:
- value_expr = ops.constant(value, dtype)
- if use_fma:
- # Use FMA for floating-point types for better precision
- return ops.fma(value_expr, t1_times_t2, self_val)
- else:
- # Fall back to regular arithmetic for integer types
- return ops.add(self_val, ops.mul(value_expr, t1_times_t2))
- return Pointwise.create(
- device=self.get_device(),
- dtype=dtype,
- inner_fn=inner_fn,
- ranges=self.get_size(),
- )
- def _foreach_addcmul_scalar(self, tensor1, tensor2, value=1):
- """
- Foreach version of addcmul with scalar value parameter.
- Uses foreach_group_loop for consistent grouping behavior.
- When emulate_precision_casts is False, we return NotImplemented to use the
- decomposition instead.
- """
- if not config.emulate_precision_casts:
- return NotImplemented
- realize_outputs = (
- len(V.graph.current_node.users) == 0
- or V.graph.current_node.target in inplace_foreach_ops
- or cur_node_has_non_foreach_users()
- )
- groups = group_foreach_args(zip(self, tensor1, tensor2))
- def apply_fn(args):
- return addcmul(*args, value=value)
- return foreach_group_loop(groups, len(self), apply_fn, realize_outputs)
- _register_foreach_lowering(aten._foreach_addcmul.Scalar, _foreach_addcmul_scalar)
- register_pointwise_numeric_ldf64(aten.cos)
- register_pointwise_numeric_ldf64(aten.sin)
- abs = register_pointwise(aten.abs)
- bitwise_and = register_pointwise(aten.bitwise_and)
- bitwise_left_shift = register_pointwise(aten.bitwise_left_shift)
- bitwise_not = register_pointwise(
- aten.bitwise_not, override_fn_when_input_bool="logical_not"
- )
- bitwise_or = register_pointwise(aten.bitwise_or)
- bitwise_right_shift = register_pointwise(aten.bitwise_right_shift)
- bitwise_xor = register_pointwise(aten.bitwise_xor)
- register_pointwise_numeric(aten.lgamma)
- erf = register_pointwise_numeric(aten.erf)
- register_lowering(
- aten.special_erf, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
- )(erf)
- register_pointwise_numeric(aten.log1p)
- register_pointwise_numeric(aten.tan)
- register_pointwise_numeric(aten.tanh)
- register_pointwise_numeric_ldf64(aten.log)
- logical_and = register_pointwise(
- aten.logical_and,
- type_promotion_kind=None,
- convert_input_to_bool=True,
- override_return_dtype=torch.bool,
- )
- logical_not = register_pointwise(
- aten.logical_not,
- type_promotion_kind=None,
- convert_input_to_bool=True,
- override_return_dtype=torch.bool,
- )
- logical_or = register_pointwise(
- aten.logical_or,
- type_promotion_kind=None,
- convert_input_to_bool=True,
- override_return_dtype=torch.bool,
- )
- logical_xor = register_pointwise(
- aten.logical_xor,
- type_promotion_kind=None,
- convert_input_to_bool=True,
- override_return_dtype=torch.bool,
- )
- maximum = register_pointwise(aten.maximum)
- minimum = register_pointwise(aten.minimum)
- register_lowering(aten.clamp_min)(maximum)
- register_lowering(aten.clamp_max)(minimum)
- neg = register_pointwise(aten.neg)
- abs = register_pointwise(aten.abs)
- reciprocal = register_pointwise_numeric(aten.reciprocal)
- register_pointwise(aten.remainder)
- sign = register_pointwise(aten.sign, override_fn_when_input_bool="identity")
- register_pointwise(aten.ceil)
- register_pointwise(aten.signbit, override_return_dtype=torch.bool)
- register_lowering(aten._neg_view)(neg)
- register_pointwise(aten.le, override_return_dtype=torch.bool)
- register_pointwise(aten.lt, override_return_dtype=torch.bool)
- register_pointwise(aten.ge, override_return_dtype=torch.bool)
- gt = register_pointwise(aten.gt, override_return_dtype=torch.bool)
- register_pointwise(aten.eq, override_return_dtype=torch.bool)
- register_pointwise(aten.ne, override_return_dtype=torch.bool)
- register_pointwise_numeric(aten.cosh)
- register_pointwise_numeric(aten.sinh)
- register_pointwise_numeric(aten.acos)
- register_pointwise_numeric(aten.acosh)
- register_pointwise_numeric(aten.asin)
- register_pointwise_numeric(aten.asinh)
- register_pointwise_numeric(aten.atan2)
- register_pointwise_numeric(aten.atan)
- register_pointwise_numeric(aten.atanh)
- register_pointwise_numeric(aten.copysign)
- register_pointwise_numeric(aten.erfc)
- register_pointwise_numeric(aten.erfinv)
- register_pointwise_numeric(aten.hypot)
- register_pointwise_numeric(aten.log10)
- register_pointwise_numeric(aten.log2)
- register_pointwise_numeric(aten.nextafter)
- from .codegen.common import BackendFeature, pointwise_overrides_data
- def _get_pointwise_overrides(ns, name):
- data = pointwise_overrides_data[name]
- op = getattr(ns, data.name, None)
- if op is None:
- return
- def make_triton_fallback(op):
- if data.triton is None:
- return fallback_handler(op)
- if isinstance(op, torch._ops.OpOverloadPacket):
- for olname in op.overloads():
- ol = getattr(op, olname)
- yield ol, data.type_promotion_kind, make_triton_fallback(ol)
- else:
- yield op, data.type_promotion_kind, make_triton_fallback(op)
- for name in pointwise_overrides_data:
- for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides(
- aten, name
- ):
- register_pointwise(
- op,
- name=name,
- type_promotion_kind=type_promotion_kind,
- triton_fallback=triton_fallback,
- )
- for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides(
- prims, name
- ):
- register_pointwise(
- op,
- name=name,
- type_promotion_kind=type_promotion_kind,
- triton_fallback=triton_fallback,
- )
- foreach_add_list = register_foreach_pointwise(
- aten._foreach_add.List, add, allow_alpha=True
- )
- foreach_add_scalar = register_foreach_pointwise(
- aten._foreach_add.Scalar, add, allow_alpha=True
- )
- register_foreach_pointwise(aten._foreach_add.Tensor, add, allow_alpha=True)
- foreach_mul_list = register_foreach_pointwise(aten._foreach_mul.List, mul)
- register_foreach_pointwise(aten._foreach_mul.Tensor, mul)
- foreach_mul_scalar = register_foreach_pointwise(aten._foreach_mul.Scalar, mul)
- register_foreach_pointwise(aten._foreach_sub.List, sub)
- register_foreach_pointwise(aten._foreach_sub.Scalar, sub)
- register_foreach_pointwise(aten._foreach_neg.default, neg)
- register_foreach_pointwise(aten._foreach_abs.default, abs)
- register_foreach_pointwise(aten._foreach_pow.Scalar, pow)
- register_foreach_pointwise(aten._foreach_pow.List, pow)
- register_foreach_pointwise(aten._foreach_pow.ScalarAndTensor, pow)
- foreach_div_list = register_foreach_pointwise(aten._foreach_div.List, div)
- register_foreach_pointwise(aten._foreach_div.Tensor, div)
- foreach_div_scalar = register_foreach_pointwise(aten._foreach_div.Scalar, div)
- register_foreach_pointwise(aten._foreach_sqrt, sqrt)
- register_foreach_pointwise(aten._foreach_rsqrt, rsqrt)
- register_foreach_pointwise(aten._foreach_maximum.List, maximum)
- register_foreach_pointwise(aten._foreach_maximum.Scalar, maximum)
- register_foreach_pointwise(aten._foreach_minimum.List, minimum)
- register_foreach_pointwise(aten._foreach_minimum.Scalar, minimum)
- register_foreach_pointwise(aten._foreach_clamp_min.List, maximum)
- register_foreach_pointwise(aten._foreach_clamp_min.Scalar, maximum)
- register_foreach_pointwise(aten._foreach_clamp_max.List, minimum)
- register_foreach_pointwise(aten._foreach_clamp_max.Scalar, minimum)
- register_foreach_pointwise(aten._foreach_reciprocal, reciprocal)
- register_foreach_pointwise(aten._foreach_sign, sign)
- foreach_copy = register_foreach_pointwise(aten._foreach_copy, copy)
- # these are only encountered as outputs of the graph
- # reinplacing epilogue copies improves compile time
- # by removing extra buffers sent to the scheduler.
- def register_foreach_inplace(aten_op, outplace_aten_op, outplace_op):
- inplaceable_foreach_ops[outplace_aten_op] = aten_op
- inplace_foreach_ops.add(aten_op)
- def fn(*args, **kwargs):
- results = outplace_op(*args, **kwargs)
- mut_results = []
- for arg, result in zip(args[0], results):
- mut_results.append(mutate_to(arg, result, unsafe_alias=True))
- return mut_results
- _register_foreach_lowering(aten_op, fn)
- register_foreach_inplace(
- aten._foreach_add_.List, aten._foreach_add.List, foreach_add_list
- )
- register_foreach_inplace(
- aten._foreach_add_.Scalar, aten._foreach_add.Scalar, foreach_add_scalar
- )
- register_foreach_inplace(
- aten._foreach_mul_.List, aten._foreach_mul.List, foreach_mul_list
- )
- register_foreach_inplace(
- aten._foreach_mul_.Scalar, aten._foreach_mul.Scalar, foreach_mul_scalar
- )
- register_foreach_inplace(
- aten._foreach_div_.List, aten._foreach_div.List, foreach_div_list
- )
- register_foreach_inplace(
- aten._foreach_div_.Scalar, aten._foreach_div.Scalar, foreach_div_scalar
- )
- register_foreach_inplace(
- aten._foreach_copy_.default, aten._foreach_copy.default, foreach_copy
- )
- def register_inplace(aten_op, outplace_op):
- @register_lowering(aten_op, type_promotion_kind=None)
- def fn(*args, **kwargs):
- result = outplace_op(*args, **kwargs)
- result = to_dtype(result, args[0].get_dtype())
- return mutate_to(args[0], result)
- return fn
- register_inplace(aten.add_, add)
- register_inplace(aten.bitwise_and_, bitwise_and)
- register_inplace(aten.bitwise_left_shift_, bitwise_left_shift)
- register_inplace(aten.bitwise_not_, bitwise_not)
- register_inplace(aten.bitwise_or_, bitwise_or)
- register_inplace(aten.bitwise_right_shift_, bitwise_right_shift)
- register_inplace(aten.bitwise_xor_, bitwise_xor)
- register_inplace(aten.mul_, mul)
- register_inplace(aten.div_.Tensor, div)
- register_inplace(aten.div_.Tensor_mode, div_mode)
- register_inplace(aten.logical_and_, logical_and)
- register_inplace(aten.logical_not_, logical_not)
- register_inplace(aten.logical_or_, logical_or)
- register_inplace(aten.logical_xor_, logical_xor)
- register_inplace(aten.sub_, sub)
- register_inplace(aten.relu_, relu)
- register_inplace(aten.sigmoid_, sigmoid)
- register_lowering(aten.__and__)(bitwise_and)
- register_lowering(aten.__lshift__)(bitwise_left_shift)
- register_lowering(aten.__or__)(bitwise_or)
- register_lowering(aten.__rshift__)(bitwise_right_shift)
- register_lowering(aten.__xor__)(bitwise_xor)
- register_inplace(aten.__iand__, aten.__and__)
- register_inplace(aten.__ilshift__, aten.__lshift__)
- register_inplace(aten.__ior__, aten.__or__)
- register_inplace(aten.__irshift__, aten.__rshift__)
- register_inplace(aten.__ixor__, aten.__xor__)
- @register_lowering(aten.sym_constrain_range)
- def sym_constrain_range(a, min=None, max=None):
- return None
- @register_lowering(aten.sym_size.int)
- def sym_size(a, dim):
- val = V.graph.current_node.meta["val"]
- if isinstance(val, torch.SymInt):
- return val.node.expr
- else:
- return int(val)
- @register_lowering(aten.sym_stride.int)
- def sym_stride(a, dim):
- val = V.graph.current_node.meta["val"]
- if isinstance(val, torch.SymInt):
- return val.node.expr
- else:
- return int(val)
- @register_lowering(aten.sym_numel)
- def sym_numel(a):
- return a.get_numel()
- for method, func in magic_methods.items():
- register_lowering(method_to_operator(method))(func) # type: ignore[arg-type]
- @register_lowering(torch.sym_sum)
- def sym_sum(args):
- return sympy.Add(*args)
- @register_lowering(aten._foobar)
- def foobar(self, *args, **kwargs):
- raise NotImplementedError("Helpful for debugging")
- @register_lowering(torch.ops._inductor_test.realize)
- def _realize(x):
- x.realize()
- return clone(x)
- @register_lowering(torch.ops.inductor.resize_storage_bytes_)
- def resize_storage_bytes_(variable, new_size):
- variable.realize()
- ir.ResizeStorageBytes(variable, new_size)
- return variable
- @register_lowering(torch.ops.aten.set_.source_Tensor)
- def set__source_tensor(self, source_tensor):
- self.realize()
- source_tensor.realize()
- return TensorBox.create(ir.SetSourceTensorKernel(self, source_tensor))
- if hasattr(torch.ops.fsdp, "copy_"):
- @register_lowering(torch.ops.fsdp.copy_.default)
- def fsdp_copy_(dst, src):
- if dst is src:
- # dst.copy_(dst) can happen from the reinplacing pass
- return dst
- src = to_device(src, dst.get_device())
- src = to_dtype(src, dst.get_dtype())
- src = expand(src, dst.get_size())
- return mutate_to(dst, src)
- @register_lowering(torch.ops.aten.resize)
- def resize(x, size, *, memory_format=None):
- assert isinstance(x, TensorBox)
- assert isinstance(size, (list, tuple))
- if memory_format is None:
- memory_format = torch.contiguous_format
- if memory_format == torch.preserve_format:
- raise RuntimeError(f"unsupported memory format: {memory_format}")
- if memory_format == torch.channels_last:
- assert len(size) == 4
- if memory_format == torch.channels_last_3d:
- assert len(size) == 5
- old_numel = x.get_numel()
- dtype = x.get_dtype()
- device = x.get_device_or_error()
- if isinstance(x.data, ir.BaseView):
- x.data = x.data.unwrap_view()
- if (
- torch.are_deterministic_algorithms_enabled()
- and torch.utils.deterministic.fill_uninitialized_memory # type: ignore[attr-defined]
- ):
- if is_float_dtype(dtype):
- uninitialized_val = float("nan")
- elif is_integer_dtype(dtype):
- uninitialized_val = torch.iinfo(dtype).max
- else:
- uninitialized_val = True
- else:
- # using zero as that is what empty does
- uninitialized_val = 0.0
- if V.graph.sizevars.statically_known_equals(old_numel, 0): # type: ignore[arg-type]
- return full(size, uninitialized_val, dtype=dtype, device=device)
- x_flat = as_strided(
- x,
- [
- old_numel,
- ],
- [
- 1,
- ],
- )
- flat_loader = x_flat.make_loader()
- out_stride = ir.FlexibleLayout.stride_ordered_for_memory_format(size, memory_format)
- out_indexer = ir.FixedLayout(device, dtype, size, out_stride).make_indexer()
- def inner_fn(idx):
- flat_index = out_indexer(idx)
- flat_index_expr = ops.index_expr(flat_index, torch.int64)
- limit = ops.index_expr(old_numel, torch.int64)
- mask = ops.lt(flat_index_expr, limit)
- return ops.masked(mask, lambda: flat_loader([flat_index]), uninitialized_val)
- out = Pointwise.create(
- device=device, dtype=dtype, inner_fn=inner_fn, ranges=list(size)
- )
- return out
- from torch._higher_order_ops.auto_functionalize import auto_functionalized
- make_fallback(auto_functionalized)
- @register_lowering(triton_kernel_wrapper_mutation)
- def triton_kernel_wrap_(
- *,
- kernel_idx,
- constant_args_idx,
- grid,
- tma_descriptor_metadata,
- kwargs,
- ):
- from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
- constant_args = kernel_side_table.get_constant_args(constant_args_idx)
- ir.UserDefinedTritonKernel(
- kernel_idx=kernel_idx,
- grid=grid,
- tma_descriptor_metadata=tma_descriptor_metadata,
- kernel_args={**kwargs, **constant_args},
- )
- return {key: val for key, val in kwargs.items() if isinstance(val, TensorBox)}
- @register_lowering(torch.ops.higher_order.cond, type_promotion_kind=None)
- def cond(
- pred, true_fn, false_fn, operands
- ) -> list[Union[ir.TensorBox, ir.ShapeAsConstantBuffer]]:
- # TODO: when graph_partition is enabled, skip - partitioning handles control flow
- # we run into memory cleanup issue
- if any(isinstance(x, IRNode) and is_triton(x) for x in [pred, *operands]):
- msg = "control flow operator: torch.cond."
- if stack_trace := V.graph.current_node.meta.get("stack_trace", None):
- msg = f"{msg} Found from : \n {stack_trace}"
- V.graph.disable_cudagraphs_reason = msg
- result = ir.Conditional.create(pred, true_fn, false_fn, operands)
- return list(map(TensorBox.create, result)) # pyrefly: ignore no-matching-overload
- @register_lowering(torch.ops.higher_order.while_loop, type_promotion_kind=None)
- def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs, stack_output=False):
- # TODO: when graph_partition is enabled, skip - partitioning handles control flow
- # we run into memory cleanup issue
- if not config.graph_partition and any(
- isinstance(x, IRNode) and is_triton(x)
- for x in carried_inputs + additional_inputs
- ):
- msg = "control flow operator: torch.while_loop."
- if stack_trace := V.graph.current_node.meta.get("stack_trace", None):
- msg = f"{msg} Found from : \n {stack_trace}"
- V.graph.disable_cudagraphs_reason = msg
- result = ir.WhileLoop.create(
- cond_fn, body_fn, carried_inputs, additional_inputs, stack_output
- )
- assert isinstance(result, Sequence)
- return list(map(ir.WhileLoop._maybe_wrap_as_tensor_box, result))
- register_lowering(
- torch.ops.higher_order.while_loop_stack_output, type_promotion_kind=None
- )(functools.partial(while_loop, stack_output=True))
- @register_lowering(torch.ops.higher_order.invoke_subgraph, type_promotion_kind=None)
- def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, *operands):
- result = ir.InvokeSubgraph.create(subgraph_fn, *operands)
- return list(map(TensorBox.create, result)) # type: ignore[call-overload]
- def process_subgraph_nodes(graph_module: torch.fx.GraphModule, args: list[Any]):
- """Process nodes from a FX graph by executing them through V.graph.
- This is a common pattern for executing a subgraph's nodes:
- - Placeholder nodes are mapped to the provided args
- - Output nodes return their result
- - Other nodes are executed via V.graph.run_node
- """
- output = None
- for i, node in enumerate(graph_module.graph.nodes):
- if node.op == "placeholder":
- assert node not in V.graph.env
- V.graph.env[node] = args[i]
- continue
- elif node.op == "output":
- output_args, kwargs = V.graph.fetch_args_kwargs_from_env(node)
- output = torch.fx.Interpreter.output(V.graph, node, output_args, kwargs)
- else:
- assert node not in V.graph.env
- # Track current node for error diagnostics; restore after run_node to handle nested calls correctly
- saved_current_node = V.graph.current_node
- try:
- V.graph.current_node = node
- V.graph.env[node] = V.graph.run_node(node)
- finally:
- V.graph.current_node = saved_current_node
- if output is None:
- raise RuntimeError("No output node found in graph")
- return output
- # Import the control_deps_op HOP for lowering
- from torch._inductor.fx_passes.control_dependencies import control_deps
- @register_lowering(control_deps, type_promotion_kind=None)
- def control_deps_op_lowering(additional_deps, subgraph_fn, *args):
- """
- Lower control_deps_op by ensuring dependencies are realized and tracking them.
- The control_deps_op HOP makes dependencies explicit in the graph. During lowering:
- 1. Realize all additional dependencies to ensure they're computed
- 2. Execute the target operation normally
- 3. Track the dependencies for the scheduler
- """
- # Realize all additional dependencies
- dep_names = []
- for dep in additional_deps:
- if not isinstance(dep, IRNode):
- continue
- dep.realize()
- dep_names.append(dep.get_name())
- original_args = V.graph.current_node.args
- arg_offset = 2 # first two args (additional_deps, subgraph)
- assert len(args) + arg_offset == len(original_args)
- operation_len = len(V.graph.operations)
- assert len(subgraph_fn.graph_module.graph.find_nodes(op="placeholder")) == len(args)
- # Process subgraph nodes using the shared helper
- output = process_subgraph_nodes(subgraph_fn.graph_module, list(args))
- assert output is not None and additional_deps
- # some operators, like wait_tensor, just return their input,
- # so its more robust to add dep to the operation itself,
- # otherwise you can have a cycle of
- # a = coll
- # b = control_deps(a, mm, ...)
- # c = control_deps(b, wait, ...)
- # if c == a, then you have a cycle.
- for op in V.graph.operations[operation_len:]:
- for dep_name in dep_names:
- op_name = op.operation_name
- assert op_name is not None
- V.graph.additional_buffer_deps[op_name].add(dep_name)
- return output
- @register_lowering(torch._higher_order_ops.invoke_quant, type_promotion_kind=None)
- def invoke_quant_tracer(subgraph_fn: ir.Subgraph, *operands, scheme=None):
- output = None
- quant_options = V.graph.current_node.meta.get("quant_options", None)
- assert quant_options is not None
- for i, node in enumerate(subgraph_fn.graph_module.graph.nodes):
- if node.op == "placeholder":
- V.graph.env[node] = operands[i]
- continue
- # todo getattr
- elif node.op == "output":
- args, kwargs = V.graph.fetch_args_kwargs_from_env(node)
- for v in itertools.chain(args, kwargs.values()):
- v.realize()
- if quant_options.codegen_low_precision:
- V.graph.low_precision_codegen_ops.add(v.get_operation_name())
- V.graph.invoke_quant_ops.add(v.get_operation_name())
- output = torch.fx.Interpreter.output(V.graph, node, args, kwargs)
- else:
- V.graph.env[node] = V.graph.run_node(node)
- return output
- @register_lowering(associative_scan_op, type_promotion_kind=None)
- def associative_scan(
- combine_fn: ir.Subgraph, xs, additional_inputs: tuple[torch.Tensor]
- ):
- from .subgraph_lowering import InputDescriptor, lower_pointwise_subgraph
- if len(additional_inputs) > 0:
- raise RuntimeError(
- "Unable to generate code for associative_scan op, because there are lifted arguments"
- )
- subgraph_inputs = [
- InputDescriptor(dtype=x.get_dtype(), device=x.get_device())
- for x in itertools.chain(xs, xs)
- ]
- lowered_combine_fn = lower_pointwise_subgraph(combine_fn, subgraph_inputs) # type: ignore[var-annotated]
- def wrapped_combine_fn(lhs, rhs):
- return lowered_combine_fn(
- *pytree.tree_leaves(lhs),
- *pytree.tree_leaves(rhs),
- )
- kwargs = _make_scan_inner(xs[0], axis=0, dtype=None)
- kwargs["dtypes"] = tuple(x.get_dtype() for x in xs)
- kwargs["inner_fns"] = tuple(x.make_loader() for x in xs)
- result = ir.Scan.create(
- combine_fn=wrapped_combine_fn,
- can_fallback_to_aten=False,
- **kwargs,
- )
- if result[0] is None:
- raise RuntimeError("Unable to generate code for associative_scan op")
- return result
- @register_lowering(torch.ops.prims._sink_tokens.default)
- def _sink_tokens(tokens):
- return None
- @register_lowering(torch.ops.prims._make_token.default)
- def _make_token():
- return None
- @register_lowering(torch.ops.higher_order.with_effects, type_promotion_kind=None)
- def with_effects(token, op, *args, **kwargs):
- """
- We lower the operator directly, and then we add StarDep dependencies to all
- the newly created nodes in the graph.
- """
- from torch._higher_order_ops.effects import _get_effect, _get_schema
- # Get effect type
- effect_type = _get_effect(op)
- if effect_type is None and op is torch.ops.higher_order.invoke_subgraph:
- from torch._guards import InvokeSubgraphCache, TracingContext
- tracing_ctx = TracingContext.try_get()
- if tracing_ctx:
- invoke_subgraph_cache = tracing_ctx.hop_dispatch_set_cache.get_cache(
- torch.ops.higher_order.invoke_subgraph
- )
- if invoke_subgraph_cache:
- assert isinstance(invoke_subgraph_cache, InvokeSubgraphCache)
- # args[1] is identifier
- effects = invoke_subgraph_cache.get_effects(args[1])
- if effects:
- assert len(effects) == 1, "Multiple effects NYI"
- effect_type = next(iter(effects))
- # Track operations before
- operation_len = len(V.graph.operations)
- # Lower the op
- if op in lowerings:
- result = lowerings[op](*args, **kwargs)
- # Realize so that we can get the ops to show up in V.graph.operations
- pytree.tree_map_only(TensorBox, lambda a: a.realize(), result)
- else:
- def wrap_tensors(x):
- return TensorBox.create(x) if isinstance(x, ir.IRNode) else x
- result = pytree.tree_map(
- wrap_tensors, ir.FallbackKernel.create(op, *args, **kwargs)
- )
- # Get all the operations created during the lowering above, and add StarDeps
- # to the previous node with the same effect
- assert len(V.graph.operations[operation_len:]) > 0, (
- f"No operation nodes were generated when lowering effectful operator {op}."
- )
- if effect_type:
- prev_effect_buffer = V.graph.effectful_ops.get(effect_type)
- for new_op in V.graph.operations[operation_len:]:
- # Patch has_side_effects to return True
- new_op.has_side_effects = lambda: True # pyrefly: ignore[missing-attribute]
- if prev_effect_buffer:
- op_name = new_op.get_name() # pyrefly: ignore[missing-attribute]
- V.graph.additional_star_deps[op_name].add(prev_effect_buffer.get_name())
- # Update the effectful ops chain to point to the latest operation
- V.graph.effectful_ops[effect_type] = (
- new_op # pyrefly: ignore[unsupported-operation]
- )
- try:
- def convert_ir_to_value(a):
- if isinstance(a, ir.TorchBindObject):
- return a.get_value()
- elif isinstance(a, TensorBox):
- # TensorBox wraps StorageBox, which wraps the actual buffer
- # We need to get the example tensor from the inner buffer
- try:
- storage = a.data
- if hasattr(storage, "data") and hasattr(
- storage.data, "get_example"
- ):
- return storage.data.get_example()
- except (AttributeError, NotImplementedError):
- pass
- # Fall back to returning the TensorBox itself if get_example fails
- return a
- return a
- schema_args, schema_kwargs = pytree.tree_map(
- convert_ir_to_value, (args, kwargs)
- )
- schema = _get_schema(op, schema_args, schema_kwargs)
- except RuntimeError as e:
- error_msg = str(e)
- log.warning(
- "Failed to get schema for %s: %s. Assuming list output", op, error_msg
- )
- if isinstance(result, (tuple, list)):
- return (token, *result)
- else:
- return (token, result)
- if len(schema.returns) == 0:
- return (token, result)
- elif len(schema.returns) == 1:
- return (token, result)
- else:
- return (token, *result)
- from .comm_lowering import register_comm_lowerings, register_symm_mem_lowerings
- register_comm_lowerings()
- register_symm_mem_lowerings()
- @register_lowering(inductor_prims.prepare_softmax_online, type_promotion_kind=None)
- def prepare_softmax_online(x, dim):
- """
- Lowering inductor_prims.prepare_softmax_online to compute max/sum in one pass if no split is needed.
- """
- kwargs = _make_reduction_inner(
- x, axis=dim, keepdims=True, dtype=None, override_return_dtype=None
- )
- reduction_ranges = kwargs["reduction_ranges"]
- rnumel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
- hint, num_split = ir.Reduction.num_splits(
- **kwargs,
- reduction_type="online_softmax_reduce", # type: ignore[arg-type]
- reduction_numel=rnumel,
- )
- if num_split == 1 and V.graph.sizevars.statically_known_geq(
- rnumel, config.unroll_reductions_threshold
- ):
- max_tensor, sum_tensor = OnlineSoftmaxReduction.create(
- input_node=x, num_output=2, reduction_hint=hint, **kwargs
- )
- return max_tensor, sum_tensor
- else:
- # Note: [Split online_softmax_reduce]
- # We don't split reduction for online_softmax_reduce for now.
- # On one hand, supporting split reduction makes things complex since
- # the split out reuctions requires 2 inputs rather than one.
- # On the other hand, during training the online_softmax_reduce should
- # usually don't requires a split due to large batch size
- # (more specifically batch size times sequence length).
- # We should support split reduction if we find legit use cases to
- # motivate the work.
- #
- # TODO: does inference need split online_softmax_reduce?
- warnings.warn(
- textwrap.dedent(
- """
- Online softmax is disabled on the fly since Inductor decides to
- split the reduction. Cut an issue to PyTorch if this is an
- important use case and you want to speed it up with online
- softmax.
- """
- )
- )
- amax = reduce_amax(x, dim, keepdims=True)
- exp = lowerings[aten.exp](sub(x, amax))
- xsum = sum_(exp, dim, keepdims=True)
- return amax, xsum
- def _is_sm100_or_later():
- """Check if we're on SM100+ hardware (Blackwell)."""
- return torch.cuda.is_available() and torch.cuda.get_device_capability() >= (10, 0)
- @register_lowering(inductor_prims.cvt_e8m0_rceil, type_promotion_kind=None)
- def cvt_e8m0_rceil_lowering(inp):
- """
- Lowering for cvt_e8m0_rceil. Uses PTX cvt.rp.satfinite.ue8m0x2.f32 on SM100+.
- The PTX instruction takes 2 float32 and outputs 2 e8m0 packed in uint16.
- Currently we pass 0.0 as the second input and only use the low byte result.
- """
- # TODO: Optimize to process pairs (pack=2) by creating a custom Pointwise
- # that loads adjacent elements, applies PTX to both, and uses a follow-up
- # kernel to extract the packed uint16 results as uint8.
- if not _is_sm100_or_later():
- raise NotImplementedError(
- "cvt_e8m0_rceil requires SM100+ (Blackwell) for PTX instruction support"
- )
- dtype = inp.get_dtype()
- if dtype not in (torch.float32, torch.float16, torch.bfloat16):
- raise ValueError(
- f"cvt_e8m0_rceil requires float32, float16, or bfloat16 input, got {dtype}"
- )
- # Upcast bf16/fp16 to float32 for PTX instruction
- if dtype != torch.float32:
- inp = to_dtype(inp, torch.float32)
- fn = functools.partial(
- ops.inline_asm_elementwise,
- asm="cvt.rp.satfinite.ue8m0x2.f32 $0, 0.0, $1;",
- constraints="=h,r",
- dtype=torch.uint16,
- is_pure=True,
- pack=1,
- )
- result = make_pointwise(fn)(inp)
- return to_dtype(result, torch.uint8)
- # populate lowerings defined in kernel/*
- from . import kernel
- import_submodule(kernel)
- from . import quantized_lowerings
- quantized_lowerings.register_quantized_ops()
- quantized_lowerings.register_woq_mm_ops()
- from . import mkldnn_lowerings
- mkldnn_lowerings.register_onednn_fusion_ops()
- from . import jagged_lowerings
- jagged_lowerings.register_jagged_ops()
- @contextlib.contextmanager
- def force_fallback(op: torch._ops.OpOverload):
- """
- A context manager to force fallback an op. Used in unit test
- for FallbackKernel.
- """
- assert isinstance(op, torch._ops.OpOverload), (
- "Only OpOverload to make the clean up easier"
- )
- old_handler = lowerings.get(op)
- try:
- register_lowering(op)(fallback_handler(op))
- yield
- finally:
- if old_handler:
- lowerings[op] = old_handler
- else:
- lowerings.pop(op)
|