decompositions.py 182 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490
  1. # mypy: allow-untyped-decorators
  2. # mypy: allow-untyped-defs
  3. import functools
  4. import itertools
  5. import numbers
  6. import operator
  7. import sys
  8. from collections.abc import Callable, Iterable
  9. from contextlib import nullcontext
  10. from enum import Enum
  11. from functools import partial, reduce
  12. from itertools import chain, product
  13. from typing import Any, cast, Optional, Union
  14. import torch
  15. import torch._meta_registrations
  16. import torch._prims as prims
  17. import torch._prims_common as utils
  18. import torch.nn.functional as F
  19. from torch import sym_float, sym_int, Tensor
  20. from torch._decomp import register_decomposition
  21. from torch._higher_order_ops.out_dtype import out_dtype
  22. from torch._prims_common import (
  23. IntLike,
  24. NumberType,
  25. suggest_memory_format,
  26. TensorLike,
  27. TensorSequenceType,
  28. )
  29. from torch._prims_common.wrappers import (
  30. _maybe_convert_to_dtype,
  31. _maybe_resize_out,
  32. _safe_copy_out,
  33. out_wrapper,
  34. )
  35. from torch.utils import _pytree as pytree
  36. from torch.utils._pytree import tree_map
  37. DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
  38. # None of these functions are publicly accessible; get at them
  39. # from torch._decomps
  40. __all__: list[str] = []
  41. aten = torch._ops.ops.aten
  42. class Reduction(Enum):
  43. NONE = 0
  44. MEAN = 1
  45. SUM = 2
  46. # This wraps a decomposition and performs various type promotion logic within it, depending on the strategy provided
  47. # We're currently reusing ELEMENTWISE_TYPE_PROMOTION_KIND, although some of the usages are on non-elementwise ops
  48. # Will need to validate the non-elementwise uses
  49. def type_casts(
  50. f: Callable,
  51. type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND,
  52. compute_dtype_only: bool = False,
  53. include_non_tensor_args: bool = False,
  54. ):
  55. @functools.wraps(f)
  56. def inner(*args, **kwargs):
  57. allowed_types = (
  58. (Tensor, torch.types._Number) if include_non_tensor_args else (Tensor,)
  59. ) # type: ignore[arg-type]
  60. flat_args = [
  61. x
  62. for x in pytree.arg_tree_leaves(*args, **kwargs)
  63. if isinstance(x, allowed_types)
  64. ]
  65. computation_dtype, result_dtype = utils.elementwise_dtypes(
  66. *flat_args, type_promotion_kind=type_promotion
  67. )
  68. # TODO: pretty sure this is not quite right
  69. def increase_prec(x):
  70. if isinstance(x, Tensor):
  71. return x.to(computation_dtype)
  72. else:
  73. return x
  74. def decrease_prec(x):
  75. if isinstance(x, Tensor):
  76. return x.to(result_dtype)
  77. else:
  78. return x
  79. r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))
  80. if compute_dtype_only:
  81. return r
  82. else:
  83. return tree_map(decrease_prec, r)
  84. return inner
  85. compute_only_pw_cast_for_opmath = partial(
  86. type_casts,
  87. type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  88. compute_dtype_only=True,
  89. )
  90. pw_cast_for_opmath = partial(
  91. type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  92. )
  93. pw_cast_for_opmath_non_tensor_args = partial(
  94. type_casts,
  95. type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  96. include_non_tensor_args=True,
  97. )
  98. pw_cast_for_int_to_real = partial(
  99. type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
  100. )
  101. # This expands x until x.dim() == dim. Might be useful as an operator
  102. def _unsqueeze_to_dim(x: Tensor, dim: int) -> Tensor:
  103. for _ in range(dim - x.dim()):
  104. x = x.unsqueeze(-1)
  105. return x
  106. @register_decomposition(aten.tanh_backward)
  107. @out_wrapper("grad_input")
  108. @pw_cast_for_opmath
  109. def tanh_backward(out_grad: Tensor, y: Tensor):
  110. return out_grad * (1 - y * y).conj_physical()
  111. @register_decomposition(aten.sigmoid_backward)
  112. @out_wrapper("grad_input")
  113. @pw_cast_for_opmath
  114. def sigmoid_backward(out_grad: Tensor, y: Tensor):
  115. return out_grad * (y * (1 - y)).conj_physical()
  116. @register_decomposition(aten.softplus_backward)
  117. @out_wrapper("grad_input")
  118. @pw_cast_for_opmath
  119. def softplus_backward(out_grad: Tensor, x: Tensor, beta: float, threshold: float):
  120. z = (x * beta).exp()
  121. return torch.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0))
  122. @register_decomposition(aten.elu_backward)
  123. @out_wrapper("grad_input")
  124. @pw_cast_for_opmath
  125. def elu_backward(
  126. grad_output: Tensor,
  127. alpha: float,
  128. scale: float,
  129. input_scale: float,
  130. is_result: bool,
  131. self_or_result: Tensor,
  132. ):
  133. negcoef = alpha * scale
  134. poscoef = scale
  135. negiptcoef = input_scale
  136. if is_result:
  137. return torch.where(
  138. self_or_result <= 0,
  139. grad_output * negiptcoef * (self_or_result + negcoef),
  140. grad_output * poscoef,
  141. )
  142. else:
  143. return torch.where(
  144. self_or_result <= 0,
  145. grad_output * negiptcoef * negcoef * torch.exp(self_or_result * negiptcoef),
  146. grad_output * poscoef,
  147. )
  148. @register_decomposition([aten.fill.Scalar])
  149. def fill_scalar(self, value):
  150. return torch.full_like(self, value)
  151. @register_decomposition([aten.fill.Tensor])
  152. def fill_tensor(self, value: Tensor):
  153. torch._check(
  154. value.dim() == 0,
  155. lambda: f"fill only supports 0-dimension value tensor but got tensor with {value.dim()} dimensions",
  156. )
  157. return aten.copy(self, value)
  158. @register_decomposition(aten.hardsigmoid)
  159. @out_wrapper()
  160. @pw_cast_for_opmath
  161. def hardsigmoid(self: Tensor) -> Tensor:
  162. return torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6
  163. @register_decomposition(aten.hardsigmoid_backward)
  164. @out_wrapper("grad_input")
  165. @pw_cast_for_opmath
  166. def hardsigmoid_backward(grad_output: Tensor, self: Tensor):
  167. return torch.where(
  168. (self > -3.0) & (self < 3.0),
  169. grad_output * (1.0 / 6.0),
  170. 0.0,
  171. )
  172. @register_decomposition(aten.hardtanh_backward)
  173. @out_wrapper("grad_input")
  174. def hardtanh_backward(
  175. grad_output: Tensor, self: Tensor, min_val: float, max_val: float
  176. ):
  177. return torch.where((self <= min_val) | (self >= max_val), 0.0, grad_output)
  178. @register_decomposition(aten.hardswish)
  179. @out_wrapper()
  180. @pw_cast_for_opmath
  181. def hardswish(self: Tensor) -> Tensor:
  182. return self * torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6
  183. @register_decomposition(aten.hardswish_backward)
  184. @out_wrapper()
  185. @pw_cast_for_opmath
  186. def hardswish_backward(grad_output: Tensor, self: Tensor) -> Tensor:
  187. return torch.where(
  188. self <= -3,
  189. 0.0,
  190. torch.where(self < 3, grad_output * ((self / 3) + 0.5), grad_output),
  191. )
  192. @register_decomposition(aten.threshold_backward)
  193. @out_wrapper("grad_input")
  194. def threshold_backward(grad_output: Tensor, self: Tensor, threshold: float):
  195. return torch.where(self <= threshold, 0, grad_output)
  196. @register_decomposition(aten.leaky_relu_backward)
  197. @out_wrapper("grad_input")
  198. @pw_cast_for_opmath
  199. def leaky_relu_backward(
  200. grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool
  201. ):
  202. return torch.where(self > 0, grad_output, grad_output * negative_slope)
  203. @register_decomposition(aten.gelu_backward)
  204. @out_wrapper("grad_input")
  205. @pw_cast_for_opmath
  206. def gelu_backward(grad: Tensor, self: Tensor, approximate: str = "none"):
  207. M_SQRT2 = 1.41421356237309504880
  208. M_SQRT1_2 = 0.70710678118654752440
  209. M_2_SQRTPI = 1.12837916709551257390
  210. if approximate == "tanh":
  211. kBeta = M_SQRT2 * M_2_SQRTPI * 0.5
  212. kKappa = 0.044715
  213. x_sq = self * self
  214. x_cube = x_sq * self
  215. inner = kBeta * (self + kKappa * x_cube)
  216. tanh_inner = torch.tanh(inner)
  217. left = 0.5 * self
  218. right = 1 + tanh_inner
  219. left_derivative = 0.5 * right
  220. tanh_derivative = 1 - tanh_inner * tanh_inner
  221. inner_derivative = kBeta * (1 + 3 * kKappa * x_sq)
  222. right_derivative = left * tanh_derivative * inner_derivative
  223. return grad * (left_derivative + right_derivative)
  224. else:
  225. kAlpha = M_SQRT1_2
  226. kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5
  227. cdf = 0.5 * (1 + torch.erf(self * kAlpha))
  228. pdf = kBeta * torch.exp(self * self * -0.5)
  229. return grad * (cdf + self * pdf)
  230. @register_decomposition(aten.mish_backward)
  231. @pw_cast_for_opmath
  232. def mish_backward(grad_output: Tensor, input: Tensor):
  233. input_tanh_softplus = torch.tanh(F.softplus(input))
  234. input_sigmoid = torch.sigmoid(input)
  235. out = input * input_sigmoid * (1 - input_tanh_softplus * input_tanh_softplus)
  236. return grad_output * (input_tanh_softplus + out)
  237. @register_decomposition(aten.silu)
  238. @out_wrapper()
  239. @pw_cast_for_opmath
  240. def silu(self: Tensor) -> Tensor:
  241. return self * torch.sigmoid(self)
  242. @register_decomposition(aten.silu_backward)
  243. @out_wrapper("grad_input")
  244. @pw_cast_for_opmath
  245. def silu_backward(grad_output: Tensor, self: Tensor) -> Tensor:
  246. sigmoid = 1 / (1 + torch.exp(-self))
  247. return grad_output * sigmoid * (1 + self * (1 - sigmoid))
  248. @register_decomposition(aten._prelu_kernel)
  249. def _prelu_kernel(self: Tensor, weight: Tensor) -> Tensor:
  250. return torch.where(self > 0, self, weight * self)
  251. @register_decomposition(aten._prelu_kernel_backward)
  252. def _prelu_kernel_backward(
  253. grad_output: Tensor,
  254. self: Tensor,
  255. weight: Tensor,
  256. ) -> tuple[Tensor, Tensor]:
  257. input_grad = torch.where(self > 0, grad_output, weight * grad_output)
  258. weight_grad = torch.where(self > 0, 0.0, self * grad_output)
  259. return (input_grad, weight_grad)
  260. @register_decomposition(aten.rrelu_with_noise_backward)
  261. @out_wrapper()
  262. @pw_cast_for_opmath
  263. def rrelu_with_noise_backward(
  264. grad_output: Tensor,
  265. self: Tensor,
  266. noise: Tensor,
  267. lower: float,
  268. upper: float,
  269. training: bool,
  270. self_is_result: bool,
  271. ) -> Tensor:
  272. if training and upper - lower > 1e-6:
  273. return grad_output.mul(noise)
  274. else:
  275. negative_slope = (lower + upper) / 2
  276. return aten.leaky_relu_backward(
  277. grad_output, self, negative_slope, self_is_result
  278. )
  279. @register_decomposition(aten.log_sigmoid_backward)
  280. @out_wrapper("grad_input")
  281. @pw_cast_for_opmath
  282. def log_sigmoid_backward(grad_output: Tensor, self: Tensor, buffer: Tensor) -> Tensor:
  283. in_negative = self < 0
  284. max_deriv = torch.where(in_negative, 1, 0)
  285. sign = torch.where(in_negative, 1, -1)
  286. z = torch.exp(-torch.abs(self))
  287. return grad_output * (max_deriv - sign * (z / (1 + z)))
  288. # CPU has a special formula that uses buffer, but disabled for convenience sake
  289. # return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output
  290. @register_decomposition(aten.ldexp)
  291. @out_wrapper()
  292. def ldexp(self: Tensor, other: Tensor) -> Tensor:
  293. two_dtype = (
  294. torch.float32
  295. if utils.is_integer_dtype(self.dtype) or utils.is_boolean_dtype(self.dtype)
  296. else self.dtype
  297. )
  298. two_tensor = self.new_full((), 2.0, dtype=two_dtype)
  299. return self * torch.pow(two_tensor, other)
  300. def apply_loss_reduction(loss: Tensor, reduction: int):
  301. if reduction == Reduction.MEAN.value:
  302. return torch.mean(loss)
  303. elif reduction == Reduction.SUM.value:
  304. return torch.sum(loss)
  305. else:
  306. return loss
  307. def to_real_dtype(dtype: torch.dtype):
  308. if dtype == torch.complex32:
  309. return torch.float16
  310. elif dtype == torch.complex64:
  311. return torch.float32
  312. elif dtype == torch.complex128:
  313. return torch.float64
  314. # TODO: None of these loss castings are quite correct, see
  315. # https://github.com/pytorch/pytorch/issues/76870. Also, the ATen kernels
  316. # perform the pointwise portion in opmath, but don't maintain it between the
  317. # pointwise portion and the reduction
  318. @register_decomposition(aten.mse_loss)
  319. @out_wrapper()
  320. @pw_cast_for_opmath
  321. def mse_loss(
  322. self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value
  323. ) -> Tensor:
  324. loss = (self - target) ** 2
  325. return apply_loss_reduction(loss, reduction)
  326. @register_decomposition(aten.mse_loss_backward)
  327. @out_wrapper("grad_input")
  328. @pw_cast_for_opmath
  329. def mse_loss_backward(
  330. grad_output: Tensor, input: Tensor, target: Tensor, reduction: int
  331. ):
  332. norm = 2.0 / input.numel() if reduction == Reduction.MEAN.value else 2.0
  333. return norm * (input - target) * grad_output
  334. @register_decomposition(aten._safe_softmax)
  335. def safe_softmax(self, dim, dtype=None):
  336. out = torch.softmax(self, dim=dim, dtype=dtype)
  337. masked = self.eq(float("-inf"))
  338. masked_rows = torch.all(masked, dim=dim, keepdim=True)
  339. zeros = torch.zeros_like(out)
  340. return torch.where(masked_rows, zeros, out)
  341. @register_decomposition(aten.smooth_l1_loss)
  342. @out_wrapper()
  343. @pw_cast_for_opmath
  344. def smooth_l1_loss(
  345. self: Tensor,
  346. target: Tensor,
  347. reduction: int = Reduction.MEAN.value,
  348. beta: float = 1.0,
  349. ):
  350. loss = (self - target).abs()
  351. loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta)
  352. return apply_loss_reduction(loss, reduction)
  353. @register_decomposition(aten.smooth_l1_loss_backward.default)
  354. @pw_cast_for_opmath
  355. def smooth_l1_loss_backward(
  356. grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, beta: float
  357. ):
  358. norm = 1.0 / self.numel() if reduction == Reduction.MEAN.value else 1.0
  359. x = self - target
  360. abs_x = torch.abs(x)
  361. norm_grad = norm * grad_output
  362. return torch.where(
  363. abs_x < beta,
  364. norm_grad * x / beta,
  365. norm_grad * torch.sign(x),
  366. )
  367. @register_decomposition(aten.smooth_l1_loss_backward.grad_input)
  368. @pw_cast_for_opmath
  369. def smooth_l1_loss_backward_out(
  370. grad_output: Tensor,
  371. self: Tensor,
  372. target: Tensor,
  373. reduction: int,
  374. beta: float,
  375. grad_input: Tensor,
  376. ):
  377. result = smooth_l1_loss_backward(grad_output, self, target, reduction, beta)
  378. _maybe_resize_out(grad_input, result.shape)
  379. return _safe_copy_out(copy_from=result, copy_to=grad_input, exact_dtype=True)
  380. @register_decomposition(aten.huber_loss_backward.default)
  381. @pw_cast_for_opmath
  382. def huber_loss_backward(
  383. grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float
  384. ):
  385. norm = 1.0 / self.numel() if reduction == Reduction.MEAN.value else 1.0
  386. x = self - target
  387. return torch.where(
  388. x < -delta,
  389. -norm * grad_output * delta,
  390. torch.where(x > delta, norm * grad_output * delta, norm * x * grad_output),
  391. )
  392. # We cannot use @out_wrapper() here, because the output tensor is not named 'out', it's 'grad_input'
  393. @register_decomposition(aten.huber_loss_backward.out)
  394. @pw_cast_for_opmath
  395. def huber_loss_backward_out(
  396. grad_output: Tensor,
  397. self: Tensor,
  398. target: Tensor,
  399. reduction: int,
  400. delta: float,
  401. grad_input: Tensor,
  402. ):
  403. result = huber_loss_backward(grad_output, self, target, reduction, delta)
  404. _maybe_resize_out(grad_input, result.shape)
  405. return _safe_copy_out(copy_from=result, copy_to=grad_input, exact_dtype=True)
  406. def _nll_loss_backward(
  407. grad_output: Tensor,
  408. self: Tensor,
  409. target: Tensor,
  410. weight: Optional[Tensor],
  411. reduction: int,
  412. ignore_index: int,
  413. total_weight: Tensor,
  414. ) -> Tensor:
  415. channel_dim = 0 if self.dim() < 2 else 1
  416. if reduction == Reduction.MEAN.value:
  417. grad_output = grad_output / total_weight
  418. target = target.unsqueeze(channel_dim)
  419. safe_target = torch.where(target != ignore_index, target, 0)
  420. grad_input = torch.zeros_like(self)
  421. grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0)
  422. if grad_input.dim() > grad_output.dim() > 0:
  423. grad_output = grad_output.unsqueeze(channel_dim)
  424. if weight is not None:
  425. new_shape = [1 for _ in range(self.dim())]
  426. new_shape[channel_dim] = weight.shape[0]
  427. weight = weight.reshape(new_shape)
  428. grad_output = grad_output * weight
  429. grad_output = torch.where(target != ignore_index, grad_output, 0)
  430. return grad_input * grad_output
  431. @register_decomposition(aten.glu_backward)
  432. @out_wrapper("grad_input")
  433. @pw_cast_for_opmath
  434. def glu_backward(grad_output: Tensor, self: Tensor, dim: int) -> Tensor:
  435. if self.dim() <= 0:
  436. raise AssertionError("glu does not support 0-dimensional tensors")
  437. wrap_dim = utils.canonicalize_dim(self.dim(), dim)
  438. nIn = self.size(wrap_dim)
  439. if nIn % 2 != 0:
  440. raise AssertionError(
  441. f"Halving dimension must be even, but dimension {wrap_dim} is size {nIn}"
  442. )
  443. inputSize = nIn // 2
  444. firstHalf = self.narrow(wrap_dim, 0, inputSize)
  445. secondHalf = self.narrow(wrap_dim, inputSize, inputSize)
  446. gradInputFirstHalf = torch.sigmoid(secondHalf)
  447. gradInputSecondHalf = (
  448. (1.0 - gradInputFirstHalf) * gradInputFirstHalf * firstHalf * grad_output
  449. )
  450. gradInputFirstHalf = gradInputFirstHalf * grad_output
  451. return torch.cat([gradInputFirstHalf, gradInputSecondHalf], dim=wrap_dim)
  452. @register_decomposition(aten.nll_loss_backward)
  453. @out_wrapper("grad_input")
  454. def nll_loss_backward(
  455. grad_output: Tensor,
  456. self: Tensor,
  457. target: Tensor,
  458. weight: Optional[Tensor],
  459. reduction: int,
  460. ignore_index: int,
  461. total_weight: Tensor,
  462. ) -> Tensor:
  463. if not (0 <= self.dim() <= 2):
  464. raise AssertionError(f"input tensor should be 1D or 2D, got {self.dim()}D")
  465. if target.dim() > 1:
  466. raise AssertionError(
  467. f"0D or 1D target tensor expected, multi-target not supported, got {target.dim()}D"
  468. )
  469. no_batch_dim = self.dim() == 1 and target.dim() == 0
  470. if not (no_batch_dim or (self.shape[0] == target.shape[0])):
  471. raise AssertionError(
  472. f"size mismatch (got input: {self.shape}, target: {target.shape})"
  473. )
  474. if total_weight.numel() != 1:
  475. raise AssertionError(
  476. f"expected total_weight to be a single element tensor, got: "
  477. f"{total_weight.shape} ({total_weight.numel()} elements)"
  478. )
  479. if weight is not None and weight.numel() != self.shape[-1]:
  480. raise AssertionError(
  481. "weight tensor should be defined either for all or no classes"
  482. )
  483. if reduction == Reduction.NONE.value and self.dim() == 2:
  484. if not (grad_output.dim() == 1 and grad_output.shape[0] == self.shape[0]):
  485. raise AssertionError(
  486. f"Expected a tensor of dimension 1 and tensor.size[0] == {self.shape[0]} but "
  487. f"got: dimension {grad_output.dim()} and tensor.size[0] == {grad_output.shape[0]}"
  488. )
  489. else:
  490. if not (grad_output.dim() <= 1 and grad_output.numel() == 1):
  491. raise AssertionError(
  492. f"Expected a single element grad_output tensor, but got: {grad_output.shape}"
  493. )
  494. return _nll_loss_backward(
  495. grad_output, self, target, weight, reduction, ignore_index, total_weight
  496. )
  497. @register_decomposition(aten.nll_loss2d_backward)
  498. @out_wrapper("grad_input")
  499. def nll_loss2d_backward(
  500. grad_output: Tensor,
  501. self: Tensor,
  502. target: Tensor,
  503. weight: Optional[Tensor],
  504. reduction: int,
  505. ignore_index: int,
  506. total_weight: Tensor,
  507. ) -> Tensor:
  508. if self.dim() != 4:
  509. raise AssertionError(
  510. f"only batches of spatial inputs supported (4D tensors), but got input of dimension: {self.dim()}"
  511. )
  512. if target.dim() != 3:
  513. raise AssertionError(
  514. f"only batches of spatial targets supported (3D tensors) but got targets of dimension: {target.dim()}"
  515. )
  516. if not (
  517. self.shape[0] == target.shape[0]
  518. and self.shape[2] == target.shape[1]
  519. and self.shape[3] == target.shape[2]
  520. ):
  521. raise AssertionError(
  522. f"size mismatch (got input: {self.shape}, target: {target.shape}"
  523. )
  524. if total_weight.numel() != 1:
  525. raise AssertionError(
  526. f"expected total_weight to be a single element tensor, "
  527. f"got: {total_weight.shape} ( {total_weight.numel()}, elements)"
  528. )
  529. return _nll_loss_backward(
  530. grad_output, self, target, weight, reduction, ignore_index, total_weight
  531. )
  532. @register_decomposition(aten.binary_cross_entropy)
  533. @out_wrapper()
  534. @pw_cast_for_opmath
  535. def binary_cross_entropy(
  536. self: Tensor,
  537. target: Tensor,
  538. weight: Optional[Tensor] = None,
  539. reduction: int = Reduction.MEAN.value,
  540. ) -> Tensor:
  541. # We cannot currently model this without introducing data-dependent control flow
  542. # TORCH_CHECK(
  543. # (input_val >= 0) && (input_val <= 1),
  544. # "all elements of input should be between 0 and 1"
  545. # )
  546. loss = (target - 1) * torch.maximum(
  547. torch.log1p(-self), self.new_full((), -100)
  548. ) - target * torch.maximum(torch.log(self), self.new_full((), -100))
  549. if weight is not None:
  550. loss = loss * weight
  551. return apply_loss_reduction(loss, reduction)
  552. @register_decomposition(aten.binary_cross_entropy_backward)
  553. @out_wrapper("grad_input")
  554. @pw_cast_for_opmath
  555. def binary_cross_entropy_backward(
  556. grad_output: Tensor,
  557. self: Tensor,
  558. target: Tensor,
  559. weight: Optional[Tensor] = None,
  560. reduction: int = Reduction.MEAN.value,
  561. ) -> Tensor:
  562. EPSILON = 1e-12
  563. result = grad_output * (self - target) / torch.clamp(self * (1 - self), min=EPSILON)
  564. if weight is not None:
  565. result = result * weight
  566. if reduction == Reduction.MEAN.value:
  567. result = result / self.numel()
  568. return result
  569. @register_decomposition(aten.soft_margin_loss)
  570. @out_wrapper()
  571. @pw_cast_for_opmath
  572. def soft_margin_loss(
  573. input: Tensor,
  574. target: Tensor,
  575. reduction: int = Reduction.MEAN.value,
  576. ) -> Tensor:
  577. loss = torch.log1p(torch.exp(-input * target))
  578. return apply_loss_reduction(loss, reduction)
  579. @register_decomposition(aten.soft_margin_loss_backward)
  580. @out_wrapper("grad_input")
  581. @pw_cast_for_opmath
  582. def soft_margin_loss_backward(
  583. grad_output: Tensor,
  584. self: Tensor,
  585. target: Tensor,
  586. reduction: int = Reduction.MEAN.value,
  587. ) -> Tensor:
  588. grad_input = target * grad_output * (torch.sigmoid(target * self) - 1)
  589. if reduction == Reduction.MEAN.value:
  590. grad_input = grad_input / self.numel()
  591. return grad_input
  592. @register_decomposition(aten.dist)
  593. @out_wrapper()
  594. def dist(input: Tensor, other: Tensor, p: float = 2):
  595. return aten.norm(input - other, p=p)
  596. @register_decomposition(aten._euclidean_dist)
  597. @out_wrapper()
  598. def _euclidean_dist(x1: Tensor, x2: Tensor) -> Tensor:
  599. x1_norm = x1.pow(2).sum(-1, True)
  600. x1_pad = torch.ones_like(x1_norm, memory_format=torch.contiguous_format)
  601. x2_norm = x2.pow(2).sum(-1, True)
  602. x2_pad = torch.ones_like(x2_norm, memory_format=torch.contiguous_format)
  603. x1_ = torch.cat([x1.mul(-2), x1_norm, x1_pad], -1)
  604. x2_ = torch.cat([x2, x2_pad, x2_norm], -1)
  605. result = x1_.matmul(x2_.mT)
  606. return result.clamp_min(0).sqrt()
  607. @register_decomposition(aten.slice_backward)
  608. @out_wrapper()
  609. def slice_backward(
  610. grad_output: Tensor,
  611. input_sizes: list[int],
  612. dim: int,
  613. start: int,
  614. end: int,
  615. step: int,
  616. ):
  617. grad_input = grad_output.new_zeros(input_sizes)
  618. return torch.slice_scatter(grad_input, grad_output, dim, start, end, step)
  619. @register_decomposition(aten.slice.Tensor)
  620. def slice_forward(
  621. # Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1
  622. self: Tensor,
  623. dim: int = 0,
  624. start: Optional[int] = None,
  625. end: Optional[int] = None,
  626. step: int = 1,
  627. ):
  628. from torch.fx.experimental.symbolic_shapes import statically_known_true
  629. ndim = self.dim()
  630. if ndim == 0:
  631. raise RuntimeError("slice() cannot be applied to a 0-dim tensor.")
  632. dim = utils.canonicalize_dim(self.dim(), dim)
  633. sizes = list(self.size())
  634. strides = list(self.stride())
  635. if step <= 0:
  636. raise RuntimeError("slice step must be positive")
  637. start_val = start if start is not None else 0
  638. end_val = end if end is not None else sys.maxsize # 2^63 - 1
  639. if start_val < 0:
  640. start_val += sizes[dim]
  641. if end_val < 0:
  642. end_val += sizes[dim]
  643. if start_val < 0:
  644. start_val = 0
  645. elif start_val > sizes[dim]:
  646. start_val = sizes[dim]
  647. if statically_known_true(end_val == sys.maxsize):
  648. end_val = sizes[dim]
  649. elif end_val < start_val:
  650. end_val = start_val
  651. elif end_val > sizes[dim]:
  652. end_val = sizes[dim]
  653. storage_offset = self.storage_offset() + start_val * strides[dim]
  654. len = end_val - start_val
  655. sizes[dim] = (len + step - 1) // step
  656. strides[dim] *= step
  657. if self.is_quantized:
  658. raise NotImplementedError(
  659. "Slice decomposition for quantized tensors aren't implemented"
  660. )
  661. else:
  662. return self.as_strided(sizes, strides, storage_offset)
  663. def _normalize_start_end(
  664. x: Tensor, dim: int, start: Optional[int], end: Optional[int]
  665. ) -> tuple[int, int]:
  666. """
  667. Normalize start and end such that both are in the range
  668. [0, x.get_size()[dim]] and start <= end.
  669. """
  670. dim_size = x.shape[dim]
  671. def clamp_wrap(val, lower, upper, default) -> int:
  672. if val is None:
  673. return default
  674. if val < 0:
  675. val = val + dim_size
  676. return min(max(val, lower), upper)
  677. start = clamp_wrap(start, 0, dim_size, 0)
  678. end = clamp_wrap(end, start, dim_size, dim_size)
  679. return start, end
  680. # This is not in torch._refs because aten.index used by
  681. # aten._unsafe_masked_index does not have a decomposition.
  682. @register_decomposition(aten.slice_scatter)
  683. @out_wrapper()
  684. def slice_scatter(
  685. input: Tensor,
  686. src: Tensor,
  687. dim: int = 0,
  688. start: Optional[int] = None,
  689. end: Optional[int] = None,
  690. step: int = 1,
  691. ):
  692. dim = utils.canonicalize_dim(input.ndim, dim)
  693. dim_size = input.shape[dim]
  694. start, end = _normalize_start_end(input, dim, start, end)
  695. src_size = list(input.shape)
  696. src_size[dim] = (end - start + (step - 1)) // step
  697. src = src.expand(src_size)
  698. if start == 0 and end == dim_size and step == 1:
  699. return src.clone()
  700. indices: list[Optional[Tensor]] = [None] * input.dim()
  701. idx = torch.arange(dim_size, device=input.device)
  702. indices[dim] = (idx - start) // step
  703. mask = torch.ones(dim_size, device=input.device, dtype=torch.bool)
  704. if start != 0:
  705. mask = torch.logical_and(mask, idx >= start)
  706. if end != dim_size:
  707. mask = torch.logical_and(mask, idx < end)
  708. if step != 1:
  709. mask = torch.logical_and(mask, (idx - start) % step == 0)
  710. mask_shape = [1] * input.dim()
  711. mask_shape[dim] = -1
  712. mask = mask.view(mask_shape)
  713. return aten.where(mask, aten._unsafe_masked_index(src, mask, indices, 0), input)
  714. @register_decomposition(aten.select_backward)
  715. @out_wrapper()
  716. def select_backward(grad_output: Tensor, input_sizes: list[int], dim: int, index: int):
  717. grad_input = grad_output.new_zeros(input_sizes)
  718. return torch.select_scatter(grad_input, grad_output, dim, index)
  719. @register_decomposition(aten.diagonal_backward)
  720. @out_wrapper()
  721. def diagonal_backward(
  722. grad_output: Tensor, input_sizes: list[int], offset: int, dim1: int, dim2: int
  723. ):
  724. grad_input = grad_output.new_zeros(input_sizes)
  725. return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2)
  726. def _cast_grad_to_input_dtype(
  727. grad_output: Tensor, grad_input: Tensor, input_dtype: torch.dtype
  728. ):
  729. if grad_output.dtype != input_dtype:
  730. grad_input = grad_input.to(input_dtype)
  731. return grad_input
  732. @register_decomposition(aten._softmax_backward_data)
  733. @out_wrapper("grad_input")
  734. @compute_only_pw_cast_for_opmath
  735. def _softmax_backward_data(
  736. grad_output: Tensor, output: Tensor, dim: int, input_dtype: torch.dtype
  737. ):
  738. new_grad_output = grad_output * output
  739. grad_input = new_grad_output - output * torch.sum(
  740. new_grad_output, dim=dim, keepdim=True
  741. )
  742. # CPU kernel doesn't respect input_dtype, but following check doesn't work for meta tensor
  743. # if grad_output.device == torch.device("cpu"):
  744. # return grad_input.contiguous()
  745. return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype).contiguous()
  746. @register_decomposition(aten._log_softmax_backward_data)
  747. @out_wrapper()
  748. @compute_only_pw_cast_for_opmath
  749. def _log_softmax_backward_data(
  750. grad_output: Tensor, output: Tensor, dim: int, input_dtype: torch.dtype
  751. ):
  752. grad_input = grad_output - torch.exp(output) * torch.sum(
  753. grad_output, dim=dim, keepdim=True
  754. )
  755. return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype)
  756. def _im2col_col2im_indices_along_dim(
  757. input_d, kernel_d, dilation_d, padding_d, stride_d, device
  758. ):
  759. """Utility function to implement im2col and col2im"""
  760. blocks_d = input_d + padding_d * 2 - dilation_d * (kernel_d - 1)
  761. arange_kw = partial(torch.arange, dtype=torch.int64, device=device)
  762. # Stride kernel over input and find starting indices along dim d
  763. blocks_d_indices = arange_kw(0, blocks_d, stride_d).unsqueeze(0)
  764. # Apply dilation on kernel and find its indices along dim d
  765. kernel_grid = arange_kw(0, kernel_d * dilation_d, dilation_d).unsqueeze(-1)
  766. # Broadcast and add kernel starting positions (indices) with
  767. # kernel_grid along dim d, to get block indices along dim d
  768. return blocks_d_indices + kernel_grid
  769. @register_decomposition(aten.im2col)
  770. @out_wrapper()
  771. def im2col(
  772. input: Tensor,
  773. kernel_size: list[int],
  774. dilation: list[int],
  775. padding: list[int],
  776. stride: list[int],
  777. ) -> Tensor:
  778. torch._check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported")
  779. torch._check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported")
  780. torch._check(len(padding) == 2, lambda: "im2col(): only 2D padding supported")
  781. torch._check(len(stride) == 2, lambda: "im2col(): only 2D stride supported")
  782. def check_positive(param, param_name, strict=True):
  783. cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
  784. torch._check(
  785. cond, lambda: f"{param_name} should be greater than zero, but got {param}"
  786. )
  787. check_positive(kernel_size, "kernel_size")
  788. check_positive(dilation, "dilation")
  789. check_positive(dilation, "padding", strict=False)
  790. check_positive(stride, "stride")
  791. shape = input.shape
  792. ndim = len(shape)
  793. torch._check(
  794. ndim in (3, 4) and all(d != 0 for d in shape[-3:]),
  795. lambda: "Expected 3D or 4D (batch mode) tensor for input with possible 0 batch size "
  796. f"and non-zero dimensions, but got: {tuple(shape)}",
  797. )
  798. output_size = tuple(
  799. 1 + (out + 2 * pad - dil * (ker - 1) - 1) // st
  800. for out, pad, dil, ker, st in zip(
  801. shape[-2:], padding, dilation, kernel_size, stride
  802. )
  803. )
  804. torch._check(
  805. all(c > 0 for c in output_size),
  806. lambda: f"Given an input with spatial size {tuple(shape[-2:])}, "
  807. f"kernel_size={kernel_size}, dilation={dilation}, "
  808. f"padding={padding}, stride={stride}, "
  809. "the calculated shape of the array of sliding blocks "
  810. f"is {output_size}, but its components must be at least one.",
  811. )
  812. batched_input = ndim == 4
  813. if not batched_input:
  814. input = input.unsqueeze(0)
  815. batch_dim, channel_dim, input_h, input_w = input.shape
  816. stride_h, stride_w = stride
  817. padding_h, padding_w = padding
  818. dilation_h, dilation_w = dilation
  819. kernel_h, kernel_w = kernel_size
  820. blocks_row_indices = _im2col_col2im_indices_along_dim(
  821. input_h, kernel_h, dilation_h, padding_h, stride_h, input.device
  822. )
  823. blocks_col_indices = _im2col_col2im_indices_along_dim(
  824. input_w, kernel_w, dilation_w, padding_w, stride_w, input.device
  825. )
  826. # Note that F.pad takes (padding_left, padding_right, padding_top, padding_bottom)
  827. # ugh
  828. padded_input = F.pad(input, (padding_w, padding_w, padding_h, padding_h))
  829. blocks_row_indices = blocks_row_indices.unsqueeze(-1).unsqueeze(-1)
  830. output = padded_input[:, :, blocks_row_indices, blocks_col_indices]
  831. output = output.permute(0, 1, 2, 4, 3, 5)
  832. num_blocks_row = blocks_row_indices.size(1)
  833. num_blocks_col = blocks_col_indices.size(1)
  834. output = output.reshape(
  835. batch_dim, channel_dim * kernel_h * kernel_w, num_blocks_row * num_blocks_col
  836. )
  837. if not batched_input:
  838. output = output.squeeze(0)
  839. return output
  840. @register_decomposition(aten.col2im)
  841. @out_wrapper()
  842. @pw_cast_for_opmath
  843. def col2im(
  844. input: Tensor,
  845. output_size: list[int],
  846. kernel_size: list[int],
  847. dilation: list[int],
  848. padding: list[int],
  849. stride: list[int],
  850. ) -> Tensor:
  851. torch._check(len(output_size) == 2, lambda: "only 2D output_size supported")
  852. torch._check(len(kernel_size) == 2, lambda: "only 2D kernel supported")
  853. torch._check(len(dilation) == 2, lambda: "only 2D dilation supported")
  854. torch._check(len(padding) == 2, lambda: "only 2D padding supported")
  855. torch._check(len(stride) == 2, lambda: "only 2D stride supported")
  856. def check_positive(param, param_name, strict=True):
  857. cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
  858. torch._check(
  859. cond, lambda: f"{param_name} should be greater than zero, but got {param}"
  860. )
  861. check_positive(kernel_size, "kernel_size")
  862. check_positive(dilation, "dilation")
  863. check_positive(padding, "padding", strict=False)
  864. check_positive(stride, "stride")
  865. check_positive(output_size, "output_size")
  866. shape = input.shape
  867. ndim = len(shape)
  868. torch._check(
  869. ndim in (2, 3) and all(d != 0 for d in shape[-2:]),
  870. lambda: "Expected 2D or 3D (batch mode) tensor for input with possible 0 batch size "
  871. f"and non-zero dimensions, but got: {tuple(shape)}",
  872. )
  873. prod_kernel_size = kernel_size[0] * kernel_size[1]
  874. torch._check(
  875. shape[-2] % prod_kernel_size == 0,
  876. lambda: "Expected size of input's first non-batch dimension to be divisible by the "
  877. f"product of kernel_size, but got input.shape[-2] = {shape[-2]} and "
  878. f"kernel_size={kernel_size}",
  879. )
  880. col = [
  881. 1 + (out + 2 * pad - dil * (ker - 1) - 1) // st
  882. for out, pad, dil, ker, st in zip(
  883. output_size, padding, dilation, kernel_size, stride
  884. )
  885. ]
  886. L = col[0] * col[1]
  887. torch._check(
  888. shape[-1] == L,
  889. lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, "
  890. f"dilation={dilation}, padding={padding}, stride={stride}, "
  891. f"expected input.size(-1) to be {L} but got {shape[-1]}.",
  892. )
  893. torch._check(
  894. L > 0,
  895. lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, "
  896. f"dilation={dilation}, padding={padding}, stride={stride}, "
  897. f"expected input.size(-1) to be {L} but got {shape[-1]}.",
  898. )
  899. batched_input = ndim == 3
  900. if not batched_input:
  901. input = input.unsqueeze(0)
  902. shape = input.shape
  903. out_h, out_w = output_size
  904. stride_h, stride_w = stride
  905. padding_h, padding_w = padding
  906. dilation_h, dilation_w = dilation
  907. kernel_h, kernel_w = kernel_size
  908. # col2im is defined as the backwards of im2col, so we differentiate its decomposition by hand
  909. input = input.reshape([shape[0], shape[1] // prod_kernel_size] + kernel_size + col)
  910. input = input.permute(0, 1, 2, 4, 3, 5)
  911. indices_row = _im2col_col2im_indices_along_dim(
  912. out_h, kernel_h, dilation_h, padding_h, stride_h, input.device
  913. )
  914. indices_row = _unsqueeze_to_dim(indices_row, 4)
  915. indices_col = _im2col_col2im_indices_along_dim(
  916. out_w, kernel_w, dilation_w, padding_w, stride_w, input.device
  917. )
  918. output_padded_size = [o + 2 * p for o, p in zip(output_size, padding)]
  919. output = input.new_zeros(
  920. [shape[0], shape[1] // prod(kernel_size)] + output_padded_size
  921. )
  922. idx = (None, None, indices_row, indices_col)
  923. output = aten._unsafe_index_put(output, idx, input, accumulate=True)
  924. output = F.pad(output, (-padding_w, -padding_w, -padding_h, -padding_h))
  925. if not batched_input:
  926. output = output.squeeze(0)
  927. return output
  928. @register_decomposition(aten.native_dropout_backward)
  929. @out_wrapper()
  930. def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float):
  931. # According to the CUDA kernel implementation we should have this test;
  932. # but it seems to fail tests!
  933. # torch._check(mask.dtype == torch.bool, lambda: f"Mask should be Bool Scalar Type {mask.dtype}")
  934. # Mimicking CUDA kernel's behavior for output stride: output follow input's memory format
  935. # This different from TensorIterator's behavior
  936. r = (grad_output * (mask.type_as(grad_output) * scale)).clone(
  937. memory_format=utils.suggest_memory_format(grad_output)
  938. )
  939. return r
  940. @register_decomposition(aten.unfold_backward)
  941. @out_wrapper()
  942. def unfold_backward(
  943. grad: Tensor, input_size: list[int], dimension: int, size: int, step: int
  944. ) -> Tensor:
  945. if len(input_size) == 0:
  946. return torch.squeeze_copy(grad, 0)
  947. dim = utils.canonicalize_dim(len(input_size), dimension)
  948. idx = torch.arange(input_size[dim], device=grad.device, dtype=torch.int32)
  949. idx = idx.unfold(0, size, step).flatten()
  950. grad = grad.movedim(-1, dim + 1).flatten(dim, dim + 1)
  951. # nb. At the moment this generates two kernels in triton
  952. # It could potentially be fused into one call to scatter_reduce,
  953. # in the case step <= size provided scatter_reduce generates 1 kernel
  954. grad_input = grad.new_zeros(input_size)
  955. index = (None,) * dim + (idx,)
  956. return aten._unsafe_index_put(grad_input, index, grad, accumulate=True).contiguous()
  957. @register_decomposition(aten.logit_backward.default)
  958. @pw_cast_for_opmath
  959. def logit_backward(
  960. grad_output: Tensor, self: Tensor, eps: Optional[float] = None
  961. ) -> Tensor:
  962. if eps is not None:
  963. lo = eps
  964. hi = 1.0 - lo
  965. return torch.where(
  966. torch.logical_and(self >= lo, self <= hi),
  967. grad_output / (self * (1.0 - self)),
  968. 0.0,
  969. )
  970. else:
  971. return torch.where(
  972. torch.logical_and(self >= 0.0, self <= 1.0),
  973. grad_output / (self * (1.0 - self)),
  974. self.new_full((), float("nan")),
  975. )
  976. @register_decomposition(aten.dropout)
  977. @aten.dropout.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  978. @aten.dropout.default.py_impl(DispatchKey.Autograd)
  979. def dropout(input: Tensor, p: float, train: Optional[bool]):
  980. if train and p != 0:
  981. return aten.native_dropout(input, p, train)[0]
  982. else:
  983. return input.clone()
  984. @register_decomposition(aten.native_dropout)
  985. @out_wrapper("out0", "out1")
  986. def native_dropout(input: Tensor, p: float, train: Optional[bool]):
  987. if train and p != 0:
  988. if p == 1:
  989. return (torch.zeros_like(input), torch.zeros_like(input, dtype=torch.bool))
  990. if not input.dtype.is_floating_point:
  991. raise RuntimeError(
  992. "result type Float can't be cast to the desired output type Long"
  993. )
  994. bool_mask = torch.rand_like(input) > p
  995. res = bool_mask * input * float(1.0 / (1.0 - p))
  996. return (res, bool_mask)
  997. else:
  998. return (input, torch.ones_like(input, dtype=torch.bool))
  999. @register_decomposition(aten._softmax)
  1000. @out_wrapper()
  1001. def _softmax(x: Tensor, dim: int, half_to_float: bool):
  1002. from torch.fx.experimental.symbolic_shapes import guard_or_false
  1003. # eager softmax returns a contiguous tensor. Ensure that decomp also returns
  1004. # a contiguous tensor.
  1005. x = x.contiguous()
  1006. if half_to_float:
  1007. if x.dtype != torch.half:
  1008. raise AssertionError(
  1009. f"half_to_float is True but x.dtype is {x.dtype}, expected torch.half"
  1010. )
  1011. computation_dtype, result_dtype = utils.elementwise_dtypes(
  1012. x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  1013. )
  1014. x = x.to(computation_dtype)
  1015. if guard_or_false(x.numel() == 0):
  1016. unnormalized = torch.exp(x)
  1017. else:
  1018. x_max = torch.amax(x, dim, keepdim=True)
  1019. unnormalized = torch.exp(x - x_max)
  1020. result = unnormalized / torch.sum(unnormalized, dim, keepdim=True)
  1021. if not half_to_float:
  1022. result = result.to(result_dtype)
  1023. return result
  1024. @register_decomposition(aten._log_softmax)
  1025. @out_wrapper(exact_dtype=True)
  1026. def _log_softmax(x: Tensor, dim: int, half_to_float: bool):
  1027. from torch.fx.experimental.symbolic_shapes import guard_or_false
  1028. # eager log_softmax returns a contiguous tensor. Ensure that decomp also
  1029. # returns a contiguous tensor.
  1030. x = x.contiguous()
  1031. if half_to_float:
  1032. if x.dtype != torch.half:
  1033. raise AssertionError(
  1034. f"half_to_float is True but x.dtype is {x.dtype}, expected torch.half"
  1035. )
  1036. computation_dtype, result_dtype = utils.elementwise_dtypes(
  1037. x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  1038. )
  1039. x = x.to(computation_dtype)
  1040. if guard_or_false(x.numel() == 0):
  1041. shifted = x
  1042. else:
  1043. x_max = torch.amax(x, dim, keepdim=True)
  1044. shifted = x - x_max
  1045. shifted_logsumexp = torch.log(torch.sum(torch.exp(shifted), dim, keepdim=True))
  1046. result = shifted - shifted_logsumexp
  1047. if not half_to_float:
  1048. result = result.to(result_dtype)
  1049. return result
  1050. @register_decomposition(aten.embedding)
  1051. @out_wrapper()
  1052. def embedding(
  1053. weight: Tensor,
  1054. indices: Tensor,
  1055. padding_idx: int = -1,
  1056. scale_grad_by_freq: bool = False,
  1057. sparse: bool = False,
  1058. ) -> Tensor:
  1059. if weight.dim() != 2:
  1060. raise AssertionError(f"'weight' must be 2-D, got {weight.dim()}-D")
  1061. # Nb. scale_grad_by_freq is not used in the forward
  1062. if indices.ndim <= 1:
  1063. # We need this one as weight[indices] calls item() in these cases
  1064. out = weight.index_select(0, indices)
  1065. if indices.ndim == 0:
  1066. out = out.squeeze(0)
  1067. return out
  1068. else:
  1069. return weight[indices]
  1070. @register_decomposition(aten.embedding_dense_backward)
  1071. @out_wrapper()
  1072. def embedding_dense_backward(
  1073. grad_output: Tensor,
  1074. indices: Tensor,
  1075. num_weights: int,
  1076. padding_idx: int,
  1077. scale_grad_by_freq: bool,
  1078. ):
  1079. computation_dtype, result_dtype = utils.elementwise_dtypes(
  1080. grad_output, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  1081. )
  1082. grad_output = grad_output.to(computation_dtype)
  1083. indices = _maybe_convert_to_dtype(indices, torch.long) # type: ignore[assignment]
  1084. if scale_grad_by_freq:
  1085. counts = indices.new_zeros((num_weights,))
  1086. ones = torch.ones_like(indices)
  1087. counts = aten._unsafe_index_put(counts, [indices], ones, accumulate=True)
  1088. grad_weights_scale = counts[indices]
  1089. grad_output = grad_output / grad_weights_scale.unsqueeze(-1)
  1090. mask = _unsqueeze_to_dim(indices == padding_idx, grad_output.ndim)
  1091. grad = grad_output.masked_fill(mask, 0)
  1092. grad_weight = grad_output.new_zeros(
  1093. (num_weights,) + grad_output.shape[indices.ndim :]
  1094. )
  1095. return aten._unsafe_index_put(grad_weight, [indices], grad, accumulate=True).to(
  1096. result_dtype
  1097. )
  1098. def prod(x: list[int]):
  1099. r = 1
  1100. for i in x:
  1101. r *= i
  1102. return r
  1103. def _pad_chunk(
  1104. tensors: list[Tensor],
  1105. dim: int,
  1106. num_chunks: int,
  1107. ) -> list[Tensor]:
  1108. padded_tensors = []
  1109. for tensor in tensors:
  1110. tensor_size = tensor.size()
  1111. pad_along_dim = (tensor_size[dim] + num_chunks - 1) // num_chunks * num_chunks
  1112. if pad_along_dim != tensor_size[dim]:
  1113. # Use aten.constant_pad_nd instead of copy_ for functionalization
  1114. pad = [0] * 2 * (tensor.ndim - dim - 1) + [
  1115. 0,
  1116. pad_along_dim - tensor_size[dim],
  1117. ]
  1118. tensor = aten.constant_pad_nd(tensor, pad, 0)
  1119. view_size = tensor_size[:dim] + torch.Size([num_chunks, -1])
  1120. padded_tensors.append(tensor.reshape(view_size))
  1121. return padded_tensors
  1122. def have_same_ndims(tensors: list[Tensor]):
  1123. ndim = tensors[0].ndim
  1124. for tensor in tensors:
  1125. if tensor.ndim != ndim:
  1126. return False
  1127. return True
  1128. def leading_dimension_matches(tensors: list[Tensor], dim: int):
  1129. leading_dim_sizes = tensors[0].size()[:dim]
  1130. for tensor in tensors:
  1131. torch._check(
  1132. tensor.size()[:dim] == leading_dim_sizes,
  1133. lambda: "_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors",
  1134. )
  1135. def _preprocess_chunk_cat_inputs(
  1136. tensors: list[Tensor],
  1137. dim: int,
  1138. num_chunks: int,
  1139. ):
  1140. torch._check(num_chunks >= 1, lambda: "_chunk_cat expects positive num_chunks")
  1141. torch._check(
  1142. len(tensors) > 0, lambda: "_chunk_cat expects a non-empty input tensor list"
  1143. )
  1144. expected_dtype = tensors[0].dtype
  1145. expected_device = tensors[0].device
  1146. for tensor in tensors:
  1147. torch._check(tensor.numel() > 0, lambda: "_chunk_cat expects non-empty tensor")
  1148. torch._check(
  1149. tensor.dtype == expected_dtype,
  1150. lambda: "_chunk_cat expects all input tensors with the same dtype",
  1151. )
  1152. torch._check(
  1153. tensor.device == expected_device,
  1154. lambda: "_chunk_cat expects all inputs tensors on the same device",
  1155. )
  1156. if have_same_ndims(tensors):
  1157. dim = utils.canonicalize_dim(tensors[0].dim(), dim)
  1158. else:
  1159. torch._check(
  1160. dim >= 0,
  1161. lambda: "_chunk_cat expects non-negative dim when input tensors have different ndims",
  1162. )
  1163. for tensor in tensors:
  1164. torch._check(
  1165. dim < tensor.ndim,
  1166. lambda: "_chunk_cat expects dim < ndim for all input tensors",
  1167. )
  1168. leading_dimension_matches(tensors, dim)
  1169. return dim
  1170. @register_decomposition([aten._chunk_cat.default, aten._chunk_cat.out])
  1171. def _chunk_cat(
  1172. tensors: list[Tensor],
  1173. dim: int,
  1174. num_chunks: int,
  1175. out: Optional[Tensor] = None,
  1176. ) -> Tensor:
  1177. dim = _preprocess_chunk_cat_inputs(tensors, dim, num_chunks)
  1178. padded_tensors = _pad_chunk(tensors, dim, num_chunks)
  1179. if out is None:
  1180. return torch.cat(padded_tensors, dim + 1)
  1181. else:
  1182. torch.cat(padded_tensors, dim + 1, out=out)
  1183. return out
  1184. # out_wrapper currently does not allow optional outputs
  1185. @register_decomposition(
  1186. [aten.split_with_sizes_copy.default, aten.split_with_sizes_copy.out]
  1187. )
  1188. def split_with_sizes_copy(
  1189. self: Tensor,
  1190. split_sizes: list[int],
  1191. dim: int = 0,
  1192. out: Optional[list[Tensor]] = None,
  1193. ) -> Optional[list[Tensor]]:
  1194. splits = aten.split_with_sizes(self, split_sizes, dim=dim)
  1195. if out is None:
  1196. return [s.clone(memory_format=torch.contiguous_format) for s in splits]
  1197. else:
  1198. for output, split in zip(out, splits):
  1199. _maybe_resize_out(output, split.shape)
  1200. _safe_copy_out(copy_from=split, copy_to=output, exact_dtype=True)
  1201. return None
  1202. @register_decomposition(aten.unsafe_split.Tensor)
  1203. def unsafe_split(input: Tensor, split_size: int, dim: int = 0) -> tuple[Tensor, ...]:
  1204. return aten.split.Tensor(input, split_size, dim)
  1205. @register_decomposition(aten.unsafe_split_with_sizes.default)
  1206. def unsafe_split_with_sizes(
  1207. input: Tensor, split_sizes: list[int], dim: int = 0
  1208. ) -> tuple[Tensor, ...]:
  1209. return aten.split_with_sizes.default(input, split_sizes, dim)
  1210. @register_decomposition(aten.split.Tensor)
  1211. def split(self: Tensor, split_size: int, dim: int = 0) -> tuple[Tensor, ...]:
  1212. input_sizes = self.shape
  1213. dim_size = input_sizes[dim]
  1214. if split_size == 0:
  1215. if dim_size != 0:
  1216. raise AssertionError(
  1217. f"split_size is 0 but dim_size is {dim_size}, expected 0"
  1218. )
  1219. return (self.detach(),)
  1220. chunks = (dim_size + split_size - 1) // split_size
  1221. # Avoid importing sympy at a module level
  1222. from torch.fx.experimental.symbolic_shapes import guard_int
  1223. chunks = guard_int(chunks)
  1224. split_sizes = [split_size for i in range(chunks)]
  1225. split_sizes[-1] = split_size - (split_size * chunks - dim_size)
  1226. return torch.split(self, split_sizes, dim)
  1227. @aten.tensor_split.tensor_indices_or_sections.py_impl(
  1228. DispatchKey.CompositeImplicitAutograd
  1229. )
  1230. def tensor_split_tensor_indices_or_sections_py_impl(
  1231. self: Tensor,
  1232. tensor_indices_or_sections: Tensor,
  1233. dim: int = 0,
  1234. ) -> tuple[Tensor, ...]:
  1235. if tensor_indices_or_sections.device.type != "cpu":
  1236. raise AssertionError(
  1237. f"tensor_indices_or_sections must be on CPU, got {tensor_indices_or_sections.device}"
  1238. )
  1239. if tensor_indices_or_sections.dtype != torch.int64:
  1240. raise AssertionError(
  1241. f"tensor_indices_or_sections must be int64, got {tensor_indices_or_sections.dtype}"
  1242. )
  1243. split_dim = tensor_indices_or_sections.dim()
  1244. torch._check(
  1245. split_dim == 1 or split_dim == 0,
  1246. lambda: "tensor_split expected tensor_indices_or_sections to be a zero-dimensional "
  1247. f"or one-dimensional tensor, but got a tensor with {split_dim} dims",
  1248. )
  1249. if split_dim == 0:
  1250. sections = tensor_indices_or_sections.item()
  1251. if not isinstance(sections, IntLike):
  1252. raise AssertionError(
  1253. f"Expected sections to be IntLike, got {type(sections).__name__}"
  1254. )
  1255. return self.tensor_split(sections, dim)
  1256. else:
  1257. ctx = nullcontext
  1258. if (fake_mode := torch._guards.detect_fake_mode()) and (
  1259. shape_env := fake_mode.shape_env
  1260. ):
  1261. ctx = shape_env.ignore_fresh_unbacked_symbols # type: ignore[assignment]
  1262. # In fake tensor prop, we end up calling slice() with these unbacked indices.
  1263. # Because slice has flexible semantics, the unbacked handling generates new output sizes
  1264. # for each slice, effectively clobbering over these index symbols.
  1265. # To avoid PendingUnbackedSymbolNotFound errors, we tell the compiler it's fine to not bind these.
  1266. with ctx():
  1267. indices = [i.item() for i in tensor_indices_or_sections]
  1268. # WARNING: Tempted to torch._check(x>0) on the indices here? You
  1269. # can't: tensor_split works with negative values in indices:
  1270. #
  1271. # >>> torch.tensor_split(torch.randn(10), torch.tensor([-5, 5]))
  1272. # (tensor([ 0.3540, 2.1074, -0.8507, 1.1639, 0.3055]), tensor([]),
  1273. # tensor([-0.4285, 1.0692, -0.1776, 0.9362, 1.6143]))
  1274. #
  1275. # Sorry, I don't make the rules. Explicitly do the item call in user
  1276. # code if you KNOW that they are non-negative.
  1277. return self.tensor_split(indices, dim)
  1278. # TODO: this doesn't appear to have enough precision in bfloat16
  1279. @register_decomposition(aten.addmm)
  1280. @out_wrapper(exact_dtype=True)
  1281. @pw_cast_for_opmath
  1282. def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = 1):
  1283. if not self.is_floating_point() and not self.is_complex():
  1284. beta = int(beta)
  1285. alpha = int(alpha)
  1286. out = alpha * torch.mm(mat1, mat2)
  1287. if beta == 0:
  1288. return out
  1289. # The output of aten.addmm is contiguous, we need to match this behavior in the decomposition.
  1290. # The original implementation 'beta * self + out' would return a strided tensor if `self` is strided.
  1291. # We thus use `out`, the output of torch.mm, which is always contiguous, as the first argument for addition.
  1292. # This is relying on TensorIterator's behavior that it takes higher precedence on the stride of first input.
  1293. # Alternative, we can write `(beta * self + out).contiguous()`, but it introduces another copy in some cases.
  1294. # This implementation is not ideal, and we should revisit this when we have a better solution.
  1295. return out + beta * self
  1296. @register_decomposition(aten._addmm_activation)
  1297. @out_wrapper()
  1298. @pw_cast_for_opmath
  1299. def _addmm_activation(
  1300. self: Tensor,
  1301. mat1: Tensor,
  1302. mat2: Tensor,
  1303. beta: int = 1,
  1304. alpha: int = 1,
  1305. use_gelu: bool = False,
  1306. ):
  1307. out = addmm(self, mat1, mat2, beta, alpha)
  1308. if use_gelu:
  1309. if self.is_cuda:
  1310. return aten.gelu(out, approximate="tanh")
  1311. else:
  1312. return aten.gelu(out)
  1313. return aten.relu(out)
  1314. @register_decomposition(aten.addmv)
  1315. @out_wrapper(exact_dtype=True)
  1316. @pw_cast_for_opmath
  1317. def addmv(self: Tensor, mat1: Tensor, vec: Tensor, beta: int = 1, alpha: int = 1):
  1318. if not self.is_floating_point() and not self.is_complex():
  1319. beta = int(beta)
  1320. alpha = int(alpha)
  1321. out = alpha * torch.mv(mat1, vec)
  1322. if beta == 0:
  1323. return out
  1324. if out.numel() == 0: # handle empty matrix
  1325. return beta * self
  1326. return out + beta * self
  1327. @register_decomposition(aten.native_group_norm_backward.default)
  1328. @pw_cast_for_opmath
  1329. def native_group_norm_backward(
  1330. grad_output: Tensor,
  1331. input: Tensor,
  1332. mean: Tensor,
  1333. rstd: Tensor,
  1334. gamma: Optional[Tensor],
  1335. N: int,
  1336. C: int,
  1337. HxW: int,
  1338. group: int,
  1339. output_mask: list[bool],
  1340. ) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
  1341. utils.check_same_device(
  1342. grad_output, input, mean, rstd, allow_cpu_scalar_tensors=False
  1343. )
  1344. utils.check_same_shape(input, grad_output, allow_cpu_scalar_tensors=False)
  1345. utils.check_same_shape(mean, rstd, allow_cpu_scalar_tensors=False)
  1346. torch._check(
  1347. input.numel() == N * C * HxW,
  1348. lambda: f"Expect input to have {N * C * HxW} elements",
  1349. )
  1350. torch._check(
  1351. mean.shape == (N, group),
  1352. lambda: f"Expect mean to have shape ({N}, {group}, but got {mean.shape}",
  1353. )
  1354. torch._check(
  1355. gamma is None or gamma.numel() == C,
  1356. lambda: f"Expect gamma to have {C} elements but got {gamma.numel() if gamma is not None else -1}",
  1357. )
  1358. cpg = C // group
  1359. torch._check(
  1360. C == cpg * group,
  1361. lambda: f"Expect number of channels {C} to be evenly-divisible by number of groups {group}",
  1362. )
  1363. # Compute Internal gradients
  1364. ds = torch.mul(grad_output, input).view(N, C, HxW).sum(dim=[2])
  1365. db = grad_output.view(N, C, HxW).sum(dim=[2])
  1366. d_input: Optional[Tensor] = None
  1367. d_gamma: Optional[Tensor] = None
  1368. d_bias: Optional[Tensor] = None
  1369. if output_mask[0]:
  1370. s = 1.0 / (HxW * cpg)
  1371. if gamma is not None:
  1372. ds_val = torch.mul(ds, gamma.unsqueeze(0)).reshape(N, group, cpg).sum(2)
  1373. db_val = torch.mul(db, gamma.unsqueeze(0)).reshape(N, group, cpg).sum(2)
  1374. c1 = torch.mul(
  1375. rstd.unsqueeze(-1),
  1376. gamma.reshape(1, group, cpg),
  1377. )
  1378. else:
  1379. ds_val = ds.reshape(N, group, cpg).sum(2)
  1380. db_val = db.reshape(N, group, cpg).sum(2)
  1381. c1 = torch.mul(
  1382. rstd.unsqueeze(-1),
  1383. torch.ones((1, group, cpg), device=rstd.device),
  1384. )
  1385. c2 = (db_val * mean - ds_val) * rstd * rstd * rstd * s
  1386. c3 = -c2 * mean - db_val * rstd * s
  1387. c1 = c1.unsqueeze(-1)
  1388. c2 = _unsqueeze_to_dim(c2, 4)
  1389. c3 = _unsqueeze_to_dim(c3, 4)
  1390. d_input = (
  1391. torch.mul(grad_output.reshape(N, group, cpg, HxW), c1)
  1392. + torch.mul(input.reshape(N, group, cpg, HxW), c2)
  1393. + c3
  1394. )
  1395. d_input = d_input.reshape(input.shape).to(input.dtype)
  1396. if output_mask[1]:
  1397. d_gamma = (
  1398. (
  1399. (ds.view(N, group, cpg) - db.view(N, group, cpg) * mean.unsqueeze(-1))
  1400. * rstd.unsqueeze(-1)
  1401. )
  1402. .sum(dim=[0])
  1403. .reshape(C)
  1404. )
  1405. if output_mask[2]:
  1406. d_bias = db.sum(dim=[0])
  1407. return (d_input, d_gamma, d_bias)
  1408. # out_wrapper currently does not allow optional outputs
  1409. @register_decomposition(aten.native_group_norm_backward.out)
  1410. def native_group_norm_backward_out(
  1411. grad_output: Tensor,
  1412. input: Tensor,
  1413. mean: Tensor,
  1414. rstd: Tensor,
  1415. gamma: Optional[Tensor],
  1416. N: int,
  1417. C: int,
  1418. HxW: int,
  1419. group: int,
  1420. output_mask: list[bool],
  1421. *,
  1422. out0: torch.Tensor,
  1423. out1: torch.Tensor,
  1424. out2: torch.Tensor,
  1425. ) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
  1426. result = native_group_norm_backward(
  1427. grad_output, input, mean, rstd, gamma, N, C, HxW, group, output_mask
  1428. )
  1429. grad_input = (out0, out1, out2)
  1430. for i, r in enumerate(result):
  1431. if r is not None:
  1432. _maybe_resize_out(grad_input[i], r.shape)
  1433. _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True)
  1434. return grad_input
  1435. def _maybe_cast(x: Optional[Tensor], dtype) -> Optional[Tensor]:
  1436. if x is not None:
  1437. return x.to(dtype)
  1438. return x
  1439. # TODO: Take a closer look at the type promotion semantics
  1440. @register_decomposition(aten.native_layer_norm_backward.default)
  1441. def native_layer_norm_backward(
  1442. grad_out: Tensor,
  1443. input: Tensor,
  1444. normalized_shape: list[int],
  1445. mean: Tensor,
  1446. rstd: Tensor,
  1447. weight: Optional[Tensor],
  1448. bias: Optional[Tensor],
  1449. output_mask: list[bool],
  1450. ) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
  1451. input_shape = input.shape
  1452. input_ndim = input.dim()
  1453. computation_dtype = utils.get_computation_dtype(input.dtype)
  1454. grad_out_cast, input_cast, weight_cast, bias_cast = (
  1455. x.to(computation_dtype, memory_format=torch.contiguous_format)
  1456. if x is not None
  1457. else x
  1458. for x in (grad_out, input, weight, bias)
  1459. )
  1460. if grad_out_cast is None:
  1461. raise AssertionError("grad_out_cast should not be None")
  1462. axis = input_ndim - len(normalized_shape)
  1463. inner_dims = input_shape[axis:]
  1464. outer_dims = input_shape[:axis]
  1465. inner_dim_indices: list[int] = []
  1466. outer_dim_indices: list[int] = []
  1467. for i in range(input_ndim):
  1468. if i >= axis:
  1469. inner_dim_indices.append(i)
  1470. else:
  1471. outer_dim_indices.append(i)
  1472. N = prod(inner_dims) # type: ignore[arg-type]
  1473. M = prod(outer_dims) # type: ignore[arg-type]
  1474. from torch.fx.experimental.symbolic_shapes import statically_known_true
  1475. if statically_known_true(M == 0) or statically_known_true(N == 0):
  1476. return (
  1477. input.new_zeros(input_shape) if output_mask[0] else None,
  1478. input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
  1479. input.new_zeros(input_shape[axis:]) if output_mask[2] else None,
  1480. )
  1481. mean = _unsqueeze_to_dim(mean, input_cast.dim()) # type: ignore[union-attr]
  1482. rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr]
  1483. if input_cast is None:
  1484. raise AssertionError("input_cast should not be None")
  1485. x_hat = (input_cast - mean) * rstd
  1486. if weight_cast is not None:
  1487. grad_x_hat = grad_out_cast * weight_cast
  1488. else:
  1489. grad_x_hat = grad_out_cast
  1490. a = grad_x_hat * N
  1491. b = torch.sum(grad_x_hat, inner_dim_indices, True)
  1492. c1 = torch.mul(grad_x_hat, x_hat)
  1493. c2 = torch.sum(c1, inner_dim_indices, True)
  1494. c3 = torch.mul(x_hat, c2)
  1495. inner = a - b - c3
  1496. d_input: Optional[Tensor] = None
  1497. d_weight: Optional[Tensor] = None
  1498. d_bias: Optional[Tensor] = None
  1499. if output_mask[0]:
  1500. d_input = (rstd / N) * inner
  1501. if output_mask[1] and weight_cast is not None:
  1502. if len(outer_dim_indices) > 0:
  1503. d_weight = torch.sum(grad_out_cast * x_hat, outer_dim_indices, False)
  1504. else:
  1505. d_weight = grad_out_cast * x_hat
  1506. if output_mask[2] and bias_cast is not None:
  1507. if len(outer_dim_indices) > 0:
  1508. d_bias = torch.sum(grad_out_cast, outer_dim_indices, False)
  1509. else:
  1510. d_bias = grad_out_cast.clone()
  1511. return (
  1512. _maybe_cast(d_input, input.dtype),
  1513. _maybe_cast(d_weight, weight.dtype if weight is not None else None),
  1514. _maybe_cast(d_bias, bias.dtype if bias is not None else None),
  1515. )
  1516. # out_wrapper currently does not allow optional outputs
  1517. @register_decomposition(aten.native_layer_norm_backward.out)
  1518. def native_layer_norm_backward_out(
  1519. grad_out: Tensor,
  1520. input: Tensor,
  1521. normalized_shape: list[int],
  1522. mean: Tensor,
  1523. rstd: Tensor,
  1524. weight: Optional[Tensor],
  1525. bias: Optional[Tensor],
  1526. output_mask: list[bool],
  1527. *,
  1528. out0: torch.Tensor,
  1529. out1: torch.Tensor,
  1530. out2: torch.Tensor,
  1531. ) -> tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
  1532. result = native_layer_norm_backward(
  1533. grad_out, input, normalized_shape, mean, rstd, weight, bias, output_mask
  1534. )
  1535. grad_input = (out0, out1, out2)
  1536. for i, r in enumerate(result):
  1537. if r is not None:
  1538. _maybe_resize_out(grad_input[i], r.shape)
  1539. _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True)
  1540. return grad_input
  1541. @register_decomposition(aten._fused_rms_norm.default)
  1542. def _fused_rms_norm(
  1543. input: Tensor,
  1544. normalized_shape: list[int],
  1545. weight: Optional[Tensor],
  1546. eps: Optional[float],
  1547. ) -> tuple[Tensor, Tensor]:
  1548. dims_to_reduce: list[int] = []
  1549. for i in range(len(normalized_shape)):
  1550. dims_to_reduce.append(input.dim() - i - 1)
  1551. # upcast is needed for fp16 and bf16
  1552. computation_dtype = utils.get_computation_dtype(input.dtype)
  1553. upcasted_input = input.to(computation_dtype)
  1554. # computation_dtype would be one of [Double, Float, ComplexFloat, ComplexDouble]
  1555. if eps is None:
  1556. if computation_dtype in (torch.float32, torch.complex64):
  1557. eps_val = torch.finfo(torch.float32).eps
  1558. else:
  1559. eps_val = torch.finfo(torch.float64).eps
  1560. else:
  1561. eps_val = eps
  1562. rqrst_input = torch.rsqrt(
  1563. # NB: don't inplace here, will violate functional IR invariant
  1564. # NB: carefully use the Scalar overload of add to ensure compatibility with the C++ decomp
  1565. torch.ops.aten.add.Scalar(
  1566. torch.pow(upcasted_input, 2).mean(dim=dims_to_reduce, keepdim=True), eps_val
  1567. )
  1568. )
  1569. upcasted_result = upcasted_input.mul(rqrst_input)
  1570. if weight is not None:
  1571. upcasted_result = upcasted_result.mul(weight)
  1572. # NB: nested should be dead here, just here for fidelity
  1573. is_nested = input.is_nested or (weight is not None and weight.is_nested)
  1574. memory_format = utils.suggest_memory_format(input)
  1575. is_channels_last = memory_format in (
  1576. torch.channels_last,
  1577. torch.channels_last_3d,
  1578. )
  1579. if not is_nested and not is_channels_last:
  1580. upcasted_result = upcasted_result.contiguous()
  1581. rqrst_input = rqrst_input.contiguous()
  1582. # Cast normalized result back to original input type
  1583. result = upcasted_result.type_as(input)
  1584. return result, rqrst_input
  1585. @register_decomposition(aten._fused_rms_norm_backward.default)
  1586. def _fused_rms_norm_backward(
  1587. grad_out: Tensor,
  1588. input: Tensor,
  1589. normalized_shape: list[int],
  1590. rstd: Tensor,
  1591. weight: Optional[Tensor],
  1592. output_mask: list[bool],
  1593. ) -> tuple[Optional[Tensor], Optional[Tensor]]:
  1594. input_shape = input.shape
  1595. input_ndim = input.dim()
  1596. computation_dtype = utils.get_computation_dtype(input.dtype)
  1597. grad_out_cast = grad_out.to(
  1598. computation_dtype, memory_format=torch.contiguous_format
  1599. )
  1600. input_cast = input.to(computation_dtype, memory_format=torch.contiguous_format)
  1601. weight_cast = (
  1602. weight.to(computation_dtype, memory_format=torch.contiguous_format)
  1603. if weight is not None
  1604. else None
  1605. )
  1606. if grad_out_cast is None:
  1607. raise AssertionError("grad_out_cast should not be None")
  1608. axis = input_ndim - len(normalized_shape)
  1609. inner_dims = input_shape[axis:]
  1610. outer_dims = input_shape[:axis]
  1611. inner_dim_indices: list[int] = []
  1612. outer_dim_indices: list[int] = []
  1613. for i in range(input_ndim):
  1614. if i >= axis:
  1615. inner_dim_indices.append(i)
  1616. else:
  1617. outer_dim_indices.append(i)
  1618. N = prod(inner_dims) # type: ignore[arg-type]
  1619. M = prod(outer_dims) # type: ignore[arg-type]
  1620. from torch.fx.experimental.symbolic_shapes import guard_or_false
  1621. if guard_or_false(M == 0) or guard_or_false(N == 0):
  1622. return (
  1623. input.new_zeros(input_shape) if output_mask[0] else None,
  1624. input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
  1625. )
  1626. rstd = _unsqueeze_to_dim(rstd, input_cast.dim()) # type: ignore[union-attr]
  1627. if weight_cast is not None:
  1628. grad_x_hat = grad_out_cast * weight_cast
  1629. else:
  1630. grad_x_hat = grad_out_cast
  1631. d_input: Optional[Tensor] = None
  1632. d_weight: Optional[Tensor] = None
  1633. x_hat = input_cast * rstd
  1634. if output_mask[0]:
  1635. sum_val = torch.sum(x_hat * grad_x_hat, dim=inner_dim_indices, keepdim=True)
  1636. d_input = (grad_x_hat - (x_hat / N) * sum_val) * rstd
  1637. if output_mask[1] and weight_cast is not None:
  1638. d_weight_full_shape = grad_out_cast * x_hat
  1639. if len(outer_dim_indices) > 0:
  1640. d_weight = torch.sum(
  1641. d_weight_full_shape, dim=outer_dim_indices, keepdim=False
  1642. )
  1643. else:
  1644. d_weight = d_weight_full_shape
  1645. return (
  1646. _maybe_cast(d_input, input.dtype),
  1647. _maybe_cast(d_weight, input.dtype),
  1648. )
  1649. def native_batch_norm_helper(
  1650. input: Tensor,
  1651. weight: Optional[Tensor],
  1652. bias: Optional[Tensor],
  1653. running_mean: Optional[Tensor],
  1654. running_var: Optional[Tensor],
  1655. training: bool,
  1656. momentum: float,
  1657. eps: float,
  1658. functional: bool,
  1659. ) -> tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
  1660. reduction_dims = [0] + list(range(2, input.dim()))
  1661. computation_dtype = utils.get_computation_dtype(input.dtype)
  1662. new_running_mean = running_mean
  1663. new_running_var = running_var
  1664. if training:
  1665. computation_dtype = utils.get_computation_dtype(input.dtype)
  1666. input_acc = input.to(dtype=computation_dtype)
  1667. biased_var, mean = torch.var_mean(
  1668. input_acc, dim=reduction_dims, correction=0, keepdim=True
  1669. )
  1670. rstd = torch.rsqrt(biased_var + eps)
  1671. output = (input - mean) * rstd
  1672. save_mean = torch.squeeze(mean, reduction_dims)
  1673. save_rstd = torch.squeeze(rstd, reduction_dims)
  1674. if running_mean is not None:
  1675. new_running_mean = momentum * save_mean + (1 - momentum) * running_mean
  1676. if not functional:
  1677. running_mean.copy_(new_running_mean)
  1678. if running_var is not None:
  1679. n = input.numel() / input.shape[1]
  1680. # This doesn't strictly match eager's numerics, which accumulates var sum and then directly applies the correction
  1681. # But... that would require re-implementing var here, for negligible numerics gain on a tensor whose
  1682. # numerics probably don't matter.
  1683. squeezed_var = torch.squeeze(biased_var, reduction_dims)
  1684. unbiased_var = squeezed_var * (n / (n - 1))
  1685. new_running_var = momentum * unbiased_var + (1 - momentum) * running_var
  1686. if not functional:
  1687. running_var.copy_(new_running_var)
  1688. else:
  1689. if running_mean is None or running_var is None:
  1690. raise AssertionError(
  1691. "running_mean and running_var must not be None in eval mode"
  1692. )
  1693. running_mean = running_mean.to(dtype=computation_dtype, copy=True)
  1694. new_running_mean = running_mean
  1695. running_var = running_var.to(dtype=computation_dtype, copy=True)
  1696. new_running_var = running_var
  1697. mean = running_mean
  1698. invstd = 1 / (torch.sqrt(running_var + eps))
  1699. # Very annoying inconsistency where CPU and CUDA give different shapes
  1700. if input.device.type != "cpu":
  1701. save_mean = running_mean
  1702. save_rstd = invstd
  1703. else:
  1704. save_mean = input.new_zeros((0,))
  1705. save_rstd = input.new_zeros((0,))
  1706. mean = _unsqueeze_to_dim(mean, input.dim() - 1)
  1707. invstd = _unsqueeze_to_dim(invstd, input.dim() - 1)
  1708. output = (input - mean) * invstd
  1709. if weight is not None:
  1710. weight = weight.flatten()
  1711. weight = _unsqueeze_to_dim(weight, input.dim() - 1)
  1712. output = output * weight
  1713. if bias is not None:
  1714. bias = bias.flatten()
  1715. bias = _unsqueeze_to_dim(bias, input.dim() - 1)
  1716. output = output + bias
  1717. if input.device.type == "cpu":
  1718. save_mean = save_mean.to(dtype=input.dtype)
  1719. save_rstd = save_rstd.to(dtype=input.dtype)
  1720. return (
  1721. output.to(dtype=input.dtype),
  1722. save_mean,
  1723. save_rstd,
  1724. new_running_mean,
  1725. new_running_var,
  1726. )
  1727. @register_decomposition(aten.native_batch_norm)
  1728. @out_wrapper("out", "save_mean", "save_invstd")
  1729. def native_batch_norm(
  1730. input: Tensor,
  1731. weight: Optional[Tensor],
  1732. bias: Optional[Tensor],
  1733. running_mean: Optional[Tensor],
  1734. running_var: Optional[Tensor],
  1735. training: bool,
  1736. momentum: float,
  1737. eps: float,
  1738. ) -> tuple[Tensor, Tensor, Tensor]:
  1739. output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
  1740. input, weight, bias, running_mean, running_var, training, momentum, eps, False
  1741. )
  1742. return output, save_mean, save_rstd
  1743. # TODO: this decomposition is NOT here to stay. We would much prefer replacing native_batch_norm
  1744. # with our new correctly schema'd _native_batch_norm_legit and its variants, but
  1745. # we cannot do that immediately in the C++ because it would be forwards incompatible
  1746. # with some mobile use cases.
  1747. #
  1748. # Since this change is most impactful for aot autograd/functionalization, we simply
  1749. # register this decomposition on the Autograd key for the python dispatcher (which is
  1750. # currently only used by aot autograd/functionalization and no one else, really).
  1751. # In two weeks or so, we should remove this decomposition and phase out the current native_batch_norm
  1752. # to be _native_batch_norm_legit and have the right schema (stating that there are input mutations).
  1753. @aten.native_batch_norm.default.py_impl(DispatchKey.Autograd)
  1754. @aten.native_batch_norm.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  1755. def native_batch_norm_decomposition(
  1756. input: Tensor,
  1757. weight: Optional[Tensor],
  1758. bias: Optional[Tensor],
  1759. running_mean: Optional[Tensor],
  1760. running_var: Optional[Tensor],
  1761. training: bool,
  1762. momentum: float,
  1763. eps: float,
  1764. ) -> tuple[Tensor, Tensor, Tensor]:
  1765. if running_mean is None and running_var is None:
  1766. return aten._native_batch_norm_legit(
  1767. input, weight, bias, training, momentum, eps
  1768. )
  1769. if running_mean is None:
  1770. raise RuntimeError(
  1771. "running_mean is None, but running_var is provided. "
  1772. "They should both be None or both be provided."
  1773. )
  1774. if running_var is None:
  1775. raise RuntimeError(
  1776. "running_var is None, but running_mean is provided. "
  1777. "They should both be None or both be provided."
  1778. )
  1779. if training:
  1780. # HACK: batch norm consolidation should clean this up so this op doesn't take in a training arg.
  1781. return aten._native_batch_norm_legit(
  1782. input, weight, bias, running_mean, running_var, training, momentum, eps
  1783. )
  1784. else:
  1785. return aten._native_batch_norm_legit_no_training(
  1786. input, weight, bias, running_mean, running_var, momentum, eps
  1787. )
  1788. @aten.unsafe_chunk.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  1789. def unsafe_chunk_py_impl(tensor, chunks, dim=0) -> list[Tensor]:
  1790. dim_size = tensor.size(dim)
  1791. split_size = (dim_size + chunks - 1) // chunks
  1792. if split_size == 0 and dim_size == 0:
  1793. split_sizes = [split_size for _ in range(chunks)]
  1794. split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size)
  1795. return torch.ops.aten.unsafe_split_with_sizes.default(tensor, split_sizes, dim)
  1796. return torch.ops.aten.unsafe_split.Tensor(tensor, split_size, dim)
  1797. @register_decomposition(aten._native_batch_norm_legit_no_training.default)
  1798. def _native_batch_norm_legit_no_training(
  1799. input: Tensor,
  1800. weight: Optional[Tensor],
  1801. bias: Optional[Tensor],
  1802. running_mean: Tensor,
  1803. running_var: Tensor,
  1804. momentum: float,
  1805. eps: float,
  1806. ) -> tuple[Tensor, Tensor, Tensor]:
  1807. return aten._native_batch_norm_legit.default(
  1808. input,
  1809. weight,
  1810. bias,
  1811. running_mean,
  1812. running_var,
  1813. False, # training
  1814. momentum,
  1815. eps,
  1816. )
  1817. @register_decomposition(aten._native_batch_norm_legit.default)
  1818. def _native_batch_norm_legit(
  1819. input: Tensor,
  1820. weight: Optional[Tensor],
  1821. bias: Optional[Tensor],
  1822. running_mean: Tensor,
  1823. running_var: Tensor,
  1824. training: bool,
  1825. momentum: float,
  1826. eps: float,
  1827. ) -> tuple[Tensor, Tensor, Tensor]:
  1828. output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
  1829. input, weight, bias, running_mean, running_var, training, momentum, eps, False
  1830. )
  1831. return output, save_mean, save_rstd
  1832. @register_decomposition(aten._native_batch_norm_legit.no_stats)
  1833. def _native_batch_norm_legit_no_stats(
  1834. input: Tensor,
  1835. weight: Optional[Tensor],
  1836. bias: Optional[Tensor],
  1837. training: bool,
  1838. momentum: float,
  1839. eps: float,
  1840. ) -> tuple[Tensor, Tensor, Tensor]:
  1841. output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
  1842. input, weight, bias, None, None, training, momentum, eps, False
  1843. )
  1844. return output, save_mean, save_rstd
  1845. @register_decomposition(aten._native_batch_norm_legit_functional.default)
  1846. def _native_batch_norm_legit_functional(
  1847. input: Tensor,
  1848. weight: Optional[Tensor],
  1849. bias: Optional[Tensor],
  1850. running_mean: Tensor,
  1851. running_var: Tensor,
  1852. training: bool,
  1853. momentum: float,
  1854. eps: float,
  1855. ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
  1856. (
  1857. output,
  1858. save_mean,
  1859. save_rstd,
  1860. new_running_mean,
  1861. new_running_var,
  1862. ) = native_batch_norm_helper(
  1863. input, weight, bias, running_mean, running_var, training, momentum, eps, True
  1864. )
  1865. if new_running_mean is None:
  1866. raise AssertionError("new_running_mean should not be None")
  1867. if new_running_var is None:
  1868. raise AssertionError("new_running_var should not be None")
  1869. return output, save_mean, save_rstd, new_running_mean, new_running_var
  1870. def _get_batch_norm_reserve_tensor(
  1871. input: Tensor,
  1872. weight: Optional[Tensor],
  1873. bias: Optional[Tensor],
  1874. running_mean: Tensor,
  1875. running_var: Tensor,
  1876. eps: float,
  1877. training: bool,
  1878. ) -> Tensor:
  1879. """
  1880. Return a reserve tensor for batch norm, used only by cudnn to pass forward state to the
  1881. backward pass. This is needed for `_batch_norm_with_update` and `_batch_norm_no_update`,
  1882. which support a variety of backends including cudnn. We create this tensor here to get
  1883. the correct shape in the traced graph if we detect that will call the cudnn kernel,
  1884. and rely on DCE to avoid materializing this tensor.
  1885. """
  1886. backend = torch._C._select_batch_norm_backend( # type: ignore[attr-defined]
  1887. input, weight, bias, running_mean, running_var, True, eps
  1888. )
  1889. reserve_size = 0
  1890. if backend == torch._C._BatchNormBackend.Cudnn: # type: ignore[attr-defined]
  1891. reserve_size = torch._C._get_cudnn_batch_norm_reserve_space_size( # type: ignore[attr-defined]
  1892. input, training
  1893. )
  1894. return torch.empty(
  1895. reserve_size, dtype=torch.uint8, layout=input.layout, device=input.device
  1896. )
  1897. @register_decomposition(aten._batch_norm_with_update.default)
  1898. def _batch_norm_with_update(
  1899. input: Tensor,
  1900. weight: Optional[Tensor],
  1901. bias: Optional[Tensor],
  1902. running_mean: Tensor,
  1903. running_var: Tensor,
  1904. momentum: float,
  1905. eps: float,
  1906. ) -> tuple[Tensor, Tensor, Tensor, Tensor]:
  1907. output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
  1908. input,
  1909. weight,
  1910. bias,
  1911. running_mean,
  1912. running_var,
  1913. True, # training
  1914. momentum,
  1915. eps,
  1916. False, # functional
  1917. )
  1918. reserve = _get_batch_norm_reserve_tensor(
  1919. input, weight, bias, running_mean, running_var, eps, training=True
  1920. )
  1921. return output, save_mean, save_rstd, reserve
  1922. @register_decomposition(aten._batch_norm_with_update_functional.default)
  1923. def _batch_norm_with_update_functional(
  1924. input: Tensor,
  1925. weight: Optional[Tensor],
  1926. bias: Optional[Tensor],
  1927. running_mean: Tensor,
  1928. running_var: Tensor,
  1929. momentum: float,
  1930. eps: float,
  1931. ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
  1932. (
  1933. output,
  1934. save_mean,
  1935. save_rstd,
  1936. new_rm,
  1937. new_rv,
  1938. ) = native_batch_norm_helper(
  1939. input, weight, bias, running_mean, running_var, True, momentum, eps, True
  1940. )
  1941. reserve = _get_batch_norm_reserve_tensor(
  1942. input, weight, bias, running_mean, running_var, eps, training=True
  1943. )
  1944. if new_rm is None:
  1945. raise AssertionError("new_running_mean should not be None")
  1946. if new_rv is None:
  1947. raise AssertionError("new_running_var should not be None")
  1948. return (output, save_mean, save_rstd, reserve, new_rm, new_rv)
  1949. @register_decomposition(aten._batch_norm_no_update.default)
  1950. def _batch_norm_no_update(
  1951. input: Tensor,
  1952. weight: Optional[Tensor],
  1953. bias: Optional[Tensor],
  1954. running_mean: Tensor,
  1955. running_var: Tensor,
  1956. momentum: float,
  1957. eps: float,
  1958. ) -> tuple[Tensor, Tensor, Tensor, Tensor]:
  1959. output, save_mean, save_rstd, _, _ = native_batch_norm_helper(
  1960. input,
  1961. weight,
  1962. bias,
  1963. running_mean,
  1964. running_var,
  1965. False, # training
  1966. momentum,
  1967. eps,
  1968. False, # functional
  1969. )
  1970. reserve = _get_batch_norm_reserve_tensor(
  1971. input, weight, bias, running_mean, running_var, eps, training=False
  1972. )
  1973. return output, save_mean, save_rstd, reserve
  1974. @register_decomposition(aten._fused_dropout)
  1975. @out_wrapper("out0", "out1")
  1976. @pw_cast_for_opmath
  1977. def _fused_dropout_decomposition(input, p, generator=None):
  1978. if generator is not None:
  1979. raise AssertionError(
  1980. f"generator must be None for _fused_dropout decomposition, got {generator}"
  1981. )
  1982. mask = (torch.rand_like(input) < p).to(dtype=torch.uint8)
  1983. res = mask.type_as(input) * input * (1.0 / p)
  1984. return (res, mask)
  1985. @register_decomposition(aten._to_copy)
  1986. @out_wrapper()
  1987. def _to_copy(
  1988. x: Union[Tensor, NumberType],
  1989. *,
  1990. dtype: Optional[torch.dtype] = None,
  1991. layout=None,
  1992. device: Optional[torch.device] = None,
  1993. pin_memory: bool = False,
  1994. non_blocking: bool = False,
  1995. memory_format: Optional[torch.memory_format] = None,
  1996. ):
  1997. if layout and layout != torch.strided:
  1998. raise AssertionError(f"layout must be None or torch.strided, got {layout}")
  1999. if pin_memory:
  2000. raise AssertionError(
  2001. "pin_memory=True is not supported in _to_copy decomposition"
  2002. )
  2003. if not isinstance(x, (torch.Tensor, int, float, bool, complex)):
  2004. raise AssertionError(f"x must be Tensor or scalar, got {type(x).__name__}")
  2005. if device is None and dtype is None and memory_format is None:
  2006. if isinstance(x, torch.Tensor):
  2007. return x.clone()
  2008. else:
  2009. return x
  2010. dtype_converted = False
  2011. if isinstance(x, torch.Tensor):
  2012. x_tensor = x
  2013. else:
  2014. x_tensor = torch.scalar_tensor(x)
  2015. if device is not None and device != x_tensor.device:
  2016. # avoid conversions on cpu
  2017. if dtype is not None and device.type == "cpu":
  2018. x_tensor = torch._prims.convert_element_type(x_tensor, dtype)
  2019. dtype_converted = True
  2020. x_tensor = torch._prims.device_put(x_tensor, device, non_blocking)
  2021. if dtype is not None and not dtype_converted:
  2022. x_tensor = torch._prims.convert_element_type(x_tensor, dtype)
  2023. dtype_converted = True
  2024. if memory_format is not None: # no ref/prim for memory format
  2025. return torch.clone(x_tensor, memory_format=memory_format)
  2026. return x_tensor
  2027. # Questionable decompositions
  2028. # This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
  2029. # Note that this decomposition causes issues with in-place ops
  2030. @register_decomposition([aten.detach, aten.lift, aten.lift_fresh])
  2031. @out_wrapper()
  2032. def nop_decomposition(x):
  2033. return aten.alias(x)
  2034. # Also register to the Autograd dispatch key, so this decomp can run above autograd.
  2035. # native_batch_norm needs to decompose into other ops before autograd.
  2036. @aten.cudnn_batch_norm.default.py_impl(DispatchKey.Autograd)
  2037. @register_decomposition(aten.cudnn_batch_norm)
  2038. @out_wrapper("out0", "out1", "out2", "out3")
  2039. def cudnn_batch_norm(
  2040. input: Tensor,
  2041. weight: Tensor,
  2042. bias: Optional[Tensor],
  2043. running_mean: Optional[Tensor],
  2044. running_var: Optional[Tensor],
  2045. training: bool,
  2046. exponential_average_factor: float,
  2047. epsilon: float,
  2048. ):
  2049. a, b, c = aten.native_batch_norm(
  2050. input,
  2051. weight,
  2052. bias,
  2053. running_mean,
  2054. running_var,
  2055. training,
  2056. exponential_average_factor,
  2057. epsilon,
  2058. )
  2059. # Cudnn return running mean and variance when training is True
  2060. if training:
  2061. return (a, b, c, input.new_zeros((0,), dtype=torch.uint8))
  2062. return (
  2063. a,
  2064. weight.new_zeros((0,)),
  2065. weight.new_zeros((0,)),
  2066. input.new_zeros((0,), dtype=torch.uint8),
  2067. )
  2068. def _broadcast_batch_norm_backward(x, broadcast_mask):
  2069. for axis, mask in enumerate(broadcast_mask):
  2070. if mask == 1 and not (axis < x.ndim and x.shape[axis] == mask):
  2071. x = x.unsqueeze(axis)
  2072. return x
  2073. @register_decomposition(aten.batch_norm_backward.default)
  2074. def batch_norm_backward(
  2075. grad_out: Tensor,
  2076. input: Tensor,
  2077. weight: Optional[Tensor],
  2078. running_mean: Optional[Tensor],
  2079. running_var: Optional[Tensor],
  2080. save_mean: Optional[Tensor],
  2081. save_invstd: Optional[Tensor],
  2082. train: bool,
  2083. eps: float,
  2084. output_mask: list[bool],
  2085. reserve: Tensor,
  2086. ) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
  2087. return native_batch_norm_backward(
  2088. grad_out,
  2089. input,
  2090. weight,
  2091. running_mean,
  2092. running_var,
  2093. save_mean,
  2094. save_invstd,
  2095. train,
  2096. eps,
  2097. output_mask,
  2098. )
  2099. @register_decomposition(aten.native_batch_norm_backward.default)
  2100. def native_batch_norm_backward(
  2101. grad_out: Tensor,
  2102. input: Tensor,
  2103. weight: Optional[Tensor],
  2104. running_mean: Optional[Tensor],
  2105. running_var: Optional[Tensor],
  2106. save_mean: Optional[Tensor],
  2107. save_invstd: Optional[Tensor],
  2108. train: bool,
  2109. eps: float,
  2110. output_mask: list[bool],
  2111. ) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
  2112. input_dtype = input.dtype
  2113. if weight is not None:
  2114. weight_dtype = weight.dtype
  2115. else:
  2116. weight_dtype = input_dtype
  2117. computation_dtype = utils.get_computation_dtype(input.dtype)
  2118. (
  2119. grad_out_cast,
  2120. input_cast,
  2121. weight_cast,
  2122. running_mean_cast,
  2123. running_var_cast,
  2124. save_mean_cast,
  2125. save_invstd_cast,
  2126. ) = (
  2127. x.to(computation_dtype) if x is not None else x
  2128. for x in (
  2129. grad_out,
  2130. input,
  2131. weight,
  2132. running_mean,
  2133. running_var,
  2134. save_mean,
  2135. save_invstd,
  2136. )
  2137. )
  2138. input_shape = input.shape
  2139. input_rank = input.dim()
  2140. if input_rank < 2:
  2141. raise AssertionError(f"rank of the input must be at least 2, got {input_rank}")
  2142. axis = 1
  2143. num_features = prod(list(input_shape)) / input_shape[axis]
  2144. mean = save_mean_cast
  2145. invstd = save_invstd_cast
  2146. if train:
  2147. if mean is None or invstd is None:
  2148. raise AssertionError("mean and invstd must not be None in training mode")
  2149. else:
  2150. if running_mean_cast is None or running_var_cast is None:
  2151. raise AssertionError(
  2152. "running_mean_cast and running_var_cast must not be None in eval mode"
  2153. )
  2154. mean = running_mean_cast
  2155. invstd = torch.rsqrt(running_var_cast + eps)
  2156. broadcast_mask: list[int] = [1] * input_rank
  2157. broadcast_mask[axis] = input_shape[axis]
  2158. reduction_axes: list[int] = []
  2159. for i in range(input_rank):
  2160. if i != axis:
  2161. reduction_axes.append(i)
  2162. mean = _broadcast_batch_norm_backward(mean, broadcast_mask) # type: ignore[arg-type]
  2163. norm = 1.0 / num_features
  2164. grad_output_sum = torch.sum(grad_out_cast, reduction_axes) # type: ignore[arg-type]
  2165. dot_p = torch.sum(grad_out_cast * (input_cast - mean), reduction_axes) # type: ignore[operator]
  2166. grad_mean = _broadcast_batch_norm_backward(grad_output_sum * norm, broadcast_mask)
  2167. proj_scale = _broadcast_batch_norm_backward(
  2168. torch.mul(dot_p * norm, invstd * invstd), # type: ignore[operator]
  2169. broadcast_mask,
  2170. )
  2171. if weight_cast is None:
  2172. grad_scale = _broadcast_batch_norm_backward(invstd, broadcast_mask) * 1.0 # type: ignore[arg-type]
  2173. else:
  2174. grad_scale = _broadcast_batch_norm_backward(
  2175. invstd * weight_cast, broadcast_mask
  2176. )
  2177. if train:
  2178. proj = (input_cast - mean) * proj_scale # type: ignore[operator]
  2179. grad_input = ((grad_out_cast - proj) - grad_mean) * grad_scale
  2180. else:
  2181. grad_input = grad_out_cast * grad_scale
  2182. if output_mask[1]:
  2183. grad_weight = dot_p * invstd
  2184. else:
  2185. grad_weight = None # "None" doesn't work with vjp, should use zeros for vjp
  2186. if output_mask[2]:
  2187. grad_bias = grad_output_sum
  2188. else:
  2189. grad_bias = None # "None" doesn't work with vjp, should use zeros for vjp
  2190. return (
  2191. grad_input.to(input_dtype),
  2192. _maybe_cast(grad_weight, weight_dtype),
  2193. _maybe_cast(grad_bias, weight_dtype),
  2194. )
  2195. # out_wrapper currently does not allow optional outputs
  2196. @register_decomposition(aten.native_batch_norm_backward.out)
  2197. def native_batch_norm_backward_out(
  2198. grad_out: Tensor,
  2199. input: Tensor,
  2200. weight: Optional[Tensor],
  2201. running_mean: Optional[Tensor],
  2202. running_var: Optional[Tensor],
  2203. save_mean: Optional[Tensor],
  2204. save_invstd: Optional[Tensor],
  2205. train: bool,
  2206. eps: float,
  2207. output_mask: list[bool],
  2208. *,
  2209. out0: torch.Tensor,
  2210. out1: torch.Tensor,
  2211. out2: torch.Tensor,
  2212. ) -> tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
  2213. result = native_batch_norm_backward(
  2214. grad_out,
  2215. input,
  2216. weight,
  2217. running_mean,
  2218. running_var,
  2219. save_mean,
  2220. save_invstd,
  2221. train,
  2222. eps,
  2223. output_mask,
  2224. )
  2225. grad_input = (out0, out1, out2)
  2226. for i, r in enumerate(result):
  2227. if r is not None:
  2228. _maybe_resize_out(grad_input[i], r.shape)
  2229. _safe_copy_out(copy_from=r, copy_to=grad_input[i], exact_dtype=True)
  2230. return grad_input
  2231. @register_decomposition(aten.miopen_batch_norm_backward)
  2232. @out_wrapper("out0", "out1", "out2")
  2233. def miopen_batch_norm_backward(
  2234. input: Tensor,
  2235. grad_output: Tensor,
  2236. weight: Tensor,
  2237. running_mean: Optional[Tensor],
  2238. running_var: Optional[Tensor],
  2239. save_mean: Optional[Tensor],
  2240. save_var: Optional[Tensor],
  2241. epsilon: float,
  2242. ):
  2243. return aten.native_batch_norm_backward(
  2244. grad_output,
  2245. input,
  2246. weight,
  2247. running_mean,
  2248. running_var,
  2249. save_mean,
  2250. save_var,
  2251. True,
  2252. epsilon,
  2253. [True, True, True],
  2254. )
  2255. @register_decomposition(aten.cudnn_batch_norm_backward)
  2256. @out_wrapper("out0", "out1", "out2")
  2257. def cudnn_batch_norm_backward(
  2258. input: Tensor,
  2259. grad_output: Tensor,
  2260. weight: Tensor,
  2261. running_mean: Optional[Tensor],
  2262. running_var: Optional[Tensor],
  2263. save_mean: Optional[Tensor],
  2264. save_var: Optional[Tensor],
  2265. epsilon: float,
  2266. reserveSpace: Tensor,
  2267. ):
  2268. return aten.native_batch_norm_backward(
  2269. grad_output,
  2270. input,
  2271. weight,
  2272. running_mean,
  2273. running_var,
  2274. save_mean,
  2275. save_var,
  2276. True,
  2277. epsilon,
  2278. [True, True, True],
  2279. )
  2280. @register_decomposition(aten._adaptive_avg_pool2d)
  2281. @out_wrapper()
  2282. @pw_cast_for_opmath
  2283. def adaptive_avg_pool2d(input: Tensor, output_size: tuple[int, int]):
  2284. # Preconditions
  2285. device = input.device
  2286. shape = input.shape
  2287. ndim = len(shape)
  2288. torch._check(
  2289. ndim in (3, 4),
  2290. lambda: f"adaptive_avg_pool2d(): Expected 3D or 4D tensor, but got {ndim}",
  2291. )
  2292. for d in input.shape[-2:]:
  2293. torch._check(
  2294. d != 0,
  2295. lambda: "adaptive_avg_pool2d(): Expected input to have non-zero size for "
  2296. f"non-batch dimensions, but input has shape {tuple(shape)}.",
  2297. )
  2298. # Optimisation (we should also do this in the kernel implementation)
  2299. if shape[-2] % output_size[-2] == 0 and shape[-1] % output_size[-1] == 0:
  2300. stride = tuple(i // o for i, o in zip(shape[-2:], output_size))
  2301. kernel = tuple(
  2302. i - (o - 1) * s for i, o, s in zip(shape[-2:], output_size, stride)
  2303. )
  2304. return torch.nn.functional.avg_pool2d(input, kernel, stride)
  2305. def start_index(a, b, c):
  2306. return torch.div(a * c, b, rounding_mode="trunc")
  2307. def end_index(a, b, c):
  2308. return torch.div((a + 1) * c + b - 1, b, rounding_mode="trunc")
  2309. def compute_idx(in_size, out_size):
  2310. orange = torch.arange(out_size, device=device, dtype=torch.int64)
  2311. i0 = start_index(orange, out_size, in_size)
  2312. # Let length = end_index - start_index, i.e. the length of the pooling kernels
  2313. # length.max() can be computed analytically as follows:
  2314. maxlength = in_size // out_size + 1
  2315. in_size_mod = in_size % out_size
  2316. # adaptive = True iff there are kernels with different lengths
  2317. adaptive = not (in_size_mod == 0 or out_size % in_size_mod == 0)
  2318. if adaptive:
  2319. maxlength += 1
  2320. elif in_size_mod == 0:
  2321. maxlength -= 1
  2322. range_max = torch.arange(maxlength, device=device, dtype=torch.int64)
  2323. idx = i0.unsqueeze(-1) + range_max
  2324. if adaptive:
  2325. # Need to clamp to avoid accessing out-of-bounds memory
  2326. # TODO make minimum accept scalars
  2327. maxval = torch.scalar_tensor(
  2328. in_size - 1, dtype=idx.dtype, device=idx.device
  2329. )
  2330. idx = torch.minimum(idx, maxval)
  2331. # Compute the length
  2332. i1 = end_index(orange, out_size, in_size)
  2333. length = i1 - i0
  2334. else:
  2335. length = maxlength
  2336. return idx, length, range_max, adaptive
  2337. # length is not None if it's constant, otherwise we'll need to compute it
  2338. idxh, length_h, range_max_h, adaptive_h = compute_idx(shape[-2], output_size[-2])
  2339. idxw, length_w, range_max_w, adaptive_w = compute_idx(shape[-1], output_size[-1])
  2340. vals = input[..., _unsqueeze_to_dim(idxh, 4), idxw]
  2341. # Shortcut for the simpler case
  2342. if not adaptive_h and not adaptive_w:
  2343. return torch.mean(vals, dim=(-3, -1))
  2344. def maybe_mask(vals, length, range_max, adaptive, dim):
  2345. if isinstance(length, IntLike):
  2346. return vals, length
  2347. else:
  2348. # zero-out the things we didn't really want to select
  2349. if dim >= 0:
  2350. raise AssertionError(f"dim should be negative when masking, got {dim}")
  2351. # hack
  2352. mask = range_max >= length.unsqueeze(-1)
  2353. if dim == -2:
  2354. mask = _unsqueeze_to_dim(mask, 4)
  2355. vals = torch.masked_fill(vals, mask, 0.0)
  2356. # Compute the length of each window
  2357. length = _unsqueeze_to_dim(length, -dim)
  2358. return vals, length
  2359. vals, length_h = maybe_mask(
  2360. vals, length_h, range_max_h, adaptive=adaptive_h, dim=-2
  2361. )
  2362. vals, length_w = maybe_mask(
  2363. vals, length_w, range_max_w, adaptive=adaptive_w, dim=-1
  2364. )
  2365. # We unroll the sum as we assume that the kernels are going to be small
  2366. ret = None
  2367. for i, j in product(range(vals.shape[-3]), range(vals.shape[-1])):
  2368. if ret is None:
  2369. ret = vals[..., i, :, j]
  2370. else:
  2371. ret = ret + vals[..., i, :, j]
  2372. return ret / (length_h * length_w)
  2373. def _max_unpoolnd(
  2374. self: TensorLike, indices: TensorLike, output_size: list[int], dim: int
  2375. ):
  2376. # If the input tensors self and indices came from max_pool call as
  2377. # required by the documentation, this operation is deterministic
  2378. # because that ensures that if there are two entries in `indices`
  2379. # tensor that are equal, the corresponding values in `self` are also
  2380. # equal. If this condition is not satisfied, the operation is
  2381. # non-deterministic as one of the different values in `self` 'wins'.
  2382. utils.alert_not_deterministic(f"max_unpooling{dim}d_forward_out")
  2383. output_shape = list(self.shape[:-dim]) + list(output_size)
  2384. if any(s == 0 for s in output_shape):
  2385. return self.new_zeros(output_shape)
  2386. nc = reduce(operator.mul, self.shape[:-dim])
  2387. hw = reduce(operator.mul, output_size)
  2388. indices_nc_shape = [1] * self.ndim
  2389. indices_nc_shape[:-dim] = self.shape[:-dim]
  2390. indices_flat = (
  2391. indices + aten.arange(nc, device=self.device).view(indices_nc_shape) * hw
  2392. ).reshape(-1)
  2393. output = self.new_zeros(output_shape)
  2394. return aten._unsafe_index_put(
  2395. output.reshape(-1), [indices_flat], self.reshape(-1), accumulate=False
  2396. ).view(output.shape)
  2397. @register_decomposition(aten.max_unpool2d)
  2398. @out_wrapper()
  2399. def max_unpool2d(
  2400. self: TensorLike,
  2401. indices: TensorLike,
  2402. output_size: list[int],
  2403. ):
  2404. torch._check(
  2405. indices.dtype == torch.int64,
  2406. lambda: f"elements in indices should be type int64 but got: {indices.dtype}",
  2407. )
  2408. torch._check(
  2409. len(output_size) == 2,
  2410. lambda: (
  2411. f"There should be exactly two elements (height, width) in output_size, "
  2412. f"but got {len(output_size)} elements."
  2413. ),
  2414. )
  2415. torch._check(
  2416. self.ndim in (3, 4),
  2417. lambda: (
  2418. f"Input to max_unpooling2d should be a 3d or 4d Tensor, "
  2419. f"but got a tensor with {self.ndim} dimensions."
  2420. ),
  2421. )
  2422. torch._check(
  2423. self.shape == indices.shape,
  2424. lambda: (
  2425. f"Expected shape of indices to be same as that of the input tensor ({self.shape}) "
  2426. f"but got indices tensor with shape: {indices.shape}"
  2427. ),
  2428. )
  2429. for i in range(1, self.ndim):
  2430. torch._check(
  2431. self.size(i) > 0,
  2432. lambda: (
  2433. f"max_unpooling2d(): "
  2434. f"Expected input to have non-zero size for non-batch dimensions, "
  2435. f"but got {self.shape} with dimension {i} being empty."
  2436. ),
  2437. )
  2438. return _max_unpoolnd(self, indices, output_size, 2)
  2439. @register_decomposition(aten.max_unpool3d)
  2440. @out_wrapper()
  2441. def max_unpool3d(
  2442. input: TensorLike,
  2443. indices: TensorLike,
  2444. output_size: list[int],
  2445. stride: list[int],
  2446. padding: list[int],
  2447. ):
  2448. torch._check(
  2449. indices.dtype == torch.int64, lambda: "elements in indices should be type int64"
  2450. )
  2451. torch._check(
  2452. input.ndim in (4, 5),
  2453. lambda: f"Input to max_unpooling3d should be a 4d or 5d Tensor, but got a tensor with {input.ndim} dimensions.",
  2454. )
  2455. torch._check(
  2456. len(output_size) == 3,
  2457. lambda: (
  2458. f"There should be exactly three elements (depth, height, width) in output_size, "
  2459. f"but got {len(output_size)} elements."
  2460. ),
  2461. )
  2462. torch._check(
  2463. len(stride) == 3,
  2464. lambda: f"There should be exactly three elements (depth, height, width) in stride, but got: {len(stride)} elements.",
  2465. )
  2466. torch._check(
  2467. len(padding) == 3,
  2468. lambda: f"There should be exactly three elements (depth, height, width) in padding, but got: {len(padding)} elements.",
  2469. )
  2470. torch._check(
  2471. input.shape == indices.shape,
  2472. lambda: (
  2473. f"Expected shape of indices to be same as that of the input tensor ({input.shape}) "
  2474. f"but got indices tensor with shape: {indices.shape}"
  2475. ),
  2476. )
  2477. for i in range(1, input.ndim):
  2478. torch._check(
  2479. input.size(i) > 0,
  2480. lambda: (
  2481. f"max_unpooling3d(): "
  2482. f"Expected input to have non-zero size for non-batch dimensions, "
  2483. f"but got {input.shape} with dimension {i} being empty."
  2484. ),
  2485. )
  2486. torch._check(
  2487. stride[0] > 0 and stride[1] > 0 and stride[2] > 0,
  2488. lambda: f"strides should be greater than zero, but got stride: {stride}",
  2489. )
  2490. return _max_unpoolnd(input, indices, output_size, 3)
  2491. @register_decomposition(aten.index_add_)
  2492. def index_add_(
  2493. x: TensorLike,
  2494. dim: int,
  2495. index: TensorLike,
  2496. tensor: TensorLike,
  2497. *,
  2498. alpha: NumberType = 1,
  2499. ):
  2500. return _index_add(x, dim, index, tensor, inplace=True, alpha=alpha)
  2501. @register_decomposition(aten.index_add)
  2502. @out_wrapper()
  2503. def index_add(
  2504. x: TensorLike,
  2505. dim: int,
  2506. index: TensorLike,
  2507. tensor: TensorLike,
  2508. *,
  2509. alpha: NumberType = 1,
  2510. ):
  2511. return _index_add(x, dim, index, tensor, inplace=False, alpha=alpha)
  2512. def _index_add(
  2513. x: TensorLike,
  2514. dim: int,
  2515. index: TensorLike,
  2516. tensor: TensorLike,
  2517. *,
  2518. inplace: bool,
  2519. alpha: NumberType = 1,
  2520. ):
  2521. dim = utils.canonicalize_dims(x.ndim, dim)
  2522. torch._check(
  2523. index.ndim <= 1,
  2524. lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
  2525. )
  2526. index_size = index.size(0) if index.ndim == 1 else 1
  2527. tensor_size = tensor.size(dim) if tensor.ndim > 0 else 1
  2528. torch._check(
  2529. tensor_size == index_size,
  2530. lambda: f"Number of indices ({index_size}) should be equal to tensor.size(dim) ({tensor_size}), for {dim=}",
  2531. )
  2532. if alpha != 1:
  2533. python_type = utils.dtype_to_type(x.dtype)
  2534. torch._check(
  2535. python_type is bool
  2536. or utils.is_weakly_lesser_type(type(alpha), python_type),
  2537. lambda: f"alpha argument of type {type(alpha)} cannot be safely cast to type {python_type}!",
  2538. )
  2539. tensor = tensor * alpha
  2540. # Treat scalars as elements of \R^1
  2541. zero_dim = x.ndim == 0
  2542. x1 = x.unsqueeze(0) if zero_dim else x
  2543. idx = (None,) * dim + (index,)
  2544. index_put = aten.index_put_ if inplace else aten.index_put
  2545. out = index_put(x1, idx, tensor, accumulate=True)
  2546. if inplace:
  2547. return x
  2548. else:
  2549. return out.squeeze(0) if zero_dim else out.contiguous()
  2550. @register_decomposition(aten.pad_sequence.default)
  2551. @aten.pad_sequence.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  2552. def pad_sequence(sequences, batch_first=False, padding_value=0.0):
  2553. torch._check(len(sequences) > 0, lambda: "received an empty list of sequences")
  2554. sequences_size = len(sequences)
  2555. max_size = sequences[0].size()
  2556. trailing_dims = max_size[1:]
  2557. max_len = max(x.size(0) for x in sequences)
  2558. if batch_first:
  2559. out_dims = (sequences_size, max_len)
  2560. else:
  2561. out_dims = (max_len, sequences_size)
  2562. out_dims = out_dims + trailing_dims
  2563. out = sequences[0].new_full(out_dims, padding_value)
  2564. dim_paddings = (0, 0) * len(trailing_dims)
  2565. for i in range(sequences_size):
  2566. currseq = sequences[i]
  2567. row = aten.constant_pad_nd(
  2568. currseq, dim_paddings + (0, max_len - currseq.size(0)), padding_value
  2569. )
  2570. if batch_first:
  2571. out = aten.select_scatter(out, row, dim=0, index=i)
  2572. else:
  2573. out = aten.select_scatter(out, row, dim=1, index=i)
  2574. return out
  2575. @register_decomposition(aten.index_copy_)
  2576. def index_copy_(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
  2577. return _index_copy(x, dim, index, tensor, inplace=True)
  2578. @register_decomposition(aten.index_copy)
  2579. @out_wrapper()
  2580. def index_copy(x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike):
  2581. return _index_copy(x, dim, index, tensor, inplace=False)
  2582. def _index_copy(
  2583. x: TensorLike, dim: int, index: TensorLike, tensor: TensorLike, *, inplace: bool
  2584. ):
  2585. dim = utils.canonicalize_dims(x.ndim, dim)
  2586. torch._check(
  2587. index.ndim <= 1,
  2588. lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
  2589. )
  2590. # Treat scalars as elements of \R^1
  2591. zero_dim = x.ndim == 0
  2592. x1 = x.unsqueeze(0) if zero_dim else x
  2593. index = index.unsqueeze(0) if index.ndim == 0 else index
  2594. idx = (None,) * dim + (index,)
  2595. index_put = aten.index_put_ if inplace else aten.index_put
  2596. out = index_put(x1, idx, tensor)
  2597. if inplace:
  2598. return x
  2599. else:
  2600. return out.squeeze(0) if zero_dim else out.contiguous()
  2601. # nb: Should use acc_t, not op_math
  2602. @register_decomposition(aten.log_sigmoid_forward)
  2603. @out_wrapper("output", "buffer")
  2604. @pw_cast_for_opmath
  2605. def log_sigmoid_forward(self: Tensor) -> tuple[Tensor, Tensor]:
  2606. min = torch.minimum(self.new_zeros(()), self)
  2607. z = torch.exp(-torch.abs(self))
  2608. if self.is_cuda or self.is_xpu:
  2609. buffer = self.new_zeros((0,))
  2610. else:
  2611. buffer = z
  2612. return min - torch.log1p(z), buffer
  2613. @register_decomposition(aten.uniform)
  2614. @out_wrapper()
  2615. def uniform(
  2616. x: Tensor,
  2617. low: Union[bool, int, float] = 0.0,
  2618. high: Union[bool, int, float] = 1.0,
  2619. generator: Optional[torch.Generator] = None,
  2620. ):
  2621. return prims._uniform_helper(
  2622. x.shape,
  2623. stride=x.stride(),
  2624. low=sym_float(low),
  2625. high=sym_float(high),
  2626. dtype=x.dtype,
  2627. device=x.device,
  2628. generator=generator,
  2629. )
  2630. @register_decomposition(aten.uniform_)
  2631. def uniform_(self, low=0, high=1, generator=None):
  2632. return self.copy_(uniform(self, low, high, generator))
  2633. # aten/src/ATen/native/UpSample.cpp compute_output_size
  2634. def upsample_compute_output_size(input_size, output_size, scale_factors):
  2635. spatial_dimensions = len(input_size) - 2
  2636. if output_size is not None:
  2637. torch._check(
  2638. scale_factors is None,
  2639. lambda: "Must specify exactly one of output_size and scale_factors",
  2640. )
  2641. torch._check(len(output_size) == spatial_dimensions, lambda: "")
  2642. return output_size
  2643. if scale_factors is not None:
  2644. # NB: this isn't necessary lol
  2645. torch._check(
  2646. output_size is None,
  2647. lambda: "Must specify exactly one of output_size and scale_factors",
  2648. )
  2649. torch._check(len(scale_factors) == spatial_dimensions, lambda: "")
  2650. output_size = []
  2651. for i, s in enumerate(scale_factors):
  2652. if int(s) == s:
  2653. output_size.append(input_size[i + 2] * int(s))
  2654. else:
  2655. output_size.append(sym_int(input_size[i + 2] * s))
  2656. return output_size
  2657. torch._check(
  2658. False, lambda: "Must specify exactly one of output_size and scale_factors"
  2659. )
  2660. def get_scale_value(scales, idx):
  2661. if scales is None:
  2662. return None
  2663. return scales[idx]
  2664. @register_decomposition(aten.upsample_nearest1d.vec)
  2665. @register_decomposition(aten.upsample_nearest2d.vec)
  2666. @register_decomposition(aten.upsample_nearest3d.vec)
  2667. @aten.upsample_nearest1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  2668. @aten.upsample_nearest1d.vec.py_impl(DispatchKey.Autograd)
  2669. @aten.upsample_nearest2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  2670. @aten.upsample_nearest2d.vec.py_impl(DispatchKey.Autograd)
  2671. @aten.upsample_nearest3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  2672. @aten.upsample_nearest3d.vec.py_impl(DispatchKey.Autograd)
  2673. def _upsample_nearest_vec(
  2674. input: Tensor,
  2675. output_size: Optional[list[int]],
  2676. scale_factors: Optional[list[float]],
  2677. ) -> Tensor:
  2678. osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
  2679. scales = (
  2680. scale_factors if scale_factors else [None] * len(osize) # type: ignore[list-item]
  2681. )
  2682. return _upsample_nearest(input, osize, scales)
  2683. @register_decomposition(aten._upsample_nearest_exact1d.vec)
  2684. @register_decomposition(aten._upsample_nearest_exact2d.vec)
  2685. @register_decomposition(aten._upsample_nearest_exact3d.vec)
  2686. @aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  2687. @aten._upsample_nearest_exact1d.vec.py_impl(DispatchKey.Autograd)
  2688. @aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  2689. @aten._upsample_nearest_exact2d.vec.py_impl(DispatchKey.Autograd)
  2690. @aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  2691. @aten._upsample_nearest_exact3d.vec.py_impl(DispatchKey.Autograd)
  2692. def _upsample_nearest_exact_vec(
  2693. input: Tensor,
  2694. output_size: Optional[list[int]],
  2695. scale_factors: Optional[list[float]],
  2696. ) -> Tensor:
  2697. osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
  2698. scales = (
  2699. scale_factors if scale_factors else [None] * len(osize) # type: ignore[list-item]
  2700. )
  2701. return _upsample_nearest(input, osize, scales, exact=True)
  2702. def _compute_upsample_nearest_indices(input, output_size, scales, exact=False):
  2703. # For each dim in output_size, compute the set of input indices used
  2704. # to produce the upsampled output.
  2705. indices = []
  2706. num_spatial_dims = len(output_size)
  2707. offset = 0.5 if exact else 0.0
  2708. for d in range(num_spatial_dims):
  2709. # Math matches aten/src/ATen/native/cpu/UpSampleKernel.cpp
  2710. #
  2711. # Indices are computed as following:
  2712. # scale = isize / osize
  2713. # Case: exact=False
  2714. # input_index = floor(output_index * scale)
  2715. # Same as OpenCV INTER_NEAREST
  2716. #
  2717. # Case: exact=False
  2718. # index_f32 = (output_index + 0.5) * scale - 0.5
  2719. # input_index = round(index_f32)
  2720. # Same as Pillow and Scikit-Image/Scipy ndi.zoom
  2721. osize = output_size[d]
  2722. isize = input.shape[-num_spatial_dims + d]
  2723. # check for scales[d] > 0 is in compute_scales_value in aten/src/ATen/native/UpSample.h
  2724. scale = (
  2725. isize / (isize * scales[d])
  2726. if scales[d] is not None and scales[d] > 0
  2727. else isize / osize
  2728. )
  2729. output_indices = torch.arange(osize, dtype=torch.float32, device=input.device)
  2730. input_indices = ((output_indices + offset) * scale).to(torch.int64)
  2731. for _ in range(num_spatial_dims - 1 - d):
  2732. input_indices = input_indices.unsqueeze(-1)
  2733. indices.append(input_indices)
  2734. return indices
  2735. @register_decomposition([aten.upsample_nearest1d.default, aten.upsample_nearest1d.out])
  2736. @aten.upsample_nearest1d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  2737. @aten.upsample_nearest1d.default.py_impl(DispatchKey.Autograd)
  2738. @out_wrapper(preserve_memory_format=True, exact_dtype=True)
  2739. def upsample_nearest1d(
  2740. input: Tensor,
  2741. output_size: list[int],
  2742. scales: Optional[float] = None,
  2743. ) -> Tensor:
  2744. return _upsample_nearest(input, output_size, [scales])
  2745. @register_decomposition(
  2746. [aten._upsample_nearest_exact1d.default, aten._upsample_nearest_exact1d.out]
  2747. )
  2748. @aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  2749. @aten._upsample_nearest_exact1d.default.py_impl(DispatchKey.Autograd)
  2750. @out_wrapper(preserve_memory_format=True, exact_dtype=True)
  2751. def upsample_nearest_exact1d(
  2752. input: Tensor,
  2753. output_size: list[int],
  2754. scales: Optional[float] = None,
  2755. ) -> Tensor:
  2756. return _upsample_nearest(input, output_size, [scales], exact=True)
  2757. @register_decomposition([aten.upsample_nearest2d.default, aten.upsample_nearest2d.out])
  2758. @aten.upsample_nearest2d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  2759. @aten.upsample_nearest2d.default.py_impl(DispatchKey.Autograd)
  2760. @out_wrapper(preserve_memory_format=True, exact_dtype=True)
  2761. def upsample_nearest2d(
  2762. input: Tensor,
  2763. output_size: list[int],
  2764. scales_h: Optional[float] = None,
  2765. scales_w: Optional[float] = None,
  2766. ) -> Tensor:
  2767. return _upsample_nearest(input, output_size, [scales_h, scales_w])
  2768. @register_decomposition(
  2769. [aten._upsample_nearest_exact2d.default, aten._upsample_nearest_exact2d.out]
  2770. )
  2771. @aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  2772. @aten._upsample_nearest_exact2d.default.py_impl(DispatchKey.Autograd)
  2773. @out_wrapper(preserve_memory_format=True, exact_dtype=True)
  2774. def _upsample_nearest_exact2d(
  2775. input: Tensor,
  2776. output_size: list[int],
  2777. scales_h: Optional[float] = None,
  2778. scales_w: Optional[float] = None,
  2779. ) -> Tensor:
  2780. return _upsample_nearest(input, output_size, [scales_h, scales_w], exact=True)
  2781. @register_decomposition([aten.upsample_nearest3d.default, aten.upsample_nearest3d.out])
  2782. @aten.upsample_nearest3d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  2783. @aten.upsample_nearest3d.default.py_impl(DispatchKey.Autograd)
  2784. @out_wrapper(preserve_memory_format=True, exact_dtype=True)
  2785. def upsample_nearest3d(
  2786. input: Tensor,
  2787. output_size: list[int],
  2788. scales_d: Optional[float] = None,
  2789. scales_h: Optional[float] = None,
  2790. scales_w: Optional[float] = None,
  2791. ) -> Tensor:
  2792. return _upsample_nearest(input, output_size, [scales_d, scales_h, scales_w])
  2793. @register_decomposition(
  2794. [aten._upsample_nearest_exact3d.default, aten._upsample_nearest_exact3d.out]
  2795. )
  2796. @aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  2797. @aten._upsample_nearest_exact3d.default.py_impl(DispatchKey.Autograd)
  2798. @out_wrapper(preserve_memory_format=True, exact_dtype=True)
  2799. def _upsample_nearest_exact3d(
  2800. input: Tensor,
  2801. output_size: list[int],
  2802. scales_d: Optional[float] = None,
  2803. scales_h: Optional[float] = None,
  2804. scales_w: Optional[float] = None,
  2805. ) -> Tensor:
  2806. return _upsample_nearest(
  2807. input, output_size, [scales_d, scales_h, scales_w], exact=True
  2808. )
  2809. @pw_cast_for_opmath
  2810. def _upsample_nearest(
  2811. input: Tensor,
  2812. output_size: list[int],
  2813. scales: list[Optional[float]],
  2814. exact: bool = False,
  2815. ) -> Tensor:
  2816. spatial_indices = _compute_upsample_nearest_indices(
  2817. input, output_size, scales, exact=exact
  2818. )
  2819. indices = [None, None] + spatial_indices
  2820. result = aten._unsafe_index(input, indices)
  2821. if result.ndim == 4:
  2822. # convert output to correct memory format, if necessary
  2823. memory_format = utils.suggest_memory_format(input)
  2824. # following "heuristic: only use channels_last path when it's faster than the contiguous path"
  2825. n_channels = input.shape[1]
  2826. if input.device.type == "cuda" and n_channels < 4:
  2827. memory_format = torch.contiguous_format
  2828. result = result.contiguous(memory_format=memory_format)
  2829. return result
  2830. def gather_params(params, has_biases, has_projections):
  2831. if has_biases and has_projections:
  2832. group_size = 5
  2833. elif has_biases:
  2834. group_size = 4
  2835. elif has_projections:
  2836. group_size = 3
  2837. else:
  2838. group_size = 2
  2839. if len(params) % group_size != 0:
  2840. raise AssertionError(
  2841. f"len(params)={len(params)} is not divisible by group_size={group_size}"
  2842. )
  2843. return [
  2844. tuple(params[i : i + group_size]) for i in range(0, len(params), group_size)
  2845. ]
  2846. def params_hiddens(params, hiddens, i, bidirectional):
  2847. if bidirectional:
  2848. cur_params, cur_hidden = params[2 * i], hiddens[2 * i]
  2849. bidir_params, bidir_hidden = params[2 * i + 1], hiddens[2 * i + 1]
  2850. else:
  2851. cur_params, cur_hidden = params[i], hiddens[i]
  2852. bidir_params, bidir_hidden = None, None
  2853. return cur_params, cur_hidden, bidir_params, bidir_hidden
  2854. def update_hidden_for_packed(cur_hidden, last_batch_size, batch_size, hiddens):
  2855. if last_batch_size <= batch_size:
  2856. raise AssertionError(
  2857. f"last_batch_size ({last_batch_size}) must be > batch_size ({batch_size})"
  2858. )
  2859. hiddens.append(cur_hidden.narrow(0, batch_size, last_batch_size - batch_size))
  2860. return cur_hidden.narrow(0, 0, batch_size)
  2861. def update_hidden_for_packed_reverse(
  2862. cur_hidden, last_batch_size, batch_size, inp_hidden
  2863. ):
  2864. if last_batch_size == batch_size:
  2865. return cur_hidden
  2866. if last_batch_size >= batch_size:
  2867. raise AssertionError(
  2868. f"last_batch_size ({last_batch_size}) must be < batch_size ({batch_size})"
  2869. )
  2870. return torch.concat(
  2871. (
  2872. cur_hidden,
  2873. inp_hidden.narrow(0, last_batch_size, batch_size - last_batch_size),
  2874. )
  2875. )
  2876. def one_layer_rnn_data(
  2877. inp, hidden, params, has_biases, hidden_fn, batch_sizes, reverse=False
  2878. ):
  2879. ih_weight = params[0]
  2880. hh_weight = params[1]
  2881. ih_bias = params[2] if has_biases else None
  2882. hh_bias = params[3] if has_biases else None
  2883. step_output = []
  2884. hiddens: list[torch.Tensor] = []
  2885. last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0]
  2886. cur_hidden = hidden.narrow(0, 0, last_batch_size)
  2887. split_inp = torch.split(inp, list(batch_sizes))
  2888. if reverse:
  2889. split_inp = split_inp[::-1]
  2890. for inp in split_inp:
  2891. i = inp.shape[0]
  2892. if last_batch_size == i:
  2893. pass # don't update cur_hidden
  2894. # this will only happen when reverse=False, since batch sizes are sorted largest -> smallest
  2895. elif reverse:
  2896. cur_hidden = update_hidden_for_packed_reverse(
  2897. cur_hidden, last_batch_size, i, hidden
  2898. )
  2899. else:
  2900. cur_hidden = update_hidden_for_packed(
  2901. cur_hidden, last_batch_size, i, hiddens
  2902. )
  2903. cur_hidden = hidden_fn(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias)
  2904. last_batch_size = i
  2905. step_output.append(cur_hidden)
  2906. if reverse:
  2907. step_output.reverse()
  2908. else:
  2909. hiddens.append(cur_hidden)
  2910. hiddens.reverse()
  2911. out = torch.cat(step_output, 0)
  2912. hidden_out = torch.cat(hiddens, 0) if not reverse else cur_hidden
  2913. return out, hidden_out
  2914. def rnn_cell(nonlinearity):
  2915. def inner(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
  2916. return nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + i)
  2917. return inner
  2918. def rnn_cell_data(nonlinearity):
  2919. def inner(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
  2920. i = F.linear(i, ih_weight, ih_bias)
  2921. return nonlinearity(F.linear(cur_hidden, hh_weight, hh_bias) + i)
  2922. return inner
  2923. def one_layer_rnn(inp, hidden, params, has_biases, hidden_fn, reverse=False):
  2924. ih_weight = params[0]
  2925. hh_weight = params[1]
  2926. ih_bias = params[2] if has_biases else None
  2927. hh_bias = params[3] if has_biases else None
  2928. precomputed_input = F.linear(inp, ih_weight, ih_bias)
  2929. precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input
  2930. cur_hidden = hidden.unsqueeze(0)
  2931. step_output = []
  2932. for i in precomputed_input:
  2933. cur_hidden = hidden_fn(i, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias)
  2934. step_output.append(cur_hidden)
  2935. if reverse:
  2936. step_output.reverse()
  2937. out = torch.cat(step_output, 0)
  2938. return out, cur_hidden.squeeze(0)
  2939. def mkldnn_one_layer_lstm(inp, hidden, params, has_biases, reverse=False):
  2940. w0 = params[0]
  2941. w1 = params[1]
  2942. if has_biases:
  2943. w2 = params[2]
  2944. w3 = params[3]
  2945. else:
  2946. w2 = torch.zeros(w0.size())
  2947. w3 = torch.zeros(w1.size())
  2948. hx = hidden[0].unsqueeze(0)
  2949. cx = hidden[1].unsqueeze(0)
  2950. batch_sizes: list[int] = []
  2951. mode = 2 # third_party/ideep/include/ideep/abstract_types.hpp: ideep::rnn_kind::LSTM = 2
  2952. hidden_size = hx.size(2)
  2953. num_layers = 1
  2954. # _rnn_helper already handles bidirectional and batch_first so we hard-code them to False here
  2955. bidirectional = False
  2956. batch_first = False
  2957. train = False
  2958. # If batch_first, inp has been permuted in _rnn_helper. Convert to contiguous here.
  2959. # Same as aten/src/ATen/native/mkldnn/RNN.cpp: mkldnn_rnn: input = input.contiguous();
  2960. inp = inp.contiguous()
  2961. hx = hx.contiguous()
  2962. cx = cx.contiguous()
  2963. outputs = torch.ops.aten.mkldnn_rnn_layer.default(
  2964. inp,
  2965. w0,
  2966. w1,
  2967. w2,
  2968. w3,
  2969. hx,
  2970. cx,
  2971. reverse,
  2972. batch_sizes,
  2973. mode,
  2974. hidden_size,
  2975. num_layers,
  2976. has_biases,
  2977. bidirectional,
  2978. batch_first,
  2979. train,
  2980. )
  2981. y, hy, cy = outputs[0], outputs[1], outputs[2]
  2982. return y, (hy.squeeze(0), cy.squeeze(0))
  2983. def _rnn_helper(
  2984. input,
  2985. hidden,
  2986. params,
  2987. has_biases,
  2988. num_layers,
  2989. dropout,
  2990. train,
  2991. bidirectional,
  2992. batch_first,
  2993. layer_fn,
  2994. ):
  2995. input = input.transpose(0, 1) if batch_first else input
  2996. final_hiddens = []
  2997. for i in range(num_layers):
  2998. cur_params, cur_hidden, bidir_params, bidir_hidden = params_hiddens(
  2999. params, hidden, i, bidirectional
  3000. )
  3001. dropout = dropout if (train and num_layers < i - 1) else 0.0
  3002. fwd_inp, fwd_hidden = layer_fn(input, cur_hidden, cur_params, has_biases)
  3003. final_hiddens.append(fwd_hidden)
  3004. if bidirectional:
  3005. bwd_inp, bwd_hidden = layer_fn(
  3006. input, bidir_hidden, bidir_params, has_biases, reverse=True
  3007. )
  3008. final_hiddens.append(bwd_hidden)
  3009. if bidirectional:
  3010. input = torch.cat([fwd_inp, bwd_inp], fwd_inp.dim() - 1) # type: ignore[possibly-undefined]
  3011. else:
  3012. input = fwd_inp
  3013. if dropout != 0 and train and i < num_layers - 1:
  3014. input = torch.dropout(input, dropout, train=True)
  3015. input = input.transpose(0, 1) if batch_first else input
  3016. return input, final_hiddens
  3017. @register_decomposition(aten.rnn_tanh.input)
  3018. @aten.rnn_tanh.input.py_impl(DispatchKey.CompositeImplicitAutograd)
  3019. @aten.rnn_tanh.input.py_impl(DispatchKey.Autograd)
  3020. def rnn_tanh_input(
  3021. input,
  3022. hx,
  3023. params,
  3024. has_biases,
  3025. num_layers,
  3026. dropout,
  3027. train,
  3028. bidirectional,
  3029. batch_first,
  3030. ):
  3031. hidden = hx.unbind(0)
  3032. params = gather_params(params, has_biases, False)
  3033. out, final_hiddens = _rnn_helper(
  3034. input,
  3035. hidden,
  3036. params,
  3037. has_biases,
  3038. num_layers,
  3039. dropout,
  3040. train,
  3041. bidirectional,
  3042. batch_first,
  3043. partial(one_layer_rnn, hidden_fn=rnn_cell(torch.tanh)),
  3044. )
  3045. return out, torch.stack(final_hiddens, 0)
  3046. @register_decomposition(aten.rnn_relu.input)
  3047. @aten.rnn_relu.input.py_impl(DispatchKey.CompositeImplicitAutograd)
  3048. @aten.rnn_relu.input.py_impl(DispatchKey.Autograd)
  3049. def rnn_relu_input(
  3050. input,
  3051. hx,
  3052. params,
  3053. has_biases,
  3054. num_layers,
  3055. dropout,
  3056. train,
  3057. bidirectional,
  3058. batch_first,
  3059. ):
  3060. hidden = hx.unbind(0)
  3061. params = gather_params(params, has_biases, False)
  3062. out, final_hiddens = _rnn_helper(
  3063. input,
  3064. hidden,
  3065. params,
  3066. has_biases,
  3067. num_layers,
  3068. dropout,
  3069. train,
  3070. bidirectional,
  3071. batch_first,
  3072. partial(one_layer_rnn, hidden_fn=rnn_cell(torch.relu)),
  3073. )
  3074. return out, torch.stack(final_hiddens, 0)
  3075. @register_decomposition(aten.rnn_relu.data)
  3076. @aten.rnn_relu.data.py_impl(DispatchKey.CompositeImplicitAutograd)
  3077. @aten.rnn_relu.data.py_impl(DispatchKey.Autograd)
  3078. def rnn_relu_data(
  3079. data,
  3080. batch_sizes,
  3081. hx,
  3082. params,
  3083. has_biases,
  3084. num_layers,
  3085. dropout,
  3086. train,
  3087. bidirectional,
  3088. ):
  3089. hidden = hx.unbind(0)
  3090. params = gather_params(params, has_biases, False)
  3091. out, final_hiddens = _rnn_helper(
  3092. data,
  3093. hidden,
  3094. params,
  3095. has_biases,
  3096. num_layers,
  3097. dropout,
  3098. train,
  3099. bidirectional,
  3100. False,
  3101. partial(
  3102. one_layer_rnn_data,
  3103. batch_sizes=batch_sizes,
  3104. hidden_fn=rnn_cell_data(torch.relu),
  3105. ),
  3106. )
  3107. return out, torch.stack(final_hiddens, 0)
  3108. @register_decomposition(aten.rnn_tanh.data)
  3109. @aten.rnn_tanh.data.py_impl(DispatchKey.CompositeImplicitAutograd)
  3110. @aten.rnn_tanh.data.py_impl(DispatchKey.Autograd)
  3111. def rnn_tanh_data(
  3112. data,
  3113. batch_sizes,
  3114. hx,
  3115. params,
  3116. has_biases,
  3117. num_layers,
  3118. dropout,
  3119. train,
  3120. bidirectional,
  3121. ):
  3122. hidden = hx.unbind(0)
  3123. params = gather_params(params, has_biases, False)
  3124. out, final_hiddens = _rnn_helper(
  3125. data,
  3126. hidden,
  3127. params,
  3128. has_biases,
  3129. num_layers,
  3130. dropout,
  3131. train,
  3132. bidirectional,
  3133. False,
  3134. partial(
  3135. one_layer_rnn_data,
  3136. batch_sizes=batch_sizes,
  3137. hidden_fn=rnn_cell_data(torch.tanh),
  3138. ),
  3139. )
  3140. return out, torch.stack(final_hiddens, 0)
  3141. def lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim):
  3142. gates = F.linear(hx, hh_weight, hh_bias) + inp
  3143. chunked_gates = gates.chunk(4, chunk_dim)
  3144. in_gate = chunked_gates[0].sigmoid()
  3145. forget_gate = chunked_gates[1].sigmoid()
  3146. cell_gate = chunked_gates[2].tanh()
  3147. out_gate = chunked_gates[3].sigmoid()
  3148. cy = forget_gate * cx + (in_gate * cell_gate)
  3149. hy = out_gate * cy.tanh()
  3150. hy = hy if hr_weight is None else F.linear(hy, hr_weight, None)
  3151. return hy, cy
  3152. def one_layer_lstm(inp, hidden, params, has_biases, reverse=False):
  3153. ih_weight = params[0]
  3154. hh_weight = params[1]
  3155. ih_bias = params[2] if has_biases else None
  3156. hh_bias = params[3] if has_biases else None
  3157. hr_weight = (
  3158. params[4] if len(params) == 5 else params[2] if len(params) == 3 else None
  3159. )
  3160. hx = hidden[0].unsqueeze(0)
  3161. cx = hidden[1].unsqueeze(0)
  3162. precomputed_input = F.linear(inp, ih_weight, ih_bias)
  3163. precomputed_input = precomputed_input.flip(0) if reverse else precomputed_input
  3164. step_output = []
  3165. for inp in precomputed_input:
  3166. hx, cx = lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim=2)
  3167. step_output.append(hx)
  3168. if reverse:
  3169. step_output.reverse()
  3170. out = torch.cat(step_output, 0)
  3171. return out, (hx.squeeze(1), cx.squeeze(1))
  3172. def one_layer_lstm_data(inp, hidden, params, has_biases, batch_sizes, reverse=False):
  3173. ih_weight = params[0]
  3174. hh_weight = params[1]
  3175. ih_bias = params[2] if has_biases else None
  3176. hh_bias = params[3] if has_biases else None
  3177. hr_weight = (
  3178. params[4] if len(params) == 5 else params[2] if len(params) == 3 else None
  3179. )
  3180. step_output = []
  3181. hiddens = []
  3182. last_batch_size = batch_sizes[-1] if reverse else batch_sizes[0]
  3183. split_inp = torch.split(inp, list(batch_sizes))
  3184. if reverse:
  3185. split_inp = split_inp[::-1]
  3186. orig_hx = hidden[0]
  3187. orig_cx = hidden[1]
  3188. hx, cx = (
  3189. orig_hx.narrow(0, 0, last_batch_size),
  3190. orig_cx.narrow(0, 0, last_batch_size),
  3191. )
  3192. for inp in split_inp:
  3193. i = inp.shape[0]
  3194. inp = F.linear(inp, ih_weight, ih_bias)
  3195. # this will only happen when reverse=False, since batch sizes are sorted largest -> smallest
  3196. if i < last_batch_size:
  3197. hiddens.append(
  3198. (
  3199. hx.narrow(0, i, last_batch_size - i),
  3200. cx.narrow(0, i, last_batch_size - i),
  3201. )
  3202. )
  3203. hx, cx = hx.narrow(0, 0, i), cx.narrow(0, 0, i)
  3204. # this will only happen when reverse=True
  3205. if i > last_batch_size:
  3206. hx = torch.concat(
  3207. (hx, orig_hx.narrow(0, last_batch_size, i - last_batch_size)), 0
  3208. )
  3209. cx = torch.concat(
  3210. (cx, orig_cx.narrow(0, last_batch_size, i - last_batch_size)), 0
  3211. )
  3212. hx, cx = lstm_cell(inp, hx, cx, hh_weight, hh_bias, hr_weight, chunk_dim=1)
  3213. last_batch_size = i
  3214. step_output.append(hx)
  3215. if reverse:
  3216. step_output.reverse()
  3217. hidden_out = (hx, cx)
  3218. else:
  3219. hiddens.append((hx, cx))
  3220. hiddens.reverse()
  3221. hidden0, hidden1 = zip(*hiddens)
  3222. hidden_out = torch.cat(hidden0, 0), torch.cat(hidden1, 0)
  3223. out = torch.cat(step_output, 0)
  3224. return out, hidden_out
  3225. def select_one_layer_lstm_function(input, hx, params):
  3226. r"""Check whether we could use decompose lstm with mkldnn_rnn_layer.
  3227. All the below conditions need to be met:
  3228. * ``torch._C._get_mkldnn_enabled()`` returns ``True``.
  3229. * All the input args are on CPU.
  3230. * The dtypes of args are either torch.float or torch.bfloat16.
  3231. * Inference.
  3232. * ``has_projections`` returns ``False``.
  3233. Args:
  3234. * input: the input sequence to LSTM
  3235. * hx: a tuple of the input hidden state and cell state ``(h_0, c_0)`` to LSTM
  3236. * params: the weight and bias tensors of LSTM
  3237. """
  3238. def use_mkldnn(input, hx, params):
  3239. if not torch._C._get_mkldnn_enabled():
  3240. return False
  3241. tensors = [input] + list(hx) + list(chain.from_iterable(params))
  3242. devices = {t.device for t in tensors}
  3243. if len(devices) != 1:
  3244. return False
  3245. device = devices.pop()
  3246. if device != torch.device("cpu"):
  3247. return False
  3248. # With autocast, possible to have mixed dtype here
  3249. dtypes = {t.dtype for t in tensors}
  3250. for dtype in dtypes:
  3251. if dtype not in [torch.float, torch.bfloat16]:
  3252. return False
  3253. if input.requires_grad:
  3254. return False
  3255. has_projections = hx[0].size(2) != hx[1].size(2)
  3256. if has_projections:
  3257. return False
  3258. return True
  3259. # mkldnn_one_layer_lstm does not depend on seq_len while one_layer_lstm
  3260. # will expand over the seq_len dim
  3261. if use_mkldnn(input, hx, params):
  3262. return mkldnn_one_layer_lstm
  3263. else:
  3264. return one_layer_lstm
  3265. @register_decomposition(aten.lstm.input)
  3266. @aten.lstm.input.py_impl(DispatchKey.CompositeImplicitAutograd)
  3267. @aten.lstm.input.py_impl(DispatchKey.Autograd)
  3268. def lstm_impl(
  3269. input,
  3270. hx,
  3271. params,
  3272. has_biases,
  3273. num_layers,
  3274. dropout,
  3275. train,
  3276. bidirectional,
  3277. batch_first,
  3278. ):
  3279. if len(hx) != 2:
  3280. raise AssertionError(f"lstm expects two hidden states, got {len(hx)}")
  3281. params = gather_params(params, has_biases, hx[0].size(2) != hx[1].size(2))
  3282. hidden = list(zip(hx[0], hx[1]))
  3283. layer_fn = select_one_layer_lstm_function(input, hx, params)
  3284. out, final_hiddens = _rnn_helper(
  3285. input,
  3286. hidden,
  3287. params,
  3288. has_biases,
  3289. num_layers,
  3290. dropout,
  3291. train,
  3292. bidirectional,
  3293. batch_first,
  3294. layer_fn,
  3295. )
  3296. final_hiddens = list(zip(*final_hiddens))
  3297. return out, torch.stack(final_hiddens[0], 0), torch.stack(final_hiddens[1], 0)
  3298. @register_decomposition(aten.lstm.data)
  3299. @aten.lstm.data.py_impl(DispatchKey.CompositeImplicitAutograd)
  3300. @aten.lstm.data.py_impl(DispatchKey.Autograd)
  3301. def lstm_data_impl(
  3302. data,
  3303. batch_sizes,
  3304. hx,
  3305. params,
  3306. has_biases,
  3307. num_layers,
  3308. dropout,
  3309. train,
  3310. bidirectional,
  3311. ):
  3312. if len(hx) != 2:
  3313. raise AssertionError(f"lstm expects two hidden states, got {len(hx)}")
  3314. params = gather_params(params, has_biases, hx[0].size(2) != hx[1].size(2))
  3315. hidden = list(zip(hx[0], hx[1]))
  3316. out, final_hiddens = _rnn_helper(
  3317. data,
  3318. hidden,
  3319. params,
  3320. has_biases,
  3321. num_layers,
  3322. dropout,
  3323. train,
  3324. bidirectional,
  3325. False,
  3326. partial(one_layer_lstm_data, batch_sizes=batch_sizes),
  3327. )
  3328. final_hiddens = list(zip(*final_hiddens))
  3329. return out, torch.stack(final_hiddens[0], 0), torch.stack(final_hiddens[1], 0)
  3330. def gru_cell(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
  3331. chunked_igates = inp.chunk(3, 1)
  3332. chunked_hgates = F.linear(cur_hidden, hh_weight, hh_bias).chunk(3, 2)
  3333. reset_gate = (chunked_hgates[0] + chunked_igates[0]).sigmoid()
  3334. input_gate = (chunked_hgates[1] + chunked_igates[1]).sigmoid()
  3335. new_gate = (chunked_igates[2] + (chunked_hgates[2] * reset_gate)).tanh()
  3336. return (cur_hidden - new_gate) * input_gate + new_gate
  3337. def gru_cell_data(inp, cur_hidden, ih_weight, ih_bias, hh_weight, hh_bias):
  3338. chunked_igates = F.linear(inp, ih_weight, ih_bias).chunk(3, 1)
  3339. chunked_hgates = F.linear(cur_hidden, hh_weight, hh_bias).chunk(3, 1)
  3340. reset_gate = (chunked_hgates[0] + chunked_igates[0]).sigmoid()
  3341. input_gate = (chunked_hgates[1] + chunked_igates[1]).sigmoid()
  3342. new_gate = (chunked_igates[2] + (chunked_hgates[2] * reset_gate)).tanh()
  3343. return (cur_hidden - new_gate) * input_gate + new_gate
  3344. @register_decomposition(aten.gru.data)
  3345. @aten.gru.data.py_impl(DispatchKey.CompositeImplicitAutograd)
  3346. @aten.gru.data.py_impl(DispatchKey.Autograd)
  3347. def gru_impl_data(
  3348. data,
  3349. batch_sizes,
  3350. hx,
  3351. params,
  3352. has_biases,
  3353. num_layers,
  3354. dropout,
  3355. train,
  3356. bidirectional,
  3357. ):
  3358. params = gather_params(params, has_biases, False)
  3359. out, final_hiddens = _rnn_helper(
  3360. data,
  3361. hx.unbind(0),
  3362. params,
  3363. has_biases,
  3364. num_layers,
  3365. dropout,
  3366. train,
  3367. bidirectional,
  3368. False,
  3369. partial(one_layer_rnn_data, batch_sizes=batch_sizes, hidden_fn=gru_cell_data),
  3370. )
  3371. return out, torch.stack(final_hiddens, 0)
  3372. @register_decomposition(aten.gru.input)
  3373. @aten.gru.input.py_impl(DispatchKey.CompositeImplicitAutograd)
  3374. @aten.gru.input.py_impl(DispatchKey.Autograd)
  3375. def gru_impl(
  3376. input,
  3377. hx,
  3378. params,
  3379. has_biases,
  3380. num_layers,
  3381. dropout,
  3382. train,
  3383. bidirectional,
  3384. batch_first,
  3385. ):
  3386. params = gather_params(params, has_biases, False)
  3387. out, final_hiddens = _rnn_helper(
  3388. input,
  3389. hx.unbind(0),
  3390. params,
  3391. has_biases,
  3392. num_layers,
  3393. dropout,
  3394. train,
  3395. bidirectional,
  3396. batch_first,
  3397. partial(one_layer_rnn, hidden_fn=gru_cell),
  3398. )
  3399. return out, torch.stack(final_hiddens, 0)
  3400. @register_decomposition(aten._upsample_bilinear2d_aa.vec)
  3401. @aten._upsample_bilinear2d_aa.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  3402. @aten._upsample_bilinear2d_aa.vec.py_impl(DispatchKey.Autograd)
  3403. def upsample_bilinear2d_aa_vec(input, output_size, align_corners, scale_factors):
  3404. osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
  3405. scale_h = get_scale_value(scale_factors, 0)
  3406. scale_w = get_scale_value(scale_factors, 1)
  3407. return torch.ops.aten._upsample_bilinear2d_aa(
  3408. input, osize, align_corners, scale_h, scale_w
  3409. )
  3410. @register_decomposition(aten._upsample_bicubic2d_aa.vec)
  3411. @aten._upsample_bicubic2d_aa.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  3412. @aten._upsample_bicubic2d_aa.vec.py_impl(DispatchKey.Autograd)
  3413. def upsample_bicubic2d_aa_vec(input, output_size, align_corners, scale_factors):
  3414. osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
  3415. scale_h = get_scale_value(scale_factors, 0)
  3416. scale_w = get_scale_value(scale_factors, 1)
  3417. return torch.ops.aten._upsample_bicubic2d_aa(
  3418. input, osize, align_corners, scale_h, scale_w
  3419. )
  3420. @register_decomposition(aten.upsample_bilinear2d.vec)
  3421. @register_decomposition(aten.upsample_trilinear3d.vec)
  3422. @aten.upsample_linear1d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  3423. @aten.upsample_linear1d.vec.py_impl(DispatchKey.Autograd)
  3424. @aten.upsample_bilinear2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  3425. @aten.upsample_bilinear2d.vec.py_impl(DispatchKey.Autograd)
  3426. @aten.upsample_trilinear3d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  3427. @aten.upsample_trilinear3d.vec.py_impl(DispatchKey.Autograd)
  3428. def _upsample_linear_vec(input, output_size, align_corners, scale_factors):
  3429. osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
  3430. scales = scale_factors if scale_factors else [None] * len(osize)
  3431. return _upsample_linear(input, osize, align_corners, scales)
  3432. @register_decomposition([aten.upsample_linear1d.default, aten.upsample_linear1d.out])
  3433. @out_wrapper()
  3434. def upsample_linear1d(
  3435. input: Tensor,
  3436. output_size: list[int],
  3437. align_corners: bool,
  3438. scales_w: Optional[float] = None,
  3439. ) -> Tensor:
  3440. return _upsample_linear(input, output_size, align_corners, [scales_w])
  3441. @register_decomposition(
  3442. [aten.upsample_bilinear2d.default, aten.upsample_bilinear2d.out]
  3443. )
  3444. @aten.upsample_bilinear2d.default.py_impl(DispatchKey.Autograd)
  3445. @out_wrapper()
  3446. def upsample_bilinear2d(
  3447. input: Tensor,
  3448. output_size: list[int],
  3449. align_corners: bool,
  3450. scales_h: Optional[float] = None,
  3451. scales_w: Optional[float] = None,
  3452. ) -> Tensor:
  3453. return _upsample_linear(input, output_size, align_corners, [scales_h, scales_w])
  3454. @register_decomposition(
  3455. [aten.upsample_trilinear3d.default, aten.upsample_trilinear3d.out]
  3456. )
  3457. @out_wrapper()
  3458. def upsample_trilinear3d(
  3459. input: Tensor,
  3460. output_size: list[int],
  3461. align_corners: bool,
  3462. scales_d: Optional[float] = None,
  3463. scales_h: Optional[float] = None,
  3464. scales_w: Optional[float] = None,
  3465. ) -> Tensor:
  3466. return _upsample_linear(
  3467. input, output_size, align_corners, [scales_d, scales_h, scales_w]
  3468. )
  3469. def _compute_scale(in_size, out_size, align_corners, scale=None):
  3470. if align_corners:
  3471. return (in_size - 1.0) / (out_size - 1.0) if out_size > 1 else 0
  3472. else:
  3473. return 1.0 / scale if scale is not None and scale > 0 else in_size / out_size
  3474. def _compute_source_index(scale, dst_index, align_corners):
  3475. if align_corners:
  3476. return scale * dst_index
  3477. else:
  3478. return scale * (dst_index + 0.5) - 0.5
  3479. def _sum_tensors_uint8(
  3480. src: Iterable[Tensor], weights: Iterable[Tensor], weights_precision: Tensor
  3481. ) -> Tensor:
  3482. output = _sum_tensors(
  3483. s.to(torch.int32) * c.to(torch.int32) for s, c in zip(src, weights)
  3484. ) + (1 << (weights_precision - 1))
  3485. output = output >> weights_precision
  3486. return torch.clamp(output, 0, 255).to(torch.uint8)
  3487. def _compute_weight_precision(weights: TensorSequenceType) -> Tensor:
  3488. max_weight = torch.stack(weights).max()
  3489. max_weight_precision = 22
  3490. precisions = torch.arange(max_weight_precision, device=max_weight.device)
  3491. values = 0.5 + max_weight * (1 << (precisions + 1))
  3492. mask = values >= (1 << 15)
  3493. return max_weight_precision - mask.sum()
  3494. @pw_cast_for_opmath
  3495. def _upsample_linear(
  3496. input: Tensor,
  3497. output_size: list[int],
  3498. align_corners: bool,
  3499. scales: list[Optional[float]],
  3500. ) -> Tensor:
  3501. # get dimensions of original image
  3502. n_channels = input.shape[1]
  3503. inp_sizes = input.shape[2:]
  3504. n_dims = len(inp_sizes)
  3505. _, dtype = utils.elementwise_dtypes(
  3506. input,
  3507. type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  3508. )
  3509. def get_values(inp_size, out_size, scales, nsqueeze):
  3510. # First Calculate scaling factor
  3511. scale_factor = _compute_scale(inp_size, out_size, align_corners, scales)
  3512. # We have to create arange with int64 dtype and use .to in order to avoid
  3513. # additional kernels creation in inductor and get a perf slowdown
  3514. i = torch.arange(out_size, device=input.device).to(dtype=dtype)
  3515. x_f32 = _compute_source_index(scale_factor, i, align_corners).clamp(min=0.0)
  3516. x_f32 = x_f32.reshape(x_f32.shape[0], *[1] * (nsqueeze))
  3517. x = x_f32.to(torch.int64)
  3518. xp1 = (x + 1).clamp(max=inp_size - 1)
  3519. return x_f32, x, xp1
  3520. values = [
  3521. get_values(inp_size, out_size, scales, n_dims - 1 - i)
  3522. for i, (inp_size, out_size, scales) in enumerate(
  3523. zip(inp_sizes, output_size, scales)
  3524. )
  3525. ]
  3526. xs_f32, xs, xp1s = list(zip(*values))
  3527. vs = []
  3528. for a in product(*[[0, 1]] * n_dims):
  3529. idx = [None, None] + [xs[k] if a[k] == 0 else xp1s[k] for k in range(n_dims)]
  3530. v = aten._unsafe_index(input, idx)
  3531. v = _maybe_convert_to_dtype(v, dtype)
  3532. vs.append(v)
  3533. for i in reversed(range(n_dims)):
  3534. xscale = (xs_f32[i] - xs[i]).clamp(0.0, 1.0).to(dtype)
  3535. vs = [
  3536. # x1 * (1 - alpha) + x2 * alpha == x1 + (x2 - x1) * alpha
  3537. v1 + torch.mul(v2 - v1, xscale)
  3538. for v1, v2 in zip(vs[::2], vs[1::2])
  3539. ]
  3540. if len(vs) != 1:
  3541. raise AssertionError(f"Expected vs to have exactly 1 element, got {len(vs)}")
  3542. result = vs[0]
  3543. # convert output to correct memory format, if necessary
  3544. memory_format = utils.suggest_memory_format(input)
  3545. # following "heuristic: only use channels_last path when it's faster than the contiguous path"
  3546. if input.device.type == "cuda" and n_channels < 16:
  3547. memory_format = torch.contiguous_format
  3548. if not isinstance(result, torch.Tensor):
  3549. raise AssertionError(
  3550. f"Expected result to be a Tensor, got {type(result).__name__}"
  3551. )
  3552. result = result.contiguous(memory_format=memory_format)
  3553. if not input.is_floating_point():
  3554. result = result.round()
  3555. return result
  3556. # We should be applying decompositions after all transformations
  3557. @register_decomposition(aten.is_same_size.default)
  3558. def is_same_size(a: Tensor, b: Tensor) -> bool:
  3559. return a.shape == b.shape
  3560. @register_decomposition([aten._reshape_alias, aten._unsafe_view])
  3561. @out_wrapper()
  3562. def _reshape_alias(x, shape, *args):
  3563. return aten.view(x, shape)
  3564. @register_decomposition([aten._unsafe_index])
  3565. def _unsafe_index(x, indices):
  3566. return aten.index(x, indices)
  3567. @register_decomposition([aten._unsafe_index_put])
  3568. def _unsafe_index_put(x, indices, value, accumulate=False):
  3569. return aten.index_put(x, indices, value, accumulate)
  3570. @register_decomposition([aten._unsafe_masked_index])
  3571. def _unsafe_masked_index(x, mask, indices, fill):
  3572. for index in indices:
  3573. if index is not None:
  3574. torch._check(
  3575. index.dtype in [torch.long, torch.int],
  3576. lambda: "tensors used as indices must be long or int tensors",
  3577. )
  3578. torch._check(
  3579. mask.dtype == torch.bool,
  3580. lambda: "tensors used as masks must be bool tensors",
  3581. )
  3582. from torch.fx.experimental.symbolic_shapes import guard_or_false
  3583. if guard_or_false(x.numel() == 0):
  3584. meta_result = torch._meta_registrations.meta_index_Tensor(x, indices)
  3585. return x.new_full(meta_result.shape, fill)
  3586. for i in range(len(indices)):
  3587. index = indices[i]
  3588. if index is not None:
  3589. indices[i] = index.clamp(min=0, max=x.size(i) - 1)
  3590. return aten._unsafe_index(x, indices).masked_fill(~mask, fill)
  3591. @register_decomposition([aten._unsafe_masked_index_put_accumulate])
  3592. def _unsafe_masked_index_put_accumulate(x, mask, indices, values):
  3593. for index in indices:
  3594. if index is not None:
  3595. torch._check(
  3596. index.dtype in [torch.long, torch.int],
  3597. lambda: "tensors used as indices must be long or int tensors",
  3598. )
  3599. torch._check(
  3600. mask.dtype == torch.bool,
  3601. lambda: "tensors used as masks must be bool tensors",
  3602. )
  3603. if x.numel() == 0:
  3604. return x.clone()
  3605. for i in range(len(indices)):
  3606. index = indices[i]
  3607. if index is not None:
  3608. indices[i] = index.clamp(min=-x.size(i), max=x.size(i) - 1)
  3609. masked_value = values.masked_fill(~mask, 0)
  3610. return aten._unsafe_index_put(x, indices, masked_value, accumulate=True)
  3611. def _nll_loss_forward(
  3612. self: Tensor,
  3613. target: Tensor,
  3614. weight: Optional[Tensor],
  3615. reduction: int,
  3616. ignore_index: int,
  3617. ) -> tuple[Tensor, Tensor]:
  3618. # self can be [N, C] or [C]
  3619. # target can be [N] or []
  3620. n_dims = self.dim()
  3621. channel_dim = 1
  3622. if n_dims < 2:
  3623. channel_dim = 0
  3624. if weight is not None:
  3625. if n_dims > 1:
  3626. shape = [
  3627. 1,
  3628. ] * n_dims
  3629. shape[channel_dim] = weight.shape[0]
  3630. w = weight.view(shape)
  3631. else:
  3632. w = weight
  3633. self = self * w
  3634. safe_target = torch.where(target != ignore_index, target, 0)
  3635. safe_target_ = safe_target.unsqueeze(channel_dim)
  3636. # target can be [N, 1] or [1]
  3637. result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim)
  3638. result = torch.where(target != ignore_index, result, 0)
  3639. if reduction == Reduction.NONE.value and n_dims > 1:
  3640. total_weight = self.new_full((), 0.0)
  3641. return result, total_weight
  3642. if weight is not None:
  3643. # pyrefly: ignore [unbound-name]
  3644. w = w.expand(self.shape)
  3645. wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim)
  3646. wsum = torch.where(target != ignore_index, wsum, 0)
  3647. total_weight = wsum.sum()
  3648. else:
  3649. total_weight = (target != ignore_index).sum().to(self)
  3650. if reduction == Reduction.SUM.value:
  3651. result = result.sum()
  3652. elif reduction == Reduction.MEAN.value:
  3653. result = result.sum() / total_weight
  3654. return result, total_weight
  3655. @register_decomposition(aten.nll_loss_forward)
  3656. @out_wrapper("output", "total_weight")
  3657. def nll_loss_forward(
  3658. self: Tensor,
  3659. target: Tensor,
  3660. weight: Optional[Tensor],
  3661. reduction: int,
  3662. ignore_index: int,
  3663. ) -> tuple[Tensor, Tensor]:
  3664. if not (self.dim() > 0 and self.dim() <= 2):
  3665. raise AssertionError(f"input tensor should be 1D or 2D, got {self.dim()}D")
  3666. if target.dim() > 1:
  3667. raise AssertionError(
  3668. f"0D or 1D target tensor expected, multi-target not supported, got {target.dim()}D"
  3669. )
  3670. no_batch_dim = self.dim() == 1 and target.dim() == 0
  3671. if not (no_batch_dim or (self.shape[0] == target.shape[0])):
  3672. raise AssertionError(
  3673. f"size mismatch (got input: {self.shape}, target: {target.shape})"
  3674. )
  3675. n_classes = self.shape[-1]
  3676. if weight is not None and not (weight.dim() == 1 and weight.numel() == n_classes):
  3677. raise AssertionError(
  3678. f"weight tensor should be defined either for all {n_classes} classes or no classes "
  3679. f"but got weight tensor of shape: {weight.shape}"
  3680. )
  3681. return _nll_loss_forward(self, target, weight, reduction, ignore_index)
  3682. @register_decomposition(aten.nll_loss2d_forward)
  3683. @out_wrapper("output", "total_weight")
  3684. def nll_loss2d_forward(
  3685. self: Tensor,
  3686. target: Tensor,
  3687. weight: Optional[Tensor],
  3688. reduction: int,
  3689. ignore_index: int,
  3690. ) -> tuple[Tensor, Tensor]:
  3691. return _nll_loss_forward(self, target, weight, reduction, ignore_index)
  3692. # These are adapted from aten/src/ATen/native/UpSample.h, which is based on
  3693. # https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
  3694. def _upsample_cubic_convolution1(x: Tensor, A: float) -> Tensor:
  3695. return ((A + 2) * x - (A + 3)) * x * x + 1
  3696. def _upsample_cubic_convolution2(x: Tensor, A: float) -> Tensor:
  3697. return ((A * x - 5 * A) * x + 8 * A) * x - 4 * A
  3698. def _upsample_get_cubic_coefficients(t: Tensor) -> TensorSequenceType:
  3699. A = -0.75
  3700. if t.device == torch.device("cpu"):
  3701. tt1 = torch.stack([t, 1.0 - t], dim=0)
  3702. tt2 = torch.stack([t + 1.0, 2.0 - t], dim=0)
  3703. w03 = _upsample_cubic_convolution2(tt2, A)
  3704. w12 = _upsample_cubic_convolution1(tt1, A)
  3705. w0, w3 = torch.unbind(w03, dim=0)
  3706. w1, w2 = torch.unbind(w12, dim=0)
  3707. return w0, w1, w2, w3
  3708. else:
  3709. return (
  3710. _upsample_cubic_convolution2(t + 1.0, A),
  3711. _upsample_cubic_convolution1(t, A),
  3712. _upsample_cubic_convolution1(1.0 - t, A),
  3713. _upsample_cubic_convolution2(2.0 - t, A),
  3714. )
  3715. def _upsample_cubic_interp1d(coeffs: TensorSequenceType, ts: Tensor) -> Tensor:
  3716. coeffs2 = _upsample_get_cubic_coefficients(ts)
  3717. return _sum_tensors(c1 * c2 for (c1, c2) in zip(coeffs, coeffs2))
  3718. # Need this instead of just sum() to keep mypy happy
  3719. def _sum_tensors(ts: Iterable[Tensor]) -> Tensor:
  3720. return reduce(torch.add, ts)
  3721. def _linspace_from_neg_one(
  3722. num_steps: int, align_corners: bool, dtype: torch.dtype, device: torch.device
  3723. ):
  3724. if num_steps <= 1:
  3725. return torch.tensor(0, device=device, dtype=dtype)
  3726. a = ((num_steps - 1) / num_steps) if not align_corners else 1
  3727. return torch.linspace(-a, a, steps=num_steps, device=device, dtype=dtype)
  3728. def _make_base_grid_4d(theta: Tensor, h: int, w: int, align_corners: bool):
  3729. dtype = theta.dtype
  3730. device = theta.device
  3731. # Using padding and summation generates a single kernel vs using torch.stack where 3 kernels generated
  3732. # corresponding to each individual tensor: grid_x, grid_y, grid_one
  3733. grid_x = _linspace_from_neg_one(w, align_corners, dtype, device).view(1, w, 1)
  3734. grid_y = _linspace_from_neg_one(h, align_corners, dtype, device).view(h, 1, 1)
  3735. grid_one = torch.ones((1, 1, 1), dtype=dtype, device=device)
  3736. # this is just a temporary hack and we should use torch.stack here once #104480 is merged
  3737. grid_x = torch.nn.functional.pad(grid_x, pad=(0, 2), mode="constant", value=0)
  3738. grid_y = torch.nn.functional.pad(grid_y, pad=(1, 1), mode="constant", value=0)
  3739. grid_one = torch.nn.functional.pad(grid_one, pad=(2, 0), mode="constant", value=0)
  3740. return grid_x + grid_y + grid_one
  3741. def _make_base_grid_5d(theta: Tensor, d: int, h: int, w: int, align_corners: bool):
  3742. dtype = theta.dtype
  3743. device = theta.device
  3744. grid_x = _linspace_from_neg_one(w, align_corners, dtype, device).view(1, 1, w, 1)
  3745. grid_y = _linspace_from_neg_one(h, align_corners, dtype, device).view(1, h, 1, 1)
  3746. grid_z = _linspace_from_neg_one(d, align_corners, dtype, device).view(d, 1, 1, 1)
  3747. grid_one = torch.ones((1, 1, 1, 1), dtype=dtype, device=device)
  3748. # this is just a temporary hack and we should use torch.stack here once #104480 is merged
  3749. grid_x = torch.nn.functional.pad(grid_x, pad=(0, 3), mode="constant", value=0)
  3750. grid_y = torch.nn.functional.pad(grid_y, pad=(1, 2), mode="constant", value=0)
  3751. grid_z = torch.nn.functional.pad(grid_z, pad=(2, 1), mode="constant", value=0)
  3752. grid_one = torch.nn.functional.pad(grid_one, pad=(3, 0), mode="constant", value=0)
  3753. return grid_x + grid_y + grid_z + grid_one
  3754. def _affine_grid_generator_4d(theta: Tensor, size: list[int], align_corners: bool):
  3755. n, _, h, w = size
  3756. base_grid = _make_base_grid_4d(theta, h, w, align_corners=align_corners)
  3757. # base_grid shape is (h, w, 3) and theta shape is (n, 2, 3)
  3758. # We do manually a matrix multiplication which is faster than mm()
  3759. # (h * w, 3, 1) * (n, 1, 3, 2) -> (n, h * w, 2)
  3760. grid = (base_grid.view(-1, 3, 1) * theta.mT.unsqueeze(1)).sum(-2)
  3761. return grid.view(n, h, w, 2)
  3762. def _affine_grid_generator_5d(theta: Tensor, size: list[int], align_corners: bool):
  3763. n, _, d, h, w = size
  3764. base_grid = _make_base_grid_5d(theta, d, h, w, align_corners=align_corners)
  3765. # base_grid shape is (d, h, w, 4) and theta shape is (n, 3, 4)
  3766. # We do manually a matrix multiplication which is faster than mm()
  3767. # (d * h * w, 4, 1) * (n, 1, 4, 3) -> (n, h * w, 3)
  3768. grid = (base_grid.view(-1, 4, 1) * theta.mT.unsqueeze(1)).sum(-2)
  3769. return grid.view(n, d, h, w, 3)
  3770. @register_decomposition(aten.affine_grid_generator)
  3771. @out_wrapper()
  3772. @pw_cast_for_opmath
  3773. def affine_grid_generator(theta: Tensor, size: list[int], align_corners: bool):
  3774. torch._check(
  3775. len(size) in (4, 5),
  3776. lambda: "affine_grid_generator needs 4d (spatial) or 5d (volumetric) inputs.",
  3777. )
  3778. if len(size) == 4:
  3779. return _affine_grid_generator_4d(theta, size, align_corners=align_corners)
  3780. else:
  3781. return _affine_grid_generator_5d(theta, size, align_corners=align_corners)
  3782. def _grid_sampler_2d(
  3783. a: Tensor,
  3784. grid: Tensor,
  3785. interpolation_mode: int = 0,
  3786. padding_mode: int = 0,
  3787. align_corners: bool = False,
  3788. _expand_grid: bool = True,
  3789. ) -> Tensor:
  3790. # This method is a copy of grid_sampler_2d implementation and introduced with additional arg _expand_grid to
  3791. # optionally expand the input grid for performance reasons.
  3792. # Experimenting locally it was found that compiled CUDA code is accelerated by ~5x
  3793. # and CPU code by ~2x on bicubic mode, if we expand the grid from (N, H, W, 2) into (N, C, H, W, 2)
  3794. # However, this leads to a slowdown around ~0.8x on CPU bilinear mode, channels first.
  3795. # Thus we apply this hack to not expand the grid for this case.
  3796. torch._check(
  3797. interpolation_mode in (0, 1, 2),
  3798. lambda: f"Invalid interpolation mode {interpolation_mode}",
  3799. )
  3800. torch._check(
  3801. padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}"
  3802. )
  3803. def unnormalize(coords: Tensor, size: int) -> Tensor:
  3804. # Rescale coordinates from [-1, 1] to:
  3805. # [0, size - 1] if align_corners is True
  3806. # [-.5, size -.5] if align_corners is False
  3807. mul = (size * 0.5 - 0.5) if align_corners else (size * 0.5)
  3808. ofs = size * 0.5 - 0.5
  3809. return coords * mul + ofs
  3810. # Reflects coordinates until they fall between low and high (inclusive).
  3811. # The bounds are passed as twice their value so that half-integer values
  3812. # can be represented as ints.
  3813. def reflect_coordinates(coords: Tensor, twice_low: int, twice_high: int) -> Tensor:
  3814. if twice_low == twice_high:
  3815. return torch.zeros_like(coords)
  3816. coords_min = twice_low / 2
  3817. coords_span = (twice_high - twice_low) / 2
  3818. coords2 = (coords - coords_min).abs()
  3819. extra = torch.fmod(coords2, coords_span)
  3820. flips = (coords2 / coords_span).floor().to(dtype=torch.int8)
  3821. return torch.where(
  3822. flips & 1 == 0, extra + coords_min, coords_span + coords_min - extra
  3823. )
  3824. def compute_coordinates(coords: Tensor, size: int) -> Tensor:
  3825. if padding_mode == 0: # Zero
  3826. return coords
  3827. elif padding_mode == 1: # Borders
  3828. return torch.clamp(coords, 0, size - 1)
  3829. else: # padding_mode == 2, Reflection
  3830. if align_corners:
  3831. coords_reflected = reflect_coordinates(coords, 0, 2 * (size - 1))
  3832. else:
  3833. coords_reflected = reflect_coordinates(coords, -1, 2 * size - 1)
  3834. return torch.clamp(coords_reflected, 0, size - 1)
  3835. def compute_source_index(coords: Tensor, size: int) -> Tensor:
  3836. coords_un = unnormalize(coords, size)
  3837. return compute_coordinates(coords_un, size)
  3838. N, C, iH, iW = a.shape
  3839. _, oH, oW, two = grid.shape
  3840. if two != 2:
  3841. raise AssertionError(
  3842. f"grid last dimension must be 2 (for x,y coords), got {two}"
  3843. )
  3844. if _expand_grid:
  3845. # Let's expand grid to [N, C, oH, oW, 2]
  3846. # This allows to generate a single triton cuda kernel instead of two kernels.
  3847. # Two kernels are due source indices, weights have shape (N, 1, oH, oW), xnumel=N*oH*oW
  3848. # and output has shape (N, C, oH, oW), xnumel=N*C*oH*oW
  3849. # Expanding grid to (N, C, oH, oW, two) unifies xnumel to N*C*oH*oW
  3850. grid = grid.view(N, 1, oH, oW, two).expand(N, C, oH, oW, 2)
  3851. def in_bounds_cond(xs: Tensor, ys: Tensor) -> Tensor:
  3852. return torch.logical_and(
  3853. 0 <= xs, torch.logical_and(xs < iW, torch.logical_and(0 <= ys, ys < iH))
  3854. )
  3855. N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1)
  3856. C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1)
  3857. def clip(xs: Tensor, ys: Tensor, ws: Tensor) -> TensorSequenceType:
  3858. cond = in_bounds_cond(xs, ys)
  3859. # To clip to inside valid coordinates, we map the coordinates
  3860. # to (x, y) = (0, 0) and also set the weight to 0
  3861. # We also change the shape of the tensor to the appropriate one for
  3862. # broadcasting with N_idx, C_idx for the purposes of advanced indexing
  3863. c = C if _expand_grid else 1
  3864. return tuple(
  3865. torch.where(cond, t, 0).view(N, c, oH, oW)
  3866. for t in (xs.to(dtype=torch.int64), ys.to(dtype=torch.int64), ws)
  3867. )
  3868. def get_summand(ix: Tensor, iy: Tensor, w) -> Tensor:
  3869. # Perform clipping, index into input tensor and multiply by weight
  3870. idx_x, idx_y, w_ = clip(ix, iy, w)
  3871. return a[N_idx, C_idx, idx_y, idx_x] * w_
  3872. x = grid[..., 0]
  3873. y = grid[..., 1]
  3874. if interpolation_mode == 0: # Bilinear
  3875. ix = compute_source_index(x, iW)
  3876. iy = compute_source_index(y, iH)
  3877. ix_nw, iy_nw = ix.floor(), iy.floor()
  3878. ix_ne, iy_ne = ix_nw + 1, iy_nw
  3879. ix_sw, iy_sw = ix_nw, iy_nw + 1
  3880. ix_se, iy_se = ix_ne, iy_sw
  3881. w_nw = (ix_se - ix) * (iy_se - iy)
  3882. w_ne = (ix - ix_sw) * (iy_sw - iy)
  3883. w_sw = (ix_ne - ix) * (iy - iy_ne)
  3884. w_se = (ix - ix_nw) * (iy - iy_nw)
  3885. return _sum_tensors(
  3886. get_summand(ix, iy, w)
  3887. for (ix, iy, w) in (
  3888. (ix_nw, iy_nw, w_nw),
  3889. (ix_ne, iy_ne, w_ne),
  3890. (ix_sw, iy_sw, w_sw),
  3891. (ix_se, iy_se, w_se),
  3892. )
  3893. )
  3894. elif interpolation_mode == 1: # Nearest
  3895. ix = compute_source_index(x, iW)
  3896. iy = compute_source_index(y, iH)
  3897. ix_nearest = ix.round()
  3898. iy_nearest = iy.round()
  3899. return get_summand(ix_nearest, iy_nearest, 1)
  3900. else: # interpolation_mode == 2, Bicubic
  3901. ix = unnormalize(x, iW)
  3902. iy = unnormalize(y, iH)
  3903. ix_nw = ix.floor()
  3904. iy_nw = iy.floor()
  3905. tx = ix - ix_nw
  3906. ty = iy - iy_nw
  3907. if not _expand_grid:
  3908. tx = tx.unsqueeze(1)
  3909. ty = ty.unsqueeze(1)
  3910. def get_value_bounded(ix: Tensor, iy: Tensor) -> Tensor:
  3911. x = compute_coordinates(ix, iW)
  3912. y = compute_coordinates(iy, iH)
  3913. return get_summand(x, y, 1)
  3914. def get_coeff(ofs: int) -> Tensor:
  3915. iy_ofs = iy_nw + (ofs - 1)
  3916. cs = (
  3917. get_value_bounded(ix_nw - 1, iy_ofs),
  3918. get_value_bounded(ix_nw, iy_ofs),
  3919. get_value_bounded(ix_nw + 1, iy_ofs),
  3920. get_value_bounded(ix_nw + 2, iy_ofs),
  3921. )
  3922. return _upsample_cubic_interp1d(cs, tx)
  3923. coeffs = tuple(get_coeff(ofs) for ofs in range(4))
  3924. return _upsample_cubic_interp1d(coeffs, ty)
  3925. @register_decomposition(aten.grid_sampler_2d)
  3926. @out_wrapper()
  3927. @pw_cast_for_opmath
  3928. def grid_sampler_2d(
  3929. a: Tensor,
  3930. grid: Tensor,
  3931. interpolation_mode: int = 0,
  3932. padding_mode: int = 0,
  3933. align_corners: bool = False,
  3934. ) -> Tensor:
  3935. return _grid_sampler_2d(
  3936. a,
  3937. grid=grid,
  3938. interpolation_mode=interpolation_mode,
  3939. padding_mode=padding_mode,
  3940. align_corners=align_corners,
  3941. )
  3942. @register_decomposition(aten.mv)
  3943. @out_wrapper(exact_dtype=True)
  3944. @pw_cast_for_opmath
  3945. def mv(self, vec):
  3946. torch._check(
  3947. self.dim() == 2 and vec.dim() == 1,
  3948. lambda: f"matrix @ vector expected, got {self.dim()}, {vec.dim()}",
  3949. )
  3950. torch._check(
  3951. self.size(1) == vec.size(0),
  3952. lambda: f"size mismatch, got input ({self.size(0)}x{self.size(1)}), vec ({vec.size(0)})",
  3953. )
  3954. return (self * vec).sum(dim=1)
  3955. @register_decomposition(aten.binary_cross_entropy_with_logits)
  3956. @out_wrapper()
  3957. def binary_cross_entropy_with_logits(
  3958. self, target, weight=None, pos_weight=None, reduction=Reduction.MEAN.value
  3959. ):
  3960. if pos_weight is not None:
  3961. log_weight = (pos_weight - 1) * target + 1
  3962. loss = (1 - target) * self - (log_weight * F.logsigmoid(self))
  3963. else:
  3964. loss = (1 - target) * self - F.logsigmoid(self)
  3965. if weight is not None:
  3966. loss = loss * weight
  3967. return apply_loss_reduction(loss, reduction)
  3968. def should_fold(tensor1: torch.Tensor, tensor2: torch.Tensor, is_out: bool) -> bool:
  3969. # For comments of the logic of this function see eager in /native/LinearAlgebra.cpp
  3970. t1, t2 = (tensor1, tensor2) if tensor1.ndim >= tensor2.ndim else (tensor2, tensor1)
  3971. from torch.fx.experimental.symbolic_shapes import guard_or_false
  3972. if not (t1.ndim >= 3 and t2.ndim <= 2):
  3973. return False
  3974. if t2.requires_grad and not is_out:
  3975. return True
  3976. if tensor1.ndim == 2:
  3977. return False
  3978. if guard_or_false(t1.numel() == 0):
  3979. return True
  3980. t1_shape = t1.shape
  3981. t1_stride = t1.stride()
  3982. # Check the contiguous, we can skip the dim with size of 1
  3983. # as aten: https://github.com/pytorch/pytorch/blob/e201460f8aa1510b4c4686627d57b69756c4b916/aten/src/ATen/TensorGeometry.cpp#L17
  3984. expected_stride = [1]
  3985. for size in reversed(t1_shape[1:]):
  3986. expected_stride.append(size * expected_stride[-1])
  3987. return all(
  3988. guard_or_false(size == 1) or guard_or_false(left == right)
  3989. for left, right, size in zip(
  3990. t1_stride, list(reversed(expected_stride)), t1_shape
  3991. )
  3992. )
  3993. @aten.matmul.default.py_impl(DispatchKey.CompositeImplicitAutograd)
  3994. @aten.matmul.out.py_impl(DispatchKey.CompositeImplicitAutograd)
  3995. @out_wrapper(pass_is_out=True)
  3996. def matmul(tensor1, tensor2, *, is_out=False):
  3997. from torch.fx.experimental.symbolic_shapes import guard_or_false, guard_or_true
  3998. dim_tensor1 = tensor1.dim()
  3999. dim_tensor2 = tensor2.dim()
  4000. if dim_tensor1 == 0 or dim_tensor2 == 0:
  4001. raise AssertionError(
  4002. f"matmul does not support 0-dimensional tensors, got dims: {dim_tensor1} and {dim_tensor2}"
  4003. )
  4004. if dim_tensor1 == 1 and dim_tensor2 == 1:
  4005. return torch.dot(tensor1, tensor2)
  4006. elif dim_tensor1 == 2 and dim_tensor2 == 1:
  4007. return torch.mv(tensor1, tensor2)
  4008. elif dim_tensor1 == 1 and dim_tensor2 == 2:
  4009. return torch.squeeze(torch.mm(torch.unsqueeze(tensor1, 0), tensor2), 0)
  4010. elif dim_tensor1 == 2 and dim_tensor2 == 2:
  4011. return torch.mm(tensor1, tensor2)
  4012. elif should_fold(tensor1, tensor2, is_out):
  4013. # dim_tensor1 >=3 && (dim_tensor2 == 1 || dim_tensor2 == 2) ||
  4014. # dim_tensor2 >=3 && (dim_tensor1 == 1 || dim_tensor1 == 2)
  4015. # and some condition on the strides is fulfilled
  4016. # optimization: use mm instead of bmm by folding the batch of the larger tensor
  4017. # into its leading matrix dimension
  4018. transpose = dim_tensor2 > dim_tensor1
  4019. t1 = tensor2.mT if transpose else tensor1
  4020. t2 = (
  4021. tensor2 if not transpose else (tensor1.t() if dim_tensor1 == 2 else tensor1)
  4022. )
  4023. # Invariant: t1.dim() >= 3 && (t2.dim() == 1 || t2.dim() == 2)
  4024. # and t1 and t2 are matmul-compatible
  4025. # Why not t1.view(-1, sizes_1[-1])?
  4026. # If the last dim is 0, then view(-1, 0) won't work because the -1 becomes ambiguous.
  4027. # This can happen in e.g. [3, 5, 0] @ [0, 0].
  4028. sizes_1 = t1.shape
  4029. output_shape = list(sizes_1[:-1])
  4030. folded_dim1 = reduce(operator.mul, output_shape)
  4031. # Readjust output_shape if we are multiplying by a matrix
  4032. t2_is_matrix = t2.dim() == 2
  4033. if t2_is_matrix:
  4034. output_shape.append(t2.shape[1])
  4035. # This will almost always be a view.
  4036. # It may not be a view if t2->requires_grad(). See should_fold in aten/ for an explanation
  4037. t1_folded = t1.reshape(folded_dim1, sizes_1[-1])
  4038. if t2_is_matrix:
  4039. # This copies if we perform a 2D @ 3D and the first tensor requires_grad
  4040. # See should_fold native/LinearAlgebra.cpp for why.
  4041. output = torch.ops.aten._unsafe_view(t1_folded.mm(t2), output_shape)
  4042. return output.mT.contiguous() if transpose else output
  4043. else:
  4044. return torch.ops.aten._unsafe_view(t1_folded.mv(t2), output_shape)
  4045. elif dim_tensor1 >= 1 and dim_tensor2 >= 1:
  4046. # We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
  4047. # we track m1 vs m2 separately even though they must match for nicer error messages
  4048. n = tensor1.size(-2) if dim_tensor1 > 1 else 1
  4049. m1 = tensor1.size(-1)
  4050. batch_tensor1 = tensor1.shape[:-2]
  4051. m2 = tensor2.size(-2) if dim_tensor2 > 1 else tensor2.size(-1)
  4052. p = tensor2.size(-1) if dim_tensor2 > 1 else 1
  4053. batch_tensor2: list[int] = []
  4054. # TODO: handling of slice
  4055. for i in range(dim_tensor2 - 2):
  4056. batch_tensor2.append(tensor2.size(i))
  4057. # Same optimization for the gradients as that in should_fold
  4058. # If we're going to broadcast, we force it to go through the should_fold branch
  4059. if (
  4060. dim_tensor1 == 3
  4061. and dim_tensor2 == 3
  4062. and guard_or_true(batch_tensor1[0] != batch_tensor2[0])
  4063. ):
  4064. if guard_or_false(batch_tensor1[0] == 1) and tensor1.requires_grad:
  4065. return matmul(tensor1.squeeze(0), tensor2)
  4066. if guard_or_false(batch_tensor2[0] == 1) and tensor2.requires_grad:
  4067. return matmul(tensor1, tensor2.squeeze(0))
  4068. # expand the batch portion (i.e. cut off matrix dimensions and expand rest)
  4069. expand_batch_portion = list(
  4070. torch.broadcast_shapes(batch_tensor1, batch_tensor2)
  4071. )
  4072. tensor1_expand_size = expand_batch_portion + [n, m1]
  4073. expand_batch_product = prod(expand_batch_portion)
  4074. # HACK: We need reshape with symint support
  4075. tensor1_expanded = tensor1.expand(tensor1_expand_size).reshape(
  4076. expand_batch_product, n, m1
  4077. )
  4078. vector_rhs = dim_tensor2 == 1
  4079. if vector_rhs:
  4080. tensor2_expand_size = expand_batch_portion + [m2]
  4081. tensor2_expanded = (
  4082. tensor2.expand(tensor2_expand_size)
  4083. .reshape(expand_batch_product, m2)
  4084. .unsqueeze(2)
  4085. )
  4086. else:
  4087. tensor2_expand_size = expand_batch_portion + [m2, p]
  4088. tensor2_expanded = tensor2.expand(tensor2_expand_size).reshape(
  4089. expand_batch_product, m2, p
  4090. )
  4091. output_shape = expand_batch_portion
  4092. if dim_tensor1 > 1:
  4093. output_shape.append(n)
  4094. if dim_tensor2 > 1:
  4095. output_shape.append(p)
  4096. if vector_rhs:
  4097. return tensor1_expanded.bmm(tensor2_expanded).squeeze(-1).view(output_shape)
  4098. else:
  4099. return tensor1_expanded.bmm(tensor2_expanded).view(output_shape)
  4100. else:
  4101. torch._check(False, lambda: "both arguments to matmul need to be at least 1D")
  4102. @register_decomposition([aten.upsample_bicubic2d.default, aten.upsample_bicubic2d.out])
  4103. @aten.upsample_bicubic2d.default.py_impl(DispatchKey.Autograd)
  4104. @out_wrapper()
  4105. @pw_cast_for_opmath
  4106. def upsample_bicubic2d_default(
  4107. input: Tensor,
  4108. output_size: tuple[int, int],
  4109. align_corners: bool,
  4110. scale_h: Optional[float] = None,
  4111. scale_w: Optional[float] = None,
  4112. ) -> Tensor:
  4113. # get dimensions of original image
  4114. _, _, in_h, in_w = input.shape
  4115. # Calculate horizontal and vertical scaling factor
  4116. h_scale_factor = _compute_scale(in_h, output_size[0], align_corners, scale_h)
  4117. w_scale_factor = _compute_scale(in_w, output_size[1], align_corners, scale_w)
  4118. _, dtype = utils.elementwise_dtypes(
  4119. input, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
  4120. )
  4121. # We have to create arange with int64 dtype and use .to in order to avoid
  4122. # additional kernels creation in inductor and get a perf slowdown
  4123. i = torch.arange(output_size[0], device=input.device).to(dtype=dtype)
  4124. j = torch.arange(output_size[1], device=input.device).to(dtype=dtype)
  4125. x_float = _compute_source_index(w_scale_factor, j, align_corners)
  4126. y_float = _compute_source_index(h_scale_factor, i, align_corners)
  4127. y_float = y_float.unsqueeze(-1)
  4128. x = x_float.floor()
  4129. y = y_float.floor()
  4130. # We should also clamp xscale/yscale
  4131. # See guard_index_and_lambda in UpSample.h
  4132. yscale = (y_float - y).clamp(0.0, 1.0)
  4133. xscale = (x_float - x).clamp(0.0, 1.0)
  4134. x = x.to(torch.int64)
  4135. y = y.to(torch.int64)
  4136. iys_ofs = (y - 1, y, y + 1, y + 2)
  4137. ixs_ofs = (x - 1, x, x + 1, x + 2)
  4138. weights_x = _upsample_get_cubic_coefficients(xscale)
  4139. weights_y = _upsample_get_cubic_coefficients(yscale)
  4140. weights_precision_x, weights_precision_y = None, None
  4141. if input.dtype == torch.uint8:
  4142. weights_precision_x = _compute_weight_precision(weights_x)
  4143. weights_precision_y = _compute_weight_precision(weights_y)
  4144. weights_x = [
  4145. (w * (1 << weights_precision_x) + torch.sign(w) * 0.5).to(torch.int16)
  4146. for w in weights_x
  4147. ]
  4148. weights_y = [
  4149. (w * (1 << weights_precision_y) + torch.sign(w) * 0.5).to(torch.int16)
  4150. for w in weights_y
  4151. ]
  4152. def load_bounded(ys, xs):
  4153. y_idx = torch.clamp(ys, 0, in_h - 1)
  4154. x_idx = torch.clamp(xs, 0, in_w - 1)
  4155. v = aten._unsafe_index(input, [None, None, y_idx, x_idx])
  4156. return v
  4157. def get_x_interp(y):
  4158. src_x = tuple(load_bounded(y, x_ofs) for x_ofs in ixs_ofs)
  4159. if input.dtype == torch.uint8:
  4160. if weights_precision_x is None:
  4161. raise AssertionError(
  4162. "weights_precision_x must not be None for uint8 input"
  4163. )
  4164. return _sum_tensors_uint8(src_x, weights_x, weights_precision_x)
  4165. return _sum_tensors(c1 * c2 for (c1, c2) in zip(src_x, weights_x))
  4166. src_y = tuple(get_x_interp(y_ofs) for y_ofs in iys_ofs)
  4167. if input.dtype == torch.uint8:
  4168. if weights_precision_y is None:
  4169. raise AssertionError("weights_precision_y must not be None for uint8 input")
  4170. result = _sum_tensors_uint8(src_y, weights_y, weights_precision_y)
  4171. else:
  4172. result = _sum_tensors(c1 * c2 for (c1, c2) in zip(src_y, weights_y))
  4173. # convert output to correct memory format, if necessary
  4174. memory_format = utils.suggest_memory_format(input)
  4175. result = result.contiguous(memory_format=memory_format)
  4176. return result
  4177. @register_decomposition(aten.upsample_bicubic2d.vec)
  4178. @aten.upsample_bicubic2d.vec.py_impl(DispatchKey.CompositeImplicitAutograd)
  4179. @aten.upsample_bicubic2d.vec.py_impl(DispatchKey.Autograd)
  4180. @out_wrapper()
  4181. @pw_cast_for_opmath
  4182. def upsample_bicubic2d_vec(
  4183. a: Tensor,
  4184. output_size: Optional[tuple[int, int]],
  4185. align_corners: bool,
  4186. scale_factors: Optional[tuple[float, float]] = None,
  4187. ) -> Tensor:
  4188. torch._check(
  4189. bool(output_size) + bool(scale_factors) == 1,
  4190. lambda: "Must specify exactly one of output_size and scale_factors.",
  4191. )
  4192. if output_size is None:
  4193. if scale_factors is None:
  4194. raise AssertionError(
  4195. "scale_factors must not be None when output_size is None"
  4196. )
  4197. output_size = cast(
  4198. tuple[int, int],
  4199. tuple(
  4200. sym_int(sym_float(w) * scale)
  4201. for w, scale in zip(a.shape[2:], scale_factors)
  4202. ),
  4203. )
  4204. scale_h, scale_w = scale_factors if scale_factors else (None, None)
  4205. return upsample_bicubic2d_default(a, output_size, align_corners, scale_h, scale_w)
  4206. @register_decomposition(aten.reflection_pad1d)
  4207. @register_decomposition(aten.reflection_pad2d)
  4208. @register_decomposition(aten.reflection_pad3d)
  4209. @pw_cast_for_opmath
  4210. @out_wrapper()
  4211. def _reflection_pad(a: Tensor, padding: tuple[int, ...]) -> Tensor:
  4212. def idx(left, middle, right):
  4213. dim_idx = torch.arange(-left, middle + right, device=a.device)
  4214. return middle - 1 - (middle - 1 - dim_idx.abs()).abs()
  4215. return _reflection_or_replication_pad(
  4216. a,
  4217. padding,
  4218. idx,
  4219. )
  4220. @register_decomposition(aten.replication_pad1d)
  4221. @register_decomposition(aten.replication_pad2d)
  4222. @register_decomposition(aten.replication_pad3d)
  4223. @pw_cast_for_opmath
  4224. @out_wrapper()
  4225. def _replication_pad(a: Tensor, padding: tuple[int, ...]) -> Tensor:
  4226. def idx(left, middle, right):
  4227. dim_idx = torch.arange(-left, middle + right, device=a.device)
  4228. return torch.clamp(dim_idx, 0, middle - 1)
  4229. return _reflection_or_replication_pad(
  4230. a,
  4231. padding,
  4232. idx,
  4233. )
  4234. def _reflection_or_replication_pad(
  4235. a: Tensor,
  4236. padding: tuple[int, ...],
  4237. idx_fn: Callable[[int, int, int], Tensor],
  4238. ) -> Tensor:
  4239. dim = len(padding) // 2
  4240. torch._check(
  4241. a.dim() in (dim + 1, dim + 2),
  4242. lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input",
  4243. )
  4244. inp_shape = a.shape[-dim:]
  4245. nc_dim = a.dim() - dim
  4246. padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)]
  4247. padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)]
  4248. result = a
  4249. for i in range(dim):
  4250. idx: list[Any] = [None] * result.dim()
  4251. idx[i + nc_dim] = idx_fn(padding_left[i], inp_shape[i], padding_right[i])
  4252. result = aten._unsafe_index(result, idx)
  4253. # convert output to correct memory format, if necessary
  4254. memory_format = utils.suggest_memory_format(result)
  4255. result = result.contiguous(memory_format=memory_format)
  4256. return result
  4257. @register_decomposition(aten.reflection_pad1d_backward)
  4258. @register_decomposition(aten.reflection_pad2d_backward)
  4259. @register_decomposition(aten.reflection_pad3d_backward)
  4260. @out_wrapper("grad_input")
  4261. def _reflection_pad_backward(grad_output, x, padding):
  4262. dim = len(padding) // 2
  4263. dhw = [h - 1 for h in x.shape[-dim:]]
  4264. padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)]
  4265. padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)]
  4266. indices = []
  4267. for i in range(x.ndim):
  4268. view_shape = [1] * x.ndim
  4269. view_shape[i] = -1
  4270. indices.append(torch.arange(x.shape[i], device=x.device).view(view_shape))
  4271. b = indices[:-dim]
  4272. xyz = indices[-dim:]
  4273. def index_range_condition(index_range):
  4274. i, lb, ub = index_range
  4275. return torch.logical_and(i >= lb, i <= ub)
  4276. # Areas after reflection:
  4277. #
  4278. # top-left | top | top-right
  4279. # -----------------------------------------
  4280. # left | center | right
  4281. # -----------------------------------------
  4282. # bottom-left | bottom | bottom-right
  4283. #
  4284. # The center area is the original matrix. Other areas are reflections.
  4285. center = [xyz[i] + padding_left[i] for i in range(dim)]
  4286. left_reflect = [padding_left[i] - xyz[i] for i in range(dim)]
  4287. right_reflect = [2 * dhw[i] + padding_left[i] - xyz[i] for i in range(dim)]
  4288. # Accumulate gradients from different areas
  4289. # If some of the padding is negative, center load is not always valid
  4290. range_c = [
  4291. (center[i], 0, dhw[i] + padding_left[i] + padding_right[i]) for i in range(dim)
  4292. ]
  4293. cond = functools.reduce(
  4294. aten.logical_and, [index_range_condition(range_c[i]) for i in range(dim)]
  4295. )
  4296. grad = aten._unsafe_masked_index(grad_output, cond, b + center, 0.0)
  4297. def accumulate(grad, out, index_ranges):
  4298. # If the upper bound is less than the lower bound, we can get rid of one accumulation.
  4299. # This happens when the padding size is zero.
  4300. for i in range(dim):
  4301. upper_less_than_lower = index_ranges[i][2] < index_ranges[i][1]
  4302. if isinstance(upper_less_than_lower, bool) and upper_less_than_lower:
  4303. return grad
  4304. cond = functools.reduce(
  4305. aten.logical_and,
  4306. [index_range_condition(index_range) for index_range in index_ranges],
  4307. )
  4308. g = aten._unsafe_masked_index(grad_output, cond, b + out, 0.0)
  4309. return grad + g
  4310. for area in itertools.product(*[[-1, 0, 1] for _ in range(dim)]):
  4311. if area == tuple([0] * dim):
  4312. # center, this is already done.
  4313. continue
  4314. outs = []
  4315. index_ranges = []
  4316. for i in range(dim):
  4317. if area[i] == 0:
  4318. out = center[i]
  4319. index_range = range_c[i]
  4320. elif area[i] == -1:
  4321. out = left_reflect[i]
  4322. index_range = (xyz[i], 1, padding_left[i])
  4323. elif area[i] == 1:
  4324. out = right_reflect[i]
  4325. index_range = (xyz[i], dhw[i] - padding_right[i], dhw[i] - 1)
  4326. outs.append(out) # type: ignore[possibly-undefined]
  4327. index_ranges.append(index_range) # type: ignore[possibly-undefined]
  4328. grad = accumulate(grad, outs, index_ranges)
  4329. return grad
  4330. @register_decomposition(aten.aminmax)
  4331. @out_wrapper("min", "max")
  4332. def aminmax(self, *, dim=None, keepdim=False):
  4333. # pyrefly: ignore [bad-argument-type]
  4334. amin = torch.amin(self, dim=dim, keepdim=keepdim)
  4335. # pyrefly: ignore [bad-argument-type]
  4336. amax = torch.amax(self, dim=dim, keepdim=keepdim)
  4337. return amin, amax
  4338. @register_decomposition(aten.nansum)
  4339. @out_wrapper()
  4340. def nansum(self, dim=None, keepdim=False, *, dtype=None):
  4341. return aten.sum(torch.where(torch.isnan(self), 0, self), dim, keepdim, dtype=dtype)
  4342. @register_decomposition([aten.arange.default, aten.arange.out])
  4343. @out_wrapper()
  4344. def arange_default(
  4345. end: NumberType,
  4346. *,
  4347. dtype: Optional[torch.dtype] = None,
  4348. layout: torch.layout = torch.strided,
  4349. device: Optional[torch.device] = None,
  4350. pin_memory: bool = False,
  4351. ):
  4352. return aten.arange.start_step(
  4353. 0, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  4354. )
  4355. @register_decomposition([aten.arange.start])
  4356. def arange_start(
  4357. start: NumberType,
  4358. end: NumberType,
  4359. *,
  4360. dtype: Optional[torch.dtype] = None,
  4361. layout: torch.layout = torch.strided,
  4362. device: Optional[torch.device] = None,
  4363. pin_memory: bool = False,
  4364. ):
  4365. return aten.arange.start_step(
  4366. start, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  4367. )
  4368. @register_decomposition(out_dtype)
  4369. def out_dtype_decomp(*args, **kwargs):
  4370. from torch._higher_order_ops.out_dtype import out_dtype_dense
  4371. return out_dtype_dense(*args, **kwargs)
  4372. @register_decomposition(aten.multi_margin_loss)
  4373. @aten.multi_margin_loss.default.py_impl(DispatchKey.Autograd)
  4374. @out_wrapper()
  4375. def multi_margin_loss(
  4376. input: Tensor,
  4377. target: Tensor,
  4378. p: NumberType = 1,
  4379. margin: NumberType = 1,
  4380. weight: Optional[Tensor] = None,
  4381. reduction: int = Reduction.MEAN.value,
  4382. ) -> Tensor:
  4383. input = torch.atleast_2d(input)
  4384. target = torch.atleast_1d(target)
  4385. nframe = input.shape[0]
  4386. dim = input.shape[1]
  4387. torch._check(p == 1 or p == 2, lambda: "only p == 1 and p == 2 supported")
  4388. torch._check(
  4389. input.ndim == 2 and dim != 0,
  4390. lambda: f"Expected non-empty vector or matrix with optional 0-dim batch size, but got: {input.shape}",
  4391. )
  4392. torch._check(
  4393. target.ndim == 1 and target.numel() == nframe,
  4394. lambda: f"inconsistent target size, expected {nframe} but got {target.shape}",
  4395. )
  4396. if weight is not None:
  4397. weight = torch.atleast_1d(weight)
  4398. torch._check(
  4399. weight.ndim == 1 and weight.numel() == dim, # type: ignore[union-attr]
  4400. lambda: f"inconsistent weight size, expected {dim} but got {weight.shape}", # type: ignore[union-attr]
  4401. )
  4402. target = target.unsqueeze(1)
  4403. u = torch.gather(input, dim=1, index=target)
  4404. z = margin - u + input
  4405. z = z.clamp_min(0)
  4406. z = z if p == 1 else z * z
  4407. if weight is not None:
  4408. z = z * weight[target]
  4409. idx = torch.arange(dim, device=input.device)
  4410. z = torch.where(idx != target, z, 0)
  4411. if reduction == Reduction.MEAN.value:
  4412. return z.mean()
  4413. elif reduction == Reduction.SUM.value:
  4414. return z.sum() / z.shape[1]
  4415. else:
  4416. return z.mean(dim=1)
  4417. @register_decomposition(aten.multilabel_margin_loss_forward)
  4418. @aten.multilabel_margin_loss_forward.default.py_impl(DispatchKey.Autograd)
  4419. @out_wrapper("output", "is_target")
  4420. def multilabel_margin_loss_forward(
  4421. input: Tensor,
  4422. target: Tensor,
  4423. reduction: int,
  4424. ) -> tuple[Tensor, Tensor]:
  4425. orig_input_shape = input.shape
  4426. orig_target_shape = target.shape
  4427. input = torch.atleast_2d(input)
  4428. target = torch.atleast_2d(target)
  4429. dim = input.shape[1]
  4430. torch._check(
  4431. len(orig_input_shape) <= 2 and dim != 0,
  4432. lambda: f"Expected non-empty vector or matrix with optional 0-dim batch size, but got: {orig_input_shape}",
  4433. )
  4434. torch._check(
  4435. len(orig_target_shape) <= 2 and orig_target_shape == orig_input_shape,
  4436. lambda: f"inconsistent target size: {orig_target_shape} for input of size: {orig_input_shape}",
  4437. )
  4438. # ignores labels after the first -1, detects when -1 is not present
  4439. idx = torch.arange(dim, device=target.device)
  4440. is_end = target == -1
  4441. end_idx = torch.amin(torch.where(is_end, idx, dim), dim=-1, keepdim=True)
  4442. # target indices
  4443. target_mask = idx < end_idx
  4444. # masks target to be able to use gather, which doesn't allow -1
  4445. tidx0 = torch.where(target_mask, target, 0)
  4446. u = torch.gather(input, dim=-1, index=tidx0)
  4447. # is_target
  4448. tidx1 = torch.where(target_mask, target, -1)
  4449. is_target = torch.any(idx == tidx1.unsqueeze(dim=-1), dim=1)
  4450. # loss
  4451. z = 1.0 - u.T.unsqueeze(dim=-1) + input
  4452. z = z.clamp_min(0)
  4453. z = z / dim
  4454. # masks loss
  4455. z = torch.where(is_target, 0, z)
  4456. # reduction
  4457. if reduction == Reduction.MEAN.value:
  4458. z = z.sum(dim=(0, -1)).mean()
  4459. elif reduction == Reduction.SUM.value:
  4460. z = z.sum()
  4461. else:
  4462. z = z.sum(dim=(0, -1))
  4463. # result
  4464. is_target = is_target.to(input.dtype).reshape(orig_target_shape)
  4465. return z, is_target
  4466. # scaled_dot_product_attention used to be decomposed in pre-autograd, given that
  4467. # it calls _scaled_dot_product_attention_math and
  4468. # _scaled_dot_product_attention_math only has a CompositeImplicitAutograd
  4469. # kernel. As a result it's decomposed into ops with finer granularity.
  4470. # However recent PRs (#103826 #105131 #115913) added new logic in
  4471. # scaled_dot_product_attention and now it calls
  4472. # _scaled_dot_product_flash_attention_for_cpu in export path. This results
  4473. # in _scaled_dot_product_flash_attention_for_cpu showing up in export result.
  4474. # This decomposition ensures scaled_dot_product_attention is still decomposed
  4475. # the same way as before, i.e., going through
  4476. # _scaled_dot_product_attention_math. Notice that this decomp rule should be
  4477. # excluded by inductor.
  4478. @register_decomposition(aten._scaled_dot_product_flash_attention_for_cpu.default)
  4479. def scaled_dot_product_flash_attention_for_cpu(
  4480. query: Tensor,
  4481. key: Tensor,
  4482. value: Tensor,
  4483. dropout_p: float = 0.0,
  4484. is_causal: bool = False,
  4485. *,
  4486. attn_mask: Optional[Tensor] = None,
  4487. scale: Optional[float] = None,
  4488. ) -> tuple[Tensor, Tensor]:
  4489. torch._check(
  4490. torch.is_floating_point(query),
  4491. lambda: f"query must be FP32, FP64, BF16, FP16 but got {query.dtype}",
  4492. )
  4493. torch._check(
  4494. query.dim() == 4 and key.dim() == 4 and value.dim() == 4,
  4495. lambda: f"q, k, v must be a 4 dimensional tensor, got {query.dim()}, {key.dim()}, {value.dim()}",
  4496. )
  4497. torch._check(
  4498. dropout_p == 0.0, lambda: f"dropout probability must be zero, got {dropout_p}"
  4499. )
  4500. torch._check(
  4501. query.shape[3] == value.shape[3] and key.shape[3] == value.shape[3],
  4502. lambda: "q, k, v should have the same head size",
  4503. )
  4504. output, attn = aten._scaled_dot_product_attention_math.default(
  4505. query,
  4506. key,
  4507. value,
  4508. attn_mask=attn_mask,
  4509. dropout_p=dropout_p,
  4510. is_causal=is_causal,
  4511. dropout_mask=None,
  4512. scale=scale,
  4513. enable_gqa=query.size(1) != key.size(1),
  4514. )
  4515. # Why this change?
  4516. # In pre-dispatch export scaled_dot_product_attention is executed via
  4517. # * flash_attention.
  4518. # flash_attention allocates output tensor as (N, H, L, E) (see PR #134656)
  4519. # assume x: [N, H, L, E] is the output sdpa
  4520. # In MHA code, this output is then permuted via (2, 0, 1, 3) to get
  4521. # (L, N, H, E) dim tensor
  4522. # x = x.permute(2, 0, 1, 3).contiguous() and the viewed via
  4523. # x = x.view(L * N, H * E)
  4524. # During pre autograd dispatch call to contiguous is not traced because
  4525. # flash_attention output after the x.permute is already contiguous
  4526. # on which the view is valid
  4527. # However, during 2nd stage export, post-dispatch, we run _match variant
  4528. # instead of flash* to get the decomposition. _match variant returns
  4529. # x: [N, H, L, E] applying x.permute(2, 0, 1, 3) returns
  4530. # x: [L, N, H, E] and without converting this to contiguous tensor
  4531. # subsequent view is not valid and the export fails
  4532. # solution is to maintain the return tensor view from the decomp to be
  4533. # exactly same as *flash* variant.
  4534. # Really the invariant you want to maintain is:
  4535. # pre-dispatch op-output and its decomposed representation must
  4536. # return tensor with same view and dims
  4537. output = (
  4538. output.permute(2, 0, 1, 3)
  4539. .contiguous(memory_format=torch.contiguous_format)
  4540. .permute(1, 2, 0, 3)
  4541. )
  4542. return output, attn
  4543. def register_inplace(aten_op, outplace_op):
  4544. @register_decomposition(aten_op)
  4545. def inplace_op(*args, **kwargs):
  4546. out = outplace_op(*args, **kwargs)
  4547. return args[0].copy_(out)
  4548. return inplace_op
  4549. @register_decomposition([aten.baddbmm])
  4550. @out_wrapper(exact_dtype=True)
  4551. @pw_cast_for_opmath
  4552. def baddbmm(self, batch1, batch2, beta=1, alpha=1):
  4553. if not self.is_floating_point() and not self.is_complex():
  4554. beta = int(beta)
  4555. alpha = int(alpha)
  4556. result = torch.bmm(batch1, batch2)
  4557. if not isinstance(alpha, numbers.Number) or alpha != 1:
  4558. result = result * alpha
  4559. if beta == 0:
  4560. return result
  4561. if not isinstance(beta, numbers.Number) or beta != 1:
  4562. self = self * beta
  4563. return self + result
  4564. @register_decomposition(aten.floor_divide)
  4565. @out_wrapper()
  4566. def floor_divide(self, other):
  4567. return torch.div(self, other, rounding_mode="floor")
  4568. @register_decomposition(aten.sym_numel)
  4569. def sym_numel(t):
  4570. return functools.reduce(operator.mul, t.shape, 1)
  4571. @register_decomposition([aten.sum.default, aten.sum.out])
  4572. def sum_default(
  4573. self: Tensor,
  4574. *,
  4575. dtype: Optional[torch.dtype] = None,
  4576. out: Optional[Tensor] = None,
  4577. ) -> Tensor:
  4578. if out is None:
  4579. return aten.sum.dim_IntList(self, [], dtype=dtype)
  4580. else:
  4581. return aten.sum.IntList_out(self, [], dtype=dtype, out=out)
  4582. @register_decomposition([aten.squeeze.default, aten.squeeze.dim])
  4583. def squeeze_default(self: Tensor, dim: Optional[int] = None):
  4584. # handle a scalar directly
  4585. if not isinstance(self, torch.Tensor):
  4586. return self
  4587. # perform squeeze
  4588. if dim is None:
  4589. return aten.squeeze.dims(self, list(range(self.dim())))
  4590. else:
  4591. return aten.squeeze.dims(self, [dim])
  4592. @register_decomposition(torch.ops.aten._weight_norm_interface)
  4593. def _weight_norm_interface(v, g, dim=0):
  4594. # https://github.com/pytorch/pytorch/blob/852f8526c52190125446adc9a6ecbcc28fb66182/aten/src/ATen/native/WeightNorm.cpp#L58
  4595. keep_dim = tuple(i for i in range(len(v.shape)) if i != dim)
  4596. # align with cuda behavior, keep norm in 'float' when g is 'bfloat16'
  4597. norm_dtype = torch.float if g.dtype == torch.bfloat16 else None
  4598. norm = v.norm(2, keep_dim, keepdim=True, dtype=norm_dtype)
  4599. return v * (g / norm.to(g.dtype)), norm
  4600. @register_decomposition(aten.isin)
  4601. @out_wrapper()
  4602. def isin(elements, test_elements, *, assume_unique=False, invert=False):
  4603. # handle when either elements or test_elements are Scalars (they can't both be)
  4604. if not isinstance(elements, torch.Tensor):
  4605. elements = torch.scalar_tensor(elements, device=test_elements.device)
  4606. if not isinstance(test_elements, torch.Tensor):
  4607. if invert:
  4608. return torch.ne(elements, test_elements)
  4609. else:
  4610. return torch.eq(elements, test_elements)
  4611. if test_elements.numel() < 10.0 * pow(elements.numel(), 0.145):
  4612. return isin_default(elements, test_elements, invert=invert)
  4613. else:
  4614. return isin_sorting(
  4615. elements, test_elements, assume_unique=assume_unique, invert=invert
  4616. )
  4617. @register_decomposition(aten.bernoulli.default)
  4618. def bernoulli(
  4619. self: torch.Tensor,
  4620. *,
  4621. generator: Optional[torch.Generator] = None,
  4622. ) -> torch.Tensor:
  4623. if generator is None:
  4624. raw_p = torch.rand(self.size(), dtype=torch.float32, device=self.device)
  4625. else:
  4626. raw_p = torch.rand(
  4627. self.size(),
  4628. generator=generator,
  4629. dtype=torch.float32,
  4630. device=self.device,
  4631. )
  4632. p = (raw_p < self).to(self.dtype)
  4633. return p
  4634. def isin_default(elements, test_elements, *, invert=False):
  4635. if elements.numel() == 0:
  4636. return torch.empty_like(elements, dtype=torch.bool)
  4637. if test_elements.ndim == 0:
  4638. res = elements == test_elements
  4639. return ~res if invert else res
  4640. expanded_elem_shape = elements.shape + (1,) * test_elements.ndim
  4641. x = elements.view(expanded_elem_shape)
  4642. dim = tuple(range(-1, -test_elements.ndim - 1, -1))
  4643. res = (x == test_elements).any(dim=dim)
  4644. return ~res if invert else res
  4645. def isin_sorting(elements, test_elements, *, assume_unique=False, invert=False):
  4646. elements_flat = elements.flatten()
  4647. test_elements_flat = test_elements.flatten()
  4648. if assume_unique:
  4649. # This is the same as the aten implementation. For
  4650. # assume_unique=False, we cannot use unique() here, so we use a
  4651. # version with searchsorted instead.
  4652. all_elements = torch.cat([elements_flat, test_elements_flat])
  4653. sorted_elements, sorted_order = torch.sort(all_elements, stable=True)
  4654. duplicate_mask = sorted_elements[1:] == sorted_elements[:-1]
  4655. duplicate_mask = torch.constant_pad_nd(duplicate_mask, [0, 1], False)
  4656. if invert:
  4657. duplicate_mask = duplicate_mask.logical_not()
  4658. mask = torch.empty_like(duplicate_mask)
  4659. mask = mask.index_copy(0, sorted_order, duplicate_mask)
  4660. return mask[0 : elements.numel()].reshape(elements.shape)
  4661. else:
  4662. sorted_test_elements, _ = torch.sort(test_elements_flat)
  4663. idx = torch.searchsorted(sorted_test_elements, elements_flat)
  4664. test_idx = torch.where(idx < sorted_test_elements.numel(), idx, 0)
  4665. cmp = sorted_test_elements[test_idx] == elements_flat
  4666. cmp = cmp.logical_not() if invert else cmp
  4667. return cmp.reshape(elements.shape)
  4668. @register_decomposition(aten.take)
  4669. @out_wrapper()
  4670. def take(self, index):
  4671. flattened = self.reshape(-1)
  4672. return flattened[index]
  4673. @register_decomposition(aten.resize_as)
  4674. def resize_as(self, other, memory_format=None):
  4675. if memory_format is None:
  4676. memory_format = torch.contiguous_format
  4677. if memory_format == torch.preserve_format:
  4678. memory_format = suggest_memory_format(other)
  4679. return aten.resize(self, other.shape, memory_format=memory_format)
  4680. register_inplace(aten.addbmm_, aten.addbmm)
  4681. register_inplace(aten.addmm_, aten.addmm)
  4682. register_inplace(aten.addmv_, aten.addmv)
  4683. register_inplace(aten.baddbmm_, aten.baddbmm)
  4684. register_inplace(aten.fill_, aten.fill)
  4685. register_inplace(aten.gelu_, aten.gelu)
  4686. register_inplace(aten.hardswish_, aten.hardswish)
  4687. register_inplace(aten.hardtanh_, aten.hardtanh)
  4688. register_inplace(aten.hardsigmoid_, aten.hardsigmoid)
  4689. register_inplace(aten.__iand__, aten.__and__)
  4690. register_inplace(aten.__ilshift__, aten.__lshift__)
  4691. register_inplace(aten.index_put_, aten.index_put)
  4692. register_inplace(aten.index_reduce_, aten.index_reduce)
  4693. register_inplace(aten.__ior__, aten.__or__)
  4694. register_inplace(aten.__irshift__, aten.__rshift__)
  4695. register_inplace(aten.__ixor__, aten.__xor__)
  4696. register_inplace(aten.ldexp_, aten.ldexp)
  4697. register_inplace(aten.leaky_relu_, aten.leaky_relu)
  4698. register_inplace(aten.logit_, aten.logit)
  4699. register_inplace(aten.relu_, aten.relu)
  4700. register_inplace(aten.renorm_, aten.renorm)
  4701. register_inplace(aten.round_, aten.round)
  4702. register_inplace(aten.scatter_, aten.scatter)
  4703. register_inplace(aten.scatter_add_, aten.scatter_add)
  4704. register_inplace(aten.scatter_reduce_, aten.scatter_reduce)
  4705. register_inplace(aten.silu_, aten.silu)