_meta_registrations.py 277 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490549154925493549454955496549754985499550055015502550355045505550655075508550955105511551255135514551555165517551855195520552155225523552455255526552755285529553055315532553355345535553655375538553955405541554255435544554555465547554855495550555155525553555455555556555755585559556055615562556355645565556655675568556955705571557255735574557555765577557855795580558155825583558455855586558755885589559055915592559355945595559655975598559956005601560256035604560556065607560856095610561156125613561456155616561756185619562056215622562356245625562656275628562956305631563256335634563556365637563856395640564156425643564456455646564756485649565056515652565356545655565656575658565956605661566256635664566556665667566856695670567156725673567456755676567756785679568056815682568356845685568656875688568956905691569256935694569556965697569856995700570157025703570457055706570757085709571057115712571357145715571657175718571957205721572257235724572557265727572857295730573157325733573457355736573757385739574057415742574357445745574657475748574957505751575257535754575557565757575857595760576157625763576457655766576757685769577057715772577357745775577657775778577957805781578257835784578557865787578857895790579157925793579457955796579757985799580058015802580358045805580658075808580958105811581258135814581558165817581858195820582158225823582458255826582758285829583058315832583358345835583658375838583958405841584258435844584558465847584858495850585158525853585458555856585758585859586058615862586358645865586658675868586958705871587258735874587558765877587858795880588158825883588458855886588758885889589058915892589358945895589658975898589959005901590259035904590559065907590859095910591159125913591459155916591759185919592059215922592359245925592659275928592959305931593259335934593559365937593859395940594159425943594459455946594759485949595059515952595359545955595659575958595959605961596259635964596559665967596859695970597159725973597459755976597759785979598059815982598359845985598659875988598959905991599259935994599559965997599859996000600160026003600460056006600760086009601060116012601360146015601660176018601960206021602260236024602560266027602860296030603160326033603460356036603760386039604060416042604360446045604660476048604960506051605260536054605560566057605860596060606160626063606460656066606760686069607060716072607360746075607660776078607960806081608260836084608560866087608860896090609160926093609460956096609760986099610061016102610361046105610661076108610961106111611261136114611561166117611861196120612161226123612461256126612761286129613061316132613361346135613661376138613961406141614261436144614561466147614861496150615161526153615461556156615761586159616061616162616361646165616661676168616961706171617261736174617561766177617861796180618161826183618461856186618761886189619061916192619361946195619661976198619962006201620262036204620562066207620862096210621162126213621462156216621762186219622062216222622362246225622662276228622962306231623262336234623562366237623862396240624162426243624462456246624762486249625062516252625362546255625662576258625962606261626262636264626562666267626862696270627162726273627462756276627762786279628062816282628362846285628662876288628962906291629262936294629562966297629862996300630163026303630463056306630763086309631063116312631363146315631663176318631963206321632263236324632563266327632863296330633163326333633463356336633763386339634063416342634363446345634663476348634963506351635263536354635563566357635863596360636163626363636463656366636763686369637063716372637363746375637663776378637963806381638263836384638563866387638863896390639163926393639463956396639763986399640064016402640364046405640664076408640964106411641264136414641564166417641864196420642164226423642464256426642764286429643064316432643364346435643664376438643964406441644264436444644564466447644864496450645164526453645464556456645764586459646064616462646364646465646664676468646964706471647264736474647564766477647864796480648164826483648464856486648764886489649064916492649364946495649664976498649965006501650265036504650565066507650865096510651165126513651465156516651765186519652065216522652365246525652665276528652965306531653265336534653565366537653865396540654165426543654465456546654765486549655065516552655365546555655665576558655965606561656265636564656565666567656865696570657165726573657465756576657765786579658065816582658365846585658665876588658965906591659265936594659565966597659865996600660166026603660466056606660766086609661066116612661366146615661666176618661966206621662266236624662566266627662866296630663166326633663466356636663766386639664066416642664366446645664666476648664966506651665266536654665566566657665866596660666166626663666466656666666766686669667066716672667366746675667666776678667966806681668266836684668566866687668866896690669166926693669466956696669766986699670067016702670367046705670667076708670967106711671267136714671567166717671867196720672167226723672467256726672767286729673067316732673367346735673667376738673967406741674267436744674567466747674867496750675167526753675467556756675767586759676067616762676367646765676667676768676967706771677267736774677567766777677867796780678167826783678467856786678767886789679067916792679367946795679667976798679968006801680268036804680568066807680868096810681168126813681468156816681768186819682068216822682368246825682668276828682968306831683268336834683568366837683868396840684168426843684468456846684768486849685068516852685368546855685668576858685968606861686268636864686568666867686868696870687168726873687468756876687768786879688068816882688368846885688668876888688968906891689268936894689568966897689868996900690169026903690469056906690769086909691069116912691369146915691669176918691969206921692269236924692569266927692869296930693169326933693469356936693769386939694069416942694369446945694669476948694969506951695269536954695569566957695869596960696169626963696469656966696769686969697069716972697369746975697669776978697969806981698269836984698569866987698869896990699169926993699469956996699769986999700070017002700370047005700670077008700970107011701270137014701570167017701870197020702170227023702470257026702770287029703070317032703370347035703670377038703970407041704270437044704570467047704870497050705170527053705470557056705770587059706070617062706370647065706670677068706970707071707270737074707570767077707870797080708170827083708470857086708770887089709070917092709370947095709670977098709971007101710271037104710571067107710871097110711171127113711471157116711771187119712071217122712371247125712671277128712971307131713271337134713571367137713871397140714171427143714471457146714771487149715071517152715371547155715671577158715971607161716271637164716571667167716871697170717171727173717471757176717771787179718071817182718371847185718671877188718971907191719271937194719571967197719871997200720172027203720472057206720772087209721072117212721372147215721672177218721972207221722272237224722572267227722872297230723172327233723472357236723772387239724072417242724372447245724672477248724972507251725272537254725572567257725872597260726172627263726472657266726772687269727072717272727372747275727672777278727972807281728272837284728572867287728872897290729172927293729472957296729772987299730073017302730373047305730673077308730973107311731273137314731573167317731873197320732173227323732473257326732773287329733073317332733373347335733673377338733973407341734273437344734573467347734873497350735173527353735473557356735773587359736073617362736373647365736673677368736973707371737273737374737573767377737873797380738173827383738473857386738773887389739073917392739373947395739673977398739974007401740274037404740574067407740874097410741174127413741474157416741774187419742074217422742374247425742674277428742974307431743274337434743574367437743874397440744174427443744474457446744774487449745074517452745374547455745674577458745974607461746274637464746574667467746874697470747174727473747474757476747774787479748074817482748374847485748674877488748974907491749274937494749574967497749874997500750175027503750475057506750775087509751075117512751375147515751675177518751975207521752275237524752575267527752875297530753175327533753475357536753775387539754075417542754375447545754675477548754975507551755275537554755575567557755875597560756175627563756475657566756775687569757075717572757375747575757675777578757975807581758275837584758575867587758875897590759175927593759475957596759775987599760076017602760376047605760676077608760976107611761276137614761576167617761876197620762176227623762476257626762776287629763076317632763376347635763676377638763976407641764276437644764576467647764876497650765176527653765476557656765776587659766076617662766376647665766676677668766976707671767276737674767576767677767876797680768176827683768476857686768776887689769076917692769376947695769676977698769977007701770277037704770577067707770877097710771177127713771477157716771777187719772077217722772377247725772677277728772977307731773277337734773577367737773877397740774177427743774477457746774777487749775077517752775377547755775677577758775977607761776277637764776577667767776877697770777177727773777477757776777777787779778077817782778377847785778677877788778977907791779277937794779577967797779877997800780178027803780478057806780778087809781078117812781378147815781678177818781978207821782278237824782578267827782878297830783178327833783478357836783778387839784078417842784378447845784678477848784978507851785278537854785578567857785878597860786178627863786478657866786778687869787078717872787378747875787678777878787978807881788278837884788578867887788878897890789178927893789478957896789778987899790079017902790379047905790679077908790979107911791279137914791579167917791879197920792179227923792479257926792779287929793079317932793379347935793679377938793979407941794279437944794579467947794879497950795179527953795479557956795779587959796079617962796379647965796679677968796979707971797279737974797579767977797879797980798179827983798479857986798779887989799079917992799379947995799679977998799980008001800280038004800580068007800880098010801180128013801480158016801780188019802080218022802380248025802680278028802980308031803280338034803580368037803880398040804180428043804480458046804780488049805080518052805380548055805680578058805980608061806280638064806580668067806880698070807180728073807480758076807780788079808080818082808380848085808680878088808980908091809280938094809580968097809880998100810181028103810481058106810781088109811081118112811381148115811681178118811981208121812281238124812581268127812881298130813181328133813481358136813781388139814081418142814381448145814681478148814981508151815281538154815581568157815881598160816181628163816481658166816781688169817081718172817381748175817681778178817981808181818281838184818581868187818881898190819181928193819481958196819781988199820082018202820382048205820682078208820982108211821282138214821582168217821882198220822182228223822482258226822782288229823082318232823382348235823682378238823982408241824282438244824582468247824882498250825182528253825482558256825782588259826082618262826382648265826682678268826982708271827282738274827582768277827882798280828182828283828482858286828782888289829082918292829382948295829682978298829983008301830283038304830583068307830883098310831183128313831483158316831783188319832083218322832383248325832683278328832983308331833283338334833583368337833883398340834183428343834483458346834783488349835083518352835383548355835683578358835983608361836283638364836583668367836883698370837183728373837483758376837783788379838083818382838383848385838683878388838983908391839283938394839583968397839883998400840184028403840484058406840784088409841084118412841384148415841684178418841984208421842284238424842584268427842884298430843184328433843484358436843784388439844084418442844384448445844684478448844984508451845284538454845584568457845884598460846184628463846484658466846784688469847084718472847384748475847684778478847984808481848284838484848584868487848884898490849184928493849484958496849784988499850085018502850385048505850685078508850985108511851285138514851585168517851885198520852185228523852485258526852785288529853085318532853385348535853685378538853985408541854285438544854585468547854885498550855185528553855485558556855785588559856085618562856385648565856685678568856985708571857285738574857585768577857885798580858185828583858485858586858785888589859085918592859385948595859685978598859986008601860286038604860586068607860886098610861186128613861486158616861786188619862086218622862386248625862686278628862986308631863286338634863586368637863886398640864186428643864486458646864786488649865086518652865386548655865686578658865986608661866286638664866586668667866886698670867186728673867486758676867786788679868086818682868386848685868686878688868986908691869286938694869586968697869886998700870187028703870487058706870787088709871087118712871387148715871687178718871987208721872287238724872587268727872887298730873187328733873487358736873787388739874087418742874387448745874687478748874987508751875287538754875587568757875887598760876187628763876487658766876787688769877087718772877387748775877687778778877987808781878287838784878587868787878887898790879187928793879487958796879787988799
  1. # mypy: allow-untyped-defs
  2. import math
  3. from collections.abc import Callable, Sequence
  4. from enum import Enum
  5. from functools import wraps
  6. from typing import TypeVar
  7. from typing_extensions import ParamSpec
  8. import torch
  9. import torch._prims_common as utils
  10. from torch import SymBool, SymFloat, Tensor
  11. from torch._decomp import (
  12. _add_op_to_registry,
  13. _convert_out_params,
  14. global_decomposition_table,
  15. meta_table,
  16. )
  17. from torch._ops import OpOverload
  18. from torch._prims import _prim_elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND
  19. from torch._prims_common import (
  20. BoolLike,
  21. corresponding_complex_dtype,
  22. corresponding_real_dtype,
  23. elementwise_dtypes,
  24. ELEMENTWISE_TYPE_PROMOTION_KIND,
  25. FloatLike,
  26. IntLike,
  27. make_contiguous_strides_for,
  28. Number,
  29. NumberType,
  30. suggest_memory_format,
  31. sym_min,
  32. TensorLike,
  33. )
  34. from torch._prims_common.wrappers import (
  35. _maybe_convert_to_dtype,
  36. _maybe_resize_out,
  37. _resize_output_check,
  38. _safe_copy_out,
  39. out_wrapper,
  40. )
  41. from torch._refs import _broadcast_shapes, _maybe_broadcast
  42. from torch.fx.experimental import _config as exp_config
  43. from torch.nn.functional import ScalingType, SwizzleType
  44. from torch.utils import _pytree as pytree
  45. _T = TypeVar("_T")
  46. _P = ParamSpec("_P")
  47. aten = torch.ops.aten
  48. _meta_lib_dont_use_me_use_register_meta = torch.library.Library("aten", "IMPL", "Meta")
  49. MODE_SUM, MODE_MEAN, MODE_MAX = range(3)
  50. def ceil_div(a, b):
  51. return (a + b - 1) // b
  52. def round_up(x, y):
  53. """Rounds up x to nearest multiple of y"""
  54. return ((x + y - 1) // y) * y
  55. def register_meta(op) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
  56. def wrapper(fn):
  57. fn = _convert_out_params(fn)
  58. def register(op):
  59. _add_op_to_registry(meta_table, op, fn)
  60. pytree.tree_map_(register, op)
  61. return fn
  62. return wrapper
  63. def elementwise_meta(
  64. *args,
  65. type_promotion: ELEMENTWISE_TYPE_PROMOTION_KIND,
  66. ):
  67. # Perform type promotion, as this is expected from prim_metafunction
  68. _, result_dtype = utils.elementwise_dtypes(
  69. *args,
  70. type_promotion_kind=type_promotion,
  71. )
  72. args = [_maybe_convert_to_dtype(x, result_dtype) for x in args]
  73. # Broadcast
  74. args = _maybe_broadcast(*args)
  75. # Perform prim checks
  76. return _prim_elementwise_meta(
  77. *args, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT
  78. )
  79. def toRealValueType(dtype):
  80. from_complex = {
  81. torch.complex32: torch.half,
  82. torch.cfloat: torch.float,
  83. torch.cdouble: torch.double,
  84. }
  85. return from_complex.get(dtype, dtype)
  86. def check_inplace_broadcast(self_shape, *args_shape):
  87. broadcasted_shape = tuple(_broadcast_shapes(self_shape, *args_shape))
  88. torch._check(
  89. broadcasted_shape == self_shape,
  90. lambda: f"output with shape {self_shape} doesn't match the broadcast shape {broadcasted_shape}",
  91. )
  92. @register_meta([aten.linspace, aten.logspace])
  93. @out_wrapper()
  94. def meta_linspace_logspace(
  95. start,
  96. end,
  97. steps,
  98. base=None,
  99. dtype=None,
  100. device=None,
  101. layout=torch.strided,
  102. pin_memory=False,
  103. requires_grad=False,
  104. ):
  105. if isinstance(start, torch.Tensor):
  106. torch._check(
  107. start.dim() == 0,
  108. lambda: "linspace only supports 0-dimensional start and end tensors",
  109. )
  110. if isinstance(end, torch.Tensor):
  111. torch._check(
  112. end.dim() == 0,
  113. lambda: "linspace only supports 0-dimensional start and end tensors",
  114. )
  115. if any(isinstance(arg, complex) for arg in (start, end, steps)):
  116. default_complex_dtype = utils.corresponding_complex_dtype(
  117. torch.get_default_dtype()
  118. )
  119. if dtype is None:
  120. dtype = default_complex_dtype
  121. else:
  122. torch._check(
  123. utils.is_complex_dtype(dtype),
  124. lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}",
  125. )
  126. else:
  127. dtype = dtype or torch.get_default_dtype()
  128. if not isinstance(dtype, torch.dtype):
  129. raise AssertionError(f"dtype must be torch.dtype, got {type(dtype)}")
  130. # steps does not participate in the computation of the dtype
  131. torch._check_type(
  132. isinstance(steps, IntLike),
  133. lambda: f"received an invalid combination of arguments - got \
  134. ({type(start).__name__}, {type(end).__name__}, {type(steps).__name__})",
  135. )
  136. if not isinstance(steps, IntLike):
  137. raise AssertionError(f"steps must be IntLike, got {type(steps)}")
  138. torch._check(steps >= 0, lambda: "number of steps must be non-negative")
  139. return torch.empty(
  140. (steps,), # type: ignore[arg-type]
  141. dtype=dtype,
  142. layout=layout,
  143. device="meta",
  144. pin_memory=pin_memory,
  145. requires_grad=requires_grad,
  146. )
  147. @register_meta([aten.take.default, aten.take.out])
  148. @out_wrapper()
  149. def meta_take(self, index):
  150. # Type and device checks
  151. torch._check(
  152. index.dtype == torch.long,
  153. lambda: f"take(): Expected a long tensor for index, but got {index.dtype}",
  154. )
  155. # Index checks
  156. torch._check_index(
  157. not (self.numel() == 0 and index.numel() != 0),
  158. lambda: "take(): tried to take from an empty tensor",
  159. )
  160. return self.new_empty(index.shape)
  161. @register_meta([aten.linalg_cross.default, aten.linalg_cross.out])
  162. @out_wrapper()
  163. def linalg_cross(self, other, *, dim=-1):
  164. x_d = self.ndim
  165. y_d = other.ndim
  166. torch._check(
  167. x_d == y_d,
  168. lambda: "linalg.cross: inputs must have the same number of dimensions.",
  169. )
  170. torch._check(
  171. self.size(dim) == 3 and other.size(dim) == 3,
  172. lambda: (
  173. f"linalg.cross: inputs dimension {dim} must have length 3. "
  174. f"Got {self.size(dim)} and {other.size(dim)}"
  175. ),
  176. )
  177. out_shape = _broadcast_shapes(self.shape, other.shape)
  178. return self.new_empty(out_shape)
  179. @register_meta(aten.linalg_matrix_exp)
  180. @out_wrapper()
  181. def linalg_matrix_exp(self):
  182. squareCheckInputs(self, "linalg.matrix_exp")
  183. checkFloatingOrComplex(self, "linalg.matrix_exp")
  184. return torch.empty_like(self, memory_format=torch.contiguous_format)
  185. @register_meta(
  186. [aten.cummax.default, aten.cummax.out, aten.cummin.default, aten.cummin.out]
  187. )
  188. @out_wrapper("values", "indices")
  189. def cummaxmin(self, dim):
  190. values = torch.empty(self.shape, device=self.device, dtype=self.dtype)
  191. indices = torch.empty(self.shape, device=self.device, dtype=torch.int64)
  192. if self.numel() != 0 and self.ndim != 0:
  193. # Checks that dim is within bounds
  194. maybe_wrap_dim(dim, self.ndim)
  195. return values, indices
  196. @register_meta([aten.logcumsumexp.default, aten.logcumsumexp.out])
  197. @out_wrapper()
  198. def logcumsumexp(self, dim):
  199. # Checks that dim is within bounds
  200. maybe_wrap_dim(dim, self.ndim)
  201. return torch.empty_like(self, memory_format=torch.contiguous_format)
  202. # Stride-related code from _exec_fft in aten/src/ATen/native/mkl/SpectralOps.cpp
  203. # and aten/src/ATen/cuda/SpectralOps.cpp
  204. #
  205. # Although the actual FFT launch is different, all the permuting code appears
  206. # to be the same
  207. def _exec_fft(out, self, out_sizes, dim, *, forward):
  208. ndim = self.ndim
  209. signal_ndim = len(dim)
  210. batch_dims = ndim - signal_ndim
  211. # Permute dimensions so batch dimensions come first, and in stride order
  212. dim_permute = list(range(ndim))
  213. is_transformed_dim = [False for _ in range(ndim)]
  214. for d in dim:
  215. is_transformed_dim[d] = True
  216. # std::partition
  217. left, right = [], []
  218. for d in dim_permute:
  219. if not is_transformed_dim[d]:
  220. left.append(d)
  221. else:
  222. right.append(d)
  223. dim_permute = left + right
  224. batch_end = len(left)
  225. self_strides = self.stride()
  226. tmp = dim_permute[:batch_end]
  227. tmp.sort(key=lambda x: self_strides[x], reverse=True)
  228. dim_permute = tmp + dim_permute[batch_end:]
  229. input = self.permute(dim_permute)
  230. # Collapse batch dimensions into a single dimension
  231. batched_sizes = [-1] + list(input.shape[batch_dims:])
  232. input = input.reshape(batched_sizes)
  233. batch_size = input.size(0)
  234. batched_sizes[0] = batch_size
  235. batched_out_sizes = list(batched_sizes)
  236. for i in range(len(dim)):
  237. batched_out_sizes[i + 1] = out_sizes[dim[i]]
  238. out.resize_(batched_out_sizes, memory_format=torch.contiguous_format)
  239. # Inplace reshaping to original batch shape and inverting the dimension permutation
  240. out_strides = [0 for _ in range(ndim)]
  241. batch_numel = 1
  242. i = batch_dims - 1
  243. while i >= 0:
  244. out_strides[dim_permute[i]] = batch_numel * out.stride(0)
  245. batch_numel *= out_sizes[dim_permute[i]]
  246. i -= 1
  247. for i in range(batch_dims, ndim):
  248. out_strides[dim_permute[i]] = out.stride(1 + (i - batch_dims))
  249. out.as_strided_(out_sizes, out_strides, out.storage_offset())
  250. return out
  251. def _sort_dims(self: Tensor, dim: list[int], exclude_last: bool = False):
  252. sorted_dims = list(dim)
  253. self_strides = self.stride()
  254. sorted_dims[: len(sorted_dims) - int(exclude_last)].sort(
  255. key=lambda i: self_strides[i]
  256. )
  257. return sorted_dims
  258. # See _fft_c2c_cufft in aten/src/ATen/native/cuda/SpectralOps.cpp
  259. # and _fft_c2c_mkl in aten/src/ATen/native/mkl/SpectralOps.cpp
  260. @register_meta([aten._fft_c2c.default, aten._fft_c2c.out])
  261. @out_wrapper()
  262. def meta_fft_c2c(self, dim, normalization, forward):
  263. torch._check(self.dtype.is_complex)
  264. if not dim:
  265. return self.clone()
  266. sorted_dims = _sort_dims(self, dim)
  267. out = self.new_empty(self.size())
  268. return _exec_fft(out, self, self.size(), sorted_dims, forward=forward)
  269. cufft_max_ndim = 3
  270. def use_optimized_cufft_path(dim: list[int]):
  271. if len(dim) > cufft_max_ndim or (len(dim) >= 2 and dim[0] == 0 and dim[1] == 1):
  272. return False
  273. else:
  274. return True
  275. @register_meta([aten._fft_r2c.default, aten._fft_r2c.out])
  276. @out_wrapper()
  277. def meta_fft_r2c(self, dim, normalization, onesided):
  278. torch._check(self.dtype.is_floating_point)
  279. input_sizes = list(self.size())
  280. out_sizes = list(input_sizes)
  281. last_dim = dim[-1]
  282. last_dim_halfsize = input_sizes[last_dim] // 2 + 1
  283. onesided_sizes = list(input_sizes)
  284. onesided_sizes[last_dim] = last_dim_halfsize
  285. if onesided:
  286. out_sizes[last_dim] = last_dim_halfsize
  287. if device_hint(self) == "cuda" or device_hint(self) == "xpu":
  288. # _fft_r2c_cufft in aten/src/ATen/native/cuda/SpectralOps.cpp
  289. # _fft_r2c_xpu in torch-xpu-ops/src/ATen/native/xpu/SpectralOps.cpp
  290. output = self.new_empty(
  291. out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
  292. )
  293. working_tensor = self
  294. if device_hint(self) == "cuda" and use_optimized_cufft_path(dim):
  295. _exec_fft(output, working_tensor, out_sizes, dim, forward=True)
  296. else:
  297. # First do the R2C transform on the last dimension
  298. target_sizes = out_sizes if len(dim) == 1 else onesided_sizes
  299. _exec_fft(output, working_tensor, target_sizes, [last_dim], forward=True)
  300. if len(dim) > 1:
  301. working_tensor = self.new_empty(
  302. out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
  303. )
  304. # Then any remaining C2C transforms
  305. sorted_dims = dim[:-1]
  306. while sorted_dims:
  307. output, working_tensor = working_tensor, output
  308. strides = working_tensor.stride()
  309. sorted_dims.sort(
  310. key=lambda i: strides[i], reverse=True
  311. ) # NB reverse! Not sure if this is og bug
  312. max_dims = min(cufft_max_ndim, len(sorted_dims))
  313. last_dims = sorted_dims[len(sorted_dims) - max_dims :]
  314. _exec_fft(
  315. output, working_tensor, onesided_sizes, last_dims, forward=True
  316. )
  317. sorted_dims = sorted_dims[: len(sorted_dims) - max_dims]
  318. if not onesided:
  319. if output.size(last_dim) != out_sizes[last_dim]:
  320. working_tensor.resize_(out_sizes, memory_format=torch.contiguous_format)
  321. output = working_tensor
  322. return output
  323. else:
  324. return self.new_empty(
  325. out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
  326. )
  327. @register_meta(aten.randperm.generator_out)
  328. def meta_randperm(n, *, generator=None, out):
  329. return _maybe_resize_out(out, torch.Size([n]))
  330. @register_meta(aten.randperm.default)
  331. def meta_randperm_default(
  332. n,
  333. *,
  334. dtype=torch.long,
  335. layout=None,
  336. device=None,
  337. pin_memory=None,
  338. ):
  339. return torch.empty(
  340. n, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  341. )
  342. @register_meta([aten.randint.default, aten.randint.out])
  343. @out_wrapper()
  344. def meta_randint(
  345. high,
  346. size,
  347. *,
  348. dtype=torch.long,
  349. layout=None,
  350. device=None,
  351. pin_memory=None,
  352. ):
  353. low = 0
  354. torch._check(
  355. high > low,
  356. lambda: f"random_ expects 'from' to be less than 'to', but got from={low} >= to={high}",
  357. )
  358. return torch.empty(
  359. size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  360. )
  361. @register_meta([aten.randint.low, aten.randint.low_out])
  362. @out_wrapper()
  363. def meta_randint_low(
  364. low,
  365. high,
  366. size,
  367. *,
  368. dtype=torch.long,
  369. layout=None,
  370. device=None,
  371. pin_memory=None,
  372. ):
  373. torch._check(
  374. high > low,
  375. lambda: f"random_ expects 'from' to be less than 'to', but got from={low} >= to={high}",
  376. )
  377. return torch.empty(
  378. size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  379. )
  380. @register_meta([aten.rand.default, aten.rand.out])
  381. @out_wrapper()
  382. def meta_rand_default(size, *, dtype=None, layout=None, device=None, pin_memory=None):
  383. return torch.empty(
  384. size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  385. )
  386. @register_meta([aten._fft_c2r.default, aten._fft_c2r.out])
  387. @out_wrapper()
  388. def meta_fft_c2r(self: Tensor, dim: list[int], normalization: int, lastdim: int):
  389. # _fft_c2r_mkl
  390. torch._check(self.dtype.is_complex)
  391. if device_hint(self) == "cuda":
  392. out_sizes = list(self.size())
  393. out_sizes[dim[-1]] = lastdim
  394. output = self.new_empty(out_sizes, dtype=toRealValueType(self.dtype))
  395. if use_optimized_cufft_path(dim):
  396. return _exec_fft(
  397. output,
  398. self.clone(memory_format=torch.contiguous_format),
  399. out_sizes,
  400. dim,
  401. forward=False,
  402. )
  403. else:
  404. # First complete any C2C transforms
  405. if len(dim) > 1:
  406. temp = meta_fft_c2c(self, dim[:-1], 0, lastdim) # fft_norm_mode::none
  407. else:
  408. temp = self.clone(memory_format=torch.contiguous_format)
  409. return _exec_fft(output, temp, out_sizes, [dim[-1]], forward=False)
  410. else:
  411. input = self
  412. if len(dim) > 1:
  413. c2c_dims = dim[:-1]
  414. input = meta_fft_c2c(self, c2c_dims, normalization, forward=False)
  415. dim = dim[-1:]
  416. out_sizes = list(input.size())
  417. out_sizes[dim[-1]] = lastdim
  418. out = self.new_empty(out_sizes, dtype=toRealValueType(self.dtype))
  419. return _exec_fft(out, input, out_sizes, dim, forward=False)
  420. @register_meta(aten.copy_.default)
  421. def meta_copy_(self, src, non_blocking=False):
  422. # This code simulates the original decomp from inductor,
  423. # which runs most of the meta checks that we care about.
  424. # In theory, we should make this more robust by carefully
  425. # auditing our C++ copy_() kernel and copying the checks here.
  426. from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
  427. # TODO: Ideally, we'd insert a deferred runtime assert here, but if we are
  428. # calling an actual copy_, you'll get that automatically
  429. # https://github.com/pytorch/pytorch/issues/122477
  430. if (
  431. not free_unbacked_symbols(self) and torch._debug_has_internal_overlap(self) == 1
  432. ): # 1 == MemOverlap::Yes
  433. raise RuntimeError(
  434. "more than one element of the written-to tensor refers to a single memory location"
  435. )
  436. if isinstance(src, Tensor):
  437. intermediate = src.to(self, non_blocking)
  438. if self.size() != intermediate.size():
  439. aten.expand_copy.default(intermediate, self.size())
  440. return self
  441. def inferUnsqueezeGeometry(tensor, dim):
  442. result_sizes = list(tensor.size())
  443. result_strides = list(tensor.stride())
  444. # pyrefly: ignore [unsupported-operation]
  445. new_stride = 1 if dim >= tensor.dim() else result_sizes[dim] * result_strides[dim]
  446. # pyrefly: ignore [bad-argument-type]
  447. result_sizes.insert(dim, 1)
  448. # pyrefly: ignore [bad-argument-type]
  449. result_strides.insert(dim, new_stride)
  450. return result_sizes, result_strides
  451. @register_meta(aten.unsqueeze_.default)
  452. def meta_unsqueeze_(self, dim):
  453. dim = maybe_wrap_dim(dim, self.dim() + 1)
  454. g_sizes, g_strides = inferUnsqueezeGeometry(self, dim)
  455. self.as_strided_(g_sizes, g_strides)
  456. return self
  457. @register_meta(aten._sparse_semi_structured_linear)
  458. def meta_sparse_structured_linear(
  459. input: Tensor,
  460. weight: Tensor,
  461. _meta: Tensor,
  462. bias: Tensor | None = None,
  463. _activation_opt: str | None = None,
  464. out_dtype: torch.dtype | None = None,
  465. ):
  466. output_sizes = list(input.shape)
  467. if bias is not None:
  468. if weight.size(0) != bias.size(0):
  469. raise AssertionError(
  470. f"output size mismatch: weight.size(0)={weight.size(0)} != bias.size(0)={bias.size(0)}"
  471. )
  472. if weight.size(1) != input.size(-1) / 2:
  473. raise AssertionError(
  474. f"weight.size(1)={weight.size(1)} != input.size(-1)/2={input.size(-1) / 2}"
  475. )
  476. output_sizes[-1] = weight.size(0)
  477. # see: https://github.com/pytorch/pytorch/pull/114477#issuecomment-1830121375
  478. # We assume that we have already squashed the inputs into a 2-D tensor
  479. # Then, as the output is transposed, we need to propagate the transposed
  480. # stride information to the output tensor
  481. if len(input.shape) != 2:
  482. raise AssertionError(
  483. f"we can only handle the squashed input case, got {len(input.shape)}D input"
  484. )
  485. transposed_strides = (1, input.size(0))
  486. if out_dtype is not None:
  487. if not (input.dtype == torch.int8 and out_dtype == torch.int32):
  488. raise AssertionError(
  489. f"out_dtype is only supported for i8i8->i32 linear operator, got input.dtype={input.dtype}, out_dtype={out_dtype}"
  490. )
  491. output = input.new_empty(
  492. output_sizes,
  493. dtype=input.dtype if out_dtype is None else out_dtype,
  494. ).as_strided(output_sizes, transposed_strides)
  495. return output
  496. @register_meta(aten._sparse_semi_structured_mm)
  497. def meta_sparse_structured_mm(
  498. mat1: Tensor,
  499. mat1_meta: Tensor,
  500. mat2: Tensor,
  501. out_dtype: torch.dtype | None = None,
  502. ):
  503. if len(mat1.shape) != 2:
  504. raise AssertionError(f"mat1 must be 2D, got {len(mat1.shape)}D")
  505. if len(mat1_meta.shape) != 2:
  506. raise AssertionError(f"mat1_meta must be 2D, got {len(mat1_meta.shape)}D")
  507. if len(mat2.shape) != 2:
  508. raise AssertionError(f"mat2 must be 2D, got {len(mat2.shape)}D")
  509. if mat1.size(1) != mat2.size(0) / 2:
  510. raise AssertionError(
  511. f"mat1.size(1)={mat1.size(1)} != mat2.size(0)/2={mat2.size(0) / 2}"
  512. )
  513. output_sizes = [mat1.size(0), mat2.size(1)]
  514. if out_dtype is not None:
  515. if not (mat2.dtype == torch.int8 and out_dtype == torch.int32):
  516. raise AssertionError(
  517. f"out_dtype is only supported for i8i8->i32 linear operator, got mat2.dtype={mat2.dtype}, out_dtype={out_dtype}"
  518. )
  519. output = mat2.new_empty(
  520. output_sizes,
  521. dtype=mat2.dtype if out_dtype is None else out_dtype,
  522. )
  523. return output
  524. @register_meta(aten._sparse_semi_structured_addmm)
  525. def meta_sparse_structured_addmm(
  526. input: Tensor,
  527. mat1: Tensor,
  528. mat1_meta: Tensor,
  529. mat2: Tensor,
  530. *,
  531. alpha=1,
  532. beta=1,
  533. out_dtype: torch.dtype | None = None,
  534. ):
  535. if len(input.shape) != 1:
  536. raise AssertionError(
  537. f"only input broadcasted to columns of mat1 * mat2 product is supported, got {len(input.shape)}D input"
  538. )
  539. if len(mat1.shape) != 2:
  540. raise AssertionError(f"mat1 must be 2D, got {len(mat1.shape)}D")
  541. if len(mat1_meta.shape) != 2:
  542. raise AssertionError(f"mat1_meta must be 2D, got {len(mat1_meta.shape)}D")
  543. if len(mat2.shape) != 2:
  544. raise AssertionError(f"mat2 must be 2D, got {len(mat2.shape)}D")
  545. if input.size(0) != mat1.size(0):
  546. raise AssertionError(
  547. f"only input broadcasted to columns of mat1 * mat2 product is supported, "
  548. f"input.size(0)={input.size(0)} != mat1.size(0)={mat1.size(0)}"
  549. )
  550. if mat1.size(1) != mat2.size(0) / 2:
  551. raise AssertionError(
  552. f"mat1.size(1)={mat1.size(1)} != mat2.size(0)/2={mat2.size(0) / 2}"
  553. )
  554. output_sizes = [mat1.size(0), mat2.size(1)]
  555. if out_dtype is not None:
  556. if not (mat2.dtype == torch.int8 and out_dtype == torch.int32):
  557. raise AssertionError(
  558. f"out_dtype is only supported for i8i8->i32 linear operator, got mat2.dtype={mat2.dtype}, out_dtype={out_dtype}"
  559. )
  560. output = mat2.new_empty(
  561. output_sizes,
  562. dtype=mat2.dtype if out_dtype is None else out_dtype,
  563. )
  564. return output
  565. @register_meta(aten._cslt_sparse_mm)
  566. def meta__cslt_sparse_mm(
  567. compressed_A: torch.Tensor,
  568. dense_B: torch.Tensor,
  569. bias: Tensor | None = None,
  570. alpha: Tensor | None = None,
  571. out_dtype: torch.dtype | None = None,
  572. transpose_result: bool = False,
  573. alg_id: int = 0,
  574. split_k: int = 1,
  575. split_k_mode: int = -1,
  576. ):
  577. if dense_B.dtype not in {
  578. torch.float32,
  579. torch.float16,
  580. torch.bfloat16,
  581. torch.int8,
  582. torch.float8_e4m3fn,
  583. }:
  584. raise AssertionError(
  585. f"_cslt_sparse_mm only supports fp16, bf16, int8, and fp8e4m3, got {dense_B.dtype}"
  586. )
  587. if compressed_A.dtype != dense_B.dtype:
  588. raise AssertionError(
  589. f"inputs must have the same dtype, got {compressed_A.dtype} and {dense_B.dtype}"
  590. )
  591. if len(dense_B.shape) != 2:
  592. raise AssertionError(
  593. f"_cslt_sparse_mm only supports 2d inputs, got {len(dense_B.shape)}D"
  594. )
  595. is_8bit_input_type = compressed_A.dtype in [torch.int8, torch.float8_e4m3fn]
  596. if is_8bit_input_type:
  597. if dense_B.is_contiguous():
  598. raise AssertionError("dense input must be transposed for 8bit dtypes")
  599. n = dense_B.size(1)
  600. m = compressed_A.size(0)
  601. if bias is not None:
  602. if m != bias.size(0):
  603. raise AssertionError(
  604. f"bias size mismatch: m={m} != bias.size(0)={bias.size(0)}"
  605. )
  606. if out_dtype is not None:
  607. if not (
  608. is_8bit_input_type
  609. and out_dtype
  610. in {
  611. torch.float16,
  612. torch.bfloat16,
  613. torch.int32,
  614. torch.float8_e4m3fn,
  615. }
  616. ):
  617. raise AssertionError(
  618. f"out_dtype is not supported for {compressed_A.dtype} x {dense_B.dtype} -> {out_dtype} matmul!"
  619. )
  620. output_shape = (n, m) if transpose_result else (m, n)
  621. return dense_B.new_empty(output_shape, dtype=out_dtype)
  622. @register_meta(aten.index_reduce.default)
  623. def meta_index_reduce(
  624. self: Tensor,
  625. dim: int,
  626. index: Tensor,
  627. source: torch.Tensor,
  628. reduce: str,
  629. *,
  630. include_self: bool = True,
  631. ) -> Tensor:
  632. return torch.empty_like(self, memory_format=torch.contiguous_format)
  633. @register_meta(aten.index_reduce_.default)
  634. def meta_index_reduce_(
  635. self: Tensor,
  636. dim: int,
  637. index: Tensor,
  638. source: torch.Tensor,
  639. reduce: str,
  640. *,
  641. include_self: bool = True,
  642. ) -> Tensor:
  643. return self
  644. # Implementations below are taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py
  645. @out_wrapper()
  646. @register_meta(aten.index_select.default)
  647. def meta_index_select(self, dim, index):
  648. result_size = list(self.size())
  649. if self.dim() > 0:
  650. result_size[dim] = index.numel()
  651. return self.new_empty(result_size)
  652. @register_meta(aten.segment_reduce.default)
  653. def meta_segment_reduce(
  654. data: Tensor,
  655. reduce: str,
  656. *,
  657. lengths: Tensor | None = None,
  658. indices: Tensor | None = None,
  659. offsets: Tensor | None = None,
  660. axis: int = 0,
  661. unsafe: bool = False,
  662. initial=None,
  663. ) -> Tensor:
  664. if indices is not None:
  665. raise NotImplementedError(
  666. "segment_reduce(): indices based reduction is not supported yet."
  667. )
  668. def segment_reduce_lengths_tensor(lengths_shape):
  669. return torch.empty(
  670. lengths_shape + data.shape[axis + 1 :],
  671. dtype=data.dtype,
  672. device="meta",
  673. memory_format=torch.contiguous_format,
  674. )
  675. if lengths is not None:
  676. return segment_reduce_lengths_tensor(lengths.shape)
  677. # FIXME should probably check that lengths and offset aren't both set, but
  678. # the ATen implementation neglects this too
  679. if offsets is not None:
  680. # lengths == torch.diff(offsets)
  681. lengths_shape = offsets.shape[:-1] + (offsets.shape[-1] - 1,)
  682. return segment_reduce_lengths_tensor(lengths_shape)
  683. raise RuntimeError("segment_reduce(): Either lengths or offsets must be defined.")
  684. @register_meta([aten.max.default, aten.max.unary_out])
  685. @out_wrapper()
  686. def meta_max(self):
  687. return self.new_empty(())
  688. @register_meta(aten.max.dim)
  689. def meta_max_dim(self, dim, keepdim=False):
  690. dim = utils.reduction_dims(self.shape, (dim,))
  691. output_shape = _compute_reduction_shape(self, dim, keepdim)
  692. return (
  693. self.new_empty(output_shape),
  694. self.new_empty(output_shape, dtype=torch.long),
  695. )
  696. @register_meta([aten.min.default, aten.min.unary_out])
  697. @out_wrapper()
  698. def meta_min(self):
  699. return self.new_empty(())
  700. @register_meta(aten.min.dim)
  701. def meta_min_dim(self, dim, keepdim=False):
  702. dim = utils.reduction_dims(self.shape, (dim,))
  703. output_shape = _compute_reduction_shape(self, dim, keepdim)
  704. return (
  705. self.new_empty(output_shape),
  706. self.new_empty(output_shape, dtype=torch.long),
  707. )
  708. @register_meta(aten.angle.default)
  709. def meta_angle(self):
  710. if self.is_complex():
  711. result_dtype = corresponding_real_dtype(self.dtype)
  712. else:
  713. _, result_dtype = elementwise_dtypes(
  714. self,
  715. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  716. )
  717. return torch.empty_like(self, dtype=result_dtype)
  718. @register_meta(aten.angle.out)
  719. def meta_angle_out(self, out):
  720. torch._resize_output_(out, self.size(), self.device)
  721. return out.copy_(torch.angle(self))
  722. @register_meta(aten._assert_async.default)
  723. def assert_async(val):
  724. return
  725. @register_meta(aten._assert_async.msg)
  726. def assert_async_meta(val, assert_msg):
  727. return
  728. @register_meta(aten._print.default)
  729. def print_meta(s):
  730. return
  731. @register_meta(aten._make_dep_token.default)
  732. def make_dep_token(
  733. *,
  734. dtype=None,
  735. layout=None,
  736. device=None,
  737. pin_memory=None,
  738. memory_format=None,
  739. ):
  740. return torch.empty(0, device="meta")
  741. @register_meta(aten.sym_constrain_range.default)
  742. def sym_constrain_range(size, min=None, max=None):
  743. # Avoid importing sympy at a module level
  744. from torch.fx.experimental.symbolic_shapes import constrain_range
  745. if isinstance(size, (SymFloat, SymBool)):
  746. raise ValueError("Constraining SymFloat or Symbool is nyi")
  747. constrain_range(size, min=min, max=max)
  748. @register_meta(aten._functional_sym_constrain_range.default)
  749. def functional_sym_constrain_range(size, min=None, max=None, dep_token=None):
  750. aten.sym_constrain_range(size, min=min, max=max)
  751. return dep_token
  752. @register_meta(aten.sym_constrain_range_for_size.default)
  753. def sym_constrain_range_for_size(size, min=None, max=None):
  754. # Avoid importing sympy at a module level
  755. from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
  756. if min is None and max is None:
  757. torch._check(size >= 0)
  758. return
  759. if isinstance(size, (SymFloat, SymBool)):
  760. raise ValueError("Constraining SymFloat or Symbool is nyi")
  761. if type(size) is int:
  762. if min is not None:
  763. torch._check(size >= min)
  764. if max is not None:
  765. torch._check(size <= max)
  766. return
  767. _constrain_range_for_size(size, min=min, max=max)
  768. @register_meta(aten._functional_sym_constrain_range_for_size.default)
  769. def functional_sym_constrain_range_for_size(size, min, max, dep_token):
  770. aten.sym_constrain_range_for_size(size, min=min, max=max)
  771. return dep_token
  772. @register_meta(aten._functional_assert_async.msg)
  773. def functional_assert_async_meta(val, assert_msg, dep_token):
  774. return dep_token
  775. # From aten/src/ATen/native/LinearAlgebraUtils.h
  776. def squareCheckInputs(self: Tensor, f_name: str):
  777. if self.dim() < 2:
  778. raise AssertionError(
  779. f"{f_name}: The input tensor must have at least 2 dimensions, got {self.dim()}"
  780. )
  781. # Use torch._check to defer validation to runtime for unbacked symbolic dimensions.
  782. torch._check(
  783. self.size(-1) == self.size(-2),
  784. lambda: f"{f_name}: A must be batches of square matrices, "
  785. f"but they are {self.size(-2)} by {self.size(-1)} matrices",
  786. )
  787. # Validates input shapes and devices
  788. # for linear solve methods (solve, cholesky_solve, lu_solve, triangular_solve)
  789. # From aten/src/ATen/native/LinearAlgebraUtils.h
  790. def linearSolveCheckInputs(self: Tensor, A: Tensor, name: str):
  791. torch._check(
  792. self.device == A.device,
  793. lambda: (
  794. f"Expected b and A to be on the same device, but found b on "
  795. f"{self.device} and A on {A.device} instead."
  796. ),
  797. )
  798. torch._check(
  799. self.dtype == A.dtype,
  800. lambda: (
  801. f"Expected b and A to have the same dtype, but found b of type "
  802. f"{self.dtype} and A of type {A.dtype} instead."
  803. ),
  804. )
  805. torch._check(
  806. A.size(-1) == A.size(-2),
  807. lambda: (
  808. f"A must be batches of square matrices, "
  809. f"but they are {A.size(-2)} by {A.size(-1)} matrices"
  810. ),
  811. )
  812. torch._check(
  813. A.size(-1) == self.size(-2),
  814. lambda: (
  815. f"Incompatible matrix sizes for {name}: each A "
  816. f"matrix is {A.size(-1)} by {A.size(-1)}"
  817. f" but each b matrix is {self.size(-2)} by {self.size(-1)}"
  818. ),
  819. )
  820. # From aten/src/ATen/native/LinearAlgebraUtils.h
  821. def checkFloatingOrComplex(
  822. t: Tensor,
  823. f_name: str,
  824. allow_low_precision_dtypes: bool = True,
  825. ):
  826. dtype = t.dtype
  827. torch._check(
  828. t.is_floating_point() or t.is_complex(),
  829. lambda: f"{f_name}: Expected a floating point or complex tensor as input. Got {dtype}",
  830. )
  831. if not allow_low_precision_dtypes:
  832. torch._check(
  833. dtype in (torch.float, torch.double, torch.cfloat, torch.cdouble),
  834. lambda: f"{f_name}: Low precision dtypes not supported. Got {dtype}",
  835. )
  836. # From aten/src/ATen/native/LinearAlgebraUtils.h
  837. def checkIsMatrix(A: Tensor, f_name: str, arg_name: str = "A"):
  838. torch._check(
  839. A.dim() >= 2,
  840. lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.",
  841. )
  842. def checkInputsSolver(A: Tensor, B: Tensor, left: bool, f_name: str):
  843. squareCheckInputs(A, f_name)
  844. checkIsMatrix(B, f_name)
  845. torch._check(
  846. A.size(-2) == B.size(-2) if left else A.size(-1) == B.size(-1),
  847. lambda: (
  848. f"{f_name}: Incompatible shapes of A and B for the equation "
  849. f"{'AX = B' if left else 'XA = B'}"
  850. f" ({A.size(-2)}x{A.size(-1)} and {B.size(-2)}x{B.size(-1)})"
  851. ),
  852. )
  853. def checkSameDevice(
  854. fn_name: str,
  855. result: Tensor,
  856. input: Tensor,
  857. result_name: str = "result",
  858. ):
  859. torch._check(
  860. result.device == input.device,
  861. lambda: (
  862. f"{fn_name}: Expected {result_name} and input tensors to be on the same device, but got "
  863. f"{result_name} on {result.device} and input on {input.device}"
  864. ),
  865. )
  866. def checkUplo(UPLO: str):
  867. UPLO_uppercase = UPLO.upper()
  868. torch._check(
  869. len(UPLO) == 1 and (UPLO_uppercase == "U" or UPLO_uppercase == "L"),
  870. lambda: f"Expected UPLO argument to be 'L' or 'U', but got {UPLO}",
  871. )
  872. @register_meta([aten._linalg_eigh.default, aten._linalg_eigh.eigenvalues])
  873. @out_wrapper("eigenvalues", "eigenvectors")
  874. def meta__linalg_eigh(A: Tensor, UPLO: str = "L", compute_v: bool = True):
  875. squareCheckInputs(A, "linalg.eigh")
  876. checkUplo(UPLO)
  877. shape = list(A.shape)
  878. if compute_v:
  879. vecs = A.new_empty(shape)
  880. vecs.as_strided_(shape, make_contiguous_strides_for(shape, row_major=False))
  881. else:
  882. vecs = A.new_empty([0])
  883. shape.pop()
  884. vals = A.new_empty(shape, dtype=toRealValueType(A.dtype))
  885. return vals, vecs
  886. @register_meta([aten._linalg_eigvals.default, aten.linalg_eigvals.out])
  887. @out_wrapper()
  888. def meta__linalg_eigvals(input: Tensor) -> Tensor:
  889. squareCheckInputs(input, "linalg.eigvals")
  890. complex_dtype = (
  891. input.dtype
  892. if utils.is_complex_dtype(input.dtype)
  893. else utils.corresponding_complex_dtype(input.dtype)
  894. )
  895. return input.new_empty(input.shape[:-1], dtype=complex_dtype)
  896. @register_meta([aten.linalg_eig])
  897. @out_wrapper("eigenvalues", "eigenvectors")
  898. def meta_linalg_eig(input: Tensor):
  899. squareCheckInputs(input, "linalg.eig")
  900. complex_dtype = (
  901. input.dtype
  902. if utils.is_complex_dtype(input.dtype)
  903. else utils.corresponding_complex_dtype(input.dtype)
  904. )
  905. values = input.new_empty(input.shape[:-1], dtype=complex_dtype)
  906. vectors = input.new_empty(input.shape, dtype=complex_dtype)
  907. is_cuda = device_hint(input) == "cuda"
  908. vectors.as_strided_(
  909. input.shape, make_contiguous_strides_for(input.shape, row_major=is_cuda)
  910. )
  911. return values, vectors
  912. def cloneBatchedColumnMajor(src: Tensor) -> Tensor:
  913. return src.mT.clone(memory_format=torch.contiguous_format).transpose(-2, -1)
  914. @register_meta(aten._cholesky_solve_helper)
  915. @out_wrapper()
  916. def _cholesky_solve_helper(self: Tensor, A: Tensor, upper: bool) -> Tensor:
  917. return cloneBatchedColumnMajor(self)
  918. @register_meta(aten.cholesky_solve)
  919. @out_wrapper()
  920. def cholesky_solve(self: Tensor, A: Tensor, upper: bool = False) -> Tensor:
  921. torch._check(
  922. self.ndim >= 2,
  923. lambda: f"b should have at least 2 dimensions, but has {self.ndim} dimensions instead",
  924. )
  925. torch._check(
  926. A.ndim >= 2,
  927. lambda: f"u should have at least 2 dimensions, but has {A.ndim} dimensions instead",
  928. )
  929. self_broadcasted, A_broadcasted = _linalg_broadcast_batch_dims_name(
  930. self, A, "cholesky_solve"
  931. )
  932. return _cholesky_solve_helper(self_broadcasted, A_broadcasted, upper)
  933. @register_meta(aten.cholesky)
  934. @out_wrapper()
  935. def cholesky(self: Tensor, upper: bool = False) -> Tensor:
  936. if self.numel() == 0:
  937. return torch.empty_like(self, memory_format=torch.legacy_contiguous_format)
  938. squareCheckInputs(self, "cholesky")
  939. return cloneBatchedColumnMajor(self)
  940. @register_meta(aten.cholesky_inverse)
  941. @out_wrapper()
  942. def cholesky_inverse(self: Tensor, upper: bool = False) -> Tensor:
  943. squareCheckInputs(self, "cholesky_inverse")
  944. return cloneBatchedColumnMajor(self)
  945. # From aten/src/ATen/native/BatchLinearAlgebra.cpp
  946. @register_meta(aten.linalg_cholesky_ex.default)
  947. def linalg_cholesky_ex(A: Tensor, upper: bool = False, check_errors: bool = False):
  948. squareCheckInputs(A, "linalg.cholesky")
  949. checkFloatingOrComplex(A, "linalg.cholesky")
  950. A_shape = A.shape
  951. ndim = len(A_shape)
  952. # L
  953. L_strides = make_contiguous_strides_for(A_shape, False)
  954. L = A.new_empty(A_shape)
  955. L.as_strided_(A_shape, L_strides)
  956. # infos
  957. infos = A.new_empty(A_shape[0 : ndim - 2], dtype=torch.int32)
  958. return L, infos
  959. @register_meta(
  960. [aten.linalg_householder_product.default, aten.linalg_householder_product.out]
  961. )
  962. @out_wrapper()
  963. def linalg_householder_product(input: Tensor, tau: Tensor) -> Tensor:
  964. torch._check(
  965. input.ndim >= 2,
  966. lambda: "torch.linalg.householder_product: input must have at least 2 dimensions.",
  967. )
  968. torch._check(
  969. input.size(-2) >= input.size(-1),
  970. lambda: "torch.linalg.householder_product: input.shape[-2] must be greater than or equal to input.shape[-1]",
  971. )
  972. torch._check(
  973. input.size(-1) >= tau.size(-1),
  974. lambda: "torch.linalg.householder_product: input.shape[-1] must be greater than or equal to tau.shape[-1]",
  975. )
  976. torch._check(
  977. input.ndim - tau.ndim == 1,
  978. lambda: (
  979. f"torch.linalg.householder_product: Expected tau to have one dimension less than input, "
  980. f"but got tau.ndim equal to {tau.ndim} and input.ndim is equal to {input.ndim}"
  981. ),
  982. )
  983. if input.ndim > 2:
  984. expected_batch_tau_shape = input.shape[:-2]
  985. actual_batch_tau_shape = tau.shape[:-1]
  986. torch._check(
  987. actual_batch_tau_shape == expected_batch_tau_shape,
  988. lambda: (
  989. f"torch.linalg.householder_product: Expected batch dimensions of tau to be "
  990. f"equal to input.shape[:-2], but got {actual_batch_tau_shape}"
  991. ),
  992. )
  993. torch._check(
  994. tau.dtype == input.dtype,
  995. lambda: (
  996. f"torch.linalg.householder_product: tau dtype {tau.dtype}"
  997. f" does not match input dtype {input.dtype}"
  998. ),
  999. )
  1000. checkSameDevice("torch.linalg.householder_product", tau, input, "tau")
  1001. return torch.empty_strided(
  1002. size=input.shape,
  1003. stride=make_contiguous_strides_for(input.shape, row_major=False),
  1004. dtype=input.dtype,
  1005. device=input.device,
  1006. )
  1007. # From aten/src/ATen/native/BatchLinearAlgebra.cpp
  1008. @register_meta(aten.linalg_inv_ex.default)
  1009. def linalg_inv_ex_meta(A: Tensor, check_errors: bool = False):
  1010. squareCheckInputs(A, "linalg.inv_ex")
  1011. checkFloatingOrComplex(A, "linalg.inv_ex", allow_low_precision_dtypes=False)
  1012. L = A.new_empty(A.shape)
  1013. L.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False))
  1014. infos = A.new_empty(A.shape[:-2], dtype=torch.int32)
  1015. return L, infos
  1016. @register_meta([aten.linalg_ldl_factor_ex.default, aten.linalg_ldl_factor_ex.out])
  1017. @out_wrapper("LD", "pivots", "info")
  1018. def linalg_ldl_factor_ex_meta(
  1019. self: Tensor,
  1020. *,
  1021. hermitian: bool = False,
  1022. check_errors: bool = False,
  1023. ) -> tuple[Tensor, Tensor, Tensor]:
  1024. squareCheckInputs(self, "torch.linalg.ldl_factor_ex")
  1025. checkFloatingOrComplex(self, "torch.linalg.ldl_factor_ex")
  1026. LD = torch.empty_strided(
  1027. size=self.shape,
  1028. stride=make_contiguous_strides_for(self.shape, row_major=False),
  1029. dtype=self.dtype,
  1030. device=self.device,
  1031. )
  1032. pivots = self.new_empty(self.shape[:-1], dtype=torch.int)
  1033. info = self.new_empty(self.shape[:-2], dtype=torch.int)
  1034. return LD, pivots, info
  1035. @register_meta([aten.linalg_ldl_solve.default, aten.linalg_ldl_solve.out])
  1036. @out_wrapper()
  1037. def linalg_ldl_solve_meta(
  1038. LD: Tensor,
  1039. pivots: Tensor,
  1040. B: Tensor,
  1041. *,
  1042. hermitian: bool = False,
  1043. ) -> Tensor:
  1044. squareCheckInputs(LD, "torch.linalg.ldl_solve")
  1045. checkFloatingOrComplex(LD, "torch.linalg.ldl_solve")
  1046. linearSolveCheckInputs(B, LD, "torch.linalg.ldl_solve")
  1047. torch._check(
  1048. B.ndim >= 2,
  1049. lambda: (
  1050. f"torch.linalg.ldl_solve: Expected B to have at least 2 dimensions, "
  1051. f"but it has {B.ndim} dimensions instead"
  1052. ),
  1053. )
  1054. expected_pivots_shape = LD.shape[:-1]
  1055. torch._check(
  1056. expected_pivots_shape == pivots.shape,
  1057. lambda: (
  1058. f"torch.linalg.ldl_solve: Expected LD.shape[:-1] and pivots.shape to be the same, "
  1059. f"but got pivots with shape {pivots.shape} instead"
  1060. ),
  1061. )
  1062. torch._check(
  1063. utils.is_integer_dtype(pivots.dtype),
  1064. lambda: f"torch.linalg.ldl_solve: Expected pivots to be integers. Got {pivots.dtype}",
  1065. )
  1066. torch._check(
  1067. LD.dtype == B.dtype,
  1068. lambda: f"torch.linalg.ldl_solve: LD dtype {LD.dtype} does not match b dtype {B.dtype}",
  1069. )
  1070. B_broadcast_size, _ = _linalg_broadcast_batch_dims(B, LD)
  1071. return torch.empty_strided(
  1072. size=B_broadcast_size,
  1073. stride=make_contiguous_strides_for(B_broadcast_size, row_major=False),
  1074. dtype=B.dtype,
  1075. device=B.device,
  1076. )
  1077. @register_meta([aten.linalg_lu.default, aten.linalg_lu.out])
  1078. @out_wrapper("P", "L", "U")
  1079. def linalg_lu_meta(A: Tensor, *, pivot: bool = True) -> tuple[Tensor, Tensor, Tensor]:
  1080. torch._check(
  1081. A.ndim >= 2,
  1082. lambda: f"linalg.lu: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead",
  1083. )
  1084. sizes = list(A.shape)
  1085. m = sizes[-2]
  1086. n = sizes[-1]
  1087. # Use sym_min to handle unbacked symbolic dimensions
  1088. k = sym_min(m, n)
  1089. sizes[-1] = m
  1090. if pivot:
  1091. P = A.new_empty(sizes)
  1092. else:
  1093. P = A.new_empty([0])
  1094. sizes[-1] = k
  1095. L = A.new_empty(sizes)
  1096. sizes[-2] = k
  1097. sizes[-1] = n
  1098. U = A.new_empty(sizes)
  1099. return P, L, U
  1100. @register_meta([aten.linalg_lu_factor_ex.default, aten.linalg_lu_factor_ex.out])
  1101. @out_wrapper("LU", "pivots", "info")
  1102. def linalg_lu_factor_ex_meta(
  1103. A: Tensor,
  1104. *,
  1105. pivot: bool = True,
  1106. check_errors: bool = False,
  1107. ) -> tuple[Tensor, Tensor, Tensor]:
  1108. torch._check(
  1109. A.ndim >= 2,
  1110. lambda: f"torch.lu_factor: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead",
  1111. )
  1112. sizes = list(A.shape)
  1113. m = sizes[-2]
  1114. n = sizes[-1]
  1115. LU = torch.empty_strided(
  1116. size=sizes,
  1117. stride=make_contiguous_strides_for(sizes, row_major=False),
  1118. dtype=A.dtype,
  1119. device=A.device,
  1120. )
  1121. # Sets sizes to the size of pivots
  1122. sizes.pop()
  1123. # Use sym_min to handle unbacked symbolic dimensions
  1124. sizes[-1] = sym_min(m, n)
  1125. pivots = A.new_empty(sizes, dtype=torch.int)
  1126. # Sets sizes to the size of info
  1127. sizes.pop()
  1128. info = A.new_empty(sizes, dtype=torch.int)
  1129. return LU, pivots, info
  1130. @register_meta([aten.linalg_lu_solve.default, aten.linalg_lu_solve.out])
  1131. @out_wrapper()
  1132. def linalg_lu_solve_meta(
  1133. LU: Tensor,
  1134. pivots: Tensor,
  1135. B: Tensor,
  1136. *,
  1137. left: bool = True,
  1138. adjoint: bool = False,
  1139. ) -> Tensor:
  1140. # dtype
  1141. checkFloatingOrComplex(LU, "torch.linalg.lu_solve")
  1142. torch._check(
  1143. LU.dtype == B.dtype,
  1144. lambda: (
  1145. f"linalg.lu_solve: Expected LU and B to have the same dtype, "
  1146. f"but found LU of type {LU.dtype} and B of type {B.dtype} instead"
  1147. ),
  1148. )
  1149. torch._check(
  1150. pivots.dtype == torch.int,
  1151. lambda: "linalg.lu_solve: pivots should be a Tensor of scalar type torch.int32",
  1152. )
  1153. # matrix shapes
  1154. squareCheckInputs(LU, "torch.linalg.lu_solve")
  1155. checkInputsSolver(LU, B, left, "linalg.lu_solve")
  1156. torch._check(
  1157. LU.size(-1) == pivots.size(-1),
  1158. lambda: "linalg.lu_solve: Number of pivots per batch should be same as the dimension of the matrix",
  1159. )
  1160. # batches
  1161. torch._check(
  1162. LU.shape[:-1] == pivots.shape,
  1163. lambda: (
  1164. f"linalg.lu_solve: Expected LU.shape[:-1] and pivots.shape to be the same, "
  1165. f"but got pivots with shape {pivots.shape} instead"
  1166. ),
  1167. )
  1168. B_broadcast_size, _ = _linalg_broadcast_batch_dims(B, LU)
  1169. result = torch.empty_strided(
  1170. size=B_broadcast_size,
  1171. stride=make_contiguous_strides_for(B_broadcast_size, row_major=not left),
  1172. dtype=B.dtype,
  1173. device=B.device,
  1174. )
  1175. if result.numel() != 0 and not left:
  1176. if result.is_complex():
  1177. result = result.conj()
  1178. return result
  1179. @register_meta(aten.lu_unpack)
  1180. @out_wrapper("P", "L", "U")
  1181. def lu_unpack_meta(
  1182. LU: Tensor,
  1183. pivots: Tensor,
  1184. unpack_data: bool = True,
  1185. unpack_pivots: bool = True,
  1186. ) -> tuple[Tensor, Tensor, Tensor]:
  1187. torch._check(
  1188. LU.ndim >= 2,
  1189. lambda: f"torch.lu_unpack: Expected tensor with 2 or more dimensions. Got size: {LU.shape} instead",
  1190. )
  1191. if unpack_pivots:
  1192. torch._check(
  1193. pivots.dtype == torch.int32,
  1194. lambda: (
  1195. "torch.lu_unpack: LU_pivots is expected to be a contiguous tensor of torch.int32 dtype.\n"
  1196. "Note: this function is intended to be used with the output produced by torch.linalg.lu_factor"
  1197. ),
  1198. )
  1199. sizes = list(LU.shape)
  1200. m = sizes[-2]
  1201. n = sizes[-1]
  1202. k = min(m, n)
  1203. sizes[-1] = m
  1204. if unpack_pivots:
  1205. P = LU.new_empty(sizes)
  1206. else:
  1207. P = LU.new_empty([0])
  1208. if unpack_data:
  1209. sizes[-1] = k
  1210. L = LU.new_empty(sizes)
  1211. sizes[-2] = k
  1212. sizes[-1] = n
  1213. U = LU.new_empty(sizes)
  1214. else:
  1215. L = LU.new_empty([0])
  1216. U = LU.new_empty([0])
  1217. return P, L, U
  1218. # parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced)
  1219. def _parse_qr_mode(mode: str) -> tuple[bool, bool]:
  1220. if mode == "reduced":
  1221. compute_q = True
  1222. reduced = True
  1223. elif mode == "complete":
  1224. compute_q = True
  1225. reduced = False
  1226. elif mode == "r":
  1227. compute_q = False
  1228. reduced = True # this is actually irrelevant in this mode
  1229. else:
  1230. torch._check(
  1231. False,
  1232. lambda: (
  1233. f"qr received unrecognized mode '{mode}' "
  1234. f"but expected one of 'reduced' (default), 'r', or 'complete'"
  1235. ),
  1236. )
  1237. return compute_q, reduced # type: ignore[possibly-undefined]
  1238. @register_meta([aten.linalg_qr.default, aten.linalg_qr.out])
  1239. @out_wrapper("Q", "R")
  1240. def linalg_qr_meta(A: Tensor, mode: str = "reduced") -> tuple[Tensor, Tensor]:
  1241. checkIsMatrix(A, "linalg.qr")
  1242. checkFloatingOrComplex(A, "linalg.qr")
  1243. compute_q, reduced_mode = _parse_qr_mode(mode)
  1244. m = A.shape[-2]
  1245. n = A.shape[-1]
  1246. k = min(m, n)
  1247. if compute_q:
  1248. Q_shape = list(A.shape)
  1249. Q_shape[-1] = k if reduced_mode else m
  1250. Q = A.new_empty(Q_shape)
  1251. Q.as_strided_(Q_shape, make_contiguous_strides_for(Q_shape, row_major=False))
  1252. else:
  1253. Q = A.new_empty([0])
  1254. # For readability
  1255. R_shape = list(A.shape)
  1256. R_shape[-2] = k if reduced_mode or not compute_q else m
  1257. R = A.new_empty(R_shape)
  1258. R.as_strided_(R_shape, make_contiguous_strides_for(R_shape, row_major=False))
  1259. return Q, R
  1260. @register_meta([aten._linalg_slogdet.default, aten._linalg_slogdet.sign])
  1261. @out_wrapper("sign", "logabsdet", "LU", "pivots")
  1262. def _linalg_slogdet(A: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
  1263. squareCheckInputs(A, "linalg.slogdet")
  1264. checkFloatingOrComplex(A, "linalg.slogdet", False)
  1265. shape = A.shape
  1266. sign = A.new_empty(shape[:-2])
  1267. logabsdet = A.new_empty(shape[:-2], dtype=toRealValueType(A.dtype))
  1268. LU = torch.empty_strided(
  1269. size=shape,
  1270. stride=make_contiguous_strides_for(shape, False),
  1271. dtype=A.dtype,
  1272. device=A.device,
  1273. )
  1274. pivots = A.new_empty(shape[:-1], dtype=torch.int32)
  1275. return sign, logabsdet, LU, pivots
  1276. # From aten/src/ATen/native/BatchLinearAlgebra.cpp
  1277. # NOTE: matching defaults in aten/src/ATen/native/native_functions.yaml
  1278. @register_meta(aten._linalg_svd.default)
  1279. def _linalg_svd_meta(
  1280. A: Tensor,
  1281. full_matrices: bool = False,
  1282. compute_uv: bool = True,
  1283. driver: str | None = None,
  1284. ):
  1285. checkIsMatrix(A, "linalg.svd")
  1286. checkFloatingOrComplex(A, "linalg.svd")
  1287. batch_dims = list(A.shape[:-2])
  1288. m = A.shape[-2]
  1289. n = A.shape[-1]
  1290. k = min(m, n)
  1291. if compute_uv:
  1292. U_shape = batch_dims + [m, m if full_matrices else k]
  1293. U = A.new_empty(U_shape)
  1294. U.as_strided_(U_shape, make_contiguous_strides_for(U_shape, row_major=False))
  1295. V_shape = batch_dims + [n if full_matrices else k, n]
  1296. V = A.new_empty(V_shape)
  1297. # NB: This checks for CUDA since there is no way to check for cuSolver.
  1298. # Also, this might not work correctly on CPU when fake_device is not
  1299. # available as device_hint just defaults to CUDA in that case. See
  1300. # _linalg_svd meta in core.
  1301. is_cuda = device_hint(A) == "cuda"
  1302. V.as_strided_(V_shape, make_contiguous_strides_for(V_shape, row_major=is_cuda))
  1303. else:
  1304. # doesn't matter
  1305. U = A.new_empty([0])
  1306. V = A.new_empty([0])
  1307. # S is always real, even when A is complex.
  1308. S = A.new_empty(batch_dims + [k], dtype=toRealValueType(A.dtype))
  1309. return U, S, V
  1310. def _linalg_broadcast_batch_dims(
  1311. arg1: Tensor,
  1312. arg2: Tensor,
  1313. ) -> tuple[list[int], list[int]]:
  1314. # broadcast the batch dimensions of arg1 and arg2.
  1315. arg1_batch_sizes = arg1.shape[:-2]
  1316. arg2_batch_sizes = arg2.shape[:-2]
  1317. expand_batch_portion = _broadcast_shapes(arg1_batch_sizes, arg2_batch_sizes)
  1318. arg1_expand_size = list(expand_batch_portion)
  1319. arg1_expand_size += [arg1.size(-2), arg1.size(-1)]
  1320. arg2_expand_size = list(expand_batch_portion)
  1321. arg2_expand_size += [arg2.size(-2), arg2.size(-1)]
  1322. return arg1_expand_size, arg2_expand_size
  1323. def _linalg_broadcast_batch_dims_name(
  1324. arg1: Tensor,
  1325. arg2: Tensor,
  1326. name: str | None,
  1327. ) -> tuple[Tensor, Tensor]:
  1328. # If there's no name we assume we don't want to check the errors
  1329. if name:
  1330. linearSolveCheckInputs(arg1, arg2, name)
  1331. arg1_expand_size, arg2_expand_size = _linalg_broadcast_batch_dims(arg1, arg2)
  1332. arg1_broadcasted = (
  1333. arg1 if arg1_expand_size == arg1.shape else arg1.expand(arg1_expand_size)
  1334. )
  1335. arg2_broadcasted = (
  1336. arg2 if arg2_expand_size == arg2.shape else arg2.expand(arg2_expand_size)
  1337. )
  1338. return arg1_broadcasted, arg2_broadcasted
  1339. def linalg_solve_is_vector_rhs(input: Tensor, other: Tensor) -> bool:
  1340. expected_batched_rhs_shape = input.shape[:-1]
  1341. vector_case = other.ndim == 1 or (
  1342. input.ndim - 1 == other.ndim and other.shape == expected_batched_rhs_shape
  1343. )
  1344. return vector_case
  1345. @register_meta(aten._linalg_solve_ex)
  1346. def _linalg_solve_ex(
  1347. A: Tensor,
  1348. B: Tensor,
  1349. *,
  1350. left: bool = True,
  1351. check_errors: bool = False,
  1352. result: Tensor | None = None,
  1353. LU: Tensor | None = None,
  1354. pivots: Tensor | None = None,
  1355. info: Tensor | None = None,
  1356. ) -> tuple[Tensor, Tensor, Tensor, Tensor]:
  1357. checkFloatingOrComplex(A, "linalg.solve")
  1358. torch._check(
  1359. A.dtype == B.dtype,
  1360. lambda: (
  1361. f"linalg.solve: Expected A and B to have the same dtype, but found A of type "
  1362. f"{A.dtype} and B of type {B.dtype} instead"
  1363. ),
  1364. )
  1365. vector_case = linalg_solve_is_vector_rhs(A, B)
  1366. B_ = B.unsqueeze(-1) if vector_case else B
  1367. checkInputsSolver(A, B_, left, "linalg.solve")
  1368. B_broad_shape, _ = _linalg_broadcast_batch_dims(B_, A)
  1369. torch._check(
  1370. left or not vector_case,
  1371. lambda: (
  1372. "linalg.solve: Vector broadcasting of the left hand side is not supported for left=False. "
  1373. "In this case linalg.solve is equivalent to B / A.squeeze(-1)"
  1374. ),
  1375. )
  1376. result_shape = B_broad_shape[:-1] if vector_case else B_broad_shape
  1377. result_ = torch.empty_strided(
  1378. size=result_shape,
  1379. stride=make_contiguous_strides_for(result_shape, not left),
  1380. dtype=B.dtype,
  1381. device=B.device,
  1382. )
  1383. shape = A.shape
  1384. LU_ = torch.empty_strided(
  1385. size=shape,
  1386. stride=make_contiguous_strides_for(shape, False),
  1387. dtype=A.dtype,
  1388. device=A.device,
  1389. )
  1390. pivots_ = A.new_empty(shape[:-1], dtype=torch.int32)
  1391. info_ = A.new_empty(shape[:-2], dtype=torch.int32)
  1392. out = (result, LU, pivots, info)
  1393. res = (result_, LU_, pivots_, info_)
  1394. if all(x is not None for x in out):
  1395. for r, o in zip(res, out):
  1396. # resize and copy operations are done in-place
  1397. _maybe_resize_out(o, r.shape) # type: ignore[arg-type]
  1398. # strides are not copied in out_wrapper
  1399. o.as_strided_(r.shape, r.stride()) # type: ignore[union-attr]
  1400. _safe_copy_out(copy_from=r, copy_to=o, exact_dtype=False) # type: ignore[arg-type]
  1401. return res
  1402. @register_meta([aten.linalg_solve_triangular.default, aten.linalg_solve_triangular.out])
  1403. def linalg_solve_triangular_meta(
  1404. A: Tensor,
  1405. B: Tensor,
  1406. *,
  1407. upper: bool,
  1408. left: bool = True,
  1409. unitriangular: bool = False,
  1410. out: Tensor | None = None,
  1411. ) -> Tensor:
  1412. if out is None:
  1413. out = A.new_empty([0])
  1414. if not isinstance(out, TensorLike):
  1415. raise AssertionError(f"out must be TensorLike, got {type(out)}")
  1416. checkInputsSolver(A, B, left, "linalg.solve_triangular")
  1417. B_, A_ = _linalg_broadcast_batch_dims_name(B, A, None)
  1418. avoid_copy_A = A_.transpose(-2, -1).is_contiguous() and A_.is_conj()
  1419. if avoid_copy_A:
  1420. out = _maybe_resize_out(out, B_.shape)
  1421. else:
  1422. # reimplementation of resize_output with result F-contig
  1423. if _resize_output_check(out, B_.shape):
  1424. out.resize_(B_.transpose(-2, -1).shape)
  1425. out.transpose_(-2, -1)
  1426. return out # type: ignore[return-value]
  1427. @register_meta(aten.triangular_solve)
  1428. @out_wrapper("X", "M", exact_dtype=True)
  1429. def triangular_solve_meta(
  1430. self: Tensor,
  1431. A: Tensor,
  1432. upper: bool = True,
  1433. transpose: bool = False,
  1434. unitriangular: bool = False,
  1435. ) -> tuple[Tensor, Tensor]:
  1436. torch._check(
  1437. self.ndim >= 2,
  1438. lambda: (
  1439. f"torch.triangular_solve: Expected b to have at least 2 dimensions, "
  1440. f"but it has {self.ndim} dimensions instead"
  1441. ),
  1442. )
  1443. torch._check(
  1444. A.ndim >= 2,
  1445. lambda: (
  1446. f"torch.triangular_solve: Expected A to have at least 2 dimensions, "
  1447. f"but it has {A.ndim} dimensions instead"
  1448. ),
  1449. )
  1450. linearSolveCheckInputs(self, A, "triangular_solve")
  1451. if A.layout == torch.strided:
  1452. self_broadcast_size, A_broadcast_size = _linalg_broadcast_batch_dims(self, A)
  1453. solution = torch.empty_strided(
  1454. size=self_broadcast_size,
  1455. stride=make_contiguous_strides_for(self_broadcast_size, row_major=False),
  1456. dtype=self.dtype,
  1457. device=self.device,
  1458. )
  1459. cloned_coefficient = torch.empty_strided(
  1460. size=A_broadcast_size,
  1461. stride=make_contiguous_strides_for(A_broadcast_size, row_major=False),
  1462. dtype=A.dtype,
  1463. device=A.device,
  1464. )
  1465. elif A.layout == torch.sparse_csr or A.layout == torch.sparse_bsr:
  1466. solution = torch.empty_like(self)
  1467. cloned_coefficient = self.new_empty([0])
  1468. else:
  1469. torch._check(False, lambda: "triangular_solve: Got an unexpected layout.")
  1470. return solution, cloned_coefficient # type: ignore[possibly-undefined]
  1471. # From aten/src/ATen/native/LinearAlgebra.cpp
  1472. @register_meta(aten._linalg_det.default)
  1473. def _linalg_det_meta(A):
  1474. squareCheckInputs(A, "linalg.det")
  1475. checkFloatingOrComplex(A, "linalg.det")
  1476. det = A.new_empty(A.shape[:-2])
  1477. LU = A.new_empty(A.shape)
  1478. LU.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False))
  1479. pivots = A.new_empty(A.shape[:-1], dtype=torch.int32)
  1480. return det, LU, pivots
  1481. @register_meta(aten.ormqr)
  1482. @out_wrapper()
  1483. def ormqr(
  1484. input: Tensor,
  1485. tau: Tensor,
  1486. other: Tensor,
  1487. left: bool = True,
  1488. transpose: bool = False,
  1489. ) -> Tensor:
  1490. torch._check(
  1491. input.ndim >= 2, lambda: "torch.ormqr: input must have at least 2 dimensions."
  1492. )
  1493. torch._check(
  1494. other.ndim >= 2, lambda: "torch.ormqr: other must have at least 2 dimensions."
  1495. )
  1496. left_size_condition = -2 if left else -1
  1497. torch._check(
  1498. other.shape[left_size_condition] >= tau.shape[-1],
  1499. lambda: f"torch.ormqr: other.shape[{left_size_condition}] must be greater than or equal to tau.shape[-1]",
  1500. )
  1501. torch._check(
  1502. other.shape[left_size_condition] == input.shape[-2],
  1503. lambda: f"torch.ormqr: other.shape[{left_size_condition}] must be equal to input.shape[-2]",
  1504. )
  1505. torch._check(
  1506. tau.shape[-1] <= input.shape[-1],
  1507. lambda: "torch.ormqr: tau.shape[-1] must be less than or equal to input.shape[-1]",
  1508. )
  1509. torch._check(
  1510. input.ndim - tau.ndim == 1,
  1511. lambda: (
  1512. f"torch.ormqr: Expected tau to have one dimension less than input, "
  1513. f"but got tau.ndim equal to {tau.ndim} and input.ndim is equal to {input.ndim}"
  1514. ),
  1515. )
  1516. torch._check(
  1517. input.ndim == other.ndim,
  1518. lambda: (
  1519. f"torch.ormqr: Expected other to have the same number of dimensions as input, "
  1520. f"but got other.ndim equal to {other.ndim} and input.ndim is equal to {input.ndim}"
  1521. ),
  1522. )
  1523. if input.ndim > 2:
  1524. expected_batch_shape = input.shape[:-2]
  1525. actual_batch_tau_shape = tau.shape[:-1]
  1526. torch._check(
  1527. actual_batch_tau_shape == expected_batch_shape,
  1528. lambda: (
  1529. f"torch.ormqr: Expected batch dimensions of tau to be "
  1530. f"equal to input.shape[:-2], but got {actual_batch_tau_shape}"
  1531. ),
  1532. )
  1533. actual_batch_other_shape = other.shape[:-2]
  1534. torch._check(
  1535. actual_batch_other_shape == expected_batch_shape,
  1536. lambda: (
  1537. f"torch.ormqr: Expected batch dimensions of other to be "
  1538. f"equal to input.shape[:-2], but got {actual_batch_other_shape}"
  1539. ),
  1540. )
  1541. torch._check(
  1542. tau.dtype == input.dtype,
  1543. lambda: (
  1544. f"torch.ormqr: Expected input and tau to have the same dtype, "
  1545. f"but input has dtype {input.dtype} and tau has dtype {tau.dtype}"
  1546. ),
  1547. )
  1548. torch._check(
  1549. other.dtype == input.dtype,
  1550. lambda: (
  1551. f"torch.ormqr: Expected input and other to have the same dtype, "
  1552. f"but input has dtype {input.dtype} and other has dtype {other.dtype}"
  1553. ),
  1554. )
  1555. checkSameDevice("torch.ormqr", tau, input, "tau")
  1556. checkSameDevice("torch.ormqr", other, input, "other")
  1557. return torch.empty_strided(
  1558. size=other.shape,
  1559. stride=make_contiguous_strides_for(other.shape, row_major=False),
  1560. dtype=other.dtype,
  1561. device=other.device,
  1562. )
  1563. def _padding_check_valid_input(input, padding, *, dim):
  1564. torch._check(
  1565. len(padding) == 2 * dim,
  1566. lambda: f"padding size is expected to be {2 * dim}, but got: {len(padding)}",
  1567. )
  1568. input_dim = input.ndim
  1569. is_batch_mode = input_dim == (dim + 2)
  1570. valid_batch_mode = is_batch_mode
  1571. valid_non_batch_mode = not is_batch_mode
  1572. if is_batch_mode:
  1573. # allow batch size of 0-dim.
  1574. for d in range(1, input_dim):
  1575. valid_batch_mode = valid_batch_mode and input.size(d) != 0
  1576. else:
  1577. for d in range(input_dim):
  1578. valid_non_batch_mode = valid_non_batch_mode and input.size(d) != 0
  1579. # allow empty batch size but not other dimensions.
  1580. torch._check(
  1581. valid_batch_mode or valid_non_batch_mode,
  1582. lambda: (
  1583. f"Expected {dim + 1}D or {dim + 2}D (batch mode) tensor with possibly 0 batch size "
  1584. f"and other non-zero dimensions for input, but got: {input.shape}"
  1585. ),
  1586. )
  1587. def _pad1d_common(input, padding, *, is_reflection):
  1588. dim_plane = 0
  1589. dim_w = 1
  1590. nbatch = 1
  1591. if input.ndim == 3:
  1592. nbatch = input.size(0)
  1593. dim_w += 1
  1594. dim_plane += 1
  1595. _padding_check_valid_input(input, padding, dim=1)
  1596. pad_l, pad_r = padding
  1597. nplane = input.size(dim_plane)
  1598. input_w = input.size(dim_w)
  1599. output_w = input_w + pad_l + pad_r
  1600. if is_reflection:
  1601. torch._check(
  1602. pad_l < input_w and pad_r < input_w,
  1603. lambda: (
  1604. f"Argument #4: Padding size should be less than the corresponding input dimension, "
  1605. f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
  1606. ),
  1607. )
  1608. torch._check(
  1609. output_w >= 1,
  1610. lambda: f"input (W: {input_w}) is too small. Calculated output W: {output_w}",
  1611. )
  1612. if input.ndim == 2:
  1613. return input.new_empty((nplane, output_w))
  1614. else:
  1615. return input.new_empty((nbatch, nplane, output_w))
  1616. @register_meta(aten.reflection_pad1d)
  1617. @out_wrapper()
  1618. def meta_reflection_pad1d(input, padding):
  1619. return _pad1d_common(input, padding, is_reflection=True)
  1620. @register_meta(aten.replication_pad1d)
  1621. @out_wrapper()
  1622. def meta_replication_pad1d(input, padding):
  1623. torch._check(
  1624. input.dtype != torch.bool,
  1625. lambda: f""""replication_pad1d" not implemented for '{input.dtype.__str__()}'""",
  1626. )
  1627. return _pad1d_common(input, padding, is_reflection=False)
  1628. def _pad1d_backward_common(grad_output, input, padding, *, is_reflection):
  1629. dim_w = 1
  1630. if not is_reflection:
  1631. torch._check(len(padding) == 2, lambda: "padding size is expected to be 2")
  1632. if input.ndim == 3:
  1633. dim_w += 1
  1634. pad_l, pad_r = padding
  1635. input_w = input.size(dim_w)
  1636. output_w = input_w + pad_l + pad_r
  1637. if is_reflection:
  1638. torch._check(
  1639. pad_l < input_w and pad_r < input_w,
  1640. lambda: (
  1641. f"Argument #4: Padding size should be less than the corresponding input dimension, "
  1642. f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
  1643. ),
  1644. )
  1645. torch._check(
  1646. output_w == grad_output.size(dim_w),
  1647. lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
  1648. )
  1649. return input.new_empty(input.shape)
  1650. @register_meta(aten.reflection_pad1d_backward)
  1651. @out_wrapper("grad_input")
  1652. def meta_reflection_pad1d_backward(grad_output, input, padding):
  1653. return _pad1d_backward_common(grad_output, input, padding, is_reflection=True)
  1654. @register_meta(aten.replication_pad1d_backward)
  1655. @out_wrapper("grad_input")
  1656. def meta_replication_pad1d_backward(grad_output, input, padding):
  1657. return _pad1d_backward_common(grad_output, input, padding, is_reflection=False)
  1658. def _pad2d_common(input, padding, *, is_reflection):
  1659. dim_w = 2
  1660. dim_h = 1
  1661. dim_slices = 0
  1662. nbatch = 1
  1663. _padding_check_valid_input(input, padding, dim=2)
  1664. ndim = input.ndim
  1665. if ndim == 4:
  1666. nbatch = input.size(0)
  1667. dim_w += 1
  1668. dim_h += 1
  1669. dim_slices += 1
  1670. pad_l, pad_r, pad_t, pad_b = padding
  1671. nplane = input.size(dim_slices)
  1672. input_h = input.size(dim_h)
  1673. input_w = input.size(dim_w)
  1674. output_h = input_h + pad_t + pad_b
  1675. output_w = input_w + pad_l + pad_r
  1676. if is_reflection:
  1677. torch._check(
  1678. pad_l < input_w and pad_r < input_w,
  1679. lambda: (
  1680. f"Argument #4: Padding size should be less than the corresponding input dimension, "
  1681. f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
  1682. ),
  1683. )
  1684. torch._check(
  1685. pad_t < input_h and pad_b < input_h,
  1686. lambda: (
  1687. f"Argument #6: Padding size should be less than the corresponding input dimension, "
  1688. f"but got: padding ({pad_t}, {pad_b}) at dimension {dim_h} of input {input.shape}"
  1689. ),
  1690. )
  1691. torch._check(
  1692. output_w >= 1 or output_h >= 1,
  1693. lambda: (
  1694. f"input (H: {input_h} W: {input_w}) is too small. "
  1695. f"Calculated output H: {output_h} W: {output_w}"
  1696. ),
  1697. )
  1698. if input.ndim == 3:
  1699. return input.new_empty((nplane, output_h, output_w))
  1700. else:
  1701. return input.new_empty((nbatch, nplane, output_h, output_w))
  1702. @register_meta(aten.reflection_pad2d)
  1703. @out_wrapper()
  1704. def meta_reflection_pad2d(input, padding):
  1705. return _pad2d_common(input, padding, is_reflection=True)
  1706. @register_meta(aten.replication_pad2d)
  1707. @out_wrapper()
  1708. def meta_replication_pad2d(input, padding):
  1709. torch._check(
  1710. input.dtype != torch.bool,
  1711. lambda: f""""replication_pad2d" not implemented for '{input.dtype.__str__()}'""",
  1712. )
  1713. return _pad2d_common(input, padding, is_reflection=False)
  1714. @register_meta(
  1715. aten._weight_norm_interface_backward.default,
  1716. )
  1717. def meta_weight_norm_backward(grad_w, saved_v, saved_g, saved_norms, dim):
  1718. grad_v = torch.empty_like(saved_v)
  1719. grad_g = torch.empty_like(saved_g)
  1720. return grad_v, grad_g
  1721. @register_meta(
  1722. [
  1723. aten.reflection_pad2d_backward.default,
  1724. aten.reflection_pad2d_backward.grad_input,
  1725. aten.replication_pad2d_backward.default,
  1726. aten.replication_pad2d_backward.grad_input,
  1727. ]
  1728. )
  1729. @out_wrapper("grad_input")
  1730. def meta_pad2d_backward(grad_output, self, padding):
  1731. dim_w = 2
  1732. dim_h = 1
  1733. dim_plane = 0
  1734. self_shape = self.shape
  1735. if self.dim() == 4:
  1736. dim_w += 1
  1737. dim_h += 1
  1738. dim_plane += 1
  1739. pad_l, pad_r, pad_t, pad_b = padding
  1740. input_h = self_shape[dim_h]
  1741. input_w = self_shape[dim_w]
  1742. output_h = input_h + pad_t + pad_b
  1743. output_w = input_w + pad_l + pad_r
  1744. torch._check(
  1745. output_w == grad_output.size(dim_w),
  1746. lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
  1747. )
  1748. torch._check(
  1749. output_h == grad_output.size(dim_h),
  1750. lambda: f"grad_output height unexpected. Expected: {output_h}, Got: {grad_output.size(dim_h)}",
  1751. )
  1752. return self.new_empty(self.shape)
  1753. def _pad3d_common(input, padding, *, is_reflection):
  1754. dim_w = 3
  1755. dim_h = 2
  1756. dim_d = 1
  1757. dim_plane = 0
  1758. _padding_check_valid_input(input, padding, dim=3)
  1759. batch_mode = input.ndim == 5
  1760. if batch_mode:
  1761. nbatch = input.size(0)
  1762. dim_w += 1
  1763. dim_h += 1
  1764. dim_d += 1
  1765. dim_plane += 1
  1766. pad_l, pad_r, pad_t, pad_b, pad_f, pad_bk = padding
  1767. nplane = input.size(dim_plane)
  1768. input_d = input.size(dim_d)
  1769. input_h = input.size(dim_h)
  1770. input_w = input.size(dim_w)
  1771. output_d = input_d + pad_f + pad_bk
  1772. output_h = input_h + pad_t + pad_b
  1773. output_w = input_w + pad_l + pad_r
  1774. if is_reflection:
  1775. torch._check(
  1776. pad_l < input_w and pad_r < input_w,
  1777. lambda: (
  1778. f"Argument #4: Padding size should be less than the corresponding input dimension, "
  1779. f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
  1780. ),
  1781. )
  1782. torch._check(
  1783. pad_t < input_h and pad_b < input_h,
  1784. lambda: (
  1785. f"Argument #6: Padding size should be less than the corresponding input dimension, "
  1786. f"but got: padding ({pad_t}, {pad_b}) at dimension {dim_h} of input {input.shape}"
  1787. ),
  1788. )
  1789. torch._check(
  1790. pad_f < input_d and pad_bk < input_d,
  1791. lambda: (
  1792. f"Argument #8: Padding size should be less than the corresponding input dimension, "
  1793. f"but got: padding ({pad_f}, {pad_bk}) at dimension {dim_d} of input {input.shape}"
  1794. ),
  1795. )
  1796. torch._check(
  1797. output_w >= 1 or output_h >= 1 or output_d >= 1,
  1798. lambda: (
  1799. f"input (D: {input_d} H: {input_h} W: {input_w}) is too small. "
  1800. f"Calculated output D: {output_d} H: {output_h} W: {output_w}"
  1801. ),
  1802. )
  1803. if batch_mode:
  1804. return input.new_empty((nbatch, nplane, output_d, output_h, output_w)) # type: ignore[possibly-undefined]
  1805. else:
  1806. return input.new_empty((nplane, output_d, output_h, output_w))
  1807. @register_meta(aten.reflection_pad3d)
  1808. @out_wrapper()
  1809. def meta_reflection_pad3d(input, padding):
  1810. return _pad3d_common(input, padding, is_reflection=True)
  1811. @register_meta(aten.replication_pad3d)
  1812. @out_wrapper()
  1813. def meta_replication_pad3d(input, padding):
  1814. torch._check(
  1815. input.dtype != torch.bool,
  1816. lambda: f""""replication_pad3d" not implemented for '{input.dtype.__str__()}'""",
  1817. )
  1818. return _pad3d_common(input, padding, is_reflection=False)
  1819. @register_meta(
  1820. [
  1821. aten.reflection_pad3d_backward.default,
  1822. aten.reflection_pad3d_backward.grad_input,
  1823. aten.replication_pad3d_backward.default,
  1824. aten.replication_pad3d_backward.grad_input,
  1825. ]
  1826. )
  1827. @out_wrapper("grad_input")
  1828. def meta_pad3d_backward(grad_output, input, padding):
  1829. torch._check(len(padding) == 6, lambda: "padding size is expected to be 6")
  1830. if input.ndim <= 3:
  1831. raise AssertionError(f"input.ndim must be > 3, got {input.ndim}")
  1832. if grad_output.ndim != input.ndim:
  1833. raise AssertionError(
  1834. f"grad_output.ndim must equal input.ndim, got {grad_output.ndim} != {input.ndim}"
  1835. )
  1836. dim_w = 3
  1837. dim_h = 2
  1838. dim_d = 1
  1839. if input.ndim == 5:
  1840. dim_w += 1
  1841. dim_h += 1
  1842. dim_d += 1
  1843. pad_l, pad_r, pad_t, pad_b, pad_f, pad_bk = padding
  1844. input_d = input.size(dim_d)
  1845. input_h = input.size(dim_h)
  1846. input_w = input.size(dim_w)
  1847. output_d = input_d + pad_f + pad_bk
  1848. output_h = input_h + pad_t + pad_b
  1849. output_w = input_w + pad_l + pad_r
  1850. torch._check(
  1851. output_w == grad_output.size(dim_w),
  1852. lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
  1853. )
  1854. torch._check(
  1855. output_h == grad_output.size(dim_h),
  1856. lambda: f"grad_output height unexpected. Expected: {output_h}, Got: {grad_output.size(dim_h)}",
  1857. )
  1858. torch._check(
  1859. output_d == grad_output.size(dim_d),
  1860. lambda: f"grad_output depth unexpected. Expected: {output_d}, Got: {grad_output.size(dim_d)}",
  1861. )
  1862. return input.new_empty(input.shape)
  1863. @register_meta(aten._pdist_forward)
  1864. @out_wrapper()
  1865. def meta__pdist_forward(self: Tensor, p: float = 2) -> Tensor:
  1866. torch._check(
  1867. self.is_contiguous(), lambda: "_pdist_forward requires contiguous input"
  1868. )
  1869. n = self.size(0)
  1870. if n <= 1:
  1871. return self.new_empty([0]).to(memory_format=torch.legacy_contiguous_format) # type: ignore[call-overload]
  1872. else:
  1873. return self.new_empty((n * (n - 1) // 2,)).to(
  1874. memory_format=torch.legacy_contiguous_format
  1875. ) # type: ignore[call-overload]
  1876. @register_meta(aten._pdist_backward)
  1877. @out_wrapper()
  1878. def meta__pdist_backward(grad: Tensor, self: Tensor, p: float, pdist: Tensor) -> Tensor:
  1879. torch._check(
  1880. self.is_contiguous(), lambda: "_pdist_backward requires self to be contiguous"
  1881. )
  1882. torch._check(
  1883. pdist.is_contiguous(), lambda: "_pdist_backward requires pdist to be contiguous"
  1884. )
  1885. return torch.empty_like(self, memory_format=torch.legacy_contiguous_format)
  1886. @register_meta([aten.baddbmm.default, aten.baddbmm.out])
  1887. @out_wrapper(exact_dtype=True)
  1888. def meta_baddbmm(self, batch1, batch2, *, beta=1, alpha=1):
  1889. from torch.fx.experimental.symbolic_shapes import guard_or_true, sym_eq
  1890. dim1 = batch1.size(0)
  1891. dim2 = batch1.size(1)
  1892. dim3 = batch2.size(2)
  1893. if guard_or_true(torch.sym_not(sym_eq(self.shape, (dim1, dim2, dim3)))):
  1894. self = self.expand((dim1, dim2, dim3))
  1895. torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
  1896. torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
  1897. if not exp_config.skip_dtype_check_in_meta_registrations:
  1898. torch._check(
  1899. self.dtype == batch1.dtype == batch2.dtype,
  1900. lambda: f"Input dtypes must be the same, got: input: {self.dtype}, batch1: {batch1.dtype}, batch2: {batch2.dtype}",
  1901. )
  1902. batch1_sizes = batch1.shape
  1903. batch2_sizes = batch2.shape
  1904. bs = batch1_sizes[0]
  1905. contraction_size = batch1_sizes[2]
  1906. torch._check(
  1907. batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size,
  1908. lambda: (
  1909. f"Expected size for first two dimensions of batch2 tensor to be: "
  1910. f"[{bs}, {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}]."
  1911. ),
  1912. )
  1913. return self.new_empty(self.size())
  1914. @register_meta([aten.bernoulli.default, aten.bernoulli.out])
  1915. @out_wrapper()
  1916. def meta_bernoulli(self, *, generator=None):
  1917. # https://github.com/pytorch/pytorch/issues/88612
  1918. return torch.empty_like(self, memory_format=torch.contiguous_format)
  1919. @register_meta(aten.bernoulli_.float)
  1920. def meta_bernoulli_(self, p=0.5, generator=None):
  1921. return self
  1922. @register_meta(aten.bernoulli.p)
  1923. def meta_bernoulli_p(self, p=0.5, generator=None):
  1924. # https://github.com/pytorch/pytorch/issues/88612
  1925. return torch.empty_like(self, memory_format=torch.contiguous_format)
  1926. @register_meta([aten.poisson.default, aten.poisson.out])
  1927. @out_wrapper()
  1928. def meta_poisson(self, generator=None):
  1929. return torch.empty_like(self)
  1930. @register_meta(aten._fused_moving_avg_obs_fq_helper.default)
  1931. def meta__fused_moving_avg_obs_fq_helper(
  1932. self,
  1933. observer_on,
  1934. fake_quant_on,
  1935. running_min,
  1936. running_max,
  1937. scale,
  1938. zero_point,
  1939. averaging_const,
  1940. quant_min,
  1941. quant_max,
  1942. ch_axis,
  1943. per_row_fake_quant=False,
  1944. symmetric_quant=False,
  1945. ):
  1946. torch._check(
  1947. ch_axis < self.dim(),
  1948. lambda: "Error in fused_moving_avg_obs_fake_quant_cpu: ch_axis must be < self.dim()",
  1949. )
  1950. mask = torch.empty_like(self, dtype=torch.bool)
  1951. return (torch.empty_like(self), mask)
  1952. @register_meta(aten.mm)
  1953. @out_wrapper(exact_dtype=True)
  1954. def meta_mm(a, b, out_dtype: torch.dtype | None = None):
  1955. torch._check(a.dim() == 2, lambda: "a must be 2D")
  1956. torch._check(b.dim() == 2, lambda: "b must be 2D")
  1957. N, M1 = a.shape
  1958. M2, P = b.shape
  1959. torch._check(
  1960. M1 == M2,
  1961. lambda: f"a and b must have same reduction dim, but got [{N}, {M1}] X [{M2}, {P}].",
  1962. )
  1963. if out_dtype is not None:
  1964. torch._check(
  1965. out_dtype == a.dtype
  1966. or (
  1967. out_dtype == torch.float32
  1968. and a.dtype in (torch.float16, torch.bfloat16)
  1969. ),
  1970. lambda: "out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs",
  1971. )
  1972. result_dtype = a.dtype if out_dtype is None else out_dtype
  1973. return a.new_empty((N, P), dtype=result_dtype)
  1974. def _compute_reduction_shape(self, dims, keepdim):
  1975. if keepdim:
  1976. return tuple(self.shape[i] if i not in dims else 1 for i in range(self.ndim))
  1977. return utils.compute_reduction_output_shape(self.shape, dims)
  1978. # FakeTensors (meta tensors with a device) will report device as meta
  1979. # when running meta kernels. Here, access the "fake device" of FakeTensor if it
  1980. # exists so meta kernels which have diverge per device will be more
  1981. # accurate when run with FakeTensors
  1982. def device_hint(tensor) -> "str":
  1983. if isinstance(tensor, torch._subclasses.FakeTensor):
  1984. return tensor.fake_device.type
  1985. elif (
  1986. hasattr(tensor, "device")
  1987. and hasattr(tensor.device, "type")
  1988. and tensor.device.type != "meta"
  1989. ):
  1990. return tensor.device.type
  1991. else:
  1992. return "cuda" # default to cuda
  1993. def calc_conv_nd_return_shape(
  1994. input_tensor: torch.Tensor,
  1995. weight: torch.Tensor,
  1996. stride: list[int] | int,
  1997. padding: list[int] | int,
  1998. dilation: list[int] | int,
  1999. is_transposed: bool,
  2000. groups: int,
  2001. output_padding: list[int] | int | None = None,
  2002. ):
  2003. def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
  2004. """
  2005. Formula to apply to calculate the length of some dimension of the output
  2006. See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
  2007. Args:
  2008. ln: length of the dimension
  2009. p: padding in that dim
  2010. d: dilation in that dim
  2011. k: kernel size in that dim
  2012. s: stride in that dim
  2013. Returns:
  2014. The output length
  2015. """
  2016. return (ln + 2 * p - d * (k - 1) - 1) // s + 1
  2017. def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int:
  2018. """
  2019. Formula to apply to calculate the length of some dimension of the output
  2020. if transposed convolution is used.
  2021. See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
  2022. Args:
  2023. ln: length of the dimension
  2024. p: padding in that dim
  2025. d: dilation in that dim
  2026. k: kernel size in that dim
  2027. s: stride in that dim
  2028. op: output padding in that dim
  2029. Returns:
  2030. The output length
  2031. """
  2032. return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1
  2033. kernel_size = weight.shape[2:]
  2034. dims = input_tensor.shape[2:]
  2035. if is_transposed:
  2036. out_channels = groups * weight.shape[1]
  2037. else:
  2038. out_channels = weight.shape[0]
  2039. if weight.shape[1] * groups != input_tensor.shape[1]:
  2040. raise RuntimeError("Invalid channel dimensions")
  2041. ret_shape = [input_tensor.shape[0], out_channels]
  2042. if isinstance(stride, IntLike):
  2043. # pyrefly: ignore [bad-assignment]
  2044. stride = [stride] * len(dims)
  2045. elif len(stride) == 1:
  2046. stride = [stride[0]] * len(dims)
  2047. if isinstance(padding, IntLike):
  2048. # pyrefly: ignore [bad-assignment]
  2049. padding = [padding] * len(dims)
  2050. elif len(padding) == 1:
  2051. padding = [padding[0]] * len(dims)
  2052. if isinstance(dilation, IntLike):
  2053. # pyrefly: ignore [bad-assignment]
  2054. dilation = [dilation] * len(dims)
  2055. elif len(dilation) == 1:
  2056. dilation = [dilation[0]] * len(dims)
  2057. output_padding_list: list[int] | None = None
  2058. if output_padding:
  2059. if isinstance(output_padding, IntLike):
  2060. # pyrefly: ignore [bad-assignment]
  2061. output_padding_list = [output_padding] * len(dims)
  2062. elif len(output_padding) == 1:
  2063. output_padding_list = [output_padding[0]] * len(dims)
  2064. else:
  2065. output_padding_list = output_padding
  2066. for i in range(len(dims)):
  2067. # If output_padding is present, we are dealing with a transposed convolution
  2068. if output_padding_list:
  2069. ret_shape.append(
  2070. _formula_transposed(
  2071. dims[i],
  2072. # pyrefly: ignore [bad-index]
  2073. padding[i],
  2074. # pyrefly: ignore [bad-index, index-error]
  2075. # pyrefly: ignore [bad-index, index-error]
  2076. dilation[i],
  2077. kernel_size[i],
  2078. # pyrefly: ignore [bad-index, index-error]
  2079. stride[i],
  2080. output_padding_list[i],
  2081. )
  2082. )
  2083. else:
  2084. ret_shape.append(
  2085. # pyrefly: ignore [bad-index, index-error]
  2086. _formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])
  2087. )
  2088. # NOTE: Backend behavior for zero-sized spatial dimensions is inconsistent.
  2089. # CUDA (cuDNN) handles zero-sized outputs gracefully by short-circuiting,
  2090. # but other backends fail: CPU rejects it, ROCm/miopen returns
  2091. # miopenStatusBadParm, and MPS asserts "Placeholder tensor is empty".
  2092. # We only allow zero-sized outputs on CUDA with cuDNN (not ROCm/HIP).
  2093. from torch._subclasses.fake_tensor import FakeTensor
  2094. from torch.fx.experimental.symbolic_shapes import sym_or
  2095. device = (
  2096. input_tensor.fake_device
  2097. if isinstance(input_tensor, FakeTensor)
  2098. else input_tensor.device
  2099. )
  2100. # ROCm also reports device.type as "cuda", but miopen doesn't support zero-sized outputs
  2101. is_cudnn = device.type == "cuda" and torch.version.hip is None
  2102. if not is_cudnn:
  2103. torch._check(
  2104. sym_or(*[x > 0 for x in ret_shape[2:]]),
  2105. lambda: f"Given input size per channel: {list(dims)}. "
  2106. f"Calculated output size per channel: {ret_shape[2:]}. "
  2107. f"Output size is too small",
  2108. )
  2109. return ret_shape
  2110. def is_channels_last(ten):
  2111. return torch._prims_common.suggest_memory_format(ten) == torch.channels_last
  2112. @register_meta(aten.miopen_batch_norm.default)
  2113. def meta_miopen_batch_norm(
  2114. input_tensor: torch.Tensor,
  2115. weight: torch.Tensor,
  2116. bias: torch.Tensor | None,
  2117. running_mean: torch.Tensor | None,
  2118. running_var: torch.Tensor | None,
  2119. training: bool,
  2120. exponential_average_factor: float,
  2121. epsilon: float,
  2122. ):
  2123. # In batch norm the output is of the same shape as the input
  2124. out_shape = input_tensor.shape
  2125. # If tensor is provided for running_mean and running_var then use this. If these are not
  2126. # provided then we return the shape of weight tensor. Similar to how this is handled in the decomposition
  2127. save_mean_shape = running_mean.shape if running_mean is not None else weight.shape
  2128. save_var_shape = running_var.shape if running_var is not None else weight.shape
  2129. def pick_memory_format():
  2130. if is_channels_last(input_tensor):
  2131. return torch.channels_last
  2132. if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
  2133. return torch.contiguous_format
  2134. return torch.contiguous_format
  2135. out = input_tensor.new_empty(out_shape).to(memory_format=pick_memory_format())
  2136. if training:
  2137. save_mean = input_tensor.new_empty(save_mean_shape)
  2138. save_var = input_tensor.new_empty(save_var_shape)
  2139. else:
  2140. save_mean = input_tensor.new_empty((0,))
  2141. save_var = input_tensor.new_empty((0,))
  2142. return out, save_mean, save_var
  2143. @register_meta(aten.convolution.default)
  2144. def meta_conv(
  2145. input_tensor: torch.Tensor,
  2146. weight: torch.Tensor,
  2147. bias: torch.Tensor,
  2148. stride: list[int],
  2149. padding: list[int],
  2150. dilation: list[int],
  2151. is_transposed: bool,
  2152. output_padding: list[int],
  2153. groups: int,
  2154. ):
  2155. shape_out = calc_conv_nd_return_shape(
  2156. input_tensor,
  2157. weight,
  2158. stride,
  2159. padding,
  2160. dilation,
  2161. is_transposed,
  2162. groups,
  2163. output_padding if is_transposed else None,
  2164. )
  2165. input_channels_dim = 1
  2166. output_channels_dim = 1
  2167. if input_tensor.size(input_channels_dim) == 0:
  2168. shape_out[output_channels_dim] = 0
  2169. out = input_tensor.new_empty(shape_out)
  2170. return out
  2171. if torch._C._has_mkldnn:
  2172. _meta_lib_dont_use_me_use_register_meta_for_mkldnn = torch.library.Library(
  2173. "mkldnn", "IMPL", "Meta"
  2174. )
  2175. @register_meta(torch.ops.mkldnn._convolution_pointwise.default)
  2176. def meta_mkldnn_convolution_default(
  2177. input_tensor,
  2178. weight,
  2179. bias,
  2180. padding,
  2181. stride,
  2182. dilation,
  2183. groups,
  2184. attr,
  2185. scalars,
  2186. algorithm,
  2187. ):
  2188. shape_out = calc_conv_nd_return_shape(
  2189. input_tensor, weight, stride, padding, dilation, False, groups, []
  2190. )
  2191. out = input_tensor.new_empty(shape_out)
  2192. out_memory_format = torch.channels_last
  2193. if input_tensor.dim() == 5:
  2194. out_memory_format = torch.channels_last_3d
  2195. out = out.to(memory_format=out_memory_format) # type: ignore[call-overload]
  2196. return out
  2197. @register_meta(torch.ops.mkldnn._linear_pointwise.default)
  2198. def meta_linear_pointwise_default(
  2199. input_tensor, weight, bias, attr, scalars, algorithm
  2200. ):
  2201. return input_tensor.new_empty((*input_tensor.shape[:-1], weight.shape[0]))
  2202. if torch._C.has_mkl:
  2203. _meta_lib_dont_use_me_use_register_meta_for_mkl = torch.library.Library(
  2204. "mkl", "IMPL", "Meta"
  2205. )
  2206. @register_meta(torch.ops.mkl._mkl_linear)
  2207. def meta_mkl_linear(input_tensor, packed_weight, orig_weight, bias, batch_size):
  2208. return input_tensor.new_empty(
  2209. (*input_tensor.shape[:-1], orig_weight.shape[0])
  2210. )
  2211. _meta_lib_dont_use_me_use_register_meta_for_onednn = torch.library.Library(
  2212. "onednn", "IMPL", "Meta"
  2213. )
  2214. @register_meta(torch.ops.onednn.qconv2d_pointwise.default)
  2215. @register_meta(torch.ops.onednn.qconv_pointwise.default)
  2216. @register_meta(torch.ops.onednn.qconv_pointwise.tensor)
  2217. def meta_qconv_pointwise(
  2218. x,
  2219. x_scale,
  2220. x_zp,
  2221. w, # prepacked_weight
  2222. w_scale,
  2223. w_zp,
  2224. bias,
  2225. stride,
  2226. padding,
  2227. dilation,
  2228. groups,
  2229. output_scale,
  2230. output_zero_point,
  2231. output_dtype,
  2232. attr,
  2233. scalars,
  2234. algorithm,
  2235. ):
  2236. shape_out = calc_conv_nd_return_shape(
  2237. x,
  2238. w,
  2239. stride,
  2240. padding,
  2241. dilation,
  2242. False,
  2243. groups,
  2244. None,
  2245. )
  2246. if output_dtype is None:
  2247. output_dtype = x.dtype
  2248. if output_dtype not in [
  2249. torch.float32,
  2250. torch.bfloat16,
  2251. torch.uint8,
  2252. torch.int8,
  2253. torch.float8_e4m3fn,
  2254. ]:
  2255. raise AssertionError(
  2256. f"output_dtype must be one of float32, bfloat16, uint8, int8, float8_e4m3fn, got {output_dtype}"
  2257. )
  2258. out = x.new_empty(shape_out, dtype=output_dtype)
  2259. if len(shape_out) not in [3, 4, 5]:
  2260. raise AssertionError(
  2261. f"Expect output to be 3d/4d/5d for conv1d/2d/3d, got {len(shape_out)}d"
  2262. )
  2263. format = {
  2264. 3: torch.contiguous_format,
  2265. 4: torch.channels_last,
  2266. 5: torch.channels_last_3d,
  2267. }[len(shape_out)]
  2268. out = out.to(memory_format=format)
  2269. return out
  2270. @register_meta(torch.ops.onednn.qconv2d_pointwise.binary)
  2271. @register_meta(torch.ops.onednn.qconv2d_pointwise.binary_tensor)
  2272. def meta_qconv2d_pointwise_binary(
  2273. x,
  2274. x_scale,
  2275. x_zp,
  2276. w,
  2277. w_scale,
  2278. w_zp,
  2279. accum,
  2280. bias,
  2281. stride,
  2282. padding,
  2283. dilation,
  2284. groups,
  2285. output_scale,
  2286. output_zero_point,
  2287. output_dtype,
  2288. accum_scale,
  2289. accum_zero_point,
  2290. binary_op_name,
  2291. alpha,
  2292. unary_op_name,
  2293. unary_op_args,
  2294. unary_op_algorithm,
  2295. ):
  2296. if binary_op_name != "sum":
  2297. raise AssertionError(
  2298. f"binary_op_name must be 'sum', got '{binary_op_name}'"
  2299. )
  2300. return accum
  2301. @register_meta(torch.ops.onednn.qlinear_pointwise.default)
  2302. @register_meta(torch.ops.onednn.qlinear_pointwise.tensor)
  2303. def meta_qlinear_pointwise(
  2304. x,
  2305. x_scale,
  2306. x_zp,
  2307. w,
  2308. w_scale,
  2309. w_zp,
  2310. bias,
  2311. output_scale,
  2312. output_zero_point,
  2313. output_dtype,
  2314. post_op_name,
  2315. post_op_args,
  2316. post_op_algorithm,
  2317. ):
  2318. output_shape = list(x.shape)
  2319. # The weight has been transposed during the qlinear weight prepack process.
  2320. output_shape[-1] = w.shape[1]
  2321. if output_dtype not in [
  2322. torch.float32,
  2323. torch.bfloat16,
  2324. torch.int8,
  2325. torch.uint8,
  2326. torch.float8_e4m3fn,
  2327. ]:
  2328. raise AssertionError(
  2329. f"output_dtype must be one of float32, bfloat16, int8, uint8, float8_e4m3fn, got {output_dtype}"
  2330. )
  2331. out = x.new_empty(output_shape, dtype=output_dtype)
  2332. return out
  2333. @register_meta(torch.ops.onednn.qlinear_pointwise.binary)
  2334. @register_meta(torch.ops.onednn.qlinear_pointwise.binary_tensor)
  2335. def meta_qlinear_pointwise_binary(
  2336. x,
  2337. x_scale,
  2338. x_zp,
  2339. w,
  2340. w_scale,
  2341. w_zp,
  2342. x_2,
  2343. bias,
  2344. output_scale,
  2345. output_zero_point,
  2346. output_dtype,
  2347. x2_scale,
  2348. x2_zp,
  2349. binary_op_name,
  2350. alpha,
  2351. unary_op_name,
  2352. unary_op_args,
  2353. unary_op_algorithm,
  2354. ):
  2355. if binary_op_name == "sum":
  2356. return x_2
  2357. output_shape = list(x.shape)
  2358. # The weight has been transposed during the qlinear weight prepack process.
  2359. output_shape[-1] = w.shape[1]
  2360. if output_dtype not in [
  2361. torch.float32,
  2362. torch.bfloat16,
  2363. torch.uint8,
  2364. torch.int8,
  2365. torch.float8_e4m3fn,
  2366. ]:
  2367. raise AssertionError(
  2368. f"output_dtype must be one of float32, bfloat16, uint8, int8, float8_e4m3fn, got {output_dtype}"
  2369. )
  2370. out = x.new_empty(output_shape, dtype=output_dtype)
  2371. return out
  2372. @register_meta(torch.ops.onednn.linear_dynamic_fp16.default)
  2373. @register_meta(torch.ops.onednn.linear_relu_dynamic_fp16.default)
  2374. def meta_linear_dynamic_fp16(
  2375. x,
  2376. w,
  2377. bias,
  2378. ):
  2379. output_shape = list(x.shape)
  2380. # The weight has been transposed during the qlinear weight prepack process.
  2381. output_shape[-1] = w.shape[1]
  2382. out = x.new_empty(output_shape)
  2383. return out
  2384. _meta_lib_dont_use_me_use_register_meta_for_quantized = torch.library.Library(
  2385. "quantized", "IMPL", "Meta"
  2386. )
  2387. @register_meta(torch.ops.quantized.max_pool2d)
  2388. def meta_quantized_max_pool2d(
  2389. input,
  2390. kernel_size,
  2391. stride=(),
  2392. padding=(0,),
  2393. dilation=(1,),
  2394. ceil_mode=False,
  2395. ):
  2396. (
  2397. nInputPlane,
  2398. outputHeight,
  2399. outputWidth,
  2400. ) = max_pool2d_checks_and_compute_shape(
  2401. input, kernel_size, stride, padding, dilation, ceil_mode
  2402. )
  2403. nbatch = input.size(-4) if input.dim() == 4 else 1
  2404. memory_format = torch.channels_last
  2405. if input.dim() == 3:
  2406. size = [nInputPlane, outputHeight, outputWidth]
  2407. else:
  2408. size = [nbatch, nInputPlane, outputHeight, outputWidth]
  2409. return torch.empty(
  2410. size,
  2411. dtype=input.dtype,
  2412. device=input.device,
  2413. memory_format=memory_format,
  2414. )
  2415. @register_meta(torch.ops.quantized.int4mm_packed_weight_cpu)
  2416. def meta_int4mm_packed_weight_cpu(x, w, q_group_size, q_scale_and_zeros):
  2417. torch._check(x.dim() == 2, lambda: f"x must be a 2D tensor, got {x.dim()}D")
  2418. torch._check(w.dim() == 2, lambda: f"w must be a 2D tensor, got {w.dim()}D")
  2419. torch._check(
  2420. x.dtype in [torch.float32, torch.float16, torch.bfloat16],
  2421. lambda: f"expected x to be f32/f16/bf16, got {x.dtype}",
  2422. )
  2423. torch._check(
  2424. w.dtype == torch.uint8, lambda: f"expected w to be uint8, got {w.dtype}"
  2425. )
  2426. torch._check(
  2427. q_group_size.dtype == torch.int64,
  2428. lambda: f"q_group_size must be int64, got {q_group_size.dtype}",
  2429. )
  2430. torch._check(
  2431. q_scale_and_zeros.dtype == x.dtype,
  2432. lambda: f"q_scale_and_zeros must have the same dtype as x, got {q_scale_and_zeros.dtype}",
  2433. )
  2434. return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
  2435. # from check_dim_size() in aten/src/ATen/TensorUtils.cpp.
  2436. def check_dim_size(tensor, dim, dim_size, size):
  2437. torch._check(
  2438. tensor.dim() == dim and tensor.shape[dim_size] == size,
  2439. lambda: f"Expected a tensor of dimension {dim} and tensor.size[{dim_size}] == {size}, "
  2440. + f"but got : dimension {tensor.dim()} and tensor.size[{dim_size}] = {tensor.shape[dim_size]}",
  2441. )
  2442. @register_meta(aten.avg_pool2d.default)
  2443. def meta_avg_pool2d(
  2444. input,
  2445. kernel_size,
  2446. stride=(),
  2447. padding=(0,),
  2448. ceil_mode=False,
  2449. count_include_pad=True,
  2450. divisor_override=None,
  2451. ):
  2452. def unpack(name, val):
  2453. torch._check(
  2454. len(val) in [1, 2],
  2455. lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints",
  2456. )
  2457. H = val[0]
  2458. W = H if len(val) == 1 else val[1]
  2459. return H, W
  2460. kH, kW = unpack("kernel_size", kernel_size)
  2461. torch._check(
  2462. len(stride) in [0, 1, 2],
  2463. lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
  2464. )
  2465. torch._check(
  2466. input.dtype not in [torch.uint8, torch.uint16, torch.uint32, torch.uint64],
  2467. lambda: f""""avg_pool2d" not implemented for '{input.dtype.__str__()}'""",
  2468. )
  2469. if len(stride) == 0:
  2470. dH, dW = kH, kW
  2471. elif len(stride) == 1:
  2472. dH, dW = stride[0], stride[0]
  2473. else:
  2474. dH, dW = unpack("stride", stride)
  2475. padH, padW = unpack("padding", padding)
  2476. torch._check(
  2477. divisor_override is None or divisor_override != 0,
  2478. lambda: "divisor must be not zero",
  2479. )
  2480. nbatch = input.size(-4) if input.dim() == 4 else 1
  2481. nInputPlane = input.size(-3)
  2482. inputHeight = input.size(-2)
  2483. inputWidth = input.size(-1)
  2484. outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
  2485. outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
  2486. memory_format = utils.suggest_memory_format(input)
  2487. pool2d_shape_check(
  2488. input,
  2489. kH,
  2490. kW,
  2491. dH,
  2492. dW,
  2493. padH,
  2494. padW,
  2495. 1,
  2496. 1,
  2497. nInputPlane,
  2498. inputHeight,
  2499. inputWidth,
  2500. outputHeight,
  2501. outputWidth,
  2502. memory_format,
  2503. )
  2504. if input.dim() == 3:
  2505. size = [nInputPlane, outputHeight, outputWidth]
  2506. else:
  2507. size = [nbatch, nInputPlane, outputHeight, outputWidth]
  2508. return torch.empty(
  2509. size,
  2510. dtype=input.dtype,
  2511. device=input.device,
  2512. memory_format=memory_format,
  2513. )
  2514. # from avg_pool2d_backward_shape_check() in aten/src/ATen/native/Pool.h.
  2515. def avg_pool2d_backward_shape_check(
  2516. input,
  2517. gradOutput,
  2518. nbatch,
  2519. kH,
  2520. kW,
  2521. dH,
  2522. dW,
  2523. padH,
  2524. padW,
  2525. nInputPlane,
  2526. inputHeight,
  2527. inputWidth,
  2528. outputHeight,
  2529. outputWidth,
  2530. mem_format,
  2531. ):
  2532. pool2d_shape_check(
  2533. input,
  2534. kH,
  2535. kW,
  2536. dH,
  2537. dW,
  2538. padH,
  2539. padW,
  2540. 1,
  2541. 1,
  2542. nInputPlane,
  2543. inputHeight,
  2544. inputWidth,
  2545. outputHeight,
  2546. outputWidth,
  2547. mem_format,
  2548. )
  2549. ndim = input.dim()
  2550. nOutputPlane = nInputPlane
  2551. check_dim_size(gradOutput, ndim, ndim - 3, nOutputPlane)
  2552. check_dim_size(gradOutput, ndim, ndim - 2, outputHeight)
  2553. check_dim_size(gradOutput, ndim, ndim - 1, outputWidth)
  2554. # Don't override the C++ registration.
  2555. @register_meta(aten.avg_pool2d_backward.default)
  2556. def meta_avg_pool2d_backward(
  2557. gradOutput_,
  2558. input,
  2559. kernel_size,
  2560. stride,
  2561. padding,
  2562. ceil_mode,
  2563. count_include_pad,
  2564. divisor_override,
  2565. ):
  2566. # From aten/src/ATen/native/AveragePool2d.cpp structured kernel meta func.
  2567. torch._check(
  2568. len(kernel_size) == 1 or len(kernel_size) == 2,
  2569. lambda: "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints",
  2570. )
  2571. kH = kernel_size[0]
  2572. kW = kH if len(kernel_size) == 1 else kernel_size[1]
  2573. torch._check(
  2574. len(stride) == 0 or len(stride) == 1 or len(stride) == 2,
  2575. lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
  2576. )
  2577. dH = kH if len(stride) == 0 else stride[0]
  2578. dW = kW if len(stride) == 0 else dH if len(stride) == 1 else stride[1]
  2579. torch._check(
  2580. len(padding) == 1 or len(padding) == 2,
  2581. lambda: "avg_pool2d: padding must either be a single int, or a tuple of two ints",
  2582. )
  2583. padH = padding[0]
  2584. padW = padH if len(padding) == 1 else padding[1]
  2585. torch._check(
  2586. divisor_override is None or divisor_override != 0,
  2587. lambda: "divisor must be not zero",
  2588. )
  2589. input_size = input.shape
  2590. nbatch = input_size[-4] if input.dim() == 4 else 1
  2591. nInputPlane = input_size[-3]
  2592. inputHeight = input_size[-2]
  2593. inputWidth = input_size[-1]
  2594. outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
  2595. outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
  2596. mem_format = utils.suggest_memory_format(input)
  2597. avg_pool2d_backward_shape_check(
  2598. input,
  2599. gradOutput_,
  2600. nbatch,
  2601. kH,
  2602. kW,
  2603. dH,
  2604. dW,
  2605. padH,
  2606. padW,
  2607. nInputPlane,
  2608. inputHeight,
  2609. inputWidth,
  2610. outputHeight,
  2611. outputWidth,
  2612. mem_format,
  2613. )
  2614. return torch.empty(
  2615. input_size,
  2616. dtype=input.dtype,
  2617. device=input.device,
  2618. memory_format=mem_format,
  2619. )
  2620. @register_meta(aten.avg_pool3d)
  2621. @out_wrapper()
  2622. def meta_avg_pool3d(
  2623. input,
  2624. kernel_size,
  2625. stride=(),
  2626. padding=(0,),
  2627. ceil_mode=False,
  2628. count_include_pad=True,
  2629. divisor_override=None,
  2630. ):
  2631. torch._check(
  2632. len(kernel_size) in (1, 3),
  2633. lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints",
  2634. )
  2635. kT = kernel_size[0]
  2636. kH = kT if len(kernel_size) == 1 else kernel_size[1]
  2637. kW = kT if len(kernel_size) == 1 else kernel_size[2]
  2638. torch._check(
  2639. not stride or len(stride) in (1, 3),
  2640. lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints",
  2641. )
  2642. torch._check(
  2643. input.dtype not in [torch.uint8, torch.uint16, torch.uint32, torch.uint64],
  2644. lambda: f""""avg_pool3d" not implemented for '{input.dtype.__str__()}'""",
  2645. )
  2646. dT = kT if not stride else stride[0]
  2647. dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
  2648. dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
  2649. torch._check(
  2650. len(padding) in (1, 3),
  2651. lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints",
  2652. )
  2653. padT = padding[0]
  2654. padH = padT if len(padding) == 1 else padding[1]
  2655. padW = padT if len(padding) == 1 else padding[2]
  2656. torch._check(
  2657. input.ndim in (4, 5),
  2658. lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
  2659. )
  2660. torch._check(
  2661. not divisor_override or divisor_override != 0,
  2662. lambda: "divisor must be not zero",
  2663. )
  2664. nbatch = input.size(0)
  2665. nslices = input.size(-4)
  2666. itime = input.size(-3)
  2667. iheight = input.size(-2)
  2668. iwidth = input.size(-1)
  2669. otime = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode)
  2670. oheight = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode)
  2671. owidth = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode)
  2672. pool3d_shape_check(
  2673. input,
  2674. nslices,
  2675. kT,
  2676. kH,
  2677. kW,
  2678. dT,
  2679. dH,
  2680. dW,
  2681. padT,
  2682. padH,
  2683. padW,
  2684. 1,
  2685. 1,
  2686. 1,
  2687. itime,
  2688. iheight,
  2689. iwidth,
  2690. otime,
  2691. oheight,
  2692. owidth,
  2693. "avg_pool3d()",
  2694. check_input_size=True,
  2695. )
  2696. if input.ndim == 4:
  2697. return input.new_empty((nslices, otime, oheight, owidth))
  2698. else:
  2699. return input.new_empty((nbatch, nslices, otime, oheight, owidth))
  2700. @register_meta(aten.avg_pool3d_backward)
  2701. @out_wrapper("grad_input")
  2702. def meta_avg_pool3d_backward(
  2703. grad_output,
  2704. input,
  2705. kernel_size,
  2706. stride,
  2707. padding,
  2708. ceil_mode,
  2709. count_include_pad,
  2710. divisor_override,
  2711. ):
  2712. torch._check(
  2713. len(kernel_size) in (1, 3),
  2714. lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints",
  2715. )
  2716. kT = kernel_size[0]
  2717. kH = kT if len(kernel_size) == 1 else kernel_size[1]
  2718. kW = kT if len(kernel_size) == 1 else kernel_size[2]
  2719. torch._check(
  2720. not stride or len(stride) in (1, 3),
  2721. lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints",
  2722. )
  2723. dT = kT if not stride else stride[0]
  2724. dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
  2725. dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
  2726. torch._check(
  2727. len(padding) in (1, 3),
  2728. lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints",
  2729. )
  2730. padT = padding[0]
  2731. padH = padT if len(padding) == 1 else padding[1]
  2732. padW = padT if len(padding) == 1 else padding[2]
  2733. torch._check(
  2734. input.ndim in (4, 5),
  2735. lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
  2736. )
  2737. torch._check(
  2738. not divisor_override or divisor_override != 0,
  2739. lambda: "divisor must be not zero",
  2740. )
  2741. nslices = input.size(-4)
  2742. itime = input.size(-3)
  2743. iheight = input.size(-2)
  2744. iwidth = input.size(-1)
  2745. otime_for_shape_check = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode)
  2746. oheight_for_shape_check = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode)
  2747. owidth_for_shape_check = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode)
  2748. avg_pool3d_backward_shape_check(
  2749. input,
  2750. grad_output,
  2751. nslices,
  2752. kT,
  2753. kH,
  2754. kW,
  2755. dT,
  2756. dH,
  2757. dW,
  2758. padT,
  2759. padH,
  2760. padW,
  2761. itime,
  2762. iheight,
  2763. iwidth,
  2764. otime_for_shape_check,
  2765. oheight_for_shape_check,
  2766. owidth_for_shape_check,
  2767. "avg_pool3d_backward()",
  2768. )
  2769. return input.new_empty(input.shape)
  2770. @register_meta(aten._adaptive_avg_pool2d.default)
  2771. def meta_adaptive_avg_pool2d(self, output_size):
  2772. torch._check(
  2773. self.ndim == 3 or self.ndim == 4,
  2774. lambda: f"Expected 3D or 4D tensor, but got {self.shape}",
  2775. )
  2776. output_shape = self.shape[:-2] + tuple(output_size)
  2777. memory_format = utils.suggest_memory_format(self)
  2778. # need to set memory_format to preserve the memory format of the input
  2779. # channel last input should have channel last output
  2780. return torch.empty(
  2781. output_shape,
  2782. dtype=self.dtype,
  2783. device=self.device,
  2784. memory_format=memory_format,
  2785. )
  2786. @register_meta(aten._adaptive_avg_pool3d.default)
  2787. def meta_adaptive_avg_pool3d(self, output_size):
  2788. torch._check(
  2789. self.ndim == 4 or self.ndim == 5,
  2790. lambda: f"Expected 4D or 5D tensor, but got {self.shape}",
  2791. )
  2792. return self.new_empty(self.shape[:-3] + tuple(output_size))
  2793. @register_meta(aten._adaptive_avg_pool2d_backward.default)
  2794. def meta__adaptive_avg_pool2d_backward(grad_out, self):
  2795. ndim = grad_out.ndim
  2796. for i in range(1, ndim):
  2797. torch._check(
  2798. grad_out.size(i) > 0,
  2799. lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \
  2800. size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty",
  2801. )
  2802. torch._check(
  2803. ndim == 3 or ndim == 4,
  2804. lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}",
  2805. )
  2806. torch._check(
  2807. self.dtype == grad_out.dtype,
  2808. lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}",
  2809. )
  2810. memory_format = torch.contiguous_format
  2811. if is_channels_last(self):
  2812. memory_format = torch.channels_last
  2813. return self.new_empty(self.shape).to(memory_format=memory_format)
  2814. @register_meta(aten._adaptive_avg_pool3d_backward)
  2815. @out_wrapper("grad_input")
  2816. def meta__adaptive_avg_pool3d_backward(grad_output, self):
  2817. _adaptive_pool_empty_output_check(grad_output, "adaptive_avg_pool3d_backward")
  2818. return torch.empty_like(self, memory_format=torch.legacy_contiguous_format)
  2819. def _adaptive_pool_empty_output_check(grad_output: Tensor, arg_name: str):
  2820. ndim = grad_output.ndim
  2821. for i in range(1, ndim):
  2822. torch._check(
  2823. grad_output.size(i) > 0,
  2824. lambda: (
  2825. f"{arg_name}(): Expected grad_output to have non-zero size for non-batch dimensions, "
  2826. f"but grad_output has sizes {grad_output.shape} with dimension {i} being empty"
  2827. ),
  2828. )
  2829. @register_meta(aten.adaptive_max_pool2d)
  2830. @out_wrapper("out", "indices")
  2831. def meta_adaptive_max_pool2d(input, output_size):
  2832. ndim = input.ndim
  2833. torch._check(
  2834. ndim in (3, 4),
  2835. lambda: f"adaptive_max_pool2d(): Expected 3D or 4D tensor, but got: {input.shape}",
  2836. )
  2837. for i in range(1, ndim):
  2838. torch._check(
  2839. input.size(i) > 0,
  2840. lambda: (
  2841. f"adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
  2842. f"but input has sizes {input.shape} with dimension {i} being empty"
  2843. ),
  2844. )
  2845. torch._check(
  2846. len(output_size) == 2,
  2847. lambda: "adaptive_max_pool2d(): internal error: output_size.size() must be 2",
  2848. )
  2849. dimH = 1
  2850. sizeB = 1
  2851. sizeD = 0
  2852. if input.ndim == 4:
  2853. sizeB = input.size(0)
  2854. dimH += 1
  2855. sizeD = input.size(dimH - 1)
  2856. osizeH, osizeW = output_size
  2857. if input.ndim == 3:
  2858. out_shape = (sizeD, osizeH, osizeW)
  2859. out = input.new_empty(out_shape)
  2860. indices = input.new_empty(out_shape, dtype=torch.int64)
  2861. return out, indices
  2862. else:
  2863. out_shape = (sizeB, sizeD, osizeH, osizeW) # type: ignore[assignment]
  2864. memory_format = utils.suggest_memory_format(input)
  2865. out = input.new_empty(out_shape).to(memory_format=memory_format)
  2866. indices = input.new_empty(out_shape, dtype=torch.int64).to(
  2867. memory_format=memory_format
  2868. )
  2869. return out, indices
  2870. @register_meta(aten.adaptive_max_pool2d_backward)
  2871. @out_wrapper("grad_input")
  2872. def meta_adaptive_max_pool2d_backward(grad_output, input, indices):
  2873. ndim = grad_output.ndim
  2874. torch._check(
  2875. ndim in (3, 4),
  2876. lambda: f"adaptive_max_pooling2d_backward(): Expected 3D or 4D grad_output, but got: {grad_output.shape}",
  2877. )
  2878. _adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool2d_backward")
  2879. torch._check(
  2880. input.dtype == grad_output.dtype,
  2881. lambda: f"expected dtype {input.dtype} for `grad_output` but got dtype {grad_output.dtype}",
  2882. )
  2883. memory_format = utils.suggest_memory_format(input)
  2884. return input.new_empty(input.shape).to(memory_format=memory_format)
  2885. @register_meta(aten.adaptive_max_pool3d)
  2886. @out_wrapper("out", "indices")
  2887. def meta_adaptive_max_pool3d(input, output_size):
  2888. ndim = input.ndim
  2889. torch._check(
  2890. ndim in (4, 5),
  2891. lambda: f"adaptive_max_pool3d(): Expected 4D or 5D tensor, but got: {input.shape}",
  2892. )
  2893. for i in range(1, ndim):
  2894. torch._check(
  2895. input.size(i) > 0,
  2896. lambda: (
  2897. f"adaptive_max_pool3d(): Expected input to have non-zero size for non-batch dimensions, "
  2898. f"but input has sizes {input.shape} with dimension {i} being empty"
  2899. ),
  2900. )
  2901. torch._check(
  2902. len(output_size) == 3,
  2903. lambda: "adaptive_max_pool3d(): internal error: output_size.size() must be 3",
  2904. )
  2905. dimD = 0
  2906. sizeB = 1
  2907. sizeD = 0
  2908. if ndim == 5:
  2909. sizeB = input.size(0)
  2910. dimD += 1
  2911. sizeD = input.size(dimD)
  2912. osizeT, osizeH, osizeW = output_size
  2913. if ndim == 4:
  2914. out_shape = (sizeD, osizeT, osizeH, osizeW)
  2915. else:
  2916. out_shape = (sizeB, sizeD, osizeT, osizeH, osizeW) # type: ignore[assignment]
  2917. out = input.new_empty(out_shape)
  2918. indices = input.new_empty(out_shape, dtype=torch.int64)
  2919. return out, indices
  2920. @register_meta(aten.adaptive_max_pool3d_backward)
  2921. @out_wrapper("grad_input")
  2922. def meta_adaptive_max_pool3d_backward(grad_output, input, indices):
  2923. _adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool3d_backward")
  2924. return input.new_empty(input.shape)
  2925. @register_meta(aten.repeat_interleave.Tensor)
  2926. def meta_repeat_interleave_Tensor(repeats, output_size=None):
  2927. if output_size is None:
  2928. raise RuntimeError("cannot repeat_interleave a meta tensor without output_size")
  2929. return repeats.new_empty(output_size)
  2930. @register_meta([aten.complex.default, aten.complex.out])
  2931. @out_wrapper()
  2932. def meta_complex(real, imag):
  2933. if not real.dtype.is_floating_point:
  2934. raise AssertionError(f"real must be floating point, got {real.dtype}")
  2935. if not imag.dtype.is_floating_point:
  2936. raise AssertionError(f"imag must be floating point, got {imag.dtype}")
  2937. result = elementwise_meta(
  2938. real.to(corresponding_complex_dtype(real.dtype)),
  2939. imag.to(corresponding_complex_dtype(imag.dtype)),
  2940. type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  2941. )
  2942. return result
  2943. @register_meta([aten.nonzero_static.default, aten.nonzero_static.out])
  2944. @out_wrapper()
  2945. def nonzero_static(self, *, size, fill_value: int = -1):
  2946. # The impl of xpu nonzero_static is different with cuda but aligned with cpu
  2947. if device_hint(self) in ("cpu", "xpu"):
  2948. return self.new_empty((size, self.dim()), dtype=torch.long)
  2949. else:
  2950. return torch.empty_strided(
  2951. (size, self.dim()),
  2952. (1, size),
  2953. dtype=torch.long,
  2954. device=self.device,
  2955. )
  2956. @register_meta([torch.ops.aten.nonzero.default, torch.ops.aten.nonzero.out])
  2957. @out_wrapper()
  2958. def nonzero(self):
  2959. torch._check_not_implemented(
  2960. exp_config.meta_nonzero_assume_all_nonzero,
  2961. lambda: "The register_meta function for torch.nonzero() raises unimplemented by default, "
  2962. "as a correct data-independent implementation does not exist. This implementation "
  2963. "returns a fake value, assuming all elements of the tensor are non-zero. "
  2964. "To enable this registration, please set "
  2965. "'torch.fx.experimental._config.meta_nonzero_assume_all_nonzero' to True.",
  2966. )
  2967. return torch.empty_strided(
  2968. (self.numel(), self.dim()),
  2969. (1, self.numel()),
  2970. dtype=torch.long,
  2971. device=self.device,
  2972. )
  2973. @register_meta([aten.index.Tensor, aten._unsafe_index.Tensor])
  2974. def meta_index_Tensor(self, indices):
  2975. torch._check(bool(indices), lambda: "at least one index must be provided")
  2976. # aten::index is the internal advanced indexing implementation
  2977. # checkIndexTensorTypes and expandTensors
  2978. result: list[Tensor | None] = []
  2979. for i, index in enumerate(indices):
  2980. if index is not None:
  2981. torch._check(
  2982. index.dtype in [torch.long, torch.int, torch.int8, torch.bool],
  2983. lambda: "tensors used as indices must be long, int, byte or bool tensors",
  2984. )
  2985. if index.dtype in [torch.int8, torch.bool]:
  2986. nonzero = index.nonzero()
  2987. k = len(result)
  2988. torch._check_index(
  2989. k + index.ndim <= self.ndim,
  2990. lambda: f"too many indices for tensor of dimension {self.ndim}",
  2991. )
  2992. for j in range(index.ndim):
  2993. torch._check_index(
  2994. index.shape[j] == self.shape[k + j],
  2995. lambda: f"The shape of the mask {index.shape} at index {i} "
  2996. f"does not match the shape of the indexed tensor {self.shape} at index {k + j}",
  2997. )
  2998. result.append(nonzero.select(1, j))
  2999. else:
  3000. result.append(index)
  3001. else:
  3002. result.append(index)
  3003. indices = result
  3004. torch._check(
  3005. len(indices) <= self.ndim,
  3006. lambda: f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})",
  3007. )
  3008. # expand_outplace
  3009. import torch._refs as refs # avoid import cycle in mypy
  3010. indices = list(refs._maybe_broadcast(*indices))
  3011. # add missing null tensors
  3012. while len(indices) < self.ndim:
  3013. indices.append(None)
  3014. # hasContiguousSubspace
  3015. # true if all non-null tensors are adjacent
  3016. # See:
  3017. # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
  3018. # https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
  3019. state = 0
  3020. has_contiguous_subspace = False
  3021. for index in indices:
  3022. if state == 0:
  3023. if index is not None:
  3024. state = 1
  3025. elif state == 1:
  3026. if index is None:
  3027. state = 2
  3028. else:
  3029. if index is not None:
  3030. break
  3031. else:
  3032. has_contiguous_subspace = True
  3033. # transposeToFront
  3034. # This is the logic that causes the newly inserted dimensions to show up
  3035. # at the beginning of the tensor, if they're not contiguous
  3036. if not has_contiguous_subspace:
  3037. dims = []
  3038. transposed_indices = []
  3039. for i, index in enumerate(indices):
  3040. if index is not None:
  3041. dims.append(i)
  3042. transposed_indices.append(index)
  3043. for i, index in enumerate(indices):
  3044. if index is None:
  3045. dims.append(i)
  3046. transposed_indices.append(index)
  3047. self = self.permute(dims)
  3048. indices = transposed_indices
  3049. # AdvancedIndex::AdvancedIndex
  3050. # Now we can assume the indices have contiguous subspace
  3051. # This is simplified from AdvancedIndex which goes to more effort
  3052. # to put the input and indices in a form so that TensorIterator can
  3053. # take them. If we write a ref for this, probably that logic should
  3054. # get implemented
  3055. before_shape: list[int] = []
  3056. after_shape: list[int] = []
  3057. replacement_shape: list[int] = []
  3058. for dim, index in enumerate(indices):
  3059. if index is None:
  3060. if replacement_shape:
  3061. after_shape.append(self.shape[dim])
  3062. else:
  3063. before_shape.append(self.shape[dim])
  3064. else:
  3065. replacement_shape = list(index.shape)
  3066. def _restride_src(self):
  3067. """
  3068. This follows restride_src in TensorAdvancedIndexing.cpp
  3069. """
  3070. shape = before_shape + replacement_shape + after_shape
  3071. strides = list(self.stride())
  3072. # pyrefly: ignore [unsupported-operation]
  3073. strides[len(before_shape) : len(self.shape) - len(after_shape)] = [0] * len(
  3074. replacement_shape
  3075. )
  3076. return self.as_strided(shape, strides)
  3077. out = self.new_empty(before_shape + replacement_shape + after_shape)
  3078. from torch.fx.experimental.symbolic_shapes import guard_or_false
  3079. if guard_or_false(self.numel() == 0):
  3080. # No need to worry about the output strides if self is empty.
  3081. return out
  3082. # Try to follow eager to decide the output stride based on self.
  3083. # Note that perm here is the reverse of the 'perm_' decided by
  3084. # TensorIteratorBase::reorder_dimensions
  3085. restrided_self = _restride_src(self)
  3086. perm, _ = utils.compute_elementwise_output_logical_to_physical_perm(restrided_self)
  3087. # Follow TensorIteratorBase::allocate_or_resize_outputs
  3088. if list(perm) != list(range(len(perm))):
  3089. perm_shape = utils.apply_perm(out.shape, perm)
  3090. new_stride = utils.make_contiguous_strides_for(perm_shape)
  3091. new_stride = utils.apply_perm(new_stride, utils.invert_perm(perm))
  3092. out = out.as_strided(out.size(), new_stride)
  3093. return out
  3094. @register_meta([aten.convolution_backward.default])
  3095. def meta_convolution_backward(
  3096. grad_output_,
  3097. input_,
  3098. weight_,
  3099. bias_sizes_opt,
  3100. stride,
  3101. padding,
  3102. dilation,
  3103. transposed,
  3104. output_padding,
  3105. groups,
  3106. output_mask,
  3107. ):
  3108. # High level logic taken from slow_conv3d_backward_cpu which should
  3109. # be representative of all convolution_backward impls
  3110. backend_grad_input = None
  3111. backend_grad_weight = None
  3112. backend_grad_bias = None
  3113. # Backend layout expectation: GPU backends (CUDA via cudnn_conv_suggest_memory_format,
  3114. # MPS via mps_conv_use_channels_last) return channels_last outputs when either input
  3115. # tensor is channels_last. This must be matched here to avoid stride assertion failures
  3116. # in inductor when the predicted strides don't match actual backend output strides.
  3117. # See: https://github.com/pytorch/pytorch/issues/171622
  3118. #
  3119. # Memory format inference rules (matching backend behavior):
  3120. # - grad_input format: derived from grad_output and weight
  3121. # - grad_weight format: derived from input and grad_output
  3122. def _conv_memory_format(t1, t2):
  3123. # Match the logic in cudnn_conv_suggest_memory_format and mps_conv_use_channels_last:
  3124. # Use channels_last if either tensor suggests it
  3125. fmt1 = suggest_memory_format(t1)
  3126. fmt2 = suggest_memory_format(t2)
  3127. if fmt1 == torch.channels_last or fmt2 == torch.channels_last:
  3128. return torch.channels_last
  3129. if fmt1 == torch.channels_last_3d or fmt2 == torch.channels_last_3d:
  3130. return torch.channels_last_3d
  3131. return torch.contiguous_format
  3132. if output_mask[0]:
  3133. memory_format = _conv_memory_format(grad_output_, weight_)
  3134. backend_grad_input = grad_output_.new_empty(input_.size()).to(
  3135. memory_format=memory_format
  3136. )
  3137. if output_mask[1]:
  3138. memory_format = _conv_memory_format(input_, grad_output_)
  3139. backend_grad_weight = grad_output_.new_empty(weight_.size()).to(
  3140. memory_format=memory_format
  3141. )
  3142. if output_mask[2]:
  3143. backend_grad_bias = grad_output_.new_empty(bias_sizes_opt)
  3144. return (backend_grad_input, backend_grad_weight, backend_grad_bias)
  3145. @register_meta([aten.addbmm.default, aten.addbmm.out])
  3146. @out_wrapper(exact_dtype=True)
  3147. def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1):
  3148. dim1 = batch1.size(1)
  3149. dim2 = batch2.size(2)
  3150. self = self.expand((dim1, dim2))
  3151. torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
  3152. torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
  3153. torch._check(
  3154. batch1.size(0) == batch2.size(0),
  3155. lambda: f"batch1 and batch2 must have same number of batches, got {batch1.size(0)} and {batch2.size(0)}",
  3156. )
  3157. torch._check(
  3158. batch1.size(2) == batch2.size(1),
  3159. lambda: (
  3160. f"Incompatible matrix sizes for bmm ({batch1.size(1)}x{batch1.size(2)} "
  3161. f"and {batch2.size(1)}x{batch2.size(2)})"
  3162. ),
  3163. )
  3164. torch._check(
  3165. self.size(0) == dim1 and self.size(1) == dim2,
  3166. lambda: "self tensor does not match matmul output shape",
  3167. )
  3168. return self.new_empty(self.size())
  3169. @register_meta([aten.randint_like.Tensor])
  3170. def meta_randint_like(self, high, **kwargs):
  3171. return self.new_empty(self.size())
  3172. @register_meta([aten._fused_adam_.default, aten._fused_adamw_.default])
  3173. def meta__fused_adam_(
  3174. self,
  3175. grads,
  3176. exp_avgs,
  3177. exp_avg_sqs,
  3178. max_exp_avg_sqs,
  3179. state_steps,
  3180. *,
  3181. lr,
  3182. beta1,
  3183. beta2,
  3184. weight_decay,
  3185. eps,
  3186. amsgrad,
  3187. maximize,
  3188. grad_scale=None,
  3189. found_inf=None,
  3190. ):
  3191. for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]:
  3192. torch._check(
  3193. isinstance(l, list),
  3194. lambda: f"exponent must be a tensor list but got {type(l)}",
  3195. )
  3196. @register_meta([aten._fused_adam.default])
  3197. def meta__fused_adam(
  3198. self,
  3199. grads,
  3200. exp_avgs,
  3201. exp_avg_sqs,
  3202. max_exp_avg_sqs,
  3203. state_steps,
  3204. *,
  3205. lr,
  3206. beta1,
  3207. beta2,
  3208. weight_decay,
  3209. eps,
  3210. amsgrad,
  3211. maximize,
  3212. grad_scale=None,
  3213. found_inf=None,
  3214. ):
  3215. for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]:
  3216. torch._check(
  3217. isinstance(l, list),
  3218. lambda: f"exponent must be a tensor list but got {type(l)}",
  3219. )
  3220. def empty_like_list(tensor_list):
  3221. return [torch.empty_like(t) for t in tensor_list]
  3222. return (
  3223. empty_like_list(self),
  3224. empty_like_list(grads),
  3225. empty_like_list(exp_avgs),
  3226. empty_like_list(exp_avg_sqs),
  3227. empty_like_list(max_exp_avg_sqs),
  3228. )
  3229. @register_meta([aten._int_mm])
  3230. @out_wrapper()
  3231. def meta__int_mm(a, b):
  3232. torch._check(a.dim() == 2, lambda: "a must be a 2D tensor")
  3233. torch._check(b.dim() == 2, lambda: "b must be a 2D tensor")
  3234. torch._check(
  3235. a.dtype is torch.int8,
  3236. lambda: f"expected self to be int8, got {a.dtype}",
  3237. )
  3238. torch._check(
  3239. b.dtype is torch.int8,
  3240. lambda: f"expected mat2 to be int8, got {b.dtype}",
  3241. )
  3242. torch._check(
  3243. a.size(1) == b.size(0),
  3244. lambda: (
  3245. f"Incompatible matrix sizes for _int_mm ({a.size(0)}x{a.size(1)} "
  3246. f"and {b.size(0)}x{b.size(1)})"
  3247. ),
  3248. )
  3249. return a.new_empty((a.size(0), b.size(1)), dtype=torch.int32)
  3250. @register_meta([aten._convert_weight_to_int4pack])
  3251. def meta__convert_weight_to_int4pack(w, inner_k_tiles):
  3252. torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
  3253. torch._check(
  3254. w.dtype is torch.uint8,
  3255. lambda: f"expected w to be uint8, got {w.dtype}",
  3256. )
  3257. n = w.size(0)
  3258. k = w.size(1) * 2 # w is [n][k / 2] uint8
  3259. return w.new_empty(
  3260. (
  3261. n // 8,
  3262. k // (inner_k_tiles * 16),
  3263. 32,
  3264. inner_k_tiles // 2,
  3265. ),
  3266. dtype=torch.int32,
  3267. )
  3268. @register_meta([aten._convert_weight_to_int4pack_for_cpu])
  3269. def meta__convert_weight_to_int4pack_for_cpu(w, inner_k_tiles):
  3270. torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
  3271. torch._check(
  3272. w.dtype is torch.int32,
  3273. lambda: f"expected w to be int32, got {w.dtype}",
  3274. )
  3275. n = w.size(0)
  3276. k = w.size(1) # w is [n][k] int32
  3277. return w.new_empty(
  3278. (n, k // 2),
  3279. dtype=torch.uint8,
  3280. )
  3281. @register_meta([aten._weight_int4pack_mm])
  3282. def meta__weight_int4pack_mm(x, w, q_group_size, q_scale_and_zeros):
  3283. torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
  3284. expected_dim = 2 if w.fake_device.type == "xpu" else 4
  3285. torch._check(w.dim() == expected_dim, lambda: f"w must be a {expected_dim}D tensor")
  3286. torch._check(
  3287. x.dtype in [torch.float32, torch.float16, torch.bfloat16],
  3288. lambda: f"expected x to be f32/f16/bf16, got {x.dtype}",
  3289. )
  3290. torch._check(
  3291. w.dtype is torch.int32,
  3292. lambda: f"expected w to be int32, got {w.dtype}",
  3293. )
  3294. dim_n = w.size(0) if w.fake_device.type == "xpu" else w.size(0) * 8
  3295. return x.new_empty(x.size(0), dim_n, dtype=x.dtype)
  3296. @register_meta([aten._weight_int4pack_mm_for_cpu])
  3297. def meta__weight_int4pack_mm_for_cpu(x, w, q_group_size, q_scale_and_zeros):
  3298. torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
  3299. torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
  3300. torch._check(
  3301. x.dtype in [torch.float32, torch.float16, torch.bfloat16],
  3302. lambda: f"expected x to be f32/f16/bf16, got {x.dtype}",
  3303. )
  3304. torch._check(
  3305. w.dtype is torch.uint8,
  3306. lambda: f"expected w to be uint8, got {w.dtype}",
  3307. )
  3308. return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
  3309. @register_meta([aten._weight_int4pack_mm_with_scales_and_zeros])
  3310. def _weight_int4pack_mm_with_scales_and_zeros(x, w, q_group_size, qScale, qZeros):
  3311. torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
  3312. torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
  3313. torch._check(
  3314. x.dtype in [torch.float32, torch.float16, torch.bfloat16],
  3315. lambda: f"expected x to be f32/f16/bf16, got {x.dtype}",
  3316. )
  3317. torch._check(
  3318. w.dtype is torch.int32,
  3319. lambda: f"expected w to be int32, got {w.dtype}",
  3320. )
  3321. return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
  3322. def kai_roundup(a: int, b: int) -> int:
  3323. return ((a + b - 1) // b) * b
  3324. def get_kai_packed_weight_size(n_bits, N, K, groupsize):
  3325. if n_bits == 4:
  3326. # Works for both fp32 and bf16 Kernels
  3327. if groupsize == K: # channelwise
  3328. # dotprod params only [1x8x32_neon_dotprod]
  3329. kai_nr = 8
  3330. kai_kr = 16
  3331. kai_sr = 2
  3332. kai_num_bytes_sum_rhs = 4 # sizeof(int32_t)
  3333. kai_num_bytes_multiplier_rhs = 4 # sizeof(float)
  3334. kai_num_bytes_bias = 4 # sizeof(float)
  3335. def kai_k_roundedup(k, kr, sr):
  3336. # Since we pack a float and int32 value at the end of the row,
  3337. # we must make sure that k is a multiple of 4 for alignment
  3338. kr_sr_roundedup4 = kai_roundup(kr * sr, 4)
  3339. return kai_roundup(k, kr_sr_roundedup4)
  3340. def kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(
  3341. k, nr, kr, sr
  3342. ):
  3343. k_internal = kai_k_roundedup(k, kr, sr)
  3344. if (k_internal % 2) != 0:
  3345. raise AssertionError(f"k_internal must be even, got {k_internal}")
  3346. return nr * (
  3347. (k_internal // 2)
  3348. + kai_num_bytes_multiplier_rhs
  3349. + kai_num_bytes_sum_rhs
  3350. + kai_num_bytes_bias
  3351. )
  3352. def kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(
  3353. n, k, nr, kr, sr
  3354. ):
  3355. num_rows = kai_roundup(n, nr) // nr
  3356. return (
  3357. num_rows
  3358. * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(
  3359. k, nr, kr, sr
  3360. )
  3361. )
  3362. return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(
  3363. N, K, kai_nr, kai_kr, kai_sr
  3364. )
  3365. elif groupsize % 32 == 0 and K % groupsize == 0: # groupwise
  3366. kai_nr = 8
  3367. kai_kr = 16
  3368. kai_sr = 2
  3369. kai_num_bytes_sum_rhs = 4
  3370. kai_num_bytes_bias = 4
  3371. kai_nr_multiple_of = 4
  3372. kai_bl_multiple_of = 32
  3373. def kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
  3374. n, k, nr, kr, sr, bl
  3375. ):
  3376. if (bl % kr) != 0:
  3377. raise AssertionError(f"bl ({bl}) must be divisible by kr ({kr})")
  3378. if (nr % kai_nr_multiple_of) != 0:
  3379. raise AssertionError(
  3380. f"nr ({nr}) must be divisible by kai_nr_multiple_of ({kai_nr_multiple_of})"
  3381. )
  3382. if (bl % kai_bl_multiple_of) != 0:
  3383. raise AssertionError(
  3384. f"bl ({bl}) must be divisible by kai_bl_multiple_of ({kai_bl_multiple_of})"
  3385. )
  3386. num_rows = kai_roundup(n, nr) // nr
  3387. return (
  3388. num_rows
  3389. * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
  3390. k, nr, kr, sr, bl
  3391. )
  3392. )
  3393. def kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
  3394. k, nr, kr, sr, bl
  3395. ):
  3396. if (bl % kr) != 0:
  3397. raise AssertionError(f"bl ({bl}) must be divisible by kr ({kr})")
  3398. if (nr % kai_nr_multiple_of) != 0:
  3399. raise AssertionError(
  3400. f"nr ({nr}) must be divisible by kai_nr_multiple_of ({kai_nr_multiple_of})"
  3401. )
  3402. if (bl % kai_bl_multiple_of) != 0:
  3403. raise AssertionError(
  3404. f"bl ({bl}) must be divisible by kai_bl_multiple_of ({kai_bl_multiple_of})"
  3405. )
  3406. # kr and sr are unused in the calculation
  3407. num_bytes_multiplier_rhs = kai_get_bf16_datatype_size_in_bytes()
  3408. num_blocks_per_row = kai_num_blocks_per_row(k, bl)
  3409. num_bytes_per_block = kai_num_bytes_per_block(
  3410. bl, num_bytes_multiplier_rhs
  3411. )
  3412. return nr * (
  3413. (num_bytes_per_block * num_blocks_per_row)
  3414. + kai_num_bytes_sum_rhs
  3415. + kai_num_bytes_bias
  3416. )
  3417. # This function returns size of these datatypes stored as enum. We modify it to just return bf16 datatype
  3418. # https://gitlab.arm.com/kleidi/kleidiai/-/blob/main/kai/kai_common.h?ref_type=heads#L55
  3419. def kai_get_bf16_datatype_size_in_bytes():
  3420. return 2 # 2 bytes
  3421. def kai_num_blocks_per_row(k, bl):
  3422. if (bl % kai_bl_multiple_of) != 0:
  3423. raise AssertionError(
  3424. f"bl ({bl}) must be divisible by kai_bl_multiple_of ({kai_bl_multiple_of})"
  3425. )
  3426. return kai_roundup(k, bl) // bl
  3427. def kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs):
  3428. if (bl % kai_bl_multiple_of) != 0:
  3429. raise AssertionError(
  3430. f"bl ({bl}) must be divisible by kai_bl_multiple_of ({kai_bl_multiple_of})"
  3431. )
  3432. return (bl // 2) + num_bytes_multiplier_rhs
  3433. return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
  3434. N, K, kai_nr, kai_kr, kai_sr, groupsize
  3435. )
  3436. @register_meta([aten._dyn_quant_pack_4bit_weight])
  3437. def meta__dyn_quant_pack_4bit_weight(
  3438. weights, scales_zeros, bias: Tensor | None, block_size, in_features, out_features
  3439. ):
  3440. torch._check(
  3441. weights.dtype is torch.uint8,
  3442. lambda: f"expected w to be uint8, got {weights.dtype}",
  3443. )
  3444. if torch.backends.kleidiai.is_available() and (
  3445. (block_size == in_features and scales_zeros.dtype == torch.float)
  3446. or (
  3447. block_size < in_features
  3448. and block_size % 32 == 0
  3449. and in_features % block_size == 0
  3450. and scales_zeros.dtype == torch.bfloat16
  3451. )
  3452. ):
  3453. packed_weight_size = get_kai_packed_weight_size(
  3454. 4, out_features, in_features, block_size
  3455. )
  3456. return weights.new_empty(int(packed_weight_size), dtype=torch.uint8)
  3457. packed_weight_size = weights.numel() + scales_zeros.numel()
  3458. if bias is not None:
  3459. packed_weight_size += bias.numel()
  3460. return weights.new_empty(packed_weight_size, dtype=torch.float)
  3461. @register_meta([aten._dyn_quant_matmul_4bit])
  3462. def meta__dyn_quant_matmul_4bit(
  3463. inp,
  3464. packed_weights,
  3465. block_size,
  3466. in_features,
  3467. out_features,
  3468. ):
  3469. torch._check(inp.dim() == 2, lambda: "input must be a 2D tensor")
  3470. torch._check(
  3471. (inp.dtype == torch.float32)
  3472. or (inp.dtype == torch.bfloat16 and block_size == in_features),
  3473. lambda: (
  3474. f"expected input to be f32 or bf16 (bf16 requires block_size == in_features), "
  3475. f"got {inp.dtype} with block_size={block_size} and in_features={in_features}"
  3476. ),
  3477. )
  3478. M = inp.size(0)
  3479. return inp.new_empty(M, out_features, dtype=inp.dtype)
  3480. @register_meta([aten._weight_int8pack_mm])
  3481. def meta__weight_int8pack_mm(x, w, q_scales):
  3482. torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
  3483. torch._check(
  3484. x.dtype in [torch.float32, torch.float16, torch.bfloat16],
  3485. lambda: f"expected x to be f32/f16/bf16, got {x.dtype}",
  3486. )
  3487. torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
  3488. torch._check(
  3489. w.dtype is torch.int8,
  3490. lambda: f"expected w to be int8, got {w.dtype}",
  3491. )
  3492. return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
  3493. @register_meta(aten._cdist_forward.default)
  3494. def meta_cdist_forward(x1, x2, p, compute_mode):
  3495. torch._check(
  3496. x1.dim() >= 2,
  3497. lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D",
  3498. )
  3499. torch._check(
  3500. x2.dim() >= 2,
  3501. lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D",
  3502. )
  3503. torch._check(
  3504. x1.size(-1) == x2.size(-1),
  3505. lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}",
  3506. )
  3507. torch._check(
  3508. utils.is_float_dtype(x1.dtype),
  3509. lambda: f"cdist only supports floating-point dtypes, X1 got: {x1.dtype}",
  3510. )
  3511. torch._check(
  3512. utils.is_float_dtype(x2.dtype),
  3513. lambda: f"cdist only supports floating-point dtypes, X2 got: {x2.dtype}",
  3514. )
  3515. torch._check(p >= 0, lambda: "cdist only supports non-negative p values")
  3516. torch._check(
  3517. compute_mode in (None, 0, 1, 2),
  3518. lambda: f"possible modes: None, 0, 1, 2, but was: {compute_mode}",
  3519. )
  3520. r1 = x1.size(-2)
  3521. r2 = x2.size(-2)
  3522. batch_tensor1 = x1.shape[:-2]
  3523. batch_tensor2 = x2.shape[:-2]
  3524. output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
  3525. output_shape.extend([r1, r2])
  3526. return x1.new_empty(output_shape)
  3527. @register_meta(aten._cdist_backward)
  3528. @out_wrapper()
  3529. def meta_cdist_backward(grad, x1, x2, p, cdist):
  3530. c1 = x1.shape[-1]
  3531. r1 = x1.shape[-2]
  3532. r2 = x2.shape[-2]
  3533. batch_tensor1 = x1.shape[:-2]
  3534. batch_tensor2 = x2.shape[:-2]
  3535. expand_batch_portion = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
  3536. tensor1_expand_size = expand_batch_portion.copy()
  3537. tensor1_expand_size.extend([r1, c1])
  3538. batch_product = math.prod(expand_batch_portion)
  3539. if r1 == 0 or r2 == 0 or c1 == 0 or batch_product == 0:
  3540. return torch.zeros_like(x1)
  3541. if tensor1_expand_size != list(x1.shape):
  3542. x1 = x1.expand(tensor1_expand_size)
  3543. return torch.empty_like(x1, memory_format=torch.contiguous_format)
  3544. # NB: This meta function accepts non-meta arguments! When this behavior
  3545. # was originally introduced this was accidental, but it is now load bearing
  3546. # as people are using this so that they can conveniently test code involving
  3547. # embeddings (feeding CPU tensor inputs with meta device EmbeddingBag module)
  3548. @register_meta(aten._embedding_bag.default)
  3549. def meta_embedding_bag(
  3550. weight,
  3551. indices,
  3552. offsets,
  3553. scale_grad_by_freq=False,
  3554. mode=0,
  3555. sparse=False,
  3556. per_sample_weights=None,
  3557. include_last_offset=False,
  3558. padding_idx=-1,
  3559. ):
  3560. torch._check(
  3561. indices.dtype in (torch.long, torch.int),
  3562. lambda: f"expected indices to be long or int, got {indices.dtype}",
  3563. )
  3564. torch._check(
  3565. offsets.dtype in (torch.long, torch.int),
  3566. lambda: f"expected offsets to be long or int, got {offsets.dtype}",
  3567. )
  3568. torch._check(
  3569. utils.is_float_dtype(weight.dtype),
  3570. lambda: f"expected weight to be floating point type, got {weight.dtype}",
  3571. )
  3572. num_bags = offsets.size(0)
  3573. if include_last_offset:
  3574. torch._check(
  3575. num_bags >= 1,
  3576. lambda: "include_last_offset: numBags should be at least 1",
  3577. )
  3578. num_bags -= 1
  3579. output = weight.new_empty(num_bags, weight.size(1))
  3580. if per_sample_weights is not None:
  3581. torch._check(
  3582. mode == MODE_SUM,
  3583. lambda: "embedding_bag: per_sample_weights only supported with mode='sum'",
  3584. )
  3585. torch._check(
  3586. per_sample_weights.ndim == 1,
  3587. lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D",
  3588. )
  3589. torch._check(
  3590. per_sample_weights.numel() == indices.numel(),
  3591. lambda: (
  3592. f"expected per_sample_weights.numel() ({per_sample_weights.numel()} "
  3593. f"to be the same as indices.numel() ({indices.numel()})"
  3594. ),
  3595. )
  3596. def is_fast_path_index_select_scale(src, scale, output, padding_idx):
  3597. return (
  3598. is_fast_path_index_select(src, output, padding_idx) and scale.stride(0) == 1
  3599. )
  3600. def is_fast_path_index_select(src, output, padding_idx):
  3601. return (
  3602. (src.dtype == torch.float or src.dtype == torch.half)
  3603. and src.stride(1) == 1
  3604. and output.stride(1) == 1
  3605. and padding_idx < 0
  3606. )
  3607. def is_fast_path(src, scale, output, padding_idx):
  3608. if scale is not None:
  3609. return is_fast_path_index_select_scale(src, scale, output, padding_idx)
  3610. else:
  3611. return is_fast_path_index_select(src, output, padding_idx)
  3612. if device_hint(offsets) != "cpu":
  3613. offset2bag = indices.new_empty(indices.size(0))
  3614. bag_size = indices.new_empty(offsets.size())
  3615. if mode == MODE_MAX:
  3616. max_indices = indices.new_empty(num_bags, weight.size(1))
  3617. else:
  3618. max_indices = indices.new_empty(0)
  3619. else:
  3620. fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx)
  3621. if mode in (MODE_MEAN, MODE_MAX) or not fast_path_sum:
  3622. offset2bag = offsets.new_empty(indices.size(0))
  3623. else:
  3624. offset2bag = offsets.new_empty(0)
  3625. bag_size = offsets.new_empty(num_bags)
  3626. # This part of the logic comes from make_max_indices_out in EmbeddingBag.cpp
  3627. numBags = offsets.shape[0]
  3628. if mode == MODE_MAX:
  3629. if include_last_offset:
  3630. torch._check(
  3631. numBags >= 1,
  3632. lambda: "include_last_offset: numBags should be at least 1",
  3633. )
  3634. numBags -= 1
  3635. max_indices = offsets.new_empty(numBags, weight.shape[1])
  3636. else:
  3637. max_indices = offsets.new_empty(bag_size.size())
  3638. return output, offset2bag, bag_size, max_indices
  3639. @register_meta(aten._embedding_bag_forward_only.default)
  3640. def meta_embedding_bag_forward_only(weight, indices, offsets, *args):
  3641. output, offset2bag, bag_size, max_indices = meta_embedding_bag(
  3642. weight, indices, offsets, *args
  3643. )
  3644. if device_hint(offsets) == "cpu":
  3645. bag_size = offsets.new_empty(offsets.size())
  3646. return output, offset2bag, bag_size, max_indices
  3647. def _get_reduction_dtype(input, dtype, promote_int_to_long=True):
  3648. # if specified, dtype takes precedence
  3649. if dtype:
  3650. return dtype
  3651. if input.dtype.is_floating_point or input.dtype.is_complex:
  3652. return input.dtype
  3653. elif promote_int_to_long:
  3654. return torch.long
  3655. return input.dtype
  3656. @register_meta([aten.nansum.default, aten.nansum.out])
  3657. @out_wrapper()
  3658. def meta_nansum(input, dims=None, keepdim=False, *, dtype=None):
  3659. output_dtype = _get_reduction_dtype(input, dtype, promote_int_to_long=True)
  3660. dims = utils.reduction_dims(input.shape, dims)
  3661. output_shape = _compute_reduction_shape(input, dims, keepdim)
  3662. return input.new_empty(output_shape, dtype=output_dtype)
  3663. @register_meta([aten.median.default, aten.nanmedian.default])
  3664. def meta_median(input):
  3665. output_shape = utils.compute_reduction_output_shape(
  3666. input.shape, tuple(range(input.dim()))
  3667. )
  3668. return input.new_empty(output_shape)
  3669. @register_meta(
  3670. [
  3671. aten.median.dim,
  3672. aten.median.dim_values,
  3673. aten.nanmedian.dim,
  3674. aten.nanmedian.dim_values,
  3675. aten.mode.default,
  3676. aten.mode.values,
  3677. ]
  3678. )
  3679. @out_wrapper("values", "indices")
  3680. def meta_median_mode_dim(input, dim=-1, keepdim=False):
  3681. if device_hint(input) == "cuda":
  3682. utils.alert_not_deterministic("median CUDA with indices output")
  3683. dim = utils.reduction_dims(input.shape, (dim,))
  3684. output_shape = _compute_reduction_shape(input, dim, keepdim)
  3685. return (
  3686. input.new_empty(output_shape),
  3687. input.new_empty(output_shape, dtype=torch.long),
  3688. )
  3689. @register_meta(aten.logical_not_.default)
  3690. def meta_logical_not_(self):
  3691. return self
  3692. @register_meta(aten.repeat.default)
  3693. def meta_repeat(self, repeats):
  3694. torch._check(
  3695. len(repeats) >= self.dim(),
  3696. lambda: "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
  3697. )
  3698. for i, rep in enumerate(repeats):
  3699. torch._check(
  3700. rep >= 0,
  3701. lambda: f"Repeats cannot be negative, found {rep} at index {i}",
  3702. )
  3703. # Add new leading dimensions to the tensor if the
  3704. # number of target dimensions is larger than the
  3705. # number of source dimensions.
  3706. num_new_dimensions = len(repeats) - self.dim()
  3707. padded_size = (1,) * num_new_dimensions + tuple(self.shape)
  3708. target_size = [padded_size[i] * repeats[i] for i in range(len(repeats))]
  3709. return self.new_empty(target_size)
  3710. @register_meta(aten.zero_.default)
  3711. def meta_zero_(self):
  3712. return self
  3713. @register_meta(
  3714. [
  3715. aten.mul_.Scalar,
  3716. aten.div_.Scalar,
  3717. aten.mul_.Tensor,
  3718. aten.div_.Tensor,
  3719. aten.logical_and_.default,
  3720. aten.logical_or_.default,
  3721. aten.logical_xor_.default,
  3722. ],
  3723. )
  3724. def meta_binop_inplace(self, other):
  3725. if isinstance(other, torch.Tensor):
  3726. check_inplace_broadcast(self.shape, other.shape)
  3727. return self
  3728. @register_meta(
  3729. [
  3730. aten.add_.Scalar,
  3731. aten.sub_.Scalar,
  3732. aten.add_.Tensor,
  3733. aten.sub_.Tensor,
  3734. ],
  3735. )
  3736. def meta_binop_inplace_alpha(self, other, alpha=1):
  3737. """
  3738. Some checks for inplace ops.
  3739. Checks for promotion rules for some dtypes.
  3740. int.add/sub_(float) and bool.add/sub_(others) are rejected.
  3741. Promoting in these in-place operations would require reallocating
  3742. and copying over elements, hence not allowed.
  3743. Checks for alpha param.
  3744. """
  3745. def is_integeric(arg):
  3746. if isinstance(arg, TensorLike):
  3747. return utils.is_integer_dtype(arg.dtype)
  3748. else:
  3749. return isinstance(arg, IntLike)
  3750. def is_floatic(arg):
  3751. if isinstance(arg, TensorLike):
  3752. return utils.is_float_dtype(arg.dtype)
  3753. else:
  3754. return isinstance(arg, FloatLike)
  3755. def is_booleanic(arg):
  3756. if isinstance(arg, TensorLike):
  3757. return utils.is_boolean_dtype(arg.dtype)
  3758. else:
  3759. return isinstance(arg, BoolLike)
  3760. # Do not allow int+float->int in-place
  3761. if is_integeric(self) and is_floatic(other):
  3762. raise RuntimeError(
  3763. "Promotion of int.add/sub_(float) in in-place ops are not possible due to element size change."
  3764. )
  3765. # Do not allow bool+other->bool in-place
  3766. if is_booleanic(self) and not is_booleanic(other):
  3767. raise RuntimeError(
  3768. "Promotion of book.add/sub_(others) in in-place ops are not possible due to element size change."
  3769. )
  3770. if isinstance(other, torch.Tensor):
  3771. check_inplace_broadcast(self.shape, other.shape)
  3772. return self
  3773. @register_meta(
  3774. [
  3775. aten.add.Scalar,
  3776. aten.sub.Scalar,
  3777. ],
  3778. )
  3779. def meta_binop_alpha(self, other, alpha=1):
  3780. return elementwise_meta(
  3781. self, other, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  3782. )
  3783. @register_meta([aten.round.default, aten.round.decimals])
  3784. def meta_round(self, **kwargs):
  3785. return elementwise_meta(
  3786. self, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  3787. )
  3788. def shift_dtype_check(fn_name, self, val):
  3789. torch._check(
  3790. utils.is_integer_dtype(self.dtype),
  3791. lambda: f"{fn_name}: Expected input tensor to have an integral dtype. Got {self.dtype}",
  3792. )
  3793. if isinstance(val, torch.Tensor):
  3794. torch._check(
  3795. utils.is_integer_dtype(val.dtype),
  3796. lambda: f"{fn_name}: Expected shift value to have an integral dtype. Got {val.dtype}",
  3797. )
  3798. else:
  3799. torch._check(
  3800. isinstance(val, IntLike),
  3801. lambda: f"{fn_name}: Expected shift value to be an int. Got {val}",
  3802. )
  3803. @register_meta([aten.__rshift__.Tensor, aten.__rshift__.Scalar])
  3804. def meta_rshifts(self, other):
  3805. shift_dtype_check("rshift", self, other)
  3806. return elementwise_meta(
  3807. self, other, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  3808. )
  3809. @register_meta([aten.__lshift__.Tensor, aten.__lshift__.Scalar])
  3810. def meta_lshifts(self, other):
  3811. shift_dtype_check("lshift", self, other)
  3812. return elementwise_meta(
  3813. self, other, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  3814. )
  3815. @register_meta(aten.zero.default)
  3816. def meta_zero(self):
  3817. return self.new_empty(self.shape)
  3818. @register_meta([aten.fill_.Tensor, aten.fill_.Scalar])
  3819. def meta_fill_(self, val):
  3820. return self
  3821. @register_meta([aten.fill.Tensor, aten.fill.Scalar])
  3822. def meta_fill(self, val):
  3823. return torch.empty_like(self)
  3824. @register_meta(aten.relu_.default)
  3825. def meta_relu_(self):
  3826. return self
  3827. @register_meta(aten._add_relu.Tensor)
  3828. @out_wrapper()
  3829. def meta__add_relu(self, other, alpha=1) -> Tensor:
  3830. return elementwise_meta(
  3831. self, other, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  3832. )
  3833. @register_meta([aten.rrelu_with_noise])
  3834. @out_wrapper()
  3835. def meta_rrelu_with_noise(
  3836. self, noise, lower=0.125, upper=0.3333333333333333, training=False, generator=None
  3837. ):
  3838. return torch.empty_like(self)
  3839. @register_meta([aten.rrelu_with_noise_functional])
  3840. def meta_rrelu_with_noise_functional(
  3841. self, noise, lower=0.125, upper=0.3333333333333333, training=False, generator=None
  3842. ):
  3843. return torch.empty_like(self), torch.empty_like(noise)
  3844. @register_meta([aten.rrelu_with_noise_])
  3845. def meta_rrelu_with_noise_(
  3846. self, lower=0.125, upper=0.3333333333333333, training=False, generator=None
  3847. ):
  3848. return self
  3849. @register_meta([aten.index_put.default, aten._unsafe_index_put.default])
  3850. def meta_index_put(self, indices, values, accumulate=False):
  3851. return torch.empty_like(self)
  3852. @register_meta(aten.masked_fill_.Scalar)
  3853. def meta_masked_fill_(self, mask, value):
  3854. check_inplace_broadcast(self.shape, mask.shape)
  3855. return self
  3856. @register_meta(aten._masked_scale.default)
  3857. def meta__masked_scale(self, mask, scale):
  3858. masked_scale = self.new_empty(self.size()).to(
  3859. memory_format=utils.suggest_memory_format(self)
  3860. )
  3861. return masked_scale
  3862. @register_meta(aten.masked_scatter_)
  3863. def meta_masked_scatter_(self, mask, source):
  3864. torch._check(
  3865. mask.dtype in (torch.bool, torch.uint8), lambda: "Mask must be bool or uint8"
  3866. )
  3867. torch._check(
  3868. self.dtype == source.dtype,
  3869. lambda: "masked_scatter: expected self and source to have same "
  3870. f"dtypes but got {self.dtype} and {source.dtype}",
  3871. )
  3872. return self
  3873. @register_meta(aten.masked_scatter)
  3874. @out_wrapper()
  3875. def meta_masked_scatter(self, mask, source):
  3876. self, mask = _maybe_broadcast(self, mask)
  3877. output = torch.empty_like(self, memory_format=torch.contiguous_format)
  3878. return meta_masked_scatter_(output, mask, source)
  3879. @register_meta(aten.masked_scatter_backward)
  3880. def meta_masked_scatter_backward(self, mask, sizes):
  3881. return self.new_empty(sizes)
  3882. @register_meta(aten.index_put_.default)
  3883. def meta_index_put_(self, indices, values, accumulate=False):
  3884. return self
  3885. def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None, out_dtype=None):
  3886. from torch.fx.experimental.symbolic_shapes import sym_and, sym_eq
  3887. torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
  3888. torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
  3889. batch1_sizes = batch1.size()
  3890. batch2_sizes = batch2.size()
  3891. bs = batch1_sizes[0]
  3892. contraction_size = batch1_sizes[2]
  3893. res_rows = batch1_sizes[1]
  3894. res_cols = batch2_sizes[2]
  3895. output_size = (bs, res_rows, res_cols)
  3896. torch._check(
  3897. sym_and(sym_eq(batch2_sizes[0], bs), sym_eq(batch2_sizes[1], contraction_size)),
  3898. lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}"
  3899. f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].",
  3900. )
  3901. if out_dtype:
  3902. supported_out_dtype = (
  3903. batch1.dtype == torch.float16 or batch1.dtype == torch.bfloat16
  3904. ) and out_dtype == torch.float32
  3905. torch._check(
  3906. out_dtype == batch1.dtype or supported_out_dtype,
  3907. lambda: "out_dtype only supported for torch.float32 output with float16/bfloat16 inputs or same as input dtypes",
  3908. )
  3909. output = batch2.new_empty(output_size).to(out_dtype)
  3910. else:
  3911. # TODO: handle out
  3912. output = batch2.new_empty(output_size)
  3913. if not is_bmm and self_baddbmm is not None:
  3914. torch._check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor")
  3915. torch._check(
  3916. sym_eq(self_baddbmm.size(), output_size),
  3917. lambda: f"Expected an input tensor shape with shape {output_size} but got shape: {self_baddbmm.size()}",
  3918. )
  3919. return output
  3920. @register_meta(aten.bmm.default)
  3921. def meta_bmm(self, mat2):
  3922. return common_meta_baddbmm_bmm(self, mat2, True)
  3923. @register_meta(aten.bmm.dtype)
  3924. def meta_bmm_dtype(self, mat2, out_dtype):
  3925. return common_meta_baddbmm_bmm(self, mat2, True, out_dtype=out_dtype)
  3926. def div_rtn(x, y):
  3927. q = x // y
  3928. r = x % y
  3929. # WARNING: explicit bool conversion here is necessary;
  3930. # would be fixed by SymBool
  3931. if r != 0 and (bool(r < 0) != bool(y < 0)):
  3932. q -= 1
  3933. return q
  3934. def pooling_output_shape_pad_lr(
  3935. inputSize,
  3936. kernelSize,
  3937. pad_l,
  3938. pad_r,
  3939. stride,
  3940. dilation,
  3941. ceil_mode,
  3942. ):
  3943. outputSize = (
  3944. div_rtn(
  3945. inputSize
  3946. + pad_l
  3947. + pad_r
  3948. - dilation * (kernelSize - 1)
  3949. - 1
  3950. + (stride - 1 if ceil_mode else 0),
  3951. stride,
  3952. )
  3953. + 1
  3954. )
  3955. if ceil_mode:
  3956. if (outputSize - 1) * stride >= inputSize + pad_l:
  3957. outputSize -= 1
  3958. return outputSize
  3959. def pooling_output_shape(inputSize, kernelSize, pad, stride, dilation, ceil_mode):
  3960. torch._check(stride != 0, lambda: "stride should not be zero")
  3961. torch._check(pad >= 0, lambda: f"pad must be non-negative, but got pad: {pad}")
  3962. torch._check(
  3963. pad <= ((kernelSize - 1) * dilation + 1) // 2,
  3964. lambda: (
  3965. f"pad should be at most half of effective kernel size, but got pad={pad}, "
  3966. f"kernel_size={kernelSize} and dilation={dilation}"
  3967. ),
  3968. )
  3969. return pooling_output_shape_pad_lr(
  3970. inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode
  3971. )
  3972. def pool2d_shape_check(
  3973. input,
  3974. kH,
  3975. kW,
  3976. dH,
  3977. dW,
  3978. padH,
  3979. padW,
  3980. dilationH,
  3981. dilationW,
  3982. nInputPlane,
  3983. inputHeight,
  3984. inputWidth,
  3985. outputHeight,
  3986. outputWidth,
  3987. memory_format,
  3988. ):
  3989. ndim = input.dim()
  3990. nOutputPlane = nInputPlane
  3991. torch._check(
  3992. kW > 0 and kH > 0,
  3993. lambda: f"kernel size should be greater than zero, but got kH: {kH}, kW: {kW}",
  3994. )
  3995. torch._check(
  3996. dW > 0 and dH > 0,
  3997. lambda: f"stride should be greater than zero, but got dH: {dH}, dW: {dW}",
  3998. )
  3999. torch._check(
  4000. dilationH > 0 and dilationW > 0,
  4001. lambda: f"dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}",
  4002. )
  4003. valid_dims = input.size(1) != 0 and input.size(2) != 0
  4004. if memory_format == torch.channels_last:
  4005. torch._check(
  4006. ndim == 4 and valid_dims and input.size(3) != 0,
  4007. lambda: "Expected 4D (batch mode) tensor expected for input with channels_last layout"
  4008. f" with optional 0 dim batch size for input, but got: {input.size()}",
  4009. )
  4010. else:
  4011. torch._check(
  4012. (ndim == 3 and input.size(0) != 0 and valid_dims)
  4013. or (ndim == 4 and valid_dims and input.size(3) != 0),
  4014. lambda: f"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got: {input.size()}",
  4015. )
  4016. torch._check(
  4017. kW // 2 >= padW and kH // 2 >= padH,
  4018. lambda: "pad should be smaller than or equal to half of kernel size, but got "
  4019. f"padW = {padW}, padH = {padH}, kW = {kW}, kH = {kH}",
  4020. )
  4021. torch._check(
  4022. outputWidth >= 1 and outputHeight >= 1,
  4023. lambda: f"Given input size: ({nInputPlane}x{inputHeight}x{inputWidth}). "
  4024. f"Calculated output size: ({nOutputPlane}x{outputHeight}x{outputWidth}). "
  4025. "Output size is too small",
  4026. )
  4027. def pool3d_shape_check(
  4028. input: Tensor,
  4029. nslices: int,
  4030. kT: int,
  4031. kH: int,
  4032. kW: int,
  4033. dT: int,
  4034. dH: int,
  4035. dW: int,
  4036. pT: int,
  4037. pH: int,
  4038. pW: int,
  4039. dilationT: int,
  4040. dilationH: int,
  4041. dilationW: int,
  4042. itime: int,
  4043. iheight: int,
  4044. iwidth: int,
  4045. otime: int,
  4046. oheight: int,
  4047. owidth: int,
  4048. fn_name: str,
  4049. check_input_size: bool = False,
  4050. ):
  4051. ndim = input.ndim
  4052. torch._check(
  4053. kT > 0 and kW > 0 and kH > 0,
  4054. lambda: (
  4055. f"kernel size should be greater than zero, but got "
  4056. f"kT: {kT}, kH: {kH}, kW: {kW}"
  4057. ),
  4058. )
  4059. torch._check(
  4060. dT > 0 and dW > 0 and dH > 0,
  4061. lambda: (
  4062. f"stride should be greater than zero, but got dT: {dT}, dH: {dH}, dW: {dW}"
  4063. ),
  4064. )
  4065. torch._check(
  4066. dilationT > 0 and dilationW > 0 and dilationH > 0,
  4067. lambda: (
  4068. f"dilation should be greater than zero, but got "
  4069. f"dilationT: {dilationT}, dilationH: {dilationH}, dilationW: {dilationW}"
  4070. ),
  4071. )
  4072. torch._check(
  4073. ndim in (4, 5),
  4074. lambda: f"{fn_name}: Expected 4D or 5D tensor for input, but got: {input.shape}",
  4075. )
  4076. for i in range(ndim):
  4077. if ndim == 5 and i == 0:
  4078. # size of batch-dim can be 0.
  4079. continue
  4080. torch._check(
  4081. input.size(i) > 0,
  4082. lambda: (
  4083. f"{fn_name}: Expected input's non-batch dimensions to have positive length,"
  4084. f" but input has a shape of {input.shape}"
  4085. f" and non-batch dimension {input.size(i)} has length zero!"
  4086. ),
  4087. )
  4088. if check_input_size: # AveragePool3d
  4089. torch._check(
  4090. itime >= kT and iheight >= kH and iwidth >= kW,
  4091. lambda: (
  4092. f"input image (T: {itime} H: {iheight} W: {iwidth}) smaller than "
  4093. f"kernel size (kT: {kT} kH: {kH} kW: {kW})"
  4094. ),
  4095. )
  4096. torch._check(
  4097. kT / 2 >= pT and kW / 2 >= pW and kH / 2 >= pH,
  4098. lambda: (
  4099. f"pad should be smaller than or equal to half of kernel size, but got "
  4100. f"kT: {kT} kW: {kW} kH: {kH} padT: {pT} padW: {pW} padH: {pH}"
  4101. ),
  4102. )
  4103. torch._check(
  4104. otime >= 1 and owidth >= 1 and oheight >= 1,
  4105. lambda: (
  4106. f"Given input size: ({nslices}x{itime}x{iheight}x{iwidth}). "
  4107. f"Calculated output size: ({nslices}x{otime}x{oheight}x{owidth}). "
  4108. f"Output size is too small"
  4109. ),
  4110. )
  4111. def max_pool3d_backward_shape_check(
  4112. input,
  4113. grad_output,
  4114. indices,
  4115. nslices,
  4116. kT,
  4117. kH,
  4118. kW,
  4119. dT,
  4120. dH,
  4121. dW,
  4122. pT,
  4123. pH,
  4124. pW,
  4125. dilationT,
  4126. dilationH,
  4127. dilationW,
  4128. itime,
  4129. iheight,
  4130. iwidth,
  4131. otime,
  4132. oheight,
  4133. owidth,
  4134. fn_name,
  4135. ):
  4136. ndim = input.ndim
  4137. pool3d_shape_check(
  4138. input,
  4139. nslices,
  4140. kT,
  4141. kH,
  4142. kW,
  4143. dT,
  4144. dH,
  4145. dW,
  4146. pT,
  4147. pH,
  4148. pW,
  4149. dilationT,
  4150. dilationH,
  4151. dilationW,
  4152. itime,
  4153. iheight,
  4154. iwidth,
  4155. otime,
  4156. oheight,
  4157. owidth,
  4158. fn_name,
  4159. )
  4160. check_dim_size(grad_output, ndim, ndim - 4, nslices)
  4161. check_dim_size(grad_output, ndim, ndim - 3, otime)
  4162. check_dim_size(grad_output, ndim, ndim - 2, oheight)
  4163. check_dim_size(grad_output, ndim, ndim - 1, owidth)
  4164. check_dim_size(indices, ndim, ndim - 4, nslices)
  4165. check_dim_size(indices, ndim, ndim - 3, otime)
  4166. check_dim_size(indices, ndim, ndim - 2, oheight)
  4167. check_dim_size(indices, ndim, ndim - 1, owidth)
  4168. def avg_pool3d_backward_shape_check(
  4169. input: Tensor,
  4170. grad_output: Tensor,
  4171. nslices: int,
  4172. kT: int,
  4173. kH: int,
  4174. kW: int,
  4175. dT: int,
  4176. dH: int,
  4177. dW: int,
  4178. pT: int,
  4179. pH: int,
  4180. pW: int,
  4181. itime: int,
  4182. iheight: int,
  4183. iwidth: int,
  4184. otime: int,
  4185. oheight: int,
  4186. owidth: int,
  4187. fn_name: str,
  4188. ):
  4189. ndim = input.ndim
  4190. pool3d_shape_check(
  4191. input,
  4192. nslices,
  4193. kT,
  4194. kH,
  4195. kW,
  4196. dT,
  4197. dH,
  4198. dW,
  4199. pT,
  4200. pH,
  4201. pW,
  4202. 1,
  4203. 1,
  4204. 1,
  4205. itime,
  4206. iheight,
  4207. iwidth,
  4208. otime,
  4209. oheight,
  4210. owidth,
  4211. fn_name,
  4212. True,
  4213. )
  4214. check_dim_size(grad_output, ndim, ndim - 4, nslices)
  4215. check_dim_size(grad_output, ndim, ndim - 3, otime)
  4216. check_dim_size(grad_output, ndim, ndim - 2, oheight)
  4217. check_dim_size(grad_output, ndim, ndim - 1, owidth)
  4218. def max_pool2d_checks_and_compute_shape(
  4219. input,
  4220. kernel_size,
  4221. stride,
  4222. padding,
  4223. dilation,
  4224. ceil_mode,
  4225. ):
  4226. # Reference: aten/src/ATen/native/DilatedMaxPool2d.cpp
  4227. def unpack(name, val):
  4228. torch._check(
  4229. len(val) in [1, 2],
  4230. lambda: f"max_pool2d: {name} must either be a single int, or a tuple of two ints",
  4231. )
  4232. H = val[0]
  4233. W = H if len(val) == 1 else val[1]
  4234. return H, W
  4235. kH, kW = unpack("kernel_size", kernel_size)
  4236. torch._check(
  4237. len(stride) in [0, 1, 2],
  4238. lambda: "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
  4239. )
  4240. if len(stride) == 0:
  4241. dH, dW = kH, kW
  4242. else:
  4243. dH, dW = unpack("stride", stride)
  4244. padH, padW = unpack("padding", padding)
  4245. dilationH, dilationW = unpack("dilation", dilation)
  4246. nInputPlane = input.size(-3)
  4247. inputHeight = input.size(-2)
  4248. inputWidth = input.size(-1)
  4249. memory_format = utils.suggest_memory_format(input)
  4250. if memory_format == torch.channels_last:
  4251. torch._check(
  4252. input.dim() == 4,
  4253. lambda: "non-empty 4D (batch mode) tensor expected for input with channels_last layout",
  4254. )
  4255. elif memory_format == torch.contiguous_format:
  4256. torch._check(
  4257. input.dim() in [3, 4],
  4258. lambda: "non-empty 3D or 4D (batch mode) tensor expected for input",
  4259. )
  4260. else:
  4261. torch._check(
  4262. False,
  4263. lambda: "Unsupported memory format. Supports only ChannelsLast, Contiguous",
  4264. )
  4265. outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode)
  4266. outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode)
  4267. pool2d_shape_check(
  4268. input,
  4269. kH,
  4270. kW,
  4271. dH,
  4272. dW,
  4273. padH,
  4274. padW,
  4275. dilationH,
  4276. dilationW,
  4277. nInputPlane,
  4278. inputHeight,
  4279. inputWidth,
  4280. outputHeight,
  4281. outputWidth,
  4282. memory_format,
  4283. )
  4284. return nInputPlane, outputHeight, outputWidth
  4285. @register_meta(aten.max_pool2d_with_indices_backward.default)
  4286. def meta_max_pool2d_with_indices_backward(
  4287. grad_output,
  4288. self,
  4289. kernel_size,
  4290. stride,
  4291. padding,
  4292. dilation,
  4293. ceil_mode,
  4294. indices,
  4295. ):
  4296. (
  4297. nInputPlane,
  4298. outputHeight,
  4299. outputWidth,
  4300. ) = max_pool2d_checks_and_compute_shape(
  4301. self, kernel_size, stride, padding, dilation, ceil_mode
  4302. )
  4303. torch._check(
  4304. self.dtype == grad_output.dtype,
  4305. lambda: f"Expected dtype {self.dtype} for `gradOutput` but got dtype {grad_output.dtype}",
  4306. )
  4307. nOutputPlane = nInputPlane
  4308. ndim = self.ndim
  4309. def _check_dim_size(t):
  4310. check_dim_size(t, ndim, ndim - 3, nOutputPlane)
  4311. check_dim_size(t, ndim, ndim - 2, outputHeight)
  4312. check_dim_size(t, ndim, ndim - 1, outputWidth)
  4313. _check_dim_size(grad_output)
  4314. _check_dim_size(indices)
  4315. memory_format = utils.suggest_memory_format(self)
  4316. return torch.empty(
  4317. self.shape,
  4318. dtype=self.dtype,
  4319. device=self.device,
  4320. memory_format=memory_format,
  4321. )
  4322. @register_meta(aten.max_pool2d_with_indices.default)
  4323. def meta_max_pool2d_with_indices(
  4324. input,
  4325. kernel_size,
  4326. stride=(),
  4327. padding=(0,),
  4328. dilation=(1,),
  4329. ceil_mode=False,
  4330. ):
  4331. (
  4332. nInputPlane,
  4333. outputHeight,
  4334. outputWidth,
  4335. ) = max_pool2d_checks_and_compute_shape(
  4336. input, kernel_size, stride, padding, dilation, ceil_mode
  4337. )
  4338. nbatch = input.size(-4) if input.dim() == 4 else 1
  4339. memory_format = utils.suggest_memory_format(input)
  4340. if input.dim() == 3:
  4341. size = [nInputPlane, outputHeight, outputWidth]
  4342. else:
  4343. size = [nbatch, nInputPlane, outputHeight, outputWidth]
  4344. return (
  4345. torch.empty(
  4346. size,
  4347. dtype=input.dtype,
  4348. device=input.device,
  4349. memory_format=memory_format,
  4350. ),
  4351. torch.empty(
  4352. size,
  4353. dtype=torch.int64,
  4354. device=input.device,
  4355. memory_format=memory_format,
  4356. ),
  4357. )
  4358. @register_meta(aten.fractional_max_pool2d.default)
  4359. def meta_fractional_max_pool2d(self, kernel_size, output_size, random_samples):
  4360. torch._check(
  4361. self.ndim in (3, 4),
  4362. lambda: f"fractional_max_pool2d: Expected 3D or 4D tensor, but got: {self.ndim}",
  4363. )
  4364. ndim = self.ndim
  4365. for d in range(ndim - 3, ndim):
  4366. torch._check(
  4367. self.size(d) > 0,
  4368. lambda: f"fractional_max_pool2d: Expected input to have non-zero "
  4369. f" size for non-batch dimensions, but got {self.size()} with dimension {d} empty",
  4370. )
  4371. # the check and message are out of sync, but this matches the structured meta
  4372. torch._check(
  4373. len(kernel_size) == 2,
  4374. lambda: "fractional_max_pool2d: kernel_size must"
  4375. "either be a single int or tuple of Ints",
  4376. )
  4377. torch._check(
  4378. len(output_size) == 2,
  4379. lambda: "fractional_max_pool2d: output_size must "
  4380. "either be a single int or tuple of Ints",
  4381. )
  4382. input_channels = self.size(-3)
  4383. input_height = self.size(-2)
  4384. input_width = self.size(-1)
  4385. if ndim == 4:
  4386. input_batch = self.size(0)
  4387. else:
  4388. input_batch = 1
  4389. torch._check(
  4390. self.dtype == random_samples.dtype,
  4391. lambda: "Expect _random_samples to have the same dtype as input",
  4392. )
  4393. torch._check(
  4394. random_samples.ndim == 3,
  4395. lambda: f"Expect _random samples to have 3 dimensions got, {random_samples.ndim}",
  4396. )
  4397. n = random_samples.size(0)
  4398. c = random_samples.size(1)
  4399. d = random_samples.size(2)
  4400. torch._check(
  4401. n >= input_batch,
  4402. lambda: "Expect _random_samples.size(0) no less then input batch size.",
  4403. )
  4404. torch._check(
  4405. c == input_channels,
  4406. lambda: "Expect _random_samples.size(1) equals to input channel size.",
  4407. )
  4408. torch._check(d == 2, lambda: f"Expect _random_samples.size(2) equals to 2 got {d}.")
  4409. torch._check(
  4410. output_size[0] + kernel_size[0] - 1 <= input_height,
  4411. lambda: f"fractional_max_pool2d: kernel height {kernel_size[0]} is too large relative to input height {input_height}",
  4412. )
  4413. torch._check(
  4414. output_size[1] + kernel_size[1] - 1 <= input_width,
  4415. lambda: f"fractional_max_pool2d: kernel width {kernel_size[1]} is too large relative to input width {input_width}",
  4416. )
  4417. if self.dim() == 4:
  4418. size = [input_batch, input_channels, output_size[0], output_size[1]]
  4419. else:
  4420. size = [input_channels, output_size[0], output_size[1]]
  4421. return (
  4422. torch.empty(
  4423. size,
  4424. dtype=self.dtype,
  4425. device=self.device,
  4426. ),
  4427. torch.empty(
  4428. size,
  4429. dtype=torch.int64,
  4430. device=self.device,
  4431. ),
  4432. )
  4433. @register_meta(aten.max_pool3d_with_indices)
  4434. @out_wrapper("out", "indices")
  4435. def meta_max_pool3d_with_indices(
  4436. input,
  4437. kernel_size,
  4438. stride=(),
  4439. padding=(0,),
  4440. dilation=(1,),
  4441. ceil_mode=False,
  4442. ):
  4443. torch._check(
  4444. len(kernel_size) in (1, 3),
  4445. lambda: "max_pool3d: kernel_size must either be a single int, or a tuple of three ints",
  4446. )
  4447. kT = kernel_size[0]
  4448. kH = kT if len(kernel_size) == 1 else kernel_size[1]
  4449. kW = kT if len(kernel_size) == 1 else kernel_size[2]
  4450. torch._check(
  4451. not stride or len(stride) in (1, 3),
  4452. lambda: "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints",
  4453. )
  4454. dT = kT if not stride else stride[0]
  4455. dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
  4456. dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
  4457. torch._check(
  4458. len(padding) in (1, 3),
  4459. lambda: "max_pool3d: padding must either be a single int, or a tuple of three ints",
  4460. )
  4461. pT = padding[0]
  4462. pH = pT if len(padding) == 1 else padding[1]
  4463. pW = pT if len(padding) == 1 else padding[2]
  4464. torch._check(
  4465. len(dilation) in (1, 3),
  4466. lambda: "max_pool3d: dilation must be either a single int, or a tuple of three ints",
  4467. )
  4468. dilationT = dilation[0]
  4469. dilationH = dilationT if len(dilation) == 1 else dilation[1]
  4470. dilationW = dilationT if len(dilation) == 1 else dilation[2]
  4471. torch._check(
  4472. input.ndim in (4, 5),
  4473. lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
  4474. )
  4475. nbatch = input.size(-5) if input.ndim == 5 else 1
  4476. nslices = input.size(-4)
  4477. itime = input.size(-3)
  4478. iheight = input.size(-2)
  4479. iwidth = input.size(-1)
  4480. otime = pooling_output_shape(itime, kT, pT, dT, dilationT, ceil_mode)
  4481. oheight = pooling_output_shape(iheight, kH, pH, dH, dilationH, ceil_mode)
  4482. owidth = pooling_output_shape(iwidth, kW, pW, dW, dilationW, ceil_mode)
  4483. pool3d_shape_check(
  4484. input,
  4485. nslices,
  4486. kT,
  4487. kH,
  4488. kW,
  4489. dT,
  4490. dH,
  4491. dW,
  4492. pT,
  4493. pH,
  4494. pW,
  4495. dilationT,
  4496. dilationH,
  4497. dilationW,
  4498. itime,
  4499. iheight,
  4500. iwidth,
  4501. otime,
  4502. oheight,
  4503. owidth,
  4504. "max_pool3d_with_indices()",
  4505. )
  4506. # channels_last_3d only applies to 5D tensors (C++ enforces this)
  4507. channels_last = (
  4508. input.ndim == 5 and utils.suggest_memory_format(input) == torch.channels_last_3d
  4509. )
  4510. if input.ndim == 4:
  4511. out_shape = (nslices, otime, oheight, owidth)
  4512. else:
  4513. out_shape = (nbatch, nslices, otime, oheight, owidth) # type: ignore[assignment]
  4514. out = input.new_empty(out_shape)
  4515. indices = input.new_empty(out_shape, dtype=torch.int64)
  4516. if channels_last:
  4517. out = out.to(memory_format=torch.channels_last_3d)
  4518. indices = indices.to(memory_format=torch.channels_last_3d)
  4519. return out, indices
  4520. @register_meta(aten.max_pool3d_with_indices_backward)
  4521. @out_wrapper("grad_input")
  4522. def meta_max_pool3d_with_indices_backward(
  4523. grad_output,
  4524. input,
  4525. kernel_size,
  4526. stride,
  4527. padding,
  4528. dilation,
  4529. ceil_mode,
  4530. indices,
  4531. ):
  4532. torch._check(
  4533. len(kernel_size) in (1, 3),
  4534. lambda: "max_pool3d: kernel_size must either be a single int, or a tuple of three ints",
  4535. )
  4536. kT = kernel_size[0]
  4537. kH = kT if len(kernel_size) == 1 else kernel_size[1]
  4538. kW = kT if len(kernel_size) == 1 else kernel_size[2]
  4539. torch._check(
  4540. not stride or len(stride) in (1, 3),
  4541. lambda: "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints",
  4542. )
  4543. dT = kT if not stride else stride[0]
  4544. dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
  4545. dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
  4546. torch._check(
  4547. len(padding) in (1, 3),
  4548. lambda: "max_pool3d: padding must either be a single int, or a tuple of three ints",
  4549. )
  4550. pT = padding[0]
  4551. pH = pT if len(padding) == 1 else padding[1]
  4552. pW = pT if len(padding) == 1 else padding[2]
  4553. torch._check(
  4554. len(dilation) in (1, 3),
  4555. lambda: "max_pool3d: dilation must be either a single int, or a tuple of three ints",
  4556. )
  4557. dilationT = dilation[0]
  4558. dilationH = dilationT if len(dilation) == 1 else dilation[1]
  4559. dilationW = dilationT if len(dilation) == 1 else dilation[2]
  4560. torch._check(
  4561. input.ndim in (4, 5),
  4562. lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
  4563. )
  4564. nslices = input.size(-4)
  4565. itime = input.size(-3)
  4566. iheight = input.size(-2)
  4567. iwidth = input.size(-1)
  4568. otime = grad_output.size(-3)
  4569. oheight = grad_output.size(-2)
  4570. owidth = grad_output.size(-1)
  4571. max_pool3d_backward_shape_check(
  4572. input,
  4573. grad_output,
  4574. indices,
  4575. nslices,
  4576. kT,
  4577. kH,
  4578. kW,
  4579. dT,
  4580. dH,
  4581. dW,
  4582. pT,
  4583. pH,
  4584. pW,
  4585. dilationT,
  4586. dilationH,
  4587. dilationW,
  4588. itime,
  4589. iheight,
  4590. iwidth,
  4591. otime,
  4592. oheight,
  4593. owidth,
  4594. "max_pool3d_with_indices_backward()",
  4595. )
  4596. # channels_last_3d only applies to 5D tensors (C++ enforces this)
  4597. channels_last = (
  4598. input.ndim == 5 and utils.suggest_memory_format(input) == torch.channels_last_3d
  4599. )
  4600. grad_input = input.new_empty(input.shape)
  4601. if channels_last:
  4602. grad_input = grad_input.to(memory_format=torch.channels_last_3d)
  4603. return grad_input
  4604. def check_grid_sampler_common(input: Tensor, grid: Tensor):
  4605. torch._check(
  4606. input.device == grid.device,
  4607. lambda: (
  4608. f"grid_sampler(): expected input and grid to be on same device, but input "
  4609. f"is on {input.device} and grid is on {grid.device}"
  4610. ),
  4611. )
  4612. torch._check(
  4613. input.layout == torch.strided and grid.layout == torch.strided,
  4614. lambda: (
  4615. f"grid_sampler(): expected input and grid to have torch.strided layout, but "
  4616. f"input has {input.layout} and grid has {grid.layout}"
  4617. ),
  4618. )
  4619. torch._check(
  4620. input.shape[0] == grid.shape[0],
  4621. lambda: (
  4622. f"grid_sampler(): expected grid and input to have same batch size, but got "
  4623. f"input with sizes {input.shape} and grid with sizes {grid.shape}"
  4624. ),
  4625. )
  4626. torch._check(
  4627. grid.shape[-1] == input.ndim - 2,
  4628. lambda: (
  4629. f"grid_sampler(): expected grid to have size {input.ndim - 2} in last "
  4630. f"dimension, but got grid with sizes {grid.shape}"
  4631. ),
  4632. )
  4633. for i in range(2, input.ndim):
  4634. torch._check(
  4635. input.shape[i] > 0,
  4636. lambda: (
  4637. f"grid_sampler(): expected input to have non-empty spatial dimensions, "
  4638. f"but input has sizes {input.shape} with dimension {i} being empty"
  4639. ),
  4640. )
  4641. class GridSamplerInterpolation(Enum):
  4642. BILINEAR = 0
  4643. NEAREST = 1
  4644. BICUBIC = 2
  4645. def check_grid_sampler_3d(input: Tensor, grid: Tensor, interpolation_mode: int):
  4646. torch._check(
  4647. input.ndim == 5 and input.ndim == grid.ndim,
  4648. lambda: (
  4649. f"grid_sampler(): expected 5D input and grid with same number of "
  4650. f"dimensions, but got input with sizes {input.shape}"
  4651. f" and grid with sizes {grid.shape}"
  4652. ),
  4653. )
  4654. torch._check(
  4655. not (
  4656. input.ndim == 5
  4657. and interpolation_mode == GridSamplerInterpolation.BICUBIC.value
  4658. ),
  4659. lambda: "grid_sampler(): bicubic interpolation only supports 4D input",
  4660. )
  4661. @register_meta(aten.grid_sampler_2d_backward.default)
  4662. def grid_sampler_2d_backward_meta(
  4663. grad_output,
  4664. input,
  4665. grid,
  4666. interpolation_mode,
  4667. padding_mode,
  4668. align_corners,
  4669. output_mask,
  4670. ):
  4671. input_requires_grad = output_mask[0]
  4672. if input_requires_grad:
  4673. grad_input = torch.zeros_like(input, memory_format=torch.contiguous_format)
  4674. else:
  4675. grad_input = None
  4676. grad_grid = torch.empty_like(grid, memory_format=torch.contiguous_format)
  4677. return (grad_input, grad_grid)
  4678. @register_meta(aten.grid_sampler_3d)
  4679. @out_wrapper()
  4680. def grid_sampler_3d(
  4681. input,
  4682. grid,
  4683. interpolation_mode,
  4684. padding_mode,
  4685. align_corners,
  4686. ):
  4687. check_grid_sampler_common(input, grid)
  4688. check_grid_sampler_3d(input, grid, interpolation_mode)
  4689. N = input.shape[0]
  4690. C = input.shape[1]
  4691. out_D = grid.shape[1]
  4692. out_H = grid.shape[2]
  4693. out_W = grid.shape[3]
  4694. return input.new_empty((N, C, out_D, out_H, out_W))
  4695. @register_meta(aten.grid_sampler_3d_backward)
  4696. @out_wrapper("grad_input", "grad_grid")
  4697. def grid_sampler_3d_backward(
  4698. grad_output,
  4699. input,
  4700. grid,
  4701. interpolation_mode,
  4702. padding_mode,
  4703. align_corners,
  4704. output_mask,
  4705. ):
  4706. check_grid_sampler_common(input, grid)
  4707. check_grid_sampler_3d(input, grid, interpolation_mode)
  4708. input_requires_grad = output_mask[0]
  4709. if input_requires_grad:
  4710. grad_input = torch.zeros_like(
  4711. input, memory_format=torch.legacy_contiguous_format
  4712. )
  4713. else:
  4714. grad_input = None
  4715. grad_grid = torch.empty_like(grid, memory_format=torch.legacy_contiguous_format)
  4716. return grad_input, grad_grid
  4717. @register_meta([aten.full.default])
  4718. def full(size, fill_value, *args, **kwargs):
  4719. dtype = kwargs.get("dtype")
  4720. if not dtype:
  4721. dtype = utils.get_dtype(fill_value)
  4722. kwargs["dtype"] = dtype
  4723. return torch.empty(size, *args, **kwargs)
  4724. # zeros_like is special cased to work for sparse
  4725. @register_meta(aten.zeros_like.default)
  4726. def zeros_like(
  4727. self,
  4728. dtype=None,
  4729. layout=None,
  4730. device=None,
  4731. pin_memory=None,
  4732. memory_format=None,
  4733. ):
  4734. if layout == torch.sparse_coo:
  4735. torch._check(
  4736. memory_format is None,
  4737. lambda: "memory format option is only supported by strided tensors",
  4738. )
  4739. res = torch.empty(
  4740. 0,
  4741. dtype=self.dtype if dtype is None else dtype,
  4742. layout=layout,
  4743. device=self.device if device is None else device,
  4744. pin_memory=pin_memory,
  4745. )
  4746. if self.is_sparse:
  4747. res.sparse_resize_and_clear_(
  4748. self.size(), self.sparse_dim(), self.dense_dim()
  4749. )
  4750. else:
  4751. res.sparse_resize_and_clear_(self.size(), self.dim(), 0)
  4752. res._coalesced_(True)
  4753. return res
  4754. res = aten.empty_like.default(
  4755. self,
  4756. dtype=dtype,
  4757. layout=layout,
  4758. device=device,
  4759. pin_memory=pin_memory,
  4760. memory_format=memory_format,
  4761. )
  4762. # device can be not "meta"
  4763. res.fill_(0)
  4764. return res
  4765. @register_meta([aten.ones.default, aten.ones.out])
  4766. @out_wrapper()
  4767. def meta_ones(
  4768. size,
  4769. *,
  4770. dtype=None,
  4771. layout=None,
  4772. device=None,
  4773. pin_memory=None,
  4774. requires_grad=False,
  4775. ):
  4776. if dtype is None:
  4777. dtype = torch.get_default_dtype()
  4778. if device is None:
  4779. device = torch.get_default_device()
  4780. if layout is None:
  4781. layout = torch.strided
  4782. return torch.empty(
  4783. size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  4784. )
  4785. @register_meta([aten.zeros.default, aten.zeros.out])
  4786. @out_wrapper()
  4787. def meta_zeros(
  4788. size,
  4789. *,
  4790. dtype=None,
  4791. layout=None,
  4792. device=None,
  4793. pin_memory=None,
  4794. requires_grad=False,
  4795. ):
  4796. if dtype is None:
  4797. dtype = torch.get_default_dtype()
  4798. if device is None:
  4799. device = torch.get_default_device()
  4800. if layout is None:
  4801. layout = torch.strided
  4802. return torch.empty(
  4803. size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  4804. )
  4805. @register_meta(aten.select_scatter.default)
  4806. def meta_select_scatter(self, src, dim, index):
  4807. return utils.clone_preserve_strides(self)
  4808. @register_meta(aten.slice_scatter.default)
  4809. def meta_slice_scatter(self, src, dim=0, start=None, end=None, step=1):
  4810. return utils.clone_preserve_strides(self)
  4811. # TODO: Deduplicate this with canonicalize_dim
  4812. def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True):
  4813. if dim_post_expr <= 0:
  4814. if not wrap_scalar:
  4815. raise AssertionError(
  4816. f"dim_post_expr={dim_post_expr} <= 0 but wrap_scalar is False"
  4817. )
  4818. dim_post_expr = 1
  4819. min = -dim_post_expr
  4820. max = dim_post_expr - 1
  4821. if dim < min or dim > max:
  4822. raise AssertionError(f"dim {dim} out of bounds ({min}, {max})")
  4823. if dim < 0:
  4824. dim += dim_post_expr
  4825. return dim
  4826. def ensure_nonempty_size(t, dim):
  4827. return 1 if t.dim() == 0 else t.shape[dim]
  4828. # From aten/src/ATen/native/ScatterGatherChecks.h
  4829. def gather_shape_check(self, dim, index):
  4830. self_dims = max(self.dim(), 1)
  4831. index_dims = max(index.dim(), 1)
  4832. torch._check(
  4833. self_dims == index_dims,
  4834. lambda: "Index tensor must have the same number of dimensions as input tensor",
  4835. )
  4836. for i in range(self_dims):
  4837. if i != dim:
  4838. torch._check(
  4839. ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i),
  4840. lambda: f"Size does not match at dimension {i} expected index {index.shape}"
  4841. + f" to be no larger than self {self.shape} apart from dimension {dim}",
  4842. )
  4843. @register_meta(aten.gather.default)
  4844. def meta_gather(self, dim, index, sparse_grad=False):
  4845. from torch.fx.experimental.symbolic_shapes import guard_or_false
  4846. wrapped_dim = maybe_wrap_dim(dim, self.dim())
  4847. is_index_empty = guard_or_false(index.numel() == 0)
  4848. if not is_index_empty:
  4849. torch._check(
  4850. index.dtype == torch.long or index.dtype == torch.int,
  4851. lambda: f"gather(): Expected dtype int32/int64 for index, but got {index.dtype}",
  4852. )
  4853. gather_shape_check(self, wrapped_dim, index)
  4854. return self.new_empty(index.shape)
  4855. # From aten/src/ATen/native/TensorAdvancedIndexing.cpp
  4856. def get_operator_enum(reduce_, use_new_options=False):
  4857. if use_new_options:
  4858. if reduce_ == "sum":
  4859. return "REDUCE_ADD"
  4860. elif reduce_ == "prod":
  4861. return "REDUCE_MULTIPLY"
  4862. elif reduce_ == "mean":
  4863. return "REDUCE_MEAN"
  4864. elif reduce_ == "amax":
  4865. return "REDUCE_MAXIMUM"
  4866. elif reduce_ == "amin":
  4867. return "REDUCE_MINIMUM"
  4868. torch._check(
  4869. False,
  4870. lambda: "reduce argument must be either sum, prod, mean, amax or amin.",
  4871. )
  4872. return
  4873. else:
  4874. if reduce_ == "add":
  4875. return "REDUCE_ADD"
  4876. elif reduce_ == "multiply":
  4877. return "REDUCE_MULTIPLY"
  4878. torch._check(False, lambda: "reduce argument must be either add or multiply.")
  4879. return
  4880. # From aten/src/ATen/native/ScatterGatherChecks.h
  4881. def scatter_gather_dtype_check(method_name, self, index, src_opt=None):
  4882. from torch.fx.experimental.symbolic_shapes import guard_or_true
  4883. if guard_or_true(index.numel() != 0):
  4884. torch._check(
  4885. index.dtype == torch.long or index.dtype == torch.int,
  4886. lambda: f"{method_name}(): Expected dtype int32/int64 for index",
  4887. )
  4888. if src_opt is not None:
  4889. torch._check(
  4890. self.dtype == src_opt.dtype,
  4891. lambda: f"{method_name}(): Expected self.dtype to be equal to src.dtype",
  4892. )
  4893. def ensure_nonempty_dim(dim):
  4894. return max(dim, 1)
  4895. # From aten/src/ATen/native/ScatterGatherChecks.h
  4896. def scatter_shape_check(self, dim, index, src_opt=None):
  4897. from torch.fx.experimental.symbolic_shapes import guard_or_false
  4898. if guard_or_false(index.numel() == 0):
  4899. return
  4900. torch._check(
  4901. ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
  4902. lambda: "Index tensor must have the same number of dimensions as self tensor",
  4903. )
  4904. self_dims = ensure_nonempty_dim(self.dim())
  4905. # Check: index.size(d) <= self.size(d) for all d != dim
  4906. # Use torch._check to defer validation to runtime for unbacked symbols.
  4907. for d in range(self_dims):
  4908. if d == dim:
  4909. continue
  4910. index_d_size = ensure_nonempty_size(index, d)
  4911. self_d_size = ensure_nonempty_size(self, d)
  4912. torch._check(
  4913. index_d_size <= self_d_size,
  4914. lambda: f"Expected index {index.shape} to be no larger than self {self.shape}"
  4915. + f" apart from dimension {dim}",
  4916. )
  4917. # Check: index.size(d) <= src.size(d) for all d if src is Tensor
  4918. if src_opt is not None:
  4919. torch._check(
  4920. ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(src_opt.dim()),
  4921. lambda: "Index tensor must have the same number of dimensions as src tensor",
  4922. )
  4923. for d in range(self_dims):
  4924. index_d_size = ensure_nonempty_size(index, d)
  4925. src_d_size = ensure_nonempty_size(src_opt, d)
  4926. torch._check(
  4927. index_d_size <= src_d_size,
  4928. lambda: f"Expected index {index.shape} to be no larger than src {src_opt.shape}",
  4929. )
  4930. # From aten/src/ATen/native/TensorAdvancedIndexing.cpp
  4931. def scatter_meta_impl(self, dim, index, src=None, reduce_=None, use_new_options=False):
  4932. wrapped_dim = maybe_wrap_dim(dim, self.dim())
  4933. scatter_gather_dtype_check("scatter", self, index, src)
  4934. scatter_shape_check(self, wrapped_dim, index, src)
  4935. if reduce_ is not None:
  4936. # Check if we have a valid reduce operator.
  4937. get_operator_enum(reduce_, use_new_options)
  4938. @register_meta(aten.scatter_add.default)
  4939. def meta_scatter_add(self, dim, index, src):
  4940. scatter_meta_impl(self, dim, index, src, "add")
  4941. return self.new_empty(self.shape)
  4942. @register_meta(aten.scatter_add_)
  4943. def meta_scatter_add_(self, dim, index, src):
  4944. scatter_meta_impl(self, dim, index, src, "add")
  4945. return self
  4946. @register_meta(
  4947. [
  4948. aten.scatter.src,
  4949. aten.scatter.value,
  4950. aten.scatter.reduce,
  4951. aten.scatter.value_reduce,
  4952. ]
  4953. )
  4954. @out_wrapper()
  4955. def meta_scatter(self, dim, index, src_or_value, reduce=None):
  4956. src = src_or_value if isinstance(src_or_value, torch.Tensor) else None
  4957. scatter_meta_impl(self, dim, index, src, reduce)
  4958. return self.new_empty(self.shape)
  4959. @register_meta(
  4960. [
  4961. aten.scatter_.src,
  4962. aten.scatter_.value,
  4963. aten.scatter_.reduce,
  4964. aten.scatter_.value_reduce,
  4965. ]
  4966. )
  4967. def meta_scatter_(self, dim, index, src_or_value, reduce=None):
  4968. src = src_or_value if isinstance(src_or_value, torch.Tensor) else None
  4969. scatter_meta_impl(self, dim, index, src, reduce)
  4970. return self
  4971. @register_meta([aten._scaled_dot_product_flash_attention.default])
  4972. def meta__scaled_dot_product_flash_attention(
  4973. query: Tensor,
  4974. key: Tensor,
  4975. value: Tensor,
  4976. dropout_p: float = 0.0,
  4977. is_causal: bool = False,
  4978. return_debug_mask: bool = False,
  4979. scale: float | None = None,
  4980. ):
  4981. batch_size = query.size(0)
  4982. num_heads = query.size(1)
  4983. max_seqlen_batch_q = query.size(2)
  4984. head_dim = query.size(3)
  4985. max_seqlen_batch_k = key.size(2)
  4986. attention = torch.empty_like(query)
  4987. logsumexp = torch.empty(
  4988. (batch_size, num_heads, max_seqlen_batch_q),
  4989. dtype=torch.float,
  4990. device=query.device,
  4991. )
  4992. if return_debug_mask:
  4993. blocksize_c = 128 if head_dim > 64 else 256
  4994. max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
  4995. if max_seqlen_batch_k <= 128:
  4996. max_seqlen_k = 128
  4997. elif max_seqlen_batch_k <= 256:
  4998. max_seqlen_k = 256
  4999. debug_mask = torch.empty(
  5000. (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
  5001. dtype=query.dtype,
  5002. device=query.device,
  5003. )
  5004. else:
  5005. debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
  5006. # Note [Seed and Offset]: device for seed and offset below depends on whether we are
  5007. # capturing or not, but at the time of tracing we don't know if we
  5008. # are going to use cudagraphs or not, so we return meta tensors here
  5009. # it's possible we'll need to have some special handling in inductor for sdpa
  5010. # See [Note] BC breaking change to flash seed/offset
  5011. if torch.version.hip and torch.cuda.is_available() or device_hint(query) == "xpu":
  5012. # Maintain old path on AMD
  5013. seed = torch.empty((), dtype=torch.long, device="meta")
  5014. offset = torch.empty((), dtype=torch.long, device="meta")
  5015. else:
  5016. seed = torch.empty((2), dtype=torch.uint64, device="meta")
  5017. offset = torch.empty((), dtype=torch.uint64, device="meta")
  5018. return (
  5019. attention,
  5020. logsumexp,
  5021. None,
  5022. None,
  5023. max_seqlen_batch_q,
  5024. max_seqlen_batch_k,
  5025. seed,
  5026. offset,
  5027. debug_mask,
  5028. )
  5029. @register_meta([aten._scaled_dot_product_flash_attention.quantized])
  5030. def meta__scaled_dot_product_flash_attention_quantized(
  5031. query: Tensor,
  5032. key: Tensor,
  5033. value: Tensor,
  5034. q_descale: Tensor | None,
  5035. k_descale: Tensor | None,
  5036. v_descale: Tensor | None,
  5037. dropout_p: float = 0.0,
  5038. is_causal: bool = False,
  5039. return_debug_mask: bool = False,
  5040. scale: float | None = None,
  5041. ):
  5042. if query.dtype == torch.float8_e4m3fn:
  5043. query = query.to(torch.bfloat16)
  5044. return meta__scaled_dot_product_flash_attention(
  5045. query,
  5046. key,
  5047. value,
  5048. dropout_p,
  5049. is_causal,
  5050. return_debug_mask,
  5051. scale,
  5052. )
  5053. def alloc_with_matching_layout(
  5054. query: Tensor,
  5055. res_shape: tuple[int, ...],
  5056. ):
  5057. if tuple(query.shape) == res_shape:
  5058. res = torch.empty_like(query)
  5059. else:
  5060. dim_order = sorted(
  5061. [0, 1, 2, 3], key=lambda idx: query.stride()[idx], reverse=True
  5062. )
  5063. permuted_shape = [res_shape[idx] for idx in dim_order]
  5064. final_permute = [dim_order.index(i) for i in range(len(dim_order))]
  5065. res = torch.empty(
  5066. permuted_shape, dtype=query.dtype, device=query.device
  5067. ).permute(final_permute)
  5068. return res
  5069. @register_meta([aten._scaled_dot_product_cudnn_attention])
  5070. def meta__scaled_dot_product_cudnn_attention(
  5071. query: Tensor,
  5072. key: Tensor,
  5073. value: Tensor,
  5074. attn_bias: Tensor | None,
  5075. compute_log_sumexp: bool,
  5076. dropout_p: float = 0.0,
  5077. is_causal: bool = False,
  5078. return_debug_mask: bool = False,
  5079. scale: float | None = None,
  5080. ):
  5081. B = query.size(0)
  5082. H = query.size(1)
  5083. S_Q = query.size(2)
  5084. S_KV = key.size(2)
  5085. D_V = value.size(-1)
  5086. res_shape = (B, H, S_Q, D_V)
  5087. res = alloc_with_matching_layout(query, res_shape)
  5088. logsum_exp = torch.empty(
  5089. (B, H, S_Q, 1),
  5090. dtype=torch.float,
  5091. device=query.device,
  5092. )
  5093. # See Note [Seed and Offset]
  5094. seed = torch.empty((), dtype=torch.long, device="meta")
  5095. offset = torch.empty((), dtype=torch.long, device="meta")
  5096. return (
  5097. res,
  5098. logsum_exp,
  5099. None,
  5100. None,
  5101. S_Q,
  5102. S_KV,
  5103. seed,
  5104. offset,
  5105. None,
  5106. )
  5107. @register_meta([aten._scaled_dot_product_fused_attention_overrideable])
  5108. def meta__scaled_dot_product_fused_attention_overrideable(
  5109. query: Tensor,
  5110. key: Tensor,
  5111. value: Tensor,
  5112. attn_bias: Tensor | None = None,
  5113. dropout_p: float = 0.0,
  5114. is_causal: bool = False,
  5115. return_debug_mask: bool = False,
  5116. scale: float | None = None,
  5117. ):
  5118. B = query.size(0)
  5119. H_Q = query.size(1)
  5120. S_Q = query.size(2)
  5121. S_KV = key.size(2)
  5122. D_V = value.size(-1)
  5123. res_shape = (B, H_Q, S_Q, D_V)
  5124. res = alloc_with_matching_layout(query, res_shape)
  5125. logsum_exp = torch.empty(
  5126. (B, H_Q, S_Q),
  5127. dtype=torch.float,
  5128. device=query.device,
  5129. )
  5130. # See Note [Seed and Offset]
  5131. seed = torch.empty((), dtype=torch.long, device="meta")
  5132. offset = torch.empty((), dtype=torch.long, device="meta")
  5133. return (
  5134. res,
  5135. logsum_exp,
  5136. None,
  5137. None,
  5138. S_Q,
  5139. S_KV,
  5140. seed,
  5141. offset,
  5142. None,
  5143. )
  5144. @register_meta(
  5145. [
  5146. aten._scaled_dot_product_flash_attention_backward,
  5147. ]
  5148. )
  5149. def meta__scaled_dot_product_flash_backward(
  5150. grad_out: Tensor,
  5151. query: Tensor,
  5152. key: Tensor,
  5153. value: Tensor,
  5154. out: Tensor,
  5155. logsumexp: Tensor,
  5156. cum_seq_q: Tensor,
  5157. cum_seq_k: Tensor,
  5158. max_q: int,
  5159. max_k: int,
  5160. dropout_p: float,
  5161. is_causal: bool,
  5162. philox_seed: Tensor,
  5163. philox_offset: Tensor,
  5164. scale: float | None = None,
  5165. ):
  5166. grad_q = torch.empty_like(query)
  5167. grad_k = torch.empty_like(key)
  5168. grad_v = torch.empty_like(value)
  5169. return grad_q, grad_k, grad_v
  5170. @register_meta(
  5171. [
  5172. aten._scaled_dot_product_flash_attention_for_cpu,
  5173. ]
  5174. )
  5175. def meta__scaled_dot_product_flash_attention_for_cpu(
  5176. query: Tensor,
  5177. key: Tensor,
  5178. value: Tensor,
  5179. dropout_p: float = 0.0,
  5180. is_causal: bool = False,
  5181. attn_mask: Tensor | None = None,
  5182. scale: float | None = None,
  5183. ):
  5184. batch_size = query.size(0)
  5185. num_heads = query.size(1)
  5186. max_seqlen_batch_q = query.size(2)
  5187. attention = torch.empty_like(query)
  5188. logsumexp = torch.empty(
  5189. (
  5190. batch_size,
  5191. max_seqlen_batch_q,
  5192. num_heads,
  5193. ),
  5194. dtype=torch.float,
  5195. device=query.device,
  5196. ).transpose(1, 2)
  5197. return (
  5198. attention,
  5199. logsumexp,
  5200. )
  5201. @register_meta(
  5202. [
  5203. aten._scaled_dot_product_flash_attention_for_cpu_backward,
  5204. ]
  5205. )
  5206. def meta__scaled_dot_product_flash_attention_for_cpu_backward(
  5207. grad_out: Tensor,
  5208. query: Tensor,
  5209. key: Tensor,
  5210. value: Tensor,
  5211. out: Tensor,
  5212. logsumexp: Tensor,
  5213. dropout_p: float,
  5214. is_causal: bool,
  5215. attn_mask: Tensor | None = None,
  5216. scale: float | None = None,
  5217. ):
  5218. # cpus's grad layout is different from cuda's,
  5219. # i.e. (batch_size, seq_len, num_heads, head_dim)
  5220. grad_q = torch.empty_permuted(
  5221. query.size(),
  5222. (0, 2, 1, 3),
  5223. dtype=query.dtype,
  5224. device=query.device,
  5225. )
  5226. grad_k = torch.empty_permuted(
  5227. key.size(),
  5228. (0, 2, 1, 3),
  5229. dtype=key.dtype,
  5230. device=key.device,
  5231. )
  5232. grad_v = torch.empty_permuted(
  5233. value.size(),
  5234. (0, 2, 1, 3),
  5235. dtype=value.dtype,
  5236. device=value.device,
  5237. )
  5238. return grad_q, grad_k, grad_v
  5239. @register_meta([aten._scaled_dot_product_attention_math_for_mps])
  5240. def meta__scaled_dot_product_attention_math_for_mps(
  5241. query: Tensor,
  5242. key: Tensor,
  5243. value: Tensor,
  5244. attn_mask: Tensor | None = None,
  5245. dropout_p: float = 0.0,
  5246. is_causal: bool = False,
  5247. dropout_mask: Tensor | None = None,
  5248. scale: float | None = None,
  5249. ) -> tuple[Tensor, Tensor]:
  5250. def ensure_4d(x):
  5251. if x.dim() == 3:
  5252. return x.unsqueeze(0), True
  5253. elif x.dim() > 4:
  5254. batch_size = 1
  5255. for i in range(x.dim() - 3):
  5256. batch_size *= x.shape[i]
  5257. return x.view(batch_size, x.size(-3), x.size(-2), x.size(-1)), True
  5258. else:
  5259. return x, False
  5260. q_, unsqueezed = ensure_4d(query)
  5261. k_, _ = ensure_4d(key)
  5262. v_, _ = ensure_4d(value)
  5263. batch_size, num_head, q_size, head_size = q_.shape
  5264. _, k_size, max_seq_length, _ = k_.shape
  5265. def sdpa_vector_fast_mps():
  5266. out = q_.new_empty(q_.shape)
  5267. if unsqueezed:
  5268. out = out.view_as(query)
  5269. attn = q_.new_empty((batch_size, num_head, q_size, max_seq_length))
  5270. if unsqueezed:
  5271. if query.dim() == 3:
  5272. attn = attn.squeeze(0)
  5273. else:
  5274. shape = list(query.shape[:-3]) + attn.shape[1:4]
  5275. attn = attn.view(shape)
  5276. return out, attn
  5277. def sdpa_vector_2pass_mps():
  5278. blocks = 32
  5279. out = q_.new_empty(q_.shape)
  5280. intermediate = q_.new_empty((batch_size, num_head, q_size, blocks, head_size))
  5281. return out, intermediate
  5282. if (max_seq_length >= 1024) or (k_size < q_size and max_seq_length >= 4096):
  5283. return sdpa_vector_2pass_mps()
  5284. else:
  5285. return sdpa_vector_fast_mps()
  5286. @register_meta([aten._scaled_dot_product_efficient_attention])
  5287. def meta__scaled_dot_product_efficient_attention(
  5288. query: Tensor,
  5289. key: Tensor,
  5290. value: Tensor,
  5291. attn_bias: Tensor | None,
  5292. compute_log_sumexp: bool,
  5293. dropout_p=0.0,
  5294. is_causal: bool = False,
  5295. scale: float | None = None,
  5296. ):
  5297. query = query.transpose(1, 2)
  5298. key = key.transpose(1, 2)
  5299. value = value.transpose(1, 2)
  5300. B = query.size(0)
  5301. M = query.size(1)
  5302. num_heads = query.size(-2)
  5303. Kv = value.size(-1)
  5304. res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
  5305. if torch.version.hip and torch.cuda.is_available():
  5306. """Please see: https://github.com/pytorch/pytorch/issues/146848
  5307. longsumexp last dim should be seq length
  5308. """
  5309. logsumexp_dim = M if compute_log_sumexp else 0
  5310. else:
  5311. logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
  5312. logsum_exp = torch.empty(
  5313. (B, num_heads, logsumexp_dim),
  5314. dtype=torch.float,
  5315. device=query.device,
  5316. )
  5317. res = res.transpose(1, 2)
  5318. # See Note [Seed and Offset]:
  5319. seed = torch.empty((), dtype=torch.long, device="meta")
  5320. offset = torch.empty((), dtype=torch.long, device="meta")
  5321. return res, logsum_exp, seed, offset
  5322. @register_meta(
  5323. [
  5324. aten._scaled_dot_product_efficient_attention_backward,
  5325. ]
  5326. )
  5327. def meta__scaled_dot_product_efficient_backward(
  5328. grad_out: Tensor,
  5329. query: Tensor,
  5330. key: Tensor,
  5331. value: Tensor,
  5332. attn_bias: Tensor | None,
  5333. out: Tensor,
  5334. logsumexp: Tensor,
  5335. philox_seed: Tensor,
  5336. philox_offset: Tensor,
  5337. dropout_p: float,
  5338. grad_input_mask: list[bool],
  5339. is_causal: bool = False,
  5340. scale: float | None = None,
  5341. ):
  5342. batch_size = query.size(0)
  5343. num_heads = query.size(1)
  5344. max_q = query.size(2)
  5345. head_dim = query.size(3)
  5346. head_dim_v = value.size(3)
  5347. max_k = key.size(2)
  5348. grad_q = torch.empty_permuted(
  5349. (batch_size, num_heads, max_q, head_dim),
  5350. (0, 2, 1, 3),
  5351. dtype=query.dtype,
  5352. device=query.device,
  5353. )
  5354. grad_k = torch.empty_permuted(
  5355. (batch_size, num_heads, max_k, head_dim),
  5356. (0, 2, 1, 3),
  5357. dtype=key.dtype,
  5358. device=key.device,
  5359. )
  5360. grad_v = torch.empty_permuted(
  5361. (batch_size, num_heads, max_k, head_dim_v),
  5362. (0, 2, 1, 3),
  5363. dtype=value.dtype,
  5364. device=value.device,
  5365. )
  5366. grad_bias = None
  5367. if attn_bias is not None and grad_input_mask[3]:
  5368. lastDim = attn_bias.size(-1)
  5369. lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16
  5370. new_sizes = list(attn_bias.size())
  5371. new_sizes[-1] = lastDimAligned
  5372. grad_bias = torch.empty(
  5373. new_sizes, dtype=attn_bias.dtype, device=attn_bias.device
  5374. )
  5375. grad_bias = grad_bias[..., :lastDim]
  5376. return grad_q, grad_k, grad_v, grad_bias
  5377. @register_meta(
  5378. [
  5379. aten._scaled_dot_product_cudnn_attention_backward,
  5380. ]
  5381. )
  5382. def meta__scaled_dot_product_cudnn_backward(
  5383. grad_out: Tensor,
  5384. query: Tensor,
  5385. key: Tensor,
  5386. value: Tensor,
  5387. out: Tensor,
  5388. logsumexp: Tensor,
  5389. philox_seed: Tensor,
  5390. philox_offset: Tensor,
  5391. attn_bias: Tensor,
  5392. cum_seq_q: Tensor,
  5393. cum_seq_k: Tensor,
  5394. max_q: int,
  5395. max_k: int,
  5396. dropout_p: float,
  5397. is_causal: bool,
  5398. scale: float | None = None,
  5399. ):
  5400. grad_q = torch.empty_like(query)
  5401. grad_k = torch.empty_like(key)
  5402. grad_v = torch.empty_like(value)
  5403. return grad_q, grad_k, grad_v
  5404. @register_meta(
  5405. [
  5406. aten._flash_attention_forward.default,
  5407. ]
  5408. )
  5409. def meta__flash_attention_forward(
  5410. query: Tensor,
  5411. key: Tensor,
  5412. value: Tensor,
  5413. cum_seq_q: Tensor | None,
  5414. cum_seq_k: Tensor | None,
  5415. max_q: int,
  5416. max_k: int,
  5417. dropout_p: float,
  5418. is_causal: bool,
  5419. return_debug_mask: bool,
  5420. scale: float | None = None,
  5421. window_size_left: int | None = None,
  5422. window_size_right: int | None = None,
  5423. seqused_k: Tensor | None = None,
  5424. alibi_slopes: Tensor | None = None,
  5425. ):
  5426. # NB: there are two underlying paths:
  5427. # 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim)
  5428. # 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total
  5429. # includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total
  5430. batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1
  5431. max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q
  5432. max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k
  5433. num_heads = query.size(-2)
  5434. head_dim = query.size(-1)
  5435. # Cuda Path
  5436. attention = torch.empty_like(query)
  5437. if cum_seq_q is None:
  5438. logsumexp = torch.empty(
  5439. (batch_size, num_heads, max_seqlen_batch_q),
  5440. dtype=torch.float,
  5441. device=query.device,
  5442. )
  5443. else:
  5444. total_q = query.size(0)
  5445. logsumexp = torch.empty(
  5446. (num_heads, total_q), dtype=torch.float, device=query.device
  5447. )
  5448. if return_debug_mask:
  5449. blocksize_c = 128 if head_dim > 64 else 256
  5450. max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
  5451. if max_seqlen_batch_k <= 128:
  5452. max_seqlen_k = 128
  5453. elif max_seqlen_batch_k <= 256:
  5454. max_seqlen_k = 256
  5455. debug_mask = torch.empty(
  5456. (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
  5457. dtype=query.dtype,
  5458. device=query.device,
  5459. )
  5460. else:
  5461. debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
  5462. # See Note [Seed and Offset]
  5463. # See [Note] BC breaking change to flash seed/offset
  5464. seed, offset = None, None
  5465. if torch.version.hip and torch.cuda.is_available():
  5466. # Maintain old path on AMD
  5467. seed = torch.empty((), dtype=torch.long, device="meta")
  5468. offset = torch.empty((), dtype=torch.long, device="meta")
  5469. else:
  5470. seed = torch.empty((2), dtype=torch.uint64, device="meta")
  5471. offset = torch.empty((), dtype=torch.uint64, device="meta")
  5472. return (
  5473. attention,
  5474. logsumexp,
  5475. seed,
  5476. offset,
  5477. debug_mask,
  5478. )
  5479. @register_meta([aten._flash_attention_forward.quantized])
  5480. def meta__flash_attention_forward_quantized(
  5481. query: Tensor,
  5482. key: Tensor,
  5483. value: Tensor,
  5484. cum_seq_q: Tensor | None,
  5485. cum_seq_k: Tensor | None,
  5486. max_q: int,
  5487. max_k: int,
  5488. dropout_p: float,
  5489. is_causal: bool,
  5490. return_debug_mask: bool,
  5491. q_descale: Tensor | None,
  5492. k_descale: Tensor | None,
  5493. v_descale: Tensor | None,
  5494. scale: float | None = None,
  5495. window_size_left: int | None = None,
  5496. window_size_right: int | None = None,
  5497. seqused_k: Tensor | None = None,
  5498. alibi_slopes: Tensor | None = None,
  5499. ):
  5500. if query.dtype == torch.float8_e4m3fn:
  5501. query = query.to(torch.bfloat16)
  5502. return meta__flash_attention_forward(
  5503. query,
  5504. key,
  5505. value,
  5506. cum_seq_q,
  5507. cum_seq_k,
  5508. max_q,
  5509. max_k,
  5510. dropout_p,
  5511. is_causal,
  5512. return_debug_mask,
  5513. scale,
  5514. window_size_left,
  5515. window_size_right,
  5516. seqused_k,
  5517. alibi_slopes,
  5518. )
  5519. @register_meta(
  5520. [
  5521. aten._flash_attention_backward,
  5522. ]
  5523. )
  5524. def meta__flash_attention_backward(
  5525. grad_out: Tensor,
  5526. query: Tensor,
  5527. key: Tensor,
  5528. value: Tensor,
  5529. out: Tensor,
  5530. logsumexp: Tensor,
  5531. cum_seq_q: Tensor,
  5532. cum_seq_k: Tensor,
  5533. max_q: int,
  5534. max_k: int,
  5535. dropout_p: float,
  5536. is_causal: bool,
  5537. philox_seed: Tensor,
  5538. philox_offset: Tensor,
  5539. scale: float | None = None,
  5540. window_size_left: int | None = None,
  5541. window_size_right: int | None = None,
  5542. ):
  5543. grad_query = torch.empty_like(query)
  5544. grad_key = torch.empty_like(key)
  5545. grad_value = torch.empty_like(value)
  5546. return grad_query, grad_key, grad_value
  5547. @register_meta(
  5548. [
  5549. aten._efficient_attention_forward,
  5550. ]
  5551. )
  5552. def meta__efficient_attention_forward(
  5553. query: Tensor,
  5554. key: Tensor,
  5555. value: Tensor,
  5556. bias: Tensor | None,
  5557. cu_seqlens_q: Tensor | None,
  5558. cu_seqlens_k: Tensor | None,
  5559. max_seqlen_q: int | None,
  5560. max_seqlen_k: int | None,
  5561. dropout_p: float,
  5562. custom_mask_type: int,
  5563. compute_log_sumexp: bool = False,
  5564. scale: float | None = None,
  5565. causal_diagonal: Tensor | None = None,
  5566. seqlen_k: Tensor | None = None,
  5567. window_size: int | None = None,
  5568. ):
  5569. B = query.size(0)
  5570. M = query.size(1)
  5571. N = key.size(1)
  5572. num_heads = query.size(-2)
  5573. Kv = value.size(-1)
  5574. res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
  5575. logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B
  5576. actual_max_seqlen_q = M
  5577. if cu_seqlens_q is not None:
  5578. if max_seqlen_q is None:
  5579. raise AssertionError(
  5580. "max_seqlen_q must not be None when cu_seqlens_q is provided"
  5581. )
  5582. actual_max_seqlen_q = max_seqlen_q
  5583. actual_max_seqlen_k = max_seqlen_k if max_seqlen_k is not None else N
  5584. logsumexp_dim = (
  5585. math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0
  5586. )
  5587. logsum_exp = torch.empty(
  5588. (logsumexp_batch_dim, num_heads, logsumexp_dim),
  5589. dtype=torch.float,
  5590. device=query.device,
  5591. )
  5592. # See Note [Seed and Offset]:
  5593. seed = torch.empty((), dtype=torch.long, device="meta")
  5594. offset = torch.empty((), dtype=torch.long, device="meta")
  5595. return res, logsum_exp, seed, offset, actual_max_seqlen_q, actual_max_seqlen_k
  5596. @register_meta(
  5597. [
  5598. aten._efficient_attention_backward,
  5599. ]
  5600. )
  5601. def meta__efficient_attention_backward(
  5602. grad_out: Tensor,
  5603. query: Tensor,
  5604. key: Tensor,
  5605. value: Tensor,
  5606. bias: Tensor | None,
  5607. cu_seqlens_q: Tensor | None,
  5608. cu_seqlens_k: Tensor | None,
  5609. max_seqlen_q: torch.SymInt,
  5610. max_seqlen_k: torch.SymInt,
  5611. logsumexp: Tensor,
  5612. dropout_p: float,
  5613. philox_seed: Tensor,
  5614. philox_offset: Tensor,
  5615. custom_mask_type: int,
  5616. bias_requires_grad: bool,
  5617. scale: float | None = None,
  5618. num_splits_key: int | None = None,
  5619. shared_storage_dqdkdv: bool = False,
  5620. ):
  5621. if shared_storage_dqdkdv:
  5622. torch._check(
  5623. query.shape[1] == key.shape[1],
  5624. lambda: "seqlen must match for `shared_storage_dqdkdv",
  5625. )
  5626. torch._check(
  5627. query.shape[3] == key.shape[3],
  5628. lambda: "embedding dim must match for `shared_storage_dqdkdv",
  5629. )
  5630. chunk = torch.empty(
  5631. (*query.shape[0:-2], 3, query.shape[-2], query.shape[-1]),
  5632. dtype=query.dtype,
  5633. device=query.device,
  5634. )
  5635. grad_query = chunk.select(-3, 0)
  5636. grad_key = chunk.select(-3, 1)
  5637. grad_value = chunk.select(-3, 2)
  5638. else:
  5639. grad_query = torch.empty_like(query)
  5640. grad_key = torch.empty_like(key)
  5641. grad_value = torch.empty_like(value)
  5642. if bias is not None:
  5643. lastDim = bias.size(-1)
  5644. lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16
  5645. new_sizes = list(bias.size())
  5646. new_sizes[-1] = lastDimAligned
  5647. grad_bias = torch.empty(new_sizes, dtype=bias.dtype, device=bias.device)
  5648. grad_bias = grad_bias[..., :lastDim]
  5649. else:
  5650. grad_bias = torch.empty((), device=query.device)
  5651. return grad_query, grad_key, grad_value, grad_bias
  5652. def _check_scaled_mm_sizes(
  5653. self: torch.Tensor,
  5654. mat2: torch.Tensor,
  5655. scale_a: torch.Tensor,
  5656. scale_b: torch.Tensor,
  5657. bias: torch.Tensor | None = None,
  5658. scale_result: torch.Tensor | None = None,
  5659. out_dtype: torch.dtype | None = None,
  5660. use_fast_accum: bool = False,
  5661. ):
  5662. def is_fp8_or_fp4_type(dtype):
  5663. return dtype in (
  5664. torch.float8_e4m3fn,
  5665. torch.float8_e5m2,
  5666. torch.float8_e4m3fnuz,
  5667. torch.float8_e5m2fnuz,
  5668. torch.float4_e2m1fn_x2,
  5669. )
  5670. torch._check(
  5671. self.dim() == 2 and mat2.dim() == 2,
  5672. lambda: f"Inputs must be 2D but got self.dim()={self.dim()} and mat2.dim()={mat2.dim()}",
  5673. )
  5674. torch._check(
  5675. is_fp8_or_fp4_type(self.dtype) and is_fp8_or_fp4_type(mat2.dtype),
  5676. lambda: f"Expected both inputs to be fp8 or fp4 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}",
  5677. )
  5678. if device_hint(self) == "cuda" or device_hint(self) == "xpu":
  5679. def is_row_major(stride):
  5680. return stride[0] > stride[1] and stride[1] == 1
  5681. def is_col_major(stride):
  5682. return stride[0] == 1 and stride[1] > 1
  5683. def has_zero_dim(tensor_2d):
  5684. return tensor_2d.size(0) == 0 or tensor_2d.size(1) == 0
  5685. torch._check(
  5686. is_row_major(self.stride()) or has_zero_dim(self),
  5687. lambda: f"self must be row_major, got stride {self.stride()}",
  5688. )
  5689. torch._check(
  5690. is_col_major(mat2.stride()) or has_zero_dim(mat2),
  5691. lambda: f"mat2 must be col_major, got stride {mat2.stride()}",
  5692. )
  5693. torch._check(
  5694. self.size(1) % 16 == 0,
  5695. lambda: f"Expected self.size(1) to be divisible by 16, but got self.size(1)={self.size(1)}",
  5696. )
  5697. torch._check(
  5698. mat2.size(0) % 16 == 0 and mat2.size(1) % 16 == 0,
  5699. lambda: f"Expected both dimensions of mat2 to be divisible by 16 but got {mat2.shape}",
  5700. )
  5701. # determine scaling type and check input dimensions (refer to Blas.cpp op)
  5702. m, _k = self.shape
  5703. n = mat2.size(1)
  5704. is_blockwise_scaling = (
  5705. (
  5706. scale_a.dtype == torch.float8_e8m0fnu
  5707. and scale_b.dtype == torch.float8_e8m0fnu
  5708. )
  5709. or (
  5710. scale_a.dtype == torch.float8_e4m3fn
  5711. and scale_b.dtype == torch.float8_e4m3fn
  5712. )
  5713. ) # note: this applies to blockwise scaling for non-FP8 types (FP8 accepts FP32 scales)
  5714. if scale_a.numel() == 1 and scale_b.numel() == 1:
  5715. # tensorwise scaling
  5716. torch._check(
  5717. scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32,
  5718. lambda: "For tensorwise scaling, both scale_a and scale_b must be float (fp32) tensors.",
  5719. )
  5720. elif is_blockwise_scaling:
  5721. # blockwise scaling
  5722. if scale_a.dtype == torch.float8_e4m3fn:
  5723. # NVIDIA's nvfp4 recipe:
  5724. # * block size is 16 elements packed (32 unpacked)
  5725. # * _k needs to be translated to the unpacked version
  5726. block_size_k = 16
  5727. _k = _k * 2
  5728. else:
  5729. block_size_k = 32
  5730. block_size_mn = 128
  5731. num_k_blocks = ceil_div(_k, block_size_k)
  5732. padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4
  5733. expected_a_size = (
  5734. block_size_mn * ceil_div(m, block_size_mn) * padded_num_k_blocks
  5735. )
  5736. expected_b_size = (
  5737. block_size_mn * ceil_div(n, block_size_mn) * padded_num_k_blocks
  5738. )
  5739. if (
  5740. scale_a.numel() == expected_a_size
  5741. and scale_b.numel() == expected_b_size
  5742. ):
  5743. torch._check(
  5744. scale_a.is_contiguous(),
  5745. lambda: "scale_a must be contiguous",
  5746. )
  5747. torch._check(
  5748. scale_b.is_contiguous(),
  5749. lambda: "scale_b must be contiguous",
  5750. )
  5751. else:
  5752. torch._check(
  5753. False,
  5754. lambda: (
  5755. "Invalid blockwise scaling configuration. "
  5756. f"For blockwise scaling, scale_a should have {expected_a_size} elements, got {scale_a.numel()}, "
  5757. f"scale_b should have {expected_b_size} elements, got {scale_b.numel()}."
  5758. ),
  5759. )
  5760. else:
  5761. torch._check(
  5762. scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32,
  5763. lambda: "For rowwise scaling, both scale_a and scale_b must be float (fp32) tensors.",
  5764. )
  5765. # for rowwise scaling, enforce 2D input tensors
  5766. torch._check(
  5767. scale_a.dim() == 2 and scale_b.dim() == 2,
  5768. lambda: f"For non-tensorwise scaling, scale tensors must be 2D, but got {scale_a.dim()=} and {scale_b.dim()=}",
  5769. )
  5770. if (
  5771. scale_a.size(0) == m
  5772. and scale_a.size(1) == 1
  5773. and scale_b.size(0) == 1
  5774. and scale_b.size(1) == n
  5775. ):
  5776. # rowwise scaling
  5777. torch._check(
  5778. scale_a.is_contiguous() and scale_b.is_contiguous(),
  5779. lambda: "Both scale_a and scale_b must be contiguous for rowwise scaling.",
  5780. )
  5781. elif (
  5782. scale_a.size(0) == m
  5783. and scale_a.size(1) == scale_b.size(0) == ceil_div(_k, 128)
  5784. and scale_b.size(1) == ceil_div(n, 128)
  5785. ):
  5786. # (BlockWise1x128, BlockWise128x128)
  5787. pass # do nothing, but do not error
  5788. elif (
  5789. scale_a.size(0) == m
  5790. and scale_a.size(1) == scale_b.size(0) == ceil_div(_k, 128)
  5791. and scale_b.size(1) == n
  5792. ):
  5793. # (BlockWise1x128, BlockWise1x128)
  5794. pass # do nothing, but do not error
  5795. elif (
  5796. scale_a.size(0) == ceil_div(m, 128)
  5797. and scale_a.size(1) == scale_b.size(0) == ceil_div(_k, 128)
  5798. and scale_b.size(1) == n
  5799. ):
  5800. # (BlockWise128x128, BlockWise1x128)
  5801. pass # do nothing, but do not error
  5802. else:
  5803. # does not match any valid scaling type
  5804. torch._check(
  5805. False,
  5806. lambda: (
  5807. "Invalid scaling configuration. "
  5808. "For tensorwise scaling, both scales should be scalar. "
  5809. f"For rowwise scaling, scale_a should be ({m}, 1), scale_b should be (1, {n}). "
  5810. f"For (BlockWise1x128, BlockWise128x128), scale_a should be ({m}, {ceil_div(_k, 128)}), "
  5811. + f"scale_b should be ({ceil_div(_k, 128)}, {ceil_div(n, 128)}). "
  5812. f"For (BlockWise1x128, BlockWise1x128), scale_a should be ({m}, {ceil_div(_k, 128)}), "
  5813. + f"scale_b should be ({ceil_div(_k, 128)}, {n}). "
  5814. f"For (BlockWise128x128, BlockWise1x128), scale_a should be ({ceil_div(m, 128)}, {ceil_div(_k, 128)}), "
  5815. + f"scale_b should be ({ceil_div(_k, 128)}, {n}). "
  5816. f"Got scale_a.size()=({scale_a.size(0)}, {scale_a.size(1)}) "
  5817. f"and scale_b.size()=({scale_b.size(0)}, {scale_b.size(1)})"
  5818. ),
  5819. )
  5820. _out_dtype = out_dtype if out_dtype is not None else self.dtype
  5821. return torch.empty(self.size(0), mat2.size(1), dtype=_out_dtype, device=self.device)
  5822. @register_meta([aten._scaled_mm.default])
  5823. def meta_scaled_mm(
  5824. self: torch.Tensor,
  5825. mat2: torch.Tensor,
  5826. scale_a: torch.Tensor,
  5827. scale_b: torch.Tensor,
  5828. bias: torch.Tensor | None = None,
  5829. scale_result: torch.Tensor | None = None,
  5830. out_dtype: torch.dtype | None = None,
  5831. use_fast_accum: bool = False,
  5832. ):
  5833. return _check_scaled_mm_sizes(
  5834. self, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum
  5835. )
  5836. def _check_scaled_mm_sizes_v2(
  5837. self: torch.Tensor,
  5838. mat2: torch.Tensor,
  5839. scale_a: list[torch.Tensor],
  5840. scale_recipe_a: list[ScalingType],
  5841. scale_b: list[torch.Tensor],
  5842. scale_recipe_b: list[ScalingType],
  5843. bias: torch.Tensor | None = None,
  5844. out_dtype: torch.dtype | None = None,
  5845. swizzle_a: list[SwizzleType] | None = None,
  5846. swizzle_b: list[SwizzleType] | None = None,
  5847. use_fast_accum: bool = False,
  5848. ):
  5849. def is_fp8_or_fp4_type(dtype):
  5850. return dtype in (
  5851. torch.float8_e4m3fn,
  5852. torch.float8_e5m2,
  5853. torch.float8_e4m3fnuz,
  5854. torch.float8_e5m2fnuz,
  5855. torch.float4_e2m1fn_x2,
  5856. )
  5857. def is_fp4_type(dtype):
  5858. return dtype == torch.float4_e2m1fn_x2
  5859. torch._check(
  5860. self.dim() == 2 and mat2.dim() == 2,
  5861. lambda: f"Inputs must be 2D but got self.dim()={self.dim()} and mat2.dim()={mat2.dim()}",
  5862. )
  5863. torch._check(
  5864. is_fp8_or_fp4_type(self.dtype) and is_fp8_or_fp4_type(mat2.dtype),
  5865. lambda: f"Expected both inputs to be fp8 or fp4 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}",
  5866. )
  5867. # Passed tensors:
  5868. # self: [M, K]
  5869. # mat2: [K, N]
  5870. M = self.shape[0]
  5871. K = self.shape[1]
  5872. N = mat2.shape[1]
  5873. # If we're using fp4, using fp4x2 packed format - adjust K appropriately
  5874. if is_fp4_type(self.dtype) and is_fp4_type(mat2.dtype):
  5875. K_packed_multiplier = 2
  5876. K *= K_packed_multiplier
  5877. scale_recipe_a = [ScalingType(si) for si in scale_recipe_a]
  5878. scale_recipe_b = [ScalingType(si) for si in scale_recipe_b]
  5879. if swizzle_a:
  5880. swizzle_a = [SwizzleType(si) for si in swizzle_a]
  5881. else:
  5882. swizzle_a = [
  5883. SwizzleType.NO_SWIZZLE,
  5884. ]
  5885. if swizzle_b:
  5886. swizzle_b = [SwizzleType(si) for si in swizzle_b]
  5887. else:
  5888. swizzle_b = [
  5889. SwizzleType.NO_SWIZZLE,
  5890. ]
  5891. if device_hint(self) == "cuda" or device_hint(self) == "xpu":
  5892. def is_row_major(stride):
  5893. return stride[0] > stride[1] and stride[1] == 1
  5894. def is_col_major(stride):
  5895. return stride[0] == 1 and stride[1] > 1
  5896. def has_zero_dim(tensor_2d):
  5897. return tensor_2d.size(0) == 0 or tensor_2d.size(1) == 0
  5898. torch._check(
  5899. is_row_major(self.stride()) or has_zero_dim(self),
  5900. lambda: f"self must be row_major, got stride {self.stride()}",
  5901. )
  5902. torch._check(
  5903. is_col_major(mat2.stride()) or has_zero_dim(mat2),
  5904. lambda: f"mat2 must be col_major, got stride {mat2.stride()}",
  5905. )
  5906. torch._check(
  5907. self.size(1) % 16 == 0,
  5908. lambda: f"Expected self.size(1) to be divisible by 16, but got self.size(1)={self.size(1)}",
  5909. )
  5910. torch._check(
  5911. mat2.size(0) % 16 == 0 and mat2.size(1) % 16 == 0,
  5912. lambda: f"Expected both dimensions of mat2 to be divisible by 16 but got {mat2.shape}",
  5913. )
  5914. def is_tensorwise(recipe_a: list[ScalingType], recipe_b: list[ScalingType]):
  5915. return (
  5916. len(recipe_a) == 1
  5917. and len(recipe_b) == 1
  5918. and recipe_a[0] == ScalingType.TensorWise
  5919. and recipe_b[0] == ScalingType.TensorWise
  5920. )
  5921. def is_rowwise(recipe_a: list[ScalingType], recipe_b: list[ScalingType]):
  5922. return (
  5923. len(recipe_a) == 1
  5924. and len(recipe_b) == 1
  5925. and recipe_a[0] == ScalingType.RowWise
  5926. and recipe_b[0] == ScalingType.RowWise
  5927. )
  5928. def is_mx(recipe_a: list[ScalingType], recipe_b: list[ScalingType]):
  5929. return (
  5930. len(recipe_a) == 1
  5931. and len(recipe_b) == 1
  5932. and recipe_a[0] == ScalingType.BlockWise1x32
  5933. and recipe_b[0] == ScalingType.BlockWise1x32
  5934. )
  5935. def is_nv_single_level(
  5936. recipe_a: list[ScalingType], recipe_b: list[ScalingType]
  5937. ):
  5938. return (
  5939. len(recipe_a) == 1
  5940. and len(recipe_b) == 1
  5941. and recipe_a[0] == ScalingType.BlockWise1x16
  5942. and recipe_b[0] == ScalingType.BlockWise1x16
  5943. )
  5944. def is_nv(recipe_a: list[ScalingType], recipe_b: list[ScalingType]):
  5945. return (
  5946. len(recipe_a) == 2
  5947. and len(recipe_b) == 2
  5948. and recipe_a[0] == ScalingType.BlockWise1x16
  5949. and recipe_a[1] == ScalingType.TensorWise
  5950. and recipe_b[0] == ScalingType.BlockWise1x16
  5951. and recipe_b[1] == ScalingType.TensorWise
  5952. )
  5953. def is_1x128_1x128(recipe_a: list[ScalingType], recipe_b: list[ScalingType]):
  5954. return (
  5955. len(recipe_a) == 1
  5956. and len(recipe_b) == 1
  5957. and recipe_a[0] == ScalingType.BlockWise1x128
  5958. and recipe_b[0] == ScalingType.BlockWise1x128
  5959. )
  5960. def is_1x128_128x128(recipe_a: list[ScalingType], recipe_b: list[ScalingType]):
  5961. return (
  5962. len(recipe_a) == 1
  5963. and len(recipe_b) == 1
  5964. and recipe_a[0] == ScalingType.BlockWise1x128
  5965. and recipe_b[0] == ScalingType.BlockWise128x128
  5966. )
  5967. def is_128x128_1x128(recipe_a: list[ScalingType], recipe_b: list[ScalingType]):
  5968. return (
  5969. len(recipe_a) == 1
  5970. and len(recipe_b) == 1
  5971. and recipe_a[0] == ScalingType.BlockWise128x128
  5972. and recipe_b[0] == ScalingType.BlockWise1x128
  5973. )
  5974. # Given scaling types, check input dimensions
  5975. if is_tensorwise(scale_recipe_a, scale_recipe_b):
  5976. # TensorWise
  5977. torch._check(
  5978. scale_a[0].numel() == 1
  5979. and scale_b[0].numel() == 1
  5980. and scale_a[0].dtype == torch.float32
  5981. and scale_b[0].dtype == torch.float32,
  5982. lambda: "For Tensorwise scaling, both scale_a and scale_b must be single element float (fp32) tensors",
  5983. )
  5984. elif is_rowwise(scale_recipe_a, scale_recipe_b):
  5985. torch._check(
  5986. scale_a[0].shape[0] == M
  5987. and scale_a[0].numel() == M
  5988. and scale_a[0].dtype == torch.float32
  5989. and scale_b[0].numel() == N
  5990. and scale_b[0].dtype == torch.float32,
  5991. lambda: (
  5992. f"For Rowwise scaling, scale_a must have {self.shape[0]} elements (got: {scale_a[0].numel()})"
  5993. f", and scale_b must have {mat2.shape[1]} elements (got: {scale_b[0].numel()})"
  5994. ),
  5995. )
  5996. elif is_1x128_1x128(scale_recipe_a, scale_recipe_b):
  5997. # A, B are fp8, scales are fp32
  5998. # As: [M x K // 128], stride: [1, M]
  5999. # Bs: [N x K // 128], stride: [1, N]
  6000. types_ok = (
  6001. scale_a[0].dtype == torch.float32 and scale_b[0].dtype == torch.float32
  6002. )
  6003. sa = scale_a[0]
  6004. scale_a_ok = (
  6005. sa.shape[0] == M
  6006. and sa.shape[1] == K // 128
  6007. and sa.stride(0) == 1
  6008. and (sa.stride(1) == M or (sa.shape[1] == 1 and sa.stride(1) == 1))
  6009. )
  6010. sb = scale_b[0]
  6011. scale_b_ok = (
  6012. sb.shape[0] == N
  6013. and sb.shape[1] == K // 128
  6014. and sb.stride(0) == 1
  6015. and (sb.stride(1) == N or (sb.shape[1] == 1 and sb.stride(1) == 1))
  6016. )
  6017. torch._check(
  6018. types_ok and scale_a_ok and scale_b_ok,
  6019. lambda: (
  6020. "For 1x128 x 1x128 blockwise scaling, "
  6021. f"scale a must have shape [{M}, {K // 128}] (got: {sa.shape}) and stride [1, {M}] (got: {sa.stride})"
  6022. f"scale b must have shape [{N}, {K // 128}] (got: {sb.shape}) and stride [1, {N}] (got: {sb.stride})"
  6023. ),
  6024. )
  6025. elif is_128x128_1x128(scale_recipe_a, scale_recipe_b):
  6026. # A, B are fp8, scales are fp32
  6027. # L4 = round_up(K // 128, 4)
  6028. # As: [L4 x M // 128], stride: [1, L4]
  6029. # Bs: [N x K // 128], stride: [1, N]
  6030. types_ok = (
  6031. scale_a[0].dtype == torch.float32 and scale_b[0].dtype == torch.float32
  6032. )
  6033. L4 = round_up(K / 128, 4)
  6034. sa = scale_a[0]
  6035. scale_a_ok = (
  6036. sa.shape[0] == L4
  6037. and sa.shape[1] == M // 128
  6038. and sa.stride(0) == 1
  6039. and (sa.stride(1) == L4 or (sa.shape[1] == 1 and sa.stride(1) == 1))
  6040. )
  6041. sb = scale_b[0]
  6042. scale_b_ok = (
  6043. sb.shape[0] == N
  6044. and sb.shape[1] == K // 128
  6045. and sb.stride(0) == 1
  6046. and (sb.stride(1) == N or (sb.shape[1] == 1 and sb.stride(1) == 1))
  6047. )
  6048. torch._check(
  6049. types_ok and scale_a_ok and scale_b_ok,
  6050. lambda: (
  6051. "For 128x128 x 1x128 blockwise scaling, L4 = {round_up(K / 128, 4)}, "
  6052. f"scale a must have shape [{L4}, {M // 128}] (got: {sa.shape}) and stride [1, {L4}] (got: {sa.stride})"
  6053. f"scale b must have shape [{N}, {K // 128}] (got: {sb.shape}) and stride [1, {N}] (got: {sb.stride})"
  6054. ),
  6055. )
  6056. elif is_1x128_128x128(scale_recipe_a, scale_recipe_b):
  6057. # A, B are fp8, scales are fp32
  6058. # L4 = round_up(K // 128, 4)
  6059. # As: [M x K // 128], stride: [1, M]
  6060. # Bs: [L4 x N // 128], stride: [1, L4]
  6061. types_ok = (
  6062. scale_a[0].dtype == torch.float32 and scale_b[0].dtype == torch.float32
  6063. )
  6064. L4 = round_up(K / 128, 4)
  6065. sa = scale_a[0]
  6066. scale_a_ok = (
  6067. sa.shape[0] == M
  6068. and sa.shape[1] == K // 128
  6069. and sa.stride(0) == 1
  6070. and (sa.stride(1) == M or (sa.shape[1] == 1 and sa.stride(1) == 1))
  6071. )
  6072. sb = scale_b[0]
  6073. scale_b_ok = (
  6074. sb.shape[0] == L4
  6075. and sb.shape[1] == N // 128
  6076. and sb.stride(0) == 1
  6077. and (sb.stride(1) == L4 or (sb.shape[1] == 1 and sb.stride(1) == 1))
  6078. )
  6079. torch._check(
  6080. types_ok and scale_a_ok and scale_b_ok,
  6081. lambda: (
  6082. "For 1x128 x 128x128 blockwise scaling, L4 = {round_up(K / 128, 4)}, "
  6083. f"scale a must have shape [{M}, {K // 128}] (got: {sa.shape}) and stride [1, {M}] (got: {sa.stride})"
  6084. f"scale b must have shape [{L4}, {N // 128}] (got: {sb.shape}) and stride [1, {L4}] (got: {sb.stride})"
  6085. ),
  6086. )
  6087. elif is_mx(scale_recipe_a, scale_recipe_b):
  6088. if torch.version.hip:
  6089. # Note(slayton58): These mirror ROCm in ScaledBlas.cpp, but I think they're wrong..
  6090. expected_scale_a_elems = ceil_div(self.shape[0], 32) * self.shape[1]
  6091. expected_scale_b_elems = ceil_div(self.shape[1], 32) * self.shape[0]
  6092. expected_swizzle = SwizzleType.NO_SWIZZLE
  6093. else:
  6094. expected_scale_a_elems = round_up(self.shape[0], 128) * round_up(
  6095. ceil_div(self.shape[1], 32), 4
  6096. )
  6097. expected_scale_b_elems = round_up(mat2.shape[1], 128) * round_up(
  6098. ceil_div(self.shape[1], 32), 4
  6099. )
  6100. expected_swizzle = SwizzleType.SWIZZLE_32_4_4
  6101. torch._check(
  6102. scale_a[0].numel() == expected_scale_a_elems
  6103. and scale_a[0].dtype == torch.float8_e8m0fnu
  6104. and scale_b[0].numel() == expected_scale_b_elems
  6105. and scale_b[0].dtype == torch.float8_e8m0fnu
  6106. and swizzle_a[0] == expected_swizzle
  6107. and swizzle_b[0] == expected_swizzle,
  6108. lambda: (
  6109. f"for MX scaling scale_a must have {expected_scale_a_elems} (got: {scale_a[0].numel()}) "
  6110. f"and scale_b must have {expected_scale_b_elems} (got: {scale_b[0].numel()}). Scales must "
  6111. f"have types {torch.float8_e8m0fnu} (for self: {scale_a[0].dtype}, mat_b: {scale_b[0].dtype}) "
  6112. f"Must have swizzle type {expected_swizzle} (got self: {swizzle_a[0]}, mat_b: {swizzle_b[0]})"
  6113. ),
  6114. )
  6115. elif is_nv_single_level(scale_recipe_a, scale_recipe_b):
  6116. expected_scale_a_elems = round_up(M, 128) * round_up(ceil_div(K, 16), 4)
  6117. expected_scale_b_elems = round_up(N, 128) * round_up(ceil_div(K, 16), 4)
  6118. expected_swizzle = SwizzleType.SWIZZLE_32_4_4
  6119. torch._check(
  6120. scale_a[0].numel() == expected_scale_a_elems
  6121. and scale_a[0].dtype == torch.float8_e4m3fn
  6122. and scale_b[0].numel() == expected_scale_b_elems
  6123. and scale_b[0].dtype == torch.float8_e4m3fn
  6124. and swizzle_a[0] == expected_swizzle
  6125. and swizzle_b[0] == expected_swizzle,
  6126. lambda: (
  6127. f"for single-level NV scaling scale_a must have {expected_scale_a_elems} (got: {scale_a[0].numel()}) "
  6128. f"and scale_b must have {expected_scale_b_elems} (got: {scale_b[0].numel()}). Must have "
  6129. f"swizzle type {expected_swizzle} (got self: {swizzle_a[0]}, mat_b: {swizzle_b[0]})"
  6130. ),
  6131. )
  6132. elif is_nv(scale_recipe_a, scale_recipe_b):
  6133. expected_scale_a_elems = round_up(M, 128) * round_up(ceil_div(K, 16), 4)
  6134. expected_scale_b_elems = round_up(N, 128) * round_up(ceil_div(K, 16), 4)
  6135. expected_swizzle = SwizzleType.SWIZZLE_32_4_4
  6136. torch._check(
  6137. scale_a[0].numel() == expected_scale_a_elems
  6138. and scale_a[0].dtype == torch.float8_e4m3fn
  6139. and scale_a[1].numel() == 1
  6140. and scale_a[1].dtype == torch.float32
  6141. and scale_b[0].numel() == expected_scale_b_elems
  6142. and scale_b[0].dtype == torch.float8_e4m3fn
  6143. and scale_b[1].numel() == 1
  6144. and scale_b[1].dtype == torch.float32
  6145. and swizzle_a[0] == expected_swizzle
  6146. and swizzle_b[0] == expected_swizzle,
  6147. lambda: (
  6148. f"for NV scaling scale_a must have {expected_scale_a_elems} (got: {scale_a[0].numel()}) "
  6149. f"and scale_b must have {expected_scale_b_elems} (got: {scale_b[0].numel()}). Must have "
  6150. f"swizzle type {expected_swizzle} (got self: {swizzle_a[0]}, mat_b: {swizzle_b[0]})"
  6151. ),
  6152. )
  6153. else:
  6154. torch._check(
  6155. False,
  6156. lambda: (
  6157. "Invalid scaling configuration. "
  6158. "For tensorwise scaling, both scales should be scalar. "
  6159. f"For rowwise scaling, scale_a should be ({M}, 1), scale_b should be (1, {N}). "
  6160. f"For (BlockWise1x128, BlockWise128x128), scale_a should be ({M}, {ceil_div(K, 128)}), "
  6161. + f"scale_b should be ({ceil_div(K, 128)}, {ceil_div(N, 128)}). "
  6162. f"For (BlockWise1x128, BlockWise1x128), scale_a should be ({M}, {ceil_div(K, 128)}), "
  6163. + f"scale_b should be ({ceil_div(K, 128)}, {N}). "
  6164. f"Got scale_a.size()=({scale_a[0].size(0)}, {scale_a[0].size(1)}) "
  6165. f"and scale_b.size()=({scale_b[0].size(0)}, {scale_b[0].size(1)})"
  6166. ),
  6167. )
  6168. _out_dtype = out_dtype if out_dtype is not None else self.dtype
  6169. return torch.empty(M, N, dtype=_out_dtype, device=self.device)
  6170. @register_meta([aten._scaled_mm_v2.default])
  6171. def meta_scaled_mm_v2(
  6172. self: torch.Tensor,
  6173. mat2: torch.Tensor,
  6174. scale_a: list[torch.Tensor],
  6175. scale_recipe_a: list[ScalingType],
  6176. swizzle_a: list[SwizzleType],
  6177. scale_b: list[torch.Tensor],
  6178. scale_recipe_b: list[ScalingType],
  6179. swizzle_b: list[SwizzleType],
  6180. bias: torch.Tensor | None = None,
  6181. output_dtype: torch.dtype | None = None,
  6182. contraction_dims: list[int] | None = None,
  6183. use_fast_accum: bool = False,
  6184. ):
  6185. return _check_scaled_mm_sizes_v2(
  6186. self,
  6187. mat2,
  6188. scale_a,
  6189. scale_recipe_a,
  6190. scale_b,
  6191. scale_recipe_b,
  6192. bias=bias,
  6193. out_dtype=output_dtype,
  6194. swizzle_a=swizzle_a,
  6195. swizzle_b=swizzle_b,
  6196. use_fast_accum=use_fast_accum,
  6197. )
  6198. @register_meta([aten.scatter_reduce.two, aten.scatter_reduce.two_out])
  6199. @out_wrapper()
  6200. def meta_scatter_reduce_two(self, dim, index, src, reduce, include_self=True):
  6201. scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True)
  6202. return self.new_empty(self.shape)
  6203. @register_meta(aten.scatter_reduce_.two)
  6204. def meta_scatter_reduce__two(self, dim, index, src, reduce, include_self=True):
  6205. scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True)
  6206. return self
  6207. @register_meta([aten.multinomial.default, aten.multinomial.out])
  6208. @out_wrapper()
  6209. def meta_multinomial(input, num_samples, replacement=False, *, generator=None):
  6210. torch._check(
  6211. 0 < input.dim() <= 2,
  6212. lambda: f"The probability distributions dimensions must be 1 or 2, but got {input.dim()}",
  6213. )
  6214. if input.dim() == 1:
  6215. return torch.empty(num_samples, dtype=torch.long, device=input.device)
  6216. return torch.empty(
  6217. input.size(0), num_samples, dtype=torch.long, device=input.device
  6218. )
  6219. def multiply_integers(vs):
  6220. r = 1
  6221. for v in vs:
  6222. r *= v
  6223. return r
  6224. def upsample_common_check(input_size, output_size, num_spatial_dims):
  6225. torch._check(
  6226. len(output_size) == num_spatial_dims,
  6227. lambda: f"It is expected output_size equals to {num_spatial_dims}, but got size {len(output_size)}",
  6228. )
  6229. expected_input_dims = num_spatial_dims + 2 # N, C, ...
  6230. torch._check(
  6231. len(input_size) == expected_input_dims,
  6232. lambda: f"It is expected input_size equals to {expected_input_dims}, but got size {len(input_size)}",
  6233. )
  6234. torch._check(
  6235. all(s > 0 for s in input_size[2:]) and all(s > 0 for s in output_size),
  6236. lambda: f"Input and output sizes should be greater than 0, but got "
  6237. f"input size {input_size} and output size {output_size}",
  6238. )
  6239. nbatch, channels = input_size[:2]
  6240. return (nbatch, channels, *output_size)
  6241. @register_meta(
  6242. [aten.upsample_nearest1d.default, aten._upsample_nearest_exact1d.default]
  6243. )
  6244. def upsample_nearest1d(input, output_size, scales=None):
  6245. torch._check(
  6246. input.numel() != 0 or multiply_integers(input.size()[1:]),
  6247. lambda: f"Non-empty 3D data tensor expected but got a tensor with sizes {input.size()}",
  6248. )
  6249. full_output_size = upsample_common_check(
  6250. input.size(), output_size, num_spatial_dims=1
  6251. )
  6252. return input.new_empty(full_output_size).to(
  6253. memory_format=utils.suggest_memory_format(input)
  6254. )
  6255. @register_meta(
  6256. [aten.upsample_nearest2d.default, aten._upsample_nearest_exact2d.default]
  6257. )
  6258. def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None):
  6259. torch._check(
  6260. input.numel() != 0 or multiply_integers(input.size()[1:]),
  6261. lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
  6262. )
  6263. full_output_size = upsample_common_check(
  6264. input.size(), output_size, num_spatial_dims=2
  6265. )
  6266. output = input.new_empty(full_output_size)
  6267. # convert output to correct memory format, if necessary
  6268. memory_format = utils.suggest_memory_format(input)
  6269. # following "heuristic: only use channels_last path when it's faster than the contiguous path"
  6270. _, n_channels, _, _ = input.shape
  6271. if input.device.type == "cuda" and n_channels < 4:
  6272. memory_format = torch.contiguous_format
  6273. output = output.contiguous(memory_format=memory_format)
  6274. return output
  6275. @register_meta(
  6276. [
  6277. aten.upsample_nearest2d_backward.default,
  6278. aten._upsample_nearest_exact2d_backward.default,
  6279. ]
  6280. )
  6281. def upsample_nearest2d_backward(
  6282. grad_output: Tensor,
  6283. output_size: Sequence[int | torch.SymInt],
  6284. input_size: Sequence[int | torch.SymInt],
  6285. scales_h: float | None = None,
  6286. scales_w: float | None = None,
  6287. ):
  6288. full_output_size = upsample_common_check(
  6289. input_size, output_size, num_spatial_dims=2
  6290. )
  6291. torch._check(
  6292. grad_output.ndim == 4,
  6293. lambda: f"Expected grad_output to be a tensor of dimension 4 but got: dimension {grad_output.ndim}",
  6294. )
  6295. for i in range(4):
  6296. torch._check(
  6297. grad_output.size(i) == full_output_size[i],
  6298. lambda: (
  6299. f"Expected grad_output to have the same shape as output;"
  6300. f" output.size({i}) = {full_output_size[i]}"
  6301. f" but got grad_output.size({i}) = {grad_output.size(i)}"
  6302. ),
  6303. )
  6304. return grad_output.new_empty(input_size).to(
  6305. memory_format=utils.suggest_memory_format(grad_output)
  6306. ) # type: ignore[call-overload]
  6307. @register_meta(
  6308. [aten.upsample_nearest3d.default, aten._upsample_nearest_exact3d.default]
  6309. )
  6310. def upsample_nearest3d(input, output_size, scales_d=None, scales_h=None, scales_w=None):
  6311. torch._check(
  6312. input.numel() != 0 or multiply_integers(input.size()[1:]),
  6313. lambda: f"Non-empty 5D data tensor expected but got a tensor with sizes {input.size()}",
  6314. )
  6315. full_output_size = upsample_common_check(
  6316. input.size(), output_size, num_spatial_dims=3
  6317. )
  6318. return input.new_empty(full_output_size).to(
  6319. memory_format=utils.suggest_memory_format(input)
  6320. )
  6321. @register_meta(
  6322. [
  6323. aten.sort.default,
  6324. aten.sort.stable,
  6325. aten.sort.values,
  6326. aten.sort.values_stable,
  6327. ]
  6328. )
  6329. def meta_sort(self, stable=None, dim=-1, descending=False, values=None, indices=None):
  6330. v, i = torch.empty_like(self), torch.empty_like(self, dtype=torch.int64)
  6331. if values is not None and indices is not None:
  6332. if not isinstance(values, TensorLike):
  6333. raise AssertionError(f"values must be TensorLike, got {type(values)}")
  6334. if not isinstance(indices, TensorLike):
  6335. raise AssertionError(f"indices must be TensorLike, got {type(indices)}")
  6336. # Makes sure values and indices have the same strides. For cases where
  6337. # these have different shapes, like (5, 10, 5) and (0) in msort.
  6338. out_shape = v.shape
  6339. out_stride = v.stride()
  6340. values = _maybe_resize_out(values, out_shape)
  6341. indices = _maybe_resize_out(indices, out_shape)
  6342. values.as_strided_(out_shape, out_stride)
  6343. indices.as_strided_(out_shape, out_stride)
  6344. _safe_copy_out(copy_from=v, copy_to=values) # type: ignore[arg-type]
  6345. _safe_copy_out(copy_from=i, copy_to=indices) # type: ignore[arg-type]
  6346. return values, indices
  6347. return v, i
  6348. def rnn_cell_checkSizes(
  6349. input_gates,
  6350. hidden_gates,
  6351. input_bias,
  6352. hidden_bias,
  6353. factor,
  6354. prev_hidden,
  6355. ):
  6356. torch._check(input_gates.ndim == 2, lambda: f"{input_gates.ndim} != 2")
  6357. torch._check(
  6358. input_gates.shape == hidden_gates.shape,
  6359. lambda: f"{input_gates.shape} != {hidden_gates.shape}",
  6360. )
  6361. gates_size = input_gates.size(1)
  6362. if input_bias is not None:
  6363. torch._check(input_bias.ndim == 1, lambda: f"{input_bias.ndim} != 1")
  6364. torch._check(
  6365. input_bias.numel() == gates_size,
  6366. lambda: f"{input_bias.numel()} != {gates_size}",
  6367. )
  6368. torch._check(
  6369. input_bias.shape == hidden_bias.shape,
  6370. lambda: f"{input_bias.shape} != {hidden_bias.shape}",
  6371. )
  6372. torch._check(prev_hidden.ndim == 2, lambda: f"{prev_hidden.ndim} != 2")
  6373. expected_prev_hidden_numel = input_gates.size(0) * gates_size // factor
  6374. torch._check(
  6375. prev_hidden.numel() == expected_prev_hidden_numel,
  6376. lambda: f"{prev_hidden.numel()} != {input_gates.size(0)} * {gates_size} // {factor} (aka {expected_prev_hidden_numel})",
  6377. )
  6378. torch._check(
  6379. all(
  6380. x.device == input_gates.device
  6381. for x in [hidden_gates, input_bias, hidden_bias, prev_hidden]
  6382. ),
  6383. lambda: "expected all inputs to be same device",
  6384. )
  6385. @register_meta(aten._thnn_fused_lstm_cell.default)
  6386. def _thnn_fused_lstm_cell_meta(
  6387. input_gates,
  6388. hidden_gates,
  6389. cx,
  6390. input_bias=None,
  6391. hidden_bias=None,
  6392. ):
  6393. rnn_cell_checkSizes(input_gates, hidden_gates, input_bias, hidden_bias, 4, cx)
  6394. workspace = torch.empty_like(input_gates, memory_format=torch.contiguous_format)
  6395. hy = torch.empty_like(cx, memory_format=torch.contiguous_format)
  6396. cy = torch.empty_like(cx, memory_format=torch.contiguous_format)
  6397. return (hy, cy, workspace)
  6398. @register_meta(aten._cudnn_rnn.default)
  6399. def _cudnn_rnn(
  6400. input,
  6401. weight,
  6402. weight_stride0,
  6403. weight_buf,
  6404. hx,
  6405. cx,
  6406. mode,
  6407. hidden_size,
  6408. proj_size,
  6409. num_layers,
  6410. batch_first,
  6411. dropout,
  6412. train,
  6413. bidirectional,
  6414. batch_sizes,
  6415. dropout_state,
  6416. ):
  6417. is_input_packed = len(batch_sizes) != 0
  6418. if is_input_packed:
  6419. seq_length = len(batch_sizes)
  6420. mini_batch = batch_sizes[0]
  6421. batch_sizes_sum = input.shape[0]
  6422. else:
  6423. seq_length = input.shape[1] if batch_first else input.shape[0]
  6424. mini_batch = input.shape[0] if batch_first else input.shape[1]
  6425. batch_sizes_sum = -1
  6426. num_directions = 2 if bidirectional else 1
  6427. out_size = proj_size if proj_size != 0 else hidden_size
  6428. if is_input_packed:
  6429. out_shape = [batch_sizes_sum, out_size * num_directions]
  6430. else:
  6431. out_shape = (
  6432. [mini_batch, seq_length, out_size * num_directions]
  6433. if batch_first
  6434. else [seq_length, mini_batch, out_size * num_directions]
  6435. )
  6436. output = input.new_empty(out_shape)
  6437. cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
  6438. if cx is None:
  6439. cy = torch.empty(0, device=input.device)
  6440. else:
  6441. cy = cx.new_empty(cell_shape)
  6442. hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])
  6443. # TODO: Query cudnnGetRNNTrainingReserveSize (expose to python)
  6444. reserve_shape = 0 if train else 0
  6445. reserve = input.new_empty(reserve_shape, dtype=torch.uint8)
  6446. return output, hy, cy, reserve, weight_buf
  6447. @register_meta(aten.miopen_rnn.default)
  6448. def miopen_rnn(
  6449. input,
  6450. weight,
  6451. weight_stride0,
  6452. # weight_buf,
  6453. hx,
  6454. cx,
  6455. mode,
  6456. hidden_size,
  6457. # proj_size,
  6458. num_layers,
  6459. batch_first,
  6460. dropout,
  6461. train,
  6462. bidirectional,
  6463. batch_sizes,
  6464. dropout_state,
  6465. ):
  6466. total_weight_elems = 0
  6467. for w in weight:
  6468. if w.numel() > 0:
  6469. total_weight_elems += w.numel()
  6470. weight_buf = input.new_empty((total_weight_elems,))
  6471. return _cudnn_rnn(
  6472. input,
  6473. weight,
  6474. weight_stride0,
  6475. weight_buf,
  6476. hx,
  6477. cx,
  6478. mode,
  6479. hidden_size,
  6480. 0,
  6481. num_layers,
  6482. batch_first,
  6483. dropout,
  6484. train,
  6485. bidirectional,
  6486. batch_sizes,
  6487. dropout_state,
  6488. )
  6489. @register_meta(aten.mkldnn_rnn_layer.default)
  6490. def mkldnn_rnn_layer(
  6491. input,
  6492. w0,
  6493. w1,
  6494. w2,
  6495. w3,
  6496. hx_,
  6497. cx_,
  6498. reverse,
  6499. batch_sizes,
  6500. mode,
  6501. hidden_size,
  6502. num_layers,
  6503. has_biases,
  6504. bidirectional,
  6505. batch_first,
  6506. train,
  6507. ):
  6508. seq_length = input.shape[1] if batch_first else input.shape[0]
  6509. mini_batch = input.shape[0] if batch_first else input.shape[1]
  6510. output_chanels = hidden_size
  6511. out_shape = (
  6512. [mini_batch, seq_length, output_chanels]
  6513. if batch_first
  6514. else [seq_length, mini_batch, output_chanels]
  6515. )
  6516. output = input.new_empty(out_shape)
  6517. if hx_ is None:
  6518. hy = torch.empty(0, device=input.device)
  6519. else:
  6520. hy = hx_.new_empty(hx_.shape)
  6521. if cx_ is None:
  6522. cy = torch.empty(0, device=input.device)
  6523. else:
  6524. cy = cx_.new_empty(cx_.shape)
  6525. workspace = torch.empty(0, device=input.device, dtype=torch.uint8)
  6526. return output, hy, cy, workspace
  6527. def zero_numel_check_dims(self, dim, fn_name):
  6528. if self.ndim == 0:
  6529. torch._check_index(
  6530. dim == 0 or dim == -1,
  6531. lambda: f"{fn_name}: Expected reduction dim -1 or 0 for scalar but got {dim}",
  6532. )
  6533. else:
  6534. torch._check_index(
  6535. self.size(dim) != 0,
  6536. lambda: f"{fn_name}: Expected reduction dim {dim} to have non-zero size.",
  6537. )
  6538. # From aten/src/ATen/native/ReduceOps.cpp
  6539. def check_argmax_argmin(name, self, dim):
  6540. if dim is not None:
  6541. dim = maybe_wrap_dim(dim, self.dim())
  6542. zero_numel_check_dims(self, dim, name)
  6543. else:
  6544. torch._check(
  6545. self.numel() != 0,
  6546. lambda: f"{name}: Expected reduction dim to be specified for input.numel() == 0.",
  6547. )
  6548. @register_meta([aten.argmax.default, aten.argmin.default])
  6549. def argmax_argmin_meta(self, dim=None, keepdim=False):
  6550. check_argmax_argmin("argmax", self, dim)
  6551. dims = utils.reduction_dims(self.shape, (dim,) if dim is not None else None)
  6552. shape = _compute_reduction_shape(self, dims, keepdim)
  6553. return self.new_empty(shape, dtype=torch.int64)
  6554. @register_meta(aten.scalar_tensor.default)
  6555. def scalar_tensor(s, dtype=None, layout=None, device=None, pin_memory=None):
  6556. # NB: It's always wrong to try to create a scalar tensor with the jagged layout.
  6557. # Rather than fix this everywhere, just use the strided layout and let NJT handle
  6558. # scalar tensor broadcasting.
  6559. if layout == torch.jagged:
  6560. layout = torch.strided
  6561. return torch.empty(
  6562. (), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  6563. )
  6564. @register_meta(aten.topk.default)
  6565. def topk_meta(self, k, dim=-1, largest=True, sorted=True):
  6566. # From aten/src/ATen/native/Sorting.cpp
  6567. dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True)
  6568. sliceSize = 1 if self.dim() == 0 else self.size(dim)
  6569. torch._check(k >= 0)
  6570. torch._check(k <= sliceSize, lambda: "k not in range for dimension")
  6571. topKSize = list(self.shape)
  6572. if len(topKSize) > 0:
  6573. topKSize[dim] = k
  6574. return self.new_empty(topKSize), self.new_empty(topKSize, dtype=torch.int64)
  6575. @register_meta(aten._segment_reduce_backward)
  6576. @out_wrapper()
  6577. def meta__segment_reduce_backward(
  6578. grad, output, data, reduce, lengths=None, offsets=None, axis=0, initial=None
  6579. ):
  6580. if lengths is None and offsets is None:
  6581. raise AssertionError(
  6582. "segment_reduce(): Either lengths or offsets must be defined"
  6583. )
  6584. data_contig = data.contiguous()
  6585. grad_contig = grad.contiguous()
  6586. return torch.empty_like(
  6587. data_contig,
  6588. dtype=grad_contig.dtype,
  6589. device=grad_contig.device,
  6590. layout=grad_contig.layout,
  6591. )
  6592. @register_meta([aten.kthvalue.default, aten.kthvalue.values])
  6593. @out_wrapper("values", "indices")
  6594. def kthvalue_meta(self, k, dim=-1, keepdim=False):
  6595. from torch.fx.experimental.symbolic_shapes import sym_and
  6596. dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True)
  6597. dimSize = self.size(dim) if self.dim() > 0 else 1
  6598. torch._check(
  6599. sym_and(k >= 1, k <= dimSize),
  6600. lambda: f"kthvalue(): selected number k out of range for dimension {dim}",
  6601. )
  6602. shape = list(self.shape[:dim] + self.shape[dim + 1 :])
  6603. if keepdim and self.dim() > 0:
  6604. shape.insert(dim, 1)
  6605. return self.new_empty(shape), self.new_empty(shape, dtype=torch.int64)
  6606. legacy_contiguous_memory_format = torch.contiguous_format
  6607. # From aten/src/ATen/native/cuda/RNN.cu
  6608. def checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace):
  6609. defined_grad = grad_hy if grad_hy is not None else grad_cy
  6610. torch._check(defined_grad.dim() == 2, lambda: "")
  6611. exp_size = defined_grad.size()
  6612. if grad_hy is not None:
  6613. torch._check(grad_hy.size() == exp_size, lambda: "")
  6614. if grad_cy is not None:
  6615. torch._check(grad_cy.size() == exp_size, lambda: "")
  6616. torch._check(cx.size() == exp_size, lambda: "")
  6617. torch._check(cy.size() == exp_size, lambda: "")
  6618. torch._check(workspace.dim() == 2, lambda: "")
  6619. torch._check(workspace.numel() == exp_size[0] * exp_size[1] * 4, lambda: "")
  6620. # From aten/src/ATen/native/cuda/RNN.cu
  6621. @register_meta(aten._thnn_fused_lstm_cell_backward_impl.default)
  6622. def _thnn_fused_lstm_cell_backward_impl(grad_hy, grad_cy, cx, cy, workspace, has_bias):
  6623. if grad_hy is None and grad_cy is None:
  6624. return None, None, None
  6625. checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace)
  6626. grad_gates = torch.empty_like(
  6627. workspace, memory_format=legacy_contiguous_memory_format
  6628. )
  6629. grad_cx = torch.empty_like(cx, memory_format=legacy_contiguous_memory_format)
  6630. grad_bias = grad_gates.sum(0, keepdim=False) if has_bias else None
  6631. return grad_gates, grad_cx, grad_bias
  6632. # From aten/src/ATen/native/mps/operations/Linear.mm
  6633. @register_meta(aten.linear_backward.default)
  6634. def linear_backward(input_, grad_output_, weight_, output_mask):
  6635. grad_input = None
  6636. grad_weight = None
  6637. grad_bias = None
  6638. if output_mask[0]:
  6639. grad_input = grad_output_.new_empty(input_.size())
  6640. if output_mask[1] or output_mask[2]:
  6641. grad_weight = grad_output_.new_empty((grad_output_.size(-1), input_.size(-1)))
  6642. grad_bias = grad_output_.new_empty(grad_output_.size(-1))
  6643. return (grad_input, grad_weight, grad_bias)
  6644. @register_meta(aten.pixel_shuffle.default)
  6645. def meta_pixel_shuffle(self, upscale_factor):
  6646. if not (
  6647. len(self.shape) > 2 and self.shape[-3] % (upscale_factor * upscale_factor) == 0
  6648. ):
  6649. raise AssertionError(
  6650. f"Invalid input shape for pixel_shuffle: {self.shape} with upscale_factor = {upscale_factor}"
  6651. )
  6652. def is_channels_last(ten):
  6653. return torch._prims_common.suggest_memory_format(ten) == torch.channels_last
  6654. def pick_memory_format():
  6655. if is_channels_last(self):
  6656. if device_hint(self) == "cuda":
  6657. return torch.contiguous_format
  6658. else:
  6659. return torch.channels_last
  6660. elif self.is_contiguous(memory_format=torch.contiguous_format):
  6661. return torch.contiguous_format
  6662. elif self.is_contiguous(memory_format=torch.preserve_format):
  6663. return torch.preserve_format
  6664. C = self.shape[-3] // (upscale_factor * upscale_factor)
  6665. Hr = self.shape[-2] * upscale_factor
  6666. Wr = self.shape[-1] * upscale_factor
  6667. out_shape = (*self.shape[:-3], C, Hr, Wr)
  6668. out = self.new_empty(out_shape)
  6669. out = out.to(memory_format=pick_memory_format()) # type: ignore[call-overload]
  6670. return out
  6671. @register_meta(aten.mkldnn_rnn_layer_backward.default)
  6672. def mkldnn_rnn_layer_backward(
  6673. input,
  6674. weight0,
  6675. weight1,
  6676. weight2,
  6677. weight3,
  6678. hx_,
  6679. cx_tmp,
  6680. output,
  6681. hy_,
  6682. cy_,
  6683. grad_output_r_opt,
  6684. grad_hy_r_opt,
  6685. grad_cy_r_opt,
  6686. reverse,
  6687. mode,
  6688. hidden_size,
  6689. num_layers,
  6690. has_biases,
  6691. train,
  6692. bidirectional,
  6693. batch_sizes,
  6694. batch_first,
  6695. workspace,
  6696. ):
  6697. diff_x = input.new_empty(input.shape)
  6698. diff_hx = hx_.new_empty(hx_.shape)
  6699. diff_cx = cx_tmp.new_empty(cx_tmp.shape)
  6700. diff_w1 = weight0.new_empty(weight0.shape)
  6701. diff_w2 = weight1.new_empty(weight1.shape)
  6702. diff_b = weight2.new_empty(weight2.shape)
  6703. return diff_x, diff_w1, diff_w2, diff_b, diff_b, diff_hx, diff_cx
  6704. @register_meta([aten.bucketize.Tensor, aten.bucketize.Tensor_out])
  6705. @out_wrapper()
  6706. def meta_bucketize(self, boundaries, *, out_int32=False, right=False):
  6707. return torch.empty_like(
  6708. self,
  6709. dtype=torch.int32 if out_int32 else torch.int64,
  6710. memory_format=torch.contiguous_format,
  6711. )
  6712. @register_meta([aten.bucketize.Scalar, aten.bucketize.Scalar_out])
  6713. def meta_bucketize_scalar(
  6714. self: NumberType,
  6715. boundaries: Tensor,
  6716. *,
  6717. out_int32: bool = False,
  6718. right: bool = False,
  6719. ):
  6720. return boundaries.new_empty(
  6721. (),
  6722. dtype=torch.int32 if out_int32 else torch.int64,
  6723. )
  6724. @register_meta([aten.histc])
  6725. @out_wrapper()
  6726. def meta_histc(input, bins=100, min=0, max=0):
  6727. fn_name = "histc()"
  6728. if device_hint(input) == "cpu":
  6729. torch._check(
  6730. input.is_floating_point(),
  6731. lambda: f"\"histogram_cpu\" not implemented for '{input.dtype}'",
  6732. )
  6733. if device_hint(input) == "cuda" and input.is_floating_point():
  6734. utils.alert_not_deterministic("_histc_cuda with floating point input")
  6735. torch._check(
  6736. isinstance(bins, IntLike),
  6737. lambda: f"{fn_name}: argument 'bins' must be int, not {type(bins)}",
  6738. )
  6739. torch._check(bins > 0, lambda: f"{fn_name}: bins must be > 0, but got {bins}")
  6740. torch._check(
  6741. isinstance(min, Number),
  6742. lambda: f"{fn_name}: argument 'min' must be Number, not {type(min)}",
  6743. )
  6744. torch._check(
  6745. isinstance(max, Number),
  6746. lambda: f"{fn_name}: argument 'max' must be Number, not {type(max)}",
  6747. )
  6748. torch._check(max >= min, lambda: f"{fn_name}: max must be larger than min")
  6749. return torch.empty(bins, device=input.device, dtype=input.dtype)
  6750. @register_meta(
  6751. [aten._upsample_bilinear2d_aa.default, aten._upsample_bicubic2d_aa.default]
  6752. )
  6753. def meta_upsample_bimode2d_aa(
  6754. input,
  6755. output_size,
  6756. align_corners,
  6757. scales_h=None,
  6758. scales_w=None,
  6759. ):
  6760. full_output_size = upsample_common_check(
  6761. input.size(), output_size, num_spatial_dims=2
  6762. )
  6763. torch._check(
  6764. input.numel() != 0 or all(size > 0 for size in input.size()[1:]),
  6765. lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
  6766. )
  6767. return input.new_empty(full_output_size).to(
  6768. memory_format=utils.suggest_memory_format(input)
  6769. )
  6770. @register_meta([aten._upsample_bilinear2d_aa_backward.default])
  6771. def meta_upsample_bimode2d_aa_backward(
  6772. grad_output,
  6773. output_size,
  6774. input_size,
  6775. align_corners,
  6776. scales_h=None,
  6777. scales_w=None,
  6778. ):
  6779. full_output_size = upsample_common_check(
  6780. input_size, output_size, num_spatial_dims=2
  6781. )
  6782. torch._check(
  6783. grad_output.ndim == 4,
  6784. lambda: f"Expected grad_output to be a tensor of dimension 4 but got: dimension {grad_output.ndim}",
  6785. )
  6786. for i in range(4):
  6787. torch._check(
  6788. grad_output.shape[i] == full_output_size[i],
  6789. lambda: f"""
  6790. Expected grad_output to have the same shape as output; output.size({i}) = {full_output_size[i]}
  6791. but got grad_output_size({i}) = {grad_output.size(i)}""",
  6792. )
  6793. return grad_output.new_empty(input_size).to(
  6794. memory_format=utils.suggest_memory_format(grad_output)
  6795. )
  6796. # From aten/src/ATen/native/cuda/AmpKernels.cu
  6797. @register_meta(aten._amp_foreach_non_finite_check_and_unscale_.default)
  6798. def _amp_foreach_non_finite_check_and_unscale_(self, found_inf, inv_scale):
  6799. torch._check(
  6800. found_inf.numel() == 1, lambda: "found_inf must be a 1-element tensor."
  6801. )
  6802. torch._check(
  6803. inv_scale.numel() == 1, lambda: "inv_scale must be a 1-element tensor."
  6804. )
  6805. torch._check(
  6806. found_inf.dtype.is_floating_point,
  6807. lambda: "found_inf must be a float tensor.",
  6808. )
  6809. torch._check(
  6810. inv_scale.dtype.is_floating_point,
  6811. lambda: "inv_scale must be a float tensor.",
  6812. )
  6813. # From aten/src/ATen/native/UnaryOps.cpp
  6814. @register_meta([aten.nan_to_num.default, aten.nan_to_num.out])
  6815. @out_wrapper()
  6816. def nan_to_num(self, nan=None, posinf=None, neginf=None):
  6817. return torch.empty_like(self)
  6818. @register_meta(torch.ops.aten.transpose_)
  6819. def transpose_(self, dim0, dim1):
  6820. if self.layout in {
  6821. torch.sparse_csr,
  6822. torch.sparse_csc,
  6823. torch.sparse_bsr,
  6824. torch.sparse_bsc,
  6825. }:
  6826. raise AssertionError(
  6827. f"torch.transpose_: in-place transposition is not supported for {self.layout} layout"
  6828. )
  6829. ndims = self.ndim
  6830. dim0 = maybe_wrap_dim(dim0, ndims)
  6831. dim1 = maybe_wrap_dim(dim1, ndims)
  6832. if dim0 == dim1:
  6833. return self
  6834. size = list(self.size())
  6835. stride = list(self.stride())
  6836. stride[dim0], stride[dim1] = stride[dim1], stride[dim0]
  6837. size[dim0], size[dim1] = size[dim1], size[dim0]
  6838. self.as_strided_(size, stride)
  6839. return self
  6840. @register_meta(torch.ops.aten.t_)
  6841. def t_(self):
  6842. ndims = self.ndim
  6843. if self.is_sparse:
  6844. sparse_dim = self.sparse_dim()
  6845. dense_dim = self.dense_dim()
  6846. if not (sparse_dim <= 2 and dense_dim == 0):
  6847. raise AssertionError(
  6848. f"t_ expects a tensor with <= 2 sparse and 0 dense dimensions, "
  6849. f"but got {sparse_dim} sparse and {dense_dim} dense dimensions"
  6850. )
  6851. else:
  6852. if self.dim() > 2:
  6853. raise AssertionError(
  6854. f"t_ expects a tensor with <= 2 dimensions, but self is {ndims}D"
  6855. )
  6856. return transpose_(self, 0, 0 if ndims < 2 else 1)
  6857. @register_meta(aten.searchsorted)
  6858. @out_wrapper()
  6859. def meta_searchsorted(
  6860. sorted_sequence,
  6861. self,
  6862. *,
  6863. out_int32=False,
  6864. right=False,
  6865. side=None,
  6866. sorter=None,
  6867. ):
  6868. # If the sorted_sequence is not one-dimensional, its shape must match that of values
  6869. # in all but the last dimension.
  6870. torch._check(
  6871. len(sorted_sequence.shape) <= 1
  6872. or sorted_sequence.shape[:-1] == self.shape[:-1],
  6873. lambda: (
  6874. "torch.searchsorted(): boundaries tensor should be 1 dimension or the "
  6875. "first N-1 dimensions of boundaries tensor and input value tensor must "
  6876. f"match, but we got boundaries tensor {list(sorted_sequence.shape)} and "
  6877. f"input value tensor {list(self.shape)}"
  6878. ),
  6879. )
  6880. # If a sorter array is provided, its dimensions must exactly match sorted_sequence.
  6881. torch._check(
  6882. sorter is None or sorted_sequence.shape == sorter.shape,
  6883. lambda: (
  6884. "torch.searchsorted(): boundary and sorter must have the same size, but "
  6885. f"got boundary tensor {list(sorted_sequence.shape)} and got sorter tensor "
  6886. f"{list(sorter.shape) if sorter is not None else []}"
  6887. ),
  6888. )
  6889. # Per the docs, if side == "left" and right is True, we error.
  6890. torch._check(
  6891. side != "left" or not right,
  6892. lambda: "torch.searchsorted(): side and right can't be set to opposites, got side of "
  6893. "left while right was True",
  6894. )
  6895. dtype = torch.int32 if out_int32 else torch.int64
  6896. if isinstance(self, torch.Tensor):
  6897. return torch.empty_like(
  6898. self, dtype=dtype, memory_format=torch.contiguous_format
  6899. )
  6900. else: # Scalar
  6901. return torch.empty((), dtype=dtype, device=sorted_sequence.device)
  6902. def _check_for_unsupported_isin_dtype(dtype):
  6903. torch._check(
  6904. dtype not in (torch.bool, torch.complex128, torch.complex64),
  6905. lambda: f"Unsupported input type encountered for isin(): {dtype}",
  6906. )
  6907. @register_meta(aten.embedding_dense_backward)
  6908. def meta_embedding_dense_backward(
  6909. grad_output,
  6910. indices,
  6911. num_weights,
  6912. padding_idx,
  6913. scale_grad_by_freq,
  6914. ):
  6915. grad_weight = grad_output.new_empty((num_weights, grad_output.size(-1)))
  6916. return grad_weight
  6917. @register_meta(aten._embedding_bag_backward)
  6918. def meta_embedding_bag_backward(
  6919. grad,
  6920. indices,
  6921. offsets,
  6922. offset2bag,
  6923. bag_size,
  6924. maximum_indices,
  6925. num_weights,
  6926. scale_grad_by_freq,
  6927. mode,
  6928. sparse,
  6929. per_sample_weights,
  6930. padding_idx=-1,
  6931. ):
  6932. if sparse:
  6933. return aten._embedding_bag_sparse_backward(
  6934. grad,
  6935. indices,
  6936. offsets,
  6937. offset2bag,
  6938. bag_size,
  6939. num_weights,
  6940. scale_grad_by_freq,
  6941. mode,
  6942. per_sample_weights,
  6943. padding_idx,
  6944. )
  6945. else:
  6946. return meta_embedding_bag_dense_backward(
  6947. grad,
  6948. indices,
  6949. offset2bag,
  6950. bag_size,
  6951. maximum_indices,
  6952. num_weights,
  6953. scale_grad_by_freq,
  6954. mode,
  6955. per_sample_weights,
  6956. padding_idx,
  6957. )
  6958. @register_meta(aten._embedding_bag_dense_backward)
  6959. def meta_embedding_bag_dense_backward(
  6960. grad,
  6961. indices,
  6962. offset2bag,
  6963. bag_size,
  6964. maximum_indices,
  6965. num_weights,
  6966. scale_grad_by_freq,
  6967. mode,
  6968. per_sample_weights,
  6969. padding_idx=-1,
  6970. ):
  6971. torch._check(
  6972. grad.dtype in [torch.float16, torch.bfloat16, torch.float32, torch.float64],
  6973. lambda: f"Unsupported input type encountered: {grad.dtype}",
  6974. )
  6975. if mode == MODE_MAX:
  6976. torch._check(maximum_indices is not None)
  6977. index_grad_weight = grad.new_empty((num_weights, grad.size(1)))
  6978. return index_grad_weight
  6979. @register_meta(aten._embedding_bag_per_sample_weights_backward)
  6980. def meta_embedding_bag_per_sample_weights_backward(
  6981. grad,
  6982. weight,
  6983. indices,
  6984. offsets,
  6985. offset2bag,
  6986. mode,
  6987. padding_idx=-1,
  6988. ):
  6989. embedding_features = grad.size(1)
  6990. torch._check(
  6991. mode == MODE_SUM,
  6992. lambda: "embedding_bag_backward: per_sample_weights only supported for mode='sum'",
  6993. )
  6994. torch._check(grad.dim() == 2)
  6995. torch._check(indices.dim() == 1)
  6996. num_samples = indices.size(0)
  6997. torch._check(weight.dim() == 2)
  6998. torch._check(weight.size(1) == embedding_features)
  6999. output = grad.new_empty((num_samples,))
  7000. return output
  7001. @register_meta(aten.isin)
  7002. @out_wrapper()
  7003. def meta_isin(elements, test_elements, *, assume_unique=False, invert=False):
  7004. torch._check(
  7005. isinstance(elements, Tensor) or isinstance(test_elements, Tensor),
  7006. lambda: "At least one of elements and test_elements must be a Tensor.",
  7007. )
  7008. if not isinstance(elements, Tensor):
  7009. elements = torch.tensor(elements, device=test_elements.device)
  7010. if not isinstance(test_elements, Tensor):
  7011. test_elements = torch.tensor(test_elements, device=elements.device)
  7012. _check_for_unsupported_isin_dtype(elements.dtype)
  7013. _check_for_unsupported_isin_dtype(test_elements.dtype)
  7014. return torch.empty_like(elements, dtype=torch.bool)
  7015. @register_meta(aten.polygamma)
  7016. @out_wrapper()
  7017. def meta_polygamma(n: int, self: Tensor) -> Tensor:
  7018. torch._check(n >= 0, lambda: "polygamma(n, x) does not support negative n.")
  7019. _, result_dtype = elementwise_dtypes(
  7020. self,
  7021. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  7022. )
  7023. return torch.empty_like(self, dtype=result_dtype)
  7024. @register_meta(aten._local_scalar_dense)
  7025. def meta_local_scalar_dense(self: Tensor):
  7026. raise RuntimeError("Tensor.item() cannot be called on meta tensors")
  7027. @register_meta(aten.silu)
  7028. @out_wrapper(exact_dtype=True)
  7029. def silu(self: Tensor) -> Tensor:
  7030. return torch.empty_like(self)
  7031. @register_meta(aten.sigmoid)
  7032. @out_wrapper()
  7033. def sigmoid(self: Tensor) -> Tensor:
  7034. _, result_dtype = elementwise_dtypes(
  7035. self,
  7036. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  7037. )
  7038. return torch.empty_like(self, dtype=result_dtype)
  7039. def _create_grouped_mm_output_tensor(mat1, mat2, offs, out_dtype):
  7040. mat1_is_2d = mat1.dim() == 2
  7041. mat2_is_2d = mat2.dim() == 2
  7042. if mat1_is_2d:
  7043. if mat2_is_2d:
  7044. out_size = [offs.size(0), mat1.size(0), mat2.size(1)]
  7045. else:
  7046. torch._check(
  7047. offs.size(0) == mat2.size(0), lambda: "matrix batch sizes have to match"
  7048. )
  7049. out_size = [mat1.size(0), mat2.size(-1)]
  7050. else:
  7051. if mat2_is_2d:
  7052. torch._check(
  7053. offs.size(0) == mat1.size(0), lambda: "matrix batch sizes have to match"
  7054. )
  7055. out_size = [mat1.size(1), mat2.size(1)]
  7056. else:
  7057. # regular bmm
  7058. torch._check(
  7059. mat1.size(0) == mat2.size(0), lambda: "batched dimension has to match"
  7060. )
  7061. out_size = [mat1.size(0), mat1.size(1), mat2.size(-1)]
  7062. out_dtype = out_dtype or mat1.dtype
  7063. if torch.version.cuda:
  7064. alignment = 16 // out_dtype.itemsize
  7065. size_padded = (out_size[-1] + alignment - 1) // alignment * alignment
  7066. if mat1_is_2d == mat2_is_2d:
  7067. out_stride = [out_size[1] * size_padded, size_padded, 1]
  7068. else:
  7069. out_stride = [size_padded, 1]
  7070. out = torch.empty_strided(
  7071. out_size, out_stride, dtype=out_dtype, device=mat1.device
  7072. )
  7073. else:
  7074. out = torch.empty(out_size, dtype=out_dtype, device=mat1.device)
  7075. return out
  7076. def _meta_grouped_mm_common(
  7077. mat_a: Tensor,
  7078. mat_b: Tensor,
  7079. scale_a: torch.Tensor | None,
  7080. scale_b: torch.Tensor | None,
  7081. offs: Tensor | None = None,
  7082. bias: Tensor | None = None,
  7083. scale_result: torch.Tensor | None = None,
  7084. out_dtype: torch.dtype | None = None,
  7085. use_fast_accum: bool = False,
  7086. ):
  7087. torch._check(
  7088. (scale_a is None) == (scale_b is None),
  7089. lambda: "Either both scale factors are given, or none",
  7090. )
  7091. scaled = scale_a is not None and scale_b is not None
  7092. # Implementing all the checks from
  7093. # _grouped_mm_cuda()/_scaled_grouped_mm_cuda() code in
  7094. # aten/src/ATen/native/cuda/Blas.cpp.
  7095. if scaled:
  7096. fp8_dtype = torch.float8_e4m3fn
  7097. if (
  7098. torch.version.hip
  7099. and torch.cuda.is_available()
  7100. and "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
  7101. ):
  7102. fp8_dtype = torch.float8_e4m3fnuz
  7103. torch._check(
  7104. mat_a.dtype == fp8_dtype and mat_b.dtype == fp8_dtype,
  7105. lambda: f"Expected inputs of E4M3 FP8 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", # noqa: B950
  7106. )
  7107. else:
  7108. torch._check(
  7109. mat_a.dtype == torch.bfloat16 and mat_b.dtype == torch.bfloat16,
  7110. lambda: f"Expected inputs of BF16 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", # noqa: B950
  7111. )
  7112. torch._check(
  7113. mat_a.dim() in [2, 3] and mat_b.dim() in [2, 3],
  7114. lambda: f"Multiplicands must be 2D or 3D but got mat_a.dim()={mat_a.dim()} and mat_b.dim()={mat_b.dim()}", # noqa: B950
  7115. )
  7116. mat_a_is_2d = mat_a.dim() == 2
  7117. mat_b_is_2d = mat_b.dim() == 2
  7118. if not mat_a_is_2d or not mat_b_is_2d:
  7119. torch._check(
  7120. mat_a.size(-1) == mat_b.size(-2),
  7121. lambda: "contraction dimension of mat_a and mat_b must match",
  7122. )
  7123. if scaled:
  7124. def is_row_major(mat):
  7125. mat_stride = mat.stride()
  7126. return mat_stride[-2] > 1 and mat_stride[-1] == 1
  7127. def is_col_major(mat):
  7128. mat_stride = mat.stride()
  7129. return mat_stride[-2] == 1 and mat_stride[-1] > 1
  7130. torch._check(
  7131. is_row_major(mat_a),
  7132. lambda: f"Expected mat_a tensor to be row major in the last two dimensions, got strides {mat_a.stride()[-2:]}", # noqa: B950
  7133. )
  7134. torch._check(
  7135. is_col_major(mat_b),
  7136. lambda: f"Expected mat_b tensor to be column major in the last two dimensions, got strides {mat_b.stride()[-2:]}", # noqa: B950
  7137. )
  7138. def check_valid_strides(mat_name, mat):
  7139. end_dim = mat.dim() - 1
  7140. alignment = 16 // mat.element_size()
  7141. mat_stride = mat.stride()
  7142. if mat_stride[end_dim - 1] == 1 and mat_stride[end_dim] >= max(
  7143. 1, mat.shape[end_dim - 1]
  7144. ):
  7145. torch._check(
  7146. mat_stride[end_dim] % alignment == 0,
  7147. lambda: f"Expected {mat_name} stride along {end_dim} dim to be multiple of 16 bytes, got {mat_stride[end_dim]}.", # noqa: B950
  7148. )
  7149. elif mat_stride[end_dim] == 1 and mat_stride[end_dim - 1] >= max(
  7150. 1, mat.shape[end_dim]
  7151. ):
  7152. torch._check(
  7153. mat_stride[end_dim - 1] % alignment == 0,
  7154. lambda: f"Expected {mat_name} stride along {end_dim - 1} dim to be multiple of 16 bytes, got {mat_stride[end_dim - 1]}.", # noqa: B950
  7155. )
  7156. else:
  7157. torch._check(
  7158. False,
  7159. lambda: f"Invalid strides/sizes, got {mat_stride} for strides and {mat.shape} for sizes.", # noqa: B950
  7160. )
  7161. check_valid_strides("mat_a", mat_a)
  7162. check_valid_strides("mat_b", mat_b)
  7163. if scale_a is not None and scale_b is not None:
  7164. torch._check(
  7165. (scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32)
  7166. or (
  7167. scale_a.dtype == torch.float8_e8m0fnu
  7168. and scale_b.dtype == torch.float8_e8m0fnu
  7169. ),
  7170. lambda: f"For FP8 scales must both be float32, or for MXFP8 both scales must be float8_e8m0fnu. Got scale_a.dtype={scale_a.dtype} and scale_b.dtype={scale_b.dtype}.", # noqa: B950
  7171. )
  7172. is_mxfp8 = (
  7173. scale_a.dtype == torch.float8_e8m0fnu
  7174. and scale_b.dtype == torch.float8_e8m0fnu
  7175. )
  7176. def check_scale(scale_name, scale, mat, scaled_dim, scale_multiplier=1):
  7177. if mat.dim() == 2:
  7178. torch._check(
  7179. scale.is_contiguous(),
  7180. lambda: f"Expected {scale_name} to be contiguous.",
  7181. )
  7182. # For MXFP8, 2d tensors have variable size groups represented as subtensors,
  7183. # that are converted to blocked padded format individually. At compile time we don't know
  7184. # the group sizes yet, so we don't know the expect size of the blocked format scale.
  7185. # This limits what we can check here.
  7186. if is_mxfp8:
  7187. torch._check(
  7188. scale.dim() == mat.dim(),
  7189. lambda: f"For MXFP8, scale must have same number of dimensions as target tensor, but {scale_name} has mat.ndim={mat.ndim} and scale.ndim={scale.ndim}", # noqa: B950
  7190. )
  7191. else:
  7192. torch._check(
  7193. scale.dim() == 1,
  7194. lambda: f"Expected {scale_name} to be 1D tensor, but got {scale.dim()}D tensor.",
  7195. )
  7196. torch._check(
  7197. scale.shape[0] == mat.shape[scaled_dim] * scale_multiplier,
  7198. lambda: f"Expected {scale_name} to have {mat.shape[scaled_dim] * scale_multiplier} elements, got {scale.shape[0]} elements.", # noqa: B950
  7199. )
  7200. else:
  7201. torch._check(
  7202. scale.stride(-1) == 1,
  7203. lambda: f"Expected {scale_name} to be contiguous in the last dimension.",
  7204. )
  7205. torch._check(
  7206. scale.shape[0] == mat.shape[0],
  7207. lambda: f"Expected {scale_name} batch dimension to be {mat.shape[0]}, got {scale.shape[0]}.",
  7208. )
  7209. # For MXFP8, 3d tensors have static 'groups' (stack of 2d tensors) so we can know the expected blocked
  7210. # scale sizes at compile time.
  7211. if is_mxfp8:
  7212. torch._check(
  7213. scale.ndim == mat.ndim - 1,
  7214. lambda: f"For MXFP8, 3d tensor should have 2d scales, but {scale_name} has mat.ndim={mat.ndim} and scale.ndim={scale.ndim}", # noqa: B950
  7215. )
  7216. # TODO: This logic only holds for RHS tensor in 2d-3d case.
  7217. # We'll need to update it to handle LHS 3d tensor in 3d-2d and 3d-3d cases.
  7218. G, K, N = mat.shape
  7219. block_size = 32
  7220. blocked_K = round_up(K / block_size, 4)
  7221. blocked_N = round_up(N, 128)
  7222. torch._check(
  7223. scale.shape[0] == G and scale.shape[1] == blocked_K * blocked_N,
  7224. lambda: f"For MXFP8, expected mat.shape={mat.shape} to have scale shape of ({G},{blocked_K * blocked_N}), but got {scale.shape}", # noqa: B950
  7225. )
  7226. else:
  7227. torch._check(
  7228. scale.dim() == 2,
  7229. lambda: f"Expected {scale_name} to be 2D tensor, but got {scale.dim()}D tensor.",
  7230. )
  7231. torch._check(
  7232. scale.shape[1] == mat.shape[1 + scaled_dim],
  7233. lambda: f"Expected {scale_name} non-batch dimension to be {mat.shape[1 + scaled_dim]}, got {scale.shape[1]}.", # noqa: B950
  7234. )
  7235. scale_multiplier = (
  7236. offs.shape[0] if offs is not None and mat_a_is_2d and mat_b_is_2d else 1
  7237. )
  7238. check_scale("scale_a", scale_a, mat_a, 0, scale_multiplier)
  7239. check_scale("scale_b", scale_b, mat_b, 1, scale_multiplier)
  7240. torch._check(
  7241. scale_result is None,
  7242. lambda: "Scale result tensor provided, but it is not supported yet.",
  7243. )
  7244. if mat_a_is_2d or mat_b_is_2d:
  7245. torch._check(
  7246. offs is not None,
  7247. lambda: f"Offsets tensor not provided, but is needed for {mat_a.dim()}D/{mat_b.dim()}D multiplicand layouts.",
  7248. )
  7249. if offs is not None: # to silence Mypy
  7250. torch._check(
  7251. offs.dim() == 1,
  7252. lambda: f"Offsets tensor must be 1D, but got offs.dim()={offs.dim()}.",
  7253. )
  7254. torch._check(
  7255. offs.dtype == torch.int32,
  7256. lambda: f"Offsets tensor must be integer (int32) tensor, but got {offs.dtype}.",
  7257. )
  7258. else:
  7259. torch._check(
  7260. offs is None,
  7261. lambda: "Offsets tensor provided, but is not needed for 3D/3D multiplicand layouts.",
  7262. )
  7263. torch._check(
  7264. bias is None,
  7265. lambda: "Bias tensor provided, but it is not supported yet.",
  7266. )
  7267. torch._check(
  7268. out_dtype is None or out_dtype == torch.bfloat16,
  7269. lambda: "If output dtype provided, it must be torch.bfloat16.",
  7270. )
  7271. return _create_grouped_mm_output_tensor(mat_a, mat_b, offs, out_dtype)
  7272. @register_meta(aten._grouped_mm)
  7273. @out_wrapper()
  7274. def meta_grouped_mm(
  7275. mat_a: Tensor,
  7276. mat_b: Tensor,
  7277. offs: Tensor | None = None,
  7278. bias: Tensor | None = None,
  7279. out_dtype: torch.dtype | None = None,
  7280. ) -> Tensor:
  7281. return _meta_grouped_mm_common(
  7282. mat_a,
  7283. mat_b,
  7284. scale_a=None,
  7285. scale_b=None,
  7286. offs=offs,
  7287. bias=bias,
  7288. scale_result=None,
  7289. out_dtype=out_dtype,
  7290. )
  7291. @register_meta([aten._scaled_grouped_mm])
  7292. def meta_scaled_grouped_mm(
  7293. mat_a: torch.Tensor,
  7294. mat_b: torch.Tensor,
  7295. scale_a: torch.Tensor,
  7296. scale_b: torch.Tensor,
  7297. offs: torch.Tensor | None = None,
  7298. bias: torch.Tensor | None = None,
  7299. scale_result: torch.Tensor | None = None,
  7300. out_dtype: torch.dtype | None = None,
  7301. use_fast_accum: bool = False,
  7302. ):
  7303. # matching _scaled_grouped_mm_cuda Blas.cpp implementation
  7304. out_dtype = out_dtype or torch.bfloat16
  7305. return _meta_grouped_mm_common(
  7306. mat_a,
  7307. mat_b,
  7308. scale_a=scale_a,
  7309. scale_b=scale_b,
  7310. offs=offs,
  7311. bias=bias,
  7312. scale_result=scale_result,
  7313. out_dtype=out_dtype,
  7314. use_fast_accum=use_fast_accum,
  7315. )
  7316. @register_meta(aten._foreach_norm.Scalar)
  7317. def meta_foreach_norm(tensors, ord=2, dtype=None):
  7318. if float(ord) == float("inf"):
  7319. for t in tensors:
  7320. torch._check(
  7321. t.numel() > 0,
  7322. lambda: "_foreach_norm cannot compute infinity norm on empty tensor",
  7323. )
  7324. results = []
  7325. for t in tensors:
  7326. out_dtype = dtype if dtype is not None else t.dtype
  7327. if out_dtype.is_complex:
  7328. out_dtype = corresponding_real_dtype(out_dtype)
  7329. results.append(t.new_empty((), dtype=out_dtype))
  7330. return results
  7331. @register_meta(aten._softmax)
  7332. @out_wrapper()
  7333. def softmax(x: Tensor, dim: int, half_to_float: bool) -> Tensor:
  7334. if half_to_float:
  7335. if x.dtype not in [torch.half, torch.bfloat16]:
  7336. raise AssertionError(
  7337. f"half_to_float is True but x.dtype is {x.dtype}, expected half or bfloat16"
  7338. )
  7339. computation_dtype, result_dtype = utils.elementwise_dtypes(
  7340. x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  7341. )
  7342. result_dtype = result_dtype if not half_to_float else computation_dtype
  7343. res = torch.empty_like(x, dtype=result_dtype, memory_format=torch.contiguous_format)
  7344. return res
  7345. @register_meta(aten.constant_pad_nd)
  7346. @out_wrapper()
  7347. def _constant_pad_nd_meta(input, pad, value=0):
  7348. # same checks as decomposition in torch/_refs/__init__.py:constant_pad_nd()
  7349. torch._check(
  7350. len(pad) % 2 == 0,
  7351. lambda: f"Length of pad must be even but instead it equals {len(pad)}",
  7352. )
  7353. input_sizes = input.shape
  7354. l_inp = len(input_sizes)
  7355. l_pad = len(pad) // 2
  7356. l_diff = l_inp - l_pad
  7357. torch._check(
  7358. l_inp >= l_pad,
  7359. lambda: "Length of pad should be no more than twice the number of "
  7360. f"dimensions of the input. Pad length is {len(pad)} while the input has "
  7361. f"{l_inp} dimensions.",
  7362. )
  7363. if all(isinstance(p, utils.IntWithoutSymInt) and p <= 0 for p in pad):
  7364. c_input = input
  7365. for i in range(l_diff, l_inp):
  7366. pad_idx = 2 * (l_inp - i - 1)
  7367. if pad[pad_idx] < 0:
  7368. c_input = c_input.narrow(
  7369. i, -pad[pad_idx], c_input.shape[i] + pad[pad_idx]
  7370. )
  7371. if pad[pad_idx + 1] < 0:
  7372. c_input = c_input.narrow(i, 0, c_input.shape[i] + pad[pad_idx + 1])
  7373. return c_input.clone()
  7374. new_shape = list(input_sizes[:l_diff])
  7375. for i in range(l_pad):
  7376. pad_idx = len(pad) - ((i + 1) * 2)
  7377. new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]
  7378. torch._check(
  7379. new_dim >= 0,
  7380. lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding "
  7381. f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, "
  7382. f"which is invalid. Check dimension {l_diff + i} of your input.",
  7383. )
  7384. new_shape.append(new_dim)
  7385. return torch.empty(
  7386. new_shape,
  7387. dtype=input.dtype,
  7388. device=input.device,
  7389. requires_grad=input.requires_grad,
  7390. memory_format=suggest_memory_format(input),
  7391. )
  7392. @register_meta(aten.embedding)
  7393. @out_wrapper()
  7394. def embedding(
  7395. weight: Tensor,
  7396. indices: Tensor,
  7397. padding_idx: int = -1,
  7398. scale_grad_by_freq: bool = False,
  7399. sparse: bool = False,
  7400. ) -> Tensor:
  7401. if weight.dim() != 2:
  7402. raise AssertionError(f"'weight' must be 2-D, got {weight.dim()}-D")
  7403. weight_shape = weight.shape
  7404. indices_shape = indices.shape
  7405. if indices.ndim == 0:
  7406. out_shape: tuple[int, ...] = (weight_shape[1],)
  7407. elif indices.ndim == 1:
  7408. out_shape = (indices_shape[0], weight_shape[1])
  7409. else:
  7410. out_shape = (*indices_shape, weight_shape[1])
  7411. out_dtype = weight.dtype
  7412. return weight.new_empty(out_shape, dtype=out_dtype)
  7413. @register_meta(aten._jagged_to_padded_dense_forward.default)
  7414. def meta__jagged_to_padded_dense_forward(
  7415. values: Tensor,
  7416. offsets: list[Tensor],
  7417. max_lengths: list[int],
  7418. padding_value: float = 0.0,
  7419. ):
  7420. # only one jagged dim is supported for now
  7421. if len(offsets) != 1:
  7422. raise AssertionError(
  7423. f"Only one jagged dim is supported, got {len(offsets)} offsets"
  7424. )
  7425. if len(max_lengths) != 1:
  7426. raise AssertionError(
  7427. f"Only one jagged dim is supported, got {len(max_lengths)} max_lengths"
  7428. )
  7429. B = offsets[0].shape[0] - 1
  7430. S = max_lengths[0]
  7431. output_shape = (B, S, *values.shape[1:])
  7432. return values.new_empty(output_shape)
  7433. def _create_unary_float_meta_func(func):
  7434. @register_meta(func)
  7435. @out_wrapper()
  7436. def _f(x):
  7437. return elementwise_meta(
  7438. x, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
  7439. )
  7440. return _f
  7441. # Implementation follows cuda implementation native_multi_head_attention_cuda
  7442. @register_meta(aten._native_multi_head_attention.default)
  7443. def native_multi_head_attention_fake(
  7444. query,
  7445. key,
  7446. value,
  7447. embed_dim,
  7448. num_head,
  7449. qkv_weight,
  7450. qkv_bias,
  7451. proj_weight,
  7452. proj_bias,
  7453. mask=None,
  7454. need_weights=True,
  7455. average_attn_weights=True,
  7456. mask_type=None,
  7457. ):
  7458. if query.is_nested or key.is_nested or value.is_nested:
  7459. raise NotImplementedError(
  7460. "_native_multi_head_attention fake implementation does not support nested tensors"
  7461. )
  7462. if query.numel() == 0:
  7463. return (query.new_empty(query.shape), query.new_empty(0))
  7464. B = query.size(0) # B: batch size
  7465. T = query.size(1) # T: target sequence length
  7466. # In native_multi_head_attention_cuda,
  7467. # we have proj = transform0213_gemm_nt_bias(attn_ctx, proj_weight, proj_bias, query)
  7468. # , which does attn_ctx @ proj_weight.T + proj_bias
  7469. # so the last dim of output shape is proj_weight.size(0)
  7470. output_dim = proj_weight.size(0)
  7471. output = query.new_empty(B, T, output_dim)
  7472. if need_weights:
  7473. if average_attn_weights:
  7474. # When averaging attention weights, shape is [B, T, T] (averaged over heads)
  7475. # T = query seq len, S = key/value seq len
  7476. attn_weights = query.new_empty(B, T, T)
  7477. else:
  7478. # When not averaging, shape is [B, num_head, T, T]
  7479. # T = query seq len, S = key/value seq len
  7480. attn_weights = query.new_empty(B, num_head, T, T)
  7481. else:
  7482. attn_weights = query.new_empty(0)
  7483. return (output, attn_weights)
  7484. def _create_binary_float_meta_func(func):
  7485. @register_meta(func)
  7486. @out_wrapper()
  7487. def _f(x, y):
  7488. return elementwise_meta(
  7489. x, y, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
  7490. )
  7491. return _f
  7492. _create_unary_float_meta_func(aten.special_airy_ai)
  7493. _create_unary_float_meta_func(aten.special_bessel_y0)
  7494. _create_unary_float_meta_func(aten.special_bessel_y1)
  7495. _create_unary_float_meta_func(aten.special_modified_bessel_i0)
  7496. _create_unary_float_meta_func(aten.special_modified_bessel_i1)
  7497. _create_unary_float_meta_func(aten.special_modified_bessel_k0)
  7498. _create_unary_float_meta_func(aten.special_modified_bessel_k1)
  7499. _create_unary_float_meta_func(aten.special_scaled_modified_bessel_k0)
  7500. _create_unary_float_meta_func(aten.special_scaled_modified_bessel_k1)
  7501. _create_binary_float_meta_func(aten.special_chebyshev_polynomial_t)
  7502. _create_binary_float_meta_func(aten.special_chebyshev_polynomial_u)
  7503. _create_binary_float_meta_func(aten.special_chebyshev_polynomial_v)
  7504. _create_binary_float_meta_func(aten.special_chebyshev_polynomial_w)
  7505. _create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_t)
  7506. _create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_u)
  7507. _create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_v)
  7508. _create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_w)
  7509. _create_binary_float_meta_func(aten.special_hermite_polynomial_h)
  7510. _create_binary_float_meta_func(aten.special_hermite_polynomial_he)
  7511. _create_binary_float_meta_func(aten.special_laguerre_polynomial_l)
  7512. _create_binary_float_meta_func(aten.special_legendre_polynomial_p)
  7513. def _register_inplace_meta(fn):
  7514. @wraps(fn)
  7515. def _fn(self, *args, **kwargs):
  7516. out = fn(self, *args, **kwargs)
  7517. check_inplace_broadcast(self.shape, out.shape)
  7518. return self
  7519. inplace_name = f"{fn.__name__}_"
  7520. _fn.__name__ = inplace_name
  7521. _fn = register_meta(getattr(aten, inplace_name))(_fn) # type: ignore[assignment]
  7522. return _fn
  7523. @register_meta(aten.lerp)
  7524. @out_wrapper()
  7525. def lerp(start, end, weight):
  7526. torch._check(
  7527. start.dtype == end.dtype,
  7528. lambda: f"expected dtype {start.dtype} for `end`, but got dtype {end.dtype}",
  7529. )
  7530. args = [start, end]
  7531. if isinstance(weight, TensorLike):
  7532. if weight.ndim != 0:
  7533. torch._check(
  7534. start.dtype == weight.dtype,
  7535. lambda: f"expected dtype {start.dtype} for `weight`, but got dtype {weight.dtype}",
  7536. )
  7537. args.append(weight)
  7538. return elementwise_meta(
  7539. *args, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  7540. )
  7541. @register_meta(aten.addcmul)
  7542. @out_wrapper()
  7543. def addcmul(input, tensor1, tensor2, *, value=1):
  7544. return elementwise_meta(
  7545. input, tensor1, tensor2, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  7546. )
  7547. @register_meta(aten.addcdiv)
  7548. @out_wrapper()
  7549. def addcdiv(input, tensor1, tensor2, *, value=1):
  7550. torch._check(
  7551. not (
  7552. utils.is_integer_dtype(tensor1.dtype)
  7553. and utils.is_integer_dtype(tensor2.dtype)
  7554. ),
  7555. lambda: (
  7556. "Integer division with addcdiv is no longer supported, and in a future ",
  7557. "release addcdiv will perform a true division of tensor1 and tensor2. ",
  7558. "The historic addcdiv behavior can be implemented as ",
  7559. "(input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) ",
  7560. "for integer inputs and as ",
  7561. "(input + value * tensor1 / tensor2) for float inputs. ",
  7562. "The future addcdiv behavior is just the latter implementation: ",
  7563. "(input + value * tensor1 / tensor2), for all dtypes.",
  7564. ),
  7565. )
  7566. return elementwise_meta(
  7567. input, tensor1, tensor2, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  7568. )
  7569. lerp_ = _register_inplace_meta(aten.lerp)
  7570. addcmul_ = _register_inplace_meta(aten.addcmul)
  7571. addcdiv_ = _register_inplace_meta(aten.addcdiv)
  7572. # We must also trigger meta registrations from PrimTorch ref
  7573. # decompositions
  7574. import torch._refs
  7575. import torch._refs.nn.functional
  7576. import torch._refs.special
  7577. def activate_meta():
  7578. activate_meta_table = {}
  7579. # For a given op, we pick the most specific decomp function from
  7580. # global_decomp_table in the precedence order of meta > post_autograd > pre_autograd
  7581. for typ in ["meta", "post_autograd", "pre_autograd"]:
  7582. registry = global_decomposition_table[typ]
  7583. for opo in registry:
  7584. if opo not in activate_meta_table:
  7585. activate_meta_table[opo] = registry[opo]
  7586. for op_overload, fn in activate_meta_table.items():
  7587. # Don't register meta for HigherOrderOp's decomp.
  7588. # We can reconsider this in the future, but in general,
  7589. # the way you do a meta for a HigherOrderOp is different from
  7590. # OpOverload.
  7591. if isinstance(op_overload, torch._ops.HigherOrderOperator):
  7592. continue
  7593. if not isinstance(op_overload, OpOverload):
  7594. raise AssertionError(
  7595. f"op_overload must be OpOverload, got {type(op_overload)}"
  7596. )
  7597. op_overload.py_impl(torch._C.DispatchKey.Meta)(fn)
  7598. if torch._C._dispatch_has_kernel_for_dispatch_key(
  7599. op_overload.name(), "CompositeImplicitAutograd"
  7600. ):
  7601. # Internally, we shouldn't be registering meta kernels for any operators that
  7602. # have CompositeImplicitAutograd kernels.
  7603. # Instead, we should be letting those decompositions run, and writing meta kernels
  7604. # only for the base operators.
  7605. if op_overload in global_decomposition_table["meta"]:
  7606. raise RuntimeError(
  7607. f"{op_overload} is a CompositeImplicitAutograd op, we shouldn't "
  7608. "register meta function for it. Instead, we should let the decomposition run and write "
  7609. "meta kernels for the base operators."
  7610. )
  7611. elif op_overload.is_view:
  7612. # Attempting to register a python meta kernel for a view operator.
  7613. # We shouldn't do this, because the output will report as not having aliased storages.
  7614. # All view ops have meta kernels in C++ today, so we should use those instead.
  7615. pass
  7616. elif (
  7617. op_overload.name()
  7618. in {
  7619. "aten::empty_strided", # causing infinite recursion, test_meta.py
  7620. "aten::clone", # causing infinite recursion
  7621. "aten::_to_copy", # causing infinite recursion, test_serialization.py -k test_tensor_subclass_getstate_overwrite # noqa: B950
  7622. "aten::copy_", # Exception not raised, test_torch.py -k test_storage_meta_errors_cpu_int64 # noqa: B950
  7623. "aten::constant_pad_nd", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_amp_istft_cuda_float32 # noqa: B950
  7624. "aten::rot90", # requires_grad mismatch! test_ops.py -k test_fake_crossref_backward_amp_rot90_cuda_float32 # noqa: B950
  7625. "aten::as_strided_scatter", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_no_amp_as_strided_scatter_cuda_float32 # noqa: B950
  7626. }
  7627. ):
  7628. pass
  7629. else:
  7630. if "mkldnn::" in op_overload.name():
  7631. _meta_lib_dont_use_me_use_register_meta_for_mkldnn.impl(op_overload, fn)
  7632. elif "mkl::" in op_overload.name():
  7633. _meta_lib_dont_use_me_use_register_meta_for_mkl.impl(op_overload, fn)
  7634. elif "onednn::" in op_overload.name():
  7635. _meta_lib_dont_use_me_use_register_meta_for_onednn.impl(op_overload, fn)
  7636. elif "quantized::" in op_overload.name():
  7637. _meta_lib_dont_use_me_use_register_meta_for_quantized.impl(
  7638. op_overload, fn
  7639. )
  7640. else:
  7641. _meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn)
  7642. activate_meta()