dnnl.hpp 642 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270427142724273427442754276427742784279428042814282428342844285428642874288428942904291429242934294429542964297429842994300430143024303430443054306430743084309431043114312431343144315431643174318431943204321432243234324432543264327432843294330433143324333433443354336433743384339434043414342434343444345434643474348434943504351435243534354435543564357435843594360436143624363436443654366436743684369437043714372437343744375437643774378437943804381438243834384438543864387438843894390439143924393439443954396439743984399440044014402440344044405440644074408440944104411441244134414441544164417441844194420442144224423442444254426442744284429443044314432443344344435443644374438443944404441444244434444444544464447444844494450445144524453445444554456445744584459446044614462446344644465446644674468446944704471447244734474447544764477447844794480448144824483448444854486448744884489449044914492449344944495449644974498449945004501450245034504450545064507450845094510451145124513451445154516451745184519452045214522452345244525452645274528452945304531453245334534453545364537453845394540454145424543454445454546454745484549455045514552455345544555455645574558455945604561456245634564456545664567456845694570457145724573457445754576457745784579458045814582458345844585458645874588458945904591459245934594459545964597459845994600460146024603460446054606460746084609461046114612461346144615461646174618461946204621462246234624462546264627462846294630463146324633463446354636463746384639464046414642464346444645464646474648464946504651465246534654465546564657465846594660466146624663466446654666466746684669467046714672467346744675467646774678467946804681468246834684468546864687468846894690469146924693469446954696469746984699470047014702470347044705470647074708470947104711471247134714471547164717471847194720472147224723472447254726472747284729473047314732473347344735473647374738473947404741474247434744474547464747474847494750475147524753475447554756475747584759476047614762476347644765476647674768476947704771477247734774477547764777477847794780478147824783478447854786478747884789479047914792479347944795479647974798479948004801480248034804480548064807480848094810481148124813481448154816481748184819482048214822482348244825482648274828482948304831483248334834483548364837483848394840484148424843484448454846484748484849485048514852485348544855485648574858485948604861486248634864486548664867486848694870487148724873487448754876487748784879488048814882488348844885488648874888488948904891489248934894489548964897489848994900490149024903490449054906490749084909491049114912491349144915491649174918491949204921492249234924492549264927492849294930493149324933493449354936493749384939494049414942494349444945494649474948494949504951495249534954495549564957495849594960496149624963496449654966496749684969497049714972497349744975497649774978497949804981498249834984498549864987498849894990499149924993499449954996499749984999500050015002500350045005500650075008500950105011501250135014501550165017501850195020502150225023502450255026502750285029503050315032503350345035503650375038503950405041504250435044504550465047504850495050505150525053505450555056505750585059506050615062506350645065506650675068506950705071507250735074507550765077507850795080508150825083508450855086508750885089509050915092509350945095509650975098509951005101510251035104510551065107510851095110511151125113511451155116511751185119512051215122512351245125512651275128512951305131513251335134513551365137513851395140514151425143514451455146514751485149515051515152515351545155515651575158515951605161516251635164516551665167516851695170517151725173517451755176517751785179518051815182518351845185518651875188518951905191519251935194519551965197519851995200520152025203520452055206520752085209521052115212521352145215521652175218521952205221522252235224522552265227522852295230523152325233523452355236523752385239524052415242524352445245524652475248524952505251525252535254525552565257525852595260526152625263526452655266526752685269527052715272527352745275527652775278527952805281528252835284528552865287528852895290529152925293529452955296529752985299530053015302530353045305530653075308530953105311531253135314531553165317531853195320532153225323532453255326532753285329533053315332533353345335533653375338533953405341534253435344534553465347534853495350535153525353535453555356535753585359536053615362536353645365536653675368536953705371537253735374537553765377537853795380538153825383538453855386538753885389539053915392539353945395539653975398539954005401540254035404540554065407540854095410541154125413541454155416541754185419542054215422542354245425542654275428542954305431543254335434543554365437543854395440544154425443544454455446544754485449545054515452545354545455545654575458545954605461546254635464546554665467546854695470547154725473547454755476547754785479548054815482548354845485548654875488548954905491549254935494549554965497549854995500550155025503550455055506550755085509551055115512551355145515551655175518551955205521552255235524552555265527552855295530553155325533553455355536553755385539554055415542554355445545554655475548554955505551555255535554555555565557555855595560556155625563556455655566556755685569557055715572557355745575557655775578557955805581558255835584558555865587558855895590559155925593559455955596559755985599560056015602560356045605560656075608560956105611561256135614561556165617561856195620562156225623562456255626562756285629563056315632563356345635563656375638563956405641564256435644564556465647564856495650565156525653565456555656565756585659566056615662566356645665566656675668566956705671567256735674567556765677567856795680568156825683568456855686568756885689569056915692569356945695569656975698569957005701570257035704570557065707570857095710571157125713571457155716571757185719572057215722572357245725572657275728572957305731573257335734573557365737573857395740574157425743574457455746574757485749575057515752575357545755575657575758575957605761576257635764576557665767576857695770577157725773577457755776577757785779578057815782578357845785578657875788578957905791579257935794579557965797579857995800580158025803580458055806580758085809581058115812581358145815581658175818581958205821582258235824582558265827582858295830583158325833583458355836583758385839584058415842584358445845584658475848584958505851585258535854585558565857585858595860586158625863586458655866586758685869587058715872587358745875587658775878587958805881588258835884588558865887588858895890589158925893589458955896589758985899590059015902590359045905590659075908590959105911591259135914591559165917591859195920592159225923592459255926592759285929593059315932593359345935593659375938593959405941594259435944594559465947594859495950595159525953595459555956595759585959596059615962596359645965596659675968596959705971597259735974597559765977597859795980598159825983598459855986598759885989599059915992599359945995599659975998599960006001600260036004600560066007600860096010601160126013601460156016601760186019602060216022602360246025602660276028602960306031603260336034603560366037603860396040604160426043604460456046604760486049605060516052605360546055605660576058605960606061606260636064606560666067606860696070607160726073607460756076607760786079608060816082608360846085608660876088608960906091609260936094609560966097609860996100610161026103610461056106610761086109611061116112611361146115611661176118611961206121612261236124612561266127612861296130613161326133613461356136613761386139614061416142614361446145614661476148614961506151615261536154615561566157615861596160616161626163616461656166616761686169617061716172617361746175617661776178617961806181618261836184618561866187618861896190619161926193619461956196619761986199620062016202620362046205620662076208620962106211621262136214621562166217621862196220622162226223622462256226622762286229623062316232623362346235623662376238623962406241624262436244624562466247624862496250625162526253625462556256625762586259626062616262626362646265626662676268626962706271627262736274627562766277627862796280628162826283628462856286628762886289629062916292629362946295629662976298629963006301630263036304630563066307630863096310631163126313631463156316631763186319632063216322632363246325632663276328632963306331633263336334633563366337633863396340634163426343634463456346634763486349635063516352635363546355635663576358635963606361636263636364636563666367636863696370637163726373637463756376637763786379638063816382638363846385638663876388638963906391639263936394639563966397639863996400640164026403640464056406640764086409641064116412641364146415641664176418641964206421642264236424642564266427642864296430643164326433643464356436643764386439644064416442644364446445644664476448644964506451645264536454645564566457645864596460646164626463646464656466646764686469647064716472647364746475647664776478647964806481648264836484648564866487648864896490649164926493649464956496649764986499650065016502650365046505650665076508650965106511651265136514651565166517651865196520652165226523652465256526652765286529653065316532653365346535653665376538653965406541654265436544654565466547654865496550655165526553655465556556655765586559656065616562656365646565656665676568656965706571657265736574657565766577657865796580658165826583658465856586658765886589659065916592659365946595659665976598659966006601660266036604660566066607660866096610661166126613661466156616661766186619662066216622662366246625662666276628662966306631663266336634663566366637663866396640664166426643664466456646664766486649665066516652665366546655665666576658665966606661666266636664666566666667666866696670667166726673667466756676667766786679668066816682668366846685668666876688668966906691669266936694669566966697669866996700670167026703670467056706670767086709671067116712671367146715671667176718671967206721672267236724672567266727672867296730673167326733673467356736673767386739674067416742674367446745674667476748674967506751675267536754675567566757675867596760676167626763676467656766676767686769677067716772677367746775677667776778677967806781678267836784678567866787678867896790679167926793679467956796679767986799680068016802680368046805680668076808680968106811681268136814681568166817681868196820682168226823682468256826682768286829683068316832683368346835683668376838683968406841684268436844684568466847684868496850685168526853685468556856685768586859686068616862686368646865686668676868686968706871687268736874687568766877687868796880688168826883688468856886688768886889689068916892689368946895689668976898689969006901690269036904690569066907690869096910691169126913691469156916691769186919692069216922692369246925692669276928692969306931693269336934693569366937693869396940694169426943694469456946694769486949695069516952695369546955695669576958695969606961696269636964696569666967696869696970697169726973697469756976697769786979698069816982698369846985698669876988698969906991699269936994699569966997699869997000700170027003700470057006700770087009701070117012701370147015701670177018701970207021702270237024702570267027702870297030703170327033703470357036703770387039704070417042704370447045704670477048704970507051705270537054705570567057705870597060706170627063706470657066706770687069707070717072707370747075707670777078707970807081708270837084708570867087708870897090709170927093709470957096709770987099710071017102710371047105710671077108710971107111711271137114711571167117711871197120712171227123712471257126712771287129713071317132713371347135713671377138713971407141714271437144714571467147714871497150715171527153715471557156715771587159716071617162716371647165716671677168716971707171717271737174717571767177717871797180718171827183718471857186718771887189719071917192719371947195719671977198719972007201720272037204720572067207720872097210721172127213721472157216721772187219722072217222722372247225722672277228722972307231723272337234723572367237723872397240724172427243724472457246724772487249725072517252725372547255725672577258725972607261726272637264726572667267726872697270727172727273727472757276727772787279728072817282728372847285728672877288728972907291729272937294729572967297729872997300730173027303730473057306730773087309731073117312731373147315731673177318731973207321732273237324732573267327732873297330733173327333733473357336733773387339734073417342734373447345734673477348734973507351735273537354735573567357735873597360736173627363736473657366736773687369737073717372737373747375737673777378737973807381738273837384738573867387738873897390739173927393739473957396739773987399740074017402740374047405740674077408740974107411741274137414741574167417741874197420742174227423742474257426742774287429743074317432743374347435743674377438743974407441744274437444744574467447744874497450745174527453745474557456745774587459746074617462746374647465746674677468746974707471747274737474747574767477747874797480748174827483748474857486748774887489749074917492749374947495749674977498749975007501750275037504750575067507750875097510751175127513751475157516751775187519752075217522752375247525752675277528752975307531753275337534753575367537753875397540754175427543754475457546754775487549755075517552755375547555755675577558755975607561756275637564756575667567756875697570757175727573757475757576757775787579758075817582758375847585758675877588758975907591759275937594759575967597759875997600760176027603760476057606760776087609761076117612761376147615761676177618761976207621762276237624762576267627762876297630763176327633763476357636763776387639764076417642764376447645764676477648764976507651765276537654765576567657765876597660766176627663766476657666766776687669767076717672767376747675767676777678767976807681768276837684768576867687768876897690769176927693769476957696769776987699770077017702770377047705770677077708770977107711771277137714771577167717771877197720772177227723772477257726772777287729773077317732773377347735773677377738773977407741774277437744774577467747774877497750775177527753775477557756775777587759776077617762776377647765776677677768776977707771777277737774777577767777777877797780778177827783778477857786778777887789779077917792779377947795779677977798779978007801780278037804780578067807780878097810781178127813781478157816781778187819782078217822782378247825782678277828782978307831783278337834783578367837783878397840784178427843784478457846784778487849785078517852785378547855785678577858785978607861786278637864786578667867786878697870787178727873787478757876787778787879788078817882788378847885788678877888788978907891789278937894789578967897789878997900790179027903790479057906790779087909791079117912791379147915791679177918791979207921792279237924792579267927792879297930793179327933793479357936793779387939794079417942794379447945794679477948794979507951795279537954795579567957795879597960796179627963796479657966796779687969797079717972797379747975797679777978797979807981798279837984798579867987798879897990799179927993799479957996799779987999800080018002800380048005800680078008800980108011801280138014801580168017801880198020802180228023802480258026802780288029803080318032803380348035803680378038803980408041804280438044804580468047804880498050805180528053805480558056805780588059806080618062806380648065806680678068806980708071807280738074807580768077807880798080808180828083808480858086808780888089809080918092809380948095809680978098809981008101810281038104810581068107810881098110811181128113811481158116811781188119812081218122812381248125812681278128812981308131813281338134813581368137813881398140814181428143814481458146814781488149815081518152815381548155815681578158815981608161816281638164816581668167816881698170817181728173817481758176817781788179818081818182818381848185818681878188818981908191819281938194819581968197819881998200820182028203820482058206820782088209821082118212821382148215821682178218821982208221822282238224822582268227822882298230823182328233823482358236823782388239824082418242824382448245824682478248824982508251825282538254825582568257825882598260826182628263826482658266826782688269827082718272827382748275827682778278827982808281828282838284828582868287828882898290829182928293829482958296829782988299830083018302830383048305830683078308830983108311831283138314831583168317831883198320832183228323832483258326832783288329833083318332833383348335833683378338833983408341834283438344834583468347834883498350835183528353835483558356835783588359836083618362836383648365836683678368836983708371837283738374837583768377837883798380838183828383838483858386838783888389839083918392839383948395839683978398839984008401840284038404840584068407840884098410841184128413841484158416841784188419842084218422842384248425842684278428842984308431843284338434843584368437843884398440844184428443844484458446844784488449845084518452845384548455845684578458845984608461846284638464846584668467846884698470847184728473847484758476847784788479848084818482848384848485848684878488848984908491849284938494849584968497849884998500850185028503850485058506850785088509851085118512851385148515851685178518851985208521852285238524852585268527852885298530853185328533853485358536853785388539854085418542854385448545854685478548854985508551855285538554855585568557855885598560856185628563856485658566856785688569857085718572857385748575857685778578857985808581858285838584858585868587858885898590859185928593859485958596859785988599860086018602860386048605860686078608860986108611861286138614861586168617861886198620862186228623862486258626862786288629863086318632863386348635863686378638863986408641864286438644864586468647864886498650865186528653865486558656865786588659866086618662866386648665866686678668866986708671867286738674867586768677867886798680868186828683868486858686868786888689869086918692869386948695869686978698869987008701870287038704870587068707870887098710871187128713871487158716871787188719872087218722872387248725872687278728872987308731873287338734873587368737873887398740874187428743874487458746874787488749875087518752875387548755875687578758875987608761876287638764876587668767876887698770877187728773877487758776877787788779878087818782878387848785878687878788878987908791879287938794879587968797879887998800880188028803880488058806880788088809881088118812881388148815881688178818881988208821882288238824882588268827882888298830883188328833883488358836883788388839884088418842884388448845884688478848884988508851885288538854885588568857885888598860886188628863886488658866886788688869887088718872887388748875887688778878887988808881888288838884888588868887888888898890889188928893889488958896889788988899890089018902890389048905890689078908890989108911891289138914891589168917891889198920892189228923892489258926892789288929893089318932893389348935893689378938893989408941894289438944894589468947894889498950895189528953895489558956895789588959896089618962896389648965896689678968896989708971897289738974897589768977897889798980898189828983898489858986898789888989899089918992899389948995899689978998899990009001900290039004900590069007900890099010901190129013901490159016901790189019902090219022902390249025902690279028902990309031903290339034903590369037903890399040904190429043904490459046904790489049905090519052905390549055905690579058905990609061906290639064906590669067906890699070907190729073907490759076907790789079908090819082908390849085908690879088908990909091909290939094909590969097909890999100910191029103910491059106910791089109911091119112911391149115911691179118911991209121912291239124912591269127912891299130913191329133913491359136913791389139914091419142914391449145914691479148914991509151915291539154915591569157915891599160916191629163916491659166916791689169917091719172917391749175917691779178917991809181918291839184918591869187918891899190919191929193919491959196919791989199920092019202920392049205920692079208920992109211921292139214921592169217921892199220922192229223922492259226922792289229923092319232923392349235923692379238923992409241924292439244924592469247924892499250925192529253925492559256925792589259926092619262926392649265926692679268926992709271927292739274927592769277927892799280928192829283928492859286928792889289929092919292929392949295929692979298929993009301930293039304930593069307930893099310931193129313931493159316931793189319932093219322932393249325932693279328932993309331933293339334933593369337933893399340934193429343934493459346934793489349935093519352935393549355935693579358935993609361936293639364936593669367936893699370937193729373937493759376937793789379938093819382938393849385938693879388938993909391939293939394939593969397939893999400940194029403940494059406940794089409941094119412941394149415941694179418941994209421942294239424942594269427942894299430943194329433943494359436943794389439944094419442944394449445944694479448944994509451945294539454945594569457945894599460946194629463946494659466946794689469947094719472947394749475947694779478947994809481948294839484948594869487948894899490949194929493949494959496949794989499950095019502950395049505950695079508950995109511951295139514951595169517951895199520952195229523952495259526952795289529953095319532953395349535953695379538953995409541954295439544954595469547954895499550955195529553955495559556955795589559956095619562956395649565956695679568956995709571957295739574957595769577957895799580958195829583958495859586958795889589959095919592959395949595959695979598959996009601960296039604960596069607960896099610961196129613961496159616961796189619962096219622962396249625962696279628962996309631963296339634963596369637963896399640964196429643964496459646964796489649965096519652965396549655965696579658965996609661966296639664966596669667966896699670967196729673967496759676967796789679968096819682968396849685968696879688968996909691969296939694969596969697969896999700970197029703970497059706970797089709971097119712971397149715971697179718971997209721972297239724972597269727972897299730973197329733973497359736973797389739974097419742974397449745974697479748974997509751975297539754975597569757975897599760976197629763976497659766976797689769977097719772977397749775977697779778977997809781978297839784978597869787978897899790979197929793979497959796979797989799980098019802980398049805980698079808980998109811981298139814981598169817981898199820982198229823982498259826982798289829983098319832983398349835983698379838983998409841984298439844984598469847984898499850985198529853985498559856985798589859986098619862986398649865986698679868986998709871987298739874987598769877987898799880988198829883988498859886988798889889989098919892989398949895989698979898989999009901990299039904990599069907990899099910991199129913991499159916991799189919992099219922992399249925992699279928992999309931993299339934993599369937993899399940994199429943994499459946994799489949995099519952995399549955995699579958995999609961996299639964996599669967996899699970997199729973997499759976997799789979998099819982998399849985998699879988998999909991999299939994999599969997999899991000010001100021000310004100051000610007100081000910010100111001210013100141001510016100171001810019100201002110022100231002410025100261002710028100291003010031100321003310034100351003610037100381003910040100411004210043100441004510046100471004810049100501005110052100531005410055100561005710058100591006010061100621006310064100651006610067100681006910070100711007210073100741007510076100771007810079100801008110082100831008410085100861008710088100891009010091100921009310094100951009610097100981009910100101011010210103101041010510106101071010810109101101011110112101131011410115101161011710118101191012010121101221012310124101251012610127101281012910130101311013210133101341013510136101371013810139101401014110142101431014410145101461014710148101491015010151101521015310154101551015610157101581015910160101611016210163101641016510166101671016810169101701017110172101731017410175101761017710178101791018010181101821018310184101851018610187101881018910190101911019210193101941019510196101971019810199102001020110202102031020410205102061020710208102091021010211102121021310214102151021610217102181021910220102211022210223102241022510226102271022810229102301023110232102331023410235102361023710238102391024010241102421024310244102451024610247102481024910250102511025210253102541025510256102571025810259102601026110262102631026410265102661026710268102691027010271102721027310274102751027610277102781027910280102811028210283102841028510286102871028810289102901029110292102931029410295102961029710298102991030010301103021030310304103051030610307103081030910310103111031210313103141031510316103171031810319103201032110322103231032410325103261032710328103291033010331103321033310334103351033610337103381033910340103411034210343103441034510346103471034810349103501035110352103531035410355103561035710358103591036010361103621036310364103651036610367103681036910370103711037210373103741037510376103771037810379103801038110382103831038410385103861038710388103891039010391103921039310394103951039610397103981039910400104011040210403104041040510406104071040810409104101041110412104131041410415104161041710418104191042010421104221042310424104251042610427104281042910430104311043210433104341043510436104371043810439104401044110442104431044410445104461044710448104491045010451104521045310454104551045610457104581045910460104611046210463104641046510466104671046810469104701047110472104731047410475104761047710478104791048010481104821048310484104851048610487104881048910490104911049210493104941049510496104971049810499105001050110502105031050410505105061050710508105091051010511105121051310514105151051610517105181051910520105211052210523105241052510526105271052810529105301053110532105331053410535105361053710538105391054010541105421054310544105451054610547105481054910550105511055210553105541055510556105571055810559105601056110562105631056410565105661056710568105691057010571105721057310574105751057610577105781057910580105811058210583105841058510586105871058810589105901059110592105931059410595105961059710598105991060010601106021060310604106051060610607106081060910610106111061210613106141061510616106171061810619106201062110622106231062410625106261062710628106291063010631106321063310634106351063610637106381063910640106411064210643106441064510646106471064810649106501065110652106531065410655106561065710658106591066010661106621066310664106651066610667106681066910670106711067210673106741067510676106771067810679106801068110682106831068410685106861068710688106891069010691106921069310694106951069610697106981069910700107011070210703107041070510706107071070810709107101071110712107131071410715107161071710718107191072010721107221072310724107251072610727107281072910730107311073210733107341073510736107371073810739107401074110742107431074410745107461074710748107491075010751107521075310754107551075610757107581075910760107611076210763107641076510766107671076810769107701077110772107731077410775107761077710778107791078010781107821078310784107851078610787107881078910790107911079210793107941079510796107971079810799108001080110802108031080410805108061080710808108091081010811108121081310814108151081610817108181081910820108211082210823108241082510826108271082810829108301083110832108331083410835108361083710838108391084010841108421084310844108451084610847108481084910850108511085210853108541085510856108571085810859108601086110862108631086410865108661086710868108691087010871108721087310874108751087610877108781087910880108811088210883108841088510886108871088810889108901089110892108931089410895108961089710898108991090010901109021090310904109051090610907109081090910910109111091210913109141091510916109171091810919109201092110922109231092410925109261092710928109291093010931109321093310934109351093610937109381093910940109411094210943109441094510946109471094810949109501095110952109531095410955109561095710958109591096010961109621096310964109651096610967109681096910970109711097210973109741097510976109771097810979109801098110982109831098410985109861098710988109891099010991109921099310994109951099610997109981099911000110011100211003110041100511006110071100811009110101101111012110131101411015110161101711018110191102011021110221102311024110251102611027110281102911030110311103211033110341103511036110371103811039110401104111042110431104411045110461104711048110491105011051110521105311054110551105611057110581105911060110611106211063110641106511066110671106811069110701107111072110731107411075110761107711078110791108011081110821108311084110851108611087110881108911090110911109211093110941109511096110971109811099111001110111102111031110411105111061110711108111091111011111111121111311114111151111611117111181111911120111211112211123111241112511126111271112811129111301113111132111331113411135111361113711138111391114011141111421114311144111451114611147111481114911150111511115211153111541115511156111571115811159111601116111162111631116411165111661116711168111691117011171111721117311174111751117611177111781117911180111811118211183111841118511186111871118811189111901119111192111931119411195111961119711198111991120011201112021120311204112051120611207112081120911210112111121211213112141121511216112171121811219112201122111222112231122411225112261122711228112291123011231112321123311234112351123611237112381123911240112411124211243112441124511246112471124811249112501125111252112531125411255112561125711258112591126011261112621126311264112651126611267112681126911270112711127211273112741127511276112771127811279112801128111282112831128411285112861128711288112891129011291112921129311294112951129611297112981129911300113011130211303113041130511306113071130811309113101131111312113131131411315113161131711318113191132011321113221132311324113251132611327113281132911330113311133211333113341133511336113371133811339113401134111342113431134411345113461134711348113491135011351113521135311354113551135611357113581135911360113611136211363113641136511366113671136811369113701137111372113731137411375113761137711378113791138011381113821138311384113851138611387113881138911390113911139211393113941139511396113971139811399114001140111402114031140411405114061140711408114091141011411114121141311414114151141611417114181141911420114211142211423114241142511426114271142811429114301143111432114331143411435114361143711438114391144011441114421144311444114451144611447114481144911450114511145211453114541145511456114571145811459114601146111462114631146411465114661146711468114691147011471114721147311474114751147611477114781147911480114811148211483114841148511486114871148811489114901149111492114931149411495114961149711498114991150011501115021150311504115051150611507115081150911510115111151211513115141151511516115171151811519115201152111522115231152411525115261152711528115291153011531115321153311534115351153611537115381153911540115411154211543115441154511546115471154811549115501155111552115531155411555115561155711558115591156011561115621156311564115651156611567115681156911570115711157211573115741157511576115771157811579115801158111582115831158411585115861158711588115891159011591115921159311594115951159611597115981159911600116011160211603116041160511606116071160811609116101161111612116131161411615116161161711618116191162011621116221162311624116251162611627116281162911630116311163211633116341163511636116371163811639116401164111642116431164411645116461164711648116491165011651116521165311654116551165611657116581165911660116611166211663116641166511666116671166811669116701167111672116731167411675116761167711678116791168011681116821168311684116851168611687116881168911690116911169211693116941169511696116971169811699117001170111702117031170411705117061170711708117091171011711117121171311714117151171611717117181171911720117211172211723117241172511726117271172811729117301173111732117331173411735117361173711738117391174011741117421174311744117451174611747117481174911750117511175211753117541175511756117571175811759117601176111762117631176411765117661176711768117691177011771117721177311774117751177611777117781177911780117811178211783117841178511786117871178811789117901179111792117931179411795117961179711798117991180011801118021180311804118051180611807118081180911810118111181211813118141181511816118171181811819118201182111822118231182411825118261182711828118291183011831118321183311834118351183611837118381183911840118411184211843118441184511846118471184811849118501185111852118531185411855118561185711858118591186011861118621186311864118651186611867118681186911870118711187211873118741187511876118771187811879118801188111882118831188411885118861188711888118891189011891118921189311894118951189611897118981189911900119011190211903119041190511906119071190811909119101191111912119131191411915119161191711918119191192011921119221192311924119251192611927119281192911930119311193211933119341193511936119371193811939119401194111942119431194411945119461194711948119491195011951119521195311954119551195611957119581195911960119611196211963119641196511966119671196811969119701197111972119731197411975119761197711978119791198011981119821198311984119851198611987119881198911990119911199211993119941199511996119971199811999120001200112002120031200412005120061200712008120091201012011120121201312014120151201612017120181201912020120211202212023120241202512026120271202812029120301203112032120331203412035120361203712038120391204012041120421204312044120451204612047120481204912050120511205212053120541205512056120571205812059120601206112062120631206412065120661206712068120691207012071120721207312074120751207612077120781207912080120811208212083120841208512086120871208812089120901209112092120931209412095120961209712098120991210012101121021210312104121051210612107121081210912110121111211212113121141211512116121171211812119121201212112122121231212412125121261212712128121291213012131121321213312134121351213612137121381213912140121411214212143121441214512146121471214812149121501215112152121531215412155121561215712158121591216012161121621216312164121651216612167121681216912170121711217212173121741217512176121771217812179121801218112182121831218412185121861218712188121891219012191121921219312194121951219612197121981219912200122011220212203122041220512206122071220812209122101221112212122131221412215122161221712218122191222012221122221222312224122251222612227122281222912230122311223212233122341223512236122371223812239122401224112242122431224412245122461224712248122491225012251122521225312254122551225612257122581225912260122611226212263122641226512266122671226812269122701227112272122731227412275122761227712278122791228012281122821228312284122851228612287122881228912290122911229212293122941229512296122971229812299123001230112302123031230412305123061230712308123091231012311123121231312314123151231612317123181231912320123211232212323123241232512326123271232812329123301233112332123331233412335123361233712338123391234012341123421234312344123451234612347123481234912350123511235212353123541235512356123571235812359123601236112362123631236412365123661236712368123691237012371123721237312374123751237612377123781237912380123811238212383123841238512386123871238812389123901239112392123931239412395123961239712398123991240012401124021240312404124051240612407124081240912410124111241212413124141241512416124171241812419124201242112422124231242412425124261242712428124291243012431124321243312434124351243612437124381243912440124411244212443124441244512446124471244812449124501245112452124531245412455124561245712458124591246012461124621246312464124651246612467124681246912470124711247212473124741247512476124771247812479124801248112482124831248412485124861248712488124891249012491124921249312494124951249612497124981249912500125011250212503125041250512506125071250812509125101251112512125131251412515125161251712518125191252012521125221252312524125251252612527125281252912530125311253212533125341253512536125371253812539125401254112542125431254412545125461254712548125491255012551125521255312554125551255612557125581255912560125611256212563125641256512566125671256812569125701257112572125731257412575125761257712578125791258012581125821258312584125851258612587125881258912590125911259212593125941259512596125971259812599126001260112602126031260412605126061260712608126091261012611126121261312614126151261612617126181261912620126211262212623126241262512626126271262812629126301263112632126331263412635126361263712638126391264012641126421264312644126451264612647126481264912650126511265212653126541265512656126571265812659126601266112662126631266412665126661266712668126691267012671126721267312674126751267612677126781267912680126811268212683126841268512686126871268812689126901269112692126931269412695126961269712698126991270012701127021270312704127051270612707127081270912710127111271212713127141271512716127171271812719127201272112722127231272412725127261272712728127291273012731127321273312734127351273612737127381273912740127411274212743127441274512746127471274812749127501275112752127531275412755127561275712758127591276012761127621276312764127651276612767127681276912770127711277212773127741277512776127771277812779127801278112782127831278412785127861278712788127891279012791127921279312794127951279612797127981279912800128011280212803128041280512806128071280812809128101281112812128131281412815128161281712818128191282012821128221282312824128251282612827128281282912830128311283212833128341283512836128371283812839128401284112842128431284412845128461284712848128491285012851128521285312854128551285612857128581285912860128611286212863128641286512866128671286812869128701287112872128731287412875128761287712878128791288012881128821288312884128851288612887128881288912890128911289212893128941289512896128971289812899129001290112902129031290412905129061290712908129091291012911129121291312914129151291612917129181291912920129211292212923129241292512926129271292812929129301293112932129331293412935129361293712938129391294012941129421294312944129451294612947129481294912950129511295212953129541295512956129571295812959129601296112962129631296412965129661296712968129691297012971129721297312974129751297612977129781297912980129811298212983129841298512986129871298812989129901299112992129931299412995129961299712998129991300013001130021300313004130051300613007130081300913010130111301213013130141301513016130171301813019130201302113022130231302413025130261302713028130291303013031130321303313034130351303613037130381303913040130411304213043130441304513046130471304813049130501305113052130531305413055130561305713058130591306013061130621306313064130651306613067130681306913070130711307213073130741307513076130771307813079130801308113082130831308413085130861308713088130891309013091130921309313094130951309613097130981309913100131011310213103131041310513106131071310813109131101311113112131131311413115131161311713118131191312013121131221312313124131251312613127131281312913130131311313213133131341313513136131371313813139131401314113142131431314413145131461314713148131491315013151131521315313154131551315613157131581315913160131611316213163131641316513166131671316813169131701317113172131731317413175131761317713178131791318013181131821318313184131851318613187131881318913190131911319213193131941319513196131971319813199132001320113202132031320413205132061320713208132091321013211132121321313214132151321613217132181321913220132211322213223132241322513226132271322813229132301323113232132331323413235132361323713238132391324013241132421324313244132451324613247132481324913250132511325213253132541325513256132571325813259132601326113262132631326413265132661326713268132691327013271132721327313274132751327613277132781327913280132811328213283132841328513286132871328813289132901329113292132931329413295132961329713298132991330013301133021330313304133051330613307133081330913310133111331213313133141331513316133171331813319133201332113322133231332413325133261332713328133291333013331133321333313334133351333613337133381333913340133411334213343133441334513346133471334813349133501335113352133531335413355133561335713358133591336013361133621336313364133651336613367133681336913370133711337213373133741337513376133771337813379133801338113382133831338413385133861338713388133891339013391133921339313394133951339613397133981339913400134011340213403134041340513406134071340813409134101341113412134131341413415134161341713418134191342013421134221342313424134251342613427134281342913430134311343213433134341343513436134371343813439134401344113442134431344413445134461344713448134491345013451134521345313454134551345613457134581345913460134611346213463134641346513466134671346813469134701347113472134731347413475134761347713478134791348013481134821348313484134851348613487134881348913490134911349213493134941349513496134971349813499135001350113502135031350413505135061350713508135091351013511135121351313514135151351613517135181351913520135211352213523135241352513526135271352813529135301353113532135331353413535135361353713538135391354013541135421354313544135451354613547135481354913550135511355213553135541355513556135571355813559135601356113562135631356413565135661356713568135691357013571135721357313574135751357613577135781357913580135811358213583135841358513586135871358813589135901359113592135931359413595135961359713598135991360013601136021360313604136051360613607136081360913610136111361213613136141361513616136171361813619136201362113622136231362413625136261362713628136291363013631136321363313634136351363613637136381363913640136411364213643136441364513646136471364813649136501365113652136531365413655136561365713658136591366013661136621366313664136651366613667136681366913670136711367213673136741367513676136771367813679136801368113682136831368413685136861368713688136891369013691136921369313694136951369613697136981369913700137011370213703137041370513706137071370813709137101371113712137131371413715137161371713718137191372013721137221372313724137251372613727137281372913730137311373213733137341373513736137371373813739137401374113742137431374413745137461374713748137491375013751137521375313754137551375613757137581375913760137611376213763137641376513766137671376813769137701377113772137731377413775137761377713778137791378013781137821378313784137851378613787137881378913790137911379213793137941379513796137971379813799138001380113802138031380413805138061380713808138091381013811138121381313814138151381613817138181381913820138211382213823138241382513826138271382813829138301383113832138331383413835138361383713838138391384013841138421384313844138451384613847138481384913850138511385213853138541385513856138571385813859138601386113862138631386413865138661386713868138691387013871138721387313874138751387613877138781387913880138811388213883138841388513886138871388813889138901389113892138931389413895138961389713898138991390013901139021390313904139051390613907139081390913910139111391213913139141391513916139171391813919139201392113922139231392413925139261392713928139291393013931139321393313934139351393613937139381393913940139411394213943139441394513946139471394813949139501395113952139531395413955139561395713958139591396013961139621396313964139651396613967139681396913970139711397213973139741397513976139771397813979139801398113982139831398413985139861398713988139891399013991139921399313994139951399613997139981399914000140011400214003140041400514006140071400814009140101401114012140131401414015140161401714018140191402014021140221402314024140251402614027140281402914030140311403214033140341403514036140371403814039140401404114042140431404414045140461404714048140491405014051140521405314054140551405614057140581405914060140611406214063140641406514066140671406814069140701407114072140731407414075140761407714078140791408014081140821408314084140851408614087140881408914090140911409214093140941409514096140971409814099141001410114102141031410414105141061410714108141091411014111141121411314114141151411614117141181411914120141211412214123141241412514126141271412814129141301413114132141331413414135141361413714138141391414014141141421414314144141451414614147141481414914150141511415214153141541415514156141571415814159141601416114162141631416414165141661416714168141691417014171141721417314174141751417614177141781417914180141811418214183141841418514186141871418814189141901419114192141931419414195141961419714198141991420014201142021420314204142051420614207142081420914210142111421214213142141421514216142171421814219
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. /*******************************************************************************
  3. * Copyright 2016-2025 Intel Corporation
  4. * Copyright 2024-2025 FUJITSU LIMITED
  5. * Copyright 2025 Arm Ltd. and affiliates
  6. *
  7. * Licensed under the Apache License, Version 2.0 (the "License");
  8. * you may not use this file except in compliance with the License.
  9. * You may obtain a copy of the License at
  10. *
  11. * http://www.apache.org/licenses/LICENSE-2.0
  12. *
  13. * Unless required by applicable law or agreed to in writing, software
  14. * distributed under the License is distributed on an "AS IS" BASIS,
  15. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. * See the License for the specific language governing permissions and
  17. * limitations under the License.
  18. *******************************************************************************/
  19. /// @file
  20. /// C++ API
  21. #ifndef ONEAPI_DNNL_DNNL_HPP
  22. #define ONEAPI_DNNL_DNNL_HPP
  23. // NOLINTBEGIN(readability-identifier-naming)
  24. #include "oneapi/dnnl/dnnl_config.h"
  25. /// @cond DO_NOT_DOCUMENT_THIS
  26. #include <algorithm>
  27. #include <cstdlib>
  28. #include <iterator>
  29. #include <memory>
  30. #include <string>
  31. #include <vector>
  32. #include <unordered_map>
  33. #include "oneapi/dnnl/dnnl.h"
  34. #include "oneapi/dnnl/dnnl_common.hpp"
  35. /// @endcond
  36. /// @addtogroup dnnl_api oneDNN API
  37. /// @{
  38. /// oneDNN namespace
  39. namespace dnnl {
  40. /// @addtogroup dnnl_api_utils Utilities
  41. /// Utility types and definitions.
  42. /// @{
  43. /// @cond DO_NOT_DOCUMENT_THIS
  44. template <typename T>
  45. void validate_container_size(const T &v, const char *error_message,
  46. int min_size = 1, int max_size = -1) {
  47. const int size = (int)v.size();
  48. if (size < min_size || (max_size >= 0 && size > max_size))
  49. DNNL_THROW_ERROR(dnnl_invalid_arguments, error_message);
  50. }
  51. /// @endcond
  52. /// @cond DO_NOT_DOCUMENT_THIS
  53. template <>
  54. struct handle_traits<dnnl_memory_desc_t> {
  55. static dnnl_status_t destructor(dnnl_memory_desc_t p) {
  56. return dnnl_memory_desc_destroy(p);
  57. }
  58. };
  59. template <>
  60. struct handle_traits<dnnl_memory_t> {
  61. static dnnl_status_t destructor(dnnl_memory_t p) {
  62. return dnnl_memory_destroy(p);
  63. }
  64. };
  65. template <>
  66. struct handle_traits<dnnl_primitive_desc_t> {
  67. static dnnl_status_t destructor(dnnl_primitive_desc_t p) {
  68. return dnnl_primitive_desc_destroy(p);
  69. }
  70. };
  71. template <>
  72. struct handle_traits<dnnl_primitive_t> {
  73. static dnnl_status_t destructor(dnnl_primitive_t p) {
  74. return dnnl_primitive_destroy(p);
  75. }
  76. };
  77. /// @endcond
  78. /// @} dnnl_api_utils
  79. struct stream;
  80. struct memory;
  81. struct primitive_desc;
  82. /// @addtogroup dnnl_api_primitives Primitives
  83. /// Compute primitives
  84. /// @sa @ref dev_guide_basic_concepts
  85. /// @{
  86. /// @addtogroup dnnl_api_primitives_common Common
  87. /// Common operations to create, destroy and inspect primitives
  88. /// @{
  89. /// Base class for all computational primitives.
  90. struct primitive : public handle<dnnl_primitive_t> {
  91. /// Kinds of primitives supported by the library.
  92. enum class kind {
  93. /// Undefined primitive
  94. undef = dnnl_undefined_primitive,
  95. /// A reorder primitive.
  96. reorder = dnnl_reorder,
  97. /// A shuffle primitive.
  98. shuffle = dnnl_shuffle,
  99. /// A (out-of-place) tensor concatenation primitive.
  100. concat = dnnl_concat,
  101. /// A summation primitive.
  102. sum = dnnl_sum,
  103. /// A convolution primitive.
  104. convolution = dnnl_convolution,
  105. /// A deconvolution primitive.
  106. deconvolution = dnnl_deconvolution,
  107. /// An element-wise primitive.
  108. eltwise = dnnl_eltwise,
  109. /// An LRN primitive.
  110. lrn = dnnl_lrn,
  111. /// A batch normalization primitive.
  112. batch_normalization = dnnl_batch_normalization,
  113. /// An inner product primitive.
  114. inner_product = dnnl_inner_product,
  115. /// An RNN primitive.
  116. rnn = dnnl_rnn,
  117. /// A binary primitive.
  118. binary = dnnl_binary,
  119. /// A matmul (matrix multiplication) primitive.
  120. matmul = dnnl_matmul,
  121. /// A resampling primitive.
  122. resampling = dnnl_resampling,
  123. /// A pooling primitive.
  124. pooling = dnnl_pooling,
  125. /// A reduction primitive.
  126. reduction = dnnl_reduction,
  127. /// A PReLU primitive.
  128. prelu = dnnl_prelu,
  129. /// A softmax primitive.
  130. softmax = dnnl_softmax,
  131. /// A layer normalization primitive.
  132. layer_normalization = dnnl_layer_normalization,
  133. /// A group normalization primitive
  134. group_normalization = dnnl_group_normalization,
  135. };
  136. using handle::handle;
  137. /// Default constructor. Constructs an empty object.
  138. primitive() = default;
  139. /// Constructs a primitive from a C API primitive descriptor.
  140. ///
  141. /// @param c_pd C API primitive descriptor.
  142. primitive(const_dnnl_primitive_desc_t c_pd);
  143. /// Constructs a primitive from a C API primitive descriptor and a cache blob.
  144. ///
  145. /// @param c_pd C API primitive descriptor.
  146. /// @param cache_blob Cache blob.
  147. primitive(const_dnnl_primitive_desc_t c_pd,
  148. const std::vector<uint8_t> &cache_blob);
  149. /// Constructs a primitive from a primitive descriptor.
  150. ///
  151. /// @param pd Primitive descriptor.
  152. primitive(const primitive_desc &pd);
  153. /// Constructs a primitive from a primitive descriptor and a cache blob.
  154. ///
  155. /// @param pd Primitive descriptor.
  156. /// @param cache_blob Cache blob.
  157. primitive(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob);
  158. /// Returns the C API primitive descriptor of the underlying C API
  159. /// primitive.
  160. ///
  161. /// @returns The underlying C API primitive descriptor.
  162. inline const_dnnl_primitive_desc_t get_primitive_desc() const;
  163. /// Returns the kind of the primitive.
  164. ///
  165. /// @returns The primitive kind.
  166. inline kind get_kind() const;
  167. /// Returns a cache blob for the primitive.
  168. ///
  169. /// @returns Vector containing the cache blob.
  170. ///
  171. /// @note The cache blob can be empty. It's the user's responsibility to
  172. /// check whether it's empty prior to passing it to the primitive
  173. /// constructor.
  174. inline std::vector<uint8_t> get_cache_blob() const;
  175. /// Executes computations specified by the primitive in a specified stream.
  176. ///
  177. /// Arguments are passed via an arguments map containing <index,
  178. /// memory object> pairs. The index must be one of the `DNNL_ARG_*` values
  179. /// such as `DNNL_ARG_SRC`, and the memory must have a memory descriptor
  180. /// matching the one returned by
  181. /// primitive_desc::query_md(#query::exec_arg_md, index) unless using
  182. /// dynamic shapes (see #DNNL_RUNTIME_DIM_VAL).
  183. ///
  184. /// @param astream Stream object. The stream must belong to the same engine
  185. /// as the primitive.
  186. /// @param args Arguments map.
  187. void execute(const stream &astream,
  188. const std::unordered_map<int, memory> &args) const;
  189. };
  190. /// Converts primitive kind enum value from C++ API to C API type.
  191. ///
  192. /// @param akind C++ API primitive kind enum value.
  193. /// @returns Corresponding C API primitive kind enum value.
  194. inline dnnl_primitive_kind_t convert_to_c(primitive::kind akind) {
  195. return static_cast<dnnl_primitive_kind_t>(akind);
  196. }
  197. const_dnnl_primitive_desc_t primitive::get_primitive_desc() const {
  198. const_dnnl_primitive_desc_t pd;
  199. error::wrap_c_api(dnnl_primitive_get_primitive_desc(get(), &pd),
  200. "could not get a primitive descriptor from a primitive");
  201. return pd;
  202. }
  203. dnnl::primitive::kind primitive::get_kind() const {
  204. const_dnnl_primitive_desc_t pd = get_primitive_desc();
  205. // TODO (Roma): the code below is only needed because get_primitive_desc
  206. // returns a C type.
  207. dnnl_primitive_kind_t kind;
  208. error::wrap_c_api(dnnl_primitive_desc_query(
  209. pd, dnnl_query_primitive_kind, 0, (void *)&kind),
  210. "could not get a primitive kind from a primitive descriptor");
  211. return static_cast<dnnl::primitive::kind>(kind);
  212. }
  213. std::vector<uint8_t> primitive::get_cache_blob() const {
  214. size_t size;
  215. error::wrap_c_api(dnnl_primitive_get_cache_blob(get(), &size, nullptr),
  216. "could not get cache blob size from a primitive");
  217. std::vector<uint8_t> cache_blob(size);
  218. error::wrap_c_api(
  219. dnnl_primitive_get_cache_blob(get(), &size, cache_blob.data()),
  220. "could not get a cache blob from a primitive");
  221. return cache_blob;
  222. }
  223. /// @} dnnl_api_primitives_common
  224. /// @addtogroup dnnl_api_attributes
  225. ///
  226. /// A container for parameters that extend primitives behavior.
  227. ///
  228. /// Attributes can also contain Post-ops, which are computations executed
  229. /// after the primitive.
  230. ///
  231. /// @sa @ref dev_guide_attributes
  232. /// @sa @ref dev_guide_attributes_post_ops
  233. ///
  234. /// @{
  235. /// Scratchpad mode
  236. enum class scratchpad_mode {
  237. /// The library manages the scratchpad allocation according to the policy
  238. /// specified by the `DNNL_ENABLE_CONCURRENT_EXEC`
  239. /// [build option](@ref dev_guide_build_options) (default).
  240. ///
  241. /// When `DNNL_ENABLE_CONCURRENT_EXEC=OFF` (default), the library
  242. /// scratchpad is common to all primitives to reduce the memory footprint.
  243. /// This configuration comes with limited thread-safety properties, namely
  244. /// primitives can be created and executed in parallel but cannot migrate
  245. /// between threads (in other words, each primitive should be executed in
  246. /// the same thread it was created in).
  247. ///
  248. /// When `DNNL_ENABLE_CONCURRENT_EXEC=ON`, the library scratchpad is
  249. /// private to each primitive. The memory footprint is larger than when
  250. /// using `DNNL_ENABLE_CONCURRENT_EXEC=OFF` but different primitives can be
  251. /// created and run concurrently (the same primitive cannot be run
  252. /// concurrently from two different threads though).
  253. library = dnnl_scratchpad_mode_library,
  254. /// The user manages the scratchpad allocation by querying and providing
  255. /// the scratchpad memory to primitives. This mode is thread-safe as long
  256. /// as the scratchpad buffers are not used concurrently by two primitive
  257. /// executions.
  258. user = dnnl_scratchpad_mode_user,
  259. };
  260. /// Converts a scratchpad mode enum value from C++ API to C API type.
  261. ///
  262. /// @param mode C++ API scratchpad mode enum value.
  263. /// @returns Corresponding C API scratchpad mode enum value.
  264. inline dnnl_scratchpad_mode_t convert_to_c(scratchpad_mode mode) {
  265. return static_cast<dnnl_scratchpad_mode_t>(mode);
  266. }
  267. /// Rounding mode
  268. enum class rounding_mode {
  269. /// rounding mode dictated by the floating-point environment
  270. environment = dnnl_rounding_mode_environment,
  271. /// stochastic rounding mode where a random bias is added to the
  272. /// trailing mantissa bits before conversion.
  273. stochastic = dnnl_rounding_mode_stochastic
  274. };
  275. /// Converts a rounding mode enum value from C++ API to C API type.
  276. ///
  277. /// @param mode C++ API rounding mode enum value.
  278. /// @returns Corresponding C API rounding mode enum value.
  279. inline dnnl_rounding_mode_t convert_to_c(rounding_mode mode) {
  280. return static_cast<dnnl_rounding_mode_t>(mode);
  281. }
  282. /// Propagation kind.
  283. enum class prop_kind {
  284. /// Undefined propagation kind.
  285. undef = dnnl_prop_kind_undef,
  286. /// Forward data propagation (training mode). In this mode, primitives
  287. /// perform computations necessary for subsequent backward propagation.
  288. forward_training = dnnl_forward_training,
  289. /// Forward data propagation (inference mode). In this mode, primitives
  290. /// perform only computations that are necessary for inference and omit
  291. /// computations that are necessary only for backward propagation.
  292. forward_inference = dnnl_forward_inference,
  293. /// Forward data propagation,
  294. /// alias for #dnnl::prop_kind::forward_training.
  295. forward = dnnl_forward,
  296. /// Backward propagation (with respect to all parameters).
  297. backward = dnnl_backward,
  298. /// Backward data propagation.
  299. backward_data = dnnl_backward_data,
  300. /// Backward weights propagation.
  301. backward_weights = dnnl_backward_weights,
  302. /// Backward bias propagation.
  303. backward_bias = dnnl_backward_bias
  304. };
  305. /// Converts propagation kind enum value from C++ API to C API type.
  306. ///
  307. /// @param akind C++ API propagation kind enum value.
  308. /// @returns Corresponding C API propagation kind enum value.
  309. inline dnnl_prop_kind_t convert_to_c(prop_kind akind) {
  310. return static_cast<dnnl_prop_kind_t>(akind);
  311. }
  312. /// Kinds of algorithms.
  313. enum class algorithm {
  314. /// Undefined algorithm
  315. undef = dnnl_alg_kind_undef,
  316. /// Convolution algorithm that is chosen to be either direct or Winograd
  317. /// automatically
  318. convolution_auto = dnnl_convolution_auto,
  319. /// Direct convolution
  320. convolution_direct = dnnl_convolution_direct,
  321. /// Winograd convolution
  322. convolution_winograd = dnnl_convolution_winograd,
  323. /// Direct deconvolution
  324. deconvolution_direct = dnnl_deconvolution_direct,
  325. /// Winograd deconvolution
  326. deconvolution_winograd = dnnl_deconvolution_winograd,
  327. /// Elementwise: rectified linear unit (ReLU)
  328. eltwise_relu = dnnl_eltwise_relu,
  329. /// Elementwise: hyperbolic tangent non-linearity (tanh)
  330. eltwise_tanh = dnnl_eltwise_tanh,
  331. /// Elementwise: exponential linear unit (ELU)
  332. eltwise_elu = dnnl_eltwise_elu,
  333. /// Elementwise: square
  334. eltwise_square = dnnl_eltwise_square,
  335. /// Elementwise: abs
  336. eltwise_abs = dnnl_eltwise_abs,
  337. /// Elementwise: square root
  338. eltwise_sqrt = dnnl_eltwise_sqrt,
  339. /// Elementwise: swish (\f$x \cdot sigmoid(a \cdot x)\f$)
  340. eltwise_swish = dnnl_eltwise_swish,
  341. /// Elementwise: linear
  342. eltwise_linear = dnnl_eltwise_linear,
  343. /// Elementwise: soft_relu
  344. eltwise_soft_relu = dnnl_eltwise_soft_relu,
  345. /// Elementwise: mish
  346. eltwise_mish = dnnl_eltwise_mish,
  347. /// Elementwise: logistic
  348. eltwise_logistic = dnnl_eltwise_logistic,
  349. /// Elementwise: exponent
  350. eltwise_exp = dnnl_eltwise_exp,
  351. /// Elementwise: tanh-based gelu
  352. eltwise_gelu_tanh = dnnl_eltwise_gelu_tanh,
  353. /// Elementwise: erf-based gelu
  354. eltwise_gelu_erf = dnnl_eltwise_gelu_erf,
  355. /// Elementwise: natural logarithm
  356. eltwise_log = dnnl_eltwise_log,
  357. /// Elementwise: clip
  358. eltwise_clip = dnnl_eltwise_clip,
  359. /// Eltwise: clip version 2
  360. eltwise_clip_v2 = dnnl_eltwise_clip_v2,
  361. /// Elementwise: pow
  362. eltwise_pow = dnnl_eltwise_pow,
  363. /// Elementwise: round
  364. eltwise_round = dnnl_eltwise_round,
  365. /// Elementwise: hardswish
  366. eltwise_hardswish = dnnl_eltwise_hardswish,
  367. /// Elementwise: hardsigmoid
  368. eltwise_hardsigmoid = dnnl_eltwise_hardsigmoid,
  369. /// Elementwise: rectified linar unit (ReLU) (dst for backward)
  370. eltwise_relu_use_dst_for_bwd = dnnl_eltwise_relu_use_dst_for_bwd,
  371. /// Elementwise: hyperbolic tangent non-linearity (tanh) (dst for backward)
  372. eltwise_tanh_use_dst_for_bwd = dnnl_eltwise_tanh_use_dst_for_bwd,
  373. /// Elementwise: exponential linear unit (ELU) (dst for backward)
  374. eltwise_elu_use_dst_for_bwd = dnnl_eltwise_elu_use_dst_for_bwd,
  375. /// Elementwise: square root (dst for backward)
  376. eltwise_sqrt_use_dst_for_bwd = dnnl_eltwise_sqrt_use_dst_for_bwd,
  377. /// Elementwise: logistic (dst for backward)
  378. eltwise_logistic_use_dst_for_bwd = dnnl_eltwise_logistic_use_dst_for_bwd,
  379. /// Elementwise: exponent (dst for backward)
  380. eltwise_exp_use_dst_for_bwd = dnnl_eltwise_exp_use_dst_for_bwd,
  381. /// Elementwise: clip version 2 (dst for backward)
  382. eltwise_clip_v2_use_dst_for_bwd = dnnl_eltwise_clip_v2_use_dst_for_bwd,
  383. /// Local response normalization (LRN) across multiple channels
  384. lrn_across_channels = dnnl_lrn_across_channels,
  385. /// LRN within a single channel
  386. lrn_within_channel = dnnl_lrn_within_channel,
  387. /// Max pooling
  388. pooling_max = dnnl_pooling_max,
  389. /// Average pooling include padding
  390. pooling_avg_include_padding = dnnl_pooling_avg_include_padding,
  391. /// Average pooling exclude padding
  392. pooling_avg_exclude_padding = dnnl_pooling_avg_exclude_padding,
  393. /// RNN cell
  394. vanilla_rnn = dnnl_vanilla_rnn,
  395. /// LSTM cell
  396. vanilla_lstm = dnnl_vanilla_lstm,
  397. /// GRU cell
  398. vanilla_gru = dnnl_vanilla_gru,
  399. /// GRU cell with linear before reset. Differs from the vanilla GRU
  400. /// in how the new memory gate is calculated:
  401. /// \f$c_t = tanh(W_c*x_t + b_{c_x} + r_t*(U_c*h_{t-1}+b_{c_h})) \f$
  402. /// LRB GRU expects 4 bias tensors on input:
  403. /// \f$[b_{u}, b_{r}, b_{c_x}, b_{c_h}]\f$
  404. lbr_gru = dnnl_lbr_gru,
  405. /// AUGRU cell
  406. vanilla_augru = dnnl_vanilla_augru,
  407. /// AUGRU cell with linear before reset
  408. lbr_augru = dnnl_lbr_augru,
  409. /// Binary add
  410. binary_add = dnnl_binary_add,
  411. /// Binary mul
  412. binary_mul = dnnl_binary_mul,
  413. /// Binary max
  414. binary_max = dnnl_binary_max,
  415. /// Binary min
  416. binary_min = dnnl_binary_min,
  417. /// Binary div
  418. binary_div = dnnl_binary_div,
  419. /// Binary sub
  420. binary_sub = dnnl_binary_sub,
  421. /// Binary greater than or equal
  422. binary_ge = dnnl_binary_ge,
  423. /// Binary greater than
  424. binary_gt = dnnl_binary_gt,
  425. /// Binary less than or equal
  426. binary_le = dnnl_binary_le,
  427. /// Binary less than
  428. binary_lt = dnnl_binary_lt,
  429. /// Binary equal
  430. binary_eq = dnnl_binary_eq,
  431. /// Binary not equal
  432. binary_ne = dnnl_binary_ne,
  433. /// Binary select
  434. binary_select = dnnl_binary_select,
  435. /// Nearest Neighbor resampling method
  436. resampling_nearest = dnnl_resampling_nearest,
  437. /// Linear (Bilinear, Trilinear) resampling method
  438. resampling_linear = dnnl_resampling_linear,
  439. /// Reduction using max operation
  440. reduction_max = dnnl_reduction_max,
  441. /// Reduction using min operation
  442. reduction_min = dnnl_reduction_min,
  443. /// Reduction using sum operation
  444. reduction_sum = dnnl_reduction_sum,
  445. /// Reduction using mul operation
  446. reduction_mul = dnnl_reduction_mul,
  447. /// Reduction using mean operation
  448. reduction_mean = dnnl_reduction_mean,
  449. /// Reduction using norm_lp_max operation
  450. reduction_norm_lp_max = dnnl_reduction_norm_lp_max,
  451. /// Reduction using norm_lp_sum operation
  452. reduction_norm_lp_sum = dnnl_reduction_norm_lp_sum,
  453. /// Reduction using norm_lp_power_p_max operation
  454. reduction_norm_lp_power_p_max = dnnl_reduction_norm_lp_power_p_max,
  455. /// Reduction using norm_lp_power_p_sum operation
  456. reduction_norm_lp_power_p_sum = dnnl_reduction_norm_lp_power_p_sum,
  457. /// Softmax, numerically stable
  458. softmax_accurate = dnnl_softmax_accurate,
  459. /// LogSoftmax, numerically stable
  460. softmax_log = dnnl_softmax_log,
  461. };
  462. /// Converts algorithm kind enum value from C++ API to C API type.
  463. /// @param aalgorithm C++ API algorithm kind enum value.
  464. /// @returns Corresponding C API algorithm kind enum value.
  465. inline dnnl_alg_kind_t convert_to_c(algorithm aalgorithm) {
  466. return static_cast<dnnl_alg_kind_t>(aalgorithm);
  467. }
  468. /// @} dnnl_api_attributes
  469. /// @addtogroup dnnl_api_primitives_common
  470. /// @{
  471. /// Flags for normalization primitives.
  472. enum class normalization_flags : unsigned {
  473. /// Use no normalization flags. If specified, the library computes mean and
  474. /// variance on forward propagation for training and inference, outputs
  475. /// them on forward propagation for training, and computes the respective
  476. /// derivatives on backward propagation.
  477. ///
  478. /// @note
  479. /// Backward propagation of type #dnnl::prop_kind::backward_data has
  480. /// the same behavior as #dnnl::prop_kind::backward.
  481. none = dnnl_normalization_flags_none,
  482. /// Use global statistics. If specified, the library uses mean and
  483. /// variance provided by the user as an input on forward propagation and
  484. /// does not compute their derivatives on backward propagation. Otherwise,
  485. /// the library computes mean and variance on forward propagation for
  486. /// training and inference, outputs them on forward propagation for
  487. /// training, and computes the respective derivatives on backward
  488. /// propagation.
  489. use_global_stats = dnnl_use_global_stats,
  490. /// Use scale parameter. If specified, the user is expected to pass scale as
  491. /// input on forward propagation. On backward propagation of type
  492. /// #dnnl::prop_kind::backward, the library computes its derivative.
  493. use_scale = dnnl_use_scale,
  494. /// Use shift parameter. If specified, the user is expected to pass shift as
  495. /// input on forward propagation. On backward propagation of type
  496. /// #dnnl::prop_kind::backward, the library computes its derivative.
  497. use_shift = dnnl_use_shift,
  498. /// Fuse normalization with ReLU. On training, normalization will require
  499. /// the workspace to implement backward propagation. On inference, the
  500. /// workspace is not required and behavior is the same as when normalization
  501. /// is fused with ReLU using the post-ops API.
  502. ///
  503. /// @note
  504. /// The flag implies negative slope being 0. On training this is the only
  505. /// configuration supported. For inference, to use non-zero negative slope
  506. /// consider using @ref dev_guide_attributes_post_ops.
  507. fuse_norm_relu = dnnl_fuse_norm_relu,
  508. /// Fuse normalization with an elementwise binary Add operation
  509. /// followed by ReLU.
  510. /// During training, normalization will require a workspace to implement
  511. /// backward propagation. For inference, the workspace is not needed.
  512. /// On forward propagation, an elementwise binary Add operation is applied
  513. /// to the normalization results with an additional input tensor, followed
  514. /// by ReLU with a negative slope of 0.
  515. /// On backward propagation, the result of the backward ReLU operation
  516. /// with the input tensor and workspace from the forward pass is saved
  517. /// to an extra output tensor, and backward normalization is performed.
  518. fuse_norm_add_relu = dnnl_fuse_norm_add_relu,
  519. /// Use Root Mean Square (RMS) Normalization. In forward propagation,
  520. /// the mean is considered zero, and RMS norm is used instead of variance
  521. /// for scaling. Only the RMS norm is output during forward propagation for
  522. /// training. In backward propagation, the library calculates the derivative
  523. /// with respect to the RMS norm only, assuming the mean is zero.
  524. ///
  525. /// @note
  526. /// When used with #dnnl::normalization_flags::use_global_stats,
  527. /// only RMS norm is required to be provided as input.
  528. rms_norm = dnnl_rms_norm,
  529. };
  530. /// Converts normalization flags enum value from C++ API to C API type.
  531. /// @param flags C++ API normalization flags enum value.
  532. /// @returns Corresponding C API normalization flags enum value.
  533. inline dnnl_normalization_flags_t convert_to_c(normalization_flags flags) {
  534. return static_cast<dnnl_normalization_flags_t>(flags);
  535. }
  536. /// @} dnnl_api_primitives_common
  537. /// @addtogroup dnnl_api_rnn
  538. /// @{
  539. /// RNN cell flags.
  540. enum class rnn_flags : unsigned {
  541. /// Undefined RNN flags
  542. undef = dnnl_rnn_flags_undef,
  543. /// Do not add weights gradient to existing diff_weights memory
  544. diff_weights_overwrite = dnnl_rnn_flags_diff_weights_overwrite,
  545. };
  546. /// Converts RNN cell flags enum value from C++ API to C API type.
  547. /// @param flags C++ API RNN cell flags enum value.
  548. /// @returns Corresponding C API RNN cell flags enum value.
  549. inline dnnl_rnn_flags_t convert_to_c(rnn_flags flags) {
  550. return static_cast<dnnl_rnn_flags_t>(flags);
  551. }
  552. DNNL_DEFINE_BITMASK_OPS(normalization_flags)
  553. DNNL_DEFINE_BITMASK_OPS(rnn_flags)
  554. /// A direction of RNN primitive execution
  555. enum class rnn_direction {
  556. /// Undefined RNN direction.
  557. undef = dnnl_rnn_direction_undef,
  558. /// Unidirectional execution of RNN primitive from left to right.
  559. unidirectional_left2right = dnnl_unidirectional_left2right,
  560. /// Unidirectional execution of RNN primitive from right to left.
  561. unidirectional_right2left = dnnl_unidirectional_right2left,
  562. /// Bidirectional execution of RNN primitive with concatenation of the
  563. /// results.
  564. bidirectional_concat = dnnl_bidirectional_concat,
  565. /// Bidirectional execution of RNN primitive with summation of the
  566. /// results.
  567. bidirectional_sum = dnnl_bidirectional_sum,
  568. };
  569. /// Converts RNN direction enum value from C++ API to C API type.
  570. /// @param dir C++ API RNN direction enum value.
  571. /// @returns Corresponding C API RNN direction enum value.
  572. inline dnnl_rnn_direction_t convert_to_c(rnn_direction dir) {
  573. return static_cast<dnnl_rnn_direction_t>(dir);
  574. }
  575. /// @} dnnl_api_rnn
  576. /// @addtogroup dnnl_api_primitives_common
  577. /// @{
  578. /// Primitive descriptor query specification.
  579. ///
  580. /// In general, queries are not used with the C++ API because most queries are
  581. /// implemented as class members.
  582. ///
  583. /// See @ref dnnl_query_t for more information.
  584. enum class query {
  585. /// no query
  586. undef = dnnl_query_undef,
  587. /// execution engine
  588. engine = dnnl_query_engine,
  589. /// primitive kind
  590. primitive_kind = dnnl_query_primitive_kind,
  591. /// number of inputs expected
  592. num_of_inputs_s32 = dnnl_query_num_of_inputs_s32,
  593. /// number of outputs expected
  594. num_of_outputs_s32 = dnnl_query_num_of_outputs_s32,
  595. /// runtime estimation (seconds), unimplemented
  596. time_estimate_f64 = dnnl_query_time_estimate_f64,
  597. /// memory required for scratchpad (bytes)
  598. ///
  599. /// @sa @ref dev_guide_attributes_scratchpad
  600. memory_consumption_s64 = dnnl_query_memory_consumption_s64,
  601. /// scratchpad engine
  602. ///
  603. /// engine to be used for creating scratchpad memory
  604. scratchpad_engine = dnnl_query_scratchpad_engine,
  605. /// reorder source engine
  606. reorder_src_engine = dnnl_query_reorder_src_engine,
  607. /// reorder destination engine
  608. reorder_dst_engine = dnnl_query_reorder_dst_engine,
  609. /// implementation name
  610. impl_info_str = dnnl_query_impl_info_str,
  611. /// propagation kind
  612. prop_kind = dnnl_query_prop_kind,
  613. /// size of cache blob ID in bytes
  614. cache_blob_id_size_s64 = dnnl_query_cache_blob_id_size_s64,
  615. /// cache blob ID (pointer to array)
  616. cache_blob_id = dnnl_query_cache_blob_id,
  617. /// strides
  618. strides = dnnl_query_strides,
  619. /// dilations
  620. dilations = dnnl_query_dilations,
  621. /// left padding
  622. padding_l = dnnl_query_padding_l,
  623. /// right padding
  624. padding_r = dnnl_query_padding_r,
  625. /// epsilon
  626. epsilon_f32 = dnnl_query_epsilon_f32,
  627. /// flags
  628. flags = dnnl_query_flags,
  629. /// algorithm kind
  630. alg_kind = dnnl_query_alg_kind,
  631. /// alpha
  632. alpha_f32 = dnnl_query_alpha_f32,
  633. /// beta
  634. beta_f32 = dnnl_query_beta_f32,
  635. /// axis
  636. axis_s32 = dnnl_query_axis_s32,
  637. /// LRN parameter local size
  638. local_size_s64 = dnnl_query_local_size_s64,
  639. /// LRN parameter K
  640. k_f32 = dnnl_query_k_f32,
  641. /// Reduction parameter P
  642. p_f32 = dnnl_query_p_f32,
  643. /// Resampling parameter factors
  644. factors = dnnl_query_factors,
  645. /// RNN parameter cell kind
  646. cell_kind = dnnl_query_cell_kind,
  647. /// RNN parameter direction
  648. direction = dnnl_query_direction,
  649. /// RNN parameter activation kind
  650. activation_kind = dnnl_query_activation_kind,
  651. /// Pooling parameter kernel
  652. kernel = dnnl_query_kernel,
  653. /// Shuffle parameter group size
  654. group_size_s64 = dnnl_query_group_size_s64,
  655. /// source memory desc
  656. src_md = dnnl_query_src_md,
  657. /// source gradient (diff) memory desc
  658. diff_src_md = dnnl_query_diff_src_md,
  659. /// weights memory descriptor desc
  660. weights_md = dnnl_query_weights_md,
  661. /// weights gradient (diff) memory desc
  662. diff_weights_md = dnnl_query_diff_weights_md,
  663. /// destination memory desc
  664. dst_md = dnnl_query_dst_md,
  665. /// destination gradient (diff) memory desc
  666. diff_dst_md = dnnl_query_diff_dst_md,
  667. /// workspace memory desc
  668. workspace_md = dnnl_query_workspace_md,
  669. /// scratchpad memory desc
  670. scratchpad_md = dnnl_query_scratchpad_md,
  671. /// memory desc of an execute argument
  672. exec_arg_md = dnnl_query_exec_arg_md,
  673. /// number of dimensions
  674. ndims_s32 = dnnl_query_ndims_s32,
  675. /// vector of dimensions
  676. dims = dnnl_query_dims,
  677. /// data type
  678. data_type = dnnl_query_data_type,
  679. /// submemory offset
  680. submemory_offset_s64 = dnnl_query_submemory_offset_s64,
  681. /// vector of padded dimensions
  682. padded_dims = dnnl_query_padded_dims,
  683. /// vector of padded offsets
  684. padded_offsets = dnnl_query_padded_offsets,
  685. /// format kind
  686. format_kind = dnnl_query_format_kind,
  687. /// number of innermost blocks
  688. inner_nblks_s32 = dnnl_query_inner_nblks_s32,
  689. /// vector of sizes of the innermost blocks
  690. inner_blks = dnnl_query_inner_blks,
  691. /// vector of logical indices of the blocks
  692. inner_idxs = dnnl_query_inner_idxs,
  693. /// Sparse encoding
  694. sparse_encoding = dnnl_query_sparse_encoding,
  695. /// Number of non-zero entries
  696. nnz_s64 = dnnl_query_nnz_s64,
  697. /// Number of buffers required for a memory descriptor
  698. num_handles_s32 = dnnl_query_num_handles_s32,
  699. };
  700. /// Converts query enum value from C++ API to C API type.
  701. /// @param aquery C++ API query enum value.
  702. /// @returns Corresponding C API query enum value.
  703. inline dnnl_query_t convert_to_c(query aquery) {
  704. return static_cast<dnnl_query_t>(aquery);
  705. }
  706. /// @} dnnl_api_primitives_common
  707. /// @} dnnl_api_primitives
  708. /// @addtogroup dnnl_api_memory Memory
  709. ///
  710. /// A container that describes and stores data. Memory objects can contain
  711. /// data of various types and formats. There are two levels of abstraction:
  712. ///
  713. /// 1. **Memory descriptor** -- engine-agnostic logical description of data
  714. /// (number of dimensions, dimension sizes, and data type), and,
  715. /// optionally, the information about the physical format of data in
  716. /// memory. If this information is not known yet, a memory descriptor can
  717. /// be created with #dnnl::memory::format_tag::any. This allows
  718. /// compute-intensive primitives to choose the best format for
  719. /// computation. The user is responsible for reordering the data into the
  720. /// chosen format when formats do not match.
  721. ///
  722. /// A memory descriptor can be initialized either by specifying dimensions
  723. /// and a memory format tag or strides for each of them, or by
  724. /// manipulating the dnnl_memory_desc_t structure directly.
  725. ///
  726. /// @warning
  727. /// The latter approach requires understanding how the physical data
  728. /// representation is mapped to the structure and is discouraged. This
  729. /// topic is discussed in @ref dev_guide_understanding_memory_formats.
  730. ///
  731. /// The user can query the amount of memory required by a memory
  732. /// descriptor using the #dnnl::memory::desc::get_size() function. The
  733. /// size of data in general cannot be computed as the product of
  734. /// dimensions multiplied by the size of the data type. So users are
  735. /// required to use this function for better code portability.
  736. ///
  737. /// Two memory descriptors can be compared using the equality and
  738. /// inequality operators. The comparison is especially useful when
  739. /// checking whether it is necessary to reorder data from the user's data
  740. /// format to a primitive's format.
  741. ///
  742. /// 2. **Memory object** -- an engine-specific object that handles the memory
  743. /// buffer and its description (a memory descriptor). For the CPU engine or
  744. /// with USM, the memory buffer handle is simply a pointer to @c void. The
  745. /// memory buffer can be queried using #dnnl::memory::get_data_handle() and
  746. /// set using #dnnl::memory::set_data_handle(). The underlying SYCL buffer,
  747. /// when used, can be queried using #dnnl::sycl_interop::get_buffer and set
  748. /// using #dnnl::sycl_interop::set_buffer. A memory object can also be
  749. /// queried for the underlying memory descriptor and for its engine using
  750. /// #dnnl::memory::get_desc() and dnnl::memory::get_engine().
  751. ///
  752. /// Along with ordinary memory descriptors with all dimensions being positive,
  753. /// the library supports *zero-volume* memory descriptors with one or more
  754. /// dimensions set to zero. This is used to support the NumPy\* convention.
  755. /// If a zero-volume memory is passed to a primitive, the primitive typically
  756. /// does not perform any computations with this memory. For example:
  757. ///
  758. /// - A concatenation primitive would ignore all memory object with zeroes in
  759. /// the concat dimension / axis.
  760. ///
  761. /// - A forward convolution with a source memory object with zero in the
  762. /// minibatch dimension would always produce a destination memory object
  763. /// with a zero in the minibatch dimension and perform no computations.
  764. ///
  765. /// - However, a forward convolution with a zero in one of the weights
  766. /// dimensions is ill-defined and is considered to be an error by the
  767. /// library because there is no clear definition of what the output values
  768. /// should be.
  769. ///
  770. /// Memory buffer of a zero-volume memory is never accessed.
  771. ///
  772. /// @{
  773. /// Memory object.
  774. ///
  775. /// A memory object encapsulates a handle to a memory buffer allocated on a
  776. /// specific engine, tensor dimensions, data type, and memory format, which is
  777. /// the way tensor indices map to offsets in linear memory space. Memory
  778. /// objects are passed to primitives during execution.
  779. struct memory : public handle<dnnl_memory_t> {
  780. using handle::handle;
  781. /// Integer type for representing dimension sizes and indices.
  782. using dim = dnnl_dim_t;
  783. /// Vector of dimensions. Implementations are free to force a limit on the
  784. /// vector's length.
  785. using dims = std::vector<dim>;
  786. /// Helper function that validates that an `std::vector` of dimensions can
  787. /// be safely converted to the C API array ::dnnl_dims_t. Throws if
  788. /// validation fails.
  789. ///
  790. /// @param v Vector of dimensions.
  791. /// @param min_size Minimum expected size of the vector.
  792. template <typename T>
  793. static void validate_dims(const std::vector<T> &v, int min_size = 0) {
  794. validate_container_size(
  795. v, "dimensions are invalid", min_size, DNNL_MAX_NDIMS);
  796. }
  797. /// Data type specification.
  798. enum class data_type {
  799. /// Undefined data type (used for empty memory descriptors).
  800. undef = dnnl_data_type_undef,
  801. /// 4-bit float data type with 3-bit exponent and 0 bit mantissa.
  802. f4_e3m0 = dnnl_f4_e3m0,
  803. /// [MX-compliant 4-bit float data type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 2-bit exponent and 1 bit mantissa.
  804. f4_e2m1 = dnnl_f4_e2m1,
  805. /// [MX-compliant 8-bit compliant scale data type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 8-bit exponent.
  806. e8m0 = dnnl_e8m0,
  807. /// [OFP8 standard 8-bit floating-point](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-06-20-pdf)
  808. /// with a 5-bit exponent and a 2-bit mantissa.
  809. f8_e5m2 = dnnl_f8_e5m2,
  810. /// [OFP8 standard 8-bit floating-point](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-06-20-pdf)
  811. /// with a 4-bit exponent and a 3-bit mantissa.
  812. f8_e4m3 = dnnl_f8_e4m3,
  813. /// [16-bit/half-precision floating point](https://en.wikipedia.org/wiki/Half-precision_floating-point_format).
  814. f16 = dnnl_f16,
  815. /// non-standard
  816. /// [16-bit floating point with 7-bit mantissa](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format).
  817. bf16 = dnnl_bf16,
  818. /// [32-bit/single-precision floating point](https://en.wikipedia.org/wiki/Single-precision_floating-point_format).
  819. f32 = dnnl_f32,
  820. //// [64-bit/double-precision floating point](https://en.wikipedia.org/wiki/Double-precision_floating-point_format).
  821. f64 = dnnl_f64,
  822. /// 32-bit signed integer.
  823. s32 = dnnl_s32,
  824. /// 8-bit signed integer.
  825. s8 = dnnl_s8,
  826. /// 8-bit unsigned integer.
  827. u8 = dnnl_u8,
  828. /// 4-bit signed integer.
  829. s4 = dnnl_s4,
  830. /// 4-bit unsigned integer.
  831. u4 = dnnl_u4,
  832. };
  833. /// Returns size of data type in bytes.
  834. /// @returns The number of bytes occupied by data type.
  835. static size_t data_type_size(data_type adata_type) {
  836. return dnnl_data_type_size(convert_to_c(adata_type));
  837. }
  838. /// Memory format kind
  839. enum class format_kind {
  840. /// Undefined memory format kind, used for empty memory descriptors.
  841. undef = dnnl_format_kind_undef,
  842. /// A special format kind that indicates that the actual format will be
  843. /// selected by a primitive automatically.
  844. any = dnnl_format_kind_any,
  845. /// A tensor in a generic format described by the stride and blocking
  846. /// values in each dimension.
  847. blocked = dnnl_blocked,
  848. /// Format kind for sparse tensors.
  849. sparse = dnnl_format_kind_sparse,
  850. /// Format kind for host scalars.
  851. host_scalar = dnnl_format_kind_host_scalar,
  852. /// A special format kind that indicates that tensor format is opaque.
  853. opaque = dnnl_format_kind_opaque,
  854. };
  855. /// Sparse encodings.
  856. /// @sa @ref dev_guide_sparsity
  857. enum class sparse_encoding {
  858. /// Undefined sparse encoding kind, used for empty memory descriptors.
  859. undef = dnnl_sparse_encoding_undef,
  860. /// Compressed Sparse Row (CSR) encoding.
  861. csr = dnnl_csr,
  862. /// An encoding that is used for an opaque storage schema for
  863. /// tensors with unstructured sparsity. A memory descriptor with the
  864. /// packed encoding cannot be used to create a memory object. It can
  865. /// only be used to create a primitive descriptor to query the
  866. /// actual memory descriptor (similar to the format tag `any`).
  867. packed = dnnl_packed,
  868. /// Coordinate Sparse (COO) encoding.
  869. coo = dnnl_coo,
  870. };
  871. /// Memory format tag specification.
  872. ///
  873. /// Memory format tags can be further divided into two categories:
  874. ///
  875. /// - Domain-agnostic names, i.e. names that do not depend on the tensor
  876. /// usage in the specific primitive. These names use letters from `a`
  877. /// to `f` to denote logical dimensions and form the order in which the
  878. /// dimensions are laid in memory. For example,
  879. /// #dnnl::memory::format_tag::ab is used to denote a 2D tensor where the
  880. /// second logical dimension (denoted as `b`) is the innermost, i.e.
  881. /// has stride = 1, and the first logical dimension (`a`) is laid out in
  882. /// memory with stride equal to the size of the second dimension. On the
  883. /// other hand, #dnnl::memory::format_tag::ba is the transposed version
  884. /// of the same tensor: the outermost dimension (`a`) becomes the
  885. /// innermost one.
  886. ///
  887. /// - Domain-specific names, i.e. names that make sense only in the
  888. /// context of a certain domain, such as CNN. These names are
  889. /// aliases to the corresponding domain-agnostic tags and used mostly
  890. /// for convenience. For example, #dnnl::memory::format_tag::nc
  891. /// is used to denote 2D CNN activations tensor memory format, where
  892. /// the channels dimension is the innermost one and the batch dimension
  893. /// is the outermost one. Moreover, #dnnl::memory::format_tag::nc is
  894. /// an alias for #dnnl::memory::format_tag::ab, because for
  895. /// CNN primitives the logical dimensions of activations tensors come
  896. /// in order: batch, channels, spatial. In other words, batch
  897. /// corresponds to the first logical dimension (`a`), and channels
  898. /// correspond to the second one (`b`).
  899. ///
  900. /// The following domain-specific notation applies to memory format tags:
  901. /// - @c 'n' denotes the mini-batch dimension
  902. /// - @c 'c' denotes a channels dimension
  903. /// - When there are multiple channel dimensions (for example,
  904. /// in convolution weights tensor), @c 'i' and @c 'o' denote dimensions
  905. /// of input and output channels
  906. /// - @c 'g' denotes a groups dimension for convolution weights
  907. /// - @c 'd', @c 'h', and @c 'w' denote spatial depth, height, and width
  908. /// respectively
  909. ///
  910. /// See @ref dnnl_format_tag_t for a detailed description.
  911. enum class format_tag {
  912. /// Undefined memory format tag
  913. undef = dnnl_format_tag_undef,
  914. /// Placeholder memory format tag. Used to instruct the primitive to
  915. /// select a format automatically.
  916. any = dnnl_format_tag_any,
  917. /// plain 1D tensor
  918. a = dnnl_a,
  919. /// plain 2D tensor
  920. ab = dnnl_ab,
  921. /// permuted 2D tensor
  922. ba = dnnl_ba,
  923. /// plain 3D tensor
  924. abc = dnnl_abc,
  925. /// permuted 3D tensor
  926. acb = dnnl_acb,
  927. /// permuted 3D tensor
  928. bac = dnnl_bac,
  929. /// permuted 3D tensor
  930. bca = dnnl_bca,
  931. /// permuted 3D tensor
  932. cba = dnnl_cba,
  933. /// plain 4D tensor
  934. abcd = dnnl_abcd,
  935. /// permuted 4D tensor
  936. abdc = dnnl_abdc,
  937. /// permuted 4D tensor
  938. acbd = dnnl_acbd,
  939. /// permuted 4D tensor
  940. acdb = dnnl_acdb,
  941. /// permuted 4D tensor
  942. adbc = dnnl_adbc,
  943. /// permuted 4D tensor
  944. bacd = dnnl_bacd,
  945. /// permuted 4D tensor
  946. bcda = dnnl_bcda,
  947. /// permuted 4D tensor
  948. cdba = dnnl_cdba,
  949. /// permuted 4D tensor
  950. dcab = dnnl_dcab,
  951. /// plain 5D tensor
  952. abcde = dnnl_abcde,
  953. /// permuted 5D tensor
  954. abdec = dnnl_abdec,
  955. /// permuted 5D tensor
  956. acbde = dnnl_acbde,
  957. /// permuted 5D tensor
  958. acdeb = dnnl_acdeb,
  959. /// permuted 5D tensor
  960. bacde = dnnl_bacde,
  961. /// permuted 5D tensor
  962. bcdea = dnnl_bcdea,
  963. /// permuted 5D tensor
  964. cdeba = dnnl_cdeba,
  965. /// permuted 5D tensor
  966. decab = dnnl_decab,
  967. /// permuted 5D tensor
  968. abced = dnnl_abced,
  969. /// plain 6D tensor
  970. abcdef = dnnl_abcdef,
  971. /// permuted 6D tensor
  972. abdfce = dnnl_abdfce,
  973. /// permuted 6D tensor
  974. acbdef = dnnl_acbdef,
  975. /// permuted 6D tensor
  976. abdefc = dnnl_abdefc,
  977. /// permuted 6D tensor
  978. defcab = dnnl_defcab,
  979. /// permuted 6D tensor
  980. abcdfe = dnnl_abcdfe,
  981. /// plain 7D tensor
  982. abcdefg = dnnl_abcdefg,
  983. /// permuted 7D tensor
  984. abcdegf = dnnl_abcdegf,
  985. /// plain 8D tensor
  986. abcdefgh = dnnl_abcdefgh,
  987. /// permuted 8D tensor
  988. abcdefhg = dnnl_abcdefhg,
  989. /// plain 9D tensor
  990. abcdefghi = dnnl_abcdefghi,
  991. /// permuted 9D tensor
  992. abcdefgih = dnnl_abcdefgih,
  993. /// plain 10D tensor
  994. abcdefghij = dnnl_abcdefghij,
  995. /// permuted 10D tensor
  996. abcdefghji = dnnl_abcdefghji,
  997. /// plain 11D tensor
  998. abcdefghijk = dnnl_abcdefghijk,
  999. /// permuted 11D tensor
  1000. abcdefghikj = dnnl_abcdefghikj,
  1001. /// plain 12D tensor
  1002. abcdefghijkl = dnnl_abcdefghijkl,
  1003. /// permuted 12D tensor
  1004. abcdefghijlk = dnnl_abcdefghijlk,
  1005. /// 1D tensor; an alias for #dnnl::memory::format_tag::a
  1006. x = a,
  1007. /// 2D CNN activations tensor; an alias for #dnnl::memory::format_tag::ab
  1008. nc = ab,
  1009. /// 2D CNN activations tensor; an alias for #dnnl::memory::format_tag::ba
  1010. cn = ba,
  1011. /// 2D RNN statistics tensor; an alias for #dnnl::memory::format_tag::ab
  1012. tn = ab,
  1013. /// 2D RNN statistics tensor; an alias for #dnnl::memory::format_tag::ba
  1014. nt = ba,
  1015. /// 3D CNN activations tensor; an alias for #dnnl::memory::format_tag::abc
  1016. ncw = abc,
  1017. /// 3D CNN activations tensor; an alias for #dnnl::memory::format_tag::acb
  1018. nwc = acb,
  1019. /// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::abcd
  1020. nchw = abcd,
  1021. /// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::acdb
  1022. nhwc = acdb,
  1023. /// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::bcda
  1024. chwn = bcda,
  1025. /// 5D CNN activations tensor; an alias for #dnnl::memory::format_tag::abcde
  1026. ncdhw = abcde,
  1027. /// 5D CNN activations tensor; an alias for #dnnl::memory::format_tag::acdeb
  1028. ndhwc = acdeb,
  1029. /// 2D CNN weights tensor; an alias for #dnnl::memory::format_tag::ab
  1030. oi = ab,
  1031. /// 2D CNN weights tensor; an alias for #dnnl::memory::format_tag::ba
  1032. io = ba,
  1033. /// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::abc
  1034. oiw = abc,
  1035. /// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::acb
  1036. owi = acb,
  1037. /// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::cba
  1038. wio = cba,
  1039. /// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::bca
  1040. iwo = bca,
  1041. /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::abcd
  1042. oihw = abcd,
  1043. /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::cdba
  1044. hwio = cdba,
  1045. /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::acdb
  1046. ohwi = acdb,
  1047. /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::bcda
  1048. ihwo = bcda,
  1049. /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::bacd
  1050. iohw = bacd,
  1051. /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::abcde
  1052. oidhw = abcde,
  1053. /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::cdeba
  1054. dhwio = cdeba,
  1055. /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::acdeb
  1056. odhwi = acdeb,
  1057. /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::bacde
  1058. iodhw = bacde,
  1059. /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::bcdea
  1060. idhwo = bcdea,
  1061. /// 4D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcd
  1062. goiw = abcd,
  1063. /// 4D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abdc
  1064. gowi = abdc,
  1065. /// 4D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::dcab
  1066. wigo = dcab,
  1067. /// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abdec
  1068. gohwi = abdec,
  1069. /// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcde
  1070. goihw = abcde,
  1071. /// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::decab
  1072. hwigo = decab,
  1073. /// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::acbde
  1074. giohw = acbde,
  1075. /// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcdef
  1076. goidhw = abcdef,
  1077. /// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcdef
  1078. giodhw = acbdef,
  1079. /// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abdefc
  1080. godhwi = abdefc,
  1081. /// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::defcab
  1082. dhwigo = defcab,
  1083. /// 3D RNN data tensor in the format (seq_length, batch, input
  1084. /// channels); an alias for #dnnl::memory::format_tag::abc.
  1085. tnc = abc,
  1086. /// 3D RNN data tensor in the format (batch, seq_length, input
  1087. /// channels); an alias for #dnnl::memory::format_tag::bac.
  1088. ntc = bac,
  1089. /// 4D RNN states tensor in the format (num_layers, num_directions,
  1090. /// batch, state channels); an alias for #dnnl::memory::format_tag::abcd.
  1091. ldnc = abcd,
  1092. /// 5D RNN weights tensor in the format (num_layers, num_directions,
  1093. /// input_channels, num_gates, output_channels);
  1094. /// an alias for #dnnl::memory::format_tag::abcde.
  1095. ///
  1096. /// - For LSTM cells, the gates order is input, forget, candidate
  1097. /// and output gate.
  1098. /// - For GRU cells, the gates order is update, reset and output gate.
  1099. ldigo = abcde,
  1100. /// 5D RNN weights tensor in the format (num_layers, num_directions,
  1101. /// num_gates, output_channels, input_channels);
  1102. /// an alias for #dnnl::memory::format_tag::abdec.
  1103. ///
  1104. /// - For LSTM cells, the gates order is input, forget, candidate
  1105. /// and output gate.
  1106. /// - For GRU cells, the gates order is update, reset and output gate.
  1107. ldgoi = abdec,
  1108. /// 4D LSTM projection tensor in the format (num_layers, num_directions,
  1109. /// num_channels_in_hidden_state, num_channels_in_recurrent_projection);
  1110. /// an alias for #dnnl::memory::format_tag::abcd.
  1111. ldio = abcd,
  1112. /// 4D LSTM projection tensor in the format (num_layers, num_directions,
  1113. /// num_channels_in_recurrent_projection, num_channels_in_hidden_state);
  1114. /// an alias for #dnnl::memory::format_tag::abdc.
  1115. ldoi = abdc,
  1116. /// 4D RNN bias tensor in the format (num_layers, num_directions,
  1117. /// num_gates, output_channels);
  1118. /// an alias for #dnnl::memory::format_tag::abcd.
  1119. ///
  1120. /// - For LSTM cells, the gates order is input, forget, candidate
  1121. /// and output gate.
  1122. /// - For GRU cells, the gates order is update, reset and output gate.
  1123. ldgo = abcd,
  1124. // Opaque blocked formats
  1125. AB16b16a = dnnl_AB16b16a,
  1126. AB16b32a = dnnl_AB16b32a,
  1127. AB16b48a = dnnl_AB16b48a,
  1128. AB16b64a = dnnl_AB16b64a,
  1129. AB8b16a2b = dnnl_AB8b16a2b,
  1130. AB8b32a2b = dnnl_AB8b32a2b,
  1131. AB8b64a2b = dnnl_AB8b64a2b,
  1132. AB4b16a4b = dnnl_AB4b16a4b,
  1133. AB4b32a4b = dnnl_AB4b32a4b,
  1134. AB4b64a4b = dnnl_AB4b64a4b,
  1135. AB16b16a4b = dnnl_AB16b16a4b,
  1136. AB16b32a4b = dnnl_AB16b32a4b,
  1137. AB16b48a4b = dnnl_AB16b48a4b,
  1138. AB16b64a4b = dnnl_AB16b64a4b,
  1139. AB16b16a2b = dnnl_AB16b16a2b,
  1140. AB16b32a2b = dnnl_AB16b32a2b,
  1141. AB16b48a2b = dnnl_AB16b48a2b,
  1142. AB16b64a2b = dnnl_AB16b64a2b,
  1143. Ab4a = dnnl_Ab4a,
  1144. Ab8a = dnnl_Ab8a,
  1145. Ab32a = dnnl_Ab32a,
  1146. Abc16a = dnnl_Abc16a,
  1147. ABc16a16b = dnnl_ABc16a16b,
  1148. ABc4a4b = dnnl_ABc4a4b,
  1149. aBc16b = dnnl_aBc16b,
  1150. aBc32b = dnnl_aBc32b,
  1151. ABc16b16a = dnnl_ABc16b16a,
  1152. AcB16b16a = dnnl_AcB16b16a,
  1153. ABc16b32a = dnnl_ABc16b32a,
  1154. AcB16b32a = dnnl_AcB16b32a,
  1155. ABc16b48a = dnnl_ABc16b48a,
  1156. AcB16b48a = dnnl_AcB16b48a,
  1157. ABc16b64a = dnnl_ABc16b64a,
  1158. AcB16b64a = dnnl_AcB16b64a,
  1159. Abc4a = dnnl_Abc4a,
  1160. aBc4b = dnnl_aBc4b,
  1161. ABc4b16a4b = dnnl_ABc4b16a4b,
  1162. AcB4b16a4b = dnnl_AcB4b16a4b,
  1163. ABc4b32a4b = dnnl_ABc4b32a4b,
  1164. AcB4b32a4b = dnnl_AcB4b32a4b,
  1165. ABc4b64a4b = dnnl_ABc4b64a4b,
  1166. AcB4b64a4b = dnnl_AcB4b64a4b,
  1167. ABc2b8a4b = dnnl_ABc2b8a4b,
  1168. ABc16a16b2a = dnnl_ABc16a16b2a,
  1169. ABc16b16a4b = dnnl_ABc16b16a4b,
  1170. ABc16b32a4b = dnnl_ABc16b32a4b,
  1171. ABc16b48a4b = dnnl_ABc16b48a4b,
  1172. ABc16b64a4b = dnnl_ABc16b64a4b,
  1173. ABc16b16a2b = dnnl_ABc16b16a2b,
  1174. ABc16b32a2b = dnnl_ABc16b32a2b,
  1175. ABc16b48a2b = dnnl_ABc16b48a2b,
  1176. ABc16b64a2b = dnnl_ABc16b64a2b,
  1177. ABc4b4a = dnnl_ABc4b4a,
  1178. ABc8a16b2a = dnnl_ABc8a16b2a,
  1179. ABc8a8b = dnnl_ABc8a8b,
  1180. ABc8a4b = dnnl_ABc8a4b,
  1181. aBc8b = dnnl_aBc8b,
  1182. ABc8b16a2b = dnnl_ABc8b16a2b,
  1183. AcB8b16a2b = dnnl_AcB8b16a2b,
  1184. ABc8b32a2b = dnnl_ABc8b32a2b,
  1185. AcB8b32a2b = dnnl_AcB8b32a2b,
  1186. ABc8b64a2b = dnnl_ABc8b64a2b,
  1187. AcB8b64a2b = dnnl_AcB8b64a2b,
  1188. ABc8b8a = dnnl_ABc8b8a,
  1189. AcB8b8a = dnnl_AcB8b8a,
  1190. Abcd8a = dnnl_Abcd8a,
  1191. Abcd16a = dnnl_Abcd16a,
  1192. Abcd32a = dnnl_Abcd32a,
  1193. ABcd16a16b = dnnl_ABcd16a16b,
  1194. aBcd16b = dnnl_aBcd16b,
  1195. aBcd32b = dnnl_aBcd32b,
  1196. ABcd16b16a = dnnl_ABcd16b16a,
  1197. AcdB16b16a = dnnl_AcdB16b16a,
  1198. ABcd16b32a = dnnl_ABcd16b32a,
  1199. AcdB16b32a = dnnl_AcdB16b32a,
  1200. ABcd16b48a = dnnl_ABcd16b48a,
  1201. AcdB16b48a = dnnl_AcdB16b48a,
  1202. ABcd16b64a = dnnl_ABcd16b64a,
  1203. AcdB16b64a = dnnl_AcdB16b64a,
  1204. aBCd16b16c = dnnl_aBCd16b16c,
  1205. aBCd16c16b = dnnl_aBCd16c16b,
  1206. Abcd4a = dnnl_Abcd4a,
  1207. aBcd4b = dnnl_aBcd4b,
  1208. ABcd4b16a4b = dnnl_ABcd4b16a4b,
  1209. AcdB4b16a4b = dnnl_AcdB4b16a4b,
  1210. ABcd4b32a4b = dnnl_ABcd4b32a4b,
  1211. AcdB4b32a4b = dnnl_AcdB4b32a4b,
  1212. ABcd4b64a4b = dnnl_ABcd4b64a4b,
  1213. AcdB4b64a4b = dnnl_AcdB4b64a4b,
  1214. ABcd2b8a4b = dnnl_ABcd2b8a4b,
  1215. ABcd4b4a = dnnl_ABcd4b4a,
  1216. ABcd4a4b = dnnl_ABcd4a4b,
  1217. aBCd4c16b4c = dnnl_aBCd4c16b4c,
  1218. aBCd2c8b4c = dnnl_aBCd2c8b4c,
  1219. ABcd16a16b2a = dnnl_ABcd16a16b2a,
  1220. ABcd16b16a4b = dnnl_ABcd16b16a4b,
  1221. ABcd16b32a4b = dnnl_ABcd16b32a4b,
  1222. ABcd16b48a4b = dnnl_ABcd16b48a4b,
  1223. ABcd16b64a4b = dnnl_ABcd16b64a4b,
  1224. ABcd16b16a2b = dnnl_ABcd16b16a2b,
  1225. ABcd16b32a2b = dnnl_ABcd16b32a2b,
  1226. ABcd16b48a2b = dnnl_ABcd16b48a2b,
  1227. ABcd16b64a2b = dnnl_ABcd16b64a2b,
  1228. aBCd16b16c2b = dnnl_aBCd16b16c2b,
  1229. aBCd16c16b4c = dnnl_aBCd16c16b4c,
  1230. aBCd16c16b2c = dnnl_aBCd16c16b2c,
  1231. aBCd4c4b = dnnl_aBCd4c4b,
  1232. aBCd4b4c = dnnl_aBCd4b4c,
  1233. ABcd8a16b2a = dnnl_ABcd8a16b2a,
  1234. ABcd8a8b = dnnl_ABcd8a8b,
  1235. ABcd8a4b = dnnl_ABcd8a4b,
  1236. ABcd8a2b = dnnl_ABcd8a2b,
  1237. /// 4D tensor blocked by 2nd dimension with block size 8
  1238. aBcd8b = dnnl_aBcd8b,
  1239. ABcd8b16a2b = dnnl_ABcd8b16a2b,
  1240. AcdB8b16a2b = dnnl_AcdB8b16a2b,
  1241. ABcd8b32a2b = dnnl_ABcd8b32a2b,
  1242. AcdB8b32a2b = dnnl_AcdB8b32a2b,
  1243. ABcd8b64a2b = dnnl_ABcd8b64a2b,
  1244. AcdB8b64a2b = dnnl_AcdB8b64a2b,
  1245. aBCd8b16c2b = dnnl_aBCd8b16c2b,
  1246. /// 4D tensor blocked by 1st and 2nd dimension with block size 8
  1247. ABcd8b8a = dnnl_ABcd8b8a,
  1248. AcdB8b8a = dnnl_AcdB8b8a,
  1249. aBCd8b8c = dnnl_aBCd8b8c,
  1250. aBCd8b4c = dnnl_aBCd8b4c,
  1251. aBCd8c16b2c = dnnl_aBCd8c16b2c,
  1252. aBCd8c8b = dnnl_aBCd8c8b,
  1253. Abcde16a = dnnl_Abcde16a,
  1254. Abcde32a = dnnl_Abcde32a,
  1255. ABcde16a16b = dnnl_ABcde16a16b,
  1256. aBcde16b = dnnl_aBcde16b,
  1257. aBcde32b = dnnl_aBcde32b,
  1258. ABcde16b16a = dnnl_ABcde16b16a,
  1259. AcdeB16b16a = dnnl_AcdeB16b16a,
  1260. ABcde16b32a = dnnl_ABcde16b32a,
  1261. AcdeB16b32a = dnnl_AcdeB16b32a,
  1262. ABcde16b48a = dnnl_ABcde16b48a,
  1263. AcdeB16b48a = dnnl_AcdeB16b48a,
  1264. ABcde16b64a = dnnl_ABcde16b64a,
  1265. AcdeB16b64a = dnnl_AcdeB16b64a,
  1266. aBCde16b16c = dnnl_aBCde16b16c,
  1267. aBCde16c16b = dnnl_aBCde16c16b,
  1268. aBCde2c8b4c = dnnl_aBCde2c8b4c,
  1269. Abcde4a = dnnl_Abcde4a,
  1270. aBcde4b = dnnl_aBcde4b,
  1271. ABcde4b4a = dnnl_ABcde4b4a,
  1272. ABcde4a4b = dnnl_ABcde4a4b,
  1273. aBCde4b4c = dnnl_aBCde4b4c,
  1274. aBCde4c16b4c = dnnl_aBCde4c16b4c,
  1275. aBCde16b16c2b = dnnl_aBCde16b16c2b,
  1276. aBCde16c16b4c = dnnl_aBCde16c16b4c,
  1277. aBCde16c16b2c = dnnl_aBCde16c16b2c,
  1278. aBCdef16c16b2c = dnnl_aBCdef16c16b2c,
  1279. aBCde4c4b = dnnl_aBCde4c4b,
  1280. Abcde8a = dnnl_Abcde8a,
  1281. ABcde8a8b = dnnl_ABcde8a8b,
  1282. ABcde8a4b = dnnl_ABcde8a4b,
  1283. aBcde8b = dnnl_aBcde8b,
  1284. ABcde8b16a2b = dnnl_ABcde8b16a2b,
  1285. AcdeB8b16a2b = dnnl_AcdeB8b16a2b,
  1286. ABcde8b32a2b = dnnl_ABcde8b32a2b,
  1287. AcdeB8b32a2b = dnnl_AcdeB8b32a2b,
  1288. ABcde8b64a2b = dnnl_ABcde8b64a2b,
  1289. AcdeB8b64a2b = dnnl_AcdeB8b64a2b,
  1290. ABcde4b16a4b = dnnl_ABcde4b16a4b,
  1291. AcdeB4b16a4b = dnnl_AcdeB4b16a4b,
  1292. ABcde4b32a4b = dnnl_ABcde4b32a4b,
  1293. AcdeB4b32a4b = dnnl_AcdeB4b32a4b,
  1294. ABcde4b64a4b = dnnl_ABcde4b64a4b,
  1295. AcdeB4b64a4b = dnnl_AcdeB4b64a4b,
  1296. ABcde16b16a4b = dnnl_ABcde16b16a4b,
  1297. ABcde16b32a4b = dnnl_ABcde16b32a4b,
  1298. ABcde16b48a4b = dnnl_ABcde16b48a4b,
  1299. ABcde16b64a4b = dnnl_ABcde16b64a4b,
  1300. ABcde16b16a2b = dnnl_ABcde16b16a2b,
  1301. ABcde16b32a2b = dnnl_ABcde16b32a2b,
  1302. ABcde16b48a2b = dnnl_ABcde16b48a2b,
  1303. ABcde16b64a2b = dnnl_ABcde16b64a2b,
  1304. ABcde2b8a4b = dnnl_ABcde2b8a4b,
  1305. aBCde8b16c2b = dnnl_aBCde8b16c2b,
  1306. ABcde8b8a = dnnl_ABcde8b8a,
  1307. AcdeB8b8a = dnnl_AcdeB8b8a,
  1308. aBCde8b8c = dnnl_aBCde8b8c,
  1309. aBCde8b4c = dnnl_aBCde8b4c,
  1310. ABcd4a8b8a4b = dnnl_ABcd4a8b8a4b,
  1311. ABcd2a8b8a2b = dnnl_ABcd2a8b8a2b,
  1312. aBCde4b8c8b4c = dnnl_aBCde4b8c8b4c,
  1313. aBCde2b8c8b2c = dnnl_aBCde2b8c8b2c,
  1314. aBCde8c16b2c = dnnl_aBCde8c16b2c,
  1315. aBCde8c8b = dnnl_aBCde8c8b,
  1316. aBcdef16b = dnnl_aBcdef16b,
  1317. aBCdef16b16c = dnnl_aBCdef16b16c,
  1318. aBCdef16c16b = dnnl_aBCdef16c16b,
  1319. aBcdef4b = dnnl_aBcdef4b,
  1320. aBCdef2c8b4c = dnnl_aBCdef2c8b4c,
  1321. aBCdef4c4b = dnnl_aBCdef4c4b,
  1322. aBCdef4b4c = dnnl_aBCdef4b4c,
  1323. aBCdef8b8c = dnnl_aBCdef8b8c,
  1324. aBCdef8b4c = dnnl_aBCdef8b4c,
  1325. aBCdef8c16b2c = dnnl_aBCdef8c16b2c,
  1326. aBCdef4c16b4c = dnnl_aBCdef4c16b4c,
  1327. aBCdef8c8b = dnnl_aBCdef8c8b,
  1328. aBdc16b = dnnl_aBdc16b,
  1329. aBdc4b = dnnl_aBdc4b,
  1330. aBdc8b = dnnl_aBdc8b,
  1331. aBdC8b2c = dnnl_aBdC8b2c,
  1332. aBdC8b4c = dnnl_aBdC8b4c,
  1333. aBdec16b = dnnl_aBdec16b,
  1334. aBdec4b = dnnl_aBdec4b,
  1335. aBdec8b = dnnl_aBdec8b,
  1336. aBdeC8b2c = dnnl_aBdeC8b2c,
  1337. aBdeC8b4c = dnnl_aBdeC8b4c,
  1338. aBdefc16b = dnnl_aBdefc16b,
  1339. aCBdef16c16b = dnnl_aCBdef16c16b,
  1340. aCBdef8b8c = dnnl_aCBdef8b8c,
  1341. aCBdef16b16c = dnnl_aCBdef16b16c,
  1342. aBdefc4b = dnnl_aBdefc4b,
  1343. aBdefc8b = dnnl_aBdefc8b,
  1344. aBdefC8b2c = dnnl_aBdefC8b2c,
  1345. aBdefC8b4c = dnnl_aBdefC8b4c,
  1346. Acb16a = dnnl_Acb16a,
  1347. Acb4a = dnnl_Acb4a,
  1348. Acb8a = dnnl_Acb8a,
  1349. AcB8a2b = dnnl_AcB8a2b,
  1350. AcB8a4b = dnnl_AcB8a4b,
  1351. aCBd8b8c = dnnl_aCBd8b8c,
  1352. aCBd16b16c = dnnl_aCBd16b16c,
  1353. aCBd16c16b = dnnl_aCBd16c16b,
  1354. aCBde8b8c = dnnl_aCBde8b8c,
  1355. aCBde16b16c = dnnl_aCBde16b16c,
  1356. aCBde16c16b = dnnl_aCBde16c16b,
  1357. Acdb16a = dnnl_Acdb16a,
  1358. Acdb4a = dnnl_Acdb4a,
  1359. Acdb8a = dnnl_Acdb8a,
  1360. AcdB8a2b = dnnl_AcdB8a2b,
  1361. AcdB8a4b = dnnl_AcdB8a4b,
  1362. Acdeb16a = dnnl_Acdeb16a,
  1363. Acdeb4a = dnnl_Acdeb4a,
  1364. Acdeb8a = dnnl_Acdeb8a,
  1365. AcdeB8a2b = dnnl_AcdeB8a2b,
  1366. AcdeB8a4b = dnnl_AcdeB8a4b,
  1367. BAc8a8b = dnnl_BAc8a8b,
  1368. BAc16a16b = dnnl_BAc16a16b,
  1369. BAc16b16a = dnnl_BAc16b16a,
  1370. BAcd8a8b = dnnl_BAcd8a8b,
  1371. BAcd16a16b = dnnl_BAcd16a16b,
  1372. BAcd16b16a = dnnl_BAcd16b16a,
  1373. ABcd32a32b = dnnl_ABcd32a32b,
  1374. BAcde16b16a = dnnl_BAcde16b16a,
  1375. BAcde8a8b = dnnl_BAcde8a8b,
  1376. BAcde16a16b = dnnl_BAcde16a16b,
  1377. aBdec32b = dnnl_aBdec32b,
  1378. Abcdef16a = dnnl_Abcdef16a,
  1379. Abcdef32a = dnnl_Abcdef32a,
  1380. Acdb32a = dnnl_Acdb32a,
  1381. aBCd2b4c2b = dnnl_aBCd2b4c2b,
  1382. aBCde2b4c2b = dnnl_aBCde2b4c2b,
  1383. aBCdef2b4c2b = dnnl_aBCdef2b4c2b,
  1384. aBCd2c4b2c = dnnl_aBCd2c4b2c,
  1385. aBCde2c4b2c = dnnl_aBCde2c4b2c,
  1386. aBCdef2c4b2c = dnnl_aBCdef2c4b2c,
  1387. aBCd4b8c2b = dnnl_aBCd4b8c2b,
  1388. aBCde4b8c2b = dnnl_aBCde4b8c2b,
  1389. aBCdef4b8c2b = dnnl_aBCdef4b8c2b,
  1390. aBCd4c8b2c = dnnl_aBCd4c8b2c,
  1391. aBCde4c8b2c = dnnl_aBCde4c8b2c,
  1392. aBCdef4c8b2c = dnnl_aBCdef4c8b2c,
  1393. AB32a32b8a4b = dnnl_AB32a32b8a4b,
  1394. AB32a32b8a2b = dnnl_AB32a32b8a2b,
  1395. AB8a4b = dnnl_AB8a4b,
  1396. AB8a2b = dnnl_AB8a2b,
  1397. abDc16d = dnnl_abDc16d,
  1398. abDc32d = dnnl_abDc32d,
  1399. abDC16d4c = dnnl_abDC16d4c,
  1400. abDC32d4c = dnnl_abDC32d4c,
  1401. abCd32c = dnnl_abCd32c,
  1402. abdEc16e = dnnl_abdEc16e,
  1403. abdEc32e = dnnl_abdEc32e,
  1404. abdEC16e4c = dnnl_abdEC16e4c,
  1405. abdEC32e2c = dnnl_abdEC32e2c,
  1406. abdEC32e4c = dnnl_abdEC32e4c,
  1407. abdCe16c = dnnl_abdCe16c,
  1408. abdCe32c = dnnl_abdCe32c,
  1409. abdCE32c2e = dnnl_abdCE32c2e,
  1410. aBCdef16c16b4c = dnnl_aBCdef16c16b4c,
  1411. aBdC16b4c = dnnl_aBdC16b4c,
  1412. aBdeC16b4c = dnnl_aBdeC16b4c,
  1413. AcB16a4b = dnnl_AcB16a4b,
  1414. AcdB16a2b = dnnl_AcdB16a2b,
  1415. aBdefC16b4c = dnnl_aBdefC16b4c,
  1416. AcdeB16a4b = dnnl_AcdeB16a4b,
  1417. Acb32a = dnnl_Acb32a,
  1418. AcB32a2b = dnnl_AcB32a2b,
  1419. AcB32a4b = dnnl_AcB32a4b,
  1420. Acb48a = dnnl_Acb48a,
  1421. AcB48a2b = dnnl_AcB48a2b,
  1422. AcB48a4b = dnnl_AcB48a4b,
  1423. Acb64a = dnnl_Acb64a,
  1424. AcB64a2b = dnnl_AcB64a2b,
  1425. AcB64a4b = dnnl_AcB64a4b,
  1426. cBa2b = dnnl_cBa2b,
  1427. cBa4b = dnnl_cBa4b,
  1428. aBdc32b = dnnl_aBdc32b,
  1429. aBdC32b2c = dnnl_aBdC32b2c,
  1430. aBdC32b4c = dnnl_aBdC32b4c,
  1431. aBdc48b = dnnl_aBdc48b,
  1432. aBdC48b2c = dnnl_aBdC48b2c,
  1433. aBdC48b4c = dnnl_aBdC48b4c,
  1434. aBdc64b = dnnl_aBdc64b,
  1435. aBdC64b2c = dnnl_aBdC64b2c,
  1436. aBdC64b4c = dnnl_aBdC64b4c,
  1437. adcb = dnnl_adcb,
  1438. adCb2c = dnnl_adCb2c,
  1439. adCb4c = dnnl_adCb4c,
  1440. AcdB32a2b = dnnl_AcdB32a2b,
  1441. AcdB32a4b = dnnl_AcdB32a4b,
  1442. Acdb48a = dnnl_Acdb48a,
  1443. AcdB48a2b = dnnl_AcdB48a2b,
  1444. AcdB48a4b = dnnl_AcdB48a4b,
  1445. Acdb64a = dnnl_Acdb64a,
  1446. AcdB64a2b = dnnl_AcdB64a2b,
  1447. AcdB64a4b = dnnl_AcdB64a4b,
  1448. cdBa2b = dnnl_cdBa2b,
  1449. cdBa4b = dnnl_cdBa4b,
  1450. aBdeC32b2c = dnnl_aBdeC32b2c,
  1451. aBdeC32b4c = dnnl_aBdeC32b4c,
  1452. aBdec48b = dnnl_aBdec48b,
  1453. aBdeC48b2c = dnnl_aBdeC48b2c,
  1454. aBdeC48b4c = dnnl_aBdeC48b4c,
  1455. aBdec64b = dnnl_aBdec64b,
  1456. aBdeC64b2c = dnnl_aBdeC64b2c,
  1457. aBdeC64b4c = dnnl_aBdeC64b4c,
  1458. adecb = dnnl_adecb,
  1459. adeCb2c = dnnl_adeCb2c,
  1460. adeCb4c = dnnl_adeCb4c,
  1461. Acdeb32a = dnnl_Acdeb32a,
  1462. AcdeB32a2b = dnnl_AcdeB32a2b,
  1463. AcdeB32a4b = dnnl_AcdeB32a4b,
  1464. Acdeb48a = dnnl_Acdeb48a,
  1465. AcdeB48a2b = dnnl_AcdeB48a2b,
  1466. AcdeB48a4b = dnnl_AcdeB48a4b,
  1467. Acdeb64a = dnnl_Acdeb64a,
  1468. AcdeB64a2b = dnnl_AcdeB64a2b,
  1469. AcdeB64a4b = dnnl_AcdeB64a4b,
  1470. cdeBa2b = dnnl_cdeBa2b,
  1471. cdeBa4b = dnnl_cdeBa4b,
  1472. aBdefc32b = dnnl_aBdefc32b,
  1473. aBdefC32b2c = dnnl_aBdefC32b2c,
  1474. aBdefC32b4c = dnnl_aBdefC32b4c,
  1475. aBdefc48b = dnnl_aBdefc48b,
  1476. aBdefC48b2c = dnnl_aBdefC48b2c,
  1477. aBdefC48b4c = dnnl_aBdefC48b4c,
  1478. aBdefc64b = dnnl_aBdefc64b,
  1479. aBdefC64b2c = dnnl_aBdefC64b2c,
  1480. aBdefC64b4c = dnnl_aBdefC64b4c,
  1481. adefcb = dnnl_adefcb,
  1482. adefCb2c = dnnl_adefCb2c,
  1483. adefCb4c = dnnl_adefCb4c,
  1484. ABc32a32b = dnnl_ABc32a32b,
  1485. BAc8a16b2a = dnnl_BAc8a16b2a,
  1486. BAcd8a16b2a = dnnl_BAcd8a16b2a,
  1487. ABcde8a16b2a = dnnl_ABcde8a16b2a,
  1488. aCBd8b16c2b = dnnl_aCBd8b16c2b,
  1489. BAcde8a16b2a = dnnl_BAcde8a16b2a,
  1490. aCBde8b16c2b = dnnl_aCBde8b16c2b,
  1491. ABcde32a32b = dnnl_ABcde32a32b,
  1492. ABc4a8b8a4b = dnnl_ABc4a8b8a4b,
  1493. ABcde4a8b8a4b = dnnl_ABcde4a8b8a4b,
  1494. BAc4b8a8b4a = dnnl_BAc4b8a8b4a,
  1495. BAcd4b8a8b4a = dnnl_BAcd4b8a8b4a,
  1496. BAcde4b8a8b4a = dnnl_BAcde4b8a8b4a,
  1497. aBCd4b8c8b4c = dnnl_aBCd4b8c8b4c,
  1498. aBCdef4b8c8b4c = dnnl_aBCdef4b8c8b4c,
  1499. aBCdef8b16c2b = dnnl_aBCdef8b16c2b,
  1500. aCBdef8b16c2b = dnnl_aCBdef8b16c2b,
  1501. aBdC16b2c = dnnl_aBdC16b2c,
  1502. aBdeC16b2c = dnnl_aBdeC16b2c,
  1503. aBdefC16b2c = dnnl_aBdefC16b2c,
  1504. aBedc16b = dnnl_aBedc16b,
  1505. AcB16a2b = dnnl_AcB16a2b,
  1506. AcdB16a4b = dnnl_AcdB16a4b,
  1507. AcdeB16a2b = dnnl_AcdeB16a2b,
  1508. Adcb16a = dnnl_Adcb16a,
  1509. aCBd4c8b8c4b = dnnl_aCBd4c8b8c4b,
  1510. aCBde4c8b8c4b = dnnl_aCBde4c8b8c4b,
  1511. aCBdef4c8b8c4b = dnnl_aCBdef4c8b8c4b,
  1512. ABc32a16b = dnnl_ABc32a16b,
  1513. ABcd16a32b = dnnl_ABcd16a32b,
  1514. ABcd32a16b = dnnl_ABcd32a16b,
  1515. ABcde32a16b = dnnl_ABcde32a16b,
  1516. AB48a16b = dnnl_AB48a16b,
  1517. AB48a32b = dnnl_AB48a32b,
  1518. ABc40a16b = dnnl_ABc40a16b,
  1519. ABc40a32b = dnnl_ABc40a32b,
  1520. aBC48b16c = dnnl_aBC48b16c,
  1521. aBC48b32c = dnnl_aBC48b32c,
  1522. ABcd40a16b = dnnl_ABcd40a16b,
  1523. ABcd40a32b = dnnl_ABcd40a32b,
  1524. BA16a16b = dnnl_BA16a16b,
  1525. BA16a32b = dnnl_BA16a32b,
  1526. BA16a48b = dnnl_BA16a48b,
  1527. BA16a64b = dnnl_BA16a64b,
  1528. BA16a16b2a = dnnl_BA16a16b2a,
  1529. BA16a32b2a = dnnl_BA16a32b2a,
  1530. BA16a48b2a = dnnl_BA16a48b2a,
  1531. BA16a64b2a = dnnl_BA16a64b2a,
  1532. BA16a16b4a = dnnl_BA16a16b4a,
  1533. BA16a32b4a = dnnl_BA16a32b4a,
  1534. BA16a48b4a = dnnl_BA16a48b4a,
  1535. BA16a64b4a = dnnl_BA16a64b4a,
  1536. BA24b8a = dnnl_BA24b8a,
  1537. aCB24c8b = dnnl_aCB24c8b,
  1538. abDC24d8c = dnnl_abDC24d8c,
  1539. decbA16a = dnnl_decbA16a,
  1540. decbA8a = dnnl_decbA8a,
  1541. defcbA16a = dnnl_defcbA16a,
  1542. defcbA8a = dnnl_defcbA8a,
  1543. aCB16b16c = dnnl_aCB16b16c,
  1544. aCB16b32c = dnnl_aCB16b32c,
  1545. aCB16b48c = dnnl_aCB16b48c,
  1546. aCB16b64c = dnnl_aCB16b64c,
  1547. aCB16b16c2b = dnnl_aCB16b16c2b,
  1548. aCB16b32c2b = dnnl_aCB16b32c2b,
  1549. aCB16b48c2b = dnnl_aCB16b48c2b,
  1550. aCB16b64c2b = dnnl_aCB16b64c2b,
  1551. aCB16b16c4b = dnnl_aCB16b16c4b,
  1552. aCB16b32c4b = dnnl_aCB16b32c4b,
  1553. aCB16b48c4b = dnnl_aCB16b48c4b,
  1554. aCB16b64c4b = dnnl_aCB16b64c4b,
  1555. Acb24a = dnnl_Acb24a,
  1556. Acdb24a = dnnl_Acdb24a,
  1557. Acdeb24a = dnnl_Acdeb24a,
  1558. aBdc24b = dnnl_aBdc24b,
  1559. aBdec24b = dnnl_aBdec24b,
  1560. aBdefc24b = dnnl_aBdefc24b,
  1561. AcB24a2b = dnnl_AcB24a2b,
  1562. AcdB24a2b = dnnl_AcdB24a2b,
  1563. AcdeB24a2b = dnnl_AcdeB24a2b,
  1564. aBdC24b2c = dnnl_aBdC24b2c,
  1565. aBdeC24b2c = dnnl_aBdeC24b2c,
  1566. aBdefC24b2c = dnnl_aBdefC24b2c,
  1567. AcB24a4b = dnnl_AcB24a4b,
  1568. AcdB24a4b = dnnl_AcdB24a4b,
  1569. AcdeB24a4b = dnnl_AcdeB24a4b,
  1570. aBdC24b4c = dnnl_aBdC24b4c,
  1571. aBdeC24b4c = dnnl_aBdeC24b4c,
  1572. aBdefC24b4c = dnnl_aBdefC24b4c,
  1573. AB8b32a = dnnl_AB8b32a,
  1574. ABc8b32a = dnnl_ABc8b32a,
  1575. AcB8b32a = dnnl_AcB8b32a,
  1576. ABcd8b32a = dnnl_ABcd8b32a,
  1577. AcdB8b32a = dnnl_AcdB8b32a,
  1578. ABcde8b32a = dnnl_ABcde8b32a,
  1579. AcdeB8b32a = dnnl_AcdeB8b32a,
  1580. AB8b24a = dnnl_AB8b24a,
  1581. ABc8b24a = dnnl_ABc8b24a,
  1582. AcB8b24a = dnnl_AcB8b24a,
  1583. ABcd8b24a = dnnl_ABcd8b24a,
  1584. AcdB8b24a = dnnl_AcdB8b24a,
  1585. ABcde8b24a = dnnl_ABcde8b24a,
  1586. AcdeB8b24a = dnnl_AcdeB8b24a,
  1587. AB8b16a = dnnl_AB8b16a,
  1588. ABc8b16a = dnnl_ABc8b16a,
  1589. AcB8b16a = dnnl_AcB8b16a,
  1590. ABcd8b16a = dnnl_ABcd8b16a,
  1591. AcdB8b16a = dnnl_AcdB8b16a,
  1592. ABcde8b16a = dnnl_ABcde8b16a,
  1593. AcdeB8b16a = dnnl_AcdeB8b16a,
  1594. AB8b8a = dnnl_AB8b8a,
  1595. abDC8d8c = dnnl_abDC8d8c,
  1596. abDC16d8c = dnnl_abDC16d8c,
  1597. aCB8c8b = dnnl_aCB8c8b,
  1598. aCB16c8b = dnnl_aCB16c8b,
  1599. BA8b8a = dnnl_BA8b8a,
  1600. BA16b8a = dnnl_BA16b8a,
  1601. AB2a4b = dnnl_AB2a4b,
  1602. format_tag_last = dnnl_format_tag_last,
  1603. nCdhw16c = dnnl_nCdhw16c,
  1604. nCdhw4c = dnnl_nCdhw4c,
  1605. nCdhw8c = dnnl_nCdhw8c,
  1606. nChw16c = dnnl_nChw16c,
  1607. nChw4c = dnnl_nChw4c,
  1608. nChw8c = dnnl_nChw8c,
  1609. nCw16c = dnnl_nCw16c,
  1610. nCw4c = dnnl_nCw4c,
  1611. nCw8c = dnnl_nCw8c,
  1612. NCw16n16c = dnnl_NCw16n16c,
  1613. NChw16n16c = dnnl_NChw16n16c,
  1614. NCdhw16n16c = dnnl_NCdhw16n16c,
  1615. NCdhw32n32c = dnnl_NCdhw32n32c,
  1616. NChw32n32c = dnnl_NChw32n32c,
  1617. IOhw16i16o = dnnl_IOhw16i16o,
  1618. OI16i16o = dnnl_OI16i16o,
  1619. OI16i32o = dnnl_OI16i32o,
  1620. OI16i48o = dnnl_OI16i48o,
  1621. OI16i64o = dnnl_OI16i64o,
  1622. OI8i16o2i = dnnl_OI8i16o2i,
  1623. OI8i32o2i = dnnl_OI8i32o2i,
  1624. OI8i64o2i = dnnl_OI8i64o2i,
  1625. OI4i8o4i = dnnl_OI4i8o4i,
  1626. OI4i16o4i = dnnl_OI4i16o4i,
  1627. OI4i24o4i = dnnl_OI4i24o4i,
  1628. OI4i32o4i = dnnl_OI4i32o4i,
  1629. OI4i64o4i = dnnl_OI4i64o4i,
  1630. Ohwi32o = dnnl_Ohwi32o,
  1631. IOdhw16i16o = dnnl_IOdhw16i16o,
  1632. gIOhw16i16o = dnnl_gIOhw16i16o,
  1633. gOhwi32o = dnnl_gOhwi32o,
  1634. Goidhw16g = dnnl_Goidhw16g,
  1635. IOw8o8i = dnnl_IOw8o8i,
  1636. IOw16o16i = dnnl_IOw16o16i,
  1637. OIw16i16o = dnnl_OIw16i16o,
  1638. OwI16i16o = dnnl_OwI16i16o,
  1639. OIw16i32o = dnnl_OIw16i32o,
  1640. OwI16i32o = dnnl_OwI16i32o,
  1641. OIw16i48o = dnnl_OIw16i48o,
  1642. OwI16i48o = dnnl_OwI16i48o,
  1643. OIw16i64o = dnnl_OIw16i64o,
  1644. OwI16i64o = dnnl_OwI16i64o,
  1645. IOw16i16o = dnnl_IOw16i16o,
  1646. gIOw16i16o = dnnl_gIOw16i16o,
  1647. OIw16o16i = dnnl_OIw16o16i,
  1648. Oiw16o = dnnl_Oiw16o,
  1649. OIw4i8o4i = dnnl_OIw4i8o4i,
  1650. OwI4i8o4i = dnnl_OwI4i8o4i,
  1651. OIw4i16o4i = dnnl_OIw4i16o4i,
  1652. OwI4i16o4i = dnnl_OwI4i16o4i,
  1653. OIw4i24o4i = dnnl_OIw4i24o4i,
  1654. OwI4i24o4i = dnnl_OwI4i24o4i,
  1655. OIw4i32o4i = dnnl_OIw4i32o4i,
  1656. OwI4i32o4i = dnnl_OwI4i32o4i,
  1657. OIw4i64o4i = dnnl_OIw4i64o4i,
  1658. OwI4i64o4i = dnnl_OwI4i64o4i,
  1659. OIw2i8o4i = dnnl_OIw2i8o4i,
  1660. OIw4i4o = dnnl_OIw4i4o,
  1661. OIw4o4i = dnnl_OIw4o4i,
  1662. Oiw4o = dnnl_Oiw4o,
  1663. OIw8i16o2i = dnnl_OIw8i16o2i,
  1664. OwI8i16o2i = dnnl_OwI8i16o2i,
  1665. OIw8i32o2i = dnnl_OIw8i32o2i,
  1666. OwI8i32o2i = dnnl_OwI8i32o2i,
  1667. OIw8i64o2i = dnnl_OIw8i64o2i,
  1668. OwI8i64o2i = dnnl_OwI8i64o2i,
  1669. OIw8i8o = dnnl_OIw8i8o,
  1670. OwI8i8o = dnnl_OwI8i8o,
  1671. OIw8o16i2o = dnnl_OIw8o16i2o,
  1672. OIw8o8i = dnnl_OIw8o8i,
  1673. OIw8o4i = dnnl_OIw8o4i,
  1674. OIw16i16o4i = dnnl_OIw16i16o4i,
  1675. OIw16i32o4i = dnnl_OIw16i32o4i,
  1676. OIw16i48o4i = dnnl_OIw16i48o4i,
  1677. OIw16i64o4i = dnnl_OIw16i64o4i,
  1678. OIw16i16o2i = dnnl_OIw16i16o2i,
  1679. OIw16i32o2i = dnnl_OIw16i32o2i,
  1680. OIw16i48o2i = dnnl_OIw16i48o2i,
  1681. OIw16i64o2i = dnnl_OIw16i64o2i,
  1682. OIw16o16i2o = dnnl_OIw16o16i2o,
  1683. Owi16o = dnnl_Owi16o,
  1684. OwI16o2i = dnnl_OwI16o2i,
  1685. Iwo16i = dnnl_Iwo16i,
  1686. IwO16i2o = dnnl_IwO16i2o,
  1687. IwO16i4o = dnnl_IwO16i4o,
  1688. Owi4o = dnnl_Owi4o,
  1689. Owi8o = dnnl_Owi8o,
  1690. OwI8o2i = dnnl_OwI8o2i,
  1691. OwI8o4i = dnnl_OwI8o4i,
  1692. IOhw8o8i = dnnl_IOhw8o8i,
  1693. IOhw16o16i = dnnl_IOhw16o16i,
  1694. Ohwi16o = dnnl_Ohwi16o,
  1695. OhwI16o2i = dnnl_OhwI16o2i,
  1696. Ihwo16i = dnnl_Ihwo16i,
  1697. IhwO16i2o = dnnl_IhwO16i2o,
  1698. IhwO16i4o = dnnl_IhwO16i4o,
  1699. Ohwi4o = dnnl_Ohwi4o,
  1700. Ohwi8o = dnnl_Ohwi8o,
  1701. OhwI8o2i = dnnl_OhwI8o2i,
  1702. OhwI8o4i = dnnl_OhwI8o4i,
  1703. OIhw16i16o = dnnl_OIhw16i16o,
  1704. OhwI16i16o = dnnl_OhwI16i16o,
  1705. OIhw16i32o = dnnl_OIhw16i32o,
  1706. OhwI16i32o = dnnl_OhwI16i32o,
  1707. OIhw16i48o = dnnl_OIhw16i48o,
  1708. OhwI16i48o = dnnl_OhwI16i48o,
  1709. OIhw16i64o = dnnl_OIhw16i64o,
  1710. OhwI16i64o = dnnl_OhwI16i64o,
  1711. OIhw16o16i = dnnl_OIhw16o16i,
  1712. Oihw16o = dnnl_Oihw16o,
  1713. OIhw4i8o4i = dnnl_OIhw4i8o4i,
  1714. OhwI4i8o4i = dnnl_OhwI4i8o4i,
  1715. OIhw4i16o4i = dnnl_OIhw4i16o4i,
  1716. OhwI4i16o4i = dnnl_OhwI4i16o4i,
  1717. OIhw4i24o4i = dnnl_OIhw4i24o4i,
  1718. OhwI4i24o4i = dnnl_OhwI4i24o4i,
  1719. OIhw4i32o4i = dnnl_OIhw4i32o4i,
  1720. OhwI4i32o4i = dnnl_OhwI4i32o4i,
  1721. OIhw4i64o4i = dnnl_OIhw4i64o4i,
  1722. OhwI4i64o4i = dnnl_OhwI4i64o4i,
  1723. OIhw4i4o = dnnl_OIhw4i4o,
  1724. OIhw4o4i = dnnl_OIhw4o4i,
  1725. Oihw4o = dnnl_Oihw4o,
  1726. OIhw8i16o2i = dnnl_OIhw8i16o2i,
  1727. OhwI8i16o2i = dnnl_OhwI8i16o2i,
  1728. OIhw8i32o2i = dnnl_OIhw8i32o2i,
  1729. OhwI8i32o2i = dnnl_OhwI8i32o2i,
  1730. OIhw8i64o2i = dnnl_OIhw8i64o2i,
  1731. OhwI8i64o2i = dnnl_OhwI8i64o2i,
  1732. OIhw8i8o = dnnl_OIhw8i8o,
  1733. OhwI8i8o = dnnl_OhwI8i8o,
  1734. OIhw8o16i2o = dnnl_OIhw8o16i2o,
  1735. OIhw8o8i = dnnl_OIhw8o8i,
  1736. OIhw8o4i = dnnl_OIhw8o4i,
  1737. OIhw2i8o4i = dnnl_OIhw2i8o4i,
  1738. IOdhw8o8i = dnnl_IOdhw8o8i,
  1739. IOdhw16o16i = dnnl_IOdhw16o16i,
  1740. Odhwi16o = dnnl_Odhwi16o,
  1741. OdhwI16o2i = dnnl_OdhwI16o2i,
  1742. Idhwo16i = dnnl_Idhwo16i,
  1743. IdhwO16i2o = dnnl_IdhwO16i2o,
  1744. IdhwO16i4o = dnnl_IdhwO16i4o,
  1745. Odhwi4o = dnnl_Odhwi4o,
  1746. Odhwi8o = dnnl_Odhwi8o,
  1747. OdhwI8o2i = dnnl_OdhwI8o2i,
  1748. OdhwI8o4i = dnnl_OdhwI8o4i,
  1749. OIdhw16i16o = dnnl_OIdhw16i16o,
  1750. OdhwI16i16o = dnnl_OdhwI16i16o,
  1751. OIdhw16i32o = dnnl_OIdhw16i32o,
  1752. OdhwI16i32o = dnnl_OdhwI16i32o,
  1753. OIdhw16i48o = dnnl_OIdhw16i48o,
  1754. OdhwI16i48o = dnnl_OdhwI16i48o,
  1755. OIdhw16i64o = dnnl_OIdhw16i64o,
  1756. OdhwI16i64o = dnnl_OdhwI16i64o,
  1757. OIdhw16o16i = dnnl_OIdhw16o16i,
  1758. OIdhw16o16i2o = dnnl_OIdhw16o16i2o,
  1759. Oidhw16o = dnnl_Oidhw16o,
  1760. OIdhw4i4o = dnnl_OIdhw4i4o,
  1761. OIdhw4o4i = dnnl_OIdhw4o4i,
  1762. Oidhw4o = dnnl_Oidhw4o,
  1763. OIdhw8i16o2i = dnnl_OIdhw8i16o2i,
  1764. OdhwI8i16o2i = dnnl_OdhwI8i16o2i,
  1765. OIdhw8i32o2i = dnnl_OIdhw8i32o2i,
  1766. OdhwI8i32o2i = dnnl_OdhwI8i32o2i,
  1767. OIdhw8i64o2i = dnnl_OIdhw8i64o2i,
  1768. OdhwI8i64o2i = dnnl_OdhwI8i64o2i,
  1769. OIdhw4i8o4i = dnnl_OIdhw4i8o4i,
  1770. OdhwI4i8o4i = dnnl_OdhwI4i8o4i,
  1771. OIdhw4i16o4i = dnnl_OIdhw4i16o4i,
  1772. OdhwI4i16o4i = dnnl_OdhwI4i16o4i,
  1773. OIdhw16i16o4i = dnnl_OIdhw16i16o4i,
  1774. OIdhw16i32o4i = dnnl_OIdhw16i32o4i,
  1775. OIdhw16i48o4i = dnnl_OIdhw16i48o4i,
  1776. OIdhw16i64o4i = dnnl_OIdhw16i64o4i,
  1777. OIdhw16i16o2i = dnnl_OIdhw16i16o2i,
  1778. OIdhw16i32o2i = dnnl_OIdhw16i32o2i,
  1779. OIdhw16i48o2i = dnnl_OIdhw16i48o2i,
  1780. OIdhw16i64o2i = dnnl_OIdhw16i64o2i,
  1781. OIdhw4i24o4i = dnnl_OIdhw4i24o4i,
  1782. OdhwI4i24o4i = dnnl_OdhwI4i24o4i,
  1783. OIdhw4i32o4i = dnnl_OIdhw4i32o4i,
  1784. OdhwI4i32o4i = dnnl_OdhwI4i32o4i,
  1785. OIdhw4i64o4i = dnnl_OIdhw4i64o4i,
  1786. OdhwI4i64o4i = dnnl_OdhwI4i64o4i,
  1787. OIdhw2i8o4i = dnnl_OIdhw2i8o4i,
  1788. OIdhw8i8o = dnnl_OIdhw8i8o,
  1789. OdhwI8i8o = dnnl_OdhwI8i8o,
  1790. OIdhw8o8i = dnnl_OIdhw8o8i,
  1791. OIdhw8o4i = dnnl_OIdhw8o4i,
  1792. gIOw8o8i = dnnl_gIOw8o8i,
  1793. gIOw16o16i = dnnl_gIOw16o16i,
  1794. gOIw16i16o = dnnl_gOIw16i16o,
  1795. gOIw16o16i = dnnl_gOIw16o16i,
  1796. gOiw16o = dnnl_gOiw16o,
  1797. gOIw4i16o4i = dnnl_gOIw4i16o4i,
  1798. gOIw2i8o4i = dnnl_gOIw2i8o4i,
  1799. gOIw4i4o = dnnl_gOIw4i4o,
  1800. gOIw4o4i = dnnl_gOIw4o4i,
  1801. gOiw4o = dnnl_gOiw4o,
  1802. gOIw8i16o2i = dnnl_gOIw8i16o2i,
  1803. gOIw8i8o = dnnl_gOIw8i8o,
  1804. gOIw8o16i2o = dnnl_gOIw8o16i2o,
  1805. gOIw8o8i = dnnl_gOIw8o8i,
  1806. gOIw8o4i = dnnl_gOIw8o4i,
  1807. gOIw16i16o4i = dnnl_gOIw16i16o4i,
  1808. gOIw16i16o2i = dnnl_gOIw16i16o2i,
  1809. gOIw16o16i2o = dnnl_gOIw16o16i2o,
  1810. gOwi16o = dnnl_gOwi16o,
  1811. gOwI16o2i = dnnl_gOwI16o2i,
  1812. gIwo16i = dnnl_gIwo16i,
  1813. gIwO16i2o = dnnl_gIwO16i2o,
  1814. gIwO16i4o = dnnl_gIwO16i4o,
  1815. gOwi4o = dnnl_gOwi4o,
  1816. gOwi8o = dnnl_gOwi8o,
  1817. gOwI8o2i = dnnl_gOwI8o2i,
  1818. gOwI8o4i = dnnl_gOwI8o4i,
  1819. Goiw8g = dnnl_Goiw8g,
  1820. Goiw16g = dnnl_Goiw16g,
  1821. gIOhw8o8i = dnnl_gIOhw8o8i,
  1822. gIOhw16o16i = dnnl_gIOhw16o16i,
  1823. gOhwi16o = dnnl_gOhwi16o,
  1824. gOhwI16o2i = dnnl_gOhwI16o2i,
  1825. gIhwo16i = dnnl_gIhwo16i,
  1826. gIhwO16i2o = dnnl_gIhwO16i2o,
  1827. gIhwO16i4o = dnnl_gIhwO16i4o,
  1828. gOhwi4o = dnnl_gOhwi4o,
  1829. gOhwi8o = dnnl_gOhwi8o,
  1830. gOhwI8o2i = dnnl_gOhwI8o2i,
  1831. gOhwI8o4i = dnnl_gOhwI8o4i,
  1832. Goihw16g = dnnl_Goihw16g,
  1833. gOIhw16i16o = dnnl_gOIhw16i16o,
  1834. gOIhw16o16i = dnnl_gOIhw16o16i,
  1835. gOihw16o = dnnl_gOihw16o,
  1836. gOIhw4i16o4i = dnnl_gOIhw4i16o4i,
  1837. gOIhw2i8o4i = dnnl_gOIhw2i8o4i,
  1838. gOIhw4i4o = dnnl_gOIhw4i4o,
  1839. gOIhw4o4i = dnnl_gOIhw4o4i,
  1840. gOihw4o = dnnl_gOihw4o,
  1841. Goihw8g = dnnl_Goihw8g,
  1842. gOIhw8i16o2i = dnnl_gOIhw8i16o2i,
  1843. gOIhw8i8o = dnnl_gOIhw8i8o,
  1844. gOIhw8o16i2o = dnnl_gOIhw8o16i2o,
  1845. OIw4o8i8o4i = dnnl_OIw4o8i8o4i,
  1846. OIdhw4o8i8o4i = dnnl_OIdhw4o8i8o4i,
  1847. OIhw4o8i8o4i = dnnl_OIhw4o8i8o4i,
  1848. OIhw2o8i8o2i = dnnl_OIhw2o8i8o2i,
  1849. gOIw4o8i8o4i = dnnl_gOIw4o8i8o4i,
  1850. gOIdhw4o8i8o4i = dnnl_gOIdhw4o8i8o4i,
  1851. gOIhw4o8i8o4i = dnnl_gOIhw4o8i8o4i,
  1852. gOIhw2o8i8o2i = dnnl_gOIhw2o8i8o2i,
  1853. OIhw16i16o4i = dnnl_OIhw16i16o4i,
  1854. OIhw16i32o4i = dnnl_OIhw16i32o4i,
  1855. OIhw16i48o4i = dnnl_OIhw16i48o4i,
  1856. OIhw16i64o4i = dnnl_OIhw16i64o4i,
  1857. OIhw16i16o2i = dnnl_OIhw16i16o2i,
  1858. OIhw16i32o2i = dnnl_OIhw16i32o2i,
  1859. OIhw16i48o2i = dnnl_OIhw16i48o2i,
  1860. OIhw16i64o2i = dnnl_OIhw16i64o2i,
  1861. OIhw16o16i2o = dnnl_OIhw16o16i2o,
  1862. gOIhw16i16o4i = dnnl_gOIhw16i16o4i,
  1863. gOIhw16i16o2i = dnnl_gOIhw16i16o2i,
  1864. gOIhw16o16i2o = dnnl_gOIhw16o16i2o,
  1865. gOIhw8o8i = dnnl_gOIhw8o8i,
  1866. gOIhw8o4i = dnnl_gOIhw8o4i,
  1867. gIOdhw16i16o = dnnl_gIOdhw16i16o,
  1868. gIOdhw8o8i = dnnl_gIOdhw8o8i,
  1869. gIOdhw16o16i = dnnl_gIOdhw16o16i,
  1870. gOdhwi16o = dnnl_gOdhwi16o,
  1871. gOdhwI16o2i = dnnl_gOdhwI16o2i,
  1872. gIdhwo16i = dnnl_gIdhwo16i,
  1873. gIdhwO16i2o = dnnl_gIdhwO16i2o,
  1874. gIdhwO16i4o = dnnl_gIdhwO16i4o,
  1875. gOdhwi4o = dnnl_gOdhwi4o,
  1876. gOdhwi8o = dnnl_gOdhwi8o,
  1877. gOdhwI8o2i = dnnl_gOdhwI8o2i,
  1878. gOdhwI8o4i = dnnl_gOdhwI8o4i,
  1879. gOIdhw16i16o = dnnl_gOIdhw16i16o,
  1880. gOIdhw16o16i = dnnl_gOIdhw16o16i,
  1881. gOIdhw16o16i2o = dnnl_gOIdhw16o16i2o,
  1882. gOidhw16o = dnnl_gOidhw16o,
  1883. gOIdhw4i4o = dnnl_gOIdhw4i4o,
  1884. gOIdhw4o4i = dnnl_gOIdhw4o4i,
  1885. gOidhw4o = dnnl_gOidhw4o,
  1886. gOIdhw8i16o2i = dnnl_gOIdhw8i16o2i,
  1887. gOIdhw4i16o4i = dnnl_gOIdhw4i16o4i,
  1888. gOIdhw16i16o4i = dnnl_gOIdhw16i16o4i,
  1889. gOIdhw16i16o2i = dnnl_gOIdhw16i16o2i,
  1890. gOIdhw2i8o4i = dnnl_gOIdhw2i8o4i,
  1891. gOIdhw8i8o = dnnl_gOIdhw8i8o,
  1892. gOIdhw8o8i = dnnl_gOIdhw8o8i,
  1893. gOIdhw8o4i = dnnl_gOIdhw8o4i,
  1894. gOIw2i4o2i = dnnl_gOIw2i4o2i,
  1895. gOIhw2i4o2i = dnnl_gOIhw2i4o2i,
  1896. gOIdhw2i4o2i = dnnl_gOIdhw2i4o2i,
  1897. gOIw2o4i2o = dnnl_gOIw2o4i2o,
  1898. gOIhw2o4i2o = dnnl_gOIhw2o4i2o,
  1899. gOIdhw2o4i2o = dnnl_gOIdhw2o4i2o,
  1900. gOIw4i8o2i = dnnl_gOIw4i8o2i,
  1901. gOIhw4i8o2i = dnnl_gOIhw4i8o2i,
  1902. gOIdhw4i8o2i = dnnl_gOIdhw4i8o2i,
  1903. gOIw4o8i2o = dnnl_gOIw4o8i2o,
  1904. gOIhw4o8i2o = dnnl_gOIhw4o8i2o,
  1905. gOIdhw4o8i2o = dnnl_gOIdhw4o8i2o,
  1906. ldOi16o = abDc16d,
  1907. ldOi32o = abDc32d,
  1908. ldOI16o4i = abDC16d4c,
  1909. ldOI32o4i = abDC32d4c,
  1910. ldgOi16o = abdEc16e,
  1911. ldgOI16o4i = abdEC16e4c,
  1912. ldgOi32o = abdEc32e,
  1913. ldgOI32o2i = abdEC32e2c,
  1914. ldgOI32o4i = abdEC32e4c,
  1915. OwI16o4i = dnnl_OwI16o4i,
  1916. OhwI16o4i = dnnl_OhwI16o4i,
  1917. gOwI16o4i = dnnl_gOwI16o4i,
  1918. gOhwI16o4i = dnnl_gOhwI16o4i,
  1919. OdhwI16o4i = dnnl_OdhwI16o4i,
  1920. gOdhwI16o4i = dnnl_gOdhwI16o4i,
  1921. Owi32o = dnnl_Owi32o,
  1922. OwI32o2i = dnnl_OwI32o2i,
  1923. OwI32o4i = dnnl_OwI32o4i,
  1924. Owi48o = dnnl_Owi48o,
  1925. OwI48o2i = dnnl_OwI48o2i,
  1926. OwI48o4i = dnnl_OwI48o4i,
  1927. Owi64o = dnnl_Owi64o,
  1928. OwI64o2i = dnnl_OwI64o2i,
  1929. OwI64o4i = dnnl_OwI64o4i,
  1930. Iwo32i = dnnl_Iwo32i,
  1931. IwO32i2o = dnnl_IwO32i2o,
  1932. IwO32i4o = dnnl_IwO32i4o,
  1933. Iwo48i = dnnl_Iwo48i,
  1934. IwO48i2o = dnnl_IwO48i2o,
  1935. IwO48i4o = dnnl_IwO48i4o,
  1936. Iwo64i = dnnl_Iwo64i,
  1937. IwO64i2o = dnnl_IwO64i2o,
  1938. IwO64i4o = dnnl_IwO64i4o,
  1939. wIo2i = dnnl_wIo2i,
  1940. wIo4i = dnnl_wIo4i,
  1941. gOwi32o = dnnl_gOwi32o,
  1942. gOwI32o2i = dnnl_gOwI32o2i,
  1943. gOwI32o4i = dnnl_gOwI32o4i,
  1944. gOwi48o = dnnl_gOwi48o,
  1945. gOwI48o2i = dnnl_gOwI48o2i,
  1946. gOwI48o4i = dnnl_gOwI48o4i,
  1947. gOwi64o = dnnl_gOwi64o,
  1948. gOwI64o2i = dnnl_gOwI64o2i,
  1949. gOwI64o4i = dnnl_gOwI64o4i,
  1950. gIwo32i = dnnl_gIwo32i,
  1951. gIwO32i2o = dnnl_gIwO32i2o,
  1952. gIwO32i4o = dnnl_gIwO32i4o,
  1953. gIwo48i = dnnl_gIwo48i,
  1954. gIwO48i2o = dnnl_gIwO48i2o,
  1955. gIwO48i4o = dnnl_gIwO48i4o,
  1956. gIwo64i = dnnl_gIwo64i,
  1957. gIwO64i2o = dnnl_gIwO64i2o,
  1958. gIwO64i4o = dnnl_gIwO64i4o,
  1959. gwio = dnnl_gwio,
  1960. gwIo2i = dnnl_gwIo2i,
  1961. gwIo4i = dnnl_gwIo4i,
  1962. OhwI32o = dnnl_OhwI32o,
  1963. OhwI32o2i = dnnl_OhwI32o2i,
  1964. OhwI32o4i = dnnl_OhwI32o4i,
  1965. Ohwi48o = dnnl_Ohwi48o,
  1966. OhwI48o2i = dnnl_OhwI48o2i,
  1967. OhwI48o4i = dnnl_OhwI48o4i,
  1968. Ohwi64o = dnnl_Ohwi64o,
  1969. OhwI64o2i = dnnl_OhwI64o2i,
  1970. OhwI64o4i = dnnl_OhwI64o4i,
  1971. Ihwo32i = dnnl_Ihwo32i,
  1972. IhwO32i2o = dnnl_IhwO32i2o,
  1973. IhwO32i4o = dnnl_IhwO32i4o,
  1974. Ihwo48i = dnnl_Ihwo48i,
  1975. IhwO48i2o = dnnl_IhwO48i2o,
  1976. IhwO48i4o = dnnl_IhwO48i4o,
  1977. Ihwo64i = dnnl_Ihwo64i,
  1978. IhwO64i2o = dnnl_IhwO64i2o,
  1979. IhwO64i4o = dnnl_IhwO64i4o,
  1980. hwIo2i = dnnl_hwIo2i,
  1981. hwIo4i = dnnl_hwIo4i,
  1982. gOhwI32o = dnnl_gOhwI32o,
  1983. gOhwI32o2i = dnnl_gOhwI32o2i,
  1984. gOhwI32o4i = dnnl_gOhwI32o4i,
  1985. gOhwi48o = dnnl_gOhwi48o,
  1986. gOhwI48o2i = dnnl_gOhwI48o2i,
  1987. gOhwI48o4i = dnnl_gOhwI48o4i,
  1988. gOhwi64o = dnnl_gOhwi64o,
  1989. gOhwI64o2i = dnnl_gOhwI64o2i,
  1990. gOhwI64o4i = dnnl_gOhwI64o4i,
  1991. gIhwo32i = dnnl_gIhwo32i,
  1992. gIhwO32i2o = dnnl_gIhwO32i2o,
  1993. gIhwO32i4o = dnnl_gIhwO32i4o,
  1994. gIhwo48i = dnnl_gIhwo48i,
  1995. gIhwO48i2o = dnnl_gIhwO48i2o,
  1996. gIhwO48i4o = dnnl_gIhwO48i4o,
  1997. gIhwo64i = dnnl_gIhwo64i,
  1998. gIhwO64i2o = dnnl_gIhwO64i2o,
  1999. gIhwO64i4o = dnnl_gIhwO64i4o,
  2000. ghwio = dnnl_ghwio,
  2001. ghwIo2i = dnnl_ghwIo2i,
  2002. ghwIo4i = dnnl_ghwIo4i,
  2003. Odhwi32o = dnnl_Odhwi32o,
  2004. OdhwI32o2i = dnnl_OdhwI32o2i,
  2005. OdhwI32o4i = dnnl_OdhwI32o4i,
  2006. Odhwi48o = dnnl_Odhwi48o,
  2007. OdhwI48o2i = dnnl_OdhwI48o2i,
  2008. OdhwI48o4i = dnnl_OdhwI48o4i,
  2009. Odhwi64o = dnnl_Odhwi64o,
  2010. OdhwI64o2i = dnnl_OdhwI64o2i,
  2011. OdhwI64o4i = dnnl_OdhwI64o4i,
  2012. Idhwo32i = dnnl_Idhwo32i,
  2013. IdhwO32i2o = dnnl_IdhwO32i2o,
  2014. IdhwO32i4o = dnnl_IdhwO32i4o,
  2015. Idhwo48i = dnnl_Idhwo48i,
  2016. IdhwO48i2o = dnnl_IdhwO48i2o,
  2017. IdhwO48i4o = dnnl_IdhwO48i4o,
  2018. Idhwo64i = dnnl_Idhwo64i,
  2019. IdhwO64i2o = dnnl_IdhwO64i2o,
  2020. IdhwO64i4o = dnnl_IdhwO64i4o,
  2021. dhwIo2i = dnnl_dhwIo2i,
  2022. dhwIo4i = dnnl_dhwIo4i,
  2023. gOdhwi32o = dnnl_gOdhwi32o,
  2024. gOdhwI32o2i = dnnl_gOdhwI32o2i,
  2025. gOdhwI32o4i = dnnl_gOdhwI32o4i,
  2026. gOdhwi48o = dnnl_gOdhwi48o,
  2027. gOdhwI48o2i = dnnl_gOdhwI48o2i,
  2028. gOdhwI48o4i = dnnl_gOdhwI48o4i,
  2029. gOdhwi64o = dnnl_gOdhwi64o,
  2030. gOdhwI64o2i = dnnl_gOdhwI64o2i,
  2031. gOdhwI64o4i = dnnl_gOdhwI64o4i,
  2032. gIdhwo32i = dnnl_gIdhwo32i,
  2033. gIdhwO32i2o = dnnl_gIdhwO32i2o,
  2034. gIdhwO32i4o = dnnl_gIdhwO32i4o,
  2035. gIdhwo48i = dnnl_gIdhwo48i,
  2036. gIdhwO48i2o = dnnl_gIdhwO48i2o,
  2037. gIdhwO48i4o = dnnl_gIdhwO48i4o,
  2038. gIdhwo64i = dnnl_gIdhwo64i,
  2039. gIdhwO64i2o = dnnl_gIdhwO64i2o,
  2040. gIdhwO64i4o = dnnl_gIdhwO64i4o,
  2041. gdhwio = dnnl_gdhwio,
  2042. gdhwIo2i = dnnl_gdhwIo2i,
  2043. gdhwIo4i = dnnl_gdhwIo4i,
  2044. ldIo32i = dnnl_ldIo32i,
  2045. ldgIo16i = dnnl_ldgIo16i,
  2046. ldgIo32i = dnnl_ldgIo32i,
  2047. ldgIO32i2o = dnnl_ldgIO32i2o,
  2048. nCdhw32c = dnnl_nCdhw32c,
  2049. nChw32c = dnnl_nChw32c,
  2050. nCw32c = dnnl_nCw32c,
  2051. NCw32n16c = dnnl_NCw32n16c,
  2052. NChw32n16c = dnnl_NChw32n16c,
  2053. NCdhw32n16c = dnnl_NCdhw32n16c,
  2054. NCw32n32c = dnnl_NCw32n32c,
  2055. OI16i16o4i = dnnl_OI16i16o4i,
  2056. IOw8o16i2o = dnnl_IOw8o16i2o,
  2057. IOhw8o16i2o = dnnl_IOhw8o16i2o,
  2058. Owhi16o = dnnl_Owhi16o,
  2059. OIdhw8o16i2o = dnnl_OIdhw8o16i2o,
  2060. IOdhw8o16i2o = dnnl_IOdhw8o16i2o,
  2061. Goiw4g = dnnl_Goiw4g,
  2062. gIOw8o16i2o = dnnl_gIOw8o16i2o,
  2063. Goiw32g = dnnl_Goiw32g,
  2064. Goihw4g = dnnl_Goihw4g,
  2065. gIOhw8o16i2o = dnnl_gIOhw8o16i2o,
  2066. Goihw32g = dnnl_Goihw32g,
  2067. gOwhi16o = dnnl_gOwhi16o,
  2068. IOw4i8o8i4o = dnnl_IOw4i8o8i4o,
  2069. IOhw4i8o8i4o = dnnl_IOhw4i8o8i4o,
  2070. IOdhw4i8o8i4o = dnnl_IOdhw4i8o8i4o,
  2071. gIOw4i8o8i4o = dnnl_gIOw4i8o8i4o,
  2072. gIOhw4i8o8i4o = dnnl_gIOhw4i8o8i4o,
  2073. gIOdhw4i8o8i4o = dnnl_gIOdhw4i8o8i4o,
  2074. gOIdhw8o16i2o = dnnl_gOIdhw8o16i2o,
  2075. gIOdhw8o16i2o = dnnl_gIOdhw8o16i2o,
  2076. Goidhw32g = dnnl_Goidhw32g,
  2077. OI16i32o4i = dnnl_OI16i32o4i,
  2078. OI16i48o4i = dnnl_OI16i48o4i,
  2079. OI16i64o4i = dnnl_OI16i64o4i,
  2080. OI16i16o2i = dnnl_OI16i16o2i,
  2081. OI16i32o2i = dnnl_OI16i32o2i,
  2082. OI16i48o2i = dnnl_OI16i48o2i,
  2083. OI16i64o2i = dnnl_OI16i64o2i,
  2084. aBdeC16c16b4c = dnnl_aBdeC16c16b4c,
  2085. AcB16b16a2b = dnnl_AcB16b16a2b,
  2086. aBdC16c16b2c = dnnl_aBdC16c16b2c,
  2087. AcB16b16a4b = dnnl_AcB16b16a4b,
  2088. aBdC16c16b4c = dnnl_aBdC16c16b4c,
  2089. AcdB16b16a2b = dnnl_AcdB16b16a2b,
  2090. aBdefC16c16b4c = dnnl_aBdefC16c16b4c,
  2091. AcdeB16b16a4b = dnnl_AcdeB16b16a4b,
  2092. AcB16b32a2b = dnnl_AcB16b32a2b,
  2093. AcB16b32a4b = dnnl_AcB16b32a4b,
  2094. AcB16b48a2b = dnnl_AcB16b48a2b,
  2095. AcB16b48a4b = dnnl_AcB16b48a4b,
  2096. AcB16b64a2b = dnnl_AcB16b64a2b,
  2097. AcB16b64a4b = dnnl_AcB16b64a4b,
  2098. aBdC16c32b2c = dnnl_aBdC16c32b2c,
  2099. aBdC16c32b4c = dnnl_aBdC16c32b4c,
  2100. aBdC16c48b2c = dnnl_aBdC16c48b2c,
  2101. aBdC16c48b4c = dnnl_aBdC16c48b4c,
  2102. aBdC16c64b2c = dnnl_aBdC16c64b2c,
  2103. aBdC16c64b4c = dnnl_aBdC16c64b4c,
  2104. AcdB16b32a2b = dnnl_AcdB16b32a2b,
  2105. AcdB16b32a4b = dnnl_AcdB16b32a4b,
  2106. AcdB16b48a2b = dnnl_AcdB16b48a2b,
  2107. AcdB16b48a4b = dnnl_AcdB16b48a4b,
  2108. AcdB16b64a2b = dnnl_AcdB16b64a2b,
  2109. AcdB16b64a4b = dnnl_AcdB16b64a4b,
  2110. aBdeC16c32b2c = dnnl_aBdeC16c32b2c,
  2111. aBdeC16c32b4c = dnnl_aBdeC16c32b4c,
  2112. aBdeC16c48b2c = dnnl_aBdeC16c48b2c,
  2113. aBdeC16c48b4c = dnnl_aBdeC16c48b4c,
  2114. aBdeC16c64b2c = dnnl_aBdeC16c64b2c,
  2115. aBdeC16c64b4c = dnnl_aBdeC16c64b4c,
  2116. AcdeB16b32a2b = dnnl_AcdeB16b32a2b,
  2117. AcdeB16b32a4b = dnnl_AcdeB16b32a4b,
  2118. AcdeB16b48a2b = dnnl_AcdeB16b48a2b,
  2119. AcdeB16b48a4b = dnnl_AcdeB16b48a4b,
  2120. AcdeB16b64a2b = dnnl_AcdeB16b64a2b,
  2121. AcdeB16b64a4b = dnnl_AcdeB16b64a4b,
  2122. aBdefC16c32b2c = dnnl_aBdefC16c32b2c,
  2123. aBdefC16c32b4c = dnnl_aBdefC16c32b4c,
  2124. aBdefC16c48b2c = dnnl_aBdefC16c48b2c,
  2125. aBdefC16c48b4c = dnnl_aBdefC16c48b4c,
  2126. aBdefC16c64b2c = dnnl_aBdefC16c64b2c,
  2127. aBdefC16c64b4c = dnnl_aBdefC16c64b4c,
  2128. OwI16i16o2i = dnnl_OwI16i16o2i,
  2129. gOwI16i16o2i = dnnl_gOwI16i16o2i,
  2130. OhwI16i16o2i = dnnl_OhwI16i16o2i,
  2131. gOhwI16i16o2i = dnnl_gOhwI16i16o2i,
  2132. OdhwI16i16o2i = dnnl_OdhwI16i16o2i,
  2133. gOdhwI16i16o2i = dnnl_gOdhwI16i16o2i,
  2134. OwI16i16o4i = dnnl_OwI16i16o4i,
  2135. gOwI16i16o4i = dnnl_gOwI16i16o4i,
  2136. OhwI16i16o4i = dnnl_OhwI16i16o4i,
  2137. gOhwI16i16o4i = dnnl_gOhwI16i16o4i,
  2138. OdhwI16i16o4i = dnnl_OdhwI16i16o4i,
  2139. gOdhwI16i16o4i = dnnl_gOdhwI16i16o4i,
  2140. OwI16i32o2i = dnnl_OwI16i32o2i,
  2141. OwI16i32o4i = dnnl_OwI16i32o4i,
  2142. OwI16i48o2i = dnnl_OwI16i48o2i,
  2143. OwI16i48o4i = dnnl_OwI16i48o4i,
  2144. OwI16i64o2i = dnnl_OwI16i64o2i,
  2145. OwI16i64o4i = dnnl_OwI16i64o4i,
  2146. gOwI16i32o2i = dnnl_gOwI16i32o2i,
  2147. gOwI16i32o4i = dnnl_gOwI16i32o4i,
  2148. gOwI16i48o2i = dnnl_gOwI16i48o2i,
  2149. gOwI16i48o4i = dnnl_gOwI16i48o4i,
  2150. gOwI16i64o2i = dnnl_gOwI16i64o2i,
  2151. gOwI16i64o4i = dnnl_gOwI16i64o4i,
  2152. OhwI16i32o2i = dnnl_OhwI16i32o2i,
  2153. OhwI16i32o4i = dnnl_OhwI16i32o4i,
  2154. OhwI16i48o2i = dnnl_OhwI16i48o2i,
  2155. OhwI16i48o4i = dnnl_OhwI16i48o4i,
  2156. OhwI16i64o2i = dnnl_OhwI16i64o2i,
  2157. OhwI16i64o4i = dnnl_OhwI16i64o4i,
  2158. gOhwI16i32o2i = dnnl_gOhwI16i32o2i,
  2159. gOhwI16i32o4i = dnnl_gOhwI16i32o4i,
  2160. gOhwI16i48o2i = dnnl_gOhwI16i48o2i,
  2161. gOhwI16i48o4i = dnnl_gOhwI16i48o4i,
  2162. gOhwI16i64o2i = dnnl_gOhwI16i64o2i,
  2163. gOhwI16i64o4i = dnnl_gOhwI16i64o4i,
  2164. OdhwI16i32o2i = dnnl_OdhwI16i32o2i,
  2165. OdhwI16i32o4i = dnnl_OdhwI16i32o4i,
  2166. OdhwI16i48o2i = dnnl_OdhwI16i48o2i,
  2167. OdhwI16i48o4i = dnnl_OdhwI16i48o4i,
  2168. OdhwI16i64o2i = dnnl_OdhwI16i64o2i,
  2169. OdhwI16i64o4i = dnnl_OdhwI16i64o4i,
  2170. IdhwO16o32i2o = dnnl_IdhwO16o32i2o,
  2171. IdhwO16o32i4o = dnnl_IdhwO16o32i4o,
  2172. IdhwO16o48i2o = dnnl_IdhwO16o48i2o,
  2173. IdhwO16o48i4o = dnnl_IdhwO16o48i4o,
  2174. IdhwO16o64i2o = dnnl_IdhwO16o64i2o,
  2175. IdhwO16o64i4o = dnnl_IdhwO16o64i4o,
  2176. gOdhwI16i32o2i = dnnl_gOdhwI16i32o2i,
  2177. gOdhwI16i32o4i = dnnl_gOdhwI16i32o4i,
  2178. gOdhwI16i48o2i = dnnl_gOdhwI16i48o2i,
  2179. gOdhwI16i48o4i = dnnl_gOdhwI16i48o4i,
  2180. gOdhwI16i64o2i = dnnl_gOdhwI16i64o2i,
  2181. gOdhwI16i64o4i = dnnl_gOdhwI16i64o4i,
  2182. gIdhwO16o32i2o = dnnl_gIdhwO16o32i2o,
  2183. gIdhwO16o32i4o = dnnl_gIdhwO16o32i4o,
  2184. gIdhwO16o48i2o = dnnl_gIdhwO16o48i2o,
  2185. gIdhwO16o48i4o = dnnl_gIdhwO16o48i4o,
  2186. gIdhwO16o64i2o = dnnl_gIdhwO16o64i2o,
  2187. gIdhwO16o64i4o = dnnl_gIdhwO16o64i4o,
  2188. IwO16o16i2o = dnnl_IwO16o16i2o,
  2189. IwO16o16i4o = dnnl_IwO16o16i4o,
  2190. IhwO16o16i2o = dnnl_IhwO16o16i2o,
  2191. IhwO16o16i4o = dnnl_IhwO16o16i4o,
  2192. IdhwO16o16i2o = dnnl_IdhwO16o16i2o,
  2193. IdhwO16o16i4o = dnnl_IdhwO16o16i4o,
  2194. gIwO16o16i2o = dnnl_gIwO16o16i2o,
  2195. gIwO16o16i4o = dnnl_gIwO16o16i4o,
  2196. gIhwO16o16i2o = dnnl_gIhwO16o16i2o,
  2197. gIhwO16o16i4o = dnnl_gIhwO16o16i4o,
  2198. gIdhwO16o16i2o = dnnl_gIdhwO16o16i2o,
  2199. gIdhwO16o16i4o = dnnl_gIdhwO16o16i4o,
  2200. IwO16o32i2o = dnnl_IwO16o32i2o,
  2201. IwO16o32i4o = dnnl_IwO16o32i4o,
  2202. IwO16o48i2o = dnnl_IwO16o48i2o,
  2203. IwO16o48i4o = dnnl_IwO16o48i4o,
  2204. IwO16o64i2o = dnnl_IwO16o64i2o,
  2205. IwO16o64i4o = dnnl_IwO16o64i4o,
  2206. gIwO16o32i2o = dnnl_gIwO16o32i2o,
  2207. gIwO16o32i4o = dnnl_gIwO16o32i4o,
  2208. gIwO16o48i2o = dnnl_gIwO16o48i2o,
  2209. gIwO16o48i4o = dnnl_gIwO16o48i4o,
  2210. gIwO16o64i2o = dnnl_gIwO16o64i2o,
  2211. gIwO16o64i4o = dnnl_gIwO16o64i4o,
  2212. IhwO16o32i2o = dnnl_IhwO16o32i2o,
  2213. IhwO16o32i4o = dnnl_IhwO16o32i4o,
  2214. IhwO16o48i2o = dnnl_IhwO16o48i2o,
  2215. IhwO16o48i4o = dnnl_IhwO16o48i4o,
  2216. IhwO16o64i2o = dnnl_IhwO16o64i2o,
  2217. IhwO16o64i4o = dnnl_IhwO16o64i4o,
  2218. gIhwO16o32i2o = dnnl_gIhwO16o32i2o,
  2219. gIhwO16o32i4o = dnnl_gIhwO16o32i4o,
  2220. gIhwO16o48i2o = dnnl_gIhwO16o48i2o,
  2221. gIhwO16o48i4o = dnnl_gIhwO16o48i4o,
  2222. gIhwO16o64i2o = dnnl_gIhwO16o64i2o,
  2223. gIhwO16o64i4o = dnnl_gIhwO16o64i4o,
  2224. aBdeC16c16b2c = dnnl_aBdeC16c16b2c,
  2225. aBdefC16c16b2c = dnnl_aBdefC16c16b2c,
  2226. AcdB16b16a4b = dnnl_AcdB16b16a4b,
  2227. AcdeB16b16a2b = dnnl_AcdeB16b16a2b,
  2228. hwioG16g = dnnl_hwioG16g,
  2229. hwioG8g = dnnl_hwioG8g,
  2230. dhwioG16g = dnnl_dhwioG16g,
  2231. dhwioG8g = dnnl_dhwioG8g,
  2232. ABc4a2b = dnnl_ABc4a2b,
  2233. ABc8a2b = dnnl_ABc8a2b,
  2234. ABcd4a2b = dnnl_ABcd4a2b,
  2235. ABcde4a2b = dnnl_ABcde4a2b,
  2236. ABcde8a2b = dnnl_ABcde8a2b,
  2237. ABcd4a8b8a2b = dnnl_ABcd4a8b8a2b,
  2238. NCdhw40n32c = dnnl_NCdhw40n32c,
  2239. NChw40n32c = dnnl_NChw40n32c,
  2240. NCw40n32c = dnnl_NCw40n32c,
  2241. OIdhw4o8i8o2i = dnnl_OIdhw4o8i8o2i,
  2242. OIhw4o8i8o2i = dnnl_OIhw4o8i8o2i,
  2243. OIw4o8i8o2i = dnnl_OIw4o8i8o2i,
  2244. gOIdhw4o8i8o2i = dnnl_gOIdhw4o8i8o2i,
  2245. gOIhw4o8i8o2i = dnnl_gOIhw4o8i8o2i,
  2246. gOIw4o8i8o2i = dnnl_gOIw4o8i8o2i,
  2247. IOdhw4i8o8i2o = dnnl_IOdhw4i8o8i2o,
  2248. IOhw4i8o8i2o = dnnl_IOhw4i8o8i2o,
  2249. IOw4i8o8i2o = dnnl_IOw4i8o8i2o,
  2250. gIOdhw4i8o8i2o = dnnl_gIOdhw4i8o8i2o,
  2251. gIOhw4i8o8i2o = dnnl_gIOhw4i8o8i2o,
  2252. gIOw4i8o8i2o = dnnl_gIOw4i8o8i2o,
  2253. aBCd8b2c = dnnl_aBCd8b2c,
  2254. ABcde40a16b = dnnl_ABcde40a16b,
  2255. ABcde40a32b = dnnl_ABcde40a32b,
  2256. aBCde8b2c = dnnl_aBCde8b2c,
  2257. ABcde4a8b8a2b = dnnl_ABcde4a8b8a2b,
  2258. ABc4a8b8a2b = dnnl_ABc4a8b8a2b,
  2259. aBCdef4b8c8b2c = dnnl_aBCdef4b8c8b2c,
  2260. aBCde4b8c8b2c = dnnl_aBCde4b8c8b2c,
  2261. aBCd4b8c8b2c = dnnl_aBCd4b8c8b2c,
  2262. BAcde4b8a8b2a = dnnl_BAcde4b8a8b2a,
  2263. BAcd4b8a8b2a = dnnl_BAcd4b8a8b2a,
  2264. BAc4b8a8b2a = dnnl_BAc4b8a8b2a,
  2265. aCBdef4c8b8c2b = dnnl_aCBdef4c8b8c2b,
  2266. aCBde4c8b8c2b = dnnl_aCBde4c8b8c2b,
  2267. aCBd4c8b8c2b = dnnl_aCBd4c8b8c2b,
  2268. aBCdef8b2c = dnnl_aBCdef8b2c,
  2269. AB32a16b = dnnl_AB32a16b,
  2270. AB32a32b = dnnl_AB32a32b,
  2271. BA4b8a8b2a = dnnl_BA4b8a8b2a,
  2272. BA4b8a8b4a = dnnl_BA4b8a8b4a,
  2273. aBC32b16c = dnnl_aBC32b16c,
  2274. aBC32b32c = dnnl_aBC32b32c,
  2275. aCB4c8b8c2b = dnnl_aCB4c8b8c2b,
  2276. aCB4c8b8c4b = dnnl_aCB4c8b8c4b,
  2277. ABc2b8a16b4a = dnnl_ABc2b8a16b4a,
  2278. ABcd2b8a16b4a = dnnl_ABcd2b8a16b4a,
  2279. ABcde2b8a16b4a = dnnl_ABcde2b8a16b4a,
  2280. ABc2a8b16a4b = dnnl_ABc2a8b16a4b,
  2281. ABc2a8b16a2b = dnnl_ABc2a8b16a2b,
  2282. ABc2b32a8b = dnnl_ABc2b32a8b,
  2283. ABcd2a8b16a4b = dnnl_ABcd2a8b16a4b,
  2284. ABcd2a8b16a2b = dnnl_ABcd2a8b16a2b,
  2285. aCBd2c8b16c2b = dnnl_aCBd2c8b16c2b,
  2286. ABcd2b32a8b = dnnl_ABcd2b32a8b,
  2287. aBCd2c8b16c2b = dnnl_aBCd2c8b16c2b,
  2288. ABcde2a8b16a4b = dnnl_ABcde2a8b16a4b,
  2289. ABcde2a8b16a2b = dnnl_ABcde2a8b16a2b,
  2290. aCBde2c8b16c2b = dnnl_aCBde2c8b16c2b,
  2291. ABcde2b32a8b = dnnl_ABcde2b32a8b,
  2292. aBC2b8c16b2c = dnnl_aBC2b8c16b2c,
  2293. aBCd2b8c16b2c = dnnl_aBCd2b8c16b2c,
  2294. aBCde2b8c16b2c = dnnl_aBCde2b8c16b2c,
  2295. aBCdef2b8c16b2c = dnnl_aBCdef2b8c16b2c,
  2296. BAcde2b8a16b4a = dnnl_BAcde2b8a16b4a,
  2297. BAcd2b8a16b4a = dnnl_BAcd2b8a16b4a,
  2298. BAc2b8a16b4a = dnnl_BAc2b8a16b4a,
  2299. BAcde2b8a16b2a = dnnl_BAcde2b8a16b2a,
  2300. BAcd2b8a16b2a = dnnl_BAcd2b8a16b2a,
  2301. BAc2b8a16b2a = dnnl_BAc2b8a16b2a,
  2302. aBCde2c8b16c2b = dnnl_aBCde2c8b16c2b,
  2303. aBCdef2c8b16c2b = dnnl_aBCdef2c8b16c2b,
  2304. aCBdef2c8b16c2b = dnnl_aCBdef2c8b16c2b,
  2305. aBCd2b8c16b4c = dnnl_aBCd2b8c16b4c,
  2306. aBCde2b8c16b4c = dnnl_aBCde2b8c16b4c,
  2307. NCdhw40n16c = dnnl_NCdhw40n16c,
  2308. NCw40n16c = dnnl_NCw40n16c,
  2309. NChw40n16c = dnnl_NChw40n16c,
  2310. NCw2c32n8c = dnnl_NCw2c32n8c,
  2311. NChw2c32n8c = dnnl_NChw2c32n8c,
  2312. NCdhw2c32n8c = dnnl_NCdhw2c32n8c,
  2313. OIw2i8o16i4o = dnnl_OIw2i8o16i4o,
  2314. OIhw2i8o16i4o = dnnl_OIhw2i8o16i4o,
  2315. OIdhw2i8o16i4o = dnnl_OIdhw2i8o16i4o,
  2316. OIw2o8i16o4i = dnnl_OIw2o8i16o4i,
  2317. OIw2o8i16o2i = dnnl_OIw2o8i16o2i,
  2318. IOw2i8o16i4o = dnnl_IOw2i8o16i4o,
  2319. IOw2i8o16i2o = dnnl_IOw2i8o16i2o,
  2320. OIhw2o8i16o4i = dnnl_OIhw2o8i16o4i,
  2321. OIhw2o8i16o2i = dnnl_OIhw2o8i16o2i,
  2322. IOhw2i8o16i4o = dnnl_IOhw2i8o16i4o,
  2323. IOhw2i8o16i2o = dnnl_IOhw2i8o16i2o,
  2324. OIdhw2o8i16o4i = dnnl_OIdhw2o8i16o4i,
  2325. OIdhw2o8i16o2i = dnnl_OIdhw2o8i16o2i,
  2326. IOdhw2i8o16i4o = dnnl_IOdhw2i8o16i4o,
  2327. IOdhw2i8o16i2o = dnnl_IOdhw2i8o16i2o,
  2328. gOIw2o8i16o2i = dnnl_gOIw2o8i16o2i,
  2329. gIOw2i8o16i2o = dnnl_gIOw2i8o16i2o,
  2330. gIOhw2i8o16i2o = dnnl_gIOhw2i8o16i2o,
  2331. gIOdhw2i8o16i2o = dnnl_gIOdhw2i8o16i2o,
  2332. gOIhw2o8i16o2i = dnnl_gOIhw2o8i16o2i,
  2333. gOIdhw2o8i16o2i = dnnl_gOIdhw2o8i16o2i,
  2334. gOIw2o8i16o4i = dnnl_gOIw2o8i16o4i,
  2335. gOIhw2o8i16o4i = dnnl_gOIhw2o8i16o4i,
  2336. BA4b8a16b2a = dnnl_BA4b8a16b2a,
  2337. BA4b8a16b4a = dnnl_BA4b8a16b4a,
  2338. aCB4c8b16c2b = dnnl_aCB4c8b16c2b,
  2339. aCB4c8b16c4b = dnnl_aCB4c8b16c4b,
  2340. aCB16c2b = dnnl_aCB16c2b,
  2341. aCB16c4b = dnnl_aCB16c4b,
  2342. BA16b2a = dnnl_BA16b2a,
  2343. BA16b4a = dnnl_BA16b4a,
  2344. BA4b4a = dnnl_BA4b4a,
  2345. BA8b4a = dnnl_BA8b4a,
  2346. aBC16b16c = dnnl_aBC16b16c,
  2347. aBC16b32c = dnnl_aBC16b32c,
  2348. AB16a16b = dnnl_AB16a16b,
  2349. AB16a32b = dnnl_AB16a32b,
  2350. ABcde16a16b2a = dnnl_ABcde16a16b2a,
  2351. aBCdef16b16c2b = dnnl_aBCdef16b16c2b,
  2352. Acedb16a = dnnl_Acedb16a,
  2353. aBdfec16b = dnnl_aBdfec16b,
  2354. Odwhi16o = dnnl_Odwhi16o,
  2355. gOdwhi16o = dnnl_gOdwhi16o,
  2356. abdEC64e2c = dnnl_abdEC64e2c,
  2357. abdEC64e4c = dnnl_abdEC64e4c,
  2358. ldgOI64o2i = abdEC64e2c,
  2359. ldgOI64o4i = abdEC64e4c,
  2360. abCd4c = dnnl_abCd4c,
  2361. abCde4c = dnnl_abCde4c,
  2362. abCdef4c = dnnl_abCdef4c,
  2363. abCde32c = dnnl_abCde32c,
  2364. abCdef32c = dnnl_abCdef32c,
  2365. aCdefB16b32c2b = dnnl_aCdefB16b32c2b,
  2366. aCdefB16b32c4b = dnnl_aCdefB16b32c4b,
  2367. aCdefB16b48c2b = dnnl_aCdefB16b48c2b,
  2368. aCdefB16b48c4b = dnnl_aCdefB16b48c4b,
  2369. aCdefB16b64c2b = dnnl_aCdefB16b64c2b,
  2370. aCdefB16b64c4b = dnnl_aCdefB16b64c4b,
  2371. BcdeA16a32b2a = dnnl_BcdeA16a32b2a,
  2372. BcdeA16a32b4a = dnnl_BcdeA16a32b4a,
  2373. BcdeA16a48b2a = dnnl_BcdeA16a48b2a,
  2374. BcdeA16a48b4a = dnnl_BcdeA16a48b4a,
  2375. BcdeA16a64b2a = dnnl_BcdeA16a64b2a,
  2376. BcdeA16a64b4a = dnnl_BcdeA16a64b4a,
  2377. aCdefb32c = dnnl_aCdefb32c,
  2378. aCdefB32c2b = dnnl_aCdefB32c2b,
  2379. aCdefB32c4b = dnnl_aCdefB32c4b,
  2380. aCdefb48c = dnnl_aCdefb48c,
  2381. aCdefB48c2b = dnnl_aCdefB48c2b,
  2382. aCdefB48c4b = dnnl_aCdefB48c4b,
  2383. aCdefb64c = dnnl_aCdefb64c,
  2384. aCdefB64c2b = dnnl_aCdefB64c2b,
  2385. aCdefB64c4b = dnnl_aCdefB64c4b,
  2386. Bcdea32b = dnnl_Bcdea32b,
  2387. BcdeA32b2a = dnnl_BcdeA32b2a,
  2388. BcdeA32b4a = dnnl_BcdeA32b4a,
  2389. Bcdea48b = dnnl_Bcdea48b,
  2390. BcdeA48b2a = dnnl_BcdeA48b2a,
  2391. BcdeA48b4a = dnnl_BcdeA48b4a,
  2392. Bcdea64b = dnnl_Bcdea64b,
  2393. BcdeA64b2a = dnnl_BcdeA64b2a,
  2394. BcdeA64b4a = dnnl_BcdeA64b4a,
  2395. Bca32b = dnnl_Bca32b,
  2396. BcA32b2a = dnnl_BcA32b2a,
  2397. BcA32b4a = dnnl_BcA32b4a,
  2398. Bca48b = dnnl_Bca48b,
  2399. BcA48b2a = dnnl_BcA48b2a,
  2400. BcA48b4a = dnnl_BcA48b4a,
  2401. Bca64b = dnnl_Bca64b,
  2402. BcA64b2a = dnnl_BcA64b2a,
  2403. BcA64b4a = dnnl_BcA64b4a,
  2404. aCdb32c = dnnl_aCdb32c,
  2405. aCdB32c2b = dnnl_aCdB32c2b,
  2406. aCdB32c4b = dnnl_aCdB32c4b,
  2407. aCdb48c = dnnl_aCdb48c,
  2408. aCdB48c2b = dnnl_aCdB48c2b,
  2409. aCdB48c4b = dnnl_aCdB48c4b,
  2410. aCdb64c = dnnl_aCdb64c,
  2411. aCdB64c2b = dnnl_aCdB64c2b,
  2412. aCdB64c4b = dnnl_aCdB64c4b,
  2413. BcA16a16b2a = dnnl_BcA16a16b2a,
  2414. BcA16a16b4a = dnnl_BcA16a16b4a,
  2415. BcdA16a16b2a = dnnl_BcdA16a16b2a,
  2416. BcdA16a16b4a = dnnl_BcdA16a16b4a,
  2417. BcdeA16a16b2a = dnnl_BcdeA16a16b2a,
  2418. BcdeA16a16b4a = dnnl_BcdeA16a16b4a,
  2419. aCdB16b16c2b = dnnl_aCdB16b16c2b,
  2420. aCdB16b16c4b = dnnl_aCdB16b16c4b,
  2421. aCdeB16b16c2b = dnnl_aCdeB16b16c2b,
  2422. aCdeB16b16c4b = dnnl_aCdeB16b16c4b,
  2423. aCdefB16b16c2b = dnnl_aCdefB16b16c2b,
  2424. aCdefB16b16c4b = dnnl_aCdefB16b16c4b,
  2425. BcA16a32b2a = dnnl_BcA16a32b2a,
  2426. BcA16a32b4a = dnnl_BcA16a32b4a,
  2427. BcA16a48b2a = dnnl_BcA16a48b2a,
  2428. BcA16a48b4a = dnnl_BcA16a48b4a,
  2429. BcA16a64b2a = dnnl_BcA16a64b2a,
  2430. BcA16a64b4a = dnnl_BcA16a64b4a,
  2431. aCdB16b32c2b = dnnl_aCdB16b32c2b,
  2432. aCdB16b32c4b = dnnl_aCdB16b32c4b,
  2433. aCdB16b48c2b = dnnl_aCdB16b48c2b,
  2434. aCdB16b48c4b = dnnl_aCdB16b48c4b,
  2435. aCdB16b64c2b = dnnl_aCdB16b64c2b,
  2436. aCdB16b64c4b = dnnl_aCdB16b64c4b,
  2437. BcdA16a32b2a = dnnl_BcdA16a32b2a,
  2438. BcdA16a32b4a = dnnl_BcdA16a32b4a,
  2439. BcdA16a48b2a = dnnl_BcdA16a48b2a,
  2440. BcdA16a48b4a = dnnl_BcdA16a48b4a,
  2441. BcdA16a64b2a = dnnl_BcdA16a64b2a,
  2442. BcdA16a64b4a = dnnl_BcdA16a64b4a,
  2443. aCdeB16b32c2b = dnnl_aCdeB16b32c2b,
  2444. aCdeB16b32c4b = dnnl_aCdeB16b32c4b,
  2445. aCdeB16b48c2b = dnnl_aCdeB16b48c2b,
  2446. aCdeB16b48c4b = dnnl_aCdeB16b48c4b,
  2447. aCdeB16b64c2b = dnnl_aCdeB16b64c2b,
  2448. aCdeB16b64c4b = dnnl_aCdeB16b64c4b,
  2449. Bca16b = dnnl_Bca16b,
  2450. BcA16b2a = dnnl_BcA16b2a,
  2451. BcA16b4a = dnnl_BcA16b4a,
  2452. Bcda16b = dnnl_Bcda16b,
  2453. BcdA16b2a = dnnl_BcdA16b2a,
  2454. BcdA16b4a = dnnl_BcdA16b4a,
  2455. Bcdea16b = dnnl_Bcdea16b,
  2456. BcdeA16b2a = dnnl_BcdeA16b2a,
  2457. BcdeA16b4a = dnnl_BcdeA16b4a,
  2458. aCdb16c = dnnl_aCdb16c,
  2459. aCdB16c2b = dnnl_aCdB16c2b,
  2460. aCdB16c4b = dnnl_aCdB16c4b,
  2461. aCdeb16c = dnnl_aCdeb16c,
  2462. aCdeB16c2b = dnnl_aCdeB16c2b,
  2463. aCdeB16c4b = dnnl_aCdeB16c4b,
  2464. aCdefb16c = dnnl_aCdefb16c,
  2465. aCdefB16c2b = dnnl_aCdefB16c2b,
  2466. aCdefB16c4b = dnnl_aCdefB16c4b,
  2467. Bcda32b = dnnl_Bcda32b,
  2468. BcdA32b2a = dnnl_BcdA32b2a,
  2469. BcdA32b4a = dnnl_BcdA32b4a,
  2470. Bcda48b = dnnl_Bcda48b,
  2471. BcdA48b2a = dnnl_BcdA48b2a,
  2472. BcdA48b4a = dnnl_BcdA48b4a,
  2473. Bcda64b = dnnl_Bcda64b,
  2474. BcdA64b2a = dnnl_BcdA64b2a,
  2475. BcdA64b4a = dnnl_BcdA64b4a,
  2476. aCdeb32c = dnnl_aCdeb32c,
  2477. aCdeB32c2b = dnnl_aCdeB32c2b,
  2478. aCdeB32c4b = dnnl_aCdeB32c4b,
  2479. aCdeb48c = dnnl_aCdeb48c,
  2480. aCdeB48c2b = dnnl_aCdeB48c2b,
  2481. aCdeB48c4b = dnnl_aCdeB48c4b,
  2482. aCdeb64c = dnnl_aCdeb64c,
  2483. aCdeB64c2b = dnnl_aCdeB64c2b,
  2484. aCdeB64c4b = dnnl_aCdeB64c4b,
  2485. NChw16n32c = dnnl_NChw16n32c,
  2486. goIw4i = dnnl_goIw4i,
  2487. goIw32i = dnnl_goIw32i,
  2488. goIhw4i = dnnl_goIhw4i,
  2489. goIhw32i = dnnl_goIhw32i,
  2490. goIdhw4i = dnnl_goIdhw4i,
  2491. goIdhw32i = dnnl_goIdhw32i,
  2492. cab = dnnl_cab,
  2493. cdab = dnnl_cdab,
  2494. cdeab = dnnl_cdeab,
  2495. woi = dnnl_woi,
  2496. hwoi = dnnl_hwoi,
  2497. dhwoi = dnnl_dhwoi,
  2498. Owi24o = dnnl_Owi24o,
  2499. Ohwi24o = dnnl_Ohwi24o,
  2500. Odhwi24o = dnnl_Odhwi24o,
  2501. gOwi24o = dnnl_gOwi24o,
  2502. gOhwi24o = dnnl_gOhwi24o,
  2503. gOdhwi24o = dnnl_gOdhwi24o,
  2504. OwI24o2i = dnnl_OwI24o2i,
  2505. OhwI24o2i = dnnl_OhwI24o2i,
  2506. OdhwI24o2i = dnnl_OdhwI24o2i,
  2507. gOwI24o2i = dnnl_gOwI24o2i,
  2508. gOhwI24o2i = dnnl_gOhwI24o2i,
  2509. gOdhwI24o2i = dnnl_gOdhwI24o2i,
  2510. OwI24o4i = dnnl_OwI24o4i,
  2511. OhwI24o4i = dnnl_OhwI24o4i,
  2512. OdhwI24o4i = dnnl_OdhwI24o4i,
  2513. gOwI24o4i = dnnl_gOwI24o4i,
  2514. gOhwI24o4i = dnnl_gOhwI24o4i,
  2515. gOdhwI24o4i = dnnl_gOdhwI24o4i,
  2516. OI8i32o = dnnl_OI8i32o,
  2517. OIw8i32o = dnnl_OIw8i32o,
  2518. OwI8i32o = dnnl_OwI8i32o,
  2519. OIhw8i32o = dnnl_OIhw8i32o,
  2520. OhwI8i32o = dnnl_OhwI8i32o,
  2521. OIdhw8i32o = dnnl_OIdhw8i32o,
  2522. OdhwI8i32o = dnnl_OdhwI8i32o,
  2523. OI8i24o = dnnl_OI8i24o,
  2524. OIw8i24o = dnnl_OIw8i24o,
  2525. OwI8i24o = dnnl_OwI8i24o,
  2526. OIhw8i24o = dnnl_OIhw8i24o,
  2527. OhwI8i24o = dnnl_OhwI8i24o,
  2528. OIdhw8i24o = dnnl_OIdhw8i24o,
  2529. OdhwI8i24o = dnnl_OdhwI8i24o,
  2530. OI8i16o = dnnl_OI8i16o,
  2531. OIw8i16o = dnnl_OIw8i16o,
  2532. OwI8i16o = dnnl_OwI8i16o,
  2533. OIhw8i16o = dnnl_OIhw8i16o,
  2534. OhwI8i16o = dnnl_OhwI8i16o,
  2535. OIdhw8i16o = dnnl_OIdhw8i16o,
  2536. OdhwI8i16o = dnnl_OdhwI8i16o,
  2537. OI8i8o = dnnl_OI8i8o,
  2538. AB4b8a4b = dnnl_AB4b8a4b,
  2539. AB4b24a4b = dnnl_AB4b24a4b,
  2540. ABc4b8a4b = dnnl_ABc4b8a4b,
  2541. AcB4b8a4b = dnnl_AcB4b8a4b,
  2542. ABc4b24a4b = dnnl_ABc4b24a4b,
  2543. AcB4b24a4b = dnnl_AcB4b24a4b,
  2544. ABcd4b8a4b = dnnl_ABcd4b8a4b,
  2545. AcdB4b8a4b = dnnl_AcdB4b8a4b,
  2546. ABcd4b24a4b = dnnl_ABcd4b24a4b,
  2547. AcdB4b24a4b = dnnl_AcdB4b24a4b,
  2548. ABcde4b8a4b = dnnl_ABcde4b8a4b,
  2549. AcdeB4b8a4b = dnnl_AcdeB4b8a4b,
  2550. ABcde4b24a4b = dnnl_ABcde4b24a4b,
  2551. AcdeB4b24a4b = dnnl_AcdeB4b24a4b,
  2552. Bca8b = dnnl_Bca8b,
  2553. BcA8b2a = dnnl_BcA8b2a,
  2554. Bcda8b = dnnl_Bcda8b,
  2555. BcdA8b2a = dnnl_BcdA8b2a,
  2556. Bcdea8b = dnnl_Bcdea8b,
  2557. BcdeA8b2a = dnnl_BcdeA8b2a,
  2558. aCdb8c = dnnl_aCdb8c,
  2559. aCdB8c2b = dnnl_aCdB8c2b,
  2560. aCdeb8c = dnnl_aCdeb8c,
  2561. aCdeB8c2b = dnnl_aCdeB8c2b,
  2562. aCdefb8c = dnnl_aCdefb8c,
  2563. aCdefB8c2b = dnnl_aCdefB8c2b,
  2564. Bca24b = dnnl_Bca24b,
  2565. BcA24b2a = dnnl_BcA24b2a,
  2566. Bcda24b = dnnl_Bcda24b,
  2567. BcdA24b2a = dnnl_BcdA24b2a,
  2568. Bcdea24b = dnnl_Bcdea24b,
  2569. BcdeA24b2a = dnnl_BcdeA24b2a,
  2570. aCdb24c = dnnl_aCdb24c,
  2571. aCdB24c2b = dnnl_aCdB24c2b,
  2572. aCdeb24c = dnnl_aCdeb24c,
  2573. aCdeB24c2b = dnnl_aCdeB24c2b,
  2574. aCdefb24c = dnnl_aCdefb24c,
  2575. aCdefB24c2b = dnnl_aCdefB24c2b,
  2576. Iwo8i = dnnl_Iwo8i,
  2577. IwO8i2o = dnnl_IwO8i2o,
  2578. Iwo24i = dnnl_Iwo24i,
  2579. IwO24i2o = dnnl_IwO24i2o,
  2580. Ihwo8i = dnnl_Ihwo8i,
  2581. IhwO8i2o = dnnl_IhwO8i2o,
  2582. Ihwo24i = dnnl_Ihwo24i,
  2583. IhwO24i2o = dnnl_IhwO24i2o,
  2584. Idhwo8i = dnnl_Idhwo8i,
  2585. IdhwO8i2o = dnnl_IdhwO8i2o,
  2586. Idhwo24i = dnnl_Idhwo24i,
  2587. IdhwO24i2o = dnnl_IdhwO24i2o,
  2588. gIwo8i = dnnl_gIwo8i,
  2589. gIwO8i2o = dnnl_gIwO8i2o,
  2590. gIwo24i = dnnl_gIwo24i,
  2591. gIwO24i2o = dnnl_gIwO24i2o,
  2592. gIhwo8i = dnnl_gIhwo8i,
  2593. gIhwO8i2o = dnnl_gIhwO8i2o,
  2594. gIhwo24i = dnnl_gIhwo24i,
  2595. gIhwO24i2o = dnnl_gIhwO24i2o,
  2596. gIdhwo8i = dnnl_gIdhwo8i,
  2597. gIdhwO8i2o = dnnl_gIdhwO8i2o,
  2598. gIdhwo24i = dnnl_gIdhwo24i,
  2599. gIdhwO24i2o = dnnl_gIdhwO24i2o,
  2600. OhwI24o = dnnl_OhwI24o,
  2601. gOhwI24o = dnnl_gOhwI24o,
  2602. AB8b24a2b = dnnl_AB8b24a2b,
  2603. ABc8b24a2b = dnnl_ABc8b24a2b,
  2604. AcB8b24a2b = dnnl_AcB8b24a2b,
  2605. ABcd8b24a2b = dnnl_ABcd8b24a2b,
  2606. AcdB8b24a2b = dnnl_AcdB8b24a2b,
  2607. ABcde8b24a2b = dnnl_ABcde8b24a2b,
  2608. AcdeB8b24a2b = dnnl_AcdeB8b24a2b,
  2609. AB8b8a2b = dnnl_AB8b8a2b,
  2610. ABc8b8a2b = dnnl_ABc8b8a2b,
  2611. AcB8b8a2b = dnnl_AcB8b8a2b,
  2612. ABcd8b8a2b = dnnl_ABcd8b8a2b,
  2613. AcdB8b8a2b = dnnl_AcdB8b8a2b,
  2614. ABcde8b8a2b = dnnl_ABcde8b8a2b,
  2615. AcdeB8b8a2b = dnnl_AcdeB8b8a2b,
  2616. OI8i8o2i = dnnl_OI8i8o2i,
  2617. OI8i24o2i = dnnl_OI8i24o2i,
  2618. OIw8i8o2i = dnnl_OIw8i8o2i,
  2619. OwI8i8o2i = dnnl_OwI8i8o2i,
  2620. OIw8i24o2i = dnnl_OIw8i24o2i,
  2621. OwI8i24o2i = dnnl_OwI8i24o2i,
  2622. OIhw8i8o2i = dnnl_OIhw8i8o2i,
  2623. OhwI8i8o2i = dnnl_OhwI8i8o2i,
  2624. OIhw8i24o2i = dnnl_OIhw8i24o2i,
  2625. OhwI8i24o2i = dnnl_OhwI8i24o2i,
  2626. OIdhw8i8o2i = dnnl_OIdhw8i8o2i,
  2627. OdhwI8i8o2i = dnnl_OdhwI8i8o2i,
  2628. OIdhw8i24o2i = dnnl_OIdhw8i24o2i,
  2629. OdhwI8i24o2i = dnnl_OdhwI8i24o2i,
  2630. BcA8b4a = dnnl_BcA8b4a,
  2631. BcdA8b4a = dnnl_BcdA8b4a,
  2632. BcdeA8b4a = dnnl_BcdeA8b4a,
  2633. aCdB8c4b = dnnl_aCdB8c4b,
  2634. aCdeB8c4b = dnnl_aCdeB8c4b,
  2635. aCdefB8c4b = dnnl_aCdefB8c4b,
  2636. BcA24b4a = dnnl_BcA24b4a,
  2637. BcdA24b4a = dnnl_BcdA24b4a,
  2638. BcdeA24b4a = dnnl_BcdeA24b4a,
  2639. aCdB24c4b = dnnl_aCdB24c4b,
  2640. aCdeB24c4b = dnnl_aCdeB24c4b,
  2641. aCdefB24c4b = dnnl_aCdefB24c4b,
  2642. ABc16a4b = dnnl_ABc16a4b,
  2643. ABcd16a4b = dnnl_ABcd16a4b,
  2644. ABcde16a4b = dnnl_ABcde16a4b,
  2645. IwO8i4o = dnnl_IwO8i4o,
  2646. IwO24i4o = dnnl_IwO24i4o,
  2647. IhwO8i4o = dnnl_IhwO8i4o,
  2648. IhwO24i4o = dnnl_IhwO24i4o,
  2649. IdhwO8i4o = dnnl_IdhwO8i4o,
  2650. IdhwO24i4o = dnnl_IdhwO24i4o,
  2651. gIwO8i4o = dnnl_gIwO8i4o,
  2652. gIwO24i4o = dnnl_gIwO24i4o,
  2653. gIhwO8i4o = dnnl_gIhwO8i4o,
  2654. gIhwO24i4o = dnnl_gIhwO24i4o,
  2655. gIdhwO8i4o = dnnl_gIdhwO8i4o,
  2656. gIdhwO24i4o = dnnl_gIdhwO24i4o,
  2657. BA2a24b = dnnl_BA2a24b,
  2658. aCB2b24c = dnnl_aCB2b24c,
  2659. BA2a8b = dnnl_BA2a8b,
  2660. aCB2b8c = dnnl_aCB2b8c,
  2661. BA8a24b = dnnl_BA8a24b,
  2662. aCB8b24c = dnnl_aCB8b24c,
  2663. BA8a16b = dnnl_BA8a16b,
  2664. aCB8b16c = dnnl_aCB8b16c,
  2665. BA8a8b = dnnl_BA8a8b,
  2666. aCB8b8c = dnnl_aCB8b8c,
  2667. bcad = dnnl_bcad,
  2668. cabd = dnnl_cabd,
  2669. dabc = dnnl_dabc,
  2670. decbA4a = dnnl_decbA4a,
  2671. defcbA4a = dnnl_defcbA4a,
  2672. hwioG4g = dnnl_hwioG4g,
  2673. dhwioG4g = dnnl_dhwioG4g,
  2674. aCBd4b4c = dnnl_aCBd4b4c,
  2675. aCBde4b4c = dnnl_aCBde4b4c,
  2676. aCBdef4b4c = dnnl_aCBdef4b4c,
  2677. BAc4a4b = dnnl_BAc4a4b,
  2678. BAcd4a4b = dnnl_BAcd4a4b,
  2679. BAcde4a4b = dnnl_BAcde4a4b,
  2680. IOw4o4i = dnnl_IOw4o4i,
  2681. IOhw4o4i = dnnl_IOhw4o4i,
  2682. IOdhw4o4i = dnnl_IOdhw4o4i,
  2683. gIOw4o4i = dnnl_gIOw4o4i,
  2684. gIOhw4o4i = dnnl_gIOhw4o4i,
  2685. gIOdhw4o4i = dnnl_gIOdhw4o4i,
  2686. };
  2687. /// A memory descriptor.
  2688. struct desc : public handle<dnnl_memory_desc_t> {
  2689. using handle<dnnl_memory_desc_t>::handle;
  2690. friend struct memory;
  2691. /// Constructs a zero (empty) memory descriptor. Such a memory
  2692. /// descriptor can be used to indicate absence of an argument.
  2693. desc() {
  2694. dnnl_memory_desc_t zero_md = nullptr;
  2695. error::wrap_c_api(
  2696. dnnl_memory_desc_create_with_tag(&zero_md, 0, nullptr,
  2697. dnnl_data_type_undef, dnnl_format_tag_undef),
  2698. "could not create a zero memory descriptor");
  2699. reset(zero_md);
  2700. }
  2701. /// Constructs a memory descriptor.
  2702. ///
  2703. /// @note
  2704. /// The logical order of dimensions corresponds to the `abc...`
  2705. /// format tag, and the physical meaning of the dimensions depends
  2706. /// both on the primitive that would operate on this memory and
  2707. /// the operation context.
  2708. ///
  2709. /// @param adims Tensor dimensions.
  2710. /// @param adata_type Data precision/type.
  2711. /// @param aformat_tag Memory format tag.
  2712. /// @param allow_empty A flag signifying whether construction is
  2713. /// allowed to fail without throwing an exception. In this case a
  2714. /// zero memory descriptor will be constructed. This flag is
  2715. /// optional and defaults to false.
  2716. desc(const dims &adims, data_type adata_type, format_tag aformat_tag,
  2717. bool allow_empty = false) {
  2718. validate_dims(adims);
  2719. dnnl_memory_desc_t md = nullptr;
  2720. dnnl_status_t status = dnnl_memory_desc_create_with_tag(&md,
  2721. (int)adims.size(), adims.data(), convert_to_c(adata_type),
  2722. convert_to_c(aformat_tag));
  2723. if (!allow_empty)
  2724. error::wrap_c_api(status,
  2725. "could not construct a memory descriptor using a "
  2726. "format tag");
  2727. reset(md);
  2728. }
  2729. /// Constructs a memory descriptor by strides.
  2730. ///
  2731. /// @note
  2732. /// The logical order of dimensions corresponds to the `abc...`
  2733. /// format tag, and the physical meaning of the dimensions depends
  2734. /// both on the primitive that would operate on this memory and
  2735. /// the operation context.
  2736. ///
  2737. /// @param adims Tensor dimensions.
  2738. /// @param adata_type Data precision/type.
  2739. /// @param strides Strides for each dimension.
  2740. /// @param allow_empty A flag signifying whether construction is
  2741. /// allowed to fail without throwing an exception. In this case a
  2742. /// zero memory descriptor will be constructed. This flag is
  2743. /// optional and defaults to false.
  2744. desc(const dims &adims, data_type adata_type, const dims &strides,
  2745. bool allow_empty = false) {
  2746. validate_dims(adims);
  2747. if (!strides.empty()) validate_dims(strides, (int)adims.size());
  2748. dnnl_memory_desc_t md = nullptr;
  2749. dnnl_status_t status = dnnl_memory_desc_create_with_strides(&md,
  2750. (int)adims.size(), adims.data(), convert_to_c(adata_type),
  2751. strides.empty() ? nullptr : &strides[0]);
  2752. if (!allow_empty)
  2753. error::wrap_c_api(status,
  2754. "could not construct a memory descriptor using "
  2755. "strides");
  2756. reset(md);
  2757. }
  2758. /// Function for creating a memory descriptor for CSR sparse encoding.
  2759. ///
  2760. /// The created memory descriptor will describe a memory object that
  2761. /// contains 3 buffers. The buffers have the following meaning and
  2762. /// assigned numbers (index):
  2763. /// - 0: values
  2764. /// - 1: indices
  2765. /// - 2: pointers
  2766. ///
  2767. /// @param adims Tensor dimensions.
  2768. /// @param adata_type Data precision/type.
  2769. /// @param nnz Number of non-zero entries.
  2770. /// @param index_dt Data type of indices.
  2771. /// @param pointer_dt Data type of pointers.
  2772. /// @param allow_empty A flag signifying whether construction is
  2773. /// allowed to fail without throwing an exception. In this case a
  2774. /// zero memory descriptor will be constructed. This flag is
  2775. /// optional and defaults to false.
  2776. /// @sa @ref dev_guide_sparsity
  2777. static desc csr(const dims &adims, data_type adata_type, dim nnz,
  2778. data_type index_dt, data_type pointer_dt,
  2779. bool allow_empty = false) {
  2780. validate_dims(adims);
  2781. dnnl_memory_desc_t md = nullptr;
  2782. dnnl_status_t status = dnnl_memory_desc_create_with_csr_encoding(
  2783. &md, (int)adims.size(), adims.data(),
  2784. convert_to_c(adata_type), nnz, convert_to_c(index_dt),
  2785. convert_to_c(pointer_dt));
  2786. if (!allow_empty)
  2787. error::wrap_c_api(status,
  2788. "could not create a memory descriptor for CSR sparse "
  2789. "encoding");
  2790. return desc {md};
  2791. }
  2792. /// Function for creating a memory descriptor for COO sparse encodings.
  2793. ///
  2794. /// The created memory descriptor will describe a memory object that
  2795. /// contains n+1 buffers for an n-dimensional tensor.
  2796. /// The buffers have the following meaning and assigned numbers (index):
  2797. /// - 0: values
  2798. /// - 1: indices for dimension 0
  2799. /// - 2: indices for dimension 1 ...
  2800. /// - n: indices for dimension n-1
  2801. ///
  2802. /// @param adims Tensor dimensions.
  2803. /// @param adata_type Data precision/type.
  2804. /// @param nnz Number of non-zero entries.
  2805. /// @param index_dt Data type of indices.
  2806. /// @param allow_empty A flag signifying whether construction is
  2807. /// allowed to fail without throwing an exception. In this case a
  2808. /// zero memory descriptor will be constructed. This flag is
  2809. /// optional and defaults to false.
  2810. /// @sa @ref dev_guide_sparsity
  2811. static desc coo(const dims &adims, data_type adata_type, dim nnz,
  2812. data_type index_dt, bool allow_empty = false) {
  2813. validate_dims(adims);
  2814. dnnl_memory_desc_t md = nullptr;
  2815. dnnl_status_t status = dnnl_memory_desc_create_with_coo_encoding(
  2816. &md, (int)adims.size(), adims.data(),
  2817. convert_to_c(adata_type), nnz, convert_to_c(index_dt));
  2818. if (!allow_empty)
  2819. error::wrap_c_api(status,
  2820. "could not create a memory descriptor for COO sparse "
  2821. "encoding");
  2822. return desc {md};
  2823. }
  2824. /// Function for creating a memory descriptor for packed sparse
  2825. /// encoding.
  2826. ///
  2827. /// The created memory descriptor cannot be used to create a memory
  2828. /// object. It can only be used to create a primitive descriptor to
  2829. /// query the actual memory descriptor (similar to the format tag
  2830. /// `any`).
  2831. ///
  2832. /// @warning
  2833. /// The meaning and content of the handles of the memory object that
  2834. /// is created using the queried memory descriptor are unspecified
  2835. /// therefore using the content is an undefined behavior.
  2836. ///
  2837. /// @param adims Tensor dimensions.
  2838. /// @param adata_type Data precision/type.
  2839. /// @param nnz Number of non-zero entries.
  2840. /// @param allow_empty A flag signifying whether construction is
  2841. /// allowed to fail without throwing an exception. In this case a
  2842. /// zero memory descriptor will be constructed. This flag is
  2843. /// optional and defaults to false.
  2844. /// @sa @ref dev_guide_sparsity
  2845. static desc packed(const dims &adims, data_type adata_type, dim nnz,
  2846. bool allow_empty = false) {
  2847. validate_dims(adims);
  2848. dnnl_memory_desc_t md = nullptr;
  2849. dnnl_status_t status = dnnl_memory_desc_create_with_packed_encoding(
  2850. &md, (int)adims.size(), adims.data(),
  2851. convert_to_c(adata_type), nnz);
  2852. if (!allow_empty)
  2853. error::wrap_c_api(status,
  2854. "could not create a memory descriptor for packed "
  2855. "sparse encoding");
  2856. return desc {md};
  2857. }
  2858. /// Creates a memory descriptor for a scalar value that resides on the host.
  2859. ///
  2860. /// @param adata_type Data type of the scalar.
  2861. /// @returns A memory descriptor for host-side scalar input.
  2862. static desc host_scalar(data_type adata_type) {
  2863. dnnl_memory_desc_t md = nullptr;
  2864. error::wrap_c_api(dnnl_memory_desc_create_host_scalar(
  2865. &md, convert_to_c(adata_type)),
  2866. "could not create a memory descriptor describing host side "
  2867. "scalar");
  2868. return desc {md};
  2869. }
  2870. /// Construct a memory descriptor from a C API ::dnnl_memory_desc_t
  2871. /// handle. The resulting handle is not weak and the C handle will be
  2872. /// destroyed during the destruction of the C++ object.
  2873. ///
  2874. /// @param md The C API memory descriptor.
  2875. desc(dnnl_memory_desc_t md) : handle<dnnl_memory_desc_t>(md) {}
  2876. /// Construct a memory descriptor from a binary blob.
  2877. ///
  2878. /// @param blob A binary blob previously queried from a memory descriptor.
  2879. desc(const std::vector<uint8_t> &blob) {
  2880. dnnl_memory_desc_t md = nullptr;
  2881. error::wrap_c_api(
  2882. dnnl_memory_desc_create_with_blob(&md, blob.data()),
  2883. "could not create a memory descriptor from blob");
  2884. reset(md);
  2885. }
  2886. /// Constructs a memory descriptor for a region inside an area
  2887. /// described by this memory descriptor.
  2888. //
  2889. /// @param adims Sizes of the region.
  2890. /// @param offsets Offsets to the region from the encompassing
  2891. /// memory object in each dimension.
  2892. /// @param allow_empty A flag signifying whether construction is
  2893. /// allowed to fail without throwing an exception. In this case a
  2894. /// zero memory descriptor will be returned. This flag is optional
  2895. /// and defaults to false.
  2896. /// @returns A memory descriptor for the region.
  2897. desc submemory_desc(const dims &adims, const dims &offsets,
  2898. bool allow_empty = false) const {
  2899. validate_dims(adims, get_ndims());
  2900. validate_dims(offsets, get_ndims());
  2901. dnnl_memory_desc_t sub_md = nullptr;
  2902. dnnl_status_t status = dnnl_memory_desc_create_submemory(
  2903. &sub_md, get(), adims.data(), offsets.data());
  2904. if (!allow_empty)
  2905. error::wrap_c_api(status, "could not construct a sub-memory");
  2906. return desc(sub_md);
  2907. }
  2908. /// Constructs a memory descriptor by reshaping an existing one. The
  2909. /// new memory descriptor inherits the data type. This operation is
  2910. /// valid only for memory descriptors that have format_kind set to
  2911. /// #dnnl::memory::format_kind::blocked or
  2912. /// #dnnl::memory::format_kind::any.
  2913. ///
  2914. /// The operation ensures that the transformation of the physical memory
  2915. /// format corresponds to the transformation of the logical dimensions.
  2916. /// If such transformation is impossible, the function either throws an
  2917. /// exception (default) or returns a zero memory descriptor depending on
  2918. /// the `allow_empty` flag.
  2919. ///
  2920. /// The reshape operation can be described as a combination of the
  2921. /// following basic operations:
  2922. /// 1. Add a dimension of size `1`. This is always possible.
  2923. /// 2. Remove a dimension of size `1`. This is possible only if the
  2924. /// dimension has no padding (i.e.
  2925. /// `padded_dims[dim] == dims[dim] && dims[dim] == 1`).
  2926. /// 3. Split a dimension into multiple ones. This is possible only if
  2927. /// the product of all tensor dimensions stays constant and the
  2928. /// dimension being split does not have padding (i.e.
  2929. /// `padded_dims[dim] = dims[dim]`).
  2930. /// 4. Join multiple consecutive dimensions into a single one. As in
  2931. /// the cases above, this requires that the dimensions do not have
  2932. /// padding and that the memory format is such that in physical
  2933. /// memory these dimensions are dense and have the same order as
  2934. /// their logical counterparts. This also assumes that these
  2935. /// dimensions are not blocked.
  2936. /// - Here, 'dense' means:
  2937. /// `stride for dim[i] == (stride for dim[i + 1]) * dim[i + 1]`;
  2938. /// - And 'same order' means:
  2939. /// `i < j` if and only if `stride for dim[j] <= stride for dim[i]`.
  2940. ///
  2941. /// @warning
  2942. /// Some combinations of physical memory layout and/or offsets or
  2943. /// dimensions may result in a failure to make a reshape.
  2944. ///
  2945. /// @param adims New dimensions. The product of dimensions must
  2946. /// remain constant.
  2947. /// @param allow_empty A flag signifying whether construction is
  2948. /// allowed to fail without throwing an exception. In this case a
  2949. /// zero memory descriptor will be returned. This flag is optional
  2950. /// and defaults to false.
  2951. /// @returns A new memory descriptor with new dimensions.
  2952. desc reshape(const dims &adims, bool allow_empty = false) const {
  2953. if (get_ndims()) validate_dims(adims, 1);
  2954. dnnl_memory_desc_t out_md = nullptr;
  2955. dnnl_status_t status = dnnl_memory_desc_reshape(
  2956. &out_md, get(), (int)adims.size(), adims.data());
  2957. if (!allow_empty)
  2958. error::wrap_c_api(
  2959. status, "could not reshape a memory descriptor");
  2960. return desc(out_md);
  2961. }
  2962. /// Constructs a memory descriptor by permuting axes in an existing
  2963. /// one.
  2964. ///
  2965. /// The physical memory layout representation is adjusted accordingly
  2966. /// to maintain the consistency between the logical and physical parts
  2967. /// of the memory descriptor. The new memory descriptor inherits the
  2968. /// data type.
  2969. ///
  2970. /// The new memory descriptor inherits the data type. This operation is
  2971. /// valid only for memory descriptors that have format_kind set to
  2972. /// #dnnl::memory::format_kind::blocked or
  2973. /// #dnnl::memory::format_kind::any.
  2974. ///
  2975. /// The logical axes will be permuted in the following manner:
  2976. /// @code
  2977. /// for (i = 0; i < get_ndims(); i++)
  2978. /// new_desc.dims()[permutation[i]] = dims()[i];
  2979. /// @endcode
  2980. ///
  2981. /// Example:
  2982. /// @code
  2983. /// std::vector<int> permutation = {1, 0}; // swap the first and
  2984. /// // the second axes
  2985. /// dnnl::memory::desc in_md(
  2986. /// {2, 3}, data_type, memory::format_tag::ab);
  2987. /// dnnl::memory::desc expect_out_md(
  2988. /// {3, 2}, data_type, memory::format_tag::ba);
  2989. ///
  2990. /// assert(in_md.permute_axes(permutation) == expect_out_md);
  2991. /// @endcode
  2992. ///
  2993. /// @param permutation Axes permutation.
  2994. /// @param allow_empty A flag signifying whether construction is
  2995. /// allowed to fail without throwing an exception. In this case a
  2996. /// zero memory descriptor will be returned. This flag is optional
  2997. /// and defaults to false.
  2998. /// @returns A new memory descriptor with new dimensions.
  2999. desc permute_axes(const std::vector<int> &permutation,
  3000. bool allow_empty = false) const {
  3001. validate_dims(permutation, get_ndims());
  3002. dnnl_memory_desc_t out_md = nullptr;
  3003. dnnl_status_t status = dnnl_memory_desc_permute_axes(
  3004. &out_md, get(), permutation.data());
  3005. if (!allow_empty)
  3006. error::wrap_c_api(status,
  3007. "could not permute axes of a memory descriptor");
  3008. return desc(out_md);
  3009. }
  3010. /// Returns a number of dimensions of the memory descriptor.
  3011. ///
  3012. /// @returns A number of dimensions.
  3013. int get_ndims() const { return query_s32(query::ndims_s32); }
  3014. /// Returns padded dimensions of the memory descriptor.
  3015. ///
  3016. /// @returns A copy of the padded dimensions vector.
  3017. memory::dims get_padded_dims() const {
  3018. return query_dims(query::padded_dims);
  3019. }
  3020. /// Returns padded offsets of the memory descriptor.
  3021. ///
  3022. /// @returns A copy of the padded offsets vector.
  3023. memory::dims get_padded_offsets() const {
  3024. return query_dims(query::padded_offsets);
  3025. }
  3026. /// Returns a submemory offset of the memory descriptor.
  3027. ///
  3028. /// @returns A submemory offset.
  3029. memory::dim get_submemory_offset() const {
  3030. dnnl_dim_t submemory_offset;
  3031. dnnl_status_t status = dnnl_memory_desc_query(
  3032. get(), dnnl_query_submemory_offset_s64, &submemory_offset);
  3033. return status == dnnl_success ? submemory_offset : 0;
  3034. }
  3035. /// Returns strides of the memory descriptor.
  3036. ///
  3037. /// @note
  3038. /// This API is only applicable to memory descriptors with format
  3039. /// kind #dnnl_blocked.
  3040. ///
  3041. /// @returns A copy of the strides vector.
  3042. /// @returns An empty #dnnl::memory::dims if the memory descriptor
  3043. /// does not have strides.
  3044. memory::dims get_strides() const { return query_dims(query::strides); }
  3045. /// Returns a number of inner blocks of the memory descriptor.
  3046. ///
  3047. /// @note
  3048. /// This API is only applicable to memory descriptors with format
  3049. /// kind #dnnl_blocked.
  3050. ///
  3051. /// @returns A number of inner blocks.
  3052. int get_inner_nblks() const {
  3053. return query_s32(query::inner_nblks_s32);
  3054. }
  3055. /// Returns inner blocks of the memory descriptor.
  3056. ///
  3057. /// @note
  3058. /// This API is only applicable to memory descriptors with format
  3059. /// kind #dnnl_blocked.
  3060. ///
  3061. /// @returns A copy of the inner blocks vector.
  3062. /// @returns An empty #dnnl::memory::dims if the memory descriptor
  3063. /// does not have inner blocks.
  3064. memory::dims get_inner_blks() const {
  3065. return query_dims(query::inner_blks);
  3066. }
  3067. /// Returns inner indices of the memory descriptor.
  3068. ///
  3069. /// @note
  3070. /// This API is only applicable to memory descriptors with format
  3071. /// kind #dnnl_blocked.
  3072. ///
  3073. /// @returns A copy of the inner indices vector.
  3074. /// @returns An empty #dnnl::memory::dims if the memory descriptor
  3075. /// does not have inner indices.
  3076. memory::dims get_inner_idxs() const {
  3077. return query_dims(query::inner_idxs);
  3078. }
  3079. /// Returns number of handles.
  3080. ///
  3081. /// @returns A number of handles.
  3082. int get_num_handles() const {
  3083. int nhandles;
  3084. dnnl_status_t status = dnnl_memory_desc_query_v2(
  3085. get(), dnnl_query_num_handles_s32, 0, &nhandles);
  3086. return status == dnnl_success ? nhandles : 0;
  3087. }
  3088. /// Returns a number of non-zero entries of the memory descriptor.
  3089. ///
  3090. /// @returns A number non-zero entries.
  3091. dim get_nnz() const {
  3092. dnnl_dim_t nnz;
  3093. dnnl_status_t status = dnnl_memory_desc_query_v2(
  3094. get(), dnnl_query_nnz_s64, 0, &nnz);
  3095. return status == dnnl_success ? nnz : 0;
  3096. }
  3097. /// Returns the sparse encoding of the memory descriptor.
  3098. ///
  3099. /// @returns the sparse encoding kind.
  3100. /// @sa @ref dev_guide_sparsity
  3101. memory::sparse_encoding get_sparse_encoding() const {
  3102. dnnl_sparse_encoding_t sparse_encoding;
  3103. dnnl_status_t status = dnnl_memory_desc_query_v2(
  3104. get(), dnnl_query_sparse_encoding, 0, &sparse_encoding);
  3105. return status == dnnl_success
  3106. ? static_cast<dnnl::memory::sparse_encoding>(
  3107. sparse_encoding)
  3108. : dnnl::memory::sparse_encoding::undef;
  3109. }
  3110. /// Returns the data type of the memory descriptor.
  3111. ///
  3112. /// @returns The data type.
  3113. memory::data_type get_data_type(int index = 0) const {
  3114. return query_data_type(query::data_type, index);
  3115. }
  3116. /// Returns the format kind of the memory descriptor.
  3117. ///
  3118. /// @returns the format kind.
  3119. memory::format_kind get_format_kind() const {
  3120. dnnl_format_kind_t format_kind;
  3121. dnnl_status_t status = dnnl_memory_desc_query(
  3122. get(), dnnl_query_format_kind, &format_kind);
  3123. return status == dnnl_success
  3124. ? static_cast<dnnl::memory::format_kind>(format_kind)
  3125. : dnnl::memory::format_kind::undef;
  3126. }
  3127. /// Returns dimensions of the memory descriptor.
  3128. ///
  3129. /// Potentially expensive due to the data copy involved.
  3130. /// @returns A copy of the dimensions vector.
  3131. memory::dims get_dims() const { return query_dims(query::dims); }
  3132. /// Returns size of the memory descriptor in bytes.
  3133. /// @param index Data index. Defaults to 0.
  3134. /// @returns The number of bytes required to allocate a memory buffer
  3135. /// for data with a particular @p index described by this memory
  3136. /// descriptor including the padding area.
  3137. size_t get_size(int index = 0) const {
  3138. return dnnl_memory_desc_get_size_v2(get(), index);
  3139. }
  3140. /// Returns a binary blob associated with the given memory descriptor
  3141. /// @returns The memory descriptor blob associated with the memory descriptor
  3142. std::vector<uint8_t> get_blob() {
  3143. size_t size;
  3144. dnnl_status_t status
  3145. = dnnl_memory_desc_get_blob(nullptr, &size, get());
  3146. error::wrap_c_api(
  3147. status, "could not get memory descriptor blob size");
  3148. std::vector<uint8_t> out_blob(size);
  3149. status = dnnl_memory_desc_get_blob(out_blob.data(), &size, get());
  3150. error::wrap_c_api(status, "could not get memory descriptor blob");
  3151. return out_blob;
  3152. }
  3153. /// Checks whether the memory descriptor is zero (empty).
  3154. /// @returns @c true if the memory descriptor describes an empty
  3155. /// memory and @c false otherwise.
  3156. bool is_zero() const { return get_ndims() == 0; }
  3157. /// An equality operator.
  3158. /// @param other Another memory descriptor.
  3159. /// @returns Whether this and the other memory descriptors have
  3160. /// the same format tag, dimensions, strides, blocking, etc.
  3161. bool operator==(const desc &other) const {
  3162. return dnnl_memory_desc_equal(get(), other.get()) != 0;
  3163. }
  3164. /// An inequality operator.
  3165. /// @param other Another memory descriptor.
  3166. /// @returns Whether this and the other memory descriptors describe
  3167. /// different memory.
  3168. bool operator!=(const desc &other) const { return !operator==(other); }
  3169. private:
  3170. memory::data_type query_data_type(query what, int index) const {
  3171. dnnl_data_type_t data_type;
  3172. dnnl_status_t status = dnnl_memory_desc_query_v2(
  3173. get(), dnnl::convert_to_c(what), index, &data_type);
  3174. return status == dnnl_success
  3175. ? static_cast<dnnl::memory::data_type>(data_type)
  3176. : dnnl::memory::data_type::undef;
  3177. }
  3178. int query_s32(query what) const {
  3179. int res;
  3180. dnnl_status_t status = dnnl_memory_desc_query(
  3181. get(), dnnl::convert_to_c(what), &res);
  3182. return status == dnnl_success ? res : 0;
  3183. }
  3184. memory::dims query_dims(query what) const {
  3185. dnnl_dims_t *c_dims;
  3186. dnnl_status_t status = dnnl_memory_desc_query(
  3187. get(), dnnl::convert_to_c(what), &c_dims);
  3188. const int ndims
  3189. = (what == query::inner_idxs || what == query::inner_blks)
  3190. ? get_inner_nblks()
  3191. : get_ndims();
  3192. return status == dnnl_success
  3193. ? memory::dims(*c_dims, *c_dims + ndims)
  3194. : memory::dims {};
  3195. }
  3196. };
  3197. /// Default constructor.
  3198. ///
  3199. /// Constructs an empty memory object, which can be used to indicate
  3200. /// absence of a parameter.
  3201. memory() = default;
  3202. /// Constructs a memory object.
  3203. ///
  3204. /// Unless @p handle is equal to #DNNL_MEMORY_NONE, the constructed memory
  3205. /// object will have the underlying buffer set. In this case, the buffer
  3206. /// will be initialized as if #dnnl::memory::set_data_handle() had been
  3207. /// called.
  3208. ///
  3209. /// @sa memory::set_data_handle()
  3210. ///
  3211. /// @param md Memory descriptor.
  3212. /// @param aengine Engine to store the data on.
  3213. /// @param handle Handle of the memory buffer to use.
  3214. /// - A pointer to the user-allocated buffer. In this case the library
  3215. /// doesn't own the buffer.
  3216. /// - The #DNNL_MEMORY_ALLOCATE special value. Instructs the library to
  3217. /// allocate the buffer for the memory object. In this case the
  3218. /// library owns the buffer.
  3219. /// - #DNNL_MEMORY_NONE to create dnnl::memory without an underlying
  3220. /// buffer.
  3221. memory(const desc &md, const engine &aengine, void *handle)
  3222. : memory(md, aengine, std::vector<void *> {handle}) {}
  3223. /// Constructs a memory object with multiple handles.
  3224. ///
  3225. /// Unless @p handle is equal to #DNNL_MEMORY_NONE, the constructed memory
  3226. /// object will have the underlying buffer set. In this case, the buffer
  3227. /// will be initialized as if #dnnl::memory::set_data_handle() had been
  3228. /// called.
  3229. ///
  3230. /// @sa memory::set_data_handle()
  3231. ///
  3232. /// @param md Memory descriptor.
  3233. /// @param aengine Engine to store the data on.
  3234. /// @param handles Handles of the memory buffers to use.
  3235. /// For each element of the @p handles vector the following applies:
  3236. /// - A pointer to the user-allocated buffer. In this case the library
  3237. /// doesn't own the buffer.
  3238. /// - The #DNNL_MEMORY_ALLOCATE special value. Instructs the library to
  3239. /// allocate the buffer for the memory object. In this case the
  3240. /// library owns the buffer.
  3241. /// - #DNNL_MEMORY_NONE Instructs the library to skip allocation of the
  3242. /// memory buffer.
  3243. memory(const desc &md, const engine &aengine, std::vector<void *> handles) {
  3244. dnnl_memory_t result;
  3245. dnnl_status_t status = dnnl_memory_create_v2(&result, md.get(),
  3246. aengine.get(), (int)handles.size(), handles.data());
  3247. error::wrap_c_api(status, "could not create a memory object");
  3248. reset(result);
  3249. }
  3250. /// Constructs a memory object.
  3251. ///
  3252. /// The underlying buffer(s) for the memory will be allocated by the
  3253. /// library.
  3254. /// @param md Memory descriptor.
  3255. /// @param aengine Engine to store the data on.
  3256. memory(const desc &md, const engine &aengine) {
  3257. dnnl_status_t status;
  3258. dnnl_memory_t result;
  3259. const int nhandles = md.get_num_handles();
  3260. std::vector<void *> handles(nhandles, DNNL_MEMORY_ALLOCATE);
  3261. status = dnnl_memory_create_v2(&result, md.get(), aengine.get(),
  3262. (int)handles.size(), handles.data());
  3263. error::wrap_c_api(status, "could not create a memory object");
  3264. reset(result);
  3265. }
  3266. /// Constructs a memory object that wraps a host-side scalar value.
  3267. ///
  3268. /// @note The scalar value is copied into the newly allocated memory storage,
  3269. /// so the user does not need to manage the lifetime of the original scalar data.
  3270. ///
  3271. /// @tparam T Type of the scalar value.
  3272. /// @param md Memory descriptor describing a scalar value residing on the host.
  3273. /// @param value The scalar value to be wrapped by the memory object.
  3274. ///
  3275. /// @throws error if the memory object could not be created.
  3276. template <typename T>
  3277. memory(const desc &md, const T value) {
  3278. dnnl_memory_t result;
  3279. // Check that the data type of T matches the memory descriptor's data type
  3280. // For host-side scalars, md.get_size() is data_type size
  3281. if (sizeof(T) != md.get_size()) {
  3282. DNNL_THROW_ERROR(dnnl_invalid_arguments,
  3283. "scalar type size does not match memory descriptor data "
  3284. "type size");
  3285. } else {
  3286. dnnl_status_t status = dnnl_memory_create_host_scalar(
  3287. &result, md.get(), (void *)&value);
  3288. error::wrap_c_api(status, "could not create a memory object");
  3289. }
  3290. reset(result);
  3291. }
  3292. /// Returns the associated memory descriptor.
  3293. desc get_desc() const {
  3294. const_dnnl_memory_desc_t cdesc;
  3295. error::wrap_c_api(dnnl_memory_get_memory_desc(get(), &cdesc),
  3296. "could not get a memory descriptor from a memory object");
  3297. dnnl_memory_desc_t cloned_md = nullptr;
  3298. error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc),
  3299. "could not clone a memory descriptor");
  3300. return desc(cloned_md);
  3301. }
  3302. /// Returns the associated engine.
  3303. engine get_engine() const {
  3304. dnnl_engine_t c_engine;
  3305. error::wrap_c_api(dnnl_memory_get_engine(get(), &c_engine),
  3306. "could not get an engine from a memory object");
  3307. return engine(c_engine, true);
  3308. }
  3309. /// Returns an underlying memory buffer that corresponds to the given index.
  3310. ///
  3311. /// On the CPU engine, or when using USM, this is a pointer to the
  3312. /// allocated memory.
  3313. void *get_data_handle(int index = 0) const {
  3314. void *handle;
  3315. error::wrap_c_api(dnnl_memory_get_data_handle_v2(get(), &handle, index),
  3316. "could not get a native handle from a memory object");
  3317. return handle;
  3318. }
  3319. /// Sets an underlying memory buffer that corresponds to the given index.
  3320. ///
  3321. /// @param handle Memory buffer to use. On the CPU engine or when USM is
  3322. /// used, the memory buffer is a pointer to the actual data. For OpenCL
  3323. /// it is a cl_mem. It must have at least
  3324. /// #dnnl::memory::desc::get_size() bytes allocated.
  3325. /// @param index Memory index to attach the buffer. Defaults to 0.
  3326. void set_data_handle(void *handle, int index = 0) const {
  3327. error::wrap_c_api(dnnl_memory_set_data_handle_v2(get(), handle, index),
  3328. "could not set native handle of a memory object");
  3329. }
  3330. /// Returns the scalar value stored in the memory object as type T.
  3331. ///
  3332. /// @tparam T Type to cast the scalar value to.
  3333. template <typename T>
  3334. T get_host_scalar_value() const {
  3335. const_dnnl_memory_desc_t cdesc;
  3336. error::wrap_c_api(dnnl_memory_get_memory_desc(get(), &cdesc),
  3337. "could not get memory descriptor");
  3338. if (sizeof(T) != dnnl_memory_desc_get_size_v2(cdesc, 0)) {
  3339. DNNL_THROW_ERROR(dnnl_invalid_arguments,
  3340. "scalar type size does not match memory descriptor data "
  3341. "type size");
  3342. }
  3343. T value;
  3344. error::wrap_c_api(dnnl_memory_get_host_scalar_value(get(), &value),
  3345. "could not get host scalar value from a memory object");
  3346. return value;
  3347. }
  3348. /// Sets the scalar value stored in the memory object.
  3349. ///
  3350. /// @note The scalar value is copied into the memory storage, so the user
  3351. /// does not need to manage the lifetime of the original scalar data.
  3352. ///
  3353. /// @param value Pointer to the scalar value to set.
  3354. template <typename T>
  3355. void set_host_scalar_value(const T value) const {
  3356. const_dnnl_memory_desc_t cdesc;
  3357. error::wrap_c_api(dnnl_memory_get_memory_desc(get(), &cdesc),
  3358. "could not get memory descriptor from a memory object");
  3359. if (sizeof(T) != dnnl_memory_desc_get_size_v2(cdesc, 0)) {
  3360. DNNL_THROW_ERROR(dnnl_invalid_arguments,
  3361. "scalar type size does not match memory descriptor data "
  3362. "type size");
  3363. }
  3364. error::wrap_c_api(dnnl_memory_set_host_scalar_value(get(), &value),
  3365. "could not set host scalar value to a memory object");
  3366. }
  3367. /// Maps a memory object and returns a host-side pointer to a memory
  3368. /// buffer with a copy of its contents. The memory buffer corresponds to
  3369. /// the given index.
  3370. ///
  3371. /// Mapping enables read/write directly from/to the memory contents for
  3372. /// engines that do not support direct memory access.
  3373. ///
  3374. /// Mapping is an exclusive operation - a memory object cannot be used in
  3375. /// other operations until it is unmapped via #dnnl::memory::unmap_data()
  3376. /// call.
  3377. ///
  3378. /// @note
  3379. /// Any primitives working with the memory should be completed before
  3380. /// the memory is mapped. Use #dnnl::stream::wait() to synchronize the
  3381. /// corresponding execution stream.
  3382. ///
  3383. /// @note
  3384. /// The map_data and unmap_data functions are provided mainly for
  3385. /// debug and testing purposes and their performance may be suboptimal.
  3386. ///
  3387. /// @tparam T Data type to return a pointer to.
  3388. /// @param index Index of the buffer. Defaults to 0.
  3389. /// @returns Pointer to the mapped memory.
  3390. template <typename T = void>
  3391. T *map_data(int index = 0) const {
  3392. void *mapped_ptr;
  3393. error::wrap_c_api(dnnl_memory_map_data_v2(get(), &mapped_ptr, index),
  3394. "could not map memory object data");
  3395. return static_cast<T *>(mapped_ptr);
  3396. }
  3397. /// Unmaps a memory object and writes back any changes made to the
  3398. /// previously mapped memory buffer. The memory buffer corresponds to
  3399. /// the given index.
  3400. ///
  3401. /// @note
  3402. /// The map_data and unmap_data functions are provided mainly for
  3403. /// debug and testing purposes and their performance may be
  3404. /// suboptimal.
  3405. ///
  3406. /// @param mapped_ptr A pointer previously returned by
  3407. /// #dnnl::memory::map_data().
  3408. /// @param index Index of the buffer. Defaults to 0.
  3409. void unmap_data(void *mapped_ptr, int index = 0) const {
  3410. error::wrap_c_api(dnnl_memory_unmap_data_v2(get(), mapped_ptr, index),
  3411. "could not unmap memory object data");
  3412. }
  3413. static dnnl_data_type_t convert_to_c(data_type adata_type) {
  3414. return static_cast<dnnl_data_type_t>(adata_type);
  3415. }
  3416. static dnnl_format_tag_t convert_to_c(format_tag format) {
  3417. return static_cast<dnnl_format_tag_t>(format);
  3418. }
  3419. };
  3420. inline bool operator==(dnnl_data_type_t a, memory::data_type b) {
  3421. return a == memory::convert_to_c(b);
  3422. }
  3423. inline bool operator!=(dnnl_data_type_t a, memory::data_type b) {
  3424. return !(a == b);
  3425. }
  3426. inline bool operator==(memory::data_type a, dnnl_data_type_t b) {
  3427. return b == a;
  3428. }
  3429. inline bool operator!=(memory::data_type a, dnnl_data_type_t b) {
  3430. return !(a == b);
  3431. }
  3432. inline bool operator==(dnnl_format_tag_t a, memory::format_tag b) {
  3433. return a == memory::convert_to_c(b);
  3434. }
  3435. inline bool operator!=(dnnl_format_tag_t a, memory::format_tag b) {
  3436. return !(a == b);
  3437. }
  3438. inline bool operator==(memory::format_tag a, dnnl_format_tag_t b) {
  3439. return b == a;
  3440. }
  3441. inline bool operator!=(memory::format_tag a, dnnl_format_tag_t b) {
  3442. return !(a == b);
  3443. }
  3444. /// @} dnnl_api_memory
  3445. /// @addtogroup dnnl_api_primitives
  3446. /// @{
  3447. /// @addtogroup dnnl_api_attributes Attributes
  3448. ///
  3449. /// A container for parameters that extend primitives behavior.
  3450. ///
  3451. /// @{
  3452. /// @cond DO_NOT_DOCUMENT_THIS
  3453. template <>
  3454. struct handle_traits<dnnl_post_ops_t> {
  3455. static dnnl_status_t destructor(dnnl_post_ops_t p) {
  3456. return dnnl_post_ops_destroy(p);
  3457. }
  3458. };
  3459. /// @endcond
  3460. /// Post-ops.
  3461. ///
  3462. /// Post-ops are computations executed after the main primitive computations
  3463. /// and are attached to the primitive via primitive attributes.
  3464. ///
  3465. /// @sa @ref dev_guide_attributes_post_ops
  3466. ///
  3467. struct post_ops : public handle<dnnl_post_ops_t> {
  3468. using handle<dnnl_post_ops_t>::handle;
  3469. /// Constructs an empty sequence of post-ops.
  3470. post_ops() {
  3471. dnnl_post_ops_t result;
  3472. error::wrap_c_api(
  3473. dnnl_post_ops_create(&result), "could not create post-ops");
  3474. reset(result);
  3475. }
  3476. /// Creates post-ops primitive attribute from a C API ::dnnl_post_ops_t
  3477. /// handle. The resulting handle is not weak and the C handle will be
  3478. /// destroyed during the destruction of the C++ object.
  3479. ///
  3480. /// @param post_ops The C API post-ops primitive attribute.
  3481. post_ops(dnnl_post_ops_t post_ops) : handle<dnnl_post_ops_t>(post_ops) {}
  3482. /// Returns the number of post-ops entries.
  3483. int len() const { return dnnl_post_ops_len(get()); }
  3484. /// Returns the primitive kind of post-op at entry with a certain index.
  3485. /// @param index Index of the post-op to return the kind for.
  3486. /// @returns Primitive kind of the post-op at the specified index.
  3487. primitive::kind kind(int index) const {
  3488. error::wrap_c_api(index < len() ? dnnl_success : dnnl_invalid_arguments,
  3489. "post-ops index is out of range");
  3490. return static_cast<primitive::kind>(
  3491. dnnl_post_ops_get_kind(get(), index));
  3492. }
  3493. /// Appends an accumulation (sum) post-op. Prior to accumulating the
  3494. /// result, the previous value will be will be reduced by zero point
  3495. /// @p zero_point and multiplied by a scaling factor @p scale.
  3496. ///
  3497. /// The kind of this post-op is #dnnl::primitive::kind::sum.
  3498. ///
  3499. /// This feature may improve performance for cases like dequantize the
  3500. /// asymmetrically quantized sum's src1 tensor to f32 domain before
  3501. /// performing the sum operation by subtracting @p zero_point before the
  3502. /// scaling.
  3503. ///
  3504. /// In the simplest case when the accumulation is the only post-op,
  3505. /// the computations will be `dst[:] := scale * (dst[:] - zero_point) +
  3506. /// op(...)` instead of `dst[:] := op(...)`.
  3507. ///
  3508. /// If @p data_type is specified, the original dst tensor will be
  3509. /// reinterpreted as a tensor with the provided data type. Because it is a
  3510. /// reinterpretation, data_type and dst data type should have the same size.
  3511. /// As a result, computations will be `dst[:] <- scale *
  3512. /// (as_data_type(dst[:]) - zero_point) + op(...)` instead of
  3513. /// `dst[:] <- op(...)`.
  3514. ///
  3515. /// @note
  3516. /// This post-op executes in-place and does not change the
  3517. /// destination layout.
  3518. ///
  3519. /// @param scale Scaling factor.
  3520. /// @param zero_point Zero point.
  3521. /// @param data_type Data type.
  3522. void append_sum(float scale = 1.f, int32_t zero_point = 0,
  3523. memory::data_type data_type = memory::data_type::undef) {
  3524. error::wrap_c_api(dnnl_post_ops_append_sum(get(), scale, zero_point,
  3525. memory::convert_to_c(data_type)),
  3526. "could not append a sum post-op");
  3527. }
  3528. /// Returns the parameters of an accumulation (sum) post-op.
  3529. ///
  3530. /// @param index Index of the sum post-op.
  3531. /// @param scale Scaling factor of the sum post-op.
  3532. void get_params_sum(int index, float &scale) const {
  3533. error::wrap_c_api(dnnl_post_ops_get_params_sum(
  3534. get(), index, &scale, nullptr, nullptr),
  3535. "could not get parameters of a sum post-op");
  3536. }
  3537. /// Returns the parameters of an accumulation (sum) post-op.
  3538. ///
  3539. /// @param index Index of the sum post-op.
  3540. /// @param scale Scaling factor of the sum post-op.
  3541. /// @param data_type Data type of the sum post-op.
  3542. void get_params_sum(
  3543. int index, float &scale, memory::data_type &data_type) const {
  3544. dnnl_data_type_t c_data_type;
  3545. error::wrap_c_api(dnnl_post_ops_get_params_sum(
  3546. get(), index, &scale, nullptr, &c_data_type),
  3547. "could not get parameters of a sum post-op");
  3548. data_type = static_cast<memory::data_type>(c_data_type);
  3549. }
  3550. /// Returns the parameters of an accumulation (sum) post-op.
  3551. ///
  3552. /// @param index Index of the sum post-op.
  3553. /// @param scale Scaling factor of the sum post-op.
  3554. /// @param zero_point Single scalar int32_t value of zeropoint.
  3555. /// @param data_type Data type of the sum post-op.
  3556. void get_params_sum(int index, float &scale, int32_t &zero_point,
  3557. memory::data_type &data_type) const {
  3558. dnnl_data_type_t c_data_type;
  3559. error::wrap_c_api(dnnl_post_ops_get_params_sum(get(), index, &scale,
  3560. &zero_point, &c_data_type),
  3561. "could not get parameters of a sum post-op");
  3562. data_type = static_cast<memory::data_type>(c_data_type);
  3563. }
  3564. /// Appends an elementwise post-op.
  3565. ///
  3566. /// The kind of this post-op is #dnnl::primitive::kind::eltwise.
  3567. ///
  3568. /// In the simplest case when the elementwise is the only post-op, the
  3569. /// computations would be `dst[:] := eltwise_op (op(...))` instead
  3570. /// of `dst[:] <- op(...)`, where eltwise_op is configured with the given
  3571. /// parameters.
  3572. ///
  3573. /// @param aalgorithm Elementwise algorithm.
  3574. /// @param alpha Alpha parameter for the elementwise algorithm.
  3575. /// @param beta Beta parameter for the elementwise algorithm.
  3576. void append_eltwise(algorithm aalgorithm, float alpha, float beta) {
  3577. error::wrap_c_api(dnnl_post_ops_append_eltwise(
  3578. get(), convert_to_c(aalgorithm), alpha, beta),
  3579. "could not append an elementwise post-op");
  3580. }
  3581. /// Returns parameters of an elementwise post-op.
  3582. ///
  3583. /// @param index Index of the post-op.
  3584. /// @param aalgorithm Output elementwise algorithm kind.
  3585. /// @param alpha Output alpha parameter for the elementwise algorithm.
  3586. /// @param beta Output beta parameter for the elementwise algorithm.
  3587. void get_params_eltwise(
  3588. int index, algorithm &aalgorithm, float &alpha, float &beta) const {
  3589. dnnl_alg_kind_t c_alg;
  3590. error::wrap_c_api(dnnl_post_ops_get_params_eltwise(
  3591. get(), index, &c_alg, &alpha, &beta),
  3592. "could not get parameters of an elementwise post-op");
  3593. aalgorithm = static_cast<dnnl::algorithm>(c_alg);
  3594. }
  3595. /// Appends a depthwise post-op convolution.
  3596. ///
  3597. /// This post-op can only be fused with a 2D 1x1 convolution (convolution
  3598. /// with weights spatial dimension equal to 1 i.e., kh=kw=1).
  3599. ///
  3600. /// The kind of this post-op is #dnnl_convolution.
  3601. ///
  3602. /// The number of outputs for primitive remain same as before. The output
  3603. /// spatial size can be derived as below:
  3604. ///
  3605. /// output_height = ceil(output_height_1x1_convolution, stride)
  3606. /// output_width = ceil(output_width_1x1_convolution, stride)
  3607. ///
  3608. /// See @ref dev_guide_attributes_post_ops_depthwise and
  3609. /// @ref dev_guide_attributes_post_ops_depthwise_fusion for more info.
  3610. ///
  3611. /// @param weights_data_type Weights data type of depthwise post-op
  3612. /// @param bias_data_type Bias data type of depthwise post-op
  3613. /// @param dst_data_type Output data type of depthwise post-op
  3614. /// @param kernel_size Size of kernel of depthwise post-op
  3615. /// @param stride_size Size of stride of depthwise post-op
  3616. /// @param padding_l_size Size of left and top paddings of depthwise post-op
  3617. void append_dw(memory::data_type weights_data_type,
  3618. memory::data_type bias_data_type, memory::data_type dst_data_type,
  3619. memory::dim kernel_size, memory::dim stride_size,
  3620. memory::dim padding_l_size) {
  3621. error::wrap_c_api(dnnl_post_ops_append_dw(get(),
  3622. memory::convert_to_c(weights_data_type),
  3623. memory::convert_to_c(bias_data_type),
  3624. memory::convert_to_c(dst_data_type),
  3625. kernel_size, stride_size, padding_l_size),
  3626. "could not append depthwise post-op");
  3627. }
  3628. /// Returns the parameters of an depthwise post-op.
  3629. ///
  3630. /// @param index Index of the elementwise post-op.
  3631. /// @param weights_data_type Weights data type of depthwise post-op
  3632. /// @param bias_data_type Bias data type of depthwise post-op
  3633. /// @param dst_data_type Output data type of depthwise post-op
  3634. /// @param kernel_size Size of kernel of depthwise post-op
  3635. /// @param stride_size Size of stride of depthwise post-op
  3636. /// @param padding_l_size Size of left and top paddings of depthwise post-op
  3637. void get_params_dw(int index, memory::data_type &weights_data_type,
  3638. memory::data_type &bias_data_type, memory::data_type &dst_data_type,
  3639. memory::dim &kernel_size, memory::dim &stride_size,
  3640. memory::dim &padding_l_size) const {
  3641. dnnl_data_type_t c_weights_data_type;
  3642. dnnl_data_type_t c_bias_data_type;
  3643. dnnl_data_type_t c_dst_data_type;
  3644. dnnl_dim_t c_kernel_size;
  3645. dnnl_dim_t c_stride_size;
  3646. dnnl_dim_t c_padding_l_size;
  3647. error::wrap_c_api(
  3648. dnnl_post_ops_get_params_dw(get(), index, &c_weights_data_type,
  3649. &c_bias_data_type, &c_dst_data_type, &c_kernel_size,
  3650. &c_stride_size, &c_padding_l_size),
  3651. "could not get parameters of depthwise post-op");
  3652. weights_data_type = static_cast<memory::data_type>(c_weights_data_type);
  3653. bias_data_type = static_cast<memory::data_type>(c_bias_data_type);
  3654. dst_data_type = static_cast<memory::data_type>(c_dst_data_type);
  3655. kernel_size = c_kernel_size;
  3656. stride_size = c_stride_size;
  3657. padding_l_size = c_padding_l_size;
  3658. }
  3659. /// Appends a binary post-op.
  3660. ///
  3661. /// This post operation is categorized as #dnnl_binary.
  3662. ///
  3663. /// In the simplest case when the binary is the only post operation, the
  3664. /// computations will be:
  3665. ///
  3666. /// dst[:] <- binary_op (dst[:], another_input[:])
  3667. ///
  3668. /// where binary_op is configured with the given parameters. binary_op
  3669. /// supports broadcast semantics for a second operand.
  3670. ///
  3671. /// @param aalgorithm Binary algorithm for the post-op.
  3672. /// @param src1_desc Memory descriptor of a second operand.
  3673. void append_binary(algorithm aalgorithm, const memory::desc &src1_desc) {
  3674. error::wrap_c_api(dnnl_post_ops_append_binary(get(),
  3675. convert_to_c(aalgorithm), src1_desc.get()),
  3676. "could not append a binary post-op");
  3677. }
  3678. /// Appends a binary post-op with ternary operators.
  3679. ///
  3680. /// This post operation is categorized as #dnnl_binary.
  3681. ///
  3682. /// In the simplest case when this is the only post operation, the
  3683. /// computations will be:
  3684. ///
  3685. /// dst[:] <- binary_op (dst[:], another_input1[:], another_input2[:])
  3686. ///
  3687. /// where binary_op is configured with the given parameters. binary_op
  3688. /// supports broadcast semantics only for the second operand and not for the
  3689. /// third operand.
  3690. ///
  3691. /// @param aalgorithm Binary algorithm for the post-op.
  3692. /// @param src1_desc Memory descriptor of the second operand.
  3693. /// @param src2_desc Memory descriptor of the third operand. If the specified
  3694. /// algorithm is not one that requires a ternary input, src2_desc will be
  3695. /// ignored.
  3696. void append_binary(algorithm aalgorithm, const memory::desc &src1_desc,
  3697. const memory::desc &src2_desc) {
  3698. error::wrap_c_api(
  3699. dnnl_post_ops_append_binary_v2(get(), convert_to_c(aalgorithm),
  3700. src1_desc.get(), src2_desc.get()),
  3701. "could not append a binary post-op with ternary operators");
  3702. }
  3703. /// Returns the parameters of a binary post-op.
  3704. ///
  3705. /// @param index Index of the binary post-op.
  3706. /// @param aalgorithm Output binary algorithm kind.
  3707. /// @param src1_desc Output memory descriptor of a second operand.
  3708. void get_params_binary(
  3709. int index, algorithm &aalgorithm, memory::desc &src1_desc) const {
  3710. dnnl_alg_kind_t c_alg;
  3711. const_dnnl_memory_desc_t cdesc;
  3712. error::wrap_c_api(
  3713. dnnl_post_ops_get_params_binary(get(), index, &c_alg, &cdesc),
  3714. "could not get parameters of a binary post-op");
  3715. aalgorithm = static_cast<dnnl::algorithm>(c_alg);
  3716. dnnl_memory_desc_t cloned_md = nullptr;
  3717. error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc),
  3718. "could not clone a memory descriptor");
  3719. src1_desc = memory::desc(cloned_md);
  3720. }
  3721. /// Returns the parameters of a binary post-op with ternary operators.
  3722. ///
  3723. /// @param index Index of the binary post-op.
  3724. /// @param aalgorithm Output binary algorithm kind.
  3725. /// @param src1_desc Output memory descriptor of the second operand.
  3726. /// @param src2_desc Output memory descriptor of the third operand.
  3727. void get_params_binary(int index, algorithm &aalgorithm,
  3728. memory::desc &src1_desc, memory::desc &src2_desc) const {
  3729. dnnl_alg_kind_t c_alg;
  3730. const_dnnl_memory_desc_t cdesc1, cdesc2;
  3731. error::wrap_c_api(dnnl_post_ops_get_params_binary_v2(
  3732. get(), index, &c_alg, &cdesc1, &cdesc2),
  3733. "could not get parameters of a binary post-op with ternary "
  3734. "operators");
  3735. aalgorithm = static_cast<dnnl::algorithm>(c_alg);
  3736. dnnl_memory_desc_t cloned_md1 = nullptr;
  3737. dnnl_memory_desc_t cloned_md2 = nullptr;
  3738. error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md1, cdesc1),
  3739. "could not clone a memory descriptor");
  3740. src1_desc = memory::desc(cloned_md1);
  3741. error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md2, cdesc2),
  3742. "could not clone a memory descriptor");
  3743. src2_desc = memory::desc(cloned_md2);
  3744. }
  3745. /// Appends a prelu forward post-op.
  3746. ///
  3747. /// The kind of this post-op is #dnnl::primitive::kind::prelu.
  3748. ///
  3749. /// The post-op can be defined as:
  3750. ///
  3751. /// dst[:] <- prelu(dst[:], weights[:])
  3752. /// prelu:
  3753. /// dst[:] <- dst[:] if dst[:] > 0
  3754. /// dst[:] <- dst[:] * weights[:] if dst[:] <= 0
  3755. ///
  3756. ///
  3757. /// Example usage:
  3758. /// @code
  3759. /// int mb = 32, oc = 32,
  3760. /// oh = 14, ow = 14; // convolution output params
  3761. /// // unique weights per output channel
  3762. /// vector<float> weights = { ... };
  3763. /// int oc_dim = 1; // mb_dim = 0, channel_dim = 1, height_dim = 2, ...
  3764. ///
  3765. /// // construct a convolution descriptor
  3766. /// dnnl::convolution::desc conv_d;
  3767. ///
  3768. /// dnnl::primitive_attr attr;
  3769. /// attr.append_prelu(1 << oc_dim);
  3770. ///
  3771. /// dnnl::primitive_desc conv_pd(conv_d, attr, engine);
  3772. /// memory prelu_weights({{1}, dt::f32, {1}}, eng, weights.data());
  3773. ///
  3774. /// std::unordered_map<int, memory> conv_args;
  3775. ///
  3776. /// conv_args.insert(
  3777. /// {DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_WEIGHTS, prelu_weights})
  3778. /// @endcode
  3779. ///
  3780. /// @note
  3781. /// The order of dimensions does not depend on how elements are laid
  3782. /// out in memory. For example:
  3783. /// - for a 2D CNN activations tensor the order is always (n, c)
  3784. /// - for a 4D CNN activations tensor the order is always (n, c, h, w)
  3785. /// - for a 5D CNN weights tensor the order is always
  3786. /// (g, oc, ic, kh, kw)
  3787. ///
  3788. /// Prelu weights tensor is passed in runtime execution phase. Prelu
  3789. /// weights tensor data type is implicitly assumed as f32 using plain
  3790. /// layout (a, ab, acb, acdb, acdeb).
  3791. ///
  3792. /// @param mask Defines the correspondence between the output tensor
  3793. /// dimensions and the prelu weights tensor. The set i-th bit indicates
  3794. /// that a dedicated weights value is used for each index along that
  3795. /// dimension. Set the mask to 0 to use a common weights value
  3796. /// for the whole output tensor.
  3797. void append_prelu(int mask) {
  3798. error::wrap_c_api(dnnl_post_ops_append_prelu(get(), mask),
  3799. "could not append a prelu post-op");
  3800. }
  3801. /// Returns the parameters of a prelu post-op.
  3802. ///
  3803. /// @param index Index of the prelu post-op.
  3804. /// @param mask Weights mask of prelu post-op.
  3805. void get_params_prelu(int index, int &mask) const {
  3806. error::wrap_c_api(dnnl_post_ops_get_params_prelu(get(), index, &mask),
  3807. "could not get parameters of a binary post-op");
  3808. }
  3809. };
  3810. /// @cond DO_NOT_DOCUMENT_THIS
  3811. template <>
  3812. struct handle_traits<dnnl_primitive_attr_t> {
  3813. static dnnl_status_t destructor(dnnl_primitive_attr_t p) {
  3814. return dnnl_primitive_attr_destroy(p);
  3815. }
  3816. };
  3817. /// @endcond
  3818. /// Primitive attributes.
  3819. ///
  3820. /// @sa @ref dev_guide_attributes
  3821. struct primitive_attr : public handle<dnnl_primitive_attr_t> {
  3822. using handle<dnnl_primitive_attr_t>::handle;
  3823. /// Constructs default (empty) primitive attributes.
  3824. primitive_attr() {
  3825. dnnl_primitive_attr_t result;
  3826. error::wrap_c_api(dnnl_primitive_attr_create(&result),
  3827. "could not create primitive attribute");
  3828. reset(result);
  3829. }
  3830. /// Creates primitive attributes from a C API ::dnnl_primitive_attr_t
  3831. /// handle. The resulting handle is not weak and the C handle will be
  3832. /// destroyed during the destruction of the C++ object.
  3833. ///
  3834. /// @param attr The C API primitive attributes.
  3835. primitive_attr(dnnl_primitive_attr_t attr)
  3836. : handle<dnnl_primitive_attr_t>(attr) {}
  3837. /// Returns the parameters of a dropout attribute.
  3838. ///
  3839. /// @param mask_desc Output memory descriptor of a dropout mask.
  3840. void get_dropout(memory::desc &mask_desc) const {
  3841. const_dnnl_memory_desc_t cdesc;
  3842. error::wrap_c_api(dnnl_primitive_attr_get_dropout(get(), &cdesc),
  3843. "could not get parameters of a dropout attribute");
  3844. dnnl_memory_desc_t cloned_md = nullptr;
  3845. error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc),
  3846. "could not clone a memory descriptor");
  3847. mask_desc = memory::desc(cloned_md);
  3848. }
  3849. /// Sets dropout probability.
  3850. ///
  3851. /// @param mask_desc Output memory descriptor of a dropout mask.
  3852. void set_dropout(const memory::desc &mask_desc) {
  3853. error::wrap_c_api(
  3854. dnnl_primitive_attr_set_dropout(get(), mask_desc.get()),
  3855. "could not set dropout primitive attribute");
  3856. }
  3857. /// Returns the fpmath mode
  3858. fpmath_mode get_fpmath_mode() const {
  3859. dnnl_fpmath_mode_t result;
  3860. error::wrap_c_api(dnnl_primitive_attr_get_fpmath_mode(get(), &result),
  3861. "could not get fpmath mode primitive attribute");
  3862. return fpmath_mode(result);
  3863. }
  3864. /// Returns the fpmath mode
  3865. ///
  3866. /// @param mode Specified fpmath mode.
  3867. /// @param apply_to_int Use floating-point arithmetic for integer primitives.
  3868. void get_fpmath_mode(fpmath_mode &mode, bool &apply_to_int) const {
  3869. dnnl_fpmath_mode_t c_mode;
  3870. int c_apply_to_int;
  3871. error::wrap_c_api(dnnl_primitive_attr_get_fpmath_mode_v2(
  3872. get(), &c_mode, &c_apply_to_int),
  3873. "could not get fpmath mode primitive attribute");
  3874. mode = fpmath_mode(c_mode);
  3875. apply_to_int = static_cast<bool>(c_apply_to_int);
  3876. }
  3877. /// Sets fpmath mode.
  3878. ///
  3879. /// @param mode Specified fpmath mode.
  3880. /// @param apply_to_int Boolean. Use of floating-point arithmetic for integer primitives.
  3881. void set_fpmath_mode(fpmath_mode mode, bool apply_to_int = false) {
  3882. error::wrap_c_api(dnnl_primitive_attr_set_fpmath_mode_v2(get(),
  3883. dnnl::convert_to_c(mode), apply_to_int),
  3884. "could not set fpmath mode primitive attribute");
  3885. }
  3886. /// Returns the accumulation mode
  3887. accumulation_mode get_accumulation_mode() const {
  3888. dnnl_accumulation_mode_t result;
  3889. error::wrap_c_api(
  3890. dnnl_primitive_attr_get_accumulation_mode(get(), &result),
  3891. "could not get accumulation mode primitive attribute");
  3892. return accumulation_mode(result);
  3893. }
  3894. /// Sets accumulation mode.
  3895. ///
  3896. /// @param mode Specified accumulation mode.
  3897. void set_accumulation_mode(accumulation_mode mode) {
  3898. error::wrap_c_api(dnnl_primitive_attr_set_accumulation_mode(
  3899. get(), dnnl::convert_to_c(mode)),
  3900. "could not set accumulation mode primitive attribute");
  3901. }
  3902. /// Returns the deterministic attribute value
  3903. bool get_deterministic() const {
  3904. int result;
  3905. error::wrap_c_api(dnnl_primitive_attr_get_deterministic(get(), &result),
  3906. "could not get deterministic primitive attribute");
  3907. return static_cast<bool>(result);
  3908. }
  3909. /// Sets deterministic attribute value
  3910. ///
  3911. /// @param value Specified deterministic mode.
  3912. void set_deterministic(bool value) {
  3913. error::wrap_c_api(dnnl_primitive_attr_set_deterministic(
  3914. get(), static_cast<int>(value)),
  3915. "could not set deterministic primitive attribute");
  3916. }
  3917. /// Returns the rounding mode attribute value
  3918. ///
  3919. /// @param arg Argument for which rounding mode query applies.
  3920. /// @returns The rounding mode applied to the specified argument.
  3921. rounding_mode get_rounding_mode(int arg) const {
  3922. dnnl_rounding_mode_t result;
  3923. error::wrap_c_api(dnnl_primitive_attr_get_rounding(get(), arg, &result),
  3924. "could not get rounding mode primitive attribute");
  3925. return rounding_mode(result);
  3926. }
  3927. /// Sets the rounding mode attribute value for a given argument
  3928. ///
  3929. /// @param arg Argument for which to set rounding mode.
  3930. /// @param mode Rounding mode to apply.
  3931. void set_rounding_mode(int arg, rounding_mode mode) {
  3932. error::wrap_c_api(dnnl_primitive_attr_set_rounding(
  3933. get(), arg, convert_to_c(mode)),
  3934. "could not set rounding mode primitive attribute");
  3935. }
  3936. /// Returns the scratchpad mode.
  3937. scratchpad_mode get_scratchpad_mode() const {
  3938. dnnl_scratchpad_mode_t result;
  3939. error::wrap_c_api(
  3940. dnnl_primitive_attr_get_scratchpad_mode(get(), &result),
  3941. "could not get scratchpad mode primitive attribute");
  3942. return scratchpad_mode(result);
  3943. }
  3944. /// Sets scratchpad mode.
  3945. ///
  3946. /// @param mode Specified scratchpad mode.
  3947. void set_scratchpad_mode(scratchpad_mode mode) {
  3948. error::wrap_c_api(dnnl_primitive_attr_set_scratchpad_mode(
  3949. get(), dnnl::convert_to_c(mode)),
  3950. "could not set scratchpad mode primitive attribute");
  3951. }
  3952. /// Sets scaling factors for primitive operations for a given memory
  3953. /// argument. The scaling factors must be passed at execution time
  3954. /// as an argument with index #DNNL_ARG_ATTR_SCALES | arg.
  3955. ///
  3956. /// @sa dnnl_primitive_attr_set_scales_mask
  3957. ///
  3958. /// @param arg Parameter argument index as passed to the
  3959. /// primitive::execute() call.
  3960. /// @param mask Scaling factors correspondence mask that defines the
  3961. /// correspondence between the tensor dimensions and the @p scales
  3962. /// vector. The set i-th bit indicates that a dedicated scaling factor
  3963. /// is used for each index along that dimension. Set the mask to 0 to
  3964. /// use a common scaling factor for the whole output tensor.
  3965. void set_scales_mask(int arg, int mask) {
  3966. error::wrap_c_api(dnnl_primitive_attr_set_scales_mask(get(), arg, mask),
  3967. "could not set scales primitive attribute");
  3968. }
  3969. /// Sets scaling factors for primitive operations for a given memory
  3970. /// argument. The scaling factors must be passed at execution time
  3971. /// as an argument with index #DNNL_ARG_ATTR_SCALES | arg.
  3972. ///
  3973. /// @note If `is_on_host` is true, sets a single host-side scalar scaling
  3974. /// factor for the specified memory argument. The scaling factor should be
  3975. /// passed as a host scalar memory object at execution time with index
  3976. /// #DNNL_ARG_ATTR_SCALES | arg.
  3977. ///
  3978. /// @sa dnnl_primitive_attr_set_scales_v2
  3979. ///
  3980. /// @param arg Parameter argument index as passed to the
  3981. /// primitive::execute() call.
  3982. /// @param mask Scales correspondence mask that defines the
  3983. /// correspondence between the tensor dimensions and the @p
  3984. /// scales vector. The set i-th bit indicates that a dedicated
  3985. /// scale is used for each index along that dimension. Set the
  3986. /// mask to 0 to use a common scale for the whole output tensor.
  3987. /// @param groups Scaling factors correspondence groups that define the
  3988. /// correspondence between the tensor dimensions and the scales array.
  3989. /// The set i-th dimension indicates a number of groups of scaling
  3990. /// factors used for that logical dimension in a memory indicated by @p arg.
  3991. /// @param data_type Scaling factors data_type.
  3992. /// @param is_on_host Indicates whether the scaling factor is a host-side scalar.
  3993. void set_scales(int arg, int mask, const memory::dims &groups,
  3994. memory::data_type data_type = memory::data_type::f32,
  3995. bool is_on_host = false) {
  3996. error::wrap_c_api(dnnl_primitive_attr_set_scales_v2(get(), arg, mask,
  3997. (int)groups.size(), groups.data(),
  3998. memory::convert_to_c(data_type), is_on_host),
  3999. "could not set scales primitive attribute");
  4000. }
  4001. /// Sets a single host-side scalar scaling
  4002. /// factor for the specified memory argument. The scaling factor should be
  4003. /// passed as a host scalar memory object at execution time with index
  4004. /// #DNNL_ARG_ATTR_SCALES | arg.
  4005. ///
  4006. /// @note Using this API to set the scaling factor implies that the scales
  4007. /// attribute has `mask == 0` and an empty groups vector.
  4008. ///
  4009. /// @sa dnnl_primitive_attr_set_scales_v2
  4010. ///
  4011. /// @param arg Parameter argument index as passed to the
  4012. /// primitive::execute() call.
  4013. /// @param data_type Scaling factors data_type.
  4014. void set_host_scale(
  4015. int arg, memory::data_type data_type = memory::data_type::f32) {
  4016. error::wrap_c_api(dnnl_primitive_attr_set_scales_v2(get(), arg, 0, 0,
  4017. nullptr, memory::convert_to_c(data_type), 1),
  4018. "could not set scales primitive attribute");
  4019. }
  4020. /// Sets zero points for primitive operations for a given memory argument.
  4021. /// The zero points must be passed at execution time as an argument with
  4022. /// index #DNNL_ARG_ATTR_ZERO_POINTS | arg.
  4023. ///
  4024. /// @sa dnnl_primitive_attr_set_zero_points_mask
  4025. ///
  4026. /// @param arg Parameter argument index as passed to the
  4027. /// primitive::execute() call.
  4028. /// @param mask Zero point correspondence mask that defines the
  4029. /// correspondence between the tensor dimensions and the @p
  4030. /// zero_points vector. The set i-th bit indicates that a dedicated
  4031. /// zero point is used for each index along that dimension. Set the
  4032. /// mask to 0 to use a common zero point for the whole output tensor.
  4033. void set_zero_points_mask(int arg, int mask) {
  4034. error::wrap_c_api(
  4035. dnnl_primitive_attr_set_zero_points_mask(get(), arg, mask),
  4036. "could not set zero points primitive attribute");
  4037. }
  4038. /// Sets zero points for primitive operations for a given memory argument.
  4039. /// The zero points must be passed at execution time as an argument with
  4040. /// index #DNNL_ARG_ATTR_ZERO_POINTS | arg.
  4041. ///
  4042. /// @note If `is_on_host` is true, sets a single host-side zero point
  4043. /// for the specified memory argument. The zero point should be
  4044. /// passed as a host scalar memory object at execution time with index
  4045. /// #DNNL_ARG_ATTR_ZERO_POINTS | arg.
  4046. ///
  4047. /// @sa dnnl_primitive_attr_set_zero_points
  4048. ///
  4049. /// @param arg Parameter argument index as passed to the
  4050. /// primitive::execute() call.
  4051. /// @param mask Zero point correspondence mask that defines the
  4052. /// correspondence between the tensor dimensions and the zero points
  4053. /// vector. The set i-th bit indicates that a dedicated zero point is
  4054. /// used for each index along that dimension. Set the mask to 0 to use
  4055. /// a common zero point for the whole output tensor.
  4056. /// @param groups Zero point factors correspondence groups that define the
  4057. /// correspondence between the tensor dimensions and the zero points
  4058. /// array.
  4059. /// The set i-th dimension indicates a number of groups of zero point
  4060. /// factors used for that logical dimension in a memory indicated by
  4061. /// @p arg.
  4062. /// @param data_type Zero point factors data_type.
  4063. /// @param is_on_host Indicates whether the zero point is a host-side scalar.
  4064. void set_zero_points(int arg, int mask, const memory::dims &groups,
  4065. memory::data_type data_type = memory::data_type::s32,
  4066. bool is_on_host = false) {
  4067. error::wrap_c_api(dnnl_primitive_attr_set_zero_points_v2(get(), arg,
  4068. mask, (int)groups.size(), groups.data(),
  4069. memory::convert_to_c(data_type), is_on_host),
  4070. "could not set zero points primitive attribute");
  4071. }
  4072. /// Sets a single host-side zero point for the specified memory argument.
  4073. /// The zero point should be passed as a host scalar memory object at
  4074. /// execution time with index #DNNL_ARG_ATTR_ZERO_POINTS | arg.
  4075. ///
  4076. /// @note Using this API to set the zero point implies that the zero
  4077. /// point attribute has `mask == 0` and an empty groups vector.
  4078. ///
  4079. /// @sa dnnl_primitive_attr_set_zero_points_v2
  4080. ///
  4081. /// @param arg Parameter argument index as passed to the
  4082. /// primitive::execute() call.
  4083. /// @param data_type Zero point data type.
  4084. void set_host_zero_point(
  4085. int arg, memory::data_type data_type = memory::data_type::s32) {
  4086. error::wrap_c_api(
  4087. dnnl_primitive_attr_set_zero_points_v2(get(), arg, 0, 0,
  4088. nullptr, memory::convert_to_c(data_type), 1),
  4089. "could not set zero points primitive attribute");
  4090. }
  4091. /// Sets precomputed reductions for primitive operations for a given memory
  4092. /// argument. The precomputed reductions must be passed at execution time as
  4093. /// an argument with index #DNNL_ARG_ATTR_PRECOMPUTED_REDUCTIONS | arg.
  4094. ///
  4095. /// @sa dnnl_primitive_attr_set_precomputed_reductions
  4096. ///
  4097. /// @param arg Parameter argument index as passed to the
  4098. /// primitive::execute() call.
  4099. /// @param mask Precomputed reductions correspondence mask that defines the
  4100. /// correspondence between the tensor dimensions and the precomputed
  4101. /// reductions vector. The set i-th bit indicates that a dedicated
  4102. /// precomputed reduction point is used for each index along that
  4103. /// dimension.
  4104. /// @param groups Precomputed reduction factors correspondence groups that
  4105. /// define the correspondence between the tensor dimensions and the
  4106. /// precomputed reductions array.
  4107. /// The set i-th dimension indicates a number of groups of precomputed
  4108. /// reduction factors used for that logical dimension in a memory
  4109. /// indicated by @p arg.
  4110. /// @param data_type Precomputed reduction factors data_type.
  4111. void set_precomputed_reductions(int arg, int mask,
  4112. const memory::dims &groups,
  4113. memory::data_type data_type = memory::data_type::s32) {
  4114. error::wrap_c_api(dnnl_primitive_attr_set_precomputed_reductions(get(),
  4115. arg, mask, (int)groups.size(), groups.data(),
  4116. memory::convert_to_c(data_type)),
  4117. "could not set precomputed reductions primitive attribute");
  4118. }
  4119. /// Returns post-ops previously set via set_post_ops().
  4120. ///
  4121. /// @returns Post-ops.
  4122. post_ops get_post_ops() const {
  4123. const_dnnl_post_ops_t const_c_post_ops;
  4124. error::wrap_c_api(
  4125. dnnl_primitive_attr_get_post_ops(get(), &const_c_post_ops),
  4126. "could not get post-ops primitive attribute");
  4127. dnnl_post_ops_t c_post_ops;
  4128. error::wrap_c_api(dnnl_post_ops_clone(&c_post_ops, const_c_post_ops),
  4129. "could not clone post-ops primitive attribute");
  4130. return post_ops(c_post_ops);
  4131. }
  4132. /// Sets post-ops.
  4133. ///
  4134. /// @note
  4135. /// There is no way to check whether the post-ops would be supported
  4136. /// by the target primitive. Any error will be reported
  4137. /// by the respective primitive descriptor constructor.
  4138. ///
  4139. /// @param ops Post-ops object to copy post-ops from.
  4140. void set_post_ops(const post_ops &ops) {
  4141. error::wrap_c_api(dnnl_primitive_attr_set_post_ops(get(), ops.get()),
  4142. "could not set post-ops primitive attribute");
  4143. }
  4144. /// Sets quantization scale and shift parameters for RNN data tensors.
  4145. ///
  4146. /// For performance reasons, the low-precision configuration of the RNN
  4147. /// primitives expect input activations to have the unsigned 8-bit integer
  4148. /// data type. The scale and shift parameters are used to quantize
  4149. /// floating-point data to unsigned integer and must be passed to the RNN
  4150. /// primitive using attributes.
  4151. ///
  4152. /// The quantization formula is `scale * data + shift`.
  4153. ///
  4154. /// Example usage:
  4155. /// @code
  4156. /// // RNN parameters
  4157. /// int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32;
  4158. /// // Activations quantization parameters
  4159. /// float scale = 63.f, shift = 64.f;
  4160. ///
  4161. /// primitive_attr attr;
  4162. ///
  4163. /// // Set scale and shift for int8 quantization of activation
  4164. /// attr.set_rnn_data_qparams(scale, shift);
  4165. ///
  4166. /// // Create an RNN primitive descriptor.
  4167. /// vanilla_rnn_forward::primitive_desc rnn_d(
  4168. /// engine, /* arguments */, attr);
  4169. /// @endcode
  4170. ///
  4171. /// @note
  4172. /// Quantization scale and shift are common for src_layer, src_iter,
  4173. /// dst_iter, and dst_layer.
  4174. ///
  4175. /// @param scale The value to scale the data by.
  4176. /// @param shift The value to shift the data by.
  4177. void set_rnn_data_qparams(float scale, float shift) {
  4178. error::wrap_c_api(
  4179. dnnl_primitive_attr_set_rnn_data_qparams(get(), scale, shift),
  4180. "could not set RNN data quantization parameters primitive "
  4181. "attribute");
  4182. }
  4183. /// Returns the quantization scale and shift parameters for RNN data
  4184. /// tensors.
  4185. ///
  4186. /// @note
  4187. /// Quantization scale and shift are common for src_layer, src_iter,
  4188. /// dst_iter, and dst_layer.
  4189. ///
  4190. /// @param scale The value to scale the data by.
  4191. /// @param shift The value to shift the data by.
  4192. void get_rnn_data_qparams(float &scale, float &shift) {
  4193. float c_scale, c_shift;
  4194. error::wrap_c_api(dnnl_primitive_attr_get_rnn_data_qparams(
  4195. get(), &c_scale, &c_shift),
  4196. "could not set RNN data quantization parameters primitive "
  4197. "attribute");
  4198. scale = c_scale;
  4199. shift = c_shift;
  4200. }
  4201. /// Sets quantization scaling factors for RNN weights tensors. The
  4202. /// low-precision configuration of the RNN primitives expect input weights
  4203. /// to use the signed 8-bit integer data type. The scaling factors are
  4204. /// used to quantize floating-point data to signed integer and must be
  4205. /// passed to RNN primitives using attributes.
  4206. ///
  4207. /// @note
  4208. /// The dimension order is always native and does not depend on the
  4209. /// actual layout used. For example, five-dimensional weights always
  4210. /// have (l, d, i, g, o) logical dimension ordering.
  4211. ///
  4212. /// @note
  4213. /// Quantization scales are common for weights_layer and
  4214. /// weights_iteration
  4215. ///
  4216. /// @param mask Scaling factors correspondence mask that defines the
  4217. /// correspondence between the output tensor dimensions and the @p
  4218. /// scales vector. The set i-th bit indicates that a dedicated scaling
  4219. /// factor should be used each index along that dimension. Set the
  4220. /// mask to 0 to use a common scaling factor for the whole output
  4221. /// tensor.
  4222. /// @param scales Constant vector of output scaling factors. The following
  4223. /// equality must hold:
  4224. /// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$
  4225. /// Violations can only be detected when the attributes are used to
  4226. /// create a primitive descriptor.
  4227. void set_rnn_weights_qparams(int mask, const std::vector<float> &scales) {
  4228. error::wrap_c_api(dnnl_primitive_attr_set_rnn_weights_qparams(get(),
  4229. (int)scales.size(), mask, scales.data()),
  4230. "could not set RNN weights quantization parameters primitive "
  4231. "attribute");
  4232. }
  4233. /// Returns the quantization scaling factors for RNN projection weights
  4234. /// tensors.
  4235. ///
  4236. /// @note
  4237. /// The dimension order is always native and does not depend on the
  4238. /// actual layout used. For example, five-dimensional weights always
  4239. /// have (l, d, i, g, o) logical dimension ordering.
  4240. ///
  4241. /// @param mask Scaling factors correspondence mask that defines the
  4242. /// correspondence between the output tensor dimensions and the @p
  4243. /// scales vector. The set i-th bit indicates that a dedicated scaling
  4244. /// factor should be used each index along that dimension. Set the
  4245. /// mask to 0 to use a common scaling factor for the whole output
  4246. /// tensor.
  4247. /// @param scales Constant vector of output scaling factors. The following
  4248. /// equality must hold:
  4249. /// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$
  4250. /// Violations can only be detected when the attributes are used to
  4251. /// create a primitive descriptor.
  4252. void get_rnn_weights_qparams(int &mask, std::vector<float> &scales) {
  4253. dnnl_dim_t count;
  4254. int c_mask;
  4255. const float *c_scales;
  4256. error::wrap_c_api(dnnl_primitive_attr_get_rnn_weights_qparams(
  4257. get(), &count, &c_mask, &c_scales),
  4258. "could not get primitive RNN weights quantization "
  4259. "parameters attributes");
  4260. scales.resize(count);
  4261. mask = c_mask;
  4262. for (dnnl_dim_t c = 0; c < count; c++)
  4263. scales[c] = c_scales[c];
  4264. }
  4265. /// Sets quantization scaling factors for RNN projection weights tensors.
  4266. // The low-precision configuration of the RNN primitives expect input
  4267. // weights to use the signed 8-bit integer data type. The scaling factors
  4268. // are used to quantize floating-point data to signed integer and must be
  4269. /// passed to RNN primitives using attributes.
  4270. ///
  4271. /// @note
  4272. /// The dimension order is always native and does not depend on the
  4273. /// actual layout used. For example, five-dimensional weights always
  4274. /// have (l, d, i, g, o) logical dimension ordering.
  4275. ///
  4276. /// @note
  4277. /// Quantization scales are common for weights_layer and
  4278. /// weights_iteration
  4279. ///
  4280. /// @param mask Scaling factors correspondence mask that defines the
  4281. /// correspondence between the output tensor dimensions and the @p
  4282. /// scales vector. The set i-th bit indicates that a dedicated scaling
  4283. /// factor should be used each index along that dimension. Set the
  4284. /// mask to 0 to use a common scaling factor for the whole output
  4285. /// tensor.
  4286. /// @param scales Constant vector of output scaling factors. The following
  4287. /// equality must hold:
  4288. /// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$
  4289. /// Violations can only be detected when the attributes are used to
  4290. /// create a primitive descriptor.
  4291. void set_rnn_weights_projection_qparams(
  4292. int mask, const std::vector<float> &scales) {
  4293. error::wrap_c_api(
  4294. dnnl_primitive_attr_set_rnn_weights_projection_qparams(
  4295. get(), (int)scales.size(), mask, scales.data()),
  4296. "could not set primitive RNN weights projection quantization "
  4297. "parameters attributes");
  4298. }
  4299. /// Returns the quantization scaling factors for RNN projection weights
  4300. /// tensors.
  4301. ///
  4302. /// @note
  4303. /// The dimension order is always native and does not depend on the
  4304. /// actual layout used. For example, five-dimensional weights always
  4305. /// have (l, d, i, g, o) logical dimension ordering.
  4306. ///
  4307. /// @param mask Scaling factors correspondence mask that defines the
  4308. /// correspondence between the output tensor dimensions and the @p
  4309. /// scales vector. The set i-th bit indicates that a dedicated scaling
  4310. /// factor should be used each index along that dimension. Set the
  4311. /// mask to 0 to use a common scaling factor for the whole output
  4312. /// tensor.
  4313. /// @param scales Constant vector of output scaling factors. The following
  4314. /// equality must hold:
  4315. /// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$
  4316. /// Violations can only be detected when the attributes are used to
  4317. /// create a primitive descriptor.
  4318. void get_rnn_weights_projection_qparams(
  4319. int &mask, std::vector<float> &scales) {
  4320. dnnl_dim_t count;
  4321. int c_mask;
  4322. const float *c_scales;
  4323. error::wrap_c_api(
  4324. dnnl_primitive_attr_get_rnn_weights_projection_qparams(
  4325. get(), &count, &c_mask, &c_scales),
  4326. "could not get primitive RNN weights projection quantization "
  4327. "parameters attributes");
  4328. scales.resize(count);
  4329. mask = c_mask;
  4330. for (dnnl_dim_t c = 0; c < count; c++)
  4331. scales[c] = c_scales[c];
  4332. }
  4333. };
  4334. /// @} dnnl_api_attributes
  4335. /// @addtogroup dnnl_api_primitives_common
  4336. /// @{
  4337. /// Base class for all primitive descriptors.
  4338. struct primitive_desc_base : public handle<dnnl_primitive_desc_t> {
  4339. using handle<dnnl_primitive_desc_t>::handle;
  4340. /// Default constructor. Produces an empty object.
  4341. primitive_desc_base() = default;
  4342. /// Returns the engine of the primitive descriptor.
  4343. /// @returns The engine of the primitive descriptor.
  4344. engine get_engine() const { return query_engine(query::engine); }
  4345. /// Returns implementation name.
  4346. /// @returns The implementation name.
  4347. const char *impl_info_str() const {
  4348. const char *res;
  4349. error::wrap_c_api(dnnl_primitive_desc_query(
  4350. get(), dnnl_query_impl_info_str, 0, &res),
  4351. "could not retrieve implementation info string from a "
  4352. "primitive descriptor");
  4353. return res;
  4354. }
  4355. /// Returns a memory::dim value (same as int64_t).
  4356. /// @param what The value to query.
  4357. /// @returns The result of the query.
  4358. memory::dim query_s64(query what) const {
  4359. memory::dim res;
  4360. dnnl_status_t status = dnnl_primitive_desc_query(
  4361. get(), dnnl::convert_to_c(what), 0, &res);
  4362. return status == dnnl_success ? res : 0;
  4363. }
  4364. /// Returns strides.
  4365. /// @returns Strides.
  4366. /// @returns An empty #dnnl::memory::dims if the primitive does not have
  4367. /// a strides parameter.
  4368. memory::dims get_strides() const { return query_dims(query::strides); }
  4369. /// Returns dilations.
  4370. /// @returns Dilations.
  4371. /// @returns An empty #dnnl::memory::dims if the primitive does not have
  4372. /// a dilations parameter.
  4373. memory::dims get_dilations() const { return query_dims(query::dilations); }
  4374. /// Returns a left padding.
  4375. /// @returns A left padding.
  4376. /// @returns An empty #dnnl::memory::dims if the primitive does not have
  4377. /// a left padding parameter.
  4378. memory::dims get_padding_l() const { return query_dims(query::padding_l); }
  4379. /// Returns a right padding.
  4380. /// @returns A right padding.
  4381. /// @returns An empty #dnnl::memory::dims if the primitive does not have
  4382. /// a right padding parameter.
  4383. memory::dims get_padding_r() const { return query_dims(query::padding_r); }
  4384. /// Returns an epsilon.
  4385. /// @returns An epsilon.
  4386. /// @returns Zero if the primitive does not have an epsilon parameter.
  4387. float get_epsilon() const { return query_f32(query::epsilon_f32); }
  4388. /// Returns flags.
  4389. /// @tparam T Flags enumeration type.
  4390. /// @returns Flags.
  4391. /// @returns Zero if the primitive does not have a flags parameter.
  4392. template <typename T = unsigned>
  4393. T get_flags() const {
  4394. unsigned res;
  4395. dnnl_status_t status
  4396. = dnnl_primitive_desc_query(get(), dnnl_query_flags, 0, &res);
  4397. return static_cast<T>(status == dnnl_success ? res : 0x0U);
  4398. }
  4399. /// Returns an algorithm kind.
  4400. /// @returns An algorithm kind.
  4401. /// @returns #dnnl::algorithm::undef if the primitive does not have an
  4402. /// algorithm parameter.
  4403. dnnl::algorithm get_algorithm() const { return query_alg(query::alg_kind); }
  4404. /// Returns an alpha.
  4405. /// @returns An alpha.
  4406. /// @returns Zero if the primitive does not have an alpha parameter.
  4407. float get_alpha() const { return query_f32(query::alpha_f32); }
  4408. /// Returns a beta.
  4409. /// @returns A beta.
  4410. /// @returns Zero if the primitive does not have a beta parameter.
  4411. float get_beta() const { return query_f32(query::beta_f32); }
  4412. /// Returns an axis.
  4413. /// @returns An axis.
  4414. /// @returns A negative number if the primitive does not have an axis
  4415. /// parameter.
  4416. int get_axis() const {
  4417. int res;
  4418. dnnl_status_t status = dnnl_primitive_desc_query(
  4419. get(), dnnl_query_axis_s32, 0, &res);
  4420. return status == dnnl_success ? res : -1;
  4421. }
  4422. /// Returns an LRN local size parameter.
  4423. /// @returns An LRN local size parameter.
  4424. /// @returns Zero if the primitive does not have an LRN local size
  4425. /// parameter.
  4426. memory::dim get_local_size() const {
  4427. return query_s64(query::local_size_s64);
  4428. }
  4429. /// Returns an LRN K parameter.
  4430. /// @returns An LRN K parameter.
  4431. /// @returns Zero if the primitive does not have an LRN K parameter.
  4432. float get_k() const { return query_f32(query::k_f32); }
  4433. /// Returns a reduction P parameter.
  4434. /// @returns A reduction P parameter.
  4435. /// @returns Zero if the primitive does not have a reduction P parameter.
  4436. float get_p() const { return query_f32(query::p_f32); }
  4437. /// Returns a resampling factors parameters.
  4438. /// @returns A vector of factors.
  4439. /// @returns An empty vector if the primitive does not have a resampling
  4440. /// factors parameter.
  4441. std::vector<float> get_factors() const {
  4442. float *factors;
  4443. dnnl_status_t status = dnnl_primitive_desc_query(
  4444. get(), dnnl_query_factors, 0, &factors);
  4445. const bool is_backward = get_prop_kind() != prop_kind::forward_training
  4446. && get_prop_kind() != prop_kind::forward_inference;
  4447. const_dnnl_memory_desc_t md = dnnl_primitive_desc_query_md(get(),
  4448. is_backward ? dnnl_query_diff_dst_md : dnnl_query_dst_md, 0);
  4449. int ndims;
  4450. error::wrap_c_api(
  4451. dnnl_memory_desc_query(md, dnnl_query_ndims_s32, &ndims),
  4452. "could not query ndims from a memory descriptor");
  4453. return status == dnnl_success
  4454. ? std::vector<float>(factors, factors + (ndims - 2))
  4455. : std::vector<float> {};
  4456. }
  4457. /// Returns an RNN cell kind parameter.
  4458. /// @returns An RNN cell kind parameter.
  4459. /// @returns #dnnl::algorithm::undef if the primitive does not have an
  4460. /// RNN cell kind parameter.
  4461. dnnl::algorithm get_cell_kind() const {
  4462. return query_alg(query::cell_kind);
  4463. }
  4464. /// Returns an RNN direction parameter.
  4465. /// @returns An RNN direction parameter.
  4466. /// @returns #dnnl::rnn_direction::undef if the primitive does not have
  4467. /// an RNN direction parameter.
  4468. dnnl::rnn_direction get_direction() const {
  4469. dnnl_rnn_direction_t direction;
  4470. dnnl_status_t status = dnnl_primitive_desc_query(
  4471. get(), dnnl_query_direction, 0, &direction);
  4472. return status == dnnl_success
  4473. ? static_cast<dnnl::rnn_direction>(direction)
  4474. : dnnl::rnn_direction::undef;
  4475. }
  4476. /// Returns an RNN activation kind parameter.
  4477. /// @returns An RNN activation kind parameter.
  4478. /// @returns #dnnl::algorithm::undef if the primitive does not have an
  4479. /// RNN activation kind parameter.
  4480. dnnl::algorithm get_activation_kind() const {
  4481. return query_alg(query::activation_kind);
  4482. }
  4483. /// Returns a pooling kernel parameter.
  4484. /// @returns A pooling kernel parameter.
  4485. /// @returns An empty #dnnl::memory::dims if the primitive does not have
  4486. /// a pooling kernel parameter.
  4487. memory::dims get_kernel() const { return query_dims(query::kernel); }
  4488. /// Returns a group size parameter.
  4489. /// @returns A group size parameter.
  4490. /// @returns Zero if the primitive does not have a group size
  4491. /// parameter.
  4492. memory::dim get_group_size() const {
  4493. return query_s64(query::group_size_s64);
  4494. }
  4495. /// Returns a propagation kind.
  4496. /// @returns A propagation kind.
  4497. /// @returns #dnnl::prop_kind::undef if the primitive does not have
  4498. /// a propagation parameter.
  4499. dnnl::prop_kind get_prop_kind() const {
  4500. dnnl_prop_kind_t prop_kind;
  4501. dnnl_status_t status = dnnl_primitive_desc_query(
  4502. get(), dnnl_query_prop_kind, 0, &prop_kind);
  4503. return status == dnnl_success ? static_cast<dnnl::prop_kind>(prop_kind)
  4504. : dnnl::prop_kind::undef;
  4505. }
  4506. /// Returns a memory descriptor.
  4507. ///
  4508. /// @note
  4509. /// There are also convenience methods
  4510. /// #dnnl::primitive_desc_base::src_desc(),
  4511. /// #dnnl::primitive_desc_base::dst_desc(), and others.
  4512. ///
  4513. /// @param what The kind of parameter to query; can be
  4514. /// #dnnl::query::src_md, #dnnl::query::dst_md, etc.
  4515. /// @param idx Index of the parameter. For example, convolution bias can
  4516. /// be queried with what = #dnnl::query::weights_md and idx = 1.
  4517. /// @returns The requested memory descriptor.
  4518. /// @returns A zero memory descriptor if the primitive does not have a
  4519. /// parameter of the specified kind or index.
  4520. memory::desc query_md(query what, int idx = 0) const {
  4521. std::vector<query> valid_q {query::src_md, query::diff_src_md,
  4522. query::weights_md, query::diff_weights_md, query::dst_md,
  4523. query::diff_dst_md, query::workspace_md, query::scratchpad_md,
  4524. query::exec_arg_md};
  4525. if (!std::any_of(valid_q.cbegin(), valid_q.cend(),
  4526. [=](query q) { return what == q; }))
  4527. DNNL_THROW_ERROR(dnnl_invalid_arguments,
  4528. "memory descriptor query is invalid");
  4529. const_dnnl_memory_desc_t cdesc = dnnl_primitive_desc_query_md(
  4530. get(), dnnl::convert_to_c(what), idx);
  4531. if (!cdesc) return memory::desc();
  4532. dnnl_memory_desc_t cloned_md = nullptr;
  4533. error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc),
  4534. "could not clone a memory descriptor");
  4535. return memory::desc(cloned_md);
  4536. }
  4537. /// Returns a source memory descriptor.
  4538. /// @param idx Source index.
  4539. /// @returns Source memory descriptor.
  4540. /// @returns A zero memory descriptor if the primitive does not have a
  4541. /// source parameter with index @p idx.
  4542. memory::desc src_desc(int idx) const {
  4543. return query_md(query::src_md, idx);
  4544. }
  4545. /// Returns a destination memory descriptor.
  4546. /// @param idx Destination index.
  4547. /// @returns Destination memory descriptor.
  4548. /// @returns A zero memory descriptor if the primitive does not have a
  4549. /// destination parameter with index @p idx.
  4550. memory::desc dst_desc(int idx) const {
  4551. return query_md(query::dst_md, idx);
  4552. }
  4553. /// Returns a weights memory descriptor.
  4554. /// @param idx Weights index.
  4555. /// @returns Weights memory descriptor.
  4556. /// @returns A zero memory descriptor if the primitive does not have a
  4557. /// weights parameter with index @p idx.
  4558. memory::desc weights_desc(int idx) const {
  4559. return query_md(query::weights_md, idx);
  4560. }
  4561. /// Returns a diff source memory descriptor.
  4562. /// @param idx Diff source index.
  4563. /// @returns Diff source memory descriptor.
  4564. /// @returns A zero memory descriptor if the primitive does not have a
  4565. /// diff source parameter with index @p idx.
  4566. memory::desc diff_src_desc(int idx) const {
  4567. return query_md(query::diff_src_md, idx);
  4568. }
  4569. /// Returns a diff destination memory descriptor.
  4570. /// @param idx Diff destination index.
  4571. /// @returns Diff destination memory descriptor.
  4572. /// @returns A zero memory descriptor if the primitive does not have a
  4573. /// diff destination parameter with index @p idx.
  4574. memory::desc diff_dst_desc(int idx) const {
  4575. return query_md(query::diff_dst_md, idx);
  4576. }
  4577. /// Returns a diff weights memory descriptor.
  4578. /// @param idx Diff weights index.
  4579. /// @returns Diff weights memory descriptor.
  4580. /// @returns A zero memory descriptor if the primitive does not have a
  4581. /// diff weights parameter with index @p idx.
  4582. memory::desc diff_weights_desc(int idx) const {
  4583. return query_md(query::diff_weights_md, idx);
  4584. }
  4585. // Separate versions without the index argument for documentation
  4586. // purposes.
  4587. /// Returns a source memory descriptor.
  4588. /// @returns Source memory descriptor.
  4589. /// @returns A zero memory descriptor if the primitive does not have a
  4590. /// source parameter.
  4591. memory::desc src_desc() const { return src_desc(0); }
  4592. /// Returns a destination memory descriptor.
  4593. /// @returns Destination memory descriptor.
  4594. /// @returns A zero memory descriptor if the primitive does not have a
  4595. /// destination parameter.
  4596. memory::desc dst_desc() const { return dst_desc(0); }
  4597. /// Returns a weights memory descriptor.
  4598. /// @returns Weights memory descriptor.
  4599. /// @returns A zero memory descriptor if the primitive does not have a
  4600. /// weights parameter.
  4601. memory::desc weights_desc() const { return weights_desc(0); }
  4602. /// Returns a diff source memory descriptor.
  4603. /// @returns Diff source memory descriptor.
  4604. /// @returns A zero memory descriptor if the primitive does not have a
  4605. /// diff source memory with.
  4606. memory::desc diff_src_desc() const { return diff_src_desc(0); }
  4607. /// Returns a diff destination memory descriptor.
  4608. /// @returns Diff destination memory descriptor.
  4609. /// @returns A zero memory descriptor if the primitive does not have a
  4610. /// diff destination parameter.
  4611. memory::desc diff_dst_desc() const { return diff_dst_desc(0); }
  4612. /// Returns a diff weights memory descriptor.
  4613. /// @returns Diff weights memory descriptor.
  4614. /// @returns A zero memory descriptor if the primitive does not have a
  4615. /// diff weights parameter.
  4616. memory::desc diff_weights_desc() const { return diff_weights_desc(0); }
  4617. /// Returns the workspace memory descriptor.
  4618. /// @returns Workspace memory descriptor.
  4619. /// @returns A zero memory descriptor if the primitive does not require
  4620. /// workspace parameter.
  4621. memory::desc workspace_desc() const {
  4622. return query_md(query::workspace_md, 0);
  4623. }
  4624. /// Returns the scratchpad memory descriptor.
  4625. /// @returns scratchpad memory descriptor.
  4626. /// @returns A zero memory descriptor if the primitive does not require
  4627. /// scratchpad parameter.
  4628. /// @sa @ref dev_guide_attributes_scratchpad
  4629. memory::desc scratchpad_desc() const {
  4630. return query_md(query::scratchpad_md, 0);
  4631. }
  4632. /// Returns the engine on which the scratchpad memory is located.
  4633. /// @returns The engine on which the scratchpad memory is located.
  4634. engine scratchpad_engine() const {
  4635. dnnl_engine_t c_engine;
  4636. error::wrap_c_api(dnnl_primitive_desc_query(get(),
  4637. dnnl::convert_to_c(query::scratchpad_engine),
  4638. 0, &c_engine),
  4639. "could not retrieve scratchpad engine from a primitive "
  4640. "descriptor");
  4641. return engine(c_engine, true);
  4642. }
  4643. /// Returns the primitive attributes.
  4644. /// @returns The primitive attributes.
  4645. primitive_attr get_primitive_attr() const {
  4646. const_dnnl_primitive_attr_t const_c_attr;
  4647. error::wrap_c_api(dnnl_primitive_desc_get_attr(get(), &const_c_attr),
  4648. "could not get attributes from a primitive descriptor");
  4649. dnnl_primitive_attr_t c_attr;
  4650. error::wrap_c_api(dnnl_primitive_attr_clone(&c_attr, const_c_attr),
  4651. "could not clone primitive attributes");
  4652. return primitive_attr(c_attr);
  4653. }
  4654. /// Returns the kind of the primitive descriptor.
  4655. /// @returns The kind of the primitive descriptor.
  4656. dnnl::primitive::kind get_kind() const {
  4657. dnnl_primitive_kind_t kind;
  4658. error::wrap_c_api(dnnl_primitive_desc_query(get(),
  4659. dnnl_query_primitive_kind, 0, (void *)&kind),
  4660. "could not get primitive kind from a primitive descriptor");
  4661. return static_cast<dnnl::primitive::kind>(kind);
  4662. }
  4663. /// Returns the cache blob ID of the primitive descriptor.
  4664. /// @returns The cache blob ID of the primitive descriptor.
  4665. std::vector<uint8_t> get_cache_blob_id() const {
  4666. dnnl_dim_t count;
  4667. const uint8_t *c_id;
  4668. error::wrap_c_api(
  4669. dnnl_primitive_desc_query(get(),
  4670. dnnl::convert_to_c(query::cache_blob_id_size_s64), 0,
  4671. (void *)&count),
  4672. "could not get size of cache blob ID from a primitive "
  4673. "descriptor");
  4674. error::wrap_c_api(dnnl_primitive_desc_query(get(),
  4675. dnnl::convert_to_c(query::cache_blob_id), 0,
  4676. (void **)&c_id),
  4677. "could not get cache blob ID from a primitive descriptor");
  4678. std::vector<uint8_t> id(c_id, c_id + count);
  4679. return id;
  4680. }
  4681. protected:
  4682. /// Returns a float value.
  4683. /// @param what The value to query.
  4684. /// @returns The result of the query.
  4685. /// @returns Zero if the primitive doesn't support the query.
  4686. float query_f32(query what) const {
  4687. float res;
  4688. dnnl_status_t status = dnnl_primitive_desc_query(
  4689. get(), dnnl::convert_to_c(what), 0, &res);
  4690. return status == dnnl_success ? res : 0.0f;
  4691. }
  4692. /// Returns an #dnnl::algorithm value.
  4693. /// @param what The value to query.
  4694. /// @returns The result of the query.
  4695. /// @returns #dnnl::algorithm::undef if the primitive doesn't support
  4696. /// the query.
  4697. algorithm query_alg(query what) const {
  4698. dnnl_alg_kind_t res;
  4699. dnnl_status_t status = dnnl_primitive_desc_query(
  4700. get(), dnnl::convert_to_c(what), 0, &res);
  4701. return status == dnnl_success ? static_cast<dnnl::algorithm>(res)
  4702. : algorithm::undef;
  4703. }
  4704. /// Returns a memory::dims value.
  4705. /// @param what The value to query.
  4706. /// @returns The result of the query.
  4707. /// @returns An empty #dnnl::memory::dims if the primitive doesn't support
  4708. /// the query.
  4709. memory::dims query_dims(query what) const {
  4710. const bool is_backward = get_prop_kind() != prop_kind::forward_training
  4711. && get_prop_kind() != prop_kind::forward_inference;
  4712. const_dnnl_memory_desc_t md = dnnl_primitive_desc_query_md(get(),
  4713. is_backward ? dnnl_query_diff_dst_md : dnnl_query_dst_md, 0);
  4714. int nspatial_dims = 0;
  4715. if (md) {
  4716. int ndims;
  4717. error::wrap_c_api(
  4718. dnnl_memory_desc_query(md, dnnl_query_ndims_s32, &ndims),
  4719. "could not query ndims from a memory descriptor");
  4720. nspatial_dims = ndims - 2;
  4721. }
  4722. dnnl_dims_t *c_dims;
  4723. dnnl_status_t status = dnnl_primitive_desc_query(
  4724. get(), dnnl::convert_to_c(what), 0, &c_dims);
  4725. return status == dnnl_success
  4726. ? memory::dims(*c_dims, *c_dims + nspatial_dims)
  4727. : memory::dims {};
  4728. }
  4729. /// Returns an #dnnl::engine value.
  4730. /// @param what The value to query.
  4731. /// @returns The result of the query.
  4732. /// @returns A weak handle to the engine that the primitive descriptor was
  4733. /// created with.
  4734. engine query_engine(query what) const {
  4735. dnnl_engine_t c_engine;
  4736. error::wrap_c_api(dnnl_primitive_desc_query(get(),
  4737. dnnl::convert_to_c(what), 0, &c_engine),
  4738. "could not get an engine from a primitive_desc");
  4739. return engine(c_engine, true);
  4740. }
  4741. /// Resets the value of the handle to a clone of a C API primitive
  4742. /// descriptor.
  4743. /// @param pd A C API primitive descriptor to clone.
  4744. void reset_with_clone(const_dnnl_primitive_desc_t pd) {
  4745. dnnl_primitive_desc_t new_pd;
  4746. error::wrap_c_api(dnnl_primitive_desc_clone(&new_pd, pd),
  4747. "could not clone a primitive descriptor");
  4748. reset(new_pd);
  4749. }
  4750. /// Constructs a primitive descriptor base object from a clone of a C API
  4751. /// primitive descriptor after verifying that it is what the caller
  4752. /// expects.
  4753. ///
  4754. /// @note
  4755. /// The @p prim_kind should map to a primitive that does not have
  4756. /// different values of propagation kind (e.g. #dnnl::binary).
  4757. /// @note
  4758. /// Primitive descriptor base constructed this way does not support
  4759. /// next_impl() (will throw).
  4760. ///
  4761. /// @param pd C API primitive descriptor to clone.
  4762. /// @param prim_kind Expected primitive kind.
  4763. primitive_desc_base(
  4764. dnnl_primitive_desc_t pd, dnnl::primitive::kind prim_kind)
  4765. : primitive_desc_base(pd, prim_kind, dnnl::prop_kind::undef) {}
  4766. /// Constructs a primitive descriptor base object from a clone of a C API
  4767. /// primitive descriptor after verifying that it is what the caller
  4768. /// expects.
  4769. ///
  4770. /// @note
  4771. /// Primitive descriptor base constructed this way does not support
  4772. /// next_impl() (will throw).
  4773. ///
  4774. /// @param pd C API primitive descriptor to clone.
  4775. /// @param prim_kind Expected primitive kind.
  4776. /// @param aprop_kind Expected propagation kind.
  4777. primitive_desc_base(dnnl_primitive_desc_t pd,
  4778. dnnl::primitive::kind prim_kind, dnnl::prop_kind aprop_kind)
  4779. : primitive_desc_base(pd, prim_kind, aprop_kind, aprop_kind) {}
  4780. /// Constructs a primitive descriptor base object from a clone of a C API
  4781. /// primitive descriptor after verifying that it is what the caller
  4782. /// expects.
  4783. ///
  4784. /// @note
  4785. /// Primitive descriptor base constructed this way does not support
  4786. /// next_impl() (will throw).
  4787. ///
  4788. /// @param pd C API primitive descriptor to clone.
  4789. /// @param prim_kind Expected primitive kind.
  4790. /// @param prop_kind1 Expected propagation kind (option 1).
  4791. /// @param prop_kind2 Expected propagation kind (option 2). This value is
  4792. /// checked if the check with @p prop_kind1 fails.
  4793. primitive_desc_base(dnnl_primitive_desc_t pd,
  4794. dnnl::primitive::kind prim_kind, dnnl::prop_kind prop_kind1,
  4795. dnnl::prop_kind prop_kind2) {
  4796. // It is OK to pass an empty primitive descriptor
  4797. if (pd == nullptr) return;
  4798. dnnl_status_t rc;
  4799. dnnl_primitive_kind_t c_prim_kind = convert_to_c(prim_kind);
  4800. dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1);
  4801. dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2);
  4802. // Check that primitive kind matches
  4803. dnnl_primitive_kind_t pd_kind;
  4804. rc = dnnl_primitive_desc_query(
  4805. pd, dnnl_query_primitive_kind, 0, (void *)&pd_kind);
  4806. error::wrap_c_api(
  4807. rc, "could not get primitive kind from a primitive descriptor");
  4808. if (pd_kind != c_prim_kind)
  4809. DNNL_THROW_ERROR(dnnl_invalid_arguments,
  4810. "primitive descriptor operation kind mismatch");
  4811. // Check that propagation kind matches
  4812. dnnl_prop_kind_t pd_prop_kind;
  4813. rc = dnnl_primitive_desc_query(
  4814. pd, dnnl_query_prop_kind, 0, (void *)&pd_prop_kind);
  4815. // Something went wrong
  4816. if (rc != dnnl_success && rc != dnnl_unimplemented)
  4817. DNNL_THROW_ERROR(dnnl_invalid_arguments,
  4818. "could not get propagation kind from the primitive "
  4819. "descriptor");
  4820. // Everything is fine
  4821. if ((rc == dnnl_unimplemented && c_prop_kind1 == dnnl_prop_kind_undef)
  4822. || (rc == dnnl_success
  4823. && (pd_prop_kind == c_prop_kind1
  4824. || pd_prop_kind == c_prop_kind2))) {
  4825. reset_with_clone(pd);
  4826. return;
  4827. }
  4828. // We could get the propagation kind but there is a mismatch
  4829. DNNL_THROW_ERROR(dnnl_invalid_arguments,
  4830. "primitive descriptor propagation kind mismatch");
  4831. }
  4832. /// Returns a constant reference to a static instance of default constructed
  4833. /// primitive attributes
  4834. static const primitive_attr &default_attr() {
  4835. static const primitive_attr attr;
  4836. return attr;
  4837. }
  4838. const_dnnl_memory_desc_t optional_arg(const memory::desc *md) {
  4839. return md ? md->get() : nullptr;
  4840. }
  4841. const dnnl_dim_t *optional_arg(const memory::dims *dims) {
  4842. return dims ? dims->data() : nullptr;
  4843. }
  4844. const float *optional_arg(const std::vector<float> *arg) {
  4845. return arg ? arg->data() : nullptr;
  4846. }
  4847. using base = primitive_desc_base;
  4848. };
  4849. /// @} dnnl_api_primitives_common
  4850. /// @addtogroup dnnl_api_reorder Reorder
  4851. ///
  4852. /// A primitive to copy data between two memory objects. This primitive is
  4853. /// typically used to change the way the data is laid out in memory.
  4854. ///
  4855. /// @sa @ref dev_guide_reorder in developer guide
  4856. ///
  4857. /// @{
  4858. /// Reorder primitive.
  4859. struct reorder : public primitive {
  4860. /// Primitive descriptor for a reorder primitive.
  4861. struct primitive_desc : public primitive_desc_base {
  4862. using primitive_desc_base::primitive_desc_base;
  4863. /// Default constructor. Produces an empty object.
  4864. primitive_desc() = default;
  4865. /// Constructs a primitive descriptor for reorder primitive.
  4866. ///
  4867. /// @note
  4868. /// If @p allow_empty is true, the constructor does not throw if a
  4869. /// primitive descriptor cannot be created.
  4870. ///
  4871. /// @param src_engine Engine on which the source memory object will be
  4872. /// located.
  4873. /// @param src_md Source memory descriptor.
  4874. /// @param dst_engine Engine on which the destination memory object
  4875. /// will be located.
  4876. /// @param dst_md Destination memory descriptor.
  4877. /// @param attr Primitive attributes to use. Attributes are optional
  4878. /// and default to empty attributes.
  4879. /// @param allow_empty A flag signifying whether construction is allowed
  4880. /// to fail without throwing an exception. In this case an empty
  4881. /// object will be produced. This flag is optional and defaults to
  4882. /// false.
  4883. primitive_desc(const engine &src_engine, const memory::desc &src_md,
  4884. const engine &dst_engine, const memory::desc &dst_md,
  4885. const primitive_attr &attr = default_attr(),
  4886. bool allow_empty = false) {
  4887. dnnl_primitive_desc_t result;
  4888. dnnl_status_t status = dnnl_reorder_primitive_desc_create(&result,
  4889. src_md.get(), src_engine.get(), dst_md.get(),
  4890. dst_engine.get(), attr.get());
  4891. if (!allow_empty)
  4892. error::wrap_c_api(status,
  4893. "could not create a primitive descriptor for "
  4894. "the reorder primitive. Run workload with "
  4895. "environment variable ONEDNN_VERBOSE=all to get "
  4896. "additional diagnostic information.");
  4897. reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
  4898. }
  4899. /// Constructs a primitive descriptor for reorder primitive.
  4900. ///
  4901. /// @param src Source memory object. It is used to obtain the source
  4902. /// memory descriptor and engine.
  4903. /// @param dst Destination memory object. It is used to obtain the
  4904. /// destination memory descriptor and engine.
  4905. /// @param attr Primitive attributes to use. Attributes are optional
  4906. /// and default to empty attributes.
  4907. /// @param allow_empty A flag signifying whether construction is allowed
  4908. /// to fail without throwing an exception. In this case an empty
  4909. /// object will be produced. This flag is optional and defaults to
  4910. /// false.
  4911. primitive_desc(const memory &src, const memory &dst,
  4912. const primitive_attr &attr = default_attr(),
  4913. bool allow_empty = false) {
  4914. dnnl_primitive_desc_t result;
  4915. auto src_md = src.get_desc();
  4916. auto dst_md = dst.get_desc();
  4917. dnnl_status_t status = dnnl_reorder_primitive_desc_create(&result,
  4918. src_md.get(), src.get_engine().get(), dst_md.get(),
  4919. dst.get_engine().get(), attr.get());
  4920. if (!allow_empty)
  4921. error::wrap_c_api(status,
  4922. "could not create a primitive descriptor for "
  4923. "the reorder primitive. Run workload with "
  4924. "environment variable ONEDNN_VERBOSE=all to get "
  4925. "additional diagnostic information.");
  4926. reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
  4927. }
  4928. /// Constructs a primitive descriptor for reorder primitive from a C
  4929. /// API primitive descriptor which must have a matching kind.
  4930. ///
  4931. /// @param pd C API primitive descriptor for reorder primitive.
  4932. primitive_desc(dnnl_primitive_desc_t pd)
  4933. : primitive_desc_base(pd, dnnl::primitive::kind::reorder) {}
  4934. /// Returns the engine on which the source memory is allocated.
  4935. /// @returns The engine on which the source memory is allocated.
  4936. engine get_src_engine() const {
  4937. return query_engine(dnnl::query::reorder_src_engine);
  4938. }
  4939. /// Returns the engine on which the destination memory is allocated.
  4940. /// @returns The engine on which the destination memory is allocated.
  4941. engine get_dst_engine() const {
  4942. return query_engine(dnnl::query::reorder_dst_engine);
  4943. }
  4944. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  4945. memory::desc src_desc() const { return base::src_desc(0); }
  4946. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  4947. memory::desc dst_desc() const { return base::dst_desc(0); }
  4948. };
  4949. /// Default constructor. Produces an empty object.
  4950. reorder() = default;
  4951. /// Constructs a reorder primitive.
  4952. /// @param pd Primitive descriptor for reorder primitive.
  4953. reorder(const primitive_desc &pd) : primitive(pd.get()) {}
  4954. /// Constructs a reorder primitive from a cache blob.
  4955. /// @param pd Primitive descriptor for reorder primitive.
  4956. /// @param cache_blob Cache blob.
  4957. reorder(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  4958. : primitive(pd.get(), cache_blob) {}
  4959. /// Constructs a reorder primitive that would reorder data between memory
  4960. /// objects having the same memory descriptors as memory objects @p src and
  4961. /// @p dst.
  4962. ///
  4963. /// @param src Source memory object.
  4964. /// @param dst Destination memory object.
  4965. /// @param attr Primitive attributes to use (optional).
  4966. reorder(const memory &src, const memory &dst,
  4967. const primitive_attr &attr = primitive_attr())
  4968. : primitive(primitive_desc(src, dst, attr).get()) {}
  4969. using primitive::execute;
  4970. /// Executes the reorder primitive.
  4971. ///
  4972. /// @param astream Stream object. The stream must belong to the same engine
  4973. /// as the primitive.
  4974. /// @param src Source memory object.
  4975. /// @param dst Destination memory object.
  4976. void execute(const stream &astream, memory &src, memory &dst) const {
  4977. primitive::execute(astream, {{DNNL_ARG_FROM, src}, {DNNL_ARG_TO, dst}});
  4978. }
  4979. };
  4980. /// @} dnnl_api_reorder
  4981. /// @addtogroup dnnl_api_concat Concat
  4982. ///
  4983. /// A primitive to concatenate data by arbitrary dimension.
  4984. ///
  4985. /// @sa @ref dev_guide_concat in developer guide
  4986. ///
  4987. /// @{
  4988. /// @cond DO_NOT_DOCUMENT_THIS
  4989. inline std::vector<const_dnnl_memory_desc_t> convert_to_c(
  4990. const std::vector<memory::desc> &mds) {
  4991. std::vector<const_dnnl_memory_desc_t> c_mds;
  4992. c_mds.reserve(mds.size());
  4993. for (const auto &md : mds)
  4994. c_mds.push_back(md.get());
  4995. return c_mds;
  4996. }
  4997. /// @endcond
  4998. /// Tensor concatenation (concat) primitive.
  4999. struct concat : public primitive {
  5000. /// Primitive descriptor for a concat primitive.
  5001. struct primitive_desc : public primitive_desc_base {
  5002. using primitive_desc_base::primitive_desc_base;
  5003. /// Default constructor. Produces an empty object.
  5004. primitive_desc() = default;
  5005. /// Constructs a primitive descriptor for an out-of-place concatenation
  5006. /// primitive.
  5007. ///
  5008. /// @param aengine Engine to perform the operation on.
  5009. /// @param dst Destination memory descriptor.
  5010. /// @param concat_dimension Source tensors will be concatenated over
  5011. /// dimension with this index. Note that order of dimensions does
  5012. /// not depend on memory format.
  5013. /// @param srcs Vector of source memory descriptors.
  5014. /// @param attr Primitive attributes to use. Attributes are optional
  5015. /// and default to empty attributes.
  5016. /// @param allow_empty A flag signifying whether construction is
  5017. /// allowed to fail without throwing an exception. In this case an
  5018. /// empty object will be produced. This flag is optional and
  5019. /// defaults to false.
  5020. primitive_desc(const engine &aengine, const memory::desc &dst,
  5021. int concat_dimension, const std::vector<memory::desc> &srcs,
  5022. const primitive_attr &attr = default_attr(),
  5023. bool allow_empty = false) {
  5024. auto c_srcs = convert_to_c(srcs);
  5025. dnnl_primitive_desc_t result;
  5026. dnnl_status_t status = dnnl_concat_primitive_desc_create(&result,
  5027. aengine.get(), dst.get(), (int)c_srcs.size(),
  5028. concat_dimension, c_srcs.data(), attr.get());
  5029. if (!allow_empty)
  5030. error::wrap_c_api(status,
  5031. "could not create a primitive descriptor for "
  5032. "the concat primitive. Run workload with "
  5033. "environment variable ONEDNN_VERBOSE=all to get "
  5034. "additional diagnostic information.");
  5035. reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
  5036. }
  5037. /// Constructs a primitive descriptor for an out-of-place concatenation
  5038. /// primitive.
  5039. ///
  5040. /// This version derives the destination memory descriptor
  5041. /// automatically.
  5042. ///
  5043. /// @param aengine Engine to perform the operation on.
  5044. /// @param concat_dimension Source tensors will be concatenated over
  5045. /// dimension with this index. Note that order of dimensions does
  5046. /// not depend on memory format.
  5047. /// @param srcs Vector of source memory descriptors.
  5048. /// @param attr Primitive attributes to use. Attributes are optional
  5049. /// and default to empty attributes.
  5050. /// @param allow_empty A flag signifying whether construction is
  5051. /// allowed to fail without throwing an exception. In this case an
  5052. /// empty object will be produced. This flag is optional and
  5053. /// defaults to false.
  5054. primitive_desc(const engine &aengine, int concat_dimension,
  5055. const std::vector<memory::desc> &srcs,
  5056. const primitive_attr &attr = default_attr(),
  5057. bool allow_empty = false) {
  5058. auto c_api_srcs = convert_to_c(srcs);
  5059. dnnl_primitive_desc_t result;
  5060. dnnl_status_t status = dnnl_concat_primitive_desc_create(&result,
  5061. aengine.get(), nullptr, (int)c_api_srcs.size(),
  5062. concat_dimension, c_api_srcs.data(), attr.get());
  5063. if (!allow_empty)
  5064. error::wrap_c_api(status,
  5065. "could not create a primitive descriptor for "
  5066. "the concat primitive. Run workload with "
  5067. "environment variable ONEDNN_VERBOSE=all to get "
  5068. "additional diagnostic information.");
  5069. reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
  5070. }
  5071. /// Constructs a primitive descriptor for concat primitive from a C
  5072. /// API primitive descriptor which must have a matching kind.
  5073. ///
  5074. /// @param pd C API primitive descriptor for concat primitive.
  5075. primitive_desc(dnnl_primitive_desc_t pd)
  5076. : primitive_desc_base(pd, dnnl::primitive::kind::concat) {}
  5077. /// @copydoc dnnl::primitive_desc_base::src_desc(int)const
  5078. memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
  5079. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  5080. memory::desc dst_desc() const { return base::dst_desc(0); }
  5081. };
  5082. /// Default constructor. Produces an empty object.
  5083. concat() = default;
  5084. /// Constructs a concatenation primitive.
  5085. /// @param pd Primitive descriptor for concatenation primitive.
  5086. concat(const primitive_desc &pd) : primitive(pd.get()) {}
  5087. /// Constructs a concatenation primitive from a cache blob.
  5088. /// @param pd Primitive descriptor for concatenation primitive.
  5089. /// @param cache_blob Cache blob.
  5090. concat(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  5091. : primitive(pd.get(), cache_blob) {}
  5092. };
  5093. /// @} dnnl_api_concat
  5094. /// @addtogroup dnnl_api_sum Sum
  5095. ///
  5096. /// A primitive to sum multiple tensors.
  5097. ///
  5098. /// @sa @ref dev_guide_sum in developer guide
  5099. ///
  5100. /// @{
  5101. /// Out-of-place summation (sum) primitive.
  5102. struct sum : public primitive {
  5103. /// Primitive descriptor for a sum primitive.
  5104. struct primitive_desc : public primitive_desc_base {
  5105. using primitive_desc_base::primitive_desc_base;
  5106. /// Default constructor. Produces an empty object.
  5107. primitive_desc() = default;
  5108. /// Constructs a primitive descriptor for a sum primitive.
  5109. ///
  5110. /// @param aengine Engine to perform the operation on.
  5111. /// @param dst Destination memory descriptor.
  5112. /// @param scales Vector of scales to multiply data in each source
  5113. /// memory by.
  5114. /// @param srcs Vector of source memory descriptors.
  5115. /// @param attr Primitive attributes to use. Attributes are optional
  5116. /// and default to empty attributes.
  5117. /// @param allow_empty A flag signifying whether construction is
  5118. /// allowed to fail without throwing an exception. In this case an
  5119. /// empty object will be produced. This flag is optional and
  5120. /// defaults to false.
  5121. primitive_desc(const engine &aengine, const memory::desc &dst,
  5122. const std::vector<float> &scales,
  5123. const std::vector<memory::desc> &srcs,
  5124. const primitive_attr &attr = default_attr(),
  5125. bool allow_empty = false) {
  5126. validate_container_size(scales,
  5127. "counts of scales and sources are not equal",
  5128. (int)srcs.size(), (int)srcs.size());
  5129. auto c_api_srcs = convert_to_c(srcs);
  5130. dnnl_primitive_desc_t result;
  5131. dnnl_status_t status = dnnl_sum_primitive_desc_create(&result,
  5132. aengine.get(), dst.get(), (int)c_api_srcs.size(),
  5133. scales.data(), c_api_srcs.data(), attr.get());
  5134. if (!allow_empty)
  5135. error::wrap_c_api(status,
  5136. "could not create a primitive descriptor for "
  5137. "the sum primitive. Run workload with "
  5138. "environment variable ONEDNN_VERBOSE=all to get "
  5139. "additional diagnostic information.");
  5140. reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
  5141. }
  5142. /// Constructs a primitive descriptor for a sum primitive.
  5143. ///
  5144. /// This version derives the destination memory descriptor
  5145. /// automatically.
  5146. ///
  5147. /// @param aengine Engine on which to perform the operation.
  5148. /// @param scales Vector of scales by which to multiply data in each
  5149. /// source memory object.
  5150. /// @param srcs Vector of source memory descriptors.
  5151. /// @param attr Primitive attributes to use. Attributes are optional
  5152. /// and default to empty attributes.
  5153. /// @param allow_empty A flag signifying whether construction is
  5154. /// allowed to fail without throwing an exception. In this case an
  5155. /// empty object will be produced. This flag is optional and
  5156. /// defaults to false.
  5157. primitive_desc(const engine &aengine, const std::vector<float> &scales,
  5158. const std::vector<memory::desc> &srcs,
  5159. const primitive_attr &attr = default_attr(),
  5160. bool allow_empty = false) {
  5161. validate_container_size(scales,
  5162. "counts of scales and sources are not equal",
  5163. (int)srcs.size(), (int)srcs.size());
  5164. auto c_api_srcs = convert_to_c(srcs);
  5165. dnnl_primitive_desc_t result;
  5166. dnnl_status_t status = dnnl_sum_primitive_desc_create(&result,
  5167. aengine.get(), nullptr, (int)c_api_srcs.size(),
  5168. scales.data(), c_api_srcs.data(), attr.get());
  5169. if (!allow_empty)
  5170. error::wrap_c_api(status,
  5171. "could not create a primitive descriptor for "
  5172. "the sum primitive. Run workload with "
  5173. "environment variable ONEDNN_VERBOSE=all to get "
  5174. "additional diagnostic information.");
  5175. reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
  5176. }
  5177. /// Constructs a primitive descriptor for sum primitive from a C API
  5178. /// primitive descriptor which must have a matching kind.
  5179. ///
  5180. /// @param pd C API primitive descriptor for sum primitive.
  5181. primitive_desc(dnnl_primitive_desc_t pd)
  5182. : primitive_desc_base(pd, dnnl::primitive::kind::sum) {}
  5183. /// @copydoc dnnl::primitive_desc_base::src_desc(int)const
  5184. memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
  5185. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  5186. memory::desc dst_desc() const { return base::dst_desc(0); }
  5187. };
  5188. /// Default constructor. Produces an empty object.
  5189. sum() = default;
  5190. /// Constructs a sum primitive.
  5191. /// @param pd Primitive descriptor for sum primitive.
  5192. sum(const primitive_desc &pd) : primitive(pd.get()) {}
  5193. /// Constructs a sum primitive from a cache blob.
  5194. /// @param pd Primitive descriptor for sum primitive.
  5195. /// @param cache_blob Cache blob.
  5196. sum(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  5197. : primitive(pd.get(), cache_blob) {}
  5198. };
  5199. /// @} dnnl_api_sum
  5200. /// @addtogroup dnnl_api_primitives_common
  5201. /// @{
  5202. /// A base class for descriptors of all primitives that support iteration
  5203. /// over multiple implementations.
  5204. struct primitive_desc : public primitive_desc_base {
  5205. using primitive_desc_base::primitive_desc_base;
  5206. primitive_desc() = default;
  5207. /// Changes the primitive descriptor to point to the next available
  5208. /// implementation.
  5209. ///
  5210. /// @returns @c true on success and @c false if the last available
  5211. /// implementation has already been reached. In the latter case, the
  5212. /// primitive descriptor itself is kept unchanged.
  5213. bool next_impl() {
  5214. dnnl_status_t status = dnnl_primitive_desc_next_impl(get());
  5215. if (status == dnnl_last_impl_reached) return false;
  5216. error::wrap_c_api(status, "last available implementation is reached");
  5217. return true;
  5218. }
  5219. };
  5220. /// @} dnnl_api_primitives_common
  5221. /// @addtogroup dnnl_api_convolution Convolution
  5222. ///
  5223. /// A primitive to perform 1D, 2D or 3D convolution. Supported variants are
  5224. /// forward propagation, backward propagation, and weights gradient with or
  5225. /// without bias.
  5226. ///
  5227. /// @sa @ref dev_guide_convolution in developer guide
  5228. ///
  5229. /// @{
  5230. /// Convolution forward propagation primitive.
  5231. struct convolution_forward : public primitive {
  5232. /// Primitive descriptor for a convolution forward propagation primitive.
  5233. struct primitive_desc : public dnnl::primitive_desc {
  5234. /// Default constructor. Produces an empty object.
  5235. primitive_desc() = default;
  5236. /// Constructs a primitive descriptor for a convolution forward
  5237. /// propagation primitive with bias.
  5238. ///
  5239. /// @note
  5240. /// All the memory descriptors may be initialized with the
  5241. /// #dnnl::memory::format_tag::any value of @p format_tag.
  5242. ///
  5243. /// Arrays @p strides, @p padding_l, and @p padding_r contain values
  5244. /// for spatial dimensions only and hence must have the same number of
  5245. /// elements as there are spatial dimensions. The order of values is
  5246. /// the same as in the tensor: depth (for 3D tensors), height (for 3D
  5247. /// and 2D tensors), and width.
  5248. ///
  5249. /// @param aengine Engine to use.
  5250. /// @param aprop_kind Propagation kind. Possible values are
  5251. /// #dnnl::prop_kind::forward_training, and
  5252. /// #dnnl::prop_kind::forward_inference.
  5253. /// @param aalgorithm Convolution algorithm. Possible values are
  5254. /// #dnnl::algorithm::convolution_direct,
  5255. /// #dnnl::algorithm::convolution_winograd, and
  5256. /// #dnnl::algorithm::convolution_auto.
  5257. /// @param src_desc Source memory descriptor.
  5258. /// @param weights_desc Weights memory descriptor.
  5259. /// @param bias_desc Bias memory descriptor. Passing zero memory
  5260. /// descriptor disables the bias term.
  5261. /// @param dst_desc Destination memory descriptor.
  5262. /// @param strides Strides for each spatial dimension.
  5263. /// @param padding_l Vector of padding values for low indices for each
  5264. /// spatial dimension `([[front,] top,] left)`.
  5265. /// @param padding_r Vector of padding values for high indices for
  5266. /// each spatial dimension `([[back,] bottom,] right)`.
  5267. /// @param attr Primitive attributes to use. Attributes are optional
  5268. /// and default to empty attributes.
  5269. /// @param allow_empty A flag signifying whether construction is
  5270. /// allowed to fail without throwing an exception. In this case an
  5271. /// empty object will be produced. This flag is optional and
  5272. /// defaults to false.
  5273. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  5274. algorithm aalgorithm, const memory::desc &src_desc,
  5275. const memory::desc &weights_desc, const memory::desc &bias_desc,
  5276. const memory::desc &dst_desc, const memory::dims &strides,
  5277. const memory::dims &padding_l, const memory::dims &padding_r,
  5278. const primitive_attr &attr = default_attr(),
  5279. bool allow_empty = false)
  5280. : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
  5281. weights_desc, &bias_desc, dst_desc, strides, nullptr,
  5282. padding_l, padding_r, attr, allow_empty) {}
  5283. /// Constructs a primitive descriptor for a convolution forward
  5284. /// propagation primitive without bias.
  5285. ///
  5286. /// @note
  5287. /// All the memory descriptors may be initialized with the
  5288. /// #dnnl::memory::format_tag::any value of @p format_tag.
  5289. ///
  5290. /// Arrays @p strides, @p padding_l, and @p padding_r contain values
  5291. /// for spatial dimensions only and hence must have the same number of
  5292. /// elements as there are spatial dimensions. The order of values is
  5293. /// the same as in the tensor: depth (for 3D tensors), height (for 3D
  5294. /// and 2D tensors), and width.
  5295. ///
  5296. /// @param aengine Engine to use.
  5297. /// @param aprop_kind Propagation kind. Possible values are
  5298. /// #dnnl::prop_kind::forward_training, and
  5299. /// #dnnl::prop_kind::forward_inference.
  5300. /// @param aalgorithm Convolution algorithm. Possible values are
  5301. /// #dnnl::algorithm::convolution_direct,
  5302. /// #dnnl::algorithm::convolution_winograd, and
  5303. /// #dnnl::algorithm::convolution_auto.
  5304. /// @param src_desc Source memory descriptor.
  5305. /// @param weights_desc Weights memory descriptor.
  5306. /// @param dst_desc Destination memory descriptor.
  5307. /// @param strides Strides for each spatial dimension.
  5308. /// @param padding_l Vector of padding values for low indices for each
  5309. /// spatial dimension `([[front,] top,] left)`.
  5310. /// @param padding_r Vector of padding values for high indices for
  5311. /// each spatial dimension `([[back,] bottom,] right)`.
  5312. /// @param attr Primitive attributes to use. Attributes are optional
  5313. /// and default to empty attributes.
  5314. /// @param allow_empty A flag signifying whether construction is
  5315. /// allowed to fail without throwing an exception. In this case an
  5316. /// empty object will be produced. This flag is optional and
  5317. /// defaults to false.
  5318. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  5319. algorithm aalgorithm, const memory::desc &src_desc,
  5320. const memory::desc &weights_desc, const memory::desc &dst_desc,
  5321. const memory::dims &strides, const memory::dims &padding_l,
  5322. const memory::dims &padding_r,
  5323. const primitive_attr &attr = default_attr(),
  5324. bool allow_empty = false)
  5325. : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
  5326. weights_desc, nullptr, dst_desc, strides, nullptr,
  5327. padding_l, padding_r, attr, allow_empty) {}
  5328. /// Constructs a primitive descriptor for a convolution forward
  5329. /// propagation primitive with bias.
  5330. ///
  5331. /// @note
  5332. /// All the memory descriptors may be initialized with the
  5333. /// #dnnl::memory::format_tag::any value of @p format_tag.
  5334. ///
  5335. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
  5336. /// contain values for spatial dimensions only and hence must have the
  5337. /// same number of elements as there are spatial dimensions. The order
  5338. /// of values is the same as in the tensor: depth (for 3D tensors),
  5339. /// height (for 3D and 2D tensors), and width.
  5340. ///
  5341. /// @param aengine Engine to use.
  5342. /// @param aprop_kind Propagation kind. Possible values are
  5343. /// #dnnl::prop_kind::forward_training, and
  5344. /// #dnnl::prop_kind::forward_inference.
  5345. /// @param aalgorithm Convolution algorithm. Possible values are
  5346. /// #dnnl::algorithm::convolution_direct,
  5347. /// #dnnl::algorithm::convolution_winograd, and
  5348. /// #dnnl::algorithm::convolution_auto.
  5349. /// @param src_desc Source memory descriptor.
  5350. /// @param weights_desc Weights memory descriptor.
  5351. /// @param bias_desc Bias memory descriptor. Passing zero memory
  5352. /// descriptor disables the bias term.
  5353. /// @param dst_desc Destination memory descriptor.
  5354. /// @param strides Strides for each spatial dimension.
  5355. /// @param dilates Dilations for each spatial dimension. A zero value
  5356. /// means no dilation in the corresponding dimension.
  5357. /// @param padding_l Vector of padding values for low indices for each
  5358. /// spatial dimension `([[front,] top,] left)`.
  5359. /// @param padding_r Vector of padding values for high indices for
  5360. /// each spatial dimension `([[back,] bottom,] right)`.
  5361. /// @param attr Primitive attributes to use. Attributes are optional
  5362. /// and default to empty attributes.
  5363. /// @param allow_empty A flag signifying whether construction is
  5364. /// allowed to fail without throwing an exception. In this case an
  5365. /// empty object will be produced. This flag is optional and
  5366. /// defaults to false.
  5367. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  5368. algorithm aalgorithm, const memory::desc &src_desc,
  5369. const memory::desc &weights_desc, const memory::desc &bias_desc,
  5370. const memory::desc &dst_desc, const memory::dims &strides,
  5371. const memory::dims &dilates, const memory::dims &padding_l,
  5372. const memory::dims &padding_r,
  5373. const primitive_attr &attr = default_attr(),
  5374. bool allow_empty = false)
  5375. : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
  5376. weights_desc, &bias_desc, dst_desc, strides, &dilates,
  5377. padding_l, padding_r, attr, allow_empty) {}
  5378. /// Constructs a primitive descriptor for a convolution forward
  5379. /// propagation primitive without bias.
  5380. ///
  5381. /// @note
  5382. /// All the memory descriptors may be initialized with the
  5383. /// #dnnl::memory::format_tag::any value of @p format_tag.
  5384. ///
  5385. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
  5386. /// contain values for spatial dimensions only and hence must have the
  5387. /// same number of elements as there are spatial dimensions. The order
  5388. /// of values is the same as in the tensor: depth (for 3D tensors),
  5389. /// height (for 3D and 2D tensors), and width.
  5390. ///
  5391. /// @param aengine Engine to use.
  5392. /// @param aprop_kind Propagation kind. Possible values are
  5393. /// #dnnl::prop_kind::forward_training, and
  5394. /// #dnnl::prop_kind::forward_inference.
  5395. /// @param aalgorithm Convolution algorithm. Possible values are
  5396. /// #dnnl::algorithm::convolution_direct,
  5397. /// #dnnl::algorithm::convolution_winograd, and
  5398. /// #dnnl::algorithm::convolution_auto.
  5399. /// @param src_desc Source memory descriptor.
  5400. /// @param weights_desc Weights memory descriptor.
  5401. /// @param dst_desc Destination memory descriptor.
  5402. /// @param strides Strides for each spatial dimension.
  5403. /// @param dilates Dilations for each spatial dimension. A zero value
  5404. /// means no dilation in the corresponding dimension.
  5405. /// @param padding_l Vector of padding values for low indices for each
  5406. /// spatial dimension `([[front,] top,] left)`.
  5407. /// @param padding_r Vector of padding values for high indices for
  5408. /// each spatial dimension `([[back,] bottom,] right)`.
  5409. /// @param attr Primitive attributes to use. Attributes are optional
  5410. /// and default to empty attributes.
  5411. /// @param allow_empty A flag signifying whether construction is
  5412. /// allowed to fail without throwing an exception. In this case an
  5413. /// empty object will be produced. This flag is optional and
  5414. /// defaults to false.
  5415. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  5416. algorithm aalgorithm, const memory::desc &src_desc,
  5417. const memory::desc &weights_desc, const memory::desc &dst_desc,
  5418. const memory::dims &strides, const memory::dims &dilates,
  5419. const memory::dims &padding_l, const memory::dims &padding_r,
  5420. const primitive_attr &attr = default_attr(),
  5421. bool allow_empty = false)
  5422. : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
  5423. weights_desc, nullptr, dst_desc, strides, &dilates,
  5424. padding_l, padding_r, attr, allow_empty) {}
  5425. /// Constructs a primitive descriptor for a convolution forward
  5426. /// propagation primitive from a C API primitive descriptor that must
  5427. /// have a matching kind.
  5428. ///
  5429. /// @param pd C API primitive descriptor for a convolution forward
  5430. /// propagation primitive.
  5431. primitive_desc(dnnl_primitive_desc_t pd)
  5432. : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
  5433. dnnl::prop_kind::forward_training,
  5434. dnnl::prop_kind::forward_inference) {}
  5435. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  5436. memory::desc src_desc() const { return base::src_desc(0); }
  5437. /// @copydoc dnnl::primitive_desc_base::weights_desc()const
  5438. memory::desc weights_desc() const { return base::weights_desc(0); }
  5439. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  5440. memory::desc dst_desc() const { return base::dst_desc(0); }
  5441. /// Returns the bias memory descriptor.
  5442. /// @returns The bias memory descriptor.
  5443. /// @returns A zero memory descriptor of the primitive does not have a
  5444. /// bias parameter.
  5445. memory::desc bias_desc() const { return base::weights_desc(1); }
  5446. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  5447. algorithm get_algorithm() const { return base::get_algorithm(); }
  5448. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  5449. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  5450. /// @copydoc dnnl::primitive_desc_base::get_strides()const
  5451. memory::dims get_strides() const { return base::get_strides(); }
  5452. /// @copydoc dnnl::primitive_desc_base::get_dilations()const
  5453. memory::dims get_dilations() const { return base::get_dilations(); }
  5454. /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
  5455. memory::dims get_padding_l() const { return base::get_padding_l(); }
  5456. /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
  5457. memory::dims get_padding_r() const { return base::get_padding_r(); }
  5458. private:
  5459. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  5460. algorithm aalgorithm, const memory::desc &src_desc,
  5461. const memory::desc &weights_desc, const memory::desc *bias_desc,
  5462. const memory::desc &dst_desc, const memory::dims &strides,
  5463. const memory::dims *dilates, const memory::dims &padding_l,
  5464. const memory::dims &padding_r, const primitive_attr &attr,
  5465. bool allow_empty) {
  5466. memory::validate_dims(strides, src_desc.get_ndims() - 2);
  5467. memory::validate_dims(padding_l, src_desc.get_ndims() - 2);
  5468. memory::validate_dims(padding_r, src_desc.get_ndims() - 2);
  5469. if (dilates)
  5470. memory::validate_dims(*dilates, src_desc.get_ndims() - 2);
  5471. dnnl_primitive_desc_t pd = nullptr;
  5472. dnnl_status_t status
  5473. = dnnl_convolution_forward_primitive_desc_create(&pd,
  5474. aengine.get(), dnnl::convert_to_c(aprop_kind),
  5475. convert_to_c(aalgorithm), src_desc.get(),
  5476. weights_desc.get(), optional_arg(bias_desc),
  5477. dst_desc.get(), &strides[0], optional_arg(dilates),
  5478. &padding_l[0], &padding_r[0], attr.get());
  5479. if (!allow_empty)
  5480. error::wrap_c_api(status,
  5481. "could not create a primitive descriptor for "
  5482. "the convolution forward propagation primitive. Run "
  5483. "workload with environment variable ONEDNN_VERBOSE=all "
  5484. "to get additional diagnostic information.");
  5485. reset(pd);
  5486. }
  5487. };
  5488. /// Default constructor. Produces an empty object.
  5489. convolution_forward() = default;
  5490. /// Constructs a convolution forward propagation primitive.
  5491. /// @param pd Primitive descriptor for a convolution forward propagation
  5492. /// primitive.
  5493. convolution_forward(const primitive_desc &pd) : primitive(pd) {}
  5494. /// Constructs a convolution forward propagation primitive from a cache
  5495. /// blob.
  5496. /// @param pd Primitive descriptor for a convolution forward propagation
  5497. /// primitive.
  5498. /// @param cache_blob Cache blob.
  5499. convolution_forward(
  5500. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  5501. : primitive(pd, cache_blob) {}
  5502. };
  5503. /// Convolution backward propagation primitive.
  5504. struct convolution_backward_data : public primitive {
  5505. /// Primitive descriptor for a convolution backward propagation primitive.
  5506. struct primitive_desc : public dnnl::primitive_desc {
  5507. /// Default constructor. Produces an empty object.
  5508. primitive_desc() = default;
  5509. /// Constructs a primitive descriptor for a convolution backward
  5510. /// propagation primitive.
  5511. ///
  5512. /// @note
  5513. /// All the memory descriptors may be initialized with the
  5514. /// #dnnl::memory::format_tag::any value of @p format_tag.
  5515. ///
  5516. /// Arrays @p strides, @p padding_l, and @p padding_r contain values
  5517. /// for spatial dimensions only and hence must have the same number of
  5518. /// elements as there are spatial dimensions. The order of values is
  5519. /// the same as in the tensor: depth (for 3D tensors), height (for 3D
  5520. /// and 2D tensors), and width.
  5521. ///
  5522. /// @param aengine Engine to use.
  5523. /// @param aalgorithm Convolution algorithm. Possible values are
  5524. /// #dnnl::algorithm::convolution_direct,
  5525. /// #dnnl::algorithm::convolution_winograd, and
  5526. /// #dnnl::algorithm::convolution_auto.
  5527. /// @param diff_src_desc Diff source memory descriptor.
  5528. /// @param weights_desc Weights memory descriptor.
  5529. /// @param diff_dst_desc Diff destination memory descriptor.
  5530. /// @param strides Strides for each spatial dimension.
  5531. /// @param padding_l Vector of padding values for low indices for each
  5532. /// spatial dimension `([[front,] top,] left)`.
  5533. /// @param padding_r Vector of padding values for high indices for
  5534. /// each spatial dimension `([[back,] bottom,] right)`.
  5535. /// @param hint_fwd_pd Primitive descriptor for a convolution
  5536. /// forward propagation primitive. It is used as a hint for
  5537. /// deciding which memory format to use.
  5538. /// @param attr Primitive attributes to use. Attributes are optional
  5539. /// and default to empty attributes.
  5540. /// @param allow_empty A flag signifying whether construction is
  5541. /// allowed to fail without throwing an exception. In this case an
  5542. /// empty object will be produced. This flag is optional and
  5543. /// defaults to false.
  5544. primitive_desc(const engine &aengine, algorithm aalgorithm,
  5545. const memory::desc &diff_src_desc,
  5546. const memory::desc &weights_desc,
  5547. const memory::desc &diff_dst_desc, const memory::dims &strides,
  5548. const memory::dims &padding_l, const memory::dims &padding_r,
  5549. const convolution_forward::primitive_desc &hint_fwd_pd,
  5550. const primitive_attr &attr = default_attr(),
  5551. bool allow_empty = false)
  5552. : primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc,
  5553. diff_dst_desc, strides, nullptr, padding_l, padding_r,
  5554. hint_fwd_pd, attr, allow_empty) {}
  5555. /// Constructs a primitive descriptor for a convolution backward
  5556. /// propagation primitive.
  5557. ///
  5558. /// @note
  5559. /// All the memory descriptors may be initialized with the
  5560. /// #dnnl::memory::format_tag::any value of @p format_tag.
  5561. ///
  5562. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
  5563. /// contain values for spatial dimensions only and hence must have the
  5564. /// same number of elements as there are spatial dimensions. The order
  5565. /// of values is the same as in the tensor: depth (for 3D tensors),
  5566. /// height (for 3D and 2D tensors), and width.
  5567. ///
  5568. /// @param aengine Engine to use.
  5569. /// @param aalgorithm Convolution algorithm. Possible values are
  5570. /// #dnnl::algorithm::convolution_direct,
  5571. /// #dnnl::algorithm::convolution_winograd, and
  5572. /// #dnnl::algorithm::convolution_auto.
  5573. /// @param diff_src_desc Diff source memory descriptor.
  5574. /// @param weights_desc Weights memory descriptor.
  5575. /// @param diff_dst_desc Diff destination memory descriptor.
  5576. /// @param strides Strides for each spatial dimension.
  5577. /// @param dilates Dilations for each spatial dimension. A zero value
  5578. /// means no dilation in the corresponding dimension.
  5579. /// @param padding_l Vector of padding values for low indices for each
  5580. /// spatial dimension `([[front,] top,] left)`.
  5581. /// @param padding_r Vector of padding values for high indices for
  5582. /// each spatial dimension `([[back,] bottom,] right)`.
  5583. /// @param hint_fwd_pd Primitive descriptor for a convolution
  5584. /// forward propagation primitive. It is used as a hint for
  5585. /// deciding which memory format to use.
  5586. /// @param attr Primitive attributes to use. Attributes are optional
  5587. /// and default to empty attributes.
  5588. /// @param allow_empty A flag signifying whether construction is
  5589. /// allowed to fail without throwing an exception. In this case an
  5590. /// empty object will be produced. This flag is optional and
  5591. /// defaults to false.
  5592. primitive_desc(const engine &aengine, algorithm aalgorithm,
  5593. const memory::desc &diff_src_desc,
  5594. const memory::desc &weights_desc,
  5595. const memory::desc &diff_dst_desc, const memory::dims &strides,
  5596. const memory::dims &dilates, const memory::dims &padding_l,
  5597. const memory::dims &padding_r,
  5598. const convolution_forward::primitive_desc &hint_fwd_pd,
  5599. const primitive_attr &attr = default_attr(),
  5600. bool allow_empty = false)
  5601. : primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc,
  5602. diff_dst_desc, strides, &dilates, padding_l, padding_r,
  5603. hint_fwd_pd, attr, allow_empty) {}
  5604. /// Constructs a primitive descriptor for a convolution backward
  5605. /// propagation primitive from a C API primitive descriptor that must
  5606. /// have a matching kind.
  5607. ///
  5608. /// @param pd C API primitive descriptor for a convolution backward
  5609. /// propagation primitive.
  5610. primitive_desc(dnnl_primitive_desc_t pd)
  5611. : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
  5612. dnnl::prop_kind::backward_data) {}
  5613. /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
  5614. memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
  5615. /// @copydoc dnnl::primitive_desc_base::weights_desc()const
  5616. memory::desc weights_desc() const { return base::weights_desc(0); }
  5617. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  5618. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  5619. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  5620. algorithm get_algorithm() const { return base::get_algorithm(); }
  5621. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  5622. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  5623. /// @copydoc dnnl::primitive_desc_base::get_strides()const
  5624. memory::dims get_strides() const { return base::get_strides(); }
  5625. /// @copydoc dnnl::primitive_desc_base::get_dilations()const
  5626. memory::dims get_dilations() const { return base::get_dilations(); }
  5627. /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
  5628. memory::dims get_padding_l() const { return base::get_padding_l(); }
  5629. /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
  5630. memory::dims get_padding_r() const { return base::get_padding_r(); }
  5631. private:
  5632. primitive_desc(const engine &aengine, algorithm aalgorithm,
  5633. const memory::desc &diff_src_desc,
  5634. const memory::desc &weights_desc,
  5635. const memory::desc &diff_dst_desc, const memory::dims &strides,
  5636. const memory::dims *dilates, const memory::dims &padding_l,
  5637. const memory::dims &padding_r,
  5638. const convolution_forward::primitive_desc &hint_fwd_pd,
  5639. const primitive_attr &attr, bool allow_empty) {
  5640. memory::validate_dims(strides, diff_src_desc.get_ndims() - 2);
  5641. memory::validate_dims(padding_l, diff_src_desc.get_ndims() - 2);
  5642. memory::validate_dims(padding_r, diff_src_desc.get_ndims() - 2);
  5643. if (dilates)
  5644. memory::validate_dims(*dilates, diff_src_desc.get_ndims() - 2);
  5645. dnnl_primitive_desc_t pd = nullptr;
  5646. dnnl_status_t status
  5647. = dnnl_convolution_backward_data_primitive_desc_create(&pd,
  5648. aengine.get(), convert_to_c(aalgorithm),
  5649. diff_src_desc.get(), weights_desc.get(),
  5650. diff_dst_desc.get(), &strides[0],
  5651. optional_arg(dilates), &padding_l[0], &padding_r[0],
  5652. hint_fwd_pd.get(), attr.get());
  5653. if (!allow_empty)
  5654. error::wrap_c_api(status,
  5655. "could not create a primitive descriptor for "
  5656. "the convolution backward propagation primitive. Run "
  5657. "workload with environment variable ONEDNN_VERBOSE=all "
  5658. "to get additional diagnostic information.");
  5659. reset(pd);
  5660. }
  5661. };
  5662. /// Default constructor. Produces an empty object.
  5663. convolution_backward_data() = default;
  5664. /// Constructs a convolution backward propagation primitive.
  5665. /// @param pd Primitive descriptor for a convolution backward propagation
  5666. /// primitive.
  5667. convolution_backward_data(const primitive_desc &pd) : primitive(pd) {}
  5668. /// Constructs a convolution backward propagation primitive from a cache
  5669. /// blob.
  5670. /// @param pd Primitive descriptor for a convolution backward propagation
  5671. /// primitive.
  5672. /// @param cache_blob Cache blob.
  5673. convolution_backward_data(
  5674. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  5675. : primitive(pd, cache_blob) {}
  5676. };
  5677. /// Convolution weights gradient primitive.
  5678. struct convolution_backward_weights : public primitive {
  5679. /// Primitive descriptor for a convolution weights gradient primitive.
  5680. struct primitive_desc : public dnnl::primitive_desc {
  5681. /// Default constructor. Produces an empty object.
  5682. primitive_desc() = default;
  5683. /// Constructs a primitive descriptor for a convolution weights gradient
  5684. /// primitive with bias.
  5685. ///
  5686. /// @note
  5687. /// All the memory descriptors may be initialized with the
  5688. /// #dnnl::memory::format_tag::any value of @p format_tag.
  5689. ///
  5690. /// Arrays @p strides, @p padding_l, and @p padding_r contain values
  5691. /// for spatial dimensions only and hence must have the same number of
  5692. /// elements as there are spatial dimensions. The order of values is
  5693. /// the same as in the tensor: depth (for 3D tensors), height (for 3D
  5694. /// and 2D tensors), and width.
  5695. ///
  5696. /// @param aengine Engine to use.
  5697. /// @param aalgorithm Convolution algorithm. Possible values are
  5698. /// #dnnl::algorithm::convolution_direct,
  5699. /// #dnnl::algorithm::convolution_winograd, and
  5700. /// #dnnl::algorithm::convolution_auto.
  5701. /// @param src_desc Source memory descriptor.
  5702. /// @param diff_weights_desc Diff weights memory descriptor.
  5703. /// @param diff_bias_desc Diff bias memory descriptor. Passing zero
  5704. /// memory descriptor disables the bias term.
  5705. /// @param diff_dst_desc Diff destination memory descriptor.
  5706. /// @param strides Strides for each spatial dimension.
  5707. /// @param padding_l Vector of padding values for low indices for each
  5708. /// spatial dimension `([[front,] top,] left)`.
  5709. /// @param padding_r Vector of padding values for high indices for
  5710. /// each spatial dimension `([[back,] bottom,] right)`.
  5711. /// @param hint_fwd_pd Primitive descriptor for a convolution
  5712. /// forward propagation primitive. It is used as a hint for
  5713. /// deciding which memory format to use.
  5714. /// @param attr Primitive attributes to use. Attributes are optional
  5715. /// and default to empty attributes.
  5716. /// @param allow_empty A flag signifying whether construction is
  5717. /// allowed to fail without throwing an exception. In this case an
  5718. /// empty object will be produced. This flag is optional and
  5719. /// defaults to false.
  5720. primitive_desc(const engine &aengine, algorithm aalgorithm,
  5721. const memory::desc &src_desc,
  5722. const memory::desc &diff_weights_desc,
  5723. const memory::desc &diff_bias_desc,
  5724. const memory::desc &diff_dst_desc, const memory::dims &strides,
  5725. const memory::dims &padding_l, const memory::dims &padding_r,
  5726. const convolution_forward::primitive_desc &hint_fwd_pd,
  5727. const primitive_attr &attr = default_attr(),
  5728. bool allow_empty = false)
  5729. : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
  5730. &diff_bias_desc, diff_dst_desc, strides, nullptr, padding_l,
  5731. padding_r, hint_fwd_pd, attr, allow_empty) {}
  5732. /// Constructs a primitive descriptor for a convolution weights gradient
  5733. /// primitive without bias.
  5734. ///
  5735. /// @note
  5736. /// All the memory descriptors may be initialized with the
  5737. /// #dnnl::memory::format_tag::any value of @p format_tag.
  5738. ///
  5739. /// Arrays @p strides, @p padding_l, and @p padding_r contain values
  5740. /// for spatial dimensions only and hence must have the same number of
  5741. /// elements as there are spatial dimensions. The order of values is
  5742. /// the same as in the tensor: depth (for 3D tensors), height (for 3D
  5743. /// and 2D tensors), and width.
  5744. ///
  5745. /// @param aengine Engine to use.
  5746. /// @param aalgorithm Convolution algorithm. Possible values are
  5747. /// #dnnl::algorithm::convolution_direct,
  5748. /// #dnnl::algorithm::convolution_winograd, and
  5749. /// #dnnl::algorithm::convolution_auto.
  5750. /// @param src_desc Source memory descriptor.
  5751. /// @param diff_weights_desc Diff weights memory descriptor.
  5752. /// @param diff_dst_desc Diff destination memory descriptor.
  5753. /// @param strides Strides for each spatial dimension.
  5754. /// @param padding_l Vector of padding values for low indices for each
  5755. /// spatial dimension `([[front,] top,] left)`.
  5756. /// @param padding_r Vector of padding values for high indices for
  5757. /// each spatial dimension `([[back,] bottom,] right)`.
  5758. /// @param hint_fwd_pd Primitive descriptor for a convolution
  5759. /// forward propagation primitive. It is used as a hint for
  5760. /// deciding which memory format to use.
  5761. /// @param attr Primitive attributes to use. Attributes are optional
  5762. /// and default to empty attributes.
  5763. /// @param allow_empty A flag signifying whether construction is
  5764. /// allowed to fail without throwing an exception. In this case an
  5765. /// empty object will be produced. This flag is optional and
  5766. /// defaults to false.
  5767. primitive_desc(const engine &aengine, algorithm aalgorithm,
  5768. const memory::desc &src_desc,
  5769. const memory::desc &diff_weights_desc,
  5770. const memory::desc &diff_dst_desc, const memory::dims &strides,
  5771. const memory::dims &padding_l, const memory::dims &padding_r,
  5772. const convolution_forward::primitive_desc &hint_fwd_pd,
  5773. const primitive_attr &attr = default_attr(),
  5774. bool allow_empty = false)
  5775. : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
  5776. nullptr, diff_dst_desc, strides, nullptr, padding_l,
  5777. padding_r, hint_fwd_pd, attr, allow_empty) {}
  5778. /// Constructs a primitive descriptor for a convolution weights
  5779. /// gradient primitive with bias.
  5780. ///
  5781. /// @note
  5782. /// All the memory descriptors may be initialized with the
  5783. /// #dnnl::memory::format_tag::any value of @p format_tag.
  5784. ///
  5785. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
  5786. /// contain values for spatial dimensions only and hence must have the
  5787. /// same number of elements as there are spatial dimensions. The order
  5788. /// of values is the same as in the tensor: depth (for 3D tensors),
  5789. /// height (for 3D and 2D tensors), and width.
  5790. ///
  5791. /// @param aengine Engine to use.
  5792. /// @param aalgorithm Convolution algorithm. Possible values are
  5793. /// #dnnl::algorithm::convolution_direct,
  5794. /// #dnnl::algorithm::convolution_winograd, and
  5795. /// #dnnl::algorithm::convolution_auto.
  5796. /// @param src_desc Source memory descriptor.
  5797. /// @param diff_weights_desc Diff weights memory descriptor.
  5798. /// @param diff_bias_desc Diff bias memory descriptor. Passing zero
  5799. /// memory descriptor disables the bias term.
  5800. /// @param diff_dst_desc Diff destination memory descriptor.
  5801. /// @param strides Strides for each spatial dimension.
  5802. /// @param dilates Dilations for each spatial dimension. A zero value
  5803. /// means no dilation in the corresponding dimension.
  5804. /// @param padding_l Vector of padding values for low indices for each
  5805. /// spatial dimension `([[front,] top,] left)`.
  5806. /// @param padding_r Vector of padding values for high indices for
  5807. /// each spatial dimension `([[back,] bottom,] right)`.
  5808. /// @param hint_fwd_pd Primitive descriptor for a convolution
  5809. /// forward propagation primitive. It is used as a hint for
  5810. /// deciding which memory format to use.
  5811. /// @param attr Primitive attributes to use. Attributes are optional
  5812. /// and default to empty attributes.
  5813. /// @param allow_empty A flag signifying whether construction is
  5814. /// allowed to fail without throwing an exception. In this case an
  5815. /// empty object will be produced. This flag is optional and
  5816. /// defaults to false.
  5817. primitive_desc(const engine &aengine, algorithm aalgorithm,
  5818. const memory::desc &src_desc,
  5819. const memory::desc &diff_weights_desc,
  5820. const memory::desc &diff_bias_desc,
  5821. const memory::desc &diff_dst_desc, const memory::dims &strides,
  5822. const memory::dims &dilates, const memory::dims &padding_l,
  5823. const memory::dims &padding_r,
  5824. const convolution_forward::primitive_desc &hint_fwd_pd,
  5825. const primitive_attr &attr = default_attr(),
  5826. bool allow_empty = false)
  5827. : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
  5828. &diff_bias_desc, diff_dst_desc, strides, &dilates,
  5829. padding_l, padding_r, hint_fwd_pd, attr, allow_empty) {}
  5830. /// Constructs a primitive descriptor for a convolution weights
  5831. /// gradient primitive without bias.
  5832. ///
  5833. /// @note
  5834. /// All the memory descriptors may be initialized with the
  5835. /// #dnnl::memory::format_tag::any value of @p format_tag.
  5836. ///
  5837. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
  5838. /// contain values for spatial dimensions only and hence must have the
  5839. /// same number of elements as there are spatial dimensions. The order
  5840. /// of values is the same as in the tensor: depth (for 3D tensors),
  5841. /// height (for 3D and 2D tensors), and width.
  5842. ///
  5843. /// @param aengine Engine to use.
  5844. /// @param aalgorithm Convolution algorithm. Possible values are
  5845. /// #dnnl::algorithm::convolution_direct,
  5846. /// #dnnl::algorithm::convolution_winograd, and
  5847. /// #dnnl::algorithm::convolution_auto.
  5848. /// @param src_desc Source memory descriptor.
  5849. /// @param diff_weights_desc Diff weights memory descriptor.
  5850. /// @param diff_dst_desc Diff destination memory descriptor.
  5851. /// @param strides Strides for each spatial dimension.
  5852. /// @param dilates Dilations for each spatial dimension. A zero value
  5853. /// means no dilation in the corresponding dimension.
  5854. /// @param padding_l Vector of padding values for low indices for each
  5855. /// spatial dimension `([[front,] top,] left)`.
  5856. /// @param padding_r Vector of padding values for high indices for
  5857. /// each spatial dimension `([[back,] bottom,] right)`.
  5858. /// @param hint_fwd_pd Primitive descriptor for a convolution
  5859. /// forward propagation primitive. It is used as a hint for
  5860. /// deciding which memory format to use.
  5861. /// @param attr Primitive attributes to use. Attributes are optional
  5862. /// and default to empty attributes.
  5863. /// @param allow_empty A flag signifying whether construction is
  5864. /// allowed to fail without throwing an exception. In this case an
  5865. /// empty object will be produced. This flag is optional and
  5866. /// defaults to false.
  5867. primitive_desc(const engine &aengine, algorithm aalgorithm,
  5868. const memory::desc &src_desc,
  5869. const memory::desc &diff_weights_desc,
  5870. const memory::desc &diff_dst_desc, const memory::dims &strides,
  5871. const memory::dims &dilates, const memory::dims &padding_l,
  5872. const memory::dims &padding_r,
  5873. const convolution_forward::primitive_desc &hint_fwd_pd,
  5874. const primitive_attr &attr = default_attr(),
  5875. bool allow_empty = false)
  5876. : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
  5877. nullptr, diff_dst_desc, strides, &dilates, padding_l,
  5878. padding_r, hint_fwd_pd, attr, allow_empty) {}
  5879. /// Constructs a primitive descriptor for a convolution weights gradient
  5880. /// primitive from a C API primitive descriptor that must have a
  5881. /// matching kind.
  5882. ///
  5883. /// @param pd C API primitive descriptor for a convolution weights
  5884. /// gradient primitive.
  5885. primitive_desc(dnnl_primitive_desc_t pd)
  5886. : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
  5887. dnnl::prop_kind::backward_weights) {}
  5888. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  5889. memory::desc src_desc() const { return base::src_desc(0); }
  5890. /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
  5891. memory::desc diff_weights_desc() const {
  5892. return base::diff_weights_desc(0);
  5893. }
  5894. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  5895. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  5896. /// Returns the diff bias memory descriptor.
  5897. /// @returns The diff bias memory descriptor.
  5898. /// @returns A zero memory descriptor of the primitive does not have a
  5899. /// diff bias parameter.
  5900. memory::desc diff_bias_desc() const {
  5901. return base::diff_weights_desc(1);
  5902. }
  5903. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  5904. algorithm get_algorithm() const { return base::get_algorithm(); }
  5905. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  5906. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  5907. /// @copydoc dnnl::primitive_desc_base::get_strides()const
  5908. memory::dims get_strides() const { return base::get_strides(); }
  5909. /// @copydoc dnnl::primitive_desc_base::get_dilations()const
  5910. memory::dims get_dilations() const { return base::get_dilations(); }
  5911. /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
  5912. memory::dims get_padding_l() const { return base::get_padding_l(); }
  5913. /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
  5914. memory::dims get_padding_r() const { return base::get_padding_r(); }
  5915. private:
  5916. primitive_desc(const engine &aengine, algorithm aalgorithm,
  5917. const memory::desc &src_desc,
  5918. const memory::desc &diff_weights_desc,
  5919. const memory::desc *diff_bias_desc,
  5920. const memory::desc &diff_dst_desc, const memory::dims &strides,
  5921. const memory::dims *dilates, const memory::dims &padding_l,
  5922. const memory::dims &padding_r,
  5923. const convolution_forward::primitive_desc &hint_fwd_pd,
  5924. const primitive_attr &attr, bool allow_empty) {
  5925. memory::validate_dims(strides, src_desc.get_ndims() - 2);
  5926. memory::validate_dims(padding_l, src_desc.get_ndims() - 2);
  5927. memory::validate_dims(padding_r, src_desc.get_ndims() - 2);
  5928. if (dilates)
  5929. memory::validate_dims(*dilates, src_desc.get_ndims() - 2);
  5930. dnnl_primitive_desc_t pd = nullptr;
  5931. dnnl_status_t status
  5932. = dnnl_convolution_backward_weights_primitive_desc_create(
  5933. &pd, aengine.get(), convert_to_c(aalgorithm),
  5934. src_desc.get(), diff_weights_desc.get(),
  5935. optional_arg(diff_bias_desc), diff_dst_desc.get(),
  5936. &strides[0], optional_arg(dilates), &padding_l[0],
  5937. &padding_r[0], hint_fwd_pd.get(), attr.get());
  5938. if (!allow_empty)
  5939. error::wrap_c_api(status,
  5940. "could not create a primitive descriptor for "
  5941. "the convolution weights update primitive. Run "
  5942. "workload with environment variable ONEDNN_VERBOSE=all "
  5943. "to get additional diagnostic information.");
  5944. reset(pd);
  5945. }
  5946. };
  5947. /// Default constructor. Produces an empty object.
  5948. convolution_backward_weights() = default;
  5949. /// Constructs a convolution weights gradient primitive.
  5950. /// @param pd Primitive descriptor for a convolution weights gradient
  5951. /// primitive.
  5952. convolution_backward_weights(const primitive_desc &pd) : primitive(pd) {}
  5953. /// Constructs a convolution weights gradient primitive from a cache blob.
  5954. /// @param pd Primitive descriptor for a convolution weights gradient
  5955. /// primitive.
  5956. /// @param cache_blob Cache blob.
  5957. convolution_backward_weights(
  5958. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  5959. : primitive(pd, cache_blob) {}
  5960. };
  5961. /// @} dnnl_api_convolution
  5962. //
  5963. /// @addtogroup dnnl_api_deconvolution Deconvolution
  5964. ///
  5965. /// A primitive to perform 1D, 2D or 3D deconvolution. Supported variants are
  5966. /// forward propagation, backward propagation, and weights gradient with or
  5967. /// without bias.
  5968. ///
  5969. /// @{
  5970. /// Deconvolution forward propagation primitive.
  5971. struct deconvolution_forward : public primitive {
  5972. /// Primitive descriptor for a deconvolution forward propagation primitive.
  5973. struct primitive_desc : public dnnl::primitive_desc {
  5974. /// Default constructor. Produces an empty object.
  5975. primitive_desc() = default;
  5976. /// Constructs a primitive descriptor for a deconvolution forward
  5977. /// propagation primitive with bias.
  5978. ///
  5979. /// @note
  5980. /// All the memory descriptors may be initialized with the
  5981. /// #dnnl::memory::format_tag::any value of @p format_tag.
  5982. ///
  5983. /// Arrays @p strides, @p padding_l, and @p padding_r contain values
  5984. /// for spatial dimensions only and hence must have the same number of
  5985. /// elements as there are spatial dimensions. The order of values is
  5986. /// the same as in the tensor: depth (for 3D tensors), height (for 3D
  5987. /// and 2D tensors), and width.
  5988. ///
  5989. /// @param aengine Engine to use.
  5990. /// @param aprop_kind Propagation kind. Possible values are
  5991. /// #dnnl::prop_kind::forward_training, and
  5992. /// #dnnl::prop_kind::forward_inference.
  5993. /// @param aalgorithm Deconvolution algorithm:
  5994. /// #dnnl::algorithm::deconvolution_direct, and
  5995. /// #dnnl::algorithm::deconvolution_winograd.
  5996. /// @param src_desc Source memory descriptor.
  5997. /// @param weights_desc Weights memory descriptor.
  5998. /// @param bias_desc Bias memory descriptor. Passing zero memory
  5999. /// descriptor disables the bias term.
  6000. /// @param dst_desc Destination memory descriptor.
  6001. /// @param strides Vector of strides for spatial dimension.
  6002. /// @param padding_l Vector of padding values for low indices for each
  6003. /// spatial dimension `([[front,] top,] left)`.
  6004. /// @param padding_r Vector of padding values for high indices for
  6005. /// each spatial dimension `([[back,] bottom,] right)`.
  6006. /// @param attr Primitive attributes to use. Attributes are optional
  6007. /// and default to empty attributes.
  6008. /// @param allow_empty A flag signifying whether construction is
  6009. /// allowed to fail without throwing an exception. In this case an
  6010. /// empty object will be produced. This flag is optional and
  6011. /// defaults to false.
  6012. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  6013. algorithm aalgorithm, const memory::desc &src_desc,
  6014. const memory::desc &weights_desc, const memory::desc &bias_desc,
  6015. const memory::desc &dst_desc, const memory::dims &strides,
  6016. const memory::dims &padding_l, const memory::dims &padding_r,
  6017. const primitive_attr &attr = default_attr(),
  6018. bool allow_empty = false)
  6019. : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
  6020. weights_desc, &bias_desc, dst_desc, strides, nullptr,
  6021. padding_l, padding_r, attr, allow_empty) {}
  6022. /// Constructs a primitive descriptor for a deconvolution forward
  6023. /// propagation primitive without bias.
  6024. ///
  6025. /// @note
  6026. /// All the memory descriptors may be initialized with the
  6027. /// #dnnl::memory::format_tag::any value of @p format_tag.
  6028. ///
  6029. /// Arrays @p strides, @p padding_l, and @p padding_r contain values
  6030. /// for spatial dimensions only and hence must have the same number of
  6031. /// elements as there are spatial dimensions. The order of values is
  6032. /// the same as in the tensor: depth (for 3D tensors), height (for 3D
  6033. /// and 2D tensors), and width.
  6034. ///
  6035. /// @param aengine Engine to use.
  6036. /// @param aprop_kind Propagation kind. Possible values are
  6037. /// #dnnl::prop_kind::forward_training, and
  6038. /// #dnnl::prop_kind::forward_inference.
  6039. /// @param aalgorithm Deconvolution algorithm:
  6040. /// #dnnl::algorithm::deconvolution_direct, and
  6041. /// #dnnl::algorithm::deconvolution_winograd.
  6042. /// @param src_desc Source memory descriptor.
  6043. /// @param weights_desc Weights memory descriptor.
  6044. /// @param dst_desc Destination memory descriptor.
  6045. /// @param strides Vector of strides for spatial dimension.
  6046. /// @param padding_l Vector of padding values for low indices for each
  6047. /// spatial dimension `([[front,] top,] left)`.
  6048. /// @param padding_r Vector of padding values for high indices for
  6049. /// each spatial dimension `([[back,] bottom,] right)`.
  6050. /// @param attr Primitive attributes to use. Attributes are optional
  6051. /// and default to empty attributes.
  6052. /// @param allow_empty A flag signifying whether construction is
  6053. /// allowed to fail without throwing an exception. In this case an
  6054. /// empty object will be produced. This flag is optional and
  6055. /// defaults to false.
  6056. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  6057. algorithm aalgorithm, const memory::desc &src_desc,
  6058. const memory::desc &weights_desc, const memory::desc &dst_desc,
  6059. const memory::dims &strides, const memory::dims &padding_l,
  6060. const memory::dims &padding_r,
  6061. const primitive_attr &attr = default_attr(),
  6062. bool allow_empty = false)
  6063. : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
  6064. weights_desc, nullptr, dst_desc, strides, nullptr,
  6065. padding_l, padding_r, attr, allow_empty) {}
  6066. /// Constructs a primitive descriptor for a deconvolution forward
  6067. /// propagation primitive with bias.
  6068. ///
  6069. /// @note
  6070. /// All the memory descriptors may be initialized with the
  6071. /// #dnnl::memory::format_tag::any value of @p format_tag.
  6072. ///
  6073. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
  6074. /// contain values for spatial dimensions only and hence must have the
  6075. /// same number of elements as there are spatial dimensions. The order
  6076. /// of values is the same as in the tensor: depth (for 3D tensors),
  6077. /// height (for 3D and 2D tensors), and width.
  6078. ///
  6079. /// @param aengine Engine to use.
  6080. /// @param aprop_kind Propagation kind. Possible values are
  6081. /// #dnnl::prop_kind::forward_training, and
  6082. /// #dnnl::prop_kind::forward_inference.
  6083. /// @param aalgorithm Deconvolution algorithm:
  6084. /// #dnnl::algorithm::deconvolution_direct, and
  6085. /// #dnnl::algorithm::deconvolution_winograd.
  6086. /// @param src_desc Source memory descriptor.
  6087. /// @param weights_desc Weights memory descriptor.
  6088. /// @param bias_desc Bias memory descriptor. Passing zero memory
  6089. /// descriptor disables the bias term.
  6090. /// @param dst_desc Destination memory descriptor.
  6091. /// @param strides Vector of strides for spatial dimension.
  6092. /// @param dilates Dilations for each spatial dimension. A zero value
  6093. /// means no dilation in the corresponding dimension.
  6094. /// @param padding_l Vector of padding values for low indices for each
  6095. /// spatial dimension `([[front,] top,] left)`.
  6096. /// @param padding_r Vector of padding values for high indices for
  6097. /// each spatial dimension `([[back,] bottom,] right)`.
  6098. /// @param attr Primitive attributes to use. Attributes are optional
  6099. /// and default to empty attributes.
  6100. /// @param allow_empty A flag signifying whether construction is
  6101. /// allowed to fail without throwing an exception. In this case an
  6102. /// empty object will be produced. This flag is optional and
  6103. /// defaults to false.
  6104. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  6105. algorithm aalgorithm, const memory::desc &src_desc,
  6106. const memory::desc &weights_desc, const memory::desc &bias_desc,
  6107. const memory::desc &dst_desc, const memory::dims &strides,
  6108. const memory::dims &dilates, const memory::dims &padding_l,
  6109. const memory::dims &padding_r,
  6110. const primitive_attr &attr = default_attr(),
  6111. bool allow_empty = false)
  6112. : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
  6113. weights_desc, &bias_desc, dst_desc, strides, &dilates,
  6114. padding_l, padding_r, attr, allow_empty) {}
  6115. /// Constructs a primitive descriptor for a deconvolution forward
  6116. /// propagation primitive without bias.
  6117. ///
  6118. /// @note
  6119. /// All the memory descriptors may be initialized with the
  6120. /// #dnnl::memory::format_tag::any value of @p format_tag.
  6121. ///
  6122. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
  6123. /// contain values for spatial dimensions only and hence must have the
  6124. /// same number of elements as there are spatial dimensions. The order
  6125. /// of values is the same as in the tensor: depth (for 3D tensors),
  6126. /// height (for 3D and 2D tensors), and width.
  6127. ///
  6128. /// @param aengine Engine to use.
  6129. /// @param aprop_kind Propagation kind. Possible values are
  6130. /// #dnnl::prop_kind::forward_training, and
  6131. /// #dnnl::prop_kind::forward_inference.
  6132. /// @param aalgorithm Deconvolution algorithm:
  6133. /// #dnnl::algorithm::deconvolution_direct, and
  6134. /// #dnnl::algorithm::deconvolution_winograd.
  6135. /// @param src_desc Source memory descriptor.
  6136. /// @param weights_desc Weights memory descriptor.
  6137. /// @param dst_desc Destination memory descriptor.
  6138. /// @param strides Vector of strides for spatial dimension.
  6139. /// @param dilates Dilations for each spatial dimension. A zero value
  6140. /// means no dilation in the corresponding dimension.
  6141. /// @param padding_l Vector of padding values for low indices for each
  6142. /// spatial dimension `([[front,] top,] left)`.
  6143. /// @param padding_r Vector of padding values for high indices for
  6144. /// each spatial dimension `([[back,] bottom,] right)`.
  6145. /// @param attr Primitive attributes to use. Attributes are optional
  6146. /// and default to empty attributes.
  6147. /// @param allow_empty A flag signifying whether construction is
  6148. /// allowed to fail without throwing an exception. In this case an
  6149. /// empty object will be produced. This flag is optional and
  6150. /// defaults to false.
  6151. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  6152. algorithm aalgorithm, const memory::desc &src_desc,
  6153. const memory::desc &weights_desc, const memory::desc &dst_desc,
  6154. const memory::dims &strides, const memory::dims &dilates,
  6155. const memory::dims &padding_l, const memory::dims &padding_r,
  6156. const primitive_attr &attr = default_attr(),
  6157. bool allow_empty = false)
  6158. : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
  6159. weights_desc, nullptr, dst_desc, strides, &dilates,
  6160. padding_l, padding_r, attr, allow_empty) {}
  6161. /// Constructs a primitive descriptor for a deconvolution forward
  6162. /// propagation primitive from a C API primitive descriptor that must
  6163. /// have a matching kind.
  6164. ///
  6165. /// @param pd C API primitive descriptor for a deconvolution forward
  6166. /// propagation primitive.
  6167. primitive_desc(dnnl_primitive_desc_t pd)
  6168. : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
  6169. dnnl::prop_kind::forward_training,
  6170. dnnl::prop_kind::forward_inference) {}
  6171. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  6172. memory::desc src_desc() const { return base::src_desc(0); }
  6173. /// @copydoc dnnl::primitive_desc_base::weights_desc()const
  6174. memory::desc weights_desc() const { return base::weights_desc(0); }
  6175. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  6176. memory::desc dst_desc() const { return base::dst_desc(0); }
  6177. /// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const
  6178. memory::desc bias_desc() const { return base::weights_desc(1); }
  6179. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  6180. algorithm get_algorithm() const { return base::get_algorithm(); }
  6181. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  6182. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  6183. /// @copydoc dnnl::primitive_desc_base::get_strides()const
  6184. memory::dims get_strides() const { return base::get_strides(); }
  6185. /// @copydoc dnnl::primitive_desc_base::get_dilations()const
  6186. memory::dims get_dilations() const { return base::get_dilations(); }
  6187. /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
  6188. memory::dims get_padding_l() const { return base::get_padding_l(); }
  6189. /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
  6190. memory::dims get_padding_r() const { return base::get_padding_r(); }
  6191. private:
  6192. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  6193. algorithm aalgorithm, const memory::desc &src_desc,
  6194. const memory::desc &weights_desc, const memory::desc *bias_desc,
  6195. const memory::desc &dst_desc, const memory::dims &strides,
  6196. const memory::dims *dilates, const memory::dims &padding_l,
  6197. const memory::dims &padding_r, const primitive_attr &attr,
  6198. bool allow_empty) {
  6199. memory::validate_dims(strides, src_desc.get_ndims() - 2);
  6200. memory::validate_dims(padding_l, src_desc.get_ndims() - 2);
  6201. memory::validate_dims(padding_r, src_desc.get_ndims() - 2);
  6202. if (dilates)
  6203. memory::validate_dims(*dilates, src_desc.get_ndims() - 2);
  6204. dnnl_primitive_desc_t pd = nullptr;
  6205. dnnl_status_t status
  6206. = dnnl_deconvolution_forward_primitive_desc_create(&pd,
  6207. aengine.get(), dnnl::convert_to_c(aprop_kind),
  6208. convert_to_c(aalgorithm), src_desc.get(),
  6209. weights_desc.get(), optional_arg(bias_desc),
  6210. dst_desc.get(), &strides[0], optional_arg(dilates),
  6211. &padding_l[0], &padding_r[0], attr.get());
  6212. if (!allow_empty)
  6213. error::wrap_c_api(status,
  6214. "could not create a primitive descriptor for "
  6215. "the deconvolution forward propagation primitive. Run "
  6216. "workload with environment variable ONEDNN_VERBOSE=all "
  6217. "to get additional diagnostic information.");
  6218. reset(pd);
  6219. }
  6220. };
  6221. /// Default constructor. Produces an empty object.
  6222. deconvolution_forward() = default;
  6223. /// Constructs a deconvolution forward propagation primitive.
  6224. /// @param pd Primitive descriptor for a deconvolution forward propagation
  6225. /// primitive.
  6226. deconvolution_forward(const primitive_desc &pd) : primitive(pd) {}
  6227. /// Constructs a deconvolution forward propagation primitive from a cache
  6228. /// blob.
  6229. /// @param pd Primitive descriptor for a deconvolution forward propagation
  6230. /// primitive.
  6231. /// @param cache_blob Cache blob.
  6232. deconvolution_forward(
  6233. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  6234. : primitive(pd, cache_blob) {}
  6235. };
  6236. /// Deconvolution backward propagation primitive.
  6237. struct deconvolution_backward_data : public primitive {
  6238. /// Primitive descriptor for a deconvolution backward propagation primitive.
  6239. struct primitive_desc : public dnnl::primitive_desc {
  6240. /// Default constructor. Produces an empty object.
  6241. primitive_desc() = default;
  6242. /// Constructs a primitive descriptor for a deconvolution backward
  6243. /// propagation primitive.
  6244. ///
  6245. /// @note
  6246. /// All the memory descriptors may be initialized with the
  6247. /// #dnnl::memory::format_tag::any value of @p format_tag.
  6248. ///
  6249. /// Arrays @p strides, @p padding_l, and @p padding_r contain values
  6250. /// for spatial dimensions only and hence must have the same number of
  6251. /// elements as there are spatial dimensions. The order of values is
  6252. /// the same as in the tensor: depth (for 3D tensors), height (for 3D
  6253. /// and 2D tensors), and width.
  6254. ///
  6255. /// @param aengine Engine to use.
  6256. /// @param aalgorithm Deconvolution algorithm
  6257. /// (#dnnl::algorithm::convolution_direct,
  6258. /// #dnnl::algorithm::convolution_winograd).
  6259. /// @param diff_src_desc Diff source memory descriptor.
  6260. /// @param weights_desc Weights memory descriptor.
  6261. /// @param diff_dst_desc Diff destination memory descriptor.
  6262. /// @param strides Strides for each spatial dimension.
  6263. /// @param padding_l Vector of padding values for low indices for each
  6264. /// spatial dimension `([[front,] top,] left)`.
  6265. /// @param padding_r Vector of padding values for high indices for
  6266. /// each spatial dimension `([[back,] bottom,] right)`.
  6267. /// @param hint_fwd_pd Primitive descriptor for a deconvolution
  6268. /// forward propagation primitive. It is used as a hint for
  6269. /// deciding which memory format to use.
  6270. /// @param attr Primitive attributes to use. Attributes are optional
  6271. /// and default to empty attributes.
  6272. /// @param allow_empty A flag signifying whether construction is
  6273. /// allowed to fail without throwing an exception. In this case an
  6274. /// empty object will be produced. This flag is optional and
  6275. /// defaults to false.
  6276. primitive_desc(const engine &aengine, algorithm aalgorithm,
  6277. const memory::desc &diff_src_desc,
  6278. const memory::desc &weights_desc,
  6279. const memory::desc &diff_dst_desc, const memory::dims &strides,
  6280. const memory::dims &padding_l, const memory::dims &padding_r,
  6281. const deconvolution_forward::primitive_desc &hint_fwd_pd,
  6282. const primitive_attr &attr = default_attr(),
  6283. bool allow_empty = false)
  6284. : primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc,
  6285. diff_dst_desc, strides, nullptr, padding_l, padding_r,
  6286. hint_fwd_pd, attr, allow_empty) {}
  6287. /// Constructs a primitive descriptor for a deconvolution backward
  6288. /// propagation primitive.
  6289. ///
  6290. /// @note
  6291. /// All the memory descriptors may be initialized with the
  6292. /// #dnnl::memory::format_tag::any value of @p format_tag.
  6293. ///
  6294. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
  6295. /// contain values for spatial dimensions only and hence must have the
  6296. /// same number of elements as there are spatial dimensions. The order
  6297. /// of values is the same as in the tensor: depth (for 3D tensors),
  6298. /// height (for 3D and 2D tensors), and width.
  6299. ///
  6300. /// @param aengine Engine to use.
  6301. /// @param aalgorithm Deconvolution algorithm
  6302. /// (#dnnl::algorithm::convolution_direct,
  6303. /// #dnnl::algorithm::convolution_winograd).
  6304. /// @param diff_src_desc Diff source memory descriptor.
  6305. /// @param weights_desc Weights memory descriptor.
  6306. /// @param diff_dst_desc Diff destination memory descriptor.
  6307. /// @param strides Strides for each spatial dimension.
  6308. /// @param dilates Dilations for each spatial dimension. A zero value
  6309. /// means no dilation in the corresponding dimension.
  6310. /// @param padding_l Vector of padding values for low indices for each
  6311. /// spatial dimension `([[front,] top,] left)`.
  6312. /// @param padding_r Vector of padding values for high indices for
  6313. /// each spatial dimension `([[back,] bottom,] right)`.
  6314. /// @param hint_fwd_pd Primitive descriptor for a deconvolution
  6315. /// forward propagation primitive. It is used as a hint for
  6316. /// deciding which memory format to use.
  6317. /// @param attr Primitive attributes to use. Attributes are optional
  6318. /// and default to empty attributes.
  6319. /// @param allow_empty A flag signifying whether construction is
  6320. /// allowed to fail without throwing an exception. In this case an
  6321. /// empty object will be produced. This flag is optional and
  6322. /// defaults to false.
  6323. primitive_desc(const engine &aengine, algorithm aalgorithm,
  6324. const memory::desc &diff_src_desc,
  6325. const memory::desc &weights_desc,
  6326. const memory::desc &diff_dst_desc, const memory::dims &strides,
  6327. const memory::dims &dilates, const memory::dims &padding_l,
  6328. const memory::dims &padding_r,
  6329. const deconvolution_forward::primitive_desc &hint_fwd_pd,
  6330. const primitive_attr &attr = default_attr(),
  6331. bool allow_empty = false)
  6332. : primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc,
  6333. diff_dst_desc, strides, &dilates, padding_l, padding_r,
  6334. hint_fwd_pd, attr, allow_empty) {}
  6335. /// Constructs a primitive descriptor for a deconvolution backward
  6336. /// propagation primitive from a C API primitive descriptor that must
  6337. /// have a matching kind.
  6338. ///
  6339. /// @param pd C API primitive descriptor for a deconvolution backward
  6340. /// propagation primitive.
  6341. primitive_desc(dnnl_primitive_desc_t pd)
  6342. : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
  6343. dnnl::prop_kind::backward_data) {}
  6344. /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
  6345. memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
  6346. /// @copydoc dnnl::primitive_desc_base::weights_desc()const
  6347. memory::desc weights_desc() const { return base::weights_desc(0); }
  6348. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  6349. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  6350. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  6351. algorithm get_algorithm() const { return base::get_algorithm(); }
  6352. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  6353. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  6354. /// @copydoc dnnl::primitive_desc_base::get_strides()const
  6355. memory::dims get_strides() const { return base::get_strides(); }
  6356. /// @copydoc dnnl::primitive_desc_base::get_dilations()const
  6357. memory::dims get_dilations() const { return base::get_dilations(); }
  6358. /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
  6359. memory::dims get_padding_l() const { return base::get_padding_l(); }
  6360. /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
  6361. memory::dims get_padding_r() const { return base::get_padding_r(); }
  6362. private:
  6363. primitive_desc(const engine &aengine, algorithm aalgorithm,
  6364. const memory::desc &diff_src_desc,
  6365. const memory::desc &weights_desc,
  6366. const memory::desc &diff_dst_desc, const memory::dims &strides,
  6367. const memory::dims *dilates, const memory::dims &padding_l,
  6368. const memory::dims &padding_r,
  6369. const deconvolution_forward::primitive_desc &hint_fwd_pd,
  6370. const primitive_attr &attr, bool allow_empty) {
  6371. memory::validate_dims(strides, diff_src_desc.get_ndims() - 2);
  6372. memory::validate_dims(padding_l, diff_src_desc.get_ndims() - 2);
  6373. memory::validate_dims(padding_r, diff_src_desc.get_ndims() - 2);
  6374. if (dilates)
  6375. memory::validate_dims(*dilates, diff_src_desc.get_ndims() - 2);
  6376. dnnl_primitive_desc_t pd = nullptr;
  6377. dnnl_status_t status
  6378. = dnnl_deconvolution_backward_data_primitive_desc_create(
  6379. &pd, aengine.get(), convert_to_c(aalgorithm),
  6380. diff_src_desc.get(), weights_desc.get(),
  6381. diff_dst_desc.get(), &strides[0],
  6382. optional_arg(dilates), &padding_l[0], &padding_r[0],
  6383. hint_fwd_pd.get(), attr.get());
  6384. if (!allow_empty)
  6385. error::wrap_c_api(status,
  6386. "could not create a primitive descriptor for "
  6387. "the deconvolution backward propagation primitive. Run "
  6388. "workload with environment variable ONEDNN_VERBOSE=all "
  6389. "to get additional diagnostic information.");
  6390. reset(pd);
  6391. }
  6392. };
  6393. /// Default constructor. Produces an empty object.
  6394. deconvolution_backward_data() = default;
  6395. /// Constructs a deconvolution backward propagation primitive.
  6396. /// @param pd Primitive descriptor for a deconvolution backward propagation
  6397. /// primitive.
  6398. deconvolution_backward_data(const primitive_desc &pd) : primitive(pd) {}
  6399. /// Constructs a deconvolution backward propagation primitive from a cache
  6400. /// blob.
  6401. /// @param pd Primitive descriptor for a deconvolution backward propagation
  6402. /// primitive.
  6403. /// @param cache_blob Cache blob.
  6404. deconvolution_backward_data(
  6405. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  6406. : primitive(pd, cache_blob) {}
  6407. };
  6408. /// Deconvolution weights gradient primitive.
  6409. struct deconvolution_backward_weights : public primitive {
  6410. /// Primitive descriptor for a deconvolution weights gradient primitive.
  6411. struct primitive_desc : public dnnl::primitive_desc {
  6412. /// Default constructor. Produces an empty object.
  6413. primitive_desc() = default;
  6414. /// Constructs a primitive descriptor for a deconvolution weights
  6415. /// gradient primitive with bias.
  6416. ///
  6417. /// @note
  6418. /// All the memory descriptors may be initialized with the
  6419. /// #dnnl::memory::format_tag::any value of @p format_tag.
  6420. ///
  6421. /// Arrays @p strides, @p padding_l, and @p padding_r contain values
  6422. /// for spatial dimensions only and hence must have the same number of
  6423. /// elements as there are spatial dimensions. The order of values is
  6424. /// the same as in the tensor: depth (for 3D tensors), height (for 3D
  6425. /// and 2D tensors), and width.
  6426. ///
  6427. /// @param aengine Engine to use.
  6428. /// @param aalgorithm Deconvolution algorithm. Possible values are
  6429. /// #dnnl::algorithm::deconvolution_direct, and
  6430. /// #dnnl::algorithm::deconvolution_winograd.
  6431. /// @param src_desc Source memory descriptor.
  6432. /// @param diff_weights_desc Diff weights memory descriptor.
  6433. /// @param diff_bias_desc Diff bias memory descriptor. Passing zero
  6434. /// memory descriptor disables the bias term.
  6435. /// @param diff_dst_desc Diff destination memory descriptor.
  6436. /// @param strides Strides for each spatial dimension.
  6437. /// @param padding_l Vector of padding values for low indices for each
  6438. /// spatial dimension `([[front,] top,] left)`.
  6439. /// @param padding_r Vector of padding values for high indices for
  6440. /// each spatial dimension `([[back,] bottom,] right)`.
  6441. /// @param hint_fwd_pd Primitive descriptor for a deconvolution
  6442. /// forward propagation primitive. It is used as a hint for
  6443. /// deciding which memory format to use.
  6444. /// @param attr Primitive attributes to use. Attributes are optional
  6445. /// and default to empty attributes.
  6446. /// @param allow_empty A flag signifying whether construction is
  6447. /// allowed to fail without throwing an exception. In this case an
  6448. /// empty object will be produced. This flag is optional and
  6449. /// defaults to false.
  6450. primitive_desc(const engine &aengine, algorithm aalgorithm,
  6451. const memory::desc &src_desc,
  6452. const memory::desc &diff_weights_desc,
  6453. const memory::desc &diff_bias_desc,
  6454. const memory::desc &diff_dst_desc, const memory::dims &strides,
  6455. const memory::dims &padding_l, const memory::dims &padding_r,
  6456. const deconvolution_forward::primitive_desc &hint_fwd_pd,
  6457. const primitive_attr &attr = default_attr(),
  6458. bool allow_empty = false)
  6459. : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
  6460. &diff_bias_desc, diff_dst_desc, strides, nullptr, padding_l,
  6461. padding_r, hint_fwd_pd, attr, allow_empty) {}
  6462. /// Constructs a primitive descriptor for a deconvolution weights
  6463. /// gradient primitive without bias.
  6464. ///
  6465. /// @note
  6466. /// All the memory descriptors may be initialized with the
  6467. /// #dnnl::memory::format_tag::any value of @p format_tag.
  6468. ///
  6469. /// Arrays @p strides, @p padding_l, and @p padding_r contain values
  6470. /// for spatial dimensions only and hence must have the same number of
  6471. /// elements as there are spatial dimensions. The order of values is
  6472. /// the same as in the tensor: depth (for 3D tensors), height (for 3D
  6473. /// and 2D tensors), and width.
  6474. ///
  6475. /// @param aengine Engine to use.
  6476. /// @param aalgorithm Deconvolution algorithm. Possible values are
  6477. /// #dnnl::algorithm::deconvolution_direct, and
  6478. /// #dnnl::algorithm::deconvolution_winograd.
  6479. /// @param src_desc Source memory descriptor.
  6480. /// @param diff_weights_desc Diff weights memory descriptor.
  6481. /// @param diff_dst_desc Diff destination memory descriptor.
  6482. /// @param strides Strides for each spatial dimension.
  6483. /// @param padding_l Vector of padding values for low indices for each
  6484. /// spatial dimension `([[front,] top,] left)`.
  6485. /// @param padding_r Vector of padding values for high indices for
  6486. /// each spatial dimension `([[back,] bottom,] right)`.
  6487. /// @param hint_fwd_pd Primitive descriptor for a deconvolution
  6488. /// forward propagation primitive. It is used as a hint for
  6489. /// deciding which memory format to use.
  6490. /// @param attr Primitive attributes to use. Attributes are optional
  6491. /// and default to empty attributes.
  6492. /// @param allow_empty A flag signifying whether construction is
  6493. /// allowed to fail without throwing an exception. In this case an
  6494. /// empty object will be produced. This flag is optional and
  6495. /// defaults to false.
  6496. primitive_desc(const engine &aengine, algorithm aalgorithm,
  6497. const memory::desc &src_desc,
  6498. const memory::desc &diff_weights_desc,
  6499. const memory::desc &diff_dst_desc, const memory::dims &strides,
  6500. const memory::dims &padding_l, const memory::dims &padding_r,
  6501. const deconvolution_forward::primitive_desc &hint_fwd_pd,
  6502. const primitive_attr &attr = default_attr(),
  6503. bool allow_empty = false)
  6504. : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
  6505. nullptr, diff_dst_desc, strides, nullptr, padding_l,
  6506. padding_r, hint_fwd_pd, attr, allow_empty) {}
  6507. /// Constructs a primitive descriptor for a deconvolution weights
  6508. /// gradient primitive with bias.
  6509. ///
  6510. /// @note
  6511. /// All the memory descriptors may be initialized with the
  6512. /// #dnnl::memory::format_tag::any value of @p format_tag.
  6513. ///
  6514. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
  6515. /// contain values for spatial dimensions only and hence must have the
  6516. /// same number of elements as there are spatial dimensions. The order
  6517. /// of values is the same as in the tensor: depth (for 3D tensors),
  6518. /// height (for 3D and 2D tensors), and width.
  6519. ///
  6520. /// @param aengine Engine to use.
  6521. /// @param aalgorithm Deconvolution algorithm. Possible values are
  6522. /// #dnnl::algorithm::deconvolution_direct, and
  6523. /// #dnnl::algorithm::deconvolution_winograd.
  6524. /// @param src_desc Source memory descriptor.
  6525. /// @param diff_weights_desc Diff weights memory descriptor.
  6526. /// @param diff_bias_desc Diff bias memory descriptor. Passing zero
  6527. /// memory descriptor disables the bias term.
  6528. /// @param diff_dst_desc Diff destination memory descriptor.
  6529. /// @param strides Strides for each spatial dimension.
  6530. /// @param dilates Dilations for each spatial dimension. A zero value
  6531. /// means no dilation in the corresponding dimension.
  6532. /// @param padding_l Vector of padding values for low indices for each
  6533. /// spatial dimension `([[front,] top,] left)`.
  6534. /// @param padding_r Vector of padding values for high indices for
  6535. /// each spatial dimension `([[back,] bottom,] right)`.
  6536. /// @param hint_fwd_pd Primitive descriptor for a deconvolution
  6537. /// forward propagation primitive. It is used as a hint for
  6538. /// deciding which memory format to use.
  6539. /// @param attr Primitive attributes to use. Attributes are optional
  6540. /// and default to empty attributes.
  6541. /// @param allow_empty A flag signifying whether construction is
  6542. /// allowed to fail without throwing an exception. In this case an
  6543. /// empty object will be produced. This flag is optional and
  6544. /// defaults to false.
  6545. primitive_desc(const engine &aengine, algorithm aalgorithm,
  6546. const memory::desc &src_desc,
  6547. const memory::desc &diff_weights_desc,
  6548. const memory::desc &diff_bias_desc,
  6549. const memory::desc &diff_dst_desc, const memory::dims &strides,
  6550. const memory::dims &dilates, const memory::dims &padding_l,
  6551. const memory::dims &padding_r,
  6552. const deconvolution_forward::primitive_desc &hint_fwd_pd,
  6553. const primitive_attr &attr = default_attr(),
  6554. bool allow_empty = false)
  6555. : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
  6556. &diff_bias_desc, diff_dst_desc, strides, &dilates,
  6557. padding_l, padding_r, hint_fwd_pd, attr, allow_empty) {}
  6558. /// Constructs a primitive descriptor for a deconvolution weights
  6559. /// gradient primitive without bias.
  6560. ///
  6561. /// @note
  6562. /// All the memory descriptors may be initialized with the
  6563. /// #dnnl::memory::format_tag::any value of @p format_tag.
  6564. ///
  6565. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
  6566. /// contain values for spatial dimensions only and hence must have the
  6567. /// same number of elements as there are spatial dimensions. The order
  6568. /// of values is the same as in the tensor: depth (for 3D tensors),
  6569. /// height (for 3D and 2D tensors), and width.
  6570. ///
  6571. /// @param aengine Engine to use.
  6572. /// @param aalgorithm Deconvolution algorithm. Possible values are
  6573. /// #dnnl::algorithm::deconvolution_direct, and
  6574. /// #dnnl::algorithm::deconvolution_winograd.
  6575. /// @param src_desc Source memory descriptor.
  6576. /// @param diff_weights_desc Diff weights memory descriptor.
  6577. /// @param diff_dst_desc Diff destination memory descriptor.
  6578. /// @param strides Strides for each spatial dimension.
  6579. /// @param dilates Dilations for each spatial dimension. A zero value
  6580. /// means no dilation in the corresponding dimension.
  6581. /// @param padding_l Vector of padding values for low indices for each
  6582. /// spatial dimension `([[front,] top,] left)`.
  6583. /// @param padding_r Vector of padding values for high indices for
  6584. /// each spatial dimension `([[back,] bottom,] right)`.
  6585. /// @param hint_fwd_pd Primitive descriptor for a deconvolution
  6586. /// forward propagation primitive. It is used as a hint for
  6587. /// deciding which memory format to use.
  6588. /// @param attr Primitive attributes to use. Attributes are optional
  6589. /// and default to empty attributes.
  6590. /// @param allow_empty A flag signifying whether construction is
  6591. /// allowed to fail without throwing an exception. In this case an
  6592. /// empty object will be produced. This flag is optional and
  6593. /// defaults to false.
  6594. primitive_desc(const engine &aengine, algorithm aalgorithm,
  6595. const memory::desc &src_desc,
  6596. const memory::desc &diff_weights_desc,
  6597. const memory::desc &diff_dst_desc, const memory::dims &strides,
  6598. const memory::dims &dilates, const memory::dims &padding_l,
  6599. const memory::dims &padding_r,
  6600. const deconvolution_forward::primitive_desc &hint_fwd_pd,
  6601. const primitive_attr &attr = default_attr(),
  6602. bool allow_empty = false)
  6603. : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
  6604. nullptr, diff_dst_desc, strides, &dilates, padding_l,
  6605. padding_r, hint_fwd_pd, attr, allow_empty) {}
  6606. /// Constructs a primitive descriptor for a deconvolution weights
  6607. /// gradient primitive from a C API primitive descriptor that must
  6608. /// have a matching kind.
  6609. ///
  6610. /// @param pd C API primitive descriptor for a deconvolution weights
  6611. /// gradient primitive.
  6612. primitive_desc(dnnl_primitive_desc_t pd)
  6613. : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
  6614. dnnl::prop_kind::backward_weights) {}
  6615. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  6616. memory::desc src_desc() const { return base::src_desc(0); }
  6617. /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
  6618. memory::desc diff_weights_desc() const {
  6619. return base::diff_weights_desc(0);
  6620. }
  6621. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  6622. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  6623. /// @copydoc dnnl::convolution_backward_weights::primitive_desc::diff_bias_desc()const
  6624. memory::desc diff_bias_desc() const {
  6625. return base::diff_weights_desc(1);
  6626. }
  6627. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  6628. algorithm get_algorithm() const { return base::get_algorithm(); }
  6629. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  6630. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  6631. /// @copydoc dnnl::primitive_desc_base::get_strides()const
  6632. memory::dims get_strides() const { return base::get_strides(); }
  6633. /// @copydoc dnnl::primitive_desc_base::get_dilations()const
  6634. memory::dims get_dilations() const { return base::get_dilations(); }
  6635. /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
  6636. memory::dims get_padding_l() const { return base::get_padding_l(); }
  6637. /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
  6638. memory::dims get_padding_r() const { return base::get_padding_r(); }
  6639. private:
  6640. primitive_desc(const engine &aengine, algorithm aalgorithm,
  6641. const memory::desc &src_desc,
  6642. const memory::desc &diff_weights_desc,
  6643. const memory::desc *diff_bias_desc,
  6644. const memory::desc &diff_dst_desc, const memory::dims &strides,
  6645. const memory::dims *dilates, const memory::dims &padding_l,
  6646. const memory::dims &padding_r,
  6647. const deconvolution_forward::primitive_desc &hint_fwd_pd,
  6648. const primitive_attr &attr, bool allow_empty) {
  6649. memory::validate_dims(strides, src_desc.get_ndims() - 2);
  6650. memory::validate_dims(padding_l, src_desc.get_ndims() - 2);
  6651. memory::validate_dims(padding_r, src_desc.get_ndims() - 2);
  6652. if (dilates)
  6653. memory::validate_dims(*dilates, src_desc.get_ndims() - 2);
  6654. dnnl_primitive_desc_t pd = nullptr;
  6655. dnnl_status_t status
  6656. = dnnl_deconvolution_backward_weights_primitive_desc_create(
  6657. &pd, aengine.get(), convert_to_c(aalgorithm),
  6658. src_desc.get(), diff_weights_desc.get(),
  6659. optional_arg(diff_bias_desc), diff_dst_desc.get(),
  6660. &strides[0], optional_arg(dilates), &padding_l[0],
  6661. &padding_r[0], hint_fwd_pd.get(), attr.get());
  6662. if (!allow_empty)
  6663. error::wrap_c_api(status,
  6664. "could not create a primitive descriptor for "
  6665. "the deconvolution weights update primitive. Run "
  6666. "workload with environment variable ONEDNN_VERBOSE=all "
  6667. "to get additional diagnostic information.");
  6668. reset(pd);
  6669. }
  6670. };
  6671. /// Default constructor. Produces an empty object.
  6672. deconvolution_backward_weights() = default;
  6673. /// Constructs a deconvolution weights gradient primitive.
  6674. /// @param pd Primitive descriptor for a deconvolution weights gradient
  6675. /// primitive.
  6676. deconvolution_backward_weights(const primitive_desc &pd) : primitive(pd) {}
  6677. /// Constructs a deconvolution weights gradient primitive from a cache
  6678. /// blob.
  6679. /// @param pd Primitive descriptor for a deconvolution weights gradient
  6680. /// primitive.
  6681. /// @param cache_blob Cache blob.
  6682. deconvolution_backward_weights(
  6683. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  6684. : primitive(pd, cache_blob) {}
  6685. };
  6686. /// @} dnnl_api_deconvolution
  6687. /// @addtogroup dnnl_api_lrn LRN
  6688. ///
  6689. /// A primitive to perform local response normalization (LRN) across or within
  6690. /// channels.
  6691. ///
  6692. /// @sa @ref dev_guide_lrn in developer guide
  6693. ///
  6694. /// @{
  6695. /// Local response normalization (LRN) forward propagation primitive.
  6696. struct lrn_forward : public primitive {
  6697. /// Primitive descriptor for an LRN forward propagation primitive.
  6698. struct primitive_desc : public dnnl::primitive_desc {
  6699. /// Default constructor. Produces an empty object.
  6700. primitive_desc() = default;
  6701. /// Constructs a primitive descriptor for an LRN forward propagation
  6702. /// primitive.
  6703. ///
  6704. /// @param aengine Engine to use.
  6705. /// @param aprop_kind Propagation kind. Possible values are
  6706. /// #dnnl::prop_kind::forward_training, and
  6707. /// #dnnl::prop_kind::forward_inference.
  6708. /// @param aalgorithm LRN algorithm kind: either
  6709. /// #dnnl::algorithm::lrn_across_channels, or
  6710. /// #dnnl::algorithm::lrn_within_channel.
  6711. /// @param src_desc Source memory descriptor.
  6712. /// @param dst_desc Destination memory descriptor.
  6713. /// @param local_size Regularization local size.
  6714. /// @param alpha The alpha regularization parameter.
  6715. /// @param beta The beta regularization parameter.
  6716. /// @param k The k regularization parameter.
  6717. /// @param attr Primitive attributes to use. Attributes are optional
  6718. /// and default to empty attributes.
  6719. /// @param allow_empty A flag signifying whether construction is
  6720. /// allowed to fail without throwing an exception. In this case an
  6721. /// empty object will be produced. This flag is optional and
  6722. /// defaults to false.
  6723. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  6724. algorithm aalgorithm, const memory::desc &src_desc,
  6725. const memory::desc &dst_desc, memory::dim local_size,
  6726. float alpha, float beta, float k,
  6727. const primitive_attr &attr = default_attr(),
  6728. bool allow_empty = false) {
  6729. dnnl_primitive_desc_t pd = nullptr;
  6730. dnnl_status_t status = dnnl_lrn_forward_primitive_desc_create(&pd,
  6731. aengine.get(), dnnl::convert_to_c(aprop_kind),
  6732. convert_to_c(aalgorithm), src_desc.get(), dst_desc.get(),
  6733. local_size, alpha, beta, k, attr.get());
  6734. if (!allow_empty)
  6735. error::wrap_c_api(status,
  6736. "could not create a primitive descriptor for "
  6737. "the lrn forward propagation primitive. Run workload "
  6738. "with environment variable ONEDNN_VERBOSE=all to get "
  6739. "additional diagnostic information.");
  6740. reset(pd);
  6741. }
  6742. /// Constructs a primitive descriptor for an LRN forward propagation
  6743. /// primitive from a C API primitive descriptor that must have a
  6744. /// matching kind.
  6745. ///
  6746. /// @param pd C API primitive descriptor for an LRN forward
  6747. /// propagation primitive.
  6748. primitive_desc(dnnl_primitive_desc_t pd)
  6749. : dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn,
  6750. dnnl::prop_kind::forward_training,
  6751. dnnl::prop_kind::forward_inference) {}
  6752. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  6753. memory::desc src_desc() const { return base::src_desc(0); }
  6754. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  6755. memory::desc dst_desc() const { return base::dst_desc(0); }
  6756. /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
  6757. memory::desc workspace_desc() const { return base::workspace_desc(); }
  6758. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  6759. algorithm get_algorithm() const { return base::get_algorithm(); }
  6760. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  6761. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  6762. /// @copydoc dnnl::primitive_desc_base::get_alpha()const
  6763. float get_alpha() const { return base::get_alpha(); }
  6764. /// @copydoc dnnl::primitive_desc_base::get_beta()const
  6765. float get_beta() const { return base::get_beta(); }
  6766. /// @copydoc dnnl::primitive_desc_base::get_local_size()const
  6767. memory::dim get_local_size() const { return base::get_local_size(); }
  6768. /// @copydoc dnnl::primitive_desc_base::get_k()const
  6769. float get_k() const { return base::get_k(); }
  6770. };
  6771. /// Default constructor. Produces an empty object.
  6772. lrn_forward() = default;
  6773. /// Constructs an LRN forward propagation primitive.
  6774. /// @param pd Primitive descriptor for an LRN forward propagation
  6775. /// primitive.
  6776. lrn_forward(const primitive_desc &pd) : primitive(pd) {}
  6777. /// Constructs an LRN forward propagation primitive from a cache blob.
  6778. /// @param pd Primitive descriptor for an LRN forward propagation
  6779. /// primitive.
  6780. /// @param cache_blob Cache blob.
  6781. lrn_forward(
  6782. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  6783. : primitive(pd, cache_blob) {}
  6784. };
  6785. /// Local response normalization (LRN) backward propagation primitive.
  6786. struct lrn_backward : public primitive {
  6787. /// Primitive descriptor for an LRN backward propagation primitive.
  6788. struct primitive_desc : public dnnl::primitive_desc {
  6789. /// Default constructor. Produces an empty object.
  6790. primitive_desc() = default;
  6791. /// Constructs a primitive descriptor for an LRN backward propagation
  6792. /// primitive.
  6793. ///
  6794. /// @param aengine Engine to use.
  6795. /// @param aalgorithm LRN algorithm kind: either
  6796. /// #dnnl::algorithm::lrn_across_channels, or
  6797. /// #dnnl::algorithm::lrn_within_channel.
  6798. /// @param diff_src_desc Diff source memory descriptor.
  6799. /// @param diff_dst_desc Diff destination memory descriptor.
  6800. /// @param src_desc Source memory descriptor.
  6801. /// @param local_size Regularization local size.
  6802. /// @param alpha The alpha regularization parameter.
  6803. /// @param beta The beta regularization parameter.
  6804. /// @param k The k regularization parameter.
  6805. /// @param hint_fwd_pd Primitive descriptor for an LRN forward
  6806. /// propagation primitive. It is used as a hint for deciding which
  6807. /// memory format to use.
  6808. /// @param attr Primitive attributes to use. Attributes are optional
  6809. /// and default to empty attributes.
  6810. /// @param allow_empty A flag signifying whether construction is
  6811. /// allowed to fail without throwing an exception. In this case an
  6812. /// empty object will be produced. This flag is optional and
  6813. /// defaults to false.
  6814. primitive_desc(const engine &aengine, algorithm aalgorithm,
  6815. const memory::desc &diff_src_desc,
  6816. const memory::desc &diff_dst_desc, const memory::desc &src_desc,
  6817. memory::dim local_size, float alpha, float beta, float k,
  6818. const lrn_forward::primitive_desc &hint_fwd_pd,
  6819. const primitive_attr &attr = default_attr(),
  6820. bool allow_empty = false) {
  6821. dnnl_primitive_desc_t pd = nullptr;
  6822. dnnl_status_t status = dnnl_lrn_backward_primitive_desc_create(&pd,
  6823. aengine.get(), convert_to_c(aalgorithm),
  6824. diff_src_desc.get(), diff_dst_desc.get(), src_desc.get(),
  6825. local_size, alpha, beta, k, hint_fwd_pd.get(), attr.get());
  6826. if (!allow_empty)
  6827. error::wrap_c_api(status,
  6828. "could not create a primitive descriptor for "
  6829. "the lrn backward propagation primitive. Run workload "
  6830. "with environment variable ONEDNN_VERBOSE=all to get "
  6831. "additional diagnostic information.");
  6832. reset(pd);
  6833. }
  6834. /// Constructs a primitive descriptor for an LRN backward propagation
  6835. /// primitive from a C API primitive descriptor that must have a
  6836. /// matching kind.
  6837. ///
  6838. /// @param pd C API primitive descriptor for an LRN backward
  6839. /// propagation primitive.
  6840. primitive_desc(dnnl_primitive_desc_t pd)
  6841. : dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn,
  6842. dnnl::prop_kind::backward_data) {}
  6843. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  6844. memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
  6845. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  6846. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  6847. /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
  6848. memory::desc workspace_desc() const { return base::workspace_desc(); }
  6849. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  6850. algorithm get_algorithm() const { return base::get_algorithm(); }
  6851. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  6852. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  6853. /// @copydoc dnnl::primitive_desc_base::get_alpha()const
  6854. float get_alpha() const { return base::get_alpha(); }
  6855. /// @copydoc dnnl::primitive_desc_base::get_beta()const
  6856. float get_beta() const { return base::get_beta(); }
  6857. /// @copydoc dnnl::primitive_desc_base::get_local_size()const
  6858. memory::dim get_local_size() const { return base::get_local_size(); }
  6859. /// @copydoc dnnl::primitive_desc_base::get_k()const
  6860. float get_k() const { return base::get_k(); }
  6861. };
  6862. /// Default constructor. Produces an empty object.
  6863. lrn_backward() = default;
  6864. /// Constructs an LRN backward propagation primitive.
  6865. /// @param pd Primitive descriptor for an LRN backward propagation
  6866. /// primitive.
  6867. lrn_backward(const primitive_desc &pd) : primitive(pd) {}
  6868. /// Constructs an LRN backward propagation primitive from a cache blob.
  6869. /// @param pd Primitive descriptor for an LRN backward propagation
  6870. /// primitive.
  6871. /// @param cache_blob Cache blob.
  6872. lrn_backward(
  6873. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  6874. : primitive(pd, cache_blob) {}
  6875. };
  6876. /// @} dnnl_api_lrn
  6877. /// @addtogroup dnnl_api_eltwise Eltwise
  6878. ///
  6879. /// A primitive to perform elementwise operations such as the
  6880. /// rectifier linear unit (ReLU).
  6881. ///
  6882. /// Both forward and backward propagation primitives support in-place
  6883. /// operation; that is, src and dst can refer to the same memory for forward
  6884. /// propagation, and diff_dst and diff_src can refer to the same memory for
  6885. /// backward propagation.
  6886. ///
  6887. /// @warning
  6888. /// Because the original source data is required for backward propagation,
  6889. /// in-place forward propagation is not generally supported in the
  6890. /// training mode. However, for algorithms supporting destination as input
  6891. /// memory, dst can be used for the backward propagation, which makes it
  6892. /// possible to get performance benefit even in the training mode.
  6893. ///
  6894. /// @sa @ref dev_guide_eltwise in developer guide
  6895. ///
  6896. /// @{
  6897. /// Elementwise unary operation forward propagation primitive.
  6898. struct eltwise_forward : public primitive {
  6899. /// Primitive descriptor for an elementwise forward propagation primitive.
  6900. struct primitive_desc : public dnnl::primitive_desc {
  6901. /// Default constructor. Produces an empty object.
  6902. primitive_desc() = default;
  6903. /// Constructs a primitive descriptor for an elementwise forward
  6904. /// propagation primitive.
  6905. ///
  6906. /// @param aengine Engine to use.
  6907. /// @param aprop_kind Propagation kind. Possible values are
  6908. /// #dnnl::prop_kind::forward_training, and
  6909. /// #dnnl::prop_kind::forward_inference.
  6910. /// @param aalgorithm Elementwise algorithm kind.
  6911. /// @param src_desc Source memory descriptor.
  6912. /// @param dst_desc Destination memory descriptor.
  6913. /// @param attr Primitive attributes to use. Attributes are optional
  6914. /// and default to empty attributes.
  6915. /// @param allow_empty A flag signifying whether construction is
  6916. /// allowed to fail without throwing an exception. In this case an
  6917. /// empty object will be produced. This flag is optional and
  6918. /// defaults to false.
  6919. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  6920. algorithm aalgorithm, const memory::desc &src_desc,
  6921. const memory::desc &dst_desc,
  6922. const primitive_attr &attr = default_attr(),
  6923. bool allow_empty = false)
  6924. : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
  6925. dst_desc, nullptr, nullptr, attr, allow_empty) {}
  6926. /// Constructs a primitive descriptor for an elementwise forward
  6927. /// propagation primitive with an alpha parameter.
  6928. ///
  6929. /// @param aengine Engine to use.
  6930. /// @param aprop_kind Propagation kind. Possible values are
  6931. /// #dnnl::prop_kind::forward_training, and
  6932. /// #dnnl::prop_kind::forward_inference.
  6933. /// @param aalgorithm Elementwise algorithm kind.
  6934. /// @param src_desc Source memory descriptor.
  6935. /// @param dst_desc Destination memory descriptor.
  6936. /// @param alpha The alpha parameter for the elementwise operation.
  6937. /// Specific meaning depends on the algorithm.
  6938. /// @param attr Primitive attributes to use. Attributes are optional
  6939. /// and default to empty attributes.
  6940. /// @param allow_empty A flag signifying whether construction is
  6941. /// allowed to fail without throwing an exception. In this case an
  6942. /// empty object will be produced. This flag is optional and
  6943. /// defaults to false.
  6944. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  6945. algorithm aalgorithm, const memory::desc &src_desc,
  6946. const memory::desc &dst_desc, float alpha,
  6947. const primitive_attr &attr = default_attr(),
  6948. bool allow_empty = false)
  6949. : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
  6950. dst_desc, &alpha, nullptr, attr, allow_empty) {}
  6951. /// Constructs a primitive descriptor for an elementwise forward
  6952. /// propagation primitive with an alpha and beta parameters.
  6953. ///
  6954. /// @param aengine Engine to use.
  6955. /// @param aprop_kind Propagation kind. Possible values are
  6956. /// #dnnl::prop_kind::forward_training, and
  6957. /// #dnnl::prop_kind::forward_inference.
  6958. /// @param aalgorithm Elementwise algorithm kind.
  6959. /// @param src_desc Source memory descriptor.
  6960. /// @param dst_desc Destination memory descriptor.
  6961. /// @param alpha The alpha parameter for the elementwise operation.
  6962. /// Specific meaning depends on the algorithm.
  6963. /// @param beta The beta parameter for the elementwise operation.
  6964. /// Specific meaning depends on the algorithm.
  6965. /// @param attr Primitive attributes to use. Attributes are optional
  6966. /// and default to empty attributes.
  6967. /// @param allow_empty A flag signifying whether construction is
  6968. /// allowed to fail without throwing an exception. In this case an
  6969. /// empty object will be produced. This flag is optional and
  6970. /// defaults to false.
  6971. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  6972. algorithm aalgorithm, const memory::desc &src_desc,
  6973. const memory::desc &dst_desc, float alpha, float beta,
  6974. const primitive_attr &attr = default_attr(),
  6975. bool allow_empty = false)
  6976. : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
  6977. dst_desc, &alpha, &beta, attr, allow_empty) {}
  6978. /// Constructs a primitive descriptor for an eltwise forward
  6979. /// propagation primitive from a C API primitive descriptor that must
  6980. /// have a matching kind.
  6981. ///
  6982. /// @param pd C API primitive descriptor for an eltwise forward
  6983. /// propagation primitive.
  6984. primitive_desc(dnnl_primitive_desc_t pd)
  6985. : dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise,
  6986. dnnl::prop_kind::forward_training,
  6987. dnnl::prop_kind::forward_inference) {}
  6988. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  6989. memory::desc src_desc() const { return base::src_desc(0); }
  6990. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  6991. memory::desc dst_desc() const { return base::dst_desc(0); }
  6992. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  6993. dnnl::algorithm get_algorithm() const { return base::get_algorithm(); }
  6994. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  6995. dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  6996. /// @copydoc dnnl::primitive_desc_base::get_alpha()const
  6997. float get_alpha() const { return base::get_alpha(); }
  6998. /// @copydoc dnnl::primitive_desc_base::get_beta()const
  6999. float get_beta() const { return base::get_beta(); }
  7000. private:
  7001. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  7002. algorithm aalgorithm, const memory::desc &src_desc,
  7003. const memory::desc &dst_desc, const float *alpha,
  7004. const float *beta, const primitive_attr &attr,
  7005. bool allow_empty) {
  7006. dnnl_primitive_desc_t pd = nullptr;
  7007. dnnl_status_t status = dnnl_eltwise_forward_primitive_desc_create(
  7008. &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
  7009. dnnl::convert_to_c(aalgorithm), src_desc.get(),
  7010. dst_desc.get(), alpha ? *alpha : 0.0f, beta ? *beta : 0.0f,
  7011. attr.get());
  7012. if (!allow_empty)
  7013. error::wrap_c_api(status,
  7014. "could not create a primitive descriptor for "
  7015. "the eltwise forward propagation primitive. Run "
  7016. "workload with environment variable ONEDNN_VERBOSE=all "
  7017. "to get additional diagnostic information.");
  7018. reset(pd);
  7019. }
  7020. };
  7021. /// Default constructor. Produces an empty object.
  7022. eltwise_forward() = default;
  7023. /// Constructs an eltwise forward propagation primitive.
  7024. /// @param pd Primitive descriptor for an eltwise forward propagation
  7025. /// primitive.
  7026. eltwise_forward(const primitive_desc &pd) : primitive(pd) {}
  7027. /// Constructs an eltwise forward propagation primitive from a cache blob.
  7028. /// @param pd Primitive descriptor for an eltwise forward propagation
  7029. /// primitive.
  7030. /// @param cache_blob Cache blob.
  7031. eltwise_forward(
  7032. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  7033. : primitive(pd, cache_blob) {}
  7034. };
  7035. /// Elementwise unary operation backward propagation primitive.
  7036. struct eltwise_backward : public primitive {
  7037. /// Primitive descriptor for eltwise backward propagation.
  7038. struct primitive_desc : public dnnl::primitive_desc {
  7039. /// Default constructor. Produces an empty object.
  7040. primitive_desc() = default;
  7041. /// Constructs a primitive descriptor for an elementwise backward
  7042. /// propagation primitive with an alpha parameter.
  7043. ///
  7044. /// @param aengine Engine to use.
  7045. /// @param aalgorithm Elementwise algorithm kind.
  7046. /// @param diff_src_desc Diff source memory descriptor.
  7047. /// @param diff_dst_desc Diff destination memory descriptor.
  7048. /// @param data_desc Destination memory descriptor if one of the
  7049. /// "use_dst_for_bwd" algorithms are used (such as
  7050. /// #dnnl_eltwise_relu_use_dst_for_bwd), source memory descriptor
  7051. /// otherwise.
  7052. /// @param hint_fwd_pd Primitive descriptor for an elementwise
  7053. /// forward propagation primitive. It is used as a hint for
  7054. /// deciding which memory format to use.
  7055. /// @param attr Primitive attributes to use. Attributes are optional
  7056. /// and default to empty attributes.
  7057. /// @param allow_empty A flag signifying whether construction is
  7058. /// allowed to fail without throwing an exception. In this case an
  7059. /// empty object will be produced. This flag is optional and
  7060. /// defaults to false.
  7061. primitive_desc(const engine &aengine, algorithm aalgorithm,
  7062. const memory::desc &diff_src_desc,
  7063. const memory::desc &diff_dst_desc,
  7064. const memory::desc &data_desc,
  7065. const eltwise_forward::primitive_desc &hint_fwd_pd,
  7066. const primitive_attr &attr = default_attr(),
  7067. bool allow_empty = false)
  7068. : primitive_desc(aengine, aalgorithm, diff_src_desc, diff_dst_desc,
  7069. data_desc, nullptr, nullptr, hint_fwd_pd, attr,
  7070. allow_empty) {}
  7071. /// Constructs a primitive descriptor for an elementwise backward
  7072. /// propagation primitive with an alpha parameter.
  7073. ///
  7074. /// @param aengine Engine to use.
  7075. /// @param aalgorithm Elementwise algorithm kind.
  7076. /// @param diff_src_desc Diff source memory descriptor.
  7077. /// @param diff_dst_desc Diff destination memory descriptor.
  7078. /// @param data_desc Destination memory descriptor if one of the
  7079. /// "use_dst_for_bwd" algorithms are used (such as
  7080. /// #dnnl_eltwise_relu_use_dst_for_bwd), source memory descriptor
  7081. /// otherwise.
  7082. /// @param alpha The alpha parameter for the elementwise operation.
  7083. /// Specific meaning depends on the algorithm.
  7084. /// @param hint_fwd_pd Primitive descriptor for an elementwise
  7085. /// forward propagation primitive. It is used as a hint for
  7086. /// deciding which memory format to use.
  7087. /// @param attr Primitive attributes to use. Attributes are optional
  7088. /// and default to empty attributes.
  7089. /// @param allow_empty A flag signifying whether construction is
  7090. /// allowed to fail without throwing an exception. In this case an
  7091. /// empty object will be produced. This flag is optional and
  7092. /// defaults to false.
  7093. primitive_desc(const engine &aengine, algorithm aalgorithm,
  7094. const memory::desc &diff_src_desc,
  7095. const memory::desc &diff_dst_desc,
  7096. const memory::desc &data_desc, float alpha,
  7097. const eltwise_forward::primitive_desc &hint_fwd_pd,
  7098. const primitive_attr &attr = default_attr(),
  7099. bool allow_empty = false)
  7100. : primitive_desc(aengine, aalgorithm, diff_src_desc, diff_dst_desc,
  7101. data_desc, &alpha, nullptr, hint_fwd_pd, attr,
  7102. allow_empty) {}
  7103. /// Constructs a primitive descriptor for an elementwise backward
  7104. /// propagation primitive with an alpha and beta parameters.
  7105. ///
  7106. /// @param aengine Engine to use.
  7107. /// @param aalgorithm Elementwise algorithm kind.
  7108. /// @param diff_src_desc Diff source memory descriptor.
  7109. /// @param diff_dst_desc Diff destination memory descriptor.
  7110. /// @param data_desc Destination memory descriptor if one of the
  7111. /// "use_dst_for_bwd" algorithms are used (such as
  7112. /// #dnnl_eltwise_relu_use_dst_for_bwd), source memory descriptor
  7113. /// otherwise.
  7114. /// @param alpha The alpha parameter for the elementwise operation.
  7115. /// Specific meaning depends on the algorithm.
  7116. /// @param beta The beta parameter for the elementwise operation.
  7117. /// Specific meaning depends on the algorithm.
  7118. /// @param hint_fwd_pd Primitive descriptor for an elementwise
  7119. /// forward propagation primitive. It is used as a hint for
  7120. /// deciding which memory format to use.
  7121. /// @param attr Primitive attributes to use. Attributes are optional
  7122. /// and default to empty attributes.
  7123. /// @param allow_empty A flag signifying whether construction is
  7124. /// allowed to fail without throwing an exception. In this case an
  7125. /// empty object will be produced. This flag is optional and
  7126. /// defaults to false.
  7127. primitive_desc(const engine &aengine, algorithm aalgorithm,
  7128. const memory::desc &diff_src_desc,
  7129. const memory::desc &diff_dst_desc,
  7130. const memory::desc &data_desc, float alpha, float beta,
  7131. const eltwise_forward::primitive_desc &hint_fwd_pd,
  7132. const primitive_attr &attr = default_attr(),
  7133. bool allow_empty = false)
  7134. : primitive_desc(aengine, aalgorithm, diff_src_desc, diff_dst_desc,
  7135. data_desc, &alpha, &beta, hint_fwd_pd, attr, allow_empty) {}
  7136. /// Constructs a primitive descriptor for an eltwise backward
  7137. /// propagation primitive from a C API primitive descriptor that must
  7138. /// have a matching kind.
  7139. ///
  7140. /// @param pd C API primitive descriptor for an eltwise backward
  7141. /// propagation primitive.
  7142. primitive_desc(dnnl_primitive_desc_t pd)
  7143. : dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise,
  7144. dnnl::prop_kind::backward_data) {}
  7145. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  7146. memory::desc src_desc() const { return base::src_desc(0); }
  7147. /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
  7148. memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
  7149. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  7150. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  7151. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  7152. dnnl::algorithm get_algorithm() const { return base::get_algorithm(); }
  7153. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  7154. dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  7155. /// @copydoc dnnl::primitive_desc_base::get_alpha()const
  7156. float get_alpha() const { return base::get_alpha(); }
  7157. /// @copydoc dnnl::primitive_desc_base::get_beta()const
  7158. float get_beta() const { return base::get_beta(); }
  7159. private:
  7160. primitive_desc(const engine &aengine, algorithm aalgorithm,
  7161. const memory::desc &diff_src_desc,
  7162. const memory::desc &diff_dst_desc,
  7163. const memory::desc &data_desc, const float *alpha,
  7164. const float *beta,
  7165. const eltwise_forward::primitive_desc &hint_fwd_pd,
  7166. const primitive_attr &attr, bool allow_empty) {
  7167. dnnl_primitive_desc_t pd = nullptr;
  7168. dnnl_status_t status = dnnl_eltwise_backward_primitive_desc_create(
  7169. &pd, aengine.get(), dnnl::convert_to_c(aalgorithm),
  7170. diff_src_desc.get(), diff_dst_desc.get(), data_desc.get(),
  7171. alpha ? *alpha : 0.0f, beta ? *beta : 0.0f,
  7172. hint_fwd_pd.get(), attr.get());
  7173. if (!allow_empty)
  7174. error::wrap_c_api(status,
  7175. "could not create a primitive descriptor for "
  7176. "the eltwise backward propagation primitive. Run "
  7177. "workload with environment variable ONEDNN_VERBOSE=all "
  7178. "to get additional diagnostic information.");
  7179. reset(pd);
  7180. }
  7181. };
  7182. /// Default constructor. Produces an empty object.
  7183. eltwise_backward() = default;
  7184. /// Constructs an eltwise backward propagation primitive.
  7185. /// @param pd Primitive descriptor for an eltwise backward propagation
  7186. /// primitive.
  7187. eltwise_backward(const primitive_desc &pd) : primitive(pd) {}
  7188. /// Constructs an eltwise backward propagation primitive from a cache blob.
  7189. /// @param pd Primitive descriptor for an eltwise backward propagation
  7190. /// primitive.
  7191. /// @param cache_blob Cache blob.
  7192. eltwise_backward(
  7193. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  7194. : primitive(pd, cache_blob) {}
  7195. };
  7196. /// @} dnnl_api_eltwise
  7197. /// @addtogroup dnnl_api_softmax Softmax
  7198. ///
  7199. /// A primitive to perform softmax.
  7200. ///
  7201. /// @sa @ref dev_guide_softmax in developer guide
  7202. ///
  7203. /// @{
  7204. /// Softmax forward propagation primitive.
  7205. struct softmax_forward : public primitive {
  7206. /// Primitive descriptor for a softmax forward propagation primitive.
  7207. struct primitive_desc : public dnnl::primitive_desc {
  7208. /// Default constructor. Produces an empty object.
  7209. primitive_desc() = default;
  7210. /// Constructs a primitive descriptor for a softmax forward propagation
  7211. /// primitive.
  7212. ///
  7213. /// @param aengine Engine to use.
  7214. /// @param aprop_kind Propagation kind. Possible values are
  7215. /// #dnnl::prop_kind::forward_training, and
  7216. /// #dnnl::prop_kind::forward_inference.
  7217. /// @param aalgorithm Softmax algorithm kind: either
  7218. /// #dnnl::algorithm::softmax_accurate,
  7219. /// or #dnnl::algorithm::softmax_log.
  7220. /// @param src_desc Source memory descriptor.
  7221. /// @param dst_desc Destination memory descriptor.
  7222. /// @param axis Axis over which softmax is computed.
  7223. /// @param attr Primitive attributes to use. Attributes are optional
  7224. /// and default to empty attributes.
  7225. /// @param allow_empty A flag signifying whether construction is
  7226. /// allowed to fail without throwing an exception. In this case an
  7227. /// empty object will be produced. This flag is optional and
  7228. /// defaults to false.
  7229. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  7230. algorithm aalgorithm, const memory::desc &src_desc,
  7231. const memory::desc &dst_desc, int axis,
  7232. const primitive_attr &attr = default_attr(),
  7233. bool allow_empty = false) {
  7234. dnnl_primitive_desc_t pd = nullptr;
  7235. dnnl_status_t status = dnnl_softmax_forward_primitive_desc_create(
  7236. &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
  7237. dnnl::convert_to_c(aalgorithm), src_desc.get(),
  7238. dst_desc.get(), axis, attr.get());
  7239. if (!allow_empty)
  7240. error::wrap_c_api(status,
  7241. "could not create a primitive descriptor for "
  7242. "the softmax forward propagation primitive. Run "
  7243. "workload with environment variable ONEDNN_VERBOSE=all "
  7244. "to get additional diagnostic information.");
  7245. reset(pd);
  7246. }
  7247. /// Constructs a primitive descriptor for a softmax forward
  7248. /// propagation primitive from a C API primitive descriptor that must
  7249. /// have a matching kind.
  7250. ///
  7251. /// @param pd C API primitive descriptor for a softmax forward
  7252. /// propagation primitive.
  7253. primitive_desc(dnnl_primitive_desc_t pd)
  7254. : dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax,
  7255. dnnl::prop_kind::forward_training,
  7256. dnnl::prop_kind::forward_inference) {}
  7257. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  7258. memory::desc src_desc() const { return base::src_desc(0); }
  7259. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  7260. memory::desc dst_desc() const { return base::dst_desc(0); }
  7261. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  7262. dnnl::algorithm get_algorithm() const { return base::get_algorithm(); }
  7263. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  7264. dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  7265. /// @copydoc dnnl::primitive_desc_base::get_axis()const
  7266. int get_axis() const { return base::get_axis(); }
  7267. };
  7268. /// Default constructor. Produces an empty object.
  7269. softmax_forward() = default;
  7270. /// Constructs a softmax forward propagation primitive.
  7271. /// @param pd Primitive descriptor for a softmax forward propagation
  7272. /// primitive.
  7273. softmax_forward(const primitive_desc &pd) : primitive(pd) {}
  7274. /// Constructs a softmax forward propagation primitive from a cache blob.
  7275. /// @param pd Primitive descriptor for a softmax forward propagation
  7276. /// primitive.
  7277. /// @param cache_blob Cache blob.
  7278. softmax_forward(
  7279. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  7280. : primitive(pd, cache_blob) {}
  7281. };
  7282. /// Softmax backward propagation primitive.
  7283. struct softmax_backward : public primitive {
  7284. /// Primitive descriptor for a softmax backward propagation primitive.
  7285. struct primitive_desc : public dnnl::primitive_desc {
  7286. /// Default constructor. Produces an empty object.
  7287. primitive_desc() = default;
  7288. /// Constructs a primitive descriptor for a softmax backward propagation
  7289. /// primitive.
  7290. ///
  7291. /// @param aengine Engine to use.
  7292. /// @param aalgorithm Softmax algorithm kind: either
  7293. /// #dnnl::algorithm::softmax_accurate,
  7294. /// or #dnnl::algorithm::softmax_log.
  7295. /// @param diff_src_desc Diff source memory descriptor.
  7296. /// @param diff_dst_desc Diff destination memory descriptor.
  7297. /// @param dst_desc Destination memory descriptor.
  7298. /// @param axis Axis over which softmax is computed.
  7299. /// @param hint_fwd_pd Primitive descriptor for a softmax
  7300. /// forward propagation primitive. It is used as a hint for
  7301. /// deciding which memory format to use.
  7302. /// @param attr Primitive attributes to use. Attributes are optional
  7303. /// and default to empty attributes.
  7304. /// @param allow_empty A flag signifying whether construction is
  7305. /// allowed to fail without throwing an exception. In this case an
  7306. /// empty object will be produced. This flag is optional and
  7307. /// defaults to false.
  7308. primitive_desc(const engine &aengine, algorithm aalgorithm,
  7309. const memory::desc &diff_src_desc,
  7310. const memory::desc &diff_dst_desc, const memory::desc &dst_desc,
  7311. int axis, const softmax_forward::primitive_desc &hint_fwd_pd,
  7312. const primitive_attr &attr = default_attr(),
  7313. bool allow_empty = false) {
  7314. dnnl_primitive_desc_t pd = nullptr;
  7315. dnnl_status_t status = dnnl_softmax_backward_primitive_desc_create(
  7316. &pd, aengine.get(), dnnl::convert_to_c(aalgorithm),
  7317. diff_src_desc.get(), diff_dst_desc.get(), dst_desc.get(),
  7318. axis, hint_fwd_pd.get(), attr.get());
  7319. if (!allow_empty)
  7320. error::wrap_c_api(status,
  7321. "could not create a primitive descriptor for "
  7322. "the softmax backward propagation primitive. Run "
  7323. "workload with environment variable ONEDNN_VERBOSE=all "
  7324. "to get additional diagnostic information.");
  7325. reset(pd);
  7326. }
  7327. /// Constructs a primitive descriptor for a softmax backward
  7328. /// propagation primitive from a C API primitive descriptor that must
  7329. /// have a matching kind.
  7330. ///
  7331. /// @param pd C API primitive descriptor for a softmax backward
  7332. /// propagation primitive.
  7333. primitive_desc(dnnl_primitive_desc_t pd)
  7334. : dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax,
  7335. dnnl::prop_kind::backward_data) {}
  7336. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  7337. memory::desc dst_desc() const { return base::dst_desc(0); }
  7338. /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
  7339. memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
  7340. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  7341. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  7342. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  7343. dnnl::algorithm get_algorithm() const { return base::get_algorithm(); }
  7344. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  7345. dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  7346. /// @copydoc dnnl::primitive_desc_base::get_axis()const
  7347. int get_axis() const { return base::get_axis(); }
  7348. };
  7349. /// Default constructor. Produces an empty object.
  7350. softmax_backward() = default;
  7351. /// Constructs a softmax backward propagation primitive.
  7352. /// @param pd Primitive descriptor for a softmax backward propagation
  7353. /// primitive.
  7354. softmax_backward(const primitive_desc &pd) : primitive(pd) {}
  7355. /// Constructs a softmax backward propagation primitive from a cache blob.
  7356. /// @param pd Primitive descriptor for a softmax backward propagation
  7357. /// primitive.
  7358. /// @param cache_blob Cache blob.
  7359. softmax_backward(
  7360. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  7361. : primitive(pd, cache_blob) {}
  7362. };
  7363. /// @} dnnl_api_softmax
  7364. /// @addtogroup dnnl_api_batch_normalization Batch Normalization
  7365. ///
  7366. /// A primitive to perform batch normalization.
  7367. ///
  7368. /// Both forward and backward propagation primitives support in-place
  7369. /// operation; that is, src and dst can refer to the same memory for forward
  7370. /// propagation, and diff_dst and diff_src can refer to the same memory for
  7371. /// backward propagation.
  7372. ///
  7373. /// The batch normalization primitives computations can be controlled by
  7374. /// specifying different @ref dnnl::normalization_flags values. For example,
  7375. /// batch normalization forward propagation can be configured to either
  7376. /// compute the mean and variance or take them as arguments. It can either
  7377. /// perform scaling and shifting using gamma and beta parameters or not.
  7378. /// Optionally, it can also perform a fused ReLU, which in case of training
  7379. /// would also require a workspace.
  7380. ///
  7381. /// @sa @ref dev_guide_batch_normalization in developer guide
  7382. ///
  7383. /// @{
  7384. /// Batch normalization forward propagation primitive.
  7385. struct batch_normalization_forward : public primitive {
  7386. /// Primitive descriptor for a batch normalization forward propagation
  7387. /// primitive.
  7388. struct primitive_desc : public dnnl::primitive_desc {
  7389. /// Default constructor. Produces an empty object.
  7390. primitive_desc() = default;
  7391. /// Constructs a primitive descriptor for a batch normalization forward
  7392. /// propagation primitive.
  7393. ///
  7394. /// @note
  7395. /// In-place operation is supported: the dst can refer to the same
  7396. /// memory as the src.
  7397. ///
  7398. /// @param aengine Engine to use.
  7399. /// @param aprop_kind Propagation kind. Possible values are
  7400. /// #dnnl::prop_kind::forward_training and
  7401. /// #dnnl::prop_kind::forward_inference.
  7402. /// @param src_desc Source memory descriptor.
  7403. /// @param dst_desc Destination memory descriptor.
  7404. /// @param epsilon Batch normalization epsilon parameter.
  7405. /// @param flags Batch normalization flags (@ref
  7406. /// dnnl::normalization_flags).
  7407. /// @param attr Primitive attributes to use. Attributes are optional
  7408. /// and default to empty attributes.
  7409. /// @param allow_empty A flag signifying whether construction is
  7410. /// allowed to fail without throwing an exception. In this case an
  7411. /// empty object will be produced. This flag is optional and
  7412. /// defaults to false.
  7413. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  7414. const memory::desc &src_desc, const memory::desc &dst_desc,
  7415. float epsilon, normalization_flags flags,
  7416. const primitive_attr &attr = default_attr(),
  7417. bool allow_empty = false) {
  7418. dnnl_primitive_desc_t pd = nullptr;
  7419. dnnl_status_t status
  7420. = dnnl_batch_normalization_forward_primitive_desc_create(
  7421. &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
  7422. src_desc.get(), dst_desc.get(), epsilon,
  7423. convert_to_c(flags), attr.get());
  7424. if (!allow_empty)
  7425. error::wrap_c_api(status,
  7426. "could not create a primitive descriptor for "
  7427. "the batch normalization forward propagation "
  7428. "primitive. Run workload with environment variable "
  7429. "ONEDNN_VERBOSE=all to get additional diagnostic "
  7430. "information.");
  7431. reset(pd);
  7432. }
  7433. /// Constructs a primitive descriptor for a batch normalization
  7434. /// forward propagation primitive from a C API primitive descriptor
  7435. /// that must have a matching kind.
  7436. ///
  7437. /// @param pd C API primitive descriptor for a batch normalization
  7438. /// forward propagation primitive.
  7439. primitive_desc(dnnl_primitive_desc_t pd)
  7440. : dnnl::primitive_desc(pd,
  7441. dnnl::primitive::kind::batch_normalization,
  7442. dnnl::prop_kind::forward_training,
  7443. dnnl::prop_kind::forward_inference) {}
  7444. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  7445. memory::desc src_desc() const { return base::src_desc(0); }
  7446. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  7447. memory::desc dst_desc() const { return base::dst_desc(0); }
  7448. /// @copydoc dnnl::primitive_desc_base::weights_desc()const
  7449. memory::desc weights_desc() const { return base::weights_desc(0); }
  7450. /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
  7451. memory::desc workspace_desc() const { return base::workspace_desc(); }
  7452. /// Returns memory descriptor for mean.
  7453. /// @returns Memory descriptor for mean.
  7454. memory::desc mean_desc() const { return stat_desc(mean); }
  7455. /// Returns memory descriptor for variance.
  7456. /// @returns Memory descriptor for variance.
  7457. memory::desc variance_desc() const { return stat_desc(var); }
  7458. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  7459. dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  7460. /// @copydoc dnnl::primitive_desc_base::get_epsilon()const
  7461. float get_epsilon() const { return base::get_epsilon(); }
  7462. /// Returns normalization flags.
  7463. /// @return Normalization flags.
  7464. normalization_flags get_flags() const {
  7465. return base::get_flags<normalization_flags>();
  7466. }
  7467. private:
  7468. enum {
  7469. mean = 1,
  7470. var = 2,
  7471. };
  7472. memory::desc stat_desc(int kind) const {
  7473. const bool use_global_stats
  7474. = (get_flags() & normalization_flags::use_global_stats)
  7475. != normalization_flags::none;
  7476. return query_md(
  7477. use_global_stats ? query::src_md : query::dst_md, kind);
  7478. }
  7479. };
  7480. /// Default constructor. Produces an empty object.
  7481. batch_normalization_forward() = default;
  7482. /// Constructs a batch normalization forward propagation primitive.
  7483. /// @param pd Primitive descriptor for a batch normalization forward
  7484. /// propagation primitive.
  7485. batch_normalization_forward(const primitive_desc &pd) : primitive(pd) {}
  7486. /// Constructs a batch normalization forward propagation primitive from
  7487. /// a cache blob.
  7488. /// @param pd Primitive descriptor for a batch normalization forward
  7489. /// propagation primitive.
  7490. /// @param cache_blob Cache blob.
  7491. batch_normalization_forward(
  7492. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  7493. : primitive(pd, cache_blob) {}
  7494. };
  7495. /// Batch normalization backward propagation primitive.
  7496. struct batch_normalization_backward : public primitive {
  7497. /// Primitive descriptor for a batch normalization backward propagation
  7498. /// primitive.
  7499. struct primitive_desc : public dnnl::primitive_desc {
  7500. /// Default constructor. Produces an empty object.
  7501. primitive_desc() = default;
  7502. /// Constructs a primitive descriptor for a batch normalization backward
  7503. /// propagation primitive.
  7504. ///
  7505. /// @param aengine Engine to use.
  7506. /// @param aprop_kind Propagation kind. Possible values are
  7507. /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
  7508. /// (diffs for all parameters are computed in this case).
  7509. /// @param diff_src_desc Diff source memory descriptor.
  7510. /// @param diff_dst_desc Diff destination memory descriptor.
  7511. /// @param src_desc Source memory descriptor.
  7512. /// @param epsilon Batch normalization epsilon parameter.
  7513. /// @param flags Batch normalization flags (@ref
  7514. /// dnnl::normalization_flags).
  7515. /// @param hint_fwd_pd Primitive descriptor for a batch normalization
  7516. /// forward propagation primitive. It is used as a hint for
  7517. /// deciding which memory format to use.
  7518. /// @param attr Primitive attributes to use. Attributes are optional
  7519. /// and default to empty attributes.
  7520. /// @param allow_empty A flag signifying whether construction is
  7521. /// allowed to fail without throwing an exception. In this case an
  7522. /// empty object will be produced. This flag is optional and
  7523. /// defaults to false.
  7524. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  7525. const memory::desc &diff_src_desc,
  7526. const memory::desc &diff_dst_desc, const memory::desc &src_desc,
  7527. float epsilon, normalization_flags flags,
  7528. const batch_normalization_forward::primitive_desc &hint_fwd_pd,
  7529. const primitive_attr &attr = default_attr(),
  7530. bool allow_empty = false) {
  7531. dnnl_primitive_desc_t pd = nullptr;
  7532. dnnl_status_t status
  7533. = dnnl_batch_normalization_backward_primitive_desc_create(
  7534. &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
  7535. diff_src_desc.get(), diff_dst_desc.get(),
  7536. src_desc.get(), epsilon, convert_to_c(flags),
  7537. hint_fwd_pd.get(), attr.get());
  7538. if (!allow_empty)
  7539. error::wrap_c_api(status,
  7540. "could not create a primitive descriptor for "
  7541. "the batch normalization backward propagation "
  7542. "primitive. Run workload with environment variable "
  7543. "ONEDNN_VERBOSE=all to get additional diagnostic "
  7544. "information.");
  7545. reset(pd);
  7546. }
  7547. /// Constructs a primitive descriptor for a batch normalization
  7548. /// backward propagation primitive from a C API primitive descriptor
  7549. /// that must have a matching kind.
  7550. ///
  7551. /// @param pd C API primitive descriptor for a batch normalization
  7552. /// backward propagation primitive.
  7553. primitive_desc(dnnl_primitive_desc_t pd)
  7554. : dnnl::primitive_desc(pd,
  7555. dnnl::primitive::kind::batch_normalization,
  7556. dnnl::prop_kind::backward, dnnl::prop_kind::backward_data) {
  7557. }
  7558. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  7559. memory::desc src_desc() const { return base::src_desc(0); }
  7560. /// @copydoc dnnl::primitive_desc_base::weights_desc()const
  7561. memory::desc weights_desc() const { return base::weights_desc(0); }
  7562. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  7563. memory::desc dst_desc() const { return base::dst_desc(0); }
  7564. /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
  7565. memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
  7566. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  7567. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  7568. /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
  7569. memory::desc diff_weights_desc() const {
  7570. return base::diff_weights_desc(0);
  7571. }
  7572. /// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const
  7573. memory::desc mean_desc() const { return query_md(query::src_md, 1); }
  7574. /// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const
  7575. memory::desc variance_desc() const {
  7576. return query_md(query::src_md, 2);
  7577. }
  7578. /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
  7579. memory::desc workspace_desc() const { return base::workspace_desc(); }
  7580. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  7581. dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  7582. /// @copydoc dnnl::primitive_desc_base::get_epsilon()const
  7583. float get_epsilon() const { return base::get_epsilon(); }
  7584. /// Returns normalization flags.
  7585. /// @return Normalization flags.
  7586. normalization_flags get_flags() const {
  7587. return base::get_flags<normalization_flags>();
  7588. }
  7589. };
  7590. /// Default constructor. Produces an empty object.
  7591. batch_normalization_backward() = default;
  7592. /// Constructs a batch normalization backward propagation primitive.
  7593. /// @param pd Primitive descriptor for a batch normalization backward
  7594. /// propagation primitive.
  7595. batch_normalization_backward(const primitive_desc &pd) : primitive(pd) {}
  7596. /// Constructs a batch normalization backward propagation primitive from
  7597. /// a cache blob.
  7598. /// @param pd Primitive descriptor for a batch normalization backward
  7599. /// propagation primitive.
  7600. /// @param cache_blob Cache blob.
  7601. batch_normalization_backward(
  7602. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  7603. : primitive(pd, cache_blob) {}
  7604. };
  7605. /// @} dnnl_api_batch_normalization
  7606. /// @addtogroup dnnl_api_group_normalization Group Normalization
  7607. ///
  7608. /// A primitive to perform group normalization.
  7609. ///
  7610. /// Both forward and backward propagation primitives support in-place
  7611. /// operation; that is, src and dst can refer to the same memory for forward
  7612. /// propagation, and diff_dst and diff_src can refer to the same memory for
  7613. /// backward propagation.
  7614. ///
  7615. /// The group normalization primitives computations can be controlled by
  7616. /// specifying different @ref dnnl::normalization_flags values. For example,
  7617. /// group normalization forward propagation can be configured to either
  7618. /// compute the mean and variance or take them as arguments. It can either
  7619. /// perform scaling and shifting using gamma and beta parameters or not.
  7620. ///
  7621. /// @sa @ref dev_guide_group_normalization in developer guide
  7622. ///
  7623. /// @{
  7624. /// Group normalization forward propagation primitive.
  7625. struct group_normalization_forward : public primitive {
  7626. /// Primitive descriptor for a group normalization forward propagation
  7627. /// primitive.
  7628. struct primitive_desc : public dnnl::primitive_desc {
  7629. /// Default constructor. Produces an empty object.
  7630. primitive_desc() = default;
  7631. /// Constructs a primitive descriptor for a group normalization forward
  7632. /// propagation primitive.
  7633. ///
  7634. /// @note
  7635. /// In-place operation is supported: the dst can refer to the same
  7636. /// memory as the src.
  7637. ///
  7638. /// @param aengine Engine to use.
  7639. /// @param aprop_kind Propagation kind. Possible values are
  7640. /// #dnnl::prop_kind::forward_training and
  7641. /// #dnnl::prop_kind::forward_inference.
  7642. /// @param src_desc Source memory descriptor.
  7643. /// @param dst_desc Destination memory descriptor.
  7644. /// @param groups Group normalization groups parameter.
  7645. /// @param epsilon Group normalization epsilon parameter.
  7646. /// @param flags Group normalization flags (@ref
  7647. /// dnnl::normalization_flags).
  7648. /// @param attr Primitive attributes to use. Attributes are optional
  7649. /// and default to empty attributes.
  7650. /// @param allow_empty A flag signifying whether construction is
  7651. /// allowed to fail without throwing an exception. In this case an
  7652. /// empty object will be produced. This flag is optional and
  7653. /// defaults to false.
  7654. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  7655. const memory::desc &src_desc, const memory::desc &dst_desc,
  7656. memory::dim groups, float epsilon, normalization_flags flags,
  7657. const primitive_attr &attr = default_attr(),
  7658. bool allow_empty = false) {
  7659. dnnl_primitive_desc_t pd = nullptr;
  7660. dnnl_status_t status
  7661. = dnnl_group_normalization_forward_primitive_desc_create(
  7662. &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
  7663. src_desc.get(), dst_desc.get(), groups, epsilon,
  7664. convert_to_c(flags), attr.get());
  7665. if (!allow_empty)
  7666. error::wrap_c_api(status,
  7667. "could not create a primitive descriptor for "
  7668. "the group normalization forward propagation "
  7669. "primitive. Run workload with environment variable "
  7670. "ONEDNN_VERBOSE=all to get additional diagnostic "
  7671. "information.");
  7672. reset(pd);
  7673. }
  7674. /// Constructs a primitive descriptor for a group normalization
  7675. /// forward propagation primitive from a C API primitive descriptor
  7676. /// that must have a matching kind.
  7677. ///
  7678. /// @param pd C API primitive descriptor for a group normalization
  7679. /// forward propagation primitive.
  7680. primitive_desc(dnnl_primitive_desc_t pd)
  7681. : dnnl::primitive_desc(pd,
  7682. dnnl::primitive::kind::group_normalization,
  7683. dnnl::prop_kind::forward_training,
  7684. dnnl::prop_kind::forward_inference) {}
  7685. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  7686. memory::desc src_desc() const { return base::src_desc(0); }
  7687. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  7688. memory::desc dst_desc() const { return base::dst_desc(0); }
  7689. /// @copydoc dnnl::primitive_desc_base::weights_desc()const
  7690. memory::desc weights_desc() const { return base::weights_desc(0); }
  7691. /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
  7692. memory::desc workspace_desc() const { return base::workspace_desc(); }
  7693. /// Returns memory descriptor for mean.
  7694. /// @returns Memory descriptor for mean.
  7695. memory::desc mean_desc() const { return stat_desc(mean); }
  7696. /// Returns memory descriptor for variance.
  7697. /// @returns Memory descriptor for variance.
  7698. memory::desc variance_desc() const { return stat_desc(var); }
  7699. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  7700. dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  7701. /// @copydoc dnnl::primitive_desc_base::get_group_size()const
  7702. memory::dim get_group_size() const { return base::get_group_size(); }
  7703. /// @copydoc dnnl::primitive_desc_base::get_epsilon()const
  7704. float get_epsilon() const { return base::get_epsilon(); }
  7705. /// Returns normalization flags.
  7706. /// @return Normalization flags.
  7707. normalization_flags get_flags() const {
  7708. return base::get_flags<normalization_flags>();
  7709. }
  7710. private:
  7711. enum {
  7712. mean = 1,
  7713. var = 2,
  7714. };
  7715. memory::desc stat_desc(int kind) const {
  7716. const bool use_global_stats
  7717. = (get_flags() & normalization_flags::use_global_stats)
  7718. != normalization_flags::none;
  7719. return query_md(
  7720. use_global_stats ? query::src_md : query::dst_md, kind);
  7721. }
  7722. };
  7723. /// Default constructor. Produces an empty object.
  7724. group_normalization_forward() = default;
  7725. /// Constructs a group normalization forward propagation primitive.
  7726. /// @param pd Primitive descriptor for a group normalization forward
  7727. /// propagation primitive.
  7728. group_normalization_forward(const primitive_desc &pd) : primitive(pd) {}
  7729. /// Constructs a group normalization forward propagation primitive from
  7730. /// a cache blob.
  7731. /// @param pd Primitive descriptor for a group normalization forward
  7732. /// propagation primitive.
  7733. /// @param cache_blob Cache blob.
  7734. group_normalization_forward(
  7735. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  7736. : primitive(pd, cache_blob) {}
  7737. };
  7738. /// Group normalization backward propagation primitive.
  7739. struct group_normalization_backward : public primitive {
  7740. /// Primitive descriptor for a group normalization backward propagation
  7741. /// primitive.
  7742. struct primitive_desc : public dnnl::primitive_desc {
  7743. /// Default constructor. Produces an empty object.
  7744. primitive_desc() = default;
  7745. /// Constructs a primitive descriptor for a group normalization backward
  7746. /// propagation primitive.
  7747. ///
  7748. /// @param aengine Engine to use.
  7749. /// @param aprop_kind Propagation kind. Possible values are
  7750. /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
  7751. /// (diffs for all parameters are computed in this case).
  7752. /// @param diff_src_desc Diff source memory descriptor.
  7753. /// @param diff_dst_desc Diff destination memory descriptor.
  7754. /// @param src_desc Source memory descriptor.
  7755. /// @param groups Group normalization groups parameter.
  7756. /// @param epsilon Group normalization epsilon parameter.
  7757. /// @param flags Group normalization flags (@ref
  7758. /// dnnl::normalization_flags).
  7759. /// @param hint_fwd_pd Primitive descriptor for a group normalization
  7760. /// forward propagation primitive. It is used as a hint for
  7761. /// deciding which memory format to use.
  7762. /// @param attr Primitive attributes to use. Attributes are optional
  7763. /// and default to empty attributes.
  7764. /// @param allow_empty A flag signifying whether construction is
  7765. /// allowed to fail without throwing an exception. In this case an
  7766. /// empty object will be produced. This flag is optional and
  7767. /// defaults to false.
  7768. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  7769. const memory::desc &diff_src_desc,
  7770. const memory::desc &diff_dst_desc, const memory::desc &src_desc,
  7771. memory::dim groups, float epsilon, normalization_flags flags,
  7772. const group_normalization_forward::primitive_desc &hint_fwd_pd,
  7773. const primitive_attr &attr = default_attr(),
  7774. bool allow_empty = false) {
  7775. dnnl_primitive_desc_t pd = nullptr;
  7776. dnnl_status_t status
  7777. = dnnl_group_normalization_backward_primitive_desc_create(
  7778. &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
  7779. diff_src_desc.get(), diff_dst_desc.get(),
  7780. src_desc.get(), groups, epsilon,
  7781. convert_to_c(flags), hint_fwd_pd.get(), attr.get());
  7782. if (!allow_empty)
  7783. error::wrap_c_api(status,
  7784. "could not create a primitive descriptor for "
  7785. "the group normalization backward propagation "
  7786. "primitive. Run workload with environment variable "
  7787. "ONEDNN_VERBOSE=all to get additional diagnostic "
  7788. "information.");
  7789. reset(pd);
  7790. }
  7791. /// Constructs a primitive descriptor for a group normalization
  7792. /// backward propagation primitive from a C API primitive descriptor
  7793. /// that must have a matching kind.
  7794. ///
  7795. /// @param pd C API primitive descriptor for a group normalization
  7796. /// backward propagation primitive.
  7797. primitive_desc(dnnl_primitive_desc_t pd)
  7798. : dnnl::primitive_desc(pd,
  7799. dnnl::primitive::kind::group_normalization,
  7800. dnnl::prop_kind::backward, dnnl::prop_kind::backward_data) {
  7801. }
  7802. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  7803. memory::desc src_desc() const { return base::src_desc(0); }
  7804. /// @copydoc dnnl::primitive_desc_base::weights_desc()const
  7805. memory::desc weights_desc() const { return base::weights_desc(0); }
  7806. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  7807. memory::desc dst_desc() const { return base::dst_desc(0); }
  7808. /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
  7809. memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
  7810. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  7811. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  7812. /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
  7813. memory::desc diff_weights_desc() const {
  7814. return base::diff_weights_desc(0);
  7815. }
  7816. /// @copydoc dnnl::group_normalization_forward::primitive_desc::mean_desc()const
  7817. memory::desc mean_desc() const { return query_md(query::src_md, 1); }
  7818. /// @copydoc dnnl::group_normalization_forward::primitive_desc::variance_desc()const
  7819. memory::desc variance_desc() const {
  7820. return query_md(query::src_md, 2);
  7821. }
  7822. /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
  7823. memory::desc workspace_desc() const { return base::workspace_desc(); }
  7824. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  7825. dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  7826. /// @copydoc dnnl::primitive_desc_base::get_group_size()const
  7827. memory::dim get_group_size() const { return base::get_group_size(); }
  7828. /// @copydoc dnnl::primitive_desc_base::get_epsilon()const
  7829. float get_epsilon() const { return base::get_epsilon(); }
  7830. /// Returns normalization flags.
  7831. /// @return Normalization flags.
  7832. normalization_flags get_flags() const {
  7833. return base::get_flags<normalization_flags>();
  7834. }
  7835. };
  7836. /// Default constructor. Produces an empty object.
  7837. group_normalization_backward() = default;
  7838. /// Constructs a group normalization backward propagation primitive.
  7839. /// @param pd Primitive descriptor for a group normalization backward
  7840. /// propagation primitive.
  7841. group_normalization_backward(const primitive_desc &pd) : primitive(pd) {}
  7842. /// Constructs a group normalization backward propagation primitive from
  7843. /// a cache blob.
  7844. /// @param pd Primitive descriptor for a group normalization backward
  7845. /// propagation primitive.
  7846. /// @param cache_blob Cache blob.
  7847. group_normalization_backward(
  7848. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  7849. : primitive(pd, cache_blob) {}
  7850. };
  7851. /// @} dnnl_api_group_normalization
  7852. /// @addtogroup dnnl_api_layer_normalization Layer Normalization
  7853. ///
  7854. /// A primitive to perform layer normalization. Normalization is performed
  7855. /// within the last logical dimension of data tensor.
  7856. ///
  7857. /// Both forward and backward propagation primitives support in-place
  7858. /// operation; that is, src and dst can refer to the same memory for forward
  7859. /// propagation, and diff_dst and diff_src can refer to the same memory for
  7860. /// backward propagation.
  7861. ///
  7862. /// The layer normalization primitives computations can be controlled by
  7863. /// specifying different @ref dnnl::normalization_flags values. For example,
  7864. /// layer normalization forward propagation can be configured to either
  7865. /// compute the mean and variance or take them as arguments. It can either
  7866. /// perform scaling and shifting using gamma and beta parameters or not.
  7867. ///
  7868. /// @sa @ref dev_guide_layer_normalization in developer guide
  7869. ///
  7870. /// @{
  7871. /// Layer normalization forward propagation primitive.
  7872. struct layer_normalization_forward : public primitive {
  7873. /// Primitive descriptor for a layer normalization forward propagation
  7874. /// primitive.
  7875. struct primitive_desc : public dnnl::primitive_desc {
  7876. /// Default constructor. Produces an empty object.
  7877. primitive_desc() = default;
  7878. /// Constructs a primitive descriptor for a layer normalization forward
  7879. /// propagation primitive.
  7880. ///
  7881. /// @param aengine Engine to use.
  7882. /// @param aprop_kind Propagation kind. Possible values are
  7883. /// #dnnl::prop_kind::forward_training, and
  7884. /// #dnnl::prop_kind::forward_inference.
  7885. /// @param src_desc Source memory descriptor.
  7886. /// @param dst_desc Destination memory descriptor.
  7887. /// @param stat_desc Statistics memory descriptors.
  7888. /// @param epsilon Layer normalization epsilon parameter.
  7889. /// @param flags Layer normalization flags (@ref
  7890. /// dnnl::normalization_flags).
  7891. /// @param attr Primitive attributes to use. Attributes are optional
  7892. /// and default to empty attributes.
  7893. /// @param allow_empty A flag signifying whether construction is
  7894. /// allowed to fail without throwing an exception. In this case an
  7895. /// empty object will be produced. This flag is optional and
  7896. /// defaults to false.
  7897. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  7898. const memory::desc &src_desc, const memory::desc &dst_desc,
  7899. const memory::desc &stat_desc, float epsilon,
  7900. normalization_flags flags,
  7901. const primitive_attr &attr = default_attr(),
  7902. bool allow_empty = false)
  7903. : primitive_desc(aengine, aprop_kind, src_desc, dst_desc,
  7904. &stat_desc, memory::data_type::f32, epsilon, flags, attr,
  7905. allow_empty) {}
  7906. /// Constructs a primitive descriptor for a layer normalization forward
  7907. /// propagation primitive.
  7908. ///
  7909. /// @param aengine Engine to use.
  7910. /// @param aprop_kind Propagation kind. Possible values are
  7911. /// #dnnl::prop_kind::forward_training, and
  7912. /// #dnnl::prop_kind::forward_inference.
  7913. /// @param src_desc Source memory descriptor.
  7914. /// @param dst_desc Destination memory descriptor.
  7915. /// @param epsilon Layer normalization epsilon parameter.
  7916. /// @param flags Layer normalization flags (@ref
  7917. /// dnnl::normalization_flags).
  7918. /// @param attr Primitive attributes to use. Attributes are optional
  7919. /// and default to empty attributes.
  7920. /// @param allow_empty A flag signifying whether construction is
  7921. /// allowed to fail without throwing an exception. In this case an
  7922. /// empty object will be produced. This flag is optional and
  7923. /// defaults to false.
  7924. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  7925. const memory::desc &src_desc, const memory::desc &dst_desc,
  7926. float epsilon, normalization_flags flags,
  7927. const primitive_attr &attr = default_attr(),
  7928. bool allow_empty = false)
  7929. : primitive_desc(aengine, aprop_kind, src_desc, dst_desc, nullptr,
  7930. memory::data_type::f32, epsilon, flags, attr, allow_empty) {
  7931. }
  7932. /// Constructs a primitive descriptor for a layer normalization forward
  7933. /// propagation primitive with a user-provided data type for the scale
  7934. /// and shift memory objects.
  7935. ///
  7936. /// @param aengine Engine to use.
  7937. /// @param aprop_kind Propagation kind. Possible values are
  7938. /// #dnnl::prop_kind::forward_training, and
  7939. /// #dnnl::prop_kind::forward_inference.
  7940. /// @param src_desc Source memory descriptor.
  7941. /// @param dst_desc Destination memory descriptor.
  7942. /// @param stat_desc Statistics memory descriptors.
  7943. /// @param scale_shift_data_type Data type of scale and shift memory.
  7944. /// If neither scale nor shift flag are specified the parameter
  7945. /// is ignored.
  7946. /// @param epsilon Layer normalization epsilon parameter.
  7947. /// @param flags Layer normalization flags (@ref
  7948. /// dnnl::normalization_flags).
  7949. /// @param attr Primitive attributes to use. Attributes are optional
  7950. /// and default to empty attributes.
  7951. /// @param allow_empty A flag signifying whether construction is
  7952. /// allowed to fail without throwing an exception. In this case an
  7953. /// empty object will be produced. This flag is optional and
  7954. /// defaults to false.
  7955. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  7956. const memory::desc &src_desc, const memory::desc &dst_desc,
  7957. const memory::desc &stat_desc,
  7958. memory::data_type scale_shift_data_type, float epsilon,
  7959. normalization_flags flags,
  7960. const primitive_attr &attr = default_attr(),
  7961. bool allow_empty = false)
  7962. : primitive_desc(aengine, aprop_kind, src_desc, dst_desc,
  7963. &stat_desc, scale_shift_data_type, epsilon, flags, attr,
  7964. allow_empty) {}
  7965. /// Constructs a primitive descriptor for a layer normalization forward
  7966. /// propagation primitive with a user-provided data type for the scale
  7967. /// and shift memory objects.
  7968. ///
  7969. /// @param aengine Engine to use.
  7970. /// @param aprop_kind Propagation kind. Possible values are
  7971. /// #dnnl::prop_kind::forward_training, and
  7972. /// #dnnl::prop_kind::forward_inference.
  7973. /// @param src_desc Source memory descriptor.
  7974. /// @param dst_desc Destination memory descriptor.
  7975. /// @param scale_shift_data_type Data type of scale and shift memory.
  7976. /// If neither scale nor shift flag are specified the parameter
  7977. /// is ignored.
  7978. /// @param epsilon Layer normalization epsilon parameter.
  7979. /// @param flags Layer normalization flags (@ref
  7980. /// dnnl::normalization_flags).
  7981. /// @param attr Primitive attributes to use. Attributes are optional
  7982. /// and default to empty attributes.
  7983. /// @param allow_empty A flag signifying whether construction is
  7984. /// allowed to fail without throwing an exception. In this case an
  7985. /// empty object will be produced. This flag is optional and
  7986. /// defaults to false.
  7987. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  7988. const memory::desc &src_desc, const memory::desc &dst_desc,
  7989. memory::data_type scale_shift_data_type, float epsilon,
  7990. normalization_flags flags,
  7991. const primitive_attr &attr = default_attr(),
  7992. bool allow_empty = false)
  7993. : primitive_desc(aengine, aprop_kind, src_desc, dst_desc, nullptr,
  7994. scale_shift_data_type, epsilon, flags, attr, allow_empty) {}
  7995. /// Constructs a primitive descriptor for a layer normalization
  7996. /// forward propagation primitive from a C API primitive descriptor
  7997. /// that must have a matching kind.
  7998. ///
  7999. /// @param pd C API primitive descriptor for a layer normalization
  8000. /// forward propagation primitive.
  8001. primitive_desc(dnnl_primitive_desc_t pd)
  8002. : dnnl::primitive_desc(pd,
  8003. dnnl::primitive::kind::layer_normalization,
  8004. dnnl::prop_kind::forward_training,
  8005. dnnl::prop_kind::forward_inference) {}
  8006. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  8007. memory::desc src_desc() const { return base::src_desc(0); }
  8008. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  8009. memory::desc dst_desc() const { return base::dst_desc(0); }
  8010. /// @copydoc dnnl::primitive_desc_base::weights_desc()const
  8011. memory::desc weights_desc() const { return base::weights_desc(0); }
  8012. /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
  8013. memory::desc workspace_desc() const { return base::workspace_desc(); }
  8014. /// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const
  8015. memory::desc mean_desc() const { return stat_desc(mean); }
  8016. /// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const
  8017. memory::desc variance_desc() const { return stat_desc(var); }
  8018. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  8019. dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  8020. /// @copydoc dnnl::primitive_desc_base::get_epsilon()const
  8021. float get_epsilon() const { return base::get_epsilon(); }
  8022. /// Returns normalization flags.
  8023. /// @return Normalization flags.
  8024. normalization_flags get_flags() const {
  8025. return base::get_flags<normalization_flags>();
  8026. }
  8027. private:
  8028. enum {
  8029. mean = 1,
  8030. var = 2,
  8031. };
  8032. memory::desc stat_desc(int kind) const {
  8033. const bool use_global_stats
  8034. = (get_flags() & normalization_flags::use_global_stats)
  8035. != normalization_flags::none;
  8036. return query_md(
  8037. use_global_stats ? query::src_md : query::dst_md, kind);
  8038. }
  8039. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  8040. const memory::desc &src_desc, const memory::desc &dst_desc,
  8041. const memory::desc *stat_desc,
  8042. memory::data_type scale_shift_data_type, float epsilon,
  8043. normalization_flags flags, const primitive_attr &attr,
  8044. bool allow_empty) {
  8045. dnnl_primitive_desc_t pd = nullptr;
  8046. dnnl_status_t status
  8047. = dnnl_layer_normalization_forward_primitive_desc_create_v2(
  8048. &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
  8049. src_desc.get(), dst_desc.get(),
  8050. optional_arg(stat_desc),
  8051. memory::convert_to_c(scale_shift_data_type),
  8052. epsilon, convert_to_c(flags), attr.get());
  8053. if (!allow_empty)
  8054. error::wrap_c_api(status,
  8055. "could not create a primitive descriptor for "
  8056. "the layer normalization forward propagation "
  8057. "primitive. Run workload with environment variable "
  8058. "ONEDNN_VERBOSE=all to get additional diagnostic "
  8059. "information.");
  8060. reset(pd);
  8061. }
  8062. };
  8063. /// Default constructor. Produces an empty object.
  8064. layer_normalization_forward() = default;
  8065. /// Constructs a layer normalization forward propagation primitive.
  8066. /// @param pd Primitive descriptor for a layer normalization forward
  8067. /// propagation primitive.
  8068. layer_normalization_forward(const primitive_desc &pd) : primitive(pd) {}
  8069. /// Constructs a layer normalization forward propagation primitive from
  8070. /// a cache blob.
  8071. /// @param pd Primitive descriptor for a layer normalization forward
  8072. /// propagation primitive.
  8073. /// @param cache_blob Cache blob.
  8074. layer_normalization_forward(
  8075. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  8076. : primitive(pd, cache_blob) {}
  8077. };
  8078. /// Layer normalization backward propagation primitive.
  8079. struct layer_normalization_backward : public primitive {
  8080. /// Primitive descriptor for a layer normalization backward propagation
  8081. /// primitive.
  8082. struct primitive_desc : public dnnl::primitive_desc {
  8083. /// Default constructor. Produces an empty object.
  8084. primitive_desc() = default;
  8085. /// Constructs a primitive descriptor for a layer normalization backward
  8086. /// propagation primitive.
  8087. ///
  8088. /// @param aengine Engine to use.
  8089. /// @param aprop_kind Propagation kind. Possible values are
  8090. /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
  8091. /// (diffs for all parameters are computed in this case).
  8092. /// @param diff_src_desc Diff source memory descriptor.
  8093. /// @param diff_dst_desc Diff destination memory descriptor.
  8094. /// @param src_desc Source memory descriptor.
  8095. /// @param stat_desc Statistics memory descriptors.
  8096. /// @param epsilon Layer normalization epsilon parameter.
  8097. /// @param flags Layer normalization flags (@ref
  8098. /// dnnl::normalization_flags).
  8099. /// @param attr Primitive attributes to use. Attributes are optional
  8100. /// and default to empty attributes.
  8101. /// @param hint_fwd_pd Primitive descriptor for a layer normalization
  8102. /// forward propagation primitive. It is used as a hint for
  8103. /// deciding which memory format to use.
  8104. /// @param allow_empty A flag signifying whether construction is
  8105. /// allowed to fail without throwing an exception. In this case an
  8106. /// empty object will be produced. This flag is optional and
  8107. /// defaults to false.
  8108. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  8109. const memory::desc &diff_src_desc,
  8110. const memory::desc &diff_dst_desc, const memory::desc &src_desc,
  8111. const memory::desc &stat_desc, float epsilon,
  8112. normalization_flags flags,
  8113. const layer_normalization_forward::primitive_desc &hint_fwd_pd,
  8114. const primitive_attr &attr = default_attr(),
  8115. bool allow_empty = false)
  8116. : primitive_desc(aengine, aprop_kind, diff_src_desc, diff_dst_desc,
  8117. src_desc, &stat_desc, memory::data_type::f32,
  8118. memory::data_type::f32, epsilon, flags, hint_fwd_pd, attr,
  8119. allow_empty) {}
  8120. /// Constructs a primitive descriptor for a layer normalization backward
  8121. /// propagation primitive.
  8122. ///
  8123. /// @param aengine Engine to use.
  8124. /// @param aprop_kind Propagation kind. Possible values are
  8125. /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
  8126. /// (diffs for all parameters are computed in this case).
  8127. /// @param diff_src_desc Diff source memory descriptor.
  8128. /// @param diff_dst_desc Diff destination memory descriptor.
  8129. /// @param src_desc Source memory descriptor.
  8130. /// @param epsilon Layer normalization epsilon parameter.
  8131. /// @param flags Layer normalization flags (@ref
  8132. /// dnnl::normalization_flags).
  8133. /// @param attr Primitive attributes to use. Attributes are optional
  8134. /// and default to empty attributes.
  8135. /// @param hint_fwd_pd Primitive descriptor for a layer normalization
  8136. /// forward propagation primitive. It is used as a hint for
  8137. /// deciding which memory format to use.
  8138. /// @param allow_empty A flag signifying whether construction is
  8139. /// allowed to fail without throwing an exception. In this case an
  8140. /// empty object will be produced. This flag is optional and
  8141. /// defaults to false.
  8142. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  8143. const memory::desc &diff_src_desc,
  8144. const memory::desc &diff_dst_desc, const memory::desc &src_desc,
  8145. float epsilon, normalization_flags flags,
  8146. const layer_normalization_forward::primitive_desc &hint_fwd_pd,
  8147. const primitive_attr &attr = default_attr(),
  8148. bool allow_empty = false)
  8149. : primitive_desc(aengine, aprop_kind, diff_src_desc, diff_dst_desc,
  8150. src_desc, nullptr, memory::data_type::f32,
  8151. memory::data_type::f32, epsilon, flags, hint_fwd_pd, attr,
  8152. allow_empty) {}
  8153. /// Constructs a primitive descriptor for a layer normalization backward
  8154. /// propagation primitive with a user-provided data type for the scale
  8155. /// and shift memory objects.
  8156. ///
  8157. /// @param aengine Engine to use.
  8158. /// @param aprop_kind Propagation kind. Possible values are
  8159. /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
  8160. /// (diffs for all parameters are computed in this case).
  8161. /// @param diff_src_desc Diff source memory descriptor.
  8162. /// @param diff_dst_desc Diff destination memory descriptor.
  8163. /// @param src_desc Source memory descriptor.
  8164. /// @param stat_desc Statistics memory descriptors.
  8165. /// @param diff_scale_shift_data_type Data type of diff scale and shift
  8166. /// memory. If neither scale nor shift flag are specified the
  8167. /// parameter is ignored.
  8168. /// @param scale_shift_data_type Data type of scale and shift memory.
  8169. /// If neither scale nor shift flag are specified the parameter
  8170. /// is ignored.
  8171. /// @param epsilon Layer normalization epsilon parameter.
  8172. /// @param flags Layer normalization flags (@ref
  8173. /// dnnl::normalization_flags).
  8174. /// @param attr Primitive attributes to use. Attributes are optional
  8175. /// and default to empty attributes.
  8176. /// @param hint_fwd_pd Primitive descriptor for a layer normalization
  8177. /// forward propagation primitive. It is used as a hint for
  8178. /// deciding which memory format to use.
  8179. /// @param allow_empty A flag signifying whether construction is
  8180. /// allowed to fail without throwing an exception. In this case an
  8181. /// empty object will be produced. This flag is optional and
  8182. /// defaults to false.
  8183. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  8184. const memory::desc &diff_src_desc,
  8185. const memory::desc &diff_dst_desc, const memory::desc &src_desc,
  8186. const memory::desc &stat_desc,
  8187. memory::data_type diff_scale_shift_data_type,
  8188. memory::data_type scale_shift_data_type, float epsilon,
  8189. normalization_flags flags,
  8190. const layer_normalization_forward::primitive_desc &hint_fwd_pd,
  8191. const primitive_attr &attr = default_attr(),
  8192. bool allow_empty = false)
  8193. : primitive_desc(aengine, aprop_kind, diff_src_desc, diff_dst_desc,
  8194. src_desc, &stat_desc, diff_scale_shift_data_type,
  8195. scale_shift_data_type, epsilon, flags, hint_fwd_pd, attr,
  8196. allow_empty) {}
  8197. /// Constructs a primitive descriptor for a layer normalization backward
  8198. /// propagation primitive with a user-provided data type for the scale
  8199. /// and shift memory objects.
  8200. ///
  8201. /// @param aengine Engine to use.
  8202. /// @param aprop_kind Propagation kind. Possible values are
  8203. /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
  8204. /// (diffs for all parameters are computed in this case).
  8205. /// @param diff_src_desc Diff source memory descriptor.
  8206. /// @param diff_dst_desc Diff destination memory descriptor.
  8207. /// @param src_desc Source memory descriptor.
  8208. /// @param diff_scale_shift_data_type Data type of diff scale and shift
  8209. /// memory. If neither scale nor shift flag are specified the
  8210. /// parameter is ignored.
  8211. /// @param scale_shift_data_type Data type of scale and shift memory.
  8212. /// If neither scale nor shift flag are specified the parameter
  8213. /// is ignored.
  8214. /// @param epsilon Layer normalization epsilon parameter.
  8215. /// @param flags Layer normalization flags (@ref
  8216. /// dnnl::normalization_flags).
  8217. /// @param attr Primitive attributes to use. Attributes are optional
  8218. /// and default to empty attributes.
  8219. /// @param hint_fwd_pd Primitive descriptor for a layer normalization
  8220. /// forward propagation primitive. It is used as a hint for
  8221. /// deciding which memory format to use.
  8222. /// @param allow_empty A flag signifying whether construction is
  8223. /// allowed to fail without throwing an exception. In this case an
  8224. /// empty object will be produced. This flag is optional and
  8225. /// defaults to false.
  8226. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  8227. const memory::desc &diff_src_desc,
  8228. const memory::desc &diff_dst_desc, const memory::desc &src_desc,
  8229. memory::data_type diff_scale_shift_data_type,
  8230. memory::data_type scale_shift_data_type, float epsilon,
  8231. normalization_flags flags,
  8232. const layer_normalization_forward::primitive_desc &hint_fwd_pd,
  8233. const primitive_attr &attr = default_attr(),
  8234. bool allow_empty = false)
  8235. : primitive_desc(aengine, aprop_kind, diff_src_desc, diff_dst_desc,
  8236. src_desc, nullptr, diff_scale_shift_data_type,
  8237. scale_shift_data_type, epsilon, flags, hint_fwd_pd, attr,
  8238. allow_empty) {}
  8239. /// Constructs a primitive descriptor for a layer normalization
  8240. /// backward propagation primitive from a C API primitive descriptor
  8241. /// that must have a matching kind.
  8242. ///
  8243. /// @param pd C API primitive descriptor for a layer normalization
  8244. /// backward propagation primitive.
  8245. primitive_desc(dnnl_primitive_desc_t pd)
  8246. : dnnl::primitive_desc(pd,
  8247. dnnl::primitive::kind::layer_normalization,
  8248. dnnl::prop_kind::backward, dnnl::prop_kind::backward_data) {
  8249. }
  8250. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  8251. memory::desc src_desc() const { return base::src_desc(0); }
  8252. /// @copydoc dnnl::primitive_desc_base::weights_desc()const
  8253. memory::desc weights_desc() const { return base::weights_desc(0); }
  8254. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  8255. memory::desc dst_desc() const { return base::dst_desc(0); }
  8256. /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
  8257. memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
  8258. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  8259. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  8260. /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
  8261. memory::desc diff_weights_desc() const {
  8262. return base::diff_weights_desc(0);
  8263. }
  8264. /// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const
  8265. memory::desc mean_desc() const { return query_md(query::src_md, 1); }
  8266. /// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const
  8267. memory::desc variance_desc() const {
  8268. return query_md(query::src_md, 2);
  8269. }
  8270. /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
  8271. memory::desc workspace_desc() const { return base::workspace_desc(); }
  8272. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  8273. dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  8274. /// @copydoc dnnl::primitive_desc_base::get_epsilon()const
  8275. float get_epsilon() const { return base::get_epsilon(); }
  8276. /// Returns normalization flags.
  8277. /// @return Normalization flags.
  8278. normalization_flags get_flags() const {
  8279. return base::get_flags<normalization_flags>();
  8280. }
  8281. private:
  8282. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  8283. const memory::desc &diff_src_desc,
  8284. const memory::desc &diff_dst_desc, const memory::desc &src_desc,
  8285. const memory::desc *stat_desc,
  8286. memory::data_type diff_scale_shift_data_type,
  8287. memory::data_type scale_shift_data_type, float epsilon,
  8288. normalization_flags flags,
  8289. const layer_normalization_forward::primitive_desc &hint_fwd_pd,
  8290. const primitive_attr &attr, bool allow_empty) {
  8291. dnnl_primitive_desc_t pd = nullptr;
  8292. dnnl_status_t status
  8293. = dnnl_layer_normalization_backward_primitive_desc_create_v2(
  8294. &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
  8295. diff_src_desc.get(), diff_dst_desc.get(),
  8296. src_desc.get(), optional_arg(stat_desc),
  8297. memory::convert_to_c(diff_scale_shift_data_type),
  8298. memory::convert_to_c(scale_shift_data_type),
  8299. epsilon, convert_to_c(flags), hint_fwd_pd.get(),
  8300. attr.get());
  8301. if (!allow_empty)
  8302. error::wrap_c_api(status,
  8303. "could not create a primitive descriptor for "
  8304. "the layer normalization backward propagation "
  8305. "primitive. Run workload with environment variable "
  8306. "ONEDNN_VERBOSE=all to get additional diagnostic "
  8307. "information.");
  8308. reset(pd);
  8309. }
  8310. };
  8311. /// Default constructor. Produces an empty object.
  8312. layer_normalization_backward() = default;
  8313. /// Constructs a layer normalization backward propagation primitive.
  8314. /// @param pd Primitive descriptor for a layer normalization backward
  8315. /// propagation primitive.
  8316. layer_normalization_backward(const primitive_desc &pd) : primitive(pd) {}
  8317. /// Constructs a layer normalization backward propagation primitive from
  8318. /// a cache blob.
  8319. /// @param pd Primitive descriptor for a layer normalization backward
  8320. /// propagation primitive.
  8321. /// @param cache_blob Cache blob.
  8322. layer_normalization_backward(
  8323. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  8324. : primitive(pd, cache_blob) {}
  8325. };
  8326. /// @} dnnl_api_layer_normalization
  8327. /// @addtogroup dnnl_api_inner_product Inner Product
  8328. ///
  8329. /// A primitive to compute an inner product.
  8330. ///
  8331. /// @sa @ref dev_guide_inner_product in developer guide
  8332. ///
  8333. /// @{
  8334. /// Inner product forward propagation primitive.
  8335. struct inner_product_forward : public primitive {
  8336. /// Primitive descriptor for an inner product forward propagation primitive.
  8337. struct primitive_desc : public dnnl::primitive_desc {
  8338. /// Default constructor. Produces an empty object.
  8339. primitive_desc() = default;
  8340. /// Constructs a primitive descriptor for an inner product forward
  8341. /// propagation primitive with bias.
  8342. ///
  8343. /// @note
  8344. /// All the memory descriptors may be initialized with the
  8345. /// #dnnl::memory::format_tag::any value of @p format_tag.
  8346. ///
  8347. /// @param aengine Engine to use.
  8348. /// @param aprop_kind Propagation kind. Possible values are
  8349. /// #dnnl::prop_kind::forward_training, and
  8350. /// #dnnl::prop_kind::forward_inference.
  8351. /// @param src_desc Memory descriptor for src.
  8352. /// @param weights_desc Memory descriptor for weights.
  8353. /// @param bias_desc Memory descriptor for bias.
  8354. /// @param dst_desc Memory descriptor for dst.
  8355. /// @param attr Primitive attributes to use. Attributes are optional
  8356. /// and default to empty attributes.
  8357. /// @param allow_empty A flag signifying whether construction is
  8358. /// allowed to fail without throwing an exception. In this case an
  8359. /// empty object will be produced. This flag is optional and
  8360. /// defaults to false.
  8361. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  8362. const memory::desc &src_desc, const memory::desc &weights_desc,
  8363. const memory::desc &bias_desc, const memory::desc &dst_desc,
  8364. const primitive_attr &attr = default_attr(),
  8365. bool allow_empty = false)
  8366. : primitive_desc(aengine, aprop_kind, src_desc, weights_desc,
  8367. &bias_desc, dst_desc, attr, allow_empty) {}
  8368. /// Constructs a primitive descriptor for an inner product forward
  8369. /// propagation primitive.
  8370. ///
  8371. /// @note
  8372. /// All the memory descriptors may be initialized with the
  8373. /// #dnnl::memory::format_tag::any value of @p format_tag.
  8374. ///
  8375. /// @param aengine Engine to use.
  8376. /// @param aprop_kind Propagation kind. Possible values are
  8377. /// #dnnl::prop_kind::forward_training, and
  8378. /// #dnnl::prop_kind::forward_inference.
  8379. /// @param src_desc Memory descriptor for src.
  8380. /// @param weights_desc Memory descriptor for weights.
  8381. /// @param dst_desc Memory descriptor for dst.
  8382. /// @param attr Primitive attributes to use. Attributes are optional
  8383. /// and default to empty attributes.
  8384. /// @param allow_empty A flag signifying whether construction is
  8385. /// allowed to fail without throwing an exception. In this case an
  8386. /// empty object will be produced. This flag is optional and
  8387. /// defaults to false.
  8388. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  8389. const memory::desc &src_desc, const memory::desc &weights_desc,
  8390. const memory::desc &dst_desc,
  8391. const primitive_attr &attr = default_attr(),
  8392. bool allow_empty = false)
  8393. : primitive_desc(aengine, aprop_kind, src_desc, weights_desc,
  8394. nullptr, dst_desc, attr, allow_empty) {}
  8395. /// Constructs a primitive descriptor for an inner product forward
  8396. /// propagation primitive from a C API primitive descriptor that must
  8397. /// have a matching kind.
  8398. ///
  8399. /// @param pd C API primitive descriptor for an inner product forward
  8400. /// propagation primitive.
  8401. primitive_desc(dnnl_primitive_desc_t pd)
  8402. : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
  8403. dnnl::prop_kind::forward_training,
  8404. dnnl::prop_kind::forward_inference) {}
  8405. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  8406. memory::desc src_desc() const { return base::src_desc(0); }
  8407. /// @copydoc dnnl::primitive_desc_base::weights_desc()const
  8408. memory::desc weights_desc() const { return base::weights_desc(0); }
  8409. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  8410. memory::desc dst_desc() const { return base::dst_desc(0); }
  8411. /// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const
  8412. memory::desc bias_desc() const { return base::weights_desc(1); }
  8413. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  8414. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  8415. private:
  8416. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  8417. const memory::desc &src_desc, const memory::desc &weights_desc,
  8418. const memory::desc *bias_desc, const memory::desc &dst_desc,
  8419. const primitive_attr &attr, bool allow_empty) {
  8420. dnnl_primitive_desc_t pd = nullptr;
  8421. dnnl_status_t status
  8422. = dnnl_inner_product_forward_primitive_desc_create(&pd,
  8423. aengine.get(), dnnl::convert_to_c(aprop_kind),
  8424. src_desc.get(), weights_desc.get(),
  8425. optional_arg(bias_desc), dst_desc.get(),
  8426. attr.get());
  8427. if (!allow_empty)
  8428. error::wrap_c_api(status,
  8429. "could not create a primitive descriptor for "
  8430. "the inner product forward propagation primitive. Run "
  8431. "workload with environment variable ONEDNN_VERBOSE=all "
  8432. "to get additional diagnostic information.");
  8433. reset(pd);
  8434. }
  8435. };
  8436. /// Default constructor. Produces an empty object.
  8437. inner_product_forward() = default;
  8438. /// Constructs an inner product forward propagation primitive.
  8439. /// @param pd Primitive descriptor for an inner product forward
  8440. /// propagation primitive.
  8441. inner_product_forward(const primitive_desc &pd) : primitive(pd) {}
  8442. /// Constructs an inner product forward propagation primitive from
  8443. /// a cache blob.
  8444. /// @param pd Primitive descriptor for an inner product forward
  8445. /// propagation primitive.
  8446. /// @param cache_blob Cache blob.
  8447. inner_product_forward(
  8448. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  8449. : primitive(pd, cache_blob) {}
  8450. };
  8451. /// Inner product backward propagation primitive.
  8452. struct inner_product_backward_data : public primitive {
  8453. /// Primitive descriptor for an inner product backward propagation
  8454. /// primitive.
  8455. struct primitive_desc : public dnnl::primitive_desc {
  8456. /// Default constructor. Produces an empty object.
  8457. primitive_desc() = default;
  8458. /// Constructs a primitive descriptor for an inner product backward
  8459. /// propagation primitive.
  8460. ///
  8461. /// @note
  8462. /// All the memory descriptors may be initialized with the
  8463. /// #dnnl::memory::format_tag::any value of @p format_tag.
  8464. ///
  8465. /// @param aengine Engine to use.
  8466. /// @param diff_src_desc Memory descriptor for diff src.
  8467. /// @param weights_desc Memory descriptor for weights.
  8468. /// @param diff_dst_desc Memory descriptor for diff dst.
  8469. /// @param hint_fwd_pd Primitive descriptor for an inner product
  8470. /// forward propagation primitive. It is used as a hint for
  8471. /// deciding which memory format to use.
  8472. /// @param attr Primitive attributes to use. Attributes are optional
  8473. /// and default to empty attributes.
  8474. /// @param allow_empty A flag signifying whether construction is
  8475. /// allowed to fail without throwing an exception. In this case an
  8476. /// empty object will be produced. This flag is optional and
  8477. /// defaults to false.
  8478. primitive_desc(const engine &aengine, const memory::desc &diff_src_desc,
  8479. const memory::desc &weights_desc,
  8480. const memory::desc &diff_dst_desc,
  8481. const inner_product_forward::primitive_desc &hint_fwd_pd,
  8482. const primitive_attr &attr = default_attr(),
  8483. bool allow_empty = false) {
  8484. dnnl_primitive_desc_t pd = nullptr;
  8485. dnnl_status_t status
  8486. = dnnl_inner_product_backward_data_primitive_desc_create(
  8487. &pd, aengine.get(), diff_src_desc.get(),
  8488. weights_desc.get(), diff_dst_desc.get(),
  8489. hint_fwd_pd.get(), attr.get());
  8490. if (!allow_empty)
  8491. error::wrap_c_api(status,
  8492. "could not create a primitive descriptor for "
  8493. "the inner product backward propagation primitive. Run "
  8494. "workload with environment variable ONEDNN_VERBOSE=all "
  8495. "to get additional diagnostic information.");
  8496. reset(pd);
  8497. }
  8498. /// Constructs a primitive descriptor for an inner product backward
  8499. /// propagation primitive from a C API primitive descriptor that must
  8500. /// have a matching kind.
  8501. ///
  8502. /// @param pd C API primitive descriptor for an inner product backward
  8503. /// propagation primitive.
  8504. primitive_desc(dnnl_primitive_desc_t pd)
  8505. : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
  8506. dnnl::prop_kind::backward_data) {}
  8507. /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
  8508. memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
  8509. /// @copydoc dnnl::primitive_desc_base::weights_desc()const
  8510. memory::desc weights_desc() const { return base::weights_desc(0); }
  8511. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  8512. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  8513. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  8514. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  8515. };
  8516. /// Default constructor. Produces an empty object.
  8517. inner_product_backward_data() = default;
  8518. /// Constructs an inner product backward propagation primitive.
  8519. /// @param pd Primitive descriptor for an inner product backward
  8520. /// propagation primitive.
  8521. inner_product_backward_data(const primitive_desc &pd) : primitive(pd) {}
  8522. /// Constructs an inner product backward propagation primitive from
  8523. /// a cache blob.
  8524. /// @param pd Primitive descriptor for an inner product backward
  8525. /// propagation primitive.
  8526. /// @param cache_blob Cache blob.
  8527. inner_product_backward_data(
  8528. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  8529. : primitive(pd, cache_blob) {}
  8530. };
  8531. /// Inner product weights gradient primitive.
  8532. struct inner_product_backward_weights : public primitive {
  8533. /// Primitive descriptor for an inner product weights gradient primitive.
  8534. struct primitive_desc : public dnnl::primitive_desc {
  8535. /// Default constructor. Produces an empty object.
  8536. primitive_desc() = default;
  8537. /// Constructs a primitive descriptor for an inner product weights
  8538. /// update primitive with bias.
  8539. ///
  8540. /// @note
  8541. /// All the memory descriptors may be initialized with the
  8542. /// #dnnl::memory::format_tag::any value of @p format_tag.
  8543. ///
  8544. /// @param aengine Engine to use.
  8545. /// @param src_desc Memory descriptor for src.
  8546. /// @param diff_weights_desc Memory descriptor for diff weights.
  8547. /// @param diff_bias_desc Memory descriptor for diff bias.
  8548. /// @param diff_dst_desc Memory descriptor for diff dst.
  8549. /// @param hint_fwd_pd Primitive descriptor for an inner product
  8550. /// forward propagation primitive. It is used as a hint for
  8551. /// deciding which memory format to use.
  8552. /// @param attr Primitive attributes to use. Attributes are optional
  8553. /// and default to empty attributes.
  8554. /// @param allow_empty A flag signifying whether construction is
  8555. /// allowed to fail without throwing an exception. In this case an
  8556. /// empty object will be produced. This flag is optional and
  8557. /// defaults to false.
  8558. primitive_desc(const engine &aengine, const memory::desc &src_desc,
  8559. const memory::desc &diff_weights_desc,
  8560. const memory::desc &diff_bias_desc,
  8561. const memory::desc &diff_dst_desc,
  8562. const inner_product_forward::primitive_desc &hint_fwd_pd,
  8563. const primitive_attr &attr = default_attr(),
  8564. bool allow_empty = false)
  8565. : primitive_desc(aengine, src_desc, diff_weights_desc,
  8566. &diff_bias_desc, diff_dst_desc, hint_fwd_pd, attr,
  8567. allow_empty) {}
  8568. /// Constructs a primitive descriptor for an inner product weights
  8569. /// update primitive.
  8570. ///
  8571. /// @note
  8572. /// All the memory descriptors may be initialized with the
  8573. /// #dnnl::memory::format_tag::any value of @p format_tag.
  8574. ///
  8575. /// @param aengine Engine to use.
  8576. /// @param src_desc Memory descriptor for src.
  8577. /// @param diff_weights_desc Memory descriptor for diff weights.
  8578. /// @param diff_dst_desc Memory descriptor for diff dst.
  8579. /// @param attr Primitive attributes to use. Attributes are optional
  8580. /// and default to empty attributes.
  8581. /// @param hint_fwd_pd Primitive descriptor for an inner product
  8582. /// forward propagation primitive. It is used as a hint for
  8583. /// deciding which memory format to use.
  8584. /// @param allow_empty A flag signifying whether construction is
  8585. /// allowed to fail without throwing an exception. In this case an
  8586. /// empty object will be produced. This flag is optional and
  8587. /// defaults to false.
  8588. primitive_desc(const engine &aengine, const memory::desc &src_desc,
  8589. const memory::desc &diff_weights_desc,
  8590. const memory::desc &diff_dst_desc,
  8591. const inner_product_forward::primitive_desc &hint_fwd_pd,
  8592. const primitive_attr &attr = default_attr(),
  8593. bool allow_empty = false)
  8594. : primitive_desc(aengine, src_desc, diff_weights_desc, nullptr,
  8595. diff_dst_desc, hint_fwd_pd, attr, allow_empty) {}
  8596. /// Constructs a primitive descriptor for an inner product weights
  8597. /// update primitive from a C API primitive descriptor that must
  8598. /// have a matching kind.
  8599. ///
  8600. /// @param pd C API primitive descriptor for an inner product weights
  8601. /// gradient primitive.
  8602. primitive_desc(dnnl_primitive_desc_t pd)
  8603. : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
  8604. dnnl::prop_kind::backward_weights) {}
  8605. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  8606. memory::desc src_desc() const { return base::src_desc(0); }
  8607. /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
  8608. memory::desc diff_weights_desc() const {
  8609. return base::diff_weights_desc(0);
  8610. }
  8611. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  8612. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  8613. /// @copydoc dnnl::convolution_backward_weights::primitive_desc::diff_bias_desc()const
  8614. memory::desc diff_bias_desc() const {
  8615. return base::diff_weights_desc(1);
  8616. }
  8617. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  8618. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  8619. private:
  8620. primitive_desc(const engine &aengine, const memory::desc &src_desc,
  8621. const memory::desc &diff_weights_desc,
  8622. const memory::desc *diff_bias_desc,
  8623. const memory::desc &diff_dst_desc,
  8624. const inner_product_forward::primitive_desc &hint_fwd_pd,
  8625. const primitive_attr &attr, bool allow_empty) {
  8626. dnnl_primitive_desc_t pd = nullptr;
  8627. dnnl_status_t status
  8628. = dnnl_inner_product_backward_weights_primitive_desc_create(
  8629. &pd, aengine.get(), src_desc.get(),
  8630. diff_weights_desc.get(),
  8631. optional_arg(diff_bias_desc), diff_dst_desc.get(),
  8632. hint_fwd_pd.get(), attr.get());
  8633. if (!allow_empty)
  8634. error::wrap_c_api(status,
  8635. "could not create a primitive descriptor for "
  8636. "the inner product weights gradient primitive. Run "
  8637. "workload with environment variable ONEDNN_VERBOSE=all "
  8638. "to get additional diagnostic information.");
  8639. reset(pd);
  8640. }
  8641. };
  8642. /// Default constructor. Produces an empty object.
  8643. inner_product_backward_weights() = default;
  8644. /// Constructs an inner product weights gradient primitive.
  8645. /// @param pd Primitive descriptor for an inner product weights gradient
  8646. /// primitive.
  8647. inner_product_backward_weights(const primitive_desc &pd) : primitive(pd) {}
  8648. /// Constructs an inner product weights gradient primitive from a cache
  8649. /// blob.
  8650. /// @param pd Primitive descriptor for an inner product weights gradient
  8651. /// primitive.
  8652. /// @param cache_blob Cache blob.
  8653. inner_product_backward_weights(
  8654. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  8655. : primitive(pd, cache_blob) {}
  8656. };
  8657. /// @} dnnl_api_inner_product
  8658. /// @addtogroup dnnl_api_rnn RNN
  8659. ///
  8660. /// A primitive to compute recurrent neural network layers.
  8661. ///
  8662. /// @sa @ref dev_guide_rnn in developer guide
  8663. ///
  8664. /// @{
  8665. /// Base class for primitive descriptors for RNN primitives.
  8666. struct rnn_primitive_desc_base : public primitive_desc {
  8667. using primitive_desc::primitive_desc;
  8668. /// Default constructor. Produces an empty object.
  8669. rnn_primitive_desc_base() = default;
  8670. /// Constructs an RNN primitive descriptor base from a C API primitive
  8671. /// descriptor while checking that it actually describes the expected
  8672. /// primitive by comparing propagation and primitive kinds.
  8673. ///
  8674. /// @param pd C API primitive descriptor.
  8675. /// @param aprop_kind Expected propagation kind.
  8676. /// @param cell_kind Expected cell kind.
  8677. rnn_primitive_desc_base(dnnl_primitive_desc_t pd,
  8678. dnnl::prop_kind aprop_kind, dnnl::algorithm cell_kind)
  8679. : rnn_primitive_desc_base(pd, aprop_kind, aprop_kind, cell_kind) {}
  8680. /// Returns source layer memory descriptor.
  8681. /// @returns Source layer memory descriptor.
  8682. memory::desc src_layer_desc() const {
  8683. return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_LAYER);
  8684. }
  8685. /// Returns AUGRU attention memory descriptor.
  8686. /// @returns AUGRU attention memory descriptor.
  8687. memory::desc augru_attention_desc() const {
  8688. return base::query_md(query::exec_arg_md, DNNL_ARG_AUGRU_ATTENTION);
  8689. }
  8690. /// Returns source iteration memory descriptor.
  8691. /// @returns Source iteration memory descriptor.
  8692. /// @returns A zero memory descriptor if the primitive does not have a
  8693. /// source iteration parameter.
  8694. memory::desc src_iter_desc() const {
  8695. return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_ITER);
  8696. }
  8697. /// Returns source recurrent cell state memory descriptor.
  8698. /// @returns Source recurrent cell state memory descriptor.
  8699. memory::desc src_iter_c_desc() const {
  8700. return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_ITER_C);
  8701. }
  8702. /// Returns weights layer memory descriptor.
  8703. /// @returns Weights layer memory descriptor.
  8704. memory::desc weights_layer_desc() const {
  8705. return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_LAYER);
  8706. }
  8707. /// Returns weights iteration memory descriptor.
  8708. /// @returns Weights iteration memory descriptor.
  8709. memory::desc weights_iter_desc() const {
  8710. return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_ITER);
  8711. }
  8712. /// Returns weights peephole memory descriptor.
  8713. /// @returns Weights peephole memory descriptor.
  8714. memory::desc weights_peephole_desc() const {
  8715. return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_PEEPHOLE);
  8716. }
  8717. /// Returns weights projection memory descriptor.
  8718. /// @returns Weights projection memory descriptor.
  8719. memory::desc weights_projection_desc() const {
  8720. return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_PROJECTION);
  8721. }
  8722. /// Returns bias memory descriptor.
  8723. /// @returns Bias memory descriptor.
  8724. /// @returns A zero memory descriptor if the primitive does not have a
  8725. /// bias parameter.
  8726. memory::desc bias_desc() const {
  8727. return base::query_md(query::exec_arg_md, DNNL_ARG_BIAS);
  8728. }
  8729. /// Returns destination layer memory descriptor.
  8730. /// @returns Destination layer memory descriptor.
  8731. memory::desc dst_layer_desc() const {
  8732. return base::query_md(query::exec_arg_md, DNNL_ARG_DST_LAYER);
  8733. }
  8734. /// Returns destination iteration memory descriptor.
  8735. /// @returns Destination iteration memory descriptor.
  8736. /// @returns A zero memory descriptor if the primitive does not have a
  8737. /// destination iteration parameter.
  8738. memory::desc dst_iter_desc() const {
  8739. return base::query_md(query::exec_arg_md, DNNL_ARG_DST_ITER);
  8740. }
  8741. /// Returns destination recurrent cell state memory descriptor.
  8742. /// @returns Destination recurrent cell state memory descriptor.
  8743. memory::desc dst_iter_c_desc() const {
  8744. return base::query_md(query::exec_arg_md, DNNL_ARG_DST_ITER_C);
  8745. }
  8746. /// Returns diff source layer memory descriptor.
  8747. /// @returns Diff source layer memory descriptor.
  8748. memory::desc diff_src_layer_desc() const {
  8749. return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_LAYER);
  8750. }
  8751. /// Returns diff AUGRU attention memory descriptor.
  8752. /// @returns Diff AUGRU attention memory descriptor.
  8753. memory::desc diff_augru_attention_desc() const {
  8754. return base::query_md(
  8755. query::exec_arg_md, DNNL_ARG_DIFF_AUGRU_ATTENTION);
  8756. }
  8757. /// Returns diff source iteration memory descriptor.
  8758. /// @returns Diff source iteration memory descriptor.
  8759. /// @returns A zero memory descriptor if the primitive does not have a
  8760. /// diff source iteration parameter.
  8761. memory::desc diff_src_iter_desc() const {
  8762. return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_ITER);
  8763. }
  8764. /// Returns diff source recurrent cell state memory descriptor.
  8765. /// @returns Diff source recurrent cell state memory descriptor.
  8766. memory::desc diff_src_iter_c_desc() const {
  8767. return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_ITER_C);
  8768. }
  8769. /// Returns diff weights layer memory descriptor.
  8770. /// @returns Diff weights layer memory descriptor.
  8771. memory::desc diff_weights_layer_desc() const {
  8772. return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_LAYER);
  8773. }
  8774. /// Returns diff weights iteration memory descriptor.
  8775. /// @returns Diff weights iteration memory descriptor.
  8776. memory::desc diff_weights_iter_desc() const {
  8777. return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_ITER);
  8778. }
  8779. /// Returns diff weights peephole memory descriptor.
  8780. /// @returns Diff weights peephole memory descriptor.
  8781. memory::desc diff_weights_peephole_desc() const {
  8782. return base::query_md(
  8783. query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE);
  8784. }
  8785. /// Returns diff weights projection memory descriptor.
  8786. /// @returns Diff weights projection memory descriptor.
  8787. memory::desc diff_weights_projection_desc() const {
  8788. return base::query_md(
  8789. query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_PROJECTION);
  8790. }
  8791. /// Returns diff bias memory descriptor.
  8792. /// @returns Diff bias memory descriptor.
  8793. /// @returns A zero memory descriptor if the primitive does not have a
  8794. /// diff bias parameter.
  8795. memory::desc diff_bias_desc() const {
  8796. return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_BIAS);
  8797. }
  8798. /// Returns diff destination layer memory descriptor.
  8799. /// @returns Diff destination layer memory descriptor.
  8800. memory::desc diff_dst_layer_desc() const {
  8801. return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_LAYER);
  8802. }
  8803. /// Returns diff destination iteration memory descriptor.
  8804. /// @returns Diff destination iteration memory descriptor.
  8805. /// @returns A zero memory descriptor if the primitive does not have a
  8806. /// diff destination iteration parameter.
  8807. memory::desc diff_dst_iter_desc() const {
  8808. return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_ITER);
  8809. }
  8810. /// Returns diff destination recurrent cell state memory descriptor.
  8811. /// @returns Diff destination recurrent cell state memory descriptor.
  8812. memory::desc diff_dst_iter_c_desc() const {
  8813. return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_ITER_C);
  8814. }
  8815. protected:
  8816. using rnn_base = rnn_primitive_desc_base;
  8817. // (Deliberately not using doxygen comments)
  8818. //
  8819. // Constructs an RNN primitive descriptor base from a C API primitive
  8820. // descriptor while checking that it actually describes the expected
  8821. // primitive by comparing propagation and primitive kinds. Caller can
  8822. // pass two options propagation kinds. This is typically used to check
  8823. // that propagation kind is inference or training forward propagation.
  8824. //
  8825. // @param pd C API primitive descriptor.
  8826. // @param prop_kind1 Expected propagation kind.
  8827. // @param prop_kind2 Expected propagation kind.
  8828. // @param cell_kind Expected cell kind.
  8829. rnn_primitive_desc_base(dnnl_primitive_desc_t pd,
  8830. dnnl::prop_kind prop_kind1, dnnl::prop_kind prop_kind2,
  8831. dnnl::algorithm cell_kind) {
  8832. dnnl_status_t rc;
  8833. dnnl_primitive_kind_t q_primitive_kind;
  8834. rc = dnnl_primitive_desc_query(
  8835. pd, dnnl_query_primitive_kind, 0, &q_primitive_kind);
  8836. error::wrap_c_api(rc,
  8837. "could not retrieve a primitive kind from a primitive "
  8838. "descriptor for an RNN primitive");
  8839. dnnl_prop_kind_t q_prop_kind;
  8840. rc = dnnl_primitive_desc_query(
  8841. pd, dnnl_query_prop_kind, 0, &q_prop_kind);
  8842. error::wrap_c_api(rc,
  8843. "could not retrieve a propagation kind from a primitive "
  8844. "descriptor for an RNN primitive");
  8845. dnnl_alg_kind_t q_cell_kind;
  8846. rc = dnnl_primitive_desc_query(
  8847. pd, dnnl_query_cell_kind, 0, &q_cell_kind);
  8848. error::wrap_c_api(rc,
  8849. "could not retrieve a cell kind from a primitive descriptor "
  8850. "for an RNN primitive");
  8851. dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1);
  8852. dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2);
  8853. dnnl_alg_kind_t c_cell_kind = convert_to_c(cell_kind);
  8854. bool ok = q_primitive_kind == dnnl_rnn
  8855. && (q_prop_kind == c_prop_kind1 || q_prop_kind == c_prop_kind2)
  8856. && q_cell_kind == c_cell_kind;
  8857. if (!ok)
  8858. DNNL_THROW_ERROR(dnnl_invalid_arguments,
  8859. "mismatch between expected and provided descriptors for an "
  8860. "RNN primitive");
  8861. reset_with_clone(pd);
  8862. }
  8863. // Constructs an RNN forward propagation primitive descriptor base for
  8864. // any cell kind.
  8865. rnn_primitive_desc_base(const engine &aengine, algorithm cell_kind,
  8866. prop_kind aprop_kind, algorithm activation, rnn_direction direction,
  8867. const memory::desc &src_layer_desc,
  8868. const memory::desc &src_iter_desc,
  8869. const memory::desc *src_iter_c_desc,
  8870. const memory::desc *attention_desc,
  8871. const memory::desc &weights_layer_desc,
  8872. const memory::desc &weights_iter_desc,
  8873. const memory::desc *weights_peephole_desc,
  8874. const memory::desc *weights_projection_desc,
  8875. const memory::desc &bias_desc, const memory::desc &dst_layer_desc,
  8876. const memory::desc &dst_iter_desc,
  8877. const memory::desc *dst_iter_c_desc, rnn_flags flags, float alpha,
  8878. float beta, const primitive_attr &attr, bool allow_empty) {
  8879. dnnl_status_t status = dnnl_success;
  8880. const char *msg
  8881. = "could not create a primitive descriptor for a requested "
  8882. "cell kind";
  8883. dnnl_primitive_desc_t pd = nullptr;
  8884. switch (cell_kind) {
  8885. case algorithm::vanilla_rnn:
  8886. status = dnnl_vanilla_rnn_forward_primitive_desc_create(&pd,
  8887. aengine.get(), dnnl::convert_to_c(aprop_kind),
  8888. dnnl::convert_to_c(activation),
  8889. dnnl::convert_to_c(direction), src_layer_desc.get(),
  8890. src_iter_desc.get(), weights_layer_desc.get(),
  8891. weights_iter_desc.get(), bias_desc.get(),
  8892. dst_layer_desc.get(), dst_iter_desc.get(),
  8893. convert_to_c(flags), alpha, beta, attr.get());
  8894. msg = "could not create a primitive descriptor for "
  8895. "the vanilla RNN forward propagation primitive. Run "
  8896. "workload with environment variable ONEDNN_VERBOSE=all "
  8897. "to get additional diagnostic information.";
  8898. break;
  8899. case algorithm::vanilla_lstm:
  8900. status = dnnl_lstm_forward_primitive_desc_create(&pd,
  8901. aengine.get(), dnnl::convert_to_c(aprop_kind),
  8902. dnnl::convert_to_c(direction), src_layer_desc.get(),
  8903. src_iter_desc.get(), optional_arg(src_iter_c_desc),
  8904. weights_layer_desc.get(), weights_iter_desc.get(),
  8905. optional_arg(weights_peephole_desc),
  8906. optional_arg(weights_projection_desc), bias_desc.get(),
  8907. dst_layer_desc.get(), dst_iter_desc.get(),
  8908. optional_arg(dst_iter_c_desc), convert_to_c(flags),
  8909. attr.get());
  8910. msg = "could not create a primitive descriptor for "
  8911. "the LSTM forward propagation primitive. Run workload "
  8912. "with environment variable ONEDNN_VERBOSE=all to get "
  8913. "additional diagnostic information.";
  8914. break;
  8915. case algorithm::vanilla_gru:
  8916. status = dnnl_gru_forward_primitive_desc_create(&pd,
  8917. aengine.get(), dnnl::convert_to_c(aprop_kind),
  8918. dnnl::convert_to_c(direction), src_layer_desc.get(),
  8919. src_iter_desc.get(), weights_layer_desc.get(),
  8920. weights_iter_desc.get(), bias_desc.get(),
  8921. dst_layer_desc.get(), dst_iter_desc.get(),
  8922. convert_to_c(flags), attr.get());
  8923. msg = "could not create a primitive descriptor for "
  8924. "the GRU forward propagation primitive. Run workload "
  8925. "with environment variable ONEDNN_VERBOSE=all to get "
  8926. "additional diagnostic information.";
  8927. break;
  8928. case algorithm::lbr_gru:
  8929. status = dnnl_lbr_gru_forward_primitive_desc_create(&pd,
  8930. aengine.get(), dnnl::convert_to_c(aprop_kind),
  8931. dnnl::convert_to_c(direction), src_layer_desc.get(),
  8932. src_iter_desc.get(), weights_layer_desc.get(),
  8933. weights_iter_desc.get(), bias_desc.get(),
  8934. dst_layer_desc.get(), dst_iter_desc.get(),
  8935. convert_to_c(flags), attr.get());
  8936. msg = "could not create a primitive descriptor for "
  8937. "the LBR GRU forward propagation primitive. Run workload "
  8938. "with environment variable ONEDNN_VERBOSE=all to get "
  8939. "additional diagnostic information.";
  8940. break;
  8941. case algorithm::vanilla_augru:
  8942. status = dnnl_augru_forward_primitive_desc_create(&pd,
  8943. aengine.get(), dnnl::convert_to_c(aprop_kind),
  8944. dnnl::convert_to_c(direction), src_layer_desc.get(),
  8945. src_iter_desc.get(), optional_arg(attention_desc),
  8946. weights_layer_desc.get(), weights_iter_desc.get(),
  8947. bias_desc.get(), dst_layer_desc.get(),
  8948. dst_iter_desc.get(), convert_to_c(flags), attr.get());
  8949. msg = "could not create a primitive descriptor for "
  8950. "the AUGRU forward propagation primitive. Run workload "
  8951. "with environment variable ONEDNN_VERBOSE=all to get "
  8952. "additional diagnostic information.";
  8953. break;
  8954. case algorithm::lbr_augru:
  8955. status = dnnl_lbr_augru_forward_primitive_desc_create(&pd,
  8956. aengine.get(), dnnl::convert_to_c(aprop_kind),
  8957. dnnl::convert_to_c(direction), src_layer_desc.get(),
  8958. src_iter_desc.get(), optional_arg(attention_desc),
  8959. weights_layer_desc.get(), weights_iter_desc.get(),
  8960. bias_desc.get(), dst_layer_desc.get(),
  8961. dst_iter_desc.get(), convert_to_c(flags), attr.get());
  8962. msg = "could not create a primitive descriptor for "
  8963. "the LBR AUGRU forward propagation primitive. Run "
  8964. "workload with environment variable ONEDNN_VERBOSE=all "
  8965. "to get additional diagnostic information.";
  8966. break;
  8967. default: status = dnnl_unimplemented;
  8968. }
  8969. if (!allow_empty) error::wrap_c_api(status, msg);
  8970. reset(pd);
  8971. }
  8972. // Constructs an RNN backward propagation primitive descriptor base for
  8973. // any cell kind.
  8974. rnn_primitive_desc_base(const engine &aengine, algorithm cell_kind,
  8975. prop_kind aprop_kind, algorithm activation, rnn_direction direction,
  8976. const memory::desc &src_layer_desc,
  8977. const memory::desc &src_iter_desc,
  8978. const memory::desc *src_iter_c_desc,
  8979. const memory::desc *attention_desc,
  8980. const memory::desc &weights_layer_desc,
  8981. const memory::desc &weights_iter_desc,
  8982. const memory::desc *weights_peephole_desc,
  8983. const memory::desc *weights_projection_desc,
  8984. const memory::desc &bias_desc, const memory::desc &dst_layer_desc,
  8985. const memory::desc &dst_iter_desc,
  8986. const memory::desc *dst_iter_c_desc,
  8987. const memory::desc &diff_src_layer_desc,
  8988. const memory::desc &diff_src_iter_desc,
  8989. const memory::desc *diff_src_iter_c_desc,
  8990. const memory::desc *diff_attention_desc,
  8991. const memory::desc &diff_weights_layer_desc,
  8992. const memory::desc &diff_weights_iter_desc,
  8993. const memory::desc *diff_weights_peephole_desc,
  8994. const memory::desc *diff_weights_projection_desc,
  8995. const memory::desc &diff_bias_desc,
  8996. const memory::desc &diff_dst_layer_desc,
  8997. const memory::desc &diff_dst_iter_desc,
  8998. const memory::desc *diff_dst_iter_c_desc, rnn_flags flags,
  8999. float alpha, float beta, const rnn_primitive_desc_base &hint_fwd_pd,
  9000. const primitive_attr &attr, bool allow_empty) {
  9001. dnnl_status_t status = dnnl_success;
  9002. const char *msg = "";
  9003. dnnl_primitive_desc_t pd = nullptr;
  9004. switch (cell_kind) {
  9005. case algorithm::vanilla_rnn:
  9006. status = dnnl_vanilla_rnn_backward_primitive_desc_create(&pd,
  9007. aengine.get(), dnnl::convert_to_c(aprop_kind),
  9008. dnnl::convert_to_c(activation),
  9009. dnnl::convert_to_c(direction), src_layer_desc.get(),
  9010. src_iter_desc.get(), weights_layer_desc.get(),
  9011. weights_iter_desc.get(), bias_desc.get(),
  9012. dst_layer_desc.get(), dst_iter_desc.get(),
  9013. diff_src_layer_desc.get(), diff_src_iter_desc.get(),
  9014. diff_weights_layer_desc.get(),
  9015. diff_weights_iter_desc.get(), diff_bias_desc.get(),
  9016. diff_dst_layer_desc.get(), diff_dst_iter_desc.get(),
  9017. convert_to_c(flags), alpha, beta, hint_fwd_pd.get(),
  9018. attr.get());
  9019. msg = "could not create a primitive descriptor for "
  9020. "the vanilla RNN backward propagation primitive. Run "
  9021. "workload with environment variable ONEDNN_VERBOSE=all "
  9022. "to get additional diagnostic information.";
  9023. break;
  9024. case algorithm::vanilla_lstm:
  9025. status = dnnl_lstm_backward_primitive_desc_create(&pd,
  9026. aengine.get(), dnnl::convert_to_c(aprop_kind),
  9027. dnnl::convert_to_c(direction), src_layer_desc.get(),
  9028. src_iter_desc.get(), optional_arg(src_iter_c_desc),
  9029. weights_layer_desc.get(), weights_iter_desc.get(),
  9030. optional_arg(weights_peephole_desc),
  9031. optional_arg(weights_projection_desc), bias_desc.get(),
  9032. dst_layer_desc.get(), dst_iter_desc.get(),
  9033. optional_arg(dst_iter_c_desc),
  9034. diff_src_layer_desc.get(), diff_src_iter_desc.get(),
  9035. optional_arg(diff_src_iter_c_desc),
  9036. diff_weights_layer_desc.get(),
  9037. diff_weights_iter_desc.get(),
  9038. optional_arg(diff_weights_peephole_desc),
  9039. optional_arg(diff_weights_projection_desc),
  9040. diff_bias_desc.get(), diff_dst_layer_desc.get(),
  9041. diff_dst_iter_desc.get(),
  9042. optional_arg(diff_dst_iter_c_desc), convert_to_c(flags),
  9043. hint_fwd_pd.get(), attr.get());
  9044. msg = "could not create a primitive descriptor for "
  9045. "the LSTM backward propagation primitive. Run workload "
  9046. "with environment variable ONEDNN_VERBOSE=all to get "
  9047. "additional diagnostic information.";
  9048. break;
  9049. case algorithm::vanilla_gru:
  9050. status = dnnl_gru_backward_primitive_desc_create(&pd,
  9051. aengine.get(), dnnl::convert_to_c(aprop_kind),
  9052. dnnl::convert_to_c(direction), src_layer_desc.get(),
  9053. src_iter_desc.get(), weights_layer_desc.get(),
  9054. weights_iter_desc.get(), bias_desc.get(),
  9055. dst_layer_desc.get(), dst_iter_desc.get(),
  9056. diff_src_layer_desc.get(), diff_src_iter_desc.get(),
  9057. diff_weights_layer_desc.get(),
  9058. diff_weights_iter_desc.get(), diff_bias_desc.get(),
  9059. diff_dst_layer_desc.get(), diff_dst_iter_desc.get(),
  9060. convert_to_c(flags), hint_fwd_pd.get(), attr.get());
  9061. msg = "could not create a primitive descriptor for "
  9062. "the GRU backward propagation primitive. Run workload "
  9063. "with environment variable ONEDNN_VERBOSE=all to get "
  9064. "additional diagnostic information.";
  9065. break;
  9066. case algorithm::lbr_gru:
  9067. status = dnnl_lbr_gru_backward_primitive_desc_create(&pd,
  9068. aengine.get(), dnnl::convert_to_c(aprop_kind),
  9069. dnnl::convert_to_c(direction), src_layer_desc.get(),
  9070. src_iter_desc.get(), weights_layer_desc.get(),
  9071. weights_iter_desc.get(), bias_desc.get(),
  9072. dst_layer_desc.get(), dst_iter_desc.get(),
  9073. diff_src_layer_desc.get(), diff_src_iter_desc.get(),
  9074. diff_weights_layer_desc.get(),
  9075. diff_weights_iter_desc.get(), diff_bias_desc.get(),
  9076. diff_dst_layer_desc.get(), diff_dst_iter_desc.get(),
  9077. convert_to_c(flags), hint_fwd_pd.get(), attr.get());
  9078. msg = "could not create a primitive descriptor for "
  9079. "the LBR GRU backward propagation primitive. Run "
  9080. "workload with environment variable ONEDNN_VERBOSE=all "
  9081. "to get additional diagnostic information.";
  9082. break;
  9083. case algorithm::vanilla_augru:
  9084. status = dnnl_augru_backward_primitive_desc_create(&pd,
  9085. aengine.get(), dnnl::convert_to_c(aprop_kind),
  9086. dnnl::convert_to_c(direction), src_layer_desc.get(),
  9087. src_iter_desc.get(), optional_arg(attention_desc),
  9088. weights_layer_desc.get(), weights_iter_desc.get(),
  9089. bias_desc.get(), dst_layer_desc.get(),
  9090. dst_iter_desc.get(), diff_src_layer_desc.get(),
  9091. diff_src_iter_desc.get(),
  9092. optional_arg(diff_attention_desc),
  9093. diff_weights_layer_desc.get(),
  9094. diff_weights_iter_desc.get(), diff_bias_desc.get(),
  9095. diff_dst_layer_desc.get(), diff_dst_iter_desc.get(),
  9096. convert_to_c(flags), hint_fwd_pd.get(), attr.get());
  9097. msg = "could not create a primitive descriptor for "
  9098. "the AUGRU backward propagation primitive. Run workload "
  9099. "with environment variable ONEDNN_VERBOSE=all to get "
  9100. "additional diagnostic information.";
  9101. break;
  9102. case algorithm::lbr_augru:
  9103. status = dnnl_lbr_augru_backward_primitive_desc_create(&pd,
  9104. aengine.get(), dnnl::convert_to_c(aprop_kind),
  9105. dnnl::convert_to_c(direction), src_layer_desc.get(),
  9106. src_iter_desc.get(), optional_arg(attention_desc),
  9107. weights_layer_desc.get(), weights_iter_desc.get(),
  9108. bias_desc.get(), dst_layer_desc.get(),
  9109. dst_iter_desc.get(), diff_src_layer_desc.get(),
  9110. diff_src_iter_desc.get(),
  9111. optional_arg(diff_attention_desc),
  9112. diff_weights_layer_desc.get(),
  9113. diff_weights_iter_desc.get(), diff_bias_desc.get(),
  9114. diff_dst_layer_desc.get(), diff_dst_iter_desc.get(),
  9115. convert_to_c(flags), hint_fwd_pd.get(), attr.get());
  9116. msg = "could not create a primitive descriptor for "
  9117. "the LBR AUGRU backward propagation primitive. Run "
  9118. "workload with environment variable ONEDNN_VERBOSE=all "
  9119. "to get additional diagnostic information.";
  9120. break;
  9121. default: status = dnnl_unimplemented;
  9122. }
  9123. if (!allow_empty) error::wrap_c_api(status, msg);
  9124. reset(pd);
  9125. }
  9126. };
  9127. /// Vanilla RNN forward propagation primitive.
  9128. struct vanilla_rnn_forward : public primitive {
  9129. /// Primitive descriptor for a vanilla RNN forward propagation primitive.
  9130. struct primitive_desc : public rnn_primitive_desc_base {
  9131. /// Default constructor. Produces an empty object.
  9132. primitive_desc() = default;
  9133. /// Constructs a primitive descriptor for a vanilla RNN forward
  9134. /// propagation primitive.
  9135. ///
  9136. /// The following arguments may point to a zero memory descriptor:
  9137. /// - @p src_iter_desc,
  9138. /// - @p bias_desc,
  9139. /// - @p dst_iter_desc.
  9140. ///
  9141. /// This would then indicate that the RNN forward propagation primitive
  9142. /// should not use them and should default to zero values instead.
  9143. ///
  9144. /// @note
  9145. /// All memory descriptors except @p src_iter_desc can be
  9146. /// initialized with an #dnnl::memory::format_tag::any value of @p
  9147. /// format_tag.
  9148. ///
  9149. /// @param aengine Engine to use.
  9150. /// @param aprop_kind Propagation kind. Possible values are
  9151. /// #dnnl::prop_kind::forward_training, and
  9152. /// #dnnl::prop_kind::forward_inference.
  9153. /// @param activation Activation kind. Possible values are
  9154. /// #dnnl::algorithm::eltwise_relu,
  9155. /// #dnnl::algorithm::eltwise_tanh, or
  9156. /// #dnnl::algorithm::eltwise_logistic.
  9157. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  9158. /// more info.
  9159. /// @param src_layer_desc Memory descriptor for the input vector.
  9160. /// @param src_iter_desc Memory descriptor for the input recurrent
  9161. /// hidden state vector.
  9162. /// @param weights_layer_desc Memory descriptor for the weights
  9163. /// applied to the layer input.
  9164. /// @param weights_iter_desc Memory descriptor for the weights applied
  9165. /// to the recurrent input.
  9166. /// @param bias_desc Bias memory descriptor.
  9167. /// @param dst_layer_desc Memory descriptor for the output vector.
  9168. /// @param dst_iter_desc Memory descriptor for the output recurrent
  9169. /// hidden state vector.
  9170. /// @param attr Primitive attributes to use. Attributes are optional
  9171. /// and default to empty attributes.
  9172. /// @param allow_empty A flag signifying whether construction is
  9173. /// allowed to fail without throwing an exception. In this case an
  9174. /// empty object will be produced. This flag is optional and
  9175. /// defaults to false.
  9176. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  9177. algorithm activation, rnn_direction direction,
  9178. const memory::desc &src_layer_desc,
  9179. const memory::desc &src_iter_desc,
  9180. const memory::desc &weights_layer_desc,
  9181. const memory::desc &weights_iter_desc,
  9182. const memory::desc &bias_desc,
  9183. const memory::desc &dst_layer_desc,
  9184. const memory::desc &dst_iter_desc,
  9185. const primitive_attr &attr = default_attr(),
  9186. bool allow_empty = false)
  9187. : rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn,
  9188. aprop_kind, activation, direction, src_layer_desc,
  9189. src_iter_desc, nullptr, nullptr, weights_layer_desc,
  9190. weights_iter_desc, nullptr, nullptr, bias_desc,
  9191. dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef,
  9192. 0.0f, 0.0f, attr, allow_empty) {}
  9193. /// Constructs a primitive descriptor for a vanilla RNN forward
  9194. /// propagation primitive with alpha parameter.
  9195. ///
  9196. /// The following arguments may point to a zero memory descriptor:
  9197. /// - @p src_iter_desc,
  9198. /// - @p bias_desc,
  9199. /// - @p dst_iter_desc.
  9200. ///
  9201. /// This would then indicate that the RNN forward propagation primitive
  9202. /// should not use them and should default to zero values instead.
  9203. ///
  9204. /// @note
  9205. /// All memory descriptors except @p src_iter_desc can be
  9206. /// initialized with an #dnnl::memory::format_tag::any value of @p
  9207. /// format_tag.
  9208. ///
  9209. /// @param aengine Engine to use.
  9210. /// @param aprop_kind Propagation kind. Possible values are
  9211. /// #dnnl::prop_kind::forward_training, and
  9212. /// #dnnl::prop_kind::forward_inference.
  9213. /// @param activation Activation kind. Possible values are
  9214. /// #dnnl::algorithm::eltwise_relu,
  9215. /// #dnnl::algorithm::eltwise_tanh, or
  9216. /// #dnnl::algorithm::eltwise_logistic.
  9217. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  9218. /// more info.
  9219. /// @param src_layer_desc Memory descriptor for the input vector.
  9220. /// @param src_iter_desc Memory descriptor for the input recurrent
  9221. /// hidden state vector.
  9222. /// @param weights_layer_desc Memory descriptor for the weights
  9223. /// applied to the layer input.
  9224. /// @param weights_iter_desc Memory descriptor for the weights applied
  9225. /// to the recurrent input.
  9226. /// @param bias_desc Bias memory descriptor.
  9227. /// @param dst_layer_desc Memory descriptor for the output vector.
  9228. /// @param dst_iter_desc Memory descriptor for the output recurrent
  9229. /// hidden state vector.
  9230. /// @param alpha Negative slope if activation is
  9231. /// #dnnl::algorithm::eltwise_relu.
  9232. /// @param attr Primitive attributes to use. Attributes are optional
  9233. /// and default to empty attributes.
  9234. /// @param allow_empty A flag signifying whether construction is
  9235. /// allowed to fail without throwing an exception. In this case an
  9236. /// empty object will be produced. This flag is optional and
  9237. /// defaults to false.
  9238. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  9239. algorithm activation, rnn_direction direction,
  9240. const memory::desc &src_layer_desc,
  9241. const memory::desc &src_iter_desc,
  9242. const memory::desc &weights_layer_desc,
  9243. const memory::desc &weights_iter_desc,
  9244. const memory::desc &bias_desc,
  9245. const memory::desc &dst_layer_desc,
  9246. const memory::desc &dst_iter_desc, float alpha,
  9247. const primitive_attr &attr = default_attr(),
  9248. bool allow_empty = false)
  9249. : rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn,
  9250. aprop_kind, activation, direction, src_layer_desc,
  9251. src_iter_desc, nullptr, nullptr, weights_layer_desc,
  9252. weights_iter_desc, nullptr, nullptr, bias_desc,
  9253. dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef,
  9254. alpha, 0.0f, attr, allow_empty) {}
  9255. /// Constructs a primitive descriptor for a vanilla RNN forward
  9256. /// propagation primitive from a C API primitive descriptor that must
  9257. /// have a matching kind.
  9258. ///
  9259. /// @param pd C API primitive descriptor for a vanilla RNN forward
  9260. /// propagation primitive.
  9261. primitive_desc(dnnl_primitive_desc_t pd)
  9262. : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
  9263. dnnl::prop_kind::forward_inference,
  9264. dnnl::algorithm::vanilla_rnn) {}
  9265. /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
  9266. memory::desc src_layer_desc() const {
  9267. return rnn_base::src_layer_desc();
  9268. }
  9269. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  9270. memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
  9271. /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
  9272. memory::desc weights_layer_desc() const {
  9273. return rnn_base::weights_layer_desc();
  9274. }
  9275. /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
  9276. memory::desc weights_iter_desc() const {
  9277. return rnn_base::weights_iter_desc();
  9278. }
  9279. /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
  9280. memory::desc bias_desc() const { return rnn_base::bias_desc(); }
  9281. /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
  9282. memory::desc dst_layer_desc() const {
  9283. return rnn_base::dst_layer_desc();
  9284. }
  9285. /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
  9286. memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
  9287. /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
  9288. memory::desc workspace_desc() const {
  9289. return rnn_base::workspace_desc();
  9290. }
  9291. /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
  9292. algorithm get_cell_kind() const { return base::get_cell_kind(); }
  9293. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  9294. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  9295. /// @copydoc dnnl::primitive_desc_base::get_activation_kind()const
  9296. algorithm get_activation_kind() const {
  9297. return base::get_activation_kind();
  9298. }
  9299. /// @copydoc dnnl::primitive_desc_base::get_direction()const
  9300. rnn_direction get_direction() const { return base::get_direction(); }
  9301. /// @copydoc dnnl::primitive_desc_base::get_alpha()const
  9302. float get_alpha() const { return base::get_alpha(); }
  9303. /// @copydoc dnnl::primitive_desc_base::get_beta()const
  9304. float get_beta() const { return base::get_beta(); }
  9305. };
  9306. /// Default constructor. Produces an empty object.
  9307. vanilla_rnn_forward() = default;
  9308. /// Constructs a vanilla RNN forward propagation primitive.
  9309. /// @param pd Primitive descriptor for a vanilla RNN forward
  9310. /// propagation primitive.
  9311. vanilla_rnn_forward(const primitive_desc &pd) : primitive(pd) {}
  9312. /// Constructs a vanilla RNN forward propagation primitive from
  9313. /// a cache blob.
  9314. /// @param pd Primitive descriptor for a vanilla RNN forward
  9315. /// propagation primitive.
  9316. /// @param cache_blob Cache blob.
  9317. vanilla_rnn_forward(
  9318. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  9319. : primitive(pd, cache_blob) {}
  9320. };
  9321. /// Vanilla RNN backward propagation primitive.
  9322. struct vanilla_rnn_backward : public primitive {
  9323. /// Primitive descriptor for an RNN backward propagation primitive.
  9324. struct primitive_desc : public rnn_primitive_desc_base {
  9325. /// Default constructor. Produces an empty object.
  9326. primitive_desc() = default;
  9327. /// Constructs a primitive descriptor for a vanilla RNN backward
  9328. /// propagation primitive.
  9329. ///
  9330. /// The following arguments may point to a zero memory descriptor:
  9331. /// - @p src_iter_desc together with @p diff_src_iter_desc,
  9332. /// - @p bias_desc together with @p diff_bias_desc,
  9333. /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
  9334. ///
  9335. /// This would then indicate that the RNN backward propagation
  9336. /// primitive should not use the respective data and should use zero
  9337. /// values instead.
  9338. ///
  9339. /// @note
  9340. /// All the memory descriptors may be initialized with the
  9341. /// #dnnl::memory::format_tag::any value of @p format_tag.
  9342. ///
  9343. /// @param aengine Engine to use.
  9344. /// @param aprop_kind Propagation kind. Must be
  9345. /// #dnnl::prop_kind::backward.
  9346. /// @param activation Activation kind. Possible values are
  9347. /// #dnnl::algorithm::eltwise_relu,
  9348. /// #dnnl::algorithm::eltwise_tanh, or
  9349. /// #dnnl::algorithm::eltwise_logistic.
  9350. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  9351. /// more info.
  9352. /// @param src_layer_desc Memory descriptor for the input vector.
  9353. /// @param src_iter_desc Memory descriptor for the input recurrent
  9354. /// hidden state vector.
  9355. /// @param weights_layer_desc Memory descriptor for the weights
  9356. /// applied to the layer input.
  9357. /// @param weights_iter_desc Memory descriptor for the weights applied
  9358. /// to the recurrent input.
  9359. /// @param bias_desc Bias memory descriptor.
  9360. /// @param dst_layer_desc Memory descriptor for the output vector.
  9361. /// @param dst_iter_desc Memory descriptor for the output recurrent
  9362. /// hidden state vector.
  9363. /// @param diff_src_layer_desc Memory descriptor for the diff of input
  9364. /// vector.
  9365. /// @param diff_src_iter_desc Memory descriptor for the diff of input
  9366. /// recurrent hidden state vector.
  9367. /// @param diff_weights_layer_desc Memory descriptor for the diff of
  9368. /// weights applied to the layer input.
  9369. /// @param diff_weights_iter_desc Memory descriptor for the diff of
  9370. /// weights applied to the recurrent input.
  9371. /// @param diff_bias_desc Diff bias memory descriptor.
  9372. /// @param diff_dst_layer_desc Memory descriptor for the diff of
  9373. /// output vector.
  9374. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  9375. /// recurrent hidden state vector.
  9376. /// @param hint_fwd_pd Primitive descriptor for a vanilla RNN
  9377. /// forward propagation primitive. It is used as a hint for
  9378. /// deciding which memory format to use.
  9379. /// @param attr Primitive attributes to use. Attributes are optional
  9380. /// and default to empty attributes.
  9381. /// @param allow_empty A flag signifying whether construction is
  9382. /// allowed to fail without throwing an exception. In this case an
  9383. /// empty object will be produced. This flag is optional and
  9384. /// defaults to false.
  9385. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  9386. algorithm activation, rnn_direction direction,
  9387. const memory::desc &src_layer_desc,
  9388. const memory::desc &src_iter_desc,
  9389. const memory::desc &weights_layer_desc,
  9390. const memory::desc &weights_iter_desc,
  9391. const memory::desc &bias_desc,
  9392. const memory::desc &dst_layer_desc,
  9393. const memory::desc &dst_iter_desc,
  9394. const memory::desc &diff_src_layer_desc,
  9395. const memory::desc &diff_src_iter_desc,
  9396. const memory::desc &diff_weights_layer_desc,
  9397. const memory::desc &diff_weights_iter_desc,
  9398. const memory::desc &diff_bias_desc,
  9399. const memory::desc &diff_dst_layer_desc,
  9400. const memory::desc &diff_dst_iter_desc,
  9401. const vanilla_rnn_forward::primitive_desc &hint_fwd_pd,
  9402. const primitive_attr &attr = default_attr(),
  9403. bool allow_empty = false)
  9404. : rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn,
  9405. aprop_kind, activation, direction, src_layer_desc,
  9406. src_iter_desc, nullptr, nullptr, weights_layer_desc,
  9407. weights_iter_desc, nullptr, nullptr, bias_desc,
  9408. dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
  9409. diff_src_iter_desc, nullptr, nullptr,
  9410. diff_weights_layer_desc, diff_weights_iter_desc, nullptr,
  9411. nullptr, diff_bias_desc, diff_dst_layer_desc,
  9412. diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f,
  9413. hint_fwd_pd, attr, allow_empty) {}
  9414. /// Constructs a primitive descriptor for a vanilla RNN backward
  9415. /// propagation primitive with an alpha parameter.
  9416. ///
  9417. /// The following arguments may point to a zero memory descriptor:
  9418. /// - @p src_iter_desc together with @p diff_src_iter_desc,
  9419. /// - @p bias_desc together with @p diff_bias_desc,
  9420. /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
  9421. ///
  9422. /// This would then indicate that the RNN backward propagation
  9423. /// primitive should not use the respective data and should use zero
  9424. /// values instead.
  9425. ///
  9426. /// @note
  9427. /// All the memory descriptors may be initialized with the
  9428. /// #dnnl::memory::format_tag::any value of @p format_tag.
  9429. ///
  9430. /// @param aengine Engine to use.
  9431. /// @param aprop_kind Propagation kind. Must be
  9432. /// #dnnl::prop_kind::backward.
  9433. /// @param activation Activation kind. Possible values are
  9434. /// #dnnl::algorithm::eltwise_relu,
  9435. /// #dnnl::algorithm::eltwise_tanh, or
  9436. /// #dnnl::algorithm::eltwise_logistic.
  9437. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  9438. /// more info.
  9439. /// @param src_layer_desc Memory descriptor for the input vector.
  9440. /// @param src_iter_desc Memory descriptor for the input recurrent
  9441. /// hidden state vector.
  9442. /// @param weights_layer_desc Memory descriptor for the weights
  9443. /// applied to the layer input.
  9444. /// @param weights_iter_desc Memory descriptor for the weights applied
  9445. /// to the recurrent input.
  9446. /// @param bias_desc Bias memory descriptor.
  9447. /// @param dst_layer_desc Memory descriptor for the output vector.
  9448. /// @param dst_iter_desc Memory descriptor for the output recurrent
  9449. /// hidden state vector.
  9450. /// @param diff_src_layer_desc Memory descriptor for the diff of input
  9451. /// vector.
  9452. /// @param diff_src_iter_desc Memory descriptor for the diff of input
  9453. /// recurrent hidden state vector.
  9454. /// @param diff_weights_layer_desc Memory descriptor for the diff of
  9455. /// weights applied to the layer input.
  9456. /// @param diff_weights_iter_desc Memory descriptor for the diff of
  9457. /// weights applied to the recurrent input.
  9458. /// @param diff_bias_desc Diff bias memory descriptor.
  9459. /// @param diff_dst_layer_desc Memory descriptor for the diff of
  9460. /// output vector.
  9461. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  9462. /// recurrent hidden state vector.
  9463. /// @param alpha Negative slope if activation is
  9464. /// #dnnl::algorithm::eltwise_relu.
  9465. /// @param hint_fwd_pd Primitive descriptor for a vanilla RNN
  9466. /// forward propagation primitive. It is used as a hint for
  9467. /// deciding which memory format to use.
  9468. /// @param attr Primitive attributes to use. Attributes are optional
  9469. /// and default to empty attributes.
  9470. /// @param allow_empty A flag signifying whether construction is
  9471. /// allowed to fail without throwing an exception. In this case an
  9472. /// empty object will be produced. This flag is optional and
  9473. /// defaults to false.
  9474. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  9475. algorithm activation, rnn_direction direction,
  9476. const memory::desc &src_layer_desc,
  9477. const memory::desc &src_iter_desc,
  9478. const memory::desc &weights_layer_desc,
  9479. const memory::desc &weights_iter_desc,
  9480. const memory::desc &bias_desc,
  9481. const memory::desc &dst_layer_desc,
  9482. const memory::desc &dst_iter_desc,
  9483. const memory::desc &diff_src_layer_desc,
  9484. const memory::desc &diff_src_iter_desc,
  9485. const memory::desc &diff_weights_layer_desc,
  9486. const memory::desc &diff_weights_iter_desc,
  9487. const memory::desc &diff_bias_desc,
  9488. const memory::desc &diff_dst_layer_desc,
  9489. const memory::desc &diff_dst_iter_desc, float alpha,
  9490. const vanilla_rnn_forward::primitive_desc &hint_fwd_pd,
  9491. const primitive_attr &attr = default_attr(),
  9492. bool allow_empty = false)
  9493. : rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn,
  9494. aprop_kind, activation, direction, src_layer_desc,
  9495. src_iter_desc, nullptr, nullptr, weights_layer_desc,
  9496. weights_iter_desc, nullptr, nullptr, bias_desc,
  9497. dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
  9498. diff_src_iter_desc, nullptr, nullptr,
  9499. diff_weights_layer_desc, diff_weights_iter_desc, nullptr,
  9500. nullptr, diff_bias_desc, diff_dst_layer_desc,
  9501. diff_dst_iter_desc, nullptr, rnn_flags::undef, alpha, 0.0f,
  9502. hint_fwd_pd, attr, allow_empty) {}
  9503. /// Constructs a primitive descriptor for a vanilla RNN backward
  9504. /// propagation primitive from a C API primitive descriptor that must
  9505. /// have a matching kind.
  9506. ///
  9507. /// @param pd C API primitive descriptor for a vanilla RNN backward
  9508. /// propagation primitive.
  9509. primitive_desc(dnnl_primitive_desc_t pd)
  9510. : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
  9511. dnnl::algorithm::vanilla_rnn) {}
  9512. /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
  9513. memory::desc src_layer_desc() const {
  9514. return rnn_base::src_layer_desc();
  9515. }
  9516. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  9517. memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
  9518. /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
  9519. memory::desc weights_layer_desc() const {
  9520. return rnn_base::weights_layer_desc();
  9521. }
  9522. /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
  9523. memory::desc weights_iter_desc() const {
  9524. return rnn_base::weights_iter_desc();
  9525. }
  9526. /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
  9527. memory::desc bias_desc() const { return rnn_base::bias_desc(); }
  9528. /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
  9529. memory::desc dst_layer_desc() const {
  9530. return rnn_base::dst_layer_desc();
  9531. }
  9532. /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
  9533. memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
  9534. /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
  9535. memory::desc workspace_desc() const {
  9536. return rnn_base::workspace_desc();
  9537. }
  9538. /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
  9539. memory::desc diff_src_layer_desc() const {
  9540. return rnn_base::diff_src_layer_desc();
  9541. }
  9542. /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
  9543. memory::desc diff_src_iter_desc() const {
  9544. return rnn_base::diff_src_iter_desc();
  9545. }
  9546. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
  9547. memory::desc diff_weights_layer_desc() const {
  9548. return rnn_base::diff_weights_layer_desc();
  9549. }
  9550. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
  9551. memory::desc diff_weights_iter_desc() const {
  9552. return rnn_base::diff_weights_iter_desc();
  9553. }
  9554. /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
  9555. memory::desc diff_bias_desc() const {
  9556. return rnn_base::diff_bias_desc();
  9557. }
  9558. /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
  9559. memory::desc diff_dst_layer_desc() const {
  9560. return rnn_base::diff_dst_layer_desc();
  9561. }
  9562. /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
  9563. memory::desc diff_dst_iter_desc() const {
  9564. return rnn_base::diff_dst_iter_desc();
  9565. }
  9566. /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
  9567. algorithm get_cell_kind() const { return base::get_cell_kind(); }
  9568. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  9569. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  9570. /// @copydoc dnnl::primitive_desc_base::get_activation_kind()const
  9571. algorithm get_activation_kind() const {
  9572. return base::get_activation_kind();
  9573. }
  9574. /// @copydoc dnnl::primitive_desc_base::get_direction()const
  9575. rnn_direction get_direction() const { return base::get_direction(); }
  9576. /// @copydoc dnnl::primitive_desc_base::get_alpha()const
  9577. float get_alpha() const { return base::get_alpha(); }
  9578. /// @copydoc dnnl::primitive_desc_base::get_beta()const
  9579. float get_beta() const { return base::get_beta(); }
  9580. };
  9581. /// Default constructor. Produces an empty object.
  9582. vanilla_rnn_backward() = default;
  9583. /// Constructs a vanilla RNN backward propagation primitive.
  9584. /// @param pd Primitive descriptor for a vanilla RNN backward
  9585. /// propagation primitive.
  9586. vanilla_rnn_backward(const primitive_desc &pd) : primitive(pd) {}
  9587. /// Constructs a vanilla RNN backward propagation primitive from
  9588. /// a cache blob.
  9589. /// @param pd Primitive descriptor for a vanilla RNN backward
  9590. /// propagation primitive.
  9591. /// @param cache_blob Cache blob.
  9592. vanilla_rnn_backward(
  9593. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  9594. : primitive(pd, cache_blob) {}
  9595. };
  9596. /// LSTM forward propagation primitive.
  9597. struct lstm_forward : public primitive {
  9598. /// Primitive descriptor for an LSTM forward propagation primitive.
  9599. struct primitive_desc : public rnn_primitive_desc_base {
  9600. /// Default constructor. Produces an empty object.
  9601. primitive_desc() = default;
  9602. /// Constructs a primitive descriptor for an LSTM (with or without
  9603. /// peephole and with or without projection) forward propagation
  9604. /// primitive.
  9605. ///
  9606. /// The following arguments may point to a zero memory descriptor:
  9607. /// - @p src_iter_desc together with @p src_iter_c_desc,
  9608. /// - @p weights_peephole_desc,
  9609. /// - @p bias_desc,
  9610. /// - @p dst_iter_desc together with @p dst_iter_c_desc.
  9611. ///
  9612. /// This would then indicate that the LSTM forward propagation
  9613. /// primitive should not use them and should default to zero values
  9614. /// instead.
  9615. ///
  9616. /// The @p weights_projection_desc may point to a zero memory
  9617. /// descriptor. This would then indicate that the LSTM doesn't have
  9618. /// recurrent projection layer.
  9619. ///
  9620. /// @note
  9621. /// All memory descriptors can be initialized with an
  9622. /// #dnnl::memory::format_tag::any value of @p format_tag.
  9623. ///
  9624. /// @param aengine Engine to use.
  9625. /// @param aprop_kind Propagation kind. Possible values are
  9626. /// #dnnl::prop_kind::forward_training, and
  9627. /// #dnnl::prop_kind::forward_inference.
  9628. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  9629. /// more info.
  9630. /// @param src_layer_desc Memory descriptor for the input vector.
  9631. /// @param src_iter_desc Memory descriptor for the input recurrent
  9632. /// hidden state vector.
  9633. /// @param src_iter_c_desc Memory descriptor for the input recurrent
  9634. /// cell state vector.
  9635. /// @param weights_layer_desc Memory descriptor for the weights
  9636. /// applied to the layer input.
  9637. /// @param weights_iter_desc Memory descriptor for the weights applied
  9638. /// to the recurrent input.
  9639. /// @param weights_peephole_desc Memory descriptor for the weights
  9640. /// applied to the cell states (according to the Peephole LSTM
  9641. /// formula).
  9642. /// @param weights_projection_desc Memory descriptor for the weights
  9643. /// applied to the hidden states to get the recurrent projection
  9644. /// (according to the Projection LSTM formula).
  9645. /// @param bias_desc Bias memory descriptor.
  9646. /// @param dst_layer_desc Memory descriptor for the output vector.
  9647. /// @param dst_iter_desc Memory descriptor for the output recurrent
  9648. /// hidden state vector.
  9649. /// @param dst_iter_c_desc Memory descriptor for the output recurrent
  9650. /// cell state vector.
  9651. /// @param attr Primitive attributes to use. Attributes are optional
  9652. /// and default to empty attributes.
  9653. /// @param allow_empty A flag signifying whether construction is
  9654. /// allowed to fail without throwing an exception. In this case an
  9655. /// empty object will be produced. This flag is optional and
  9656. /// defaults to false.
  9657. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  9658. rnn_direction direction, const memory::desc &src_layer_desc,
  9659. const memory::desc &src_iter_desc,
  9660. const memory::desc &src_iter_c_desc,
  9661. const memory::desc &weights_layer_desc,
  9662. const memory::desc &weights_iter_desc,
  9663. const memory::desc &weights_peephole_desc,
  9664. const memory::desc &weights_projection_desc,
  9665. const memory::desc &bias_desc,
  9666. const memory::desc &dst_layer_desc,
  9667. const memory::desc &dst_iter_desc,
  9668. const memory::desc &dst_iter_c_desc,
  9669. const primitive_attr &attr = default_attr(),
  9670. bool allow_empty = false)
  9671. : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
  9672. aprop_kind, algorithm::undef, direction, src_layer_desc,
  9673. src_iter_desc, &src_iter_c_desc, nullptr,
  9674. weights_layer_desc, weights_iter_desc,
  9675. &weights_peephole_desc, &weights_projection_desc, bias_desc,
  9676. dst_layer_desc, dst_iter_desc, &dst_iter_c_desc,
  9677. rnn_flags::undef, 0.0f, 0.0f, attr, allow_empty) {}
  9678. /// Constructs a primitive descriptor for an LSTM (with or without
  9679. /// peephole) forward propagation primitive.
  9680. ///
  9681. /// The following arguments may point to a zero memory descriptor:
  9682. /// - @p src_iter_desc together with @p src_iter_c_desc,
  9683. /// - @p weights_peephole_desc,
  9684. /// - @p bias_desc,
  9685. /// - @p dst_iter_desc together with @p dst_iter_c_desc.
  9686. ///
  9687. /// This would then indicate that the LSTM forward propagation
  9688. /// primitive should not use them and should default to zero values
  9689. /// instead.
  9690. ///
  9691. /// @note
  9692. /// All memory descriptors can be initialized with an
  9693. /// #dnnl::memory::format_tag::any value of @p format_tag.
  9694. ///
  9695. /// @param aengine Engine to use.
  9696. /// @param aprop_kind Propagation kind. Possible values are
  9697. /// #dnnl::prop_kind::forward_training, and
  9698. /// #dnnl::prop_kind::forward_inference.
  9699. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  9700. /// more info.
  9701. /// @param src_layer_desc Memory descriptor for the input vector.
  9702. /// @param src_iter_desc Memory descriptor for the input recurrent
  9703. /// hidden state vector.
  9704. /// @param src_iter_c_desc Memory descriptor for the input recurrent
  9705. /// cell state vector.
  9706. /// @param weights_layer_desc Memory descriptor for the weights
  9707. /// applied to the layer input.
  9708. /// @param weights_iter_desc Memory descriptor for the weights applied
  9709. /// to the recurrent input.
  9710. /// @param weights_peephole_desc Memory descriptor for the weights
  9711. /// applied to the cell states (according to the Peephole LSTM
  9712. /// formula).
  9713. /// @param bias_desc Bias memory descriptor.
  9714. /// @param dst_layer_desc Memory descriptor for the output vector.
  9715. /// @param dst_iter_desc Memory descriptor for the output recurrent
  9716. /// hidden state vector.
  9717. /// @param dst_iter_c_desc Memory descriptor for the output recurrent
  9718. /// cell state vector.
  9719. /// @param attr Primitive attributes to use. Attributes are optional
  9720. /// and default to empty attributes.
  9721. /// @param allow_empty A flag signifying whether construction is
  9722. /// allowed to fail without throwing an exception. In this case an
  9723. /// empty object will be produced. This flag is optional and
  9724. /// defaults to false.
  9725. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  9726. rnn_direction direction, const memory::desc &src_layer_desc,
  9727. const memory::desc &src_iter_desc,
  9728. const memory::desc &src_iter_c_desc,
  9729. const memory::desc &weights_layer_desc,
  9730. const memory::desc &weights_iter_desc,
  9731. const memory::desc &weights_peephole_desc,
  9732. const memory::desc &bias_desc,
  9733. const memory::desc &dst_layer_desc,
  9734. const memory::desc &dst_iter_desc,
  9735. const memory::desc &dst_iter_c_desc,
  9736. const primitive_attr &attr = default_attr(),
  9737. bool allow_empty = false)
  9738. : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
  9739. aprop_kind, algorithm::undef, direction, src_layer_desc,
  9740. src_iter_desc, &src_iter_c_desc, nullptr,
  9741. weights_layer_desc, weights_iter_desc,
  9742. &weights_peephole_desc, nullptr, bias_desc, dst_layer_desc,
  9743. dst_iter_desc, &dst_iter_c_desc, rnn_flags::undef, 0.0f,
  9744. 0.0f, attr, allow_empty) {}
  9745. /// Constructs a primitive descriptor for an LSTM forward propagation
  9746. /// primitive.
  9747. ///
  9748. /// The following arguments may point to a zero memory descriptor:
  9749. /// - @p src_iter_desc together with @p src_iter_c_desc,
  9750. /// - @p bias_desc,
  9751. /// - @p dst_iter_desc together with @p dst_iter_c_desc.
  9752. ///
  9753. /// This would then indicate that the LSTM forward propagation
  9754. /// primitive should not use them and should default to zero values
  9755. /// instead.
  9756. ///
  9757. /// @note
  9758. /// All memory descriptors can be initialized with an
  9759. /// #dnnl::memory::format_tag::any value of @p format_tag.
  9760. ///
  9761. /// @param aengine Engine to use.
  9762. /// @param aprop_kind Propagation kind. Possible values are
  9763. /// #dnnl::prop_kind::forward_training, and
  9764. /// #dnnl::prop_kind::forward_inference.
  9765. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  9766. /// more info.
  9767. /// @param src_layer_desc Memory descriptor for the input vector.
  9768. /// @param src_iter_desc Memory descriptor for the input recurrent
  9769. /// hidden state vector.
  9770. /// @param src_iter_c_desc Memory descriptor for the input recurrent
  9771. /// cell state vector.
  9772. /// @param weights_layer_desc Memory descriptor for the weights
  9773. /// applied to the layer input.
  9774. /// @param weights_iter_desc Memory descriptor for the weights applied
  9775. /// to the recurrent input.
  9776. /// @param bias_desc Bias memory descriptor.
  9777. /// @param dst_layer_desc Memory descriptor for the output vector.
  9778. /// @param dst_iter_desc Memory descriptor for the output recurrent
  9779. /// hidden state vector.
  9780. /// @param dst_iter_c_desc Memory descriptor for the output recurrent
  9781. /// cell state vector.
  9782. /// @param attr Primitive attributes to use. Attributes are optional
  9783. /// and default to empty attributes.
  9784. /// @param allow_empty A flag signifying whether construction is
  9785. /// allowed to fail without throwing an exception. In this case an
  9786. /// empty object will be produced. This flag is optional and
  9787. /// defaults to false.
  9788. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  9789. rnn_direction direction, const memory::desc &src_layer_desc,
  9790. const memory::desc &src_iter_desc,
  9791. const memory::desc &src_iter_c_desc,
  9792. const memory::desc &weights_layer_desc,
  9793. const memory::desc &weights_iter_desc,
  9794. const memory::desc &bias_desc,
  9795. const memory::desc &dst_layer_desc,
  9796. const memory::desc &dst_iter_desc,
  9797. const memory::desc &dst_iter_c_desc,
  9798. const primitive_attr &attr = default_attr(),
  9799. bool allow_empty = false)
  9800. : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
  9801. aprop_kind, algorithm::undef, direction, src_layer_desc,
  9802. src_iter_desc, &src_iter_c_desc, nullptr,
  9803. weights_layer_desc, weights_iter_desc, nullptr, nullptr,
  9804. bias_desc, dst_layer_desc, dst_iter_desc, &dst_iter_c_desc,
  9805. rnn_flags::undef, 0.0f, 0.0f, attr, allow_empty) {}
  9806. /// Constructs a primitive descriptor for an LSTM forward propagation
  9807. /// primitive from a C API primitive descriptor that must have a
  9808. /// matching kind.
  9809. ///
  9810. /// @param pd C API primitive descriptor for an LSTM forward
  9811. /// propagation primitive.
  9812. primitive_desc(dnnl_primitive_desc_t pd)
  9813. : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
  9814. dnnl::prop_kind::forward_inference,
  9815. dnnl::algorithm::vanilla_lstm) {}
  9816. /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
  9817. memory::desc src_layer_desc() const {
  9818. return rnn_base::src_layer_desc();
  9819. }
  9820. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  9821. memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
  9822. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  9823. memory::desc src_iter_c_desc() const {
  9824. return rnn_base::src_iter_c_desc();
  9825. }
  9826. /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
  9827. memory::desc weights_layer_desc() const {
  9828. return rnn_base::weights_layer_desc();
  9829. }
  9830. /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
  9831. memory::desc weights_iter_desc() const {
  9832. return rnn_base::weights_iter_desc();
  9833. }
  9834. /// @copydoc dnnl::rnn_primitive_desc_base::weights_peephole_desc()const
  9835. memory::desc weights_peephole_desc() const {
  9836. return rnn_base::weights_peephole_desc();
  9837. }
  9838. /// @copydoc dnnl::rnn_primitive_desc_base::weights_projection_desc()const
  9839. memory::desc weights_projection_desc() const {
  9840. return rnn_base::weights_projection_desc();
  9841. }
  9842. /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
  9843. memory::desc bias_desc() const { return rnn_base::bias_desc(); }
  9844. /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
  9845. memory::desc dst_layer_desc() const {
  9846. return rnn_base::dst_layer_desc();
  9847. }
  9848. /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
  9849. memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
  9850. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  9851. memory::desc dst_iter_c_desc() const {
  9852. return rnn_base::dst_iter_c_desc();
  9853. }
  9854. /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
  9855. memory::desc workspace_desc() const {
  9856. return rnn_base::workspace_desc();
  9857. }
  9858. /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
  9859. algorithm get_cell_kind() const { return base::get_cell_kind(); }
  9860. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  9861. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  9862. /// @copydoc dnnl::primitive_desc_base::get_direction()const
  9863. rnn_direction get_direction() const { return base::get_direction(); }
  9864. };
  9865. /// Default constructor. Produces an empty object.
  9866. lstm_forward() = default;
  9867. /// Constructs an LSTM forward propagation primitive.
  9868. /// @param pd Primitive descriptor for an LSTM forward propagation
  9869. /// primitive.
  9870. lstm_forward(const primitive_desc &pd) : primitive(pd) {}
  9871. /// Constructs an LSTM forward propagation primitive from a cache blob.
  9872. /// @param pd Primitive descriptor for an LSTM forward propagation
  9873. /// primitive.
  9874. /// @param cache_blob Cache blob.
  9875. lstm_forward(
  9876. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  9877. : primitive(pd, cache_blob) {}
  9878. };
  9879. /// LSTM backward propagation primitive.
  9880. struct lstm_backward : public primitive {
  9881. /// Primitive descriptor for an LSTM backward propagation primitive.
  9882. struct primitive_desc : public rnn_primitive_desc_base {
  9883. /// Default constructor. Produces an empty object.
  9884. primitive_desc() = default;
  9885. /// Constructs an LSTM (with or without peephole and with or without
  9886. /// projection) primitive descriptor for backward propagation
  9887. /// using @p prop_kind, @p direction, and memory descriptors.
  9888. ///
  9889. /// The following arguments may point to a zero memory descriptor:
  9890. /// - @p src_iter_desc together with @p src_iter_c_desc,
  9891. /// @p diff_src_iter_desc, and @p diff_src_iter_c_desc,
  9892. /// - @p weights_peephole_desc together with
  9893. /// @p diff_weights_peephole_desc
  9894. /// - @p bias_desc together with @p diff_bias_desc,
  9895. /// - @p dst_iter_desc together with @p dst_iter_c_desc,
  9896. /// @p diff_dst_iter_desc, and @p diff_dst_iter_c_desc.
  9897. ///
  9898. /// This would then indicate that the LSTM backward propagation
  9899. /// primitive should not use them and should default to zero values
  9900. /// instead.
  9901. ///
  9902. /// The @p weights_projection_desc together with @p
  9903. /// diff_weights_projection_desc may point to a zero memory descriptor.
  9904. /// This would then indicate that the LSTM doesn't have recurrent
  9905. /// projection layer.
  9906. ///
  9907. /// @note
  9908. /// All memory descriptors can be initialized with
  9909. /// #dnnl::memory::format_tag::any value of @p format_tag.
  9910. ///
  9911. /// @param aengine Engine to use.
  9912. /// @param aprop_kind Propagation kind. Must be
  9913. /// #dnnl::prop_kind::backward.
  9914. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  9915. /// more info.
  9916. /// @param src_layer_desc Memory descriptor for the input vector.
  9917. /// @param src_iter_desc Memory descriptor for the input recurrent
  9918. /// hidden state vector.
  9919. /// @param src_iter_c_desc Memory descriptor for the input recurrent
  9920. /// cell state vector.
  9921. /// @param weights_layer_desc Memory descriptor for the weights
  9922. /// applied to the layer input.
  9923. /// @param weights_iter_desc Memory descriptor for the weights applied
  9924. /// to the recurrent input.
  9925. /// @param weights_peephole_desc Memory descriptor for the weights
  9926. /// applied to the cell states (according to the Peephole LSTM
  9927. /// formula).
  9928. /// @param weights_projection_desc Memory descriptor for the weights
  9929. /// applied to the hidden states to get the recurrent projection
  9930. /// (according to the Projection LSTM formula).
  9931. /// @param bias_desc Bias memory descriptor.
  9932. /// @param dst_layer_desc Memory descriptor for the output vector.
  9933. /// @param dst_iter_desc Memory descriptor for the output recurrent
  9934. /// hidden state vector.
  9935. /// @param dst_iter_c_desc Memory descriptor for the output recurrent
  9936. /// cell state vector.
  9937. /// @param diff_src_layer_desc Memory descriptor for the diff of input
  9938. /// vector.
  9939. /// @param diff_src_iter_desc Memory descriptor for the diff of input
  9940. /// recurrent hidden state vector.
  9941. /// @param diff_src_iter_c_desc Memory descriptor for the diff of
  9942. /// input recurrent cell state vector.
  9943. /// @param diff_weights_layer_desc Memory descriptor for the diff of
  9944. /// weights applied to the layer input.
  9945. /// @param diff_weights_iter_desc Memory descriptor for the diff of
  9946. /// weights applied to the recurrent input.
  9947. /// @param diff_weights_peephole_desc Memory descriptor for the diff of
  9948. /// weights applied to the cell states (according to the Peephole
  9949. /// LSTM formula).
  9950. /// @param diff_weights_projection_desc Memory descriptor for the diff
  9951. /// of weights applied to the hidden states to get the recurrent
  9952. /// projection (according to the Projection LSTM formula).
  9953. /// @param diff_bias_desc Diff bias memory descriptor.
  9954. /// @param diff_dst_layer_desc Memory descriptor for the diff of
  9955. /// output vector.
  9956. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  9957. /// recurrent hidden state vector.
  9958. /// @param diff_dst_iter_c_desc Memory descriptor for the diff of
  9959. /// output recurrent cell state vector.
  9960. /// @param hint_fwd_pd Primitive descriptor for an LSTM
  9961. /// forward propagation primitive. It is used as a hint for
  9962. /// deciding which memory format to use.
  9963. /// @param attr Primitive attributes to use. Attributes are optional
  9964. /// and default to empty attributes.
  9965. /// @param allow_empty A flag signifying whether construction is
  9966. /// allowed to fail without throwing an exception. In this case an
  9967. /// empty object will be produced. This flag is optional and
  9968. /// defaults to false.
  9969. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  9970. rnn_direction direction, const memory::desc &src_layer_desc,
  9971. const memory::desc &src_iter_desc,
  9972. const memory::desc &src_iter_c_desc,
  9973. const memory::desc &weights_layer_desc,
  9974. const memory::desc &weights_iter_desc,
  9975. const memory::desc &weights_peephole_desc,
  9976. const memory::desc &weights_projection_desc,
  9977. const memory::desc &bias_desc,
  9978. const memory::desc &dst_layer_desc,
  9979. const memory::desc &dst_iter_desc,
  9980. const memory::desc &dst_iter_c_desc,
  9981. const memory::desc &diff_src_layer_desc,
  9982. const memory::desc &diff_src_iter_desc,
  9983. const memory::desc &diff_src_iter_c_desc,
  9984. const memory::desc &diff_weights_layer_desc,
  9985. const memory::desc &diff_weights_iter_desc,
  9986. const memory::desc &diff_weights_peephole_desc,
  9987. const memory::desc &diff_weights_projection_desc,
  9988. const memory::desc &diff_bias_desc,
  9989. const memory::desc &diff_dst_layer_desc,
  9990. const memory::desc &diff_dst_iter_desc,
  9991. const memory::desc &diff_dst_iter_c_desc,
  9992. const lstm_forward::primitive_desc &hint_fwd_pd,
  9993. const primitive_attr &attr = default_attr(),
  9994. bool allow_empty = false)
  9995. : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
  9996. aprop_kind, algorithm::undef, direction, src_layer_desc,
  9997. src_iter_desc, &src_iter_c_desc, nullptr,
  9998. weights_layer_desc, weights_iter_desc,
  9999. &weights_peephole_desc, &weights_projection_desc, bias_desc,
  10000. dst_layer_desc, dst_iter_desc, &dst_iter_c_desc,
  10001. diff_src_layer_desc, diff_src_iter_desc,
  10002. &diff_src_iter_c_desc, nullptr, diff_weights_layer_desc,
  10003. diff_weights_iter_desc, &diff_weights_peephole_desc,
  10004. &diff_weights_projection_desc, diff_bias_desc,
  10005. diff_dst_layer_desc, diff_dst_iter_desc,
  10006. &diff_dst_iter_c_desc, rnn_flags::undef, 0.0f, 0.0f,
  10007. hint_fwd_pd, attr, allow_empty) {}
  10008. /// Constructs an LSTM (with or without peephole) primitive descriptor
  10009. /// for backward propagation using @p prop_kind, @p direction,
  10010. /// and memory descriptors.
  10011. ///
  10012. /// The following arguments may point to a zero memory descriptor:
  10013. /// - @p src_iter_desc together with @p src_iter_c_desc,
  10014. /// @p diff_src_iter_desc, and @p diff_src_iter_c_desc,
  10015. /// - @p weights_peephole_desc together with
  10016. /// @p diff_weights_peephole_desc
  10017. /// - @p bias_desc together with @p diff_bias_desc,
  10018. /// - @p dst_iter_desc together with @p dst_iter_c_desc,
  10019. /// @p diff_dst_iter_desc, and @p diff_dst_iter_c_desc.
  10020. ///
  10021. /// This would then indicate that the LSTM backward propagation
  10022. /// primitive should not use them and should default to zero values
  10023. /// instead.
  10024. ///
  10025. /// @note
  10026. /// All memory descriptors may be initialized with
  10027. /// #dnnl::memory::format_tag::any value of @p format_tag.
  10028. ///
  10029. /// @param aengine Engine to use.
  10030. /// @param aprop_kind Propagation kind. Must be
  10031. /// #dnnl::prop_kind::backward.
  10032. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  10033. /// more info.
  10034. /// @param src_layer_desc Memory descriptor for the input vector.
  10035. /// @param src_iter_desc Memory descriptor for the input recurrent
  10036. /// hidden state vector.
  10037. /// @param src_iter_c_desc Memory descriptor for the input recurrent
  10038. /// cell state vector.
  10039. /// @param weights_layer_desc Memory descriptor for the weights
  10040. /// applied to the layer input.
  10041. /// @param weights_iter_desc Memory descriptor for the weights applied
  10042. /// to the recurrent input.
  10043. /// @param weights_peephole_desc Memory descriptor for the weights
  10044. /// applied to the cell states (according to the Peephole LSTM
  10045. /// formula).
  10046. /// @param bias_desc Bias memory descriptor.
  10047. /// @param dst_layer_desc Memory descriptor for the output vector.
  10048. /// @param dst_iter_desc Memory descriptor for the output recurrent
  10049. /// hidden state vector.
  10050. /// @param dst_iter_c_desc Memory descriptor for the output recurrent
  10051. /// cell state vector.
  10052. /// @param diff_src_layer_desc Memory descriptor for the diff of input
  10053. /// vector.
  10054. /// @param diff_src_iter_desc Memory descriptor for the diff of input
  10055. /// recurrent hidden state vector.
  10056. /// @param diff_src_iter_c_desc Memory descriptor for the diff of
  10057. /// input recurrent cell state vector.
  10058. /// @param diff_weights_layer_desc Memory descriptor for the diff of
  10059. /// weights applied to the layer input.
  10060. /// @param diff_weights_iter_desc Memory descriptor for the diff of
  10061. /// weights applied to the recurrent input.
  10062. /// @param diff_weights_peephole_desc Memory descriptor for the diff of
  10063. /// weights applied to the cell states (according to the Peephole
  10064. /// LSTM formula).
  10065. /// @param diff_bias_desc Diff bias memory descriptor.
  10066. /// @param diff_dst_layer_desc Memory descriptor for the diff of
  10067. /// output vector.
  10068. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  10069. /// recurrent hidden state vector.
  10070. /// @param diff_dst_iter_c_desc Memory descriptor for the diff of
  10071. /// output recurrent cell state vector.
  10072. /// @param hint_fwd_pd Primitive descriptor for an LSTM
  10073. /// forward propagation primitive. It is used as a hint for
  10074. /// deciding which memory format to use.
  10075. /// @param attr Primitive attributes to use. Attributes are optional
  10076. /// and default to empty attributes.
  10077. /// @param allow_empty A flag signifying whether construction is
  10078. /// allowed to fail without throwing an exception. In this case an
  10079. /// empty object will be produced. This flag is optional and
  10080. /// defaults to false.
  10081. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  10082. rnn_direction direction, const memory::desc &src_layer_desc,
  10083. const memory::desc &src_iter_desc,
  10084. const memory::desc &src_iter_c_desc,
  10085. const memory::desc &weights_layer_desc,
  10086. const memory::desc &weights_iter_desc,
  10087. const memory::desc &weights_peephole_desc,
  10088. const memory::desc &bias_desc,
  10089. const memory::desc &dst_layer_desc,
  10090. const memory::desc &dst_iter_desc,
  10091. const memory::desc &dst_iter_c_desc,
  10092. const memory::desc &diff_src_layer_desc,
  10093. const memory::desc &diff_src_iter_desc,
  10094. const memory::desc &diff_src_iter_c_desc,
  10095. const memory::desc &diff_weights_layer_desc,
  10096. const memory::desc &diff_weights_iter_desc,
  10097. const memory::desc &diff_weights_peephole_desc,
  10098. const memory::desc &diff_bias_desc,
  10099. const memory::desc &diff_dst_layer_desc,
  10100. const memory::desc &diff_dst_iter_desc,
  10101. const memory::desc &diff_dst_iter_c_desc,
  10102. const lstm_forward::primitive_desc &hint_fwd_pd,
  10103. const primitive_attr &attr = default_attr(),
  10104. bool allow_empty = false)
  10105. : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
  10106. aprop_kind, algorithm::undef, direction, src_layer_desc,
  10107. src_iter_desc, &src_iter_c_desc, nullptr,
  10108. weights_layer_desc, weights_iter_desc,
  10109. &weights_peephole_desc, nullptr, bias_desc, dst_layer_desc,
  10110. dst_iter_desc, &dst_iter_c_desc, diff_src_layer_desc,
  10111. diff_src_iter_desc, &diff_src_iter_c_desc, nullptr,
  10112. diff_weights_layer_desc, diff_weights_iter_desc,
  10113. &diff_weights_peephole_desc, nullptr, diff_bias_desc,
  10114. diff_dst_layer_desc, diff_dst_iter_desc,
  10115. &diff_dst_iter_c_desc, rnn_flags::undef, 0.0f, 0.0f,
  10116. hint_fwd_pd, attr, allow_empty) {}
  10117. /// Constructs an LSTM primitive descriptor for backward propagation
  10118. /// using @p prop_kind, @p direction, and memory descriptors.
  10119. ///
  10120. /// The following arguments may point to a zero memory descriptor:
  10121. /// - @p src_iter_desc together with @p src_iter_c_desc,
  10122. /// @p diff_src_iter_desc, and @p diff_src_iter_c_desc,
  10123. /// - @p bias_desc together with @p diff_bias_desc,
  10124. /// - @p dst_iter_desc together with @p dst_iter_c_desc,
  10125. /// @p diff_dst_iter_desc, and @p diff_dst_iter_c_desc.
  10126. ///
  10127. /// This would then indicate that the LSTM backward propagation
  10128. /// primitive should not use them and should default to zero values
  10129. /// instead.
  10130. ///
  10131. /// @note
  10132. /// All memory descriptors may be initialized with
  10133. /// #dnnl::memory::format_tag::any value of @p format_tag.
  10134. ///
  10135. /// @param aengine Engine to use.
  10136. /// @param aprop_kind Propagation kind. Must be
  10137. /// #dnnl::prop_kind::backward.
  10138. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  10139. /// more info.
  10140. /// @param src_layer_desc Memory descriptor for the input vector.
  10141. /// @param src_iter_desc Memory descriptor for the input recurrent
  10142. /// hidden state vector.
  10143. /// @param src_iter_c_desc Memory descriptor for the input recurrent
  10144. /// cell state vector.
  10145. /// @param weights_layer_desc Memory descriptor for the weights
  10146. /// applied to the layer input.
  10147. /// @param weights_iter_desc Memory descriptor for the weights applied
  10148. /// to the recurrent input.
  10149. /// @param bias_desc Bias memory descriptor.
  10150. /// @param dst_layer_desc Memory descriptor for the output vector.
  10151. /// @param dst_iter_desc Memory descriptor for the output recurrent
  10152. /// hidden state vector.
  10153. /// @param dst_iter_c_desc Memory descriptor for the output recurrent
  10154. /// cell state vector.
  10155. /// @param diff_src_layer_desc Memory descriptor for the diff of input
  10156. /// vector.
  10157. /// @param diff_src_iter_desc Memory descriptor for the diff of input
  10158. /// recurrent hidden state vector.
  10159. /// @param diff_src_iter_c_desc Memory descriptor for the diff of
  10160. /// input recurrent cell state vector.
  10161. /// @param diff_weights_layer_desc Memory descriptor for the diff of
  10162. /// weights applied to the layer input.
  10163. /// @param diff_weights_iter_desc Memory descriptor for the diff of
  10164. /// weights applied to the recurrent input.
  10165. /// @param diff_bias_desc Diff bias memory descriptor.
  10166. /// @param diff_dst_layer_desc Memory descriptor for the diff of
  10167. /// output vector.
  10168. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  10169. /// recurrent hidden state vector.
  10170. /// @param diff_dst_iter_c_desc Memory descriptor for the diff of
  10171. /// output recurrent cell state vector.
  10172. /// @param hint_fwd_pd Primitive descriptor for a convolution
  10173. /// forward propagation primitive. It is used as a hint for
  10174. /// deciding which memory format to use.
  10175. /// @param attr Primitive attributes to use. Attributes are optional
  10176. /// and default to empty attributes.
  10177. /// @param allow_empty A flag signifying whether construction is
  10178. /// allowed to fail without throwing an exception. In this case an
  10179. /// empty object will be produced. This flag is optional and
  10180. /// defaults to false.
  10181. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  10182. rnn_direction direction, const memory::desc &src_layer_desc,
  10183. const memory::desc &src_iter_desc,
  10184. const memory::desc &src_iter_c_desc,
  10185. const memory::desc &weights_layer_desc,
  10186. const memory::desc &weights_iter_desc,
  10187. const memory::desc &bias_desc,
  10188. const memory::desc &dst_layer_desc,
  10189. const memory::desc &dst_iter_desc,
  10190. const memory::desc &dst_iter_c_desc,
  10191. const memory::desc &diff_src_layer_desc,
  10192. const memory::desc &diff_src_iter_desc,
  10193. const memory::desc &diff_src_iter_c_desc,
  10194. const memory::desc &diff_weights_layer_desc,
  10195. const memory::desc &diff_weights_iter_desc,
  10196. const memory::desc &diff_bias_desc,
  10197. const memory::desc &diff_dst_layer_desc,
  10198. const memory::desc &diff_dst_iter_desc,
  10199. const memory::desc &diff_dst_iter_c_desc,
  10200. const lstm_forward::primitive_desc &hint_fwd_pd,
  10201. const primitive_attr &attr = default_attr(),
  10202. bool allow_empty = false)
  10203. : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
  10204. aprop_kind, algorithm::undef, direction, src_layer_desc,
  10205. src_iter_desc, &src_iter_c_desc, nullptr,
  10206. weights_layer_desc, weights_iter_desc, nullptr, nullptr,
  10207. bias_desc, dst_layer_desc, dst_iter_desc, &dst_iter_c_desc,
  10208. diff_src_layer_desc, diff_src_iter_desc,
  10209. &diff_src_iter_c_desc, nullptr, diff_weights_layer_desc,
  10210. diff_weights_iter_desc, nullptr, nullptr, diff_bias_desc,
  10211. diff_dst_layer_desc, diff_dst_iter_desc,
  10212. &diff_dst_iter_c_desc, rnn_flags::undef, 0.0f, 0.0f,
  10213. hint_fwd_pd, attr, allow_empty) {}
  10214. /// Constructs a primitive descriptor for an LSTM backward propagation
  10215. /// primitive from a C API primitive descriptor that must have a
  10216. /// matching kind.
  10217. ///
  10218. /// @param pd C API primitive descriptor for an LSTM backward
  10219. /// propagation primitive.
  10220. primitive_desc(dnnl_primitive_desc_t pd)
  10221. : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
  10222. dnnl::algorithm::vanilla_lstm) {}
  10223. /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
  10224. memory::desc src_layer_desc() const {
  10225. return rnn_base::src_layer_desc();
  10226. }
  10227. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  10228. memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
  10229. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  10230. memory::desc src_iter_c_desc() const {
  10231. return rnn_base::src_iter_c_desc();
  10232. }
  10233. /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
  10234. memory::desc weights_layer_desc() const {
  10235. return rnn_base::weights_layer_desc();
  10236. }
  10237. /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
  10238. memory::desc weights_iter_desc() const {
  10239. return rnn_base::weights_iter_desc();
  10240. }
  10241. /// @copydoc dnnl::rnn_primitive_desc_base::weights_peephole_desc()const
  10242. memory::desc weights_peephole_desc() const {
  10243. return rnn_base::weights_peephole_desc();
  10244. }
  10245. /// @copydoc dnnl::rnn_primitive_desc_base::weights_projection_desc()const
  10246. memory::desc weights_projection_desc() const {
  10247. return rnn_base::weights_projection_desc();
  10248. }
  10249. /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
  10250. memory::desc bias_desc() const { return rnn_base::bias_desc(); }
  10251. /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
  10252. memory::desc dst_layer_desc() const {
  10253. return rnn_base::dst_layer_desc();
  10254. }
  10255. /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
  10256. memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
  10257. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  10258. memory::desc dst_iter_c_desc() const {
  10259. return rnn_base::dst_iter_c_desc();
  10260. }
  10261. /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
  10262. memory::desc workspace_desc() const {
  10263. return rnn_base::workspace_desc();
  10264. }
  10265. /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
  10266. memory::desc diff_src_layer_desc() const {
  10267. return rnn_base::diff_src_layer_desc();
  10268. }
  10269. /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
  10270. memory::desc diff_src_iter_desc() const {
  10271. return rnn_base::diff_src_iter_desc();
  10272. }
  10273. /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_c_desc()const
  10274. memory::desc diff_src_iter_c_desc() const {
  10275. return rnn_base::diff_src_iter_c_desc();
  10276. }
  10277. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
  10278. memory::desc diff_weights_layer_desc() const {
  10279. return rnn_base::diff_weights_layer_desc();
  10280. }
  10281. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
  10282. memory::desc diff_weights_iter_desc() const {
  10283. return rnn_base::diff_weights_iter_desc();
  10284. }
  10285. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_peephole_desc()const
  10286. memory::desc diff_weights_peephole_desc() const {
  10287. return rnn_base::diff_weights_peephole_desc();
  10288. }
  10289. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_projection_desc()const
  10290. memory::desc diff_weights_projection_desc() const {
  10291. return rnn_base::diff_weights_projection_desc();
  10292. }
  10293. /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
  10294. memory::desc diff_bias_desc() const {
  10295. return rnn_base::diff_bias_desc();
  10296. }
  10297. /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
  10298. memory::desc diff_dst_layer_desc() const {
  10299. return rnn_base::diff_dst_layer_desc();
  10300. }
  10301. /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
  10302. memory::desc diff_dst_iter_desc() const {
  10303. return rnn_base::diff_dst_iter_desc();
  10304. }
  10305. /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_c_desc()const
  10306. memory::desc diff_dst_iter_c_desc() const {
  10307. return rnn_base::diff_dst_iter_c_desc();
  10308. }
  10309. /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
  10310. algorithm get_cell_kind() const { return base::get_cell_kind(); }
  10311. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  10312. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  10313. /// @copydoc dnnl::primitive_desc_base::get_direction()const
  10314. rnn_direction get_direction() const { return base::get_direction(); }
  10315. };
  10316. /// Default constructor. Produces an empty object.
  10317. lstm_backward() = default;
  10318. /// Constructs an LSTM backward propagation primitive.
  10319. /// @param pd Primitive descriptor for an LSTM backward propagation
  10320. /// primitive.
  10321. lstm_backward(const primitive_desc &pd) : primitive(pd) {}
  10322. /// Constructs an LSTM backward propagation primitive from a cache blob.
  10323. /// @param pd Primitive descriptor for an LSTM backward propagation
  10324. /// primitive.
  10325. /// @param cache_blob Cache blob.
  10326. lstm_backward(
  10327. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  10328. : primitive(pd, cache_blob) {}
  10329. };
  10330. /// GRU forward propagation primitive.
  10331. struct gru_forward : public primitive {
  10332. /// Primitive descriptor for a GRU forward propagation primitive.
  10333. struct primitive_desc : public rnn_primitive_desc_base {
  10334. /// Default constructor. Produces an empty object.
  10335. primitive_desc() = default;
  10336. /// Constructs a primitive descriptor for a GRU forward propagation
  10337. /// primitive.
  10338. ///
  10339. /// The following arguments may point to a zero memory descriptor:
  10340. /// - @p src_iter_desc,
  10341. /// - @p bias_desc,
  10342. /// - @p dst_iter_desc.
  10343. ///
  10344. /// This would then indicate that the GRU forward propagation primitive
  10345. /// should not use them and should default to zero values instead.
  10346. ///
  10347. /// @note
  10348. /// All memory descriptors except @p src_iter_desc may be
  10349. /// initialized with an #dnnl::memory::format_tag::any value of @p
  10350. /// format_tag.
  10351. ///
  10352. /// @param aengine Engine to use.
  10353. /// @param aprop_kind Propagation kind. Possible values are
  10354. /// #dnnl::prop_kind::forward_training, and
  10355. /// #dnnl::prop_kind::forward_inference.
  10356. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  10357. /// more info.
  10358. /// @param src_layer_desc Memory descriptor for the input vector.
  10359. /// @param src_iter_desc Memory descriptor for the input recurrent
  10360. /// hidden state vector.
  10361. /// @param weights_layer_desc Memory descriptor for the weights
  10362. /// applied to the layer input.
  10363. /// @param weights_iter_desc Memory descriptor for the weights applied
  10364. /// to the recurrent input.
  10365. /// @param bias_desc Bias memory descriptor.
  10366. /// @param dst_layer_desc Memory descriptor for the output vector.
  10367. /// @param dst_iter_desc Memory descriptor for the output recurrent
  10368. /// hidden state vector.
  10369. /// @param attr Primitive attributes to use. Attributes are optional
  10370. /// and default to empty attributes.
  10371. /// @param allow_empty A flag signifying whether construction is
  10372. /// allowed to fail without throwing an exception. In this case an
  10373. /// empty object will be produced. This flag is optional and
  10374. /// defaults to false.
  10375. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  10376. rnn_direction direction, const memory::desc &src_layer_desc,
  10377. const memory::desc &src_iter_desc,
  10378. const memory::desc &weights_layer_desc,
  10379. const memory::desc &weights_iter_desc,
  10380. const memory::desc &bias_desc,
  10381. const memory::desc &dst_layer_desc,
  10382. const memory::desc &dst_iter_desc,
  10383. const primitive_attr &attr = default_attr(),
  10384. bool allow_empty = false)
  10385. : rnn_primitive_desc_base(aengine, algorithm::vanilla_gru,
  10386. aprop_kind, algorithm::undef, direction, src_layer_desc,
  10387. src_iter_desc, nullptr, nullptr, weights_layer_desc,
  10388. weights_iter_desc, nullptr, nullptr, bias_desc,
  10389. dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef,
  10390. 0.0f, 0.0f, attr, allow_empty) {}
  10391. /// Constructs a primitive descriptor for a GRU forward propagation
  10392. /// primitive from a C API primitive descriptor that must have a
  10393. /// matching kind.
  10394. ///
  10395. /// @param pd C API primitive descriptor for a GRU forward
  10396. /// propagation primitive.
  10397. primitive_desc(dnnl_primitive_desc_t pd)
  10398. : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
  10399. dnnl::prop_kind::forward_inference,
  10400. dnnl::algorithm::vanilla_gru) {}
  10401. /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
  10402. memory::desc src_layer_desc() const {
  10403. return rnn_base::src_layer_desc();
  10404. }
  10405. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  10406. memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
  10407. /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
  10408. memory::desc weights_layer_desc() const {
  10409. return rnn_base::weights_layer_desc();
  10410. }
  10411. /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
  10412. memory::desc weights_iter_desc() const {
  10413. return rnn_base::weights_iter_desc();
  10414. }
  10415. /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
  10416. memory::desc bias_desc() const { return rnn_base::bias_desc(); }
  10417. /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
  10418. memory::desc dst_layer_desc() const {
  10419. return rnn_base::dst_layer_desc();
  10420. }
  10421. /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
  10422. memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
  10423. /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
  10424. memory::desc workspace_desc() const {
  10425. return rnn_base::workspace_desc();
  10426. }
  10427. /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
  10428. algorithm get_cell_kind() const { return base::get_cell_kind(); }
  10429. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  10430. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  10431. /// @copydoc dnnl::primitive_desc_base::get_direction()const
  10432. rnn_direction get_direction() const { return base::get_direction(); }
  10433. };
  10434. /// Default constructor. Produces an empty object.
  10435. gru_forward() = default;
  10436. /// Constructs a GRU forward propagation primitive.
  10437. /// @param pd Primitive descriptor for a GRU forward propagation
  10438. /// primitive.
  10439. gru_forward(const primitive_desc &pd) : primitive(pd) {}
  10440. /// Constructs a GRU forward propagation primitive from a cache blob.
  10441. /// @param pd Primitive descriptor for a GRU forward propagation
  10442. /// primitive.
  10443. /// @param cache_blob Cache blob.
  10444. gru_forward(
  10445. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  10446. : primitive(pd, cache_blob) {}
  10447. };
  10448. /// GRU backward propagation primitive.
  10449. struct gru_backward : public primitive {
  10450. /// Primitive descriptor for a GRU backward propagation primitive.
  10451. struct primitive_desc : public rnn_primitive_desc_base {
  10452. /// Default constructor. Produces an empty object.
  10453. primitive_desc() = default;
  10454. /// Constructs a primitive descriptor for a GRU backward propagation
  10455. /// primitive.
  10456. ///
  10457. /// The following arguments may point to a zero memory descriptor:
  10458. /// - @p src_iter_desc together with @p diff_src_iter_desc,
  10459. /// - @p bias_desc together with @p diff_bias_desc,
  10460. /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
  10461. ///
  10462. /// This would then indicate that the GRU backward propagation
  10463. /// primitive should not use them and should default to zero values
  10464. /// instead.
  10465. ///
  10466. /// @note
  10467. /// All memory descriptors may be initialized with
  10468. /// #dnnl::memory::format_tag::any value of @p format_tag.
  10469. ///
  10470. /// @param aengine Engine to use.
  10471. /// @param aprop_kind Propagation kind. Must be
  10472. /// #dnnl::prop_kind::backward.
  10473. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  10474. /// more info.
  10475. /// @param src_layer_desc Memory descriptor for the input vector.
  10476. /// @param src_iter_desc Memory descriptor for the input recurrent
  10477. /// hidden state vector.
  10478. /// @param weights_layer_desc Memory descriptor for the weights
  10479. /// applied to the layer input.
  10480. /// @param weights_iter_desc Memory descriptor for the weights applied
  10481. /// to the recurrent input.
  10482. /// @param bias_desc Bias memory descriptor.
  10483. /// @param dst_layer_desc Memory descriptor for the output vector.
  10484. /// @param dst_iter_desc Memory descriptor for the output recurrent
  10485. /// hidden state vector.
  10486. /// @param diff_src_layer_desc Memory descriptor for the diff of input
  10487. /// vector.
  10488. /// @param diff_src_iter_desc Memory descriptor for the diff of input
  10489. /// recurrent hidden state vector.
  10490. /// @param diff_weights_layer_desc Memory descriptor for the diff of
  10491. /// weights applied to the layer input.
  10492. /// @param diff_weights_iter_desc Memory descriptor for the diff of
  10493. /// weights applied to the recurrent input.
  10494. /// @param diff_bias_desc Diff bias memory descriptor.
  10495. /// @param diff_dst_layer_desc Memory descriptor for the diff of
  10496. /// output vector.
  10497. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  10498. /// recurrent hidden state vector.
  10499. /// @param hint_fwd_pd Primitive descriptor for a GRU
  10500. /// forward propagation primitive. It is used as a hint for
  10501. /// deciding which memory format to use.
  10502. /// @param attr Primitive attributes to use. Attributes are optional
  10503. /// and default to empty attributes.
  10504. /// @param allow_empty A flag signifying whether construction is
  10505. /// allowed to fail without throwing an exception. In this case an
  10506. /// empty object will be produced. This flag is optional and
  10507. /// defaults to false.
  10508. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  10509. rnn_direction direction, const memory::desc &src_layer_desc,
  10510. const memory::desc &src_iter_desc,
  10511. const memory::desc &weights_layer_desc,
  10512. const memory::desc &weights_iter_desc,
  10513. const memory::desc &bias_desc,
  10514. const memory::desc &dst_layer_desc,
  10515. const memory::desc &dst_iter_desc,
  10516. const memory::desc &diff_src_layer_desc,
  10517. const memory::desc &diff_src_iter_desc,
  10518. const memory::desc &diff_weights_layer_desc,
  10519. const memory::desc &diff_weights_iter_desc,
  10520. const memory::desc &diff_bias_desc,
  10521. const memory::desc &diff_dst_layer_desc,
  10522. const memory::desc &diff_dst_iter_desc,
  10523. const gru_forward::primitive_desc &hint_fwd_pd,
  10524. const primitive_attr &attr = default_attr(),
  10525. bool allow_empty = false)
  10526. : rnn_primitive_desc_base(aengine, algorithm::vanilla_gru,
  10527. aprop_kind, algorithm::undef, direction, src_layer_desc,
  10528. src_iter_desc, nullptr, nullptr, weights_layer_desc,
  10529. weights_iter_desc, nullptr, nullptr, bias_desc,
  10530. dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
  10531. diff_src_iter_desc, nullptr, nullptr,
  10532. diff_weights_layer_desc, diff_weights_iter_desc, nullptr,
  10533. nullptr, diff_bias_desc, diff_dst_layer_desc,
  10534. diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f,
  10535. hint_fwd_pd, attr, allow_empty) {}
  10536. /// Constructs a primitive descriptor for a GRU backward propagation
  10537. /// primitive from a C API primitive descriptor that must have a
  10538. /// matching kind.
  10539. ///
  10540. /// @param pd C API primitive descriptor for a GRU backward
  10541. /// propagation primitive.
  10542. primitive_desc(dnnl_primitive_desc_t pd)
  10543. : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
  10544. dnnl::algorithm::vanilla_gru) {}
  10545. /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
  10546. memory::desc src_layer_desc() const {
  10547. return rnn_base::src_layer_desc();
  10548. }
  10549. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  10550. memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
  10551. /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
  10552. memory::desc weights_layer_desc() const {
  10553. return rnn_base::weights_layer_desc();
  10554. }
  10555. /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
  10556. memory::desc weights_iter_desc() const {
  10557. return rnn_base::weights_iter_desc();
  10558. }
  10559. /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
  10560. memory::desc bias_desc() const { return rnn_base::bias_desc(); }
  10561. /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
  10562. memory::desc dst_layer_desc() const {
  10563. return rnn_base::dst_layer_desc();
  10564. }
  10565. /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
  10566. memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
  10567. /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
  10568. memory::desc workspace_desc() const {
  10569. return rnn_base::workspace_desc();
  10570. }
  10571. /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
  10572. memory::desc diff_src_layer_desc() const {
  10573. return rnn_base::diff_src_layer_desc();
  10574. }
  10575. /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
  10576. memory::desc diff_src_iter_desc() const {
  10577. return rnn_base::diff_src_iter_desc();
  10578. }
  10579. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
  10580. memory::desc diff_weights_layer_desc() const {
  10581. return rnn_base::diff_weights_layer_desc();
  10582. }
  10583. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
  10584. memory::desc diff_weights_iter_desc() const {
  10585. return rnn_base::diff_weights_iter_desc();
  10586. }
  10587. /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
  10588. memory::desc diff_bias_desc() const {
  10589. return rnn_base::diff_bias_desc();
  10590. }
  10591. /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
  10592. memory::desc diff_dst_layer_desc() const {
  10593. return rnn_base::diff_dst_layer_desc();
  10594. }
  10595. /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
  10596. memory::desc diff_dst_iter_desc() const {
  10597. return rnn_base::diff_dst_iter_desc();
  10598. }
  10599. /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
  10600. algorithm get_cell_kind() const { return base::get_cell_kind(); }
  10601. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  10602. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  10603. /// @copydoc dnnl::primitive_desc_base::get_direction()const
  10604. rnn_direction get_direction() const { return base::get_direction(); }
  10605. };
  10606. /// Default constructor. Produces an empty object.
  10607. gru_backward() = default;
  10608. /// Constructs a GRU backward propagation primitive.
  10609. /// @param pd Primitive descriptor for a GRU backward propagation
  10610. /// primitive.
  10611. gru_backward(const primitive_desc &pd) : primitive(pd) {}
  10612. /// Constructs a GRU backward propagation primitive from a cache blob.
  10613. /// @param pd Primitive descriptor for a GRU backward propagation
  10614. /// primitive.
  10615. /// @param cache_blob Cache blob.
  10616. gru_backward(
  10617. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  10618. : primitive(pd, cache_blob) {}
  10619. };
  10620. /// LBR GRU forward propagation primitive.
  10621. struct lbr_gru_forward : public primitive {
  10622. /// Primitive descriptor for an LBR GRU forward propagation primitive.
  10623. struct primitive_desc : public rnn_primitive_desc_base {
  10624. /// Default constructor. Produces an empty object.
  10625. primitive_desc() = default;
  10626. /// Constructs a primitive descriptor for LBR GRU forward propagation
  10627. /// primitive.
  10628. ///
  10629. /// The following arguments may point to a zero memory descriptor:
  10630. /// - @p src_iter_desc,
  10631. /// - @p bias_desc,
  10632. /// - @p dst_iter_desc.
  10633. ///
  10634. /// This would then indicate that the LBR GRU forward propagation
  10635. /// primitive should not use them and should default to zero values
  10636. /// instead.
  10637. ///
  10638. /// @note
  10639. /// All memory descriptors except @p src_iter_desc may be
  10640. /// initialized with an #dnnl::memory::format_tag::any value of @p
  10641. /// format_tag.
  10642. ///
  10643. /// @param aengine Engine to use.
  10644. /// @param aprop_kind Propagation kind. Possible values are
  10645. /// #dnnl::prop_kind::forward_training, and
  10646. /// #dnnl::prop_kind::forward_inference.
  10647. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  10648. /// more info.
  10649. /// @param src_layer_desc Memory descriptor for the input vector.
  10650. /// @param src_iter_desc Memory descriptor for the input recurrent
  10651. /// hidden state vector.
  10652. /// @param weights_layer_desc Memory descriptor for the weights
  10653. /// applied to the layer input.
  10654. /// @param weights_iter_desc Memory descriptor for the weights applied
  10655. /// to the recurrent input.
  10656. /// @param bias_desc Bias memory descriptor.
  10657. /// @param dst_layer_desc Memory descriptor for the output vector.
  10658. /// @param dst_iter_desc Memory descriptor for the output recurrent
  10659. /// hidden state vector.
  10660. /// @param attr Primitive attributes to use. Attributes are optional
  10661. /// and default to empty attributes.
  10662. /// @param allow_empty A flag signifying whether construction is
  10663. /// allowed to fail without throwing an exception. In this case an
  10664. /// empty object will be produced. This flag is optional and
  10665. /// defaults to false.
  10666. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  10667. rnn_direction direction, const memory::desc &src_layer_desc,
  10668. const memory::desc &src_iter_desc,
  10669. const memory::desc &weights_layer_desc,
  10670. const memory::desc &weights_iter_desc,
  10671. const memory::desc &bias_desc,
  10672. const memory::desc &dst_layer_desc,
  10673. const memory::desc &dst_iter_desc,
  10674. const primitive_attr &attr = default_attr(),
  10675. bool allow_empty = false)
  10676. : rnn_primitive_desc_base(aengine, algorithm::lbr_gru, aprop_kind,
  10677. algorithm::undef, direction, src_layer_desc, src_iter_desc,
  10678. nullptr, nullptr, weights_layer_desc, weights_iter_desc,
  10679. nullptr, nullptr, bias_desc, dst_layer_desc, dst_iter_desc,
  10680. nullptr, rnn_flags::undef, 0.0f, 0.0f, attr, allow_empty) {}
  10681. /// Constructs a primitive descriptor for a LBR GRU forward propagation
  10682. /// primitive from a C API primitive descriptor that must have a
  10683. /// matching kind.
  10684. ///
  10685. /// @param pd C API primitive descriptor for a LBR GRU forward
  10686. /// propagation primitive.
  10687. primitive_desc(dnnl_primitive_desc_t pd)
  10688. : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
  10689. dnnl::prop_kind::forward_inference,
  10690. dnnl::algorithm::lbr_gru) {}
  10691. /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
  10692. memory::desc src_layer_desc() const {
  10693. return rnn_base::src_layer_desc();
  10694. }
  10695. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  10696. memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
  10697. /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
  10698. memory::desc weights_layer_desc() const {
  10699. return rnn_base::weights_layer_desc();
  10700. }
  10701. /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
  10702. memory::desc weights_iter_desc() const {
  10703. return rnn_base::weights_iter_desc();
  10704. }
  10705. /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
  10706. memory::desc bias_desc() const { return rnn_base::bias_desc(); }
  10707. /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
  10708. memory::desc dst_layer_desc() const {
  10709. return rnn_base::dst_layer_desc();
  10710. }
  10711. /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
  10712. memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
  10713. /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
  10714. memory::desc workspace_desc() const {
  10715. return rnn_base::workspace_desc();
  10716. }
  10717. /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
  10718. algorithm get_cell_kind() const { return base::get_cell_kind(); }
  10719. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  10720. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  10721. /// @copydoc dnnl::primitive_desc_base::get_direction()const
  10722. rnn_direction get_direction() const { return base::get_direction(); }
  10723. };
  10724. /// Default constructor. Produces an empty object.
  10725. lbr_gru_forward() = default;
  10726. /// Constructs an LBR GRU forward propagation primitive.
  10727. /// @param pd Primitive descriptor for an LBR GRU forward propagation
  10728. /// primitive.
  10729. lbr_gru_forward(const primitive_desc &pd) : primitive(pd) {}
  10730. /// Constructs an LBR GRU forward propagation primitive from a cache blob.
  10731. /// @param pd Primitive descriptor for an LBR GRU forward propagation
  10732. /// primitive.
  10733. /// @param cache_blob Cache blob.
  10734. lbr_gru_forward(
  10735. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  10736. : primitive(pd, cache_blob) {}
  10737. };
  10738. /// LBR GRU backward propagation primitive.
  10739. struct lbr_gru_backward : public primitive {
  10740. /// Primitive descriptor for an LBR GRU backward propagation primitive.
  10741. struct primitive_desc : public rnn_primitive_desc_base {
  10742. /// Default constructor. Produces an empty object.
  10743. primitive_desc() = default;
  10744. /// Constructs a primitive descriptor for LBR GRU backward propagation
  10745. /// primitive.
  10746. ///
  10747. /// The following arguments may point to a zero memory descriptor:
  10748. /// - @p src_iter_desc together with @p diff_src_iter_desc,
  10749. /// - @p bias_desc together with @p diff_bias_desc,
  10750. /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
  10751. ///
  10752. /// This would then indicate that the LBR GRU backward propagation
  10753. /// primitive should not use them and should default to zero values
  10754. /// instead.
  10755. ///
  10756. /// @note
  10757. /// All memory descriptors may be initialized with
  10758. /// #dnnl::memory::format_tag::any value of @p format_tag.
  10759. ///
  10760. /// @param aengine Engine to use.
  10761. /// @param aprop_kind Propagation kind. Must be
  10762. /// #dnnl::prop_kind::backward.
  10763. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  10764. /// more info.
  10765. /// @param src_layer_desc Memory descriptor for the input vector.
  10766. /// @param src_iter_desc Memory descriptor for the input recurrent
  10767. /// hidden state vector.
  10768. /// @param weights_layer_desc Memory descriptor for the weights
  10769. /// applied to the layer input.
  10770. /// @param weights_iter_desc Memory descriptor for the weights applied
  10771. /// to the recurrent input.
  10772. /// @param bias_desc Bias memory descriptor.
  10773. /// @param dst_layer_desc Memory descriptor for the output vector.
  10774. /// @param dst_iter_desc Memory descriptor for the output recurrent
  10775. /// hidden state vector.
  10776. /// @param diff_src_layer_desc Memory descriptor for the diff of input
  10777. /// vector.
  10778. /// @param diff_src_iter_desc Memory descriptor for the diff of input
  10779. /// recurrent hidden state vector.
  10780. /// @param diff_weights_layer_desc Memory descriptor for the diff of
  10781. /// weights applied to the layer input.
  10782. /// @param diff_weights_iter_desc Memory descriptor for the diff of
  10783. /// weights applied to the recurrent input.
  10784. /// @param diff_bias_desc Diff bias memory descriptor.
  10785. /// @param diff_dst_layer_desc Memory descriptor for the diff of
  10786. /// output vector.
  10787. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  10788. /// recurrent hidden state vector.
  10789. /// @param hint_fwd_pd Primitive descriptor for an LBR GRU
  10790. /// forward propagation primitive. It is used as a hint for
  10791. /// deciding which memory format to use.
  10792. /// @param attr Primitive attributes to use. Attributes are optional
  10793. /// and default to empty attributes.
  10794. /// @param allow_empty A flag signifying whether construction is
  10795. /// allowed to fail without throwing an exception. In this case an
  10796. /// empty object will be produced. This flag is optional and
  10797. /// defaults to false.
  10798. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  10799. rnn_direction direction, const memory::desc &src_layer_desc,
  10800. const memory::desc &src_iter_desc,
  10801. const memory::desc &weights_layer_desc,
  10802. const memory::desc &weights_iter_desc,
  10803. const memory::desc &bias_desc,
  10804. const memory::desc &dst_layer_desc,
  10805. const memory::desc &dst_iter_desc,
  10806. const memory::desc &diff_src_layer_desc,
  10807. const memory::desc &diff_src_iter_desc,
  10808. const memory::desc &diff_weights_layer_desc,
  10809. const memory::desc &diff_weights_iter_desc,
  10810. const memory::desc &diff_bias_desc,
  10811. const memory::desc &diff_dst_layer_desc,
  10812. const memory::desc &diff_dst_iter_desc,
  10813. const lbr_gru_forward::primitive_desc &hint_fwd_pd,
  10814. const primitive_attr &attr = default_attr(),
  10815. bool allow_empty = false)
  10816. : rnn_primitive_desc_base(aengine, algorithm::lbr_gru, aprop_kind,
  10817. algorithm::undef, direction, src_layer_desc, src_iter_desc,
  10818. nullptr, nullptr, weights_layer_desc, weights_iter_desc,
  10819. nullptr, nullptr, bias_desc, dst_layer_desc, dst_iter_desc,
  10820. nullptr, diff_src_layer_desc, diff_src_iter_desc, nullptr,
  10821. nullptr, diff_weights_layer_desc, diff_weights_iter_desc,
  10822. nullptr, nullptr, diff_bias_desc, diff_dst_layer_desc,
  10823. diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f,
  10824. hint_fwd_pd, attr, allow_empty) {}
  10825. /// Constructs a primitive descriptor for a LBR GRU backward propagation
  10826. /// primitive from a C API primitive descriptor that must have a
  10827. /// matching kind.
  10828. ///
  10829. /// @param pd C API primitive descriptor for a LBR GRU backward
  10830. /// propagation primitive.
  10831. primitive_desc(dnnl_primitive_desc_t pd)
  10832. : rnn_primitive_desc_base(
  10833. pd, dnnl::prop_kind::backward, dnnl::algorithm::lbr_gru) {}
  10834. /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
  10835. memory::desc src_layer_desc() const {
  10836. return rnn_base::src_layer_desc();
  10837. }
  10838. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  10839. memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
  10840. /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
  10841. memory::desc weights_layer_desc() const {
  10842. return rnn_base::weights_layer_desc();
  10843. }
  10844. /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
  10845. memory::desc weights_iter_desc() const {
  10846. return rnn_base::weights_iter_desc();
  10847. }
  10848. /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
  10849. memory::desc bias_desc() const { return rnn_base::bias_desc(); }
  10850. /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
  10851. memory::desc dst_layer_desc() const {
  10852. return rnn_base::dst_layer_desc();
  10853. }
  10854. /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
  10855. memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
  10856. /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
  10857. memory::desc workspace_desc() const {
  10858. return rnn_base::workspace_desc();
  10859. }
  10860. /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
  10861. memory::desc diff_src_layer_desc() const {
  10862. return rnn_base::diff_src_layer_desc();
  10863. }
  10864. /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
  10865. memory::desc diff_src_iter_desc() const {
  10866. return rnn_base::diff_src_iter_desc();
  10867. }
  10868. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
  10869. memory::desc diff_weights_layer_desc() const {
  10870. return rnn_base::diff_weights_layer_desc();
  10871. }
  10872. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
  10873. memory::desc diff_weights_iter_desc() const {
  10874. return rnn_base::diff_weights_iter_desc();
  10875. }
  10876. /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
  10877. memory::desc diff_bias_desc() const {
  10878. return rnn_base::diff_bias_desc();
  10879. }
  10880. /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
  10881. memory::desc diff_dst_layer_desc() const {
  10882. return rnn_base::diff_dst_layer_desc();
  10883. }
  10884. /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
  10885. memory::desc diff_dst_iter_desc() const {
  10886. return rnn_base::diff_dst_iter_desc();
  10887. }
  10888. /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
  10889. algorithm get_cell_kind() const { return base::get_cell_kind(); }
  10890. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  10891. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  10892. /// @copydoc dnnl::primitive_desc_base::get_direction()const
  10893. rnn_direction get_direction() const { return base::get_direction(); }
  10894. };
  10895. /// Default constructor. Produces an empty object.
  10896. lbr_gru_backward() = default;
  10897. /// Constructs an LBR GRU backward propagation primitive.
  10898. /// @param pd Primitive descriptor for an LBR GRU backward propagation
  10899. /// primitive.
  10900. lbr_gru_backward(const primitive_desc &pd) : primitive(pd) {}
  10901. /// Constructs an LBR GRU backward propagation primitive from a cache blob.
  10902. /// @param pd Primitive descriptor for an LBR GRU backward propagation
  10903. /// primitive.
  10904. /// @param cache_blob Cache blob.
  10905. lbr_gru_backward(
  10906. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  10907. : primitive(pd, cache_blob) {}
  10908. };
  10909. /// AUGRU forward propagation primitive.
  10910. struct augru_forward : public primitive {
  10911. /// Primitive descriptor for an AUGRU forward propagation primitive.
  10912. struct primitive_desc : public rnn_primitive_desc_base {
  10913. /// Default constructor. Produces an empty object.
  10914. primitive_desc() = default;
  10915. /// Constructs a primitive descriptor for an AUGRU forward propagation
  10916. /// primitive.
  10917. ///
  10918. /// The following arguments may point to a zero memory descriptor:
  10919. /// - @p src_iter_desc,
  10920. /// - @p bias_desc,
  10921. /// - @p dst_iter_desc.
  10922. ///
  10923. /// This would then indicate that the AUGRU forward propagation
  10924. /// primitive should not use them and should default to zero values
  10925. /// instead.
  10926. ///
  10927. /// @note
  10928. /// All memory descriptors except @p src_iter_desc may be
  10929. /// initialized with an #dnnl::memory::format_tag::any value of @p
  10930. /// format_tag.
  10931. ///
  10932. /// @param aengine Engine to use.
  10933. /// @param aprop_kind Propagation kind. Possible values are
  10934. /// #dnnl::prop_kind::forward_training, and
  10935. /// #dnnl::prop_kind::forward_inference.
  10936. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  10937. /// more info.
  10938. /// @param src_layer_desc Memory descriptor for the input vector.
  10939. /// @param src_iter_desc Memory descriptor for the input recurrent
  10940. /// hidden state vector.
  10941. /// @param attention_desc Memory descriptor for the attention vector.
  10942. /// @param weights_layer_desc Memory descriptor for the weights
  10943. /// applied to the layer input.
  10944. /// @param weights_iter_desc Memory descriptor for the weights applied
  10945. /// to the recurrent input.
  10946. /// @param bias_desc Bias memory descriptor.
  10947. /// @param dst_layer_desc Memory descriptor for the output vector.
  10948. /// @param dst_iter_desc Memory descriptor for the output recurrent
  10949. /// hidden state vector.
  10950. /// @param attr Primitive attributes to use. Attributes are optional
  10951. /// and default to empty attributes.
  10952. /// @param allow_empty A flag signifying whether construction is
  10953. /// allowed to fail without throwing an exception. In this case an
  10954. /// empty object will be produced. This flag is optional and
  10955. /// defaults to false.
  10956. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  10957. rnn_direction direction, const memory::desc &src_layer_desc,
  10958. const memory::desc &src_iter_desc,
  10959. const memory::desc &attention_desc,
  10960. const memory::desc &weights_layer_desc,
  10961. const memory::desc &weights_iter_desc,
  10962. const memory::desc &bias_desc,
  10963. const memory::desc &dst_layer_desc,
  10964. const memory::desc &dst_iter_desc,
  10965. const primitive_attr &attr = default_attr(),
  10966. bool allow_empty = false)
  10967. : rnn_primitive_desc_base(aengine, algorithm::vanilla_augru,
  10968. aprop_kind, algorithm::undef, direction, src_layer_desc,
  10969. src_iter_desc, nullptr, &attention_desc, weights_layer_desc,
  10970. weights_iter_desc, nullptr, nullptr, bias_desc,
  10971. dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef,
  10972. 0.0f, 0.0f, attr, allow_empty) {}
  10973. /// Constructs a primitive descriptor for an AUGRU forward propagation
  10974. /// primitive from a C API primitive descriptor that must have a
  10975. /// matching kind.
  10976. ///
  10977. /// @param pd C API primitive descriptor for an AUGRU forward
  10978. /// propagation primitive.
  10979. primitive_desc(dnnl_primitive_desc_t pd)
  10980. : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
  10981. dnnl::prop_kind::forward_inference,
  10982. dnnl::algorithm::vanilla_augru) {}
  10983. /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
  10984. memory::desc src_layer_desc() const {
  10985. return rnn_base::src_layer_desc();
  10986. }
  10987. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  10988. memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
  10989. /// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const
  10990. memory::desc attention_desc() const {
  10991. return rnn_base::augru_attention_desc();
  10992. }
  10993. /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
  10994. memory::desc weights_layer_desc() const {
  10995. return rnn_base::weights_layer_desc();
  10996. }
  10997. /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
  10998. memory::desc weights_iter_desc() const {
  10999. return rnn_base::weights_iter_desc();
  11000. }
  11001. /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
  11002. memory::desc bias_desc() const { return rnn_base::bias_desc(); }
  11003. /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
  11004. memory::desc dst_layer_desc() const {
  11005. return rnn_base::dst_layer_desc();
  11006. }
  11007. /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
  11008. memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
  11009. /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
  11010. memory::desc workspace_desc() const {
  11011. return rnn_base::workspace_desc();
  11012. }
  11013. /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
  11014. algorithm get_cell_kind() const { return base::get_cell_kind(); }
  11015. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  11016. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  11017. /// @copydoc dnnl::primitive_desc_base::get_direction()const
  11018. rnn_direction get_direction() const { return base::get_direction(); }
  11019. };
  11020. /// Default constructor. Produces an empty object.
  11021. augru_forward() = default;
  11022. /// Constructs an AUGRU forward propagation primitive.
  11023. /// @param pd Primitive descriptor for an AUGRU forward propagation
  11024. /// primitive.
  11025. augru_forward(const primitive_desc &pd) : primitive(pd) {}
  11026. /// Constructs an AUGRU forward propagation primitive from a cache blob.
  11027. /// @param pd Primitive descriptor for an AUGRU forward propagation
  11028. /// primitive.
  11029. /// @param cache_blob Cache blob.
  11030. augru_forward(
  11031. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  11032. : primitive(pd, cache_blob) {}
  11033. };
  11034. /// AUGRU backward propagation primitive.
  11035. struct augru_backward : public primitive {
  11036. /// Descriptor for an AUGRU backward propagation primitive.
  11037. /// Primitive descriptor for an AUGRU backward propagation primitive.
  11038. struct primitive_desc : public rnn_primitive_desc_base {
  11039. /// Default constructor. Produces an empty object.
  11040. primitive_desc() = default;
  11041. /// Constructs a primitive descriptor for an AUGRU backward propagation
  11042. /// primitive.
  11043. ///
  11044. /// The following arguments may point to a zero memory descriptor:
  11045. /// - @p src_iter_desc together with @p diff_src_iter_desc,
  11046. /// - @p bias_desc together with @p diff_bias_desc,
  11047. /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
  11048. ///
  11049. /// This would then indicate that the AUGRU backward propagation
  11050. /// primitive should not use them and should default to zero values
  11051. /// instead.
  11052. ///
  11053. /// @note
  11054. /// All memory descriptors may be initialized with
  11055. /// #dnnl::memory::format_tag::any value of @p format_tag.
  11056. ///
  11057. /// @param aengine Engine to use.
  11058. /// @param aprop_kind Propagation kind. Must be
  11059. /// #dnnl::prop_kind::backward.
  11060. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  11061. /// more info.
  11062. /// @param src_layer_desc Memory descriptor for the input vector.
  11063. /// @param src_iter_desc Memory descriptor for the input recurrent
  11064. /// hidden state vector.
  11065. /// @param attention_desc Memory descriptor for the attention vector.
  11066. /// @param weights_layer_desc Memory descriptor for the weights
  11067. /// applied to the layer input.
  11068. /// @param weights_iter_desc Memory descriptor for the weights applied
  11069. /// to the recurrent input.
  11070. /// @param bias_desc Bias memory descriptor.
  11071. /// @param dst_layer_desc Memory descriptor for the output vector.
  11072. /// @param dst_iter_desc Memory descriptor for the output recurrent
  11073. /// hidden state vector.
  11074. /// @param diff_src_layer_desc Memory descriptor for the diff of input
  11075. /// vector.
  11076. /// @param diff_src_iter_desc Memory descriptor for the diff of input
  11077. /// recurrent hidden state vector.
  11078. /// @param diff_attention_desc Memory descriptor for the diff of
  11079. /// attention vector.
  11080. /// @param diff_weights_layer_desc Memory descriptor for the diff of
  11081. /// weights applied to the layer input.
  11082. /// @param diff_weights_iter_desc Memory descriptor for the diff of
  11083. /// weights applied to the recurrent input.
  11084. /// @param diff_bias_desc Diff bias memory descriptor.
  11085. /// @param diff_dst_layer_desc Memory descriptor for the diff of
  11086. /// output vector.
  11087. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  11088. /// recurrent hidden state vector.
  11089. /// @param hint_fwd_pd Primitive descriptor for an AUGRU
  11090. /// forward propagation primitive. It is used as a hint for
  11091. /// deciding which memory format to use.
  11092. /// @param attr Primitive attributes to use. Attributes are optional
  11093. /// and default to empty attributes.
  11094. /// @param allow_empty A flag signifying whether construction is
  11095. /// allowed to fail without throwing an exception. In this case an
  11096. /// empty object will be produced. This flag is optional and
  11097. /// defaults to false.
  11098. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  11099. rnn_direction direction, const memory::desc &src_layer_desc,
  11100. const memory::desc &src_iter_desc,
  11101. const memory::desc &attention_desc,
  11102. const memory::desc &weights_layer_desc,
  11103. const memory::desc &weights_iter_desc,
  11104. const memory::desc &bias_desc,
  11105. const memory::desc &dst_layer_desc,
  11106. const memory::desc &dst_iter_desc,
  11107. const memory::desc &diff_src_layer_desc,
  11108. const memory::desc &diff_src_iter_desc,
  11109. const memory::desc &diff_attention_desc,
  11110. const memory::desc &diff_weights_layer_desc,
  11111. const memory::desc &diff_weights_iter_desc,
  11112. const memory::desc &diff_bias_desc,
  11113. const memory::desc &diff_dst_layer_desc,
  11114. const memory::desc &diff_dst_iter_desc,
  11115. const augru_forward::primitive_desc &hint_fwd_pd,
  11116. const primitive_attr &attr = default_attr(),
  11117. bool allow_empty = false)
  11118. : rnn_primitive_desc_base(aengine, algorithm::vanilla_augru,
  11119. aprop_kind, algorithm::undef, direction, src_layer_desc,
  11120. src_iter_desc, nullptr, &attention_desc, weights_layer_desc,
  11121. weights_iter_desc, nullptr, nullptr, bias_desc,
  11122. dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
  11123. diff_src_iter_desc, nullptr, &diff_attention_desc,
  11124. diff_weights_layer_desc, diff_weights_iter_desc, nullptr,
  11125. nullptr, diff_bias_desc, diff_dst_layer_desc,
  11126. diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f,
  11127. hint_fwd_pd, attr, allow_empty) {}
  11128. /// Constructs a primitive descriptor for an AUGRU backward propagation
  11129. /// primitive from a C API primitive descriptor that must have a
  11130. /// matching kind.
  11131. ///
  11132. /// @param pd C API primitive descriptor for an AUGRU backward
  11133. /// propagation primitive.
  11134. primitive_desc(dnnl_primitive_desc_t pd)
  11135. : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
  11136. dnnl::algorithm::vanilla_augru) {}
  11137. /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
  11138. memory::desc src_layer_desc() const {
  11139. return rnn_base::src_layer_desc();
  11140. }
  11141. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  11142. memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
  11143. /// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const
  11144. memory::desc attention_desc() const {
  11145. return rnn_base::augru_attention_desc();
  11146. }
  11147. /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
  11148. memory::desc weights_layer_desc() const {
  11149. return rnn_base::weights_layer_desc();
  11150. }
  11151. /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
  11152. memory::desc weights_iter_desc() const {
  11153. return rnn_base::weights_iter_desc();
  11154. }
  11155. /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
  11156. memory::desc bias_desc() const { return rnn_base::bias_desc(); }
  11157. /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
  11158. memory::desc dst_layer_desc() const {
  11159. return rnn_base::dst_layer_desc();
  11160. }
  11161. /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
  11162. memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
  11163. /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
  11164. memory::desc workspace_desc() const {
  11165. return rnn_base::workspace_desc();
  11166. }
  11167. /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
  11168. memory::desc diff_src_layer_desc() const {
  11169. return rnn_base::diff_src_layer_desc();
  11170. }
  11171. /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
  11172. memory::desc diff_src_iter_desc() const {
  11173. return rnn_base::diff_src_iter_desc();
  11174. }
  11175. /// @copydoc dnnl::rnn_primitive_desc_base::diff_augru_attention_desc()const
  11176. memory::desc diff_attention_desc() const {
  11177. return rnn_base::diff_augru_attention_desc();
  11178. }
  11179. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
  11180. memory::desc diff_weights_layer_desc() const {
  11181. return rnn_base::diff_weights_layer_desc();
  11182. }
  11183. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
  11184. memory::desc diff_weights_iter_desc() const {
  11185. return rnn_base::diff_weights_iter_desc();
  11186. }
  11187. /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
  11188. memory::desc diff_bias_desc() const {
  11189. return rnn_base::diff_bias_desc();
  11190. }
  11191. /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
  11192. memory::desc diff_dst_layer_desc() const {
  11193. return rnn_base::diff_dst_layer_desc();
  11194. }
  11195. /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
  11196. memory::desc diff_dst_iter_desc() const {
  11197. return rnn_base::diff_dst_iter_desc();
  11198. }
  11199. /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
  11200. algorithm get_cell_kind() const { return base::get_cell_kind(); }
  11201. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  11202. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  11203. /// @copydoc dnnl::primitive_desc_base::get_direction()const
  11204. rnn_direction get_direction() const { return base::get_direction(); }
  11205. };
  11206. /// Default constructor. Produces an empty object.
  11207. augru_backward() = default;
  11208. /// Constructs an AUGRU backward propagation primitive.
  11209. /// @param pd Primitive descriptor for an AUGRU backward propagation
  11210. /// primitive.
  11211. augru_backward(const primitive_desc &pd) : primitive(pd) {}
  11212. /// Constructs an AUGRU backward propagation primitive from a cache blob.
  11213. /// @param pd Primitive descriptor for an AUGRU backward propagation
  11214. /// primitive.
  11215. /// @param cache_blob Cache blob.
  11216. augru_backward(
  11217. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  11218. : primitive(pd, cache_blob) {}
  11219. };
  11220. /// LBR AUGRU forward propagation primitive.
  11221. struct lbr_augru_forward : public primitive {
  11222. /// Descriptor for an LBR AUGRU forward propagation primitive.
  11223. /// Primitive descriptor for an LBR AUGRU forward propagation primitive.
  11224. struct primitive_desc : public rnn_primitive_desc_base {
  11225. /// Default constructor. Produces an empty object.
  11226. primitive_desc() = default;
  11227. /// Constructs a primitive descriptor for LBR AUGRU forward propagation
  11228. /// primitive.
  11229. ///
  11230. /// The following arguments may point to a zero memory descriptor:
  11231. /// - @p src_iter_desc,
  11232. /// - @p bias_desc,
  11233. /// - @p dst_iter_desc.
  11234. ///
  11235. /// This would then indicate that the LBR AUGRU forward propagation
  11236. /// primitive should not use them and should default to zero values
  11237. /// instead.
  11238. ///
  11239. /// @note
  11240. /// All memory descriptors except @p src_iter_desc may be
  11241. /// initialized with an #dnnl::memory::format_tag::any value of @p
  11242. /// format_tag.
  11243. ///
  11244. /// @param aengine Engine to use.
  11245. /// @param aprop_kind Propagation kind. Possible values are
  11246. /// #dnnl::prop_kind::forward_training, and
  11247. /// #dnnl::prop_kind::forward_inference.
  11248. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  11249. /// more info.
  11250. /// @param src_layer_desc Memory descriptor for the input vector.
  11251. /// @param src_iter_desc Memory descriptor for the input recurrent
  11252. /// hidden state vector.
  11253. /// @param attention_desc Memory descriptor for the attention vector.
  11254. /// @param weights_layer_desc Memory descriptor for the weights
  11255. /// applied to the layer input.
  11256. /// @param weights_iter_desc Memory descriptor for the weights applied
  11257. /// to the recurrent input.
  11258. /// @param bias_desc Bias memory descriptor.
  11259. /// @param dst_layer_desc Memory descriptor for the output vector.
  11260. /// @param dst_iter_desc Memory descriptor for the output recurrent
  11261. /// hidden state vector.
  11262. /// @param attr Primitive attributes to use. Attributes are optional
  11263. /// and default to empty attributes.
  11264. /// @param allow_empty A flag signifying whether construction is
  11265. /// allowed to fail without throwing an exception. In this case an
  11266. /// empty object will be produced. This flag is optional and
  11267. /// defaults to false.
  11268. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  11269. rnn_direction direction, const memory::desc &src_layer_desc,
  11270. const memory::desc &src_iter_desc,
  11271. const memory::desc &attention_desc,
  11272. const memory::desc &weights_layer_desc,
  11273. const memory::desc &weights_iter_desc,
  11274. const memory::desc &bias_desc,
  11275. const memory::desc &dst_layer_desc,
  11276. const memory::desc &dst_iter_desc,
  11277. const primitive_attr &attr = default_attr(),
  11278. bool allow_empty = false)
  11279. : rnn_primitive_desc_base(aengine, algorithm::lbr_augru, aprop_kind,
  11280. algorithm::undef, direction, src_layer_desc, src_iter_desc,
  11281. nullptr, &attention_desc, weights_layer_desc,
  11282. weights_iter_desc, nullptr, nullptr, bias_desc,
  11283. dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef,
  11284. 0.0f, 0.0f, attr, allow_empty) {}
  11285. /// Constructs a primitive descriptor for an LBR AUGRU forward propagation
  11286. /// primitive from a C API primitive descriptor that must have a
  11287. /// matching kind.
  11288. ///
  11289. /// @param pd C API primitive descriptor for an LBR AUGRU forward
  11290. /// propagation primitive.
  11291. primitive_desc(dnnl_primitive_desc_t pd)
  11292. : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
  11293. dnnl::prop_kind::forward_inference,
  11294. dnnl::algorithm::lbr_augru) {}
  11295. /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
  11296. memory::desc src_layer_desc() const {
  11297. return rnn_base::src_layer_desc();
  11298. }
  11299. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  11300. memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
  11301. /// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const
  11302. memory::desc attention_desc() const {
  11303. return rnn_base::augru_attention_desc();
  11304. }
  11305. /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
  11306. memory::desc weights_layer_desc() const {
  11307. return rnn_base::weights_layer_desc();
  11308. }
  11309. /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
  11310. memory::desc weights_iter_desc() const {
  11311. return rnn_base::weights_iter_desc();
  11312. }
  11313. /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
  11314. memory::desc bias_desc() const { return rnn_base::bias_desc(); }
  11315. /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
  11316. memory::desc dst_layer_desc() const {
  11317. return rnn_base::dst_layer_desc();
  11318. }
  11319. /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
  11320. memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
  11321. /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
  11322. memory::desc workspace_desc() const {
  11323. return rnn_base::workspace_desc();
  11324. }
  11325. /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
  11326. algorithm get_cell_kind() const { return base::get_cell_kind(); }
  11327. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  11328. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  11329. /// @copydoc dnnl::primitive_desc_base::get_direction()const
  11330. rnn_direction get_direction() const { return base::get_direction(); }
  11331. };
  11332. /// Default constructor. Produces an empty object.
  11333. lbr_augru_forward() = default;
  11334. /// Constructs an LBR AUGRU forward propagation primitive.
  11335. /// @param pd Primitive descriptor for an LBR AUGRU forward propagation
  11336. /// primitive.
  11337. lbr_augru_forward(const primitive_desc &pd) : primitive(pd) {}
  11338. /// Constructs an LBR AUGRU forward propagation primitive from a cache blob.
  11339. /// @param pd Primitive descriptor for an LBR AUGRU forward propagation
  11340. /// primitive.
  11341. /// @param cache_blob Cache blob.
  11342. lbr_augru_forward(
  11343. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  11344. : primitive(pd, cache_blob) {}
  11345. };
  11346. /// LBR AUGRU backward propagation primitive.
  11347. struct lbr_augru_backward : public primitive {
  11348. /// Primitive descriptor for an LBR AUGRU backward propagation primitive.
  11349. struct primitive_desc : public rnn_primitive_desc_base {
  11350. /// Default constructor. Produces an empty object.
  11351. primitive_desc() = default;
  11352. /// Constructs a primitive descriptor for LBR AUGRU backward propagation
  11353. /// primitive.
  11354. ///
  11355. /// The following arguments may point to a zero memory descriptor:
  11356. /// - @p src_iter_desc together with @p diff_src_iter_desc,
  11357. /// - @p bias_desc together with @p diff_bias_desc,
  11358. /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
  11359. ///
  11360. /// This would then indicate that the LBR AUGRU backward propagation
  11361. /// primitive should not use them and should default to zero values
  11362. /// instead.
  11363. ///
  11364. /// @note
  11365. /// All memory descriptors may be initialized with
  11366. /// #dnnl::memory::format_tag::any value of @p format_tag.
  11367. ///
  11368. /// @param aengine Engine to use.
  11369. /// @param aprop_kind Propagation kind. Must be
  11370. /// #dnnl::prop_kind::backward.
  11371. /// @param direction RNN direction. See @ref dnnl::rnn_direction for
  11372. /// more info.
  11373. /// @param src_layer_desc Memory descriptor for the input vector.
  11374. /// @param src_iter_desc Memory descriptor for the input recurrent
  11375. /// hidden state vector.
  11376. /// @param attention_desc Memory descriptor for the attention vector.
  11377. /// @param weights_layer_desc Memory descriptor for the weights
  11378. /// applied to the layer input.
  11379. /// @param weights_iter_desc Memory descriptor for the weights applied
  11380. /// to the recurrent input.
  11381. /// @param bias_desc Bias memory descriptor.
  11382. /// @param dst_layer_desc Memory descriptor for the output vector.
  11383. /// @param dst_iter_desc Memory descriptor for the output recurrent
  11384. /// hidden state vector.
  11385. /// @param diff_src_layer_desc Memory descriptor for the diff of input
  11386. /// vector.
  11387. /// @param diff_src_iter_desc Memory descriptor for the diff of input
  11388. /// recurrent hidden state vector.
  11389. /// @param diff_attention_desc Memory descriptor for the diff of
  11390. /// attention vector.
  11391. /// @param diff_weights_layer_desc Memory descriptor for the diff of
  11392. /// weights applied to the layer input.
  11393. /// @param diff_weights_iter_desc Memory descriptor for the diff of
  11394. /// weights applied to the recurrent input.
  11395. /// @param diff_bias_desc Diff bias memory descriptor.
  11396. /// @param diff_dst_layer_desc Memory descriptor for the diff of
  11397. /// output vector.
  11398. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  11399. /// recurrent hidden state vector.
  11400. /// @param hint_fwd_pd Primitive descriptor for an LBR AUGRU
  11401. /// forward propagation primitive. It is used as a hint for
  11402. /// deciding which memory format to use.
  11403. /// @param attr Primitive attributes to use. Attributes are optional
  11404. /// and default to empty attributes.
  11405. /// @param allow_empty A flag signifying whether construction is
  11406. /// allowed to fail without throwing an exception. In this case an
  11407. /// empty object will be produced. This flag is optional and
  11408. /// defaults to false.
  11409. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  11410. rnn_direction direction, const memory::desc &src_layer_desc,
  11411. const memory::desc &src_iter_desc,
  11412. const memory::desc &attention_desc,
  11413. const memory::desc &weights_layer_desc,
  11414. const memory::desc &weights_iter_desc,
  11415. const memory::desc &bias_desc,
  11416. const memory::desc &dst_layer_desc,
  11417. const memory::desc &dst_iter_desc,
  11418. const memory::desc &diff_src_layer_desc,
  11419. const memory::desc &diff_src_iter_desc,
  11420. const memory::desc &diff_attention_desc,
  11421. const memory::desc &diff_weights_layer_desc,
  11422. const memory::desc &diff_weights_iter_desc,
  11423. const memory::desc &diff_bias_desc,
  11424. const memory::desc &diff_dst_layer_desc,
  11425. const memory::desc &diff_dst_iter_desc,
  11426. const lbr_augru_forward::primitive_desc &hint_fwd_pd,
  11427. const primitive_attr &attr = default_attr(),
  11428. bool allow_empty = false)
  11429. : rnn_primitive_desc_base(aengine, algorithm::lbr_augru, aprop_kind,
  11430. algorithm::undef, direction, src_layer_desc, src_iter_desc,
  11431. nullptr, &attention_desc, weights_layer_desc,
  11432. weights_iter_desc, nullptr, nullptr, bias_desc,
  11433. dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
  11434. diff_src_iter_desc, nullptr, &diff_attention_desc,
  11435. diff_weights_layer_desc, diff_weights_iter_desc, nullptr,
  11436. nullptr, diff_bias_desc, diff_dst_layer_desc,
  11437. diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f,
  11438. hint_fwd_pd, attr, allow_empty) {}
  11439. /// Constructs a primitive descriptor for an LBR AUGRU backward
  11440. /// propagation primitive from a C API primitive descriptor that must
  11441. /// have a matching kind.
  11442. ///
  11443. /// @param pd C API primitive descriptor for an LBR AUGRU backward
  11444. /// propagation primitive.
  11445. primitive_desc(dnnl_primitive_desc_t pd)
  11446. : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
  11447. dnnl::algorithm::lbr_augru) {}
  11448. /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
  11449. memory::desc src_layer_desc() const {
  11450. return rnn_base::src_layer_desc();
  11451. }
  11452. /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
  11453. memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
  11454. /// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const
  11455. memory::desc attention_desc() const {
  11456. return rnn_base::augru_attention_desc();
  11457. }
  11458. /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
  11459. memory::desc weights_layer_desc() const {
  11460. return rnn_base::weights_layer_desc();
  11461. }
  11462. /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
  11463. memory::desc weights_iter_desc() const {
  11464. return rnn_base::weights_iter_desc();
  11465. }
  11466. /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
  11467. memory::desc bias_desc() const { return rnn_base::bias_desc(); }
  11468. /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
  11469. memory::desc dst_layer_desc() const {
  11470. return rnn_base::dst_layer_desc();
  11471. }
  11472. /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
  11473. memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
  11474. /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
  11475. memory::desc workspace_desc() const {
  11476. return rnn_base::workspace_desc();
  11477. }
  11478. /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
  11479. memory::desc diff_src_layer_desc() const {
  11480. return rnn_base::diff_src_layer_desc();
  11481. }
  11482. /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
  11483. memory::desc diff_src_iter_desc() const {
  11484. return rnn_base::diff_src_iter_desc();
  11485. }
  11486. /// @copydoc dnnl::rnn_primitive_desc_base::diff_augru_attention_desc()const
  11487. memory::desc diff_attention_desc() const {
  11488. return rnn_base::diff_augru_attention_desc();
  11489. }
  11490. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
  11491. memory::desc diff_weights_layer_desc() const {
  11492. return rnn_base::diff_weights_layer_desc();
  11493. }
  11494. /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
  11495. memory::desc diff_weights_iter_desc() const {
  11496. return rnn_base::diff_weights_iter_desc();
  11497. }
  11498. /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
  11499. memory::desc diff_bias_desc() const {
  11500. return rnn_base::diff_bias_desc();
  11501. }
  11502. /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
  11503. memory::desc diff_dst_layer_desc() const {
  11504. return rnn_base::diff_dst_layer_desc();
  11505. }
  11506. /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
  11507. memory::desc diff_dst_iter_desc() const {
  11508. return rnn_base::diff_dst_iter_desc();
  11509. }
  11510. /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
  11511. algorithm get_cell_kind() const { return base::get_cell_kind(); }
  11512. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  11513. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  11514. /// @copydoc dnnl::primitive_desc_base::get_direction()const
  11515. rnn_direction get_direction() const { return base::get_direction(); }
  11516. };
  11517. /// Default constructor. Produces an empty object.
  11518. lbr_augru_backward() = default;
  11519. /// Constructs an LBR AUGRU backward propagation primitive.
  11520. /// @param pd Primitive descriptor for an LBR AUGRU backward propagation
  11521. /// primitive.
  11522. lbr_augru_backward(const primitive_desc &pd) : primitive(pd) {}
  11523. /// Constructs an LBR AUGRU backward propagation primitive from a cache blob.
  11524. /// @param pd Primitive descriptor for an LBR AUGRU backward propagation
  11525. /// primitive.
  11526. /// @param cache_blob Cache blob.
  11527. lbr_augru_backward(
  11528. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  11529. : primitive(pd, cache_blob) {}
  11530. };
  11531. /// @} dnnl_api_rnn
  11532. /// @addtogroup dnnl_api_shuffle Shuffle
  11533. ///
  11534. /// A primitive to shuffle tensor data along an axis.
  11535. ///
  11536. /// @sa @ref dev_guide_shuffle in developer guide
  11537. ///
  11538. /// @{
  11539. /// Shuffle forward propagation primitive.
  11540. struct shuffle_forward : public primitive {
  11541. /// Primitive descriptor for a shuffle forward propagation primitive.
  11542. struct primitive_desc : public dnnl::primitive_desc {
  11543. /// Default constructor. Produces an empty object.
  11544. primitive_desc() = default;
  11545. /// Constructs a primitive descriptor for a shuffle forward propagation
  11546. /// primitive.
  11547. ///
  11548. /// @param aengine Engine to use.
  11549. /// @param aprop_kind Propagation kind. Possible values are
  11550. /// #dnnl::prop_kind::forward_training, and
  11551. /// #dnnl::prop_kind::forward_inference.
  11552. /// @param src_desc Source memory descriptor.
  11553. /// @param dst_desc Destination memory descriptor.
  11554. /// @param axis The axis along which the data is shuffled.
  11555. /// @param group_size Shuffle group size.
  11556. /// @param attr Primitive attributes to use. Attributes are optional
  11557. /// and default to empty attributes.
  11558. /// @param allow_empty A flag signifying whether construction is
  11559. /// allowed to fail without throwing an exception. In this case an
  11560. /// empty object will be produced. This flag is optional and
  11561. /// defaults to false.
  11562. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  11563. const memory::desc &src_desc, const memory::desc &dst_desc,
  11564. int axis, int group_size,
  11565. const primitive_attr &attr = default_attr(),
  11566. bool allow_empty = false) {
  11567. dnnl_primitive_desc_t pd = nullptr;
  11568. dnnl_status_t status = dnnl_shuffle_forward_primitive_desc_create(
  11569. &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
  11570. src_desc.get(), dst_desc.get(), axis, group_size,
  11571. attr.get());
  11572. if (!allow_empty)
  11573. error::wrap_c_api(status,
  11574. "could not create a primitive descriptor for "
  11575. "the shuffle forward propagation primitive. Run "
  11576. "workload with environment variable ONEDNN_VERBOSE=all "
  11577. "to get additional diagnostic information.");
  11578. reset(pd);
  11579. }
  11580. /// Constructs a primitive descriptor for a shuffle forward propagation
  11581. /// primitive from a C API primitive descriptor that must have a
  11582. /// matching kind.
  11583. ///
  11584. /// @param pd C API primitive descriptor for a shuffle forward
  11585. /// propagation primitive.
  11586. primitive_desc(dnnl_primitive_desc_t pd)
  11587. : dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle,
  11588. dnnl::prop_kind::forward_training,
  11589. dnnl::prop_kind::forward_inference) {}
  11590. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  11591. memory::desc src_desc() const { return base::src_desc(0); }
  11592. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  11593. memory::desc dst_desc() const { return base::dst_desc(0); }
  11594. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  11595. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  11596. /// @copydoc dnnl::primitive_desc_base::get_axis()const
  11597. int get_axis() const { return base::get_axis(); }
  11598. /// @copydoc dnnl::primitive_desc_base::get_group_size()const
  11599. memory::dim get_group_size() const { return base::get_group_size(); }
  11600. };
  11601. /// Default constructor. Produces an empty object.
  11602. shuffle_forward() = default;
  11603. /// Constructs a shuffle forward propagation primitive.
  11604. /// @param pd Primitive descriptor for a shuffle forward propagation
  11605. /// primitive.
  11606. shuffle_forward(const primitive_desc &pd) : primitive(pd) {}
  11607. /// Constructs a shuffle forward propagation primitive from a cache blob.
  11608. /// @param pd Primitive descriptor for a shuffle forward propagation
  11609. /// primitive.
  11610. /// @param cache_blob Cache blob.
  11611. shuffle_forward(
  11612. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  11613. : primitive(pd, cache_blob) {}
  11614. };
  11615. /// Shuffle backward propagation primitive.
  11616. struct shuffle_backward : public primitive {
  11617. /// Primitive descriptor for a shuffle backward propagation primitive.
  11618. struct primitive_desc : public dnnl::primitive_desc {
  11619. /// Default constructor. Produces an empty object.
  11620. primitive_desc() = default;
  11621. /// Constructs a primitive descriptor for a shuffle backward propagation
  11622. /// primitive.
  11623. ///
  11624. /// @param aengine Engine to use.
  11625. /// @param diff_src_desc Diff source memory descriptor.
  11626. /// @param diff_dst_desc Diff destination memory descriptor.
  11627. /// @param axis The axis along which the data is shuffled.
  11628. /// @param group_size Shuffle group size.
  11629. /// @param hint_fwd_pd Primitive descriptor for a shuffle forward
  11630. /// propagation primitive. It is used as a hint for deciding which
  11631. /// memory format to use.
  11632. /// @param attr Primitive attributes to use. Attributes are optional
  11633. /// and default to empty attributes.
  11634. /// @param allow_empty A flag signifying whether construction is
  11635. /// allowed to fail without throwing an exception. In this case an
  11636. /// empty object will be produced. This flag is optional and
  11637. /// defaults to false.
  11638. primitive_desc(const engine &aengine, const memory::desc &diff_src_desc,
  11639. const memory::desc &diff_dst_desc, int axis, int group_size,
  11640. const shuffle_forward::primitive_desc &hint_fwd_pd,
  11641. const primitive_attr &attr = default_attr(),
  11642. bool allow_empty = false) {
  11643. dnnl_primitive_desc_t pd = nullptr;
  11644. dnnl_status_t status = dnnl_shuffle_backward_primitive_desc_create(
  11645. &pd, aengine.get(), diff_src_desc.get(),
  11646. diff_dst_desc.get(), axis, group_size, hint_fwd_pd.get(),
  11647. attr.get());
  11648. if (!allow_empty)
  11649. error::wrap_c_api(status,
  11650. "could not create a primitive descriptor for "
  11651. "the shuffle backward propagation primitive. Run "
  11652. "workload with environment variable ONEDNN_VERBOSE=all "
  11653. "to get additional diagnostic information.");
  11654. reset(pd);
  11655. }
  11656. /// Constructs a primitive descriptor for a shuffle backward
  11657. /// propagation primitive from a C API primitive descriptor that must
  11658. /// have a matching kind.
  11659. ///
  11660. /// @param pd C API primitive descriptor for a shuffle backward
  11661. /// propagation primitive.
  11662. primitive_desc(dnnl_primitive_desc_t pd)
  11663. : dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle,
  11664. dnnl::prop_kind::backward_data) {}
  11665. /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
  11666. memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
  11667. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  11668. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  11669. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  11670. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  11671. /// @copydoc dnnl::primitive_desc_base::get_axis()const
  11672. int get_axis() const { return base::get_axis(); }
  11673. /// @copydoc dnnl::primitive_desc_base::get_group_size()const
  11674. memory::dim get_group_size() const { return base::get_group_size(); }
  11675. };
  11676. /// Default constructor. Produces an empty object.
  11677. shuffle_backward() = default;
  11678. /// Constructs a shuffle backward propagation primitive.
  11679. /// @param pd Primitive descriptor for a shuffle backward propagation
  11680. /// primitive.
  11681. shuffle_backward(const primitive_desc &pd) : primitive(pd) {}
  11682. /// Constructs a shuffle backward propagation primitive from a cache blob.
  11683. /// @param pd Primitive descriptor for a shuffle backward propagation
  11684. /// primitive.
  11685. /// @param cache_blob Cache blob.
  11686. shuffle_backward(
  11687. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  11688. : primitive(pd, cache_blob) {}
  11689. };
  11690. /// @} dnnl_api_shuffle
  11691. /// @addtogroup dnnl_api_binary Binary
  11692. ///
  11693. /// A primitive to perform tensor operations over two tensors.
  11694. ///
  11695. /// @sa @ref dev_guide_binary in developer guide
  11696. ///
  11697. /// @{
  11698. /// Elementwise binary operator primitive.
  11699. struct binary : public primitive {
  11700. /// Primitive descriptor for an elementwise binary operator primitive.
  11701. struct primitive_desc : public dnnl::primitive_desc {
  11702. /// Default constructor. Produces an empty object.
  11703. primitive_desc() = default;
  11704. /// Constructs a primitive descriptor for an elementwise binary operator
  11705. /// primitive.
  11706. ///
  11707. /// @param aengine Engine to use.
  11708. /// @param aalgorithm Elementwise binary algorithm.
  11709. /// @param src0 Memory descriptor for source tensor #0.
  11710. /// @param src1 Memory descriptor for source tensor #1.
  11711. /// @param dst Memory descriptor for destination tensor.
  11712. /// @param attr Primitive attributes to use. Attributes are optional
  11713. /// and default to empty attributes.
  11714. /// @param allow_empty A flag signifying whether construction is
  11715. /// allowed to fail without throwing an exception. In this case an
  11716. /// empty object will be produced. This flag is optional and
  11717. /// defaults to false.
  11718. primitive_desc(const engine &aengine, algorithm aalgorithm,
  11719. const memory::desc &src0, const memory::desc &src1,
  11720. const memory::desc &dst,
  11721. const primitive_attr &attr = default_attr(),
  11722. bool allow_empty = false) {
  11723. dnnl_primitive_desc_t pd = nullptr;
  11724. dnnl_status_t status = dnnl_binary_primitive_desc_create(&pd,
  11725. aengine.get(), dnnl::convert_to_c(aalgorithm), src0.get(),
  11726. src1.get(), dst.get(), attr.get());
  11727. if (!allow_empty)
  11728. error::wrap_c_api(status,
  11729. "could not create a primitive descriptor for "
  11730. "the binary operation primitive. Run workload with "
  11731. "environment variable ONEDNN_VERBOSE=all to get "
  11732. "additional diagnostic information.");
  11733. reset(pd);
  11734. }
  11735. /// Constructs a primitive descriptor for an elementwise binary operator
  11736. /// primitive with support of ternary operators.
  11737. ///
  11738. /// @param aengine Engine to use.
  11739. /// @param aalgorithm Elementwise binary algorithm.
  11740. /// @param src0 Memory descriptor for source tensor #0.
  11741. /// @param src1 Memory descriptor for source tensor #1.
  11742. /// @param src2 Memory descriptor for source tensor #2 for ternary
  11743. /// operations. Might be empty.
  11744. /// @param dst Memory descriptor for destination tensor.
  11745. /// @param attr Primitive attributes to use. Attributes are optional
  11746. /// and default to empty attributes.
  11747. /// @param allow_empty A flag signifying whether construction is
  11748. /// allowed to fail without throwing an exception. In this case an
  11749. /// empty object will be produced. This flag is optional and
  11750. /// defaults to false.
  11751. primitive_desc(const engine &aengine, algorithm aalgorithm,
  11752. const memory::desc &src0, const memory::desc &src1,
  11753. const memory::desc &src2, const memory::desc &dst,
  11754. const primitive_attr &attr = default_attr(),
  11755. bool allow_empty = false) {
  11756. dnnl_primitive_desc_t pd = nullptr;
  11757. dnnl_status_t status = dnnl_binary_primitive_desc_create_v2(&pd,
  11758. aengine.get(), dnnl::convert_to_c(aalgorithm), src0.get(),
  11759. src1.get(), src2.get(), dst.get(), attr.get());
  11760. if (!allow_empty)
  11761. error::wrap_c_api(status,
  11762. "could not create a primitive descriptor for "
  11763. "the binary v2 operation primitive. Run workload with "
  11764. "environment variable ONEDNN_VERBOSE=all to get "
  11765. "additional diagnostic information.");
  11766. reset(pd);
  11767. }
  11768. /// Constructs a primitive descriptor for a binary primitive from a C
  11769. /// API primitive descriptor that must have a matching kind.
  11770. ///
  11771. /// @param pd C API primitive descriptor for a binary primitive.
  11772. primitive_desc(dnnl_primitive_desc_t pd)
  11773. : dnnl::primitive_desc(pd, dnnl::primitive::kind::binary) {}
  11774. /// @copydoc dnnl::primitive_desc_base::src_desc(int)const
  11775. memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
  11776. /// Returns the memory descriptor for source #0.
  11777. memory::desc src0_desc() const { return base::src_desc(0); }
  11778. /// Returns the memory descriptor for source #1.
  11779. memory::desc src1_desc() const { return base::src_desc(1); }
  11780. /// Returns the memory descriptor for source #2.
  11781. memory::desc src2_desc() const { return base::src_desc(2); }
  11782. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  11783. memory::desc dst_desc() const { return base::dst_desc(0); }
  11784. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  11785. algorithm get_algorithm() const { return base::get_algorithm(); }
  11786. };
  11787. /// Default constructor. Produces an empty object.
  11788. binary() = default;
  11789. /// Constructs an elementwise binary operation primitive.
  11790. /// @param pd Primitive descriptor for an elementwise binary operation
  11791. /// primitive.
  11792. binary(const primitive_desc &pd) : primitive(pd) {}
  11793. /// Constructs an elementwise binary operation primitive from a cache blob.
  11794. /// @param pd Primitive descriptor for an elementwise binary operation
  11795. /// primitive.
  11796. /// @param cache_blob Cache blob.
  11797. binary(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  11798. : primitive(pd, cache_blob) {}
  11799. };
  11800. /// @} dnnl_api_binary
  11801. /// @addtogroup dnnl_api_matmul Matrix Multiplication
  11802. ///
  11803. /// A primitive to perform matrix-matrix multiplication. The batched mode
  11804. /// is supported with 3D tensors.
  11805. ///
  11806. /// @sa @ref dev_guide_matmul in developer guide
  11807. ///
  11808. ///
  11809. /// @{
  11810. /// Matrix multiplication (matmul) primitive.
  11811. struct matmul : public primitive {
  11812. /// Primitive descriptor for a matmul primitive.
  11813. struct primitive_desc : public dnnl::primitive_desc {
  11814. /// Default constructor. Produces an empty object.
  11815. primitive_desc() = default;
  11816. /// Constructs a primitive descriptor for a matmul primitive
  11817. /// without bias.
  11818. ///
  11819. /// @param aengine Engine to use.
  11820. /// @param src_desc Memory descriptor for source (matrix A).
  11821. /// @param weights_desc Memory descriptor for weights (matrix B).
  11822. /// @param dst_desc Memory descriptor for destination (matrix C).
  11823. /// @param attr Primitive attributes to use. Attributes are optional
  11824. /// and default to empty attributes.
  11825. /// @param allow_empty A flag signifying whether construction is
  11826. /// allowed to fail without throwing an exception. In this case an
  11827. /// empty object will be produced. This flag is optional and
  11828. /// defaults to false.
  11829. primitive_desc(const engine &aengine, const memory::desc &src_desc,
  11830. const memory::desc &weights_desc, const memory::desc &dst_desc,
  11831. const primitive_attr &attr = default_attr(),
  11832. bool allow_empty = false)
  11833. : primitive_desc(aengine, src_desc, weights_desc, nullptr, dst_desc,
  11834. attr, allow_empty) {}
  11835. /// Constructs a primitive descriptor for a matmul primitive with bias.
  11836. ///
  11837. /// @param aengine Engine to use.
  11838. /// @param src_desc Memory descriptor for source (matrix A).
  11839. /// @param weights_desc Memory descriptor for weights (matrix B).
  11840. /// @param dst_desc Memory descriptor for destination (matrix C).
  11841. /// @param bias_desc Memory descriptor for bias.
  11842. /// @param attr Primitive attributes to use. Attributes are optional
  11843. /// and default to empty attributes.
  11844. /// @param allow_empty A flag signifying whether construction is
  11845. /// allowed to fail without throwing an exception. In this case an
  11846. /// empty object will be produced. This flag is optional and
  11847. /// defaults to false.
  11848. primitive_desc(const engine &aengine, const memory::desc &src_desc,
  11849. const memory::desc &weights_desc, const memory::desc &bias_desc,
  11850. const memory::desc &dst_desc,
  11851. const primitive_attr &attr = default_attr(),
  11852. bool allow_empty = false)
  11853. : primitive_desc(aengine, src_desc, weights_desc, &bias_desc,
  11854. dst_desc, attr, allow_empty) {}
  11855. /// Constructs a primitive descriptor for a matmul primitive from a C
  11856. /// API primitive descriptor that must have a matching kind.
  11857. ///
  11858. /// @param pd C API primitive descriptor for a matmul primitive.
  11859. primitive_desc(dnnl_primitive_desc_t pd)
  11860. : dnnl::primitive_desc(pd, dnnl::primitive::kind::matmul) {}
  11861. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  11862. memory::desc src_desc() const { return query_md(query::src_md, 0); }
  11863. /// @copydoc dnnl::primitive_desc_base::weights_desc()const
  11864. memory::desc weights_desc() const {
  11865. return query_md(query::weights_md, 0);
  11866. }
  11867. /// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const
  11868. memory::desc bias_desc() const {
  11869. return query_md(query::weights_md, 1);
  11870. }
  11871. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  11872. memory::desc dst_desc() const { return query_md(query::dst_md, 0); }
  11873. private:
  11874. primitive_desc(const engine &aengine, const memory::desc &src_desc,
  11875. const memory::desc &weights_desc, const memory::desc *bias_desc,
  11876. const memory::desc &dst_desc, const primitive_attr &attr,
  11877. bool allow_empty) {
  11878. dnnl_primitive_desc_t pd = nullptr;
  11879. dnnl_status_t status = dnnl_matmul_primitive_desc_create(&pd,
  11880. aengine.get(), src_desc.get(), weights_desc.get(),
  11881. optional_arg(bias_desc), dst_desc.get(), attr.get());
  11882. if (!allow_empty)
  11883. error::wrap_c_api(status,
  11884. "could not create a primitive descriptor for "
  11885. "the matmul primitive. Run workload with "
  11886. "environment variable ONEDNN_VERBOSE=all to get "
  11887. "additional diagnostic information.");
  11888. reset(pd);
  11889. }
  11890. };
  11891. /// Default constructor. Produces an empty object.
  11892. matmul() = default;
  11893. /// Constructs a matmul primitive.
  11894. /// @param pd Primitive descriptor for a matmul primitive.
  11895. matmul(const primitive_desc &pd) : primitive(pd) {}
  11896. /// Constructs a matmul primitive from a cache blob.
  11897. /// @param pd Primitive descriptor for a matmul primitive.
  11898. /// @param cache_blob Cache blob.
  11899. matmul(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  11900. : primitive(pd, cache_blob) {}
  11901. };
  11902. /// @} dnnl_api_matmul
  11903. /// @addtogroup dnnl_api_resampling Resampling
  11904. ///
  11905. /// A primitive to compute resampling operation on 1D, 2D or 3D data tensor
  11906. /// using Nearest Neighbor, or Linear (Bilinear, Trilinear) interpolation
  11907. /// method.
  11908. ///
  11909. /// @sa @ref dev_guide_resampling in developer guide
  11910. ///
  11911. /// @{
  11912. /// Resampling forward propagation.
  11913. struct resampling_forward : public primitive {
  11914. /// Primitive descriptor for a resampling forward propagation primitive.
  11915. struct primitive_desc : public dnnl::primitive_desc {
  11916. /// Default constructor. Produces an empty object.
  11917. primitive_desc() = default;
  11918. /// Constructs a primitive descriptor for a resampling forward
  11919. /// propagation primitive using source and destination memory
  11920. /// descriptors.
  11921. ///
  11922. /// @note
  11923. /// Destination memory descriptor may be initialized with
  11924. /// #dnnl::memory::format_tag::any value of @p format_tag.
  11925. ///
  11926. /// @param aengine Engine to use.
  11927. /// @param aprop_kind Propagation kind. Possible values are
  11928. /// #dnnl::prop_kind::forward_training, and
  11929. /// #dnnl::prop_kind::forward_inference.
  11930. /// @param aalgorithm resampling algorithm kind: either
  11931. /// #dnnl::algorithm::resampling_nearest, or
  11932. /// #dnnl::algorithm::resampling_linear
  11933. /// @param src_desc Source memory descriptor.
  11934. /// @param dst_desc Destination memory descriptor.
  11935. /// @param attr Primitive attributes to use. Attributes are optional
  11936. /// and default to empty attributes.
  11937. /// @param allow_empty A flag signifying whether construction is
  11938. /// allowed to fail without throwing an exception. In this case an
  11939. /// empty object will be produced. This flag is optional and
  11940. /// defaults to false.
  11941. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  11942. algorithm aalgorithm, const memory::desc &src_desc,
  11943. const memory::desc &dst_desc,
  11944. const primitive_attr &attr = default_attr(),
  11945. bool allow_empty = false)
  11946. : primitive_desc(aengine, aprop_kind, aalgorithm, nullptr, src_desc,
  11947. &dst_desc, attr, allow_empty) {}
  11948. /// Constructs a primitive descriptor for a resampling forward
  11949. /// propagation primitive using source memory descriptor and
  11950. /// factors.
  11951. ///
  11952. /// @param aengine Engine to use.
  11953. /// @param aprop_kind Propagation kind. Possible values are
  11954. /// #dnnl::prop_kind::forward_training, and
  11955. /// #dnnl::prop_kind::forward_inference.
  11956. /// @param aalgorithm resampling algorithm kind: either
  11957. /// #dnnl::algorithm::resampling_nearest, or
  11958. /// #dnnl::algorithm::resampling_linear
  11959. /// @param factors Vector of scaling factors for spatial dimension.
  11960. /// @param src_desc Source memory descriptor.
  11961. /// @param attr Primitive attributes to use. Attributes are optional
  11962. /// and default to empty attributes.
  11963. /// @param allow_empty A flag signifying whether construction is
  11964. /// allowed to fail without throwing an exception. In this case an
  11965. /// empty object will be produced. This flag is optional and
  11966. /// defaults to false.
  11967. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  11968. algorithm aalgorithm, const std::vector<float> &factors,
  11969. const memory::desc &src_desc,
  11970. const primitive_attr &attr = default_attr(),
  11971. bool allow_empty = false)
  11972. : primitive_desc(aengine, aprop_kind, aalgorithm, &factors,
  11973. src_desc, nullptr, attr, allow_empty) {}
  11974. /// Constructs a primitive descriptor for a resampling forward
  11975. /// propagation primitive.
  11976. ///
  11977. /// @note
  11978. /// The destination memory descriptor may be initialized with
  11979. /// #dnnl::memory::format_tag::any value of @p format_tag.
  11980. ///
  11981. /// @param aengine Engine to use.
  11982. /// @param aprop_kind Propagation kind. Possible values are
  11983. /// #dnnl::prop_kind::forward_training, and
  11984. /// #dnnl::prop_kind::forward_inference.
  11985. /// @param aalgorithm resampling algorithm kind: either
  11986. /// #dnnl::algorithm::resampling_nearest, or
  11987. /// #dnnl::algorithm::resampling_linear
  11988. /// @param factors Vector of scaling factors for spatial dimension.
  11989. /// @param src_desc Source memory descriptor.
  11990. /// @param dst_desc Destination memory descriptor.
  11991. /// @param attr Primitive attributes to use. Attributes are optional
  11992. /// and default to empty attributes.
  11993. /// @param allow_empty A flag signifying whether construction is
  11994. /// allowed to fail without throwing an exception. In this case an
  11995. /// empty object will be produced. This flag is optional and
  11996. /// defaults to false.
  11997. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  11998. algorithm aalgorithm, const std::vector<float> &factors,
  11999. const memory::desc &src_desc, const memory::desc &dst_desc,
  12000. const primitive_attr &attr = default_attr(),
  12001. bool allow_empty = false)
  12002. : primitive_desc(aengine, aprop_kind, aalgorithm, &factors,
  12003. src_desc, &dst_desc, attr, allow_empty) {}
  12004. /// Constructs a primitive descriptor for a resampling forward
  12005. /// propagation primitive from a C API primitive descriptor that must
  12006. /// have a matching kind.
  12007. ///
  12008. /// @param pd C API primitive descriptor for a resampling forward
  12009. /// propagation primitive.
  12010. primitive_desc(dnnl_primitive_desc_t pd)
  12011. : dnnl::primitive_desc(pd, dnnl::primitive::kind::resampling,
  12012. dnnl::prop_kind::forward_training,
  12013. dnnl::prop_kind::forward_inference) {}
  12014. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  12015. memory::desc src_desc() const { return base::src_desc(0); }
  12016. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  12017. memory::desc dst_desc() const { return base::dst_desc(0); }
  12018. private:
  12019. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  12020. algorithm aalgorithm, const std::vector<float> *factors,
  12021. const memory::desc &src_desc, const memory::desc *dst_desc,
  12022. const primitive_attr &attr, bool allow_empty) {
  12023. if (factors)
  12024. memory::validate_dims(*factors, src_desc.get_ndims() - 2);
  12025. dnnl_primitive_desc_t pd = nullptr;
  12026. dnnl_status_t status
  12027. = dnnl_resampling_forward_primitive_desc_create(&pd,
  12028. aengine.get(), dnnl::convert_to_c(aprop_kind),
  12029. convert_to_c(aalgorithm), optional_arg(factors),
  12030. src_desc.get(), optional_arg(dst_desc), attr.get());
  12031. if (!allow_empty)
  12032. error::wrap_c_api(status,
  12033. "could not create a primitive descriptor for "
  12034. "the resampling forward propagation primitive. Run "
  12035. "workload with environment variable ONEDNN_VERBOSE=all "
  12036. "to get additional diagnostic information.");
  12037. reset(pd);
  12038. }
  12039. };
  12040. /// Default constructor. Produces an empty object.
  12041. resampling_forward() = default;
  12042. /// Constructs a resampling forward propagation primitive.
  12043. /// @param pd Primitive descriptor for a resampling forward propagation
  12044. /// primitive.
  12045. resampling_forward(const primitive_desc &pd) : primitive(pd) {}
  12046. /// Constructs a resampling forward propagation primitive from a cache
  12047. /// blob.
  12048. /// @param pd Primitive descriptor for a resampling forward propagation
  12049. /// primitive.
  12050. /// @param cache_blob Cache blob.
  12051. resampling_forward(
  12052. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  12053. : primitive(pd, cache_blob) {}
  12054. };
  12055. /// Resampling backward propagation primitive.
  12056. struct resampling_backward : public primitive {
  12057. /// Primitive descriptor for resampling backward propagation primitive.
  12058. struct primitive_desc : public dnnl::primitive_desc {
  12059. /// Default constructor. Produces an empty object.
  12060. primitive_desc() = default;
  12061. /// Constructs a primitive descriptor for a resampling backward
  12062. /// propagation primitive using source and destination memory
  12063. /// descriptors.
  12064. ///
  12065. /// @param aengine Engine to use.
  12066. /// @param aalgorithm resampling algorithm kind: either
  12067. /// #dnnl::algorithm::resampling_nearest, or
  12068. /// #dnnl::algorithm::resampling_linear
  12069. /// @param diff_src_desc Diff source memory descriptor.
  12070. /// @param diff_dst_desc Diff destination memory descriptor.
  12071. /// @param hint_fwd_pd Primitive descriptor for a resampling
  12072. /// forward propagation primitive. It is used as a hint for
  12073. /// deciding which memory format to use.
  12074. /// @param attr Primitive attributes to use. Attributes are optional
  12075. /// and default to empty attributes.
  12076. /// @param allow_empty A flag signifying whether construction is
  12077. /// allowed to fail without throwing an exception. In this case an
  12078. /// empty object will be produced. This flag is optional and
  12079. /// defaults to false.
  12080. primitive_desc(const engine &aengine, algorithm aalgorithm,
  12081. const memory::desc &diff_src_desc,
  12082. const memory::desc &diff_dst_desc,
  12083. const resampling_forward::primitive_desc &hint_fwd_pd,
  12084. const primitive_attr &attr = default_attr(),
  12085. bool allow_empty = false)
  12086. : primitive_desc(aengine, aalgorithm, nullptr, diff_src_desc,
  12087. diff_dst_desc, hint_fwd_pd, attr, allow_empty) {}
  12088. /// Constructs a primitive descriptor for resampling backward
  12089. /// propagation primitive.
  12090. ///
  12091. /// @param aengine Engine to use.
  12092. /// @param aalgorithm resampling algorithm kind: either
  12093. /// #dnnl::algorithm::resampling_nearest, or
  12094. /// #dnnl::algorithm::resampling_linear
  12095. /// @param factors Vector of scaling factors for spatial dimension.
  12096. /// @param diff_src_desc Diff source memory descriptor.
  12097. /// @param diff_dst_desc Diff destination memory descriptor.
  12098. /// @param hint_fwd_pd Primitive descriptor for a resampling
  12099. /// forward propagation primitive. It is used as a hint for
  12100. /// deciding which memory format to use.
  12101. /// @param attr Primitive attributes to use. Attributes are optional
  12102. /// and default to empty attributes.
  12103. /// @param allow_empty A flag signifying whether construction is
  12104. /// allowed to fail without throwing an exception. In this case an
  12105. /// empty object will be produced. This flag is optional and
  12106. /// defaults to false.
  12107. primitive_desc(const engine &aengine, algorithm aalgorithm,
  12108. const std::vector<float> &factors,
  12109. const memory::desc &diff_src_desc,
  12110. const memory::desc &diff_dst_desc,
  12111. const resampling_forward::primitive_desc &hint_fwd_pd,
  12112. const primitive_attr &attr = default_attr(),
  12113. bool allow_empty = false)
  12114. : primitive_desc(aengine, aalgorithm, &factors, diff_src_desc,
  12115. diff_dst_desc, hint_fwd_pd, attr, allow_empty) {}
  12116. /// Constructs a primitive descriptor for a resampling backward
  12117. /// propagation primitive from a C API primitive descriptor that must
  12118. /// have a matching kind.
  12119. ///
  12120. /// @param pd C API primitive descriptor for a resampling backward
  12121. /// propagation primitive.
  12122. primitive_desc(dnnl_primitive_desc_t pd)
  12123. : dnnl::primitive_desc(pd, dnnl::primitive::kind::resampling,
  12124. dnnl::prop_kind::backward_data) {}
  12125. /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
  12126. memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
  12127. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  12128. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  12129. private:
  12130. primitive_desc(const engine &aengine, algorithm aalgorithm,
  12131. const std::vector<float> *factors,
  12132. const memory::desc &diff_src_desc,
  12133. const memory::desc &diff_dst_desc,
  12134. const resampling_forward::primitive_desc &hint_fwd_pd,
  12135. const primitive_attr &attr, bool allow_empty) {
  12136. if (factors)
  12137. memory::validate_dims(*factors, diff_src_desc.get_ndims() - 2);
  12138. dnnl_primitive_desc_t pd = nullptr;
  12139. dnnl_status_t status
  12140. = dnnl_resampling_backward_primitive_desc_create(&pd,
  12141. aengine.get(), convert_to_c(aalgorithm),
  12142. optional_arg(factors), diff_src_desc.get(),
  12143. diff_dst_desc.get(), hint_fwd_pd.get(), attr.get());
  12144. if (!allow_empty)
  12145. error::wrap_c_api(status,
  12146. "could not create a primitive descriptor for "
  12147. "the resampling backward propagation primitive. Run "
  12148. "workload with environment variable ONEDNN_VERBOSE=all "
  12149. "to get additional diagnostic information.");
  12150. reset(pd);
  12151. }
  12152. };
  12153. /// Default constructor. Produces an empty object.
  12154. resampling_backward() = default;
  12155. /// Constructs a resampling backward propagation primitive.
  12156. /// @param pd Primitive descriptor for a resampling backward propagation
  12157. /// primitive.
  12158. resampling_backward(const primitive_desc &pd) : primitive(pd) {}
  12159. /// Constructs a resampling backward propagation primitive from a cache
  12160. /// blob.
  12161. /// @param pd Primitive descriptor for a resampling backward propagation
  12162. /// primitive.
  12163. /// @param cache_blob Cache blob.
  12164. resampling_backward(
  12165. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  12166. : primitive(pd, cache_blob) {}
  12167. };
  12168. /// @} dnnl_api_resampling
  12169. /// @addtogroup dnnl_api_pooling Pooling
  12170. ///
  12171. /// A primitive to perform max or average pooling with dilation.
  12172. ///
  12173. /// @sa @ref dev_guide_pooling in developer guide
  12174. ///
  12175. /// @{
  12176. /// Pooling forward propagation primitive.
  12177. struct pooling_forward : public primitive {
  12178. /// Primitive descriptor for a pooling forward propagation primitive.
  12179. struct primitive_desc : public dnnl::primitive_desc {
  12180. /// Default constructor. Produces an empty object.
  12181. primitive_desc() = default;
  12182. /// Constructs a primitive descriptor for pooling forward propagation
  12183. /// primitive.
  12184. ///
  12185. /// Arrays @p strides, @p kernel, @p dilation, @p padding_l
  12186. /// and @p padding_r contain values for spatial dimensions only and
  12187. /// hence must have the same number of elements as there are spatial
  12188. /// dimensions. The order of values is the same as in the tensor:
  12189. /// depth (for 3D tensors), height (for 3D and 2D tensors), and width.
  12190. ///
  12191. /// @param aengine Engine to use.
  12192. /// @param aprop_kind Propagation kind. Possible values are
  12193. /// #dnnl::prop_kind::forward_training, and
  12194. /// #dnnl::prop_kind::forward_inference.
  12195. /// @param aalgorithm Pooling algorithm kind: either
  12196. /// #dnnl::algorithm::pooling_max,
  12197. /// #dnnl::algorithm::pooling_avg_include_padding,
  12198. /// or #dnnl::algorithm::pooling_avg_exclude_padding.
  12199. /// @param src_desc Source memory descriptor.
  12200. /// @param dst_desc Destination memory descriptor.
  12201. /// @param strides Vector of strides for spatial dimension.
  12202. /// @param kernel Vector of kernel spatial dimensions.
  12203. /// @param dilation Array of dilations for spatial dimension.
  12204. /// @param padding_l Vector of padding values for low indices for each
  12205. /// spatial dimension `([[front,] top,] left)`.
  12206. /// @param padding_r Vector of padding values for high indices for
  12207. /// each spatial dimension `([[back,] bottom,] right)`.
  12208. /// @param attr Primitive attributes to use. Attributes are optional
  12209. /// and default to empty attributes.
  12210. /// @param allow_empty A flag signifying whether construction is
  12211. /// allowed to fail without throwing an exception. In this case an
  12212. /// empty object will be produced. This flag is optional and
  12213. /// defaults to false.
  12214. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  12215. algorithm aalgorithm, const memory::desc &src_desc,
  12216. const memory::desc &dst_desc, const memory::dims &strides,
  12217. const memory::dims &kernel, const memory::dims &dilation,
  12218. const memory::dims &padding_l, const memory::dims &padding_r,
  12219. const primitive_attr &attr = default_attr(),
  12220. bool allow_empty = false) {
  12221. memory::validate_dims(strides, src_desc.get_ndims() - 2);
  12222. memory::validate_dims(kernel, src_desc.get_ndims() - 2);
  12223. memory::validate_dims(padding_l, src_desc.get_ndims() - 2);
  12224. memory::validate_dims(padding_r, src_desc.get_ndims() - 2);
  12225. memory::validate_dims(dilation, src_desc.get_ndims() - 2);
  12226. dnnl_primitive_desc_t pd = nullptr;
  12227. dnnl_status_t status = dnnl_pooling_forward_primitive_desc_create(
  12228. &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
  12229. convert_to_c(aalgorithm), src_desc.get(), dst_desc.get(),
  12230. &strides[0], &kernel[0], &dilation[0], &padding_l[0],
  12231. &padding_r[0], attr.get());
  12232. if (!allow_empty)
  12233. error::wrap_c_api(status,
  12234. "could not create a descriptor for a pooling forward "
  12235. "propagation primitive");
  12236. reset(pd);
  12237. }
  12238. /// Constructs a primitive descriptor for a pooling forward propagation
  12239. /// primitive from a C API primitive descriptor that must have a
  12240. /// matching kind.
  12241. ///
  12242. /// @param pd C API primitive descriptor for a pooling forward
  12243. /// propagation primitive.
  12244. primitive_desc(dnnl_primitive_desc_t pd)
  12245. : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling,
  12246. dnnl::prop_kind::forward_training,
  12247. dnnl::prop_kind::forward_inference) {}
  12248. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  12249. memory::desc src_desc() const { return base::src_desc(0); }
  12250. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  12251. memory::desc dst_desc() const { return base::dst_desc(0); }
  12252. /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
  12253. memory::desc workspace_desc() const { return base::workspace_desc(); }
  12254. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  12255. algorithm get_algorithm() const { return base::get_algorithm(); }
  12256. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  12257. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  12258. /// @copydoc dnnl::primitive_desc_base::get_strides()const
  12259. memory::dims get_strides() const { return base::get_strides(); }
  12260. /// @copydoc dnnl::primitive_desc_base::get_kernel()const
  12261. memory::dims get_kernel() const { return base::get_kernel(); }
  12262. /// @copydoc dnnl::primitive_desc_base::get_dilations()const
  12263. memory::dims get_dilations() const { return base::get_dilations(); }
  12264. /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
  12265. memory::dims get_padding_l() const { return base::get_padding_l(); }
  12266. /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
  12267. memory::dims get_padding_r() const { return base::get_padding_r(); }
  12268. };
  12269. /// Default constructor. Produces an empty object.
  12270. pooling_forward() = default;
  12271. /// Constructs a pooling forward propagation primitive.
  12272. ///
  12273. /// @param pd Primitive descriptor for a pooling forward propagation
  12274. /// primitive.
  12275. pooling_forward(const primitive_desc &pd) : primitive(pd) {}
  12276. /// Constructs a pooling forward propagation primitive from a cache blob.
  12277. ///
  12278. /// @param pd Primitive descriptor for a pooling forward propagation
  12279. /// primitive.
  12280. /// @param cache_blob Cache blob.
  12281. pooling_forward(
  12282. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  12283. : primitive(pd, cache_blob) {}
  12284. };
  12285. /// Pooling backward propagation primitive.
  12286. struct pooling_backward : public primitive {
  12287. /// Primitive descriptor for a pooling backward propagation primitive.
  12288. struct primitive_desc : public dnnl::primitive_desc {
  12289. /// Default constructor. Produces an empty object.
  12290. primitive_desc() = default;
  12291. /// Constructs a primitive descriptor for a pooling backward propagation
  12292. /// primitive.
  12293. ///
  12294. /// Arrays @p strides, @p kernel, @p dilation, @p padding_l
  12295. /// and @p padding_r contain values for spatial dimensions only and
  12296. /// hence must have the same number of elements as there are spatial
  12297. /// dimensions. The order of values is the same as in the tensor:
  12298. /// depth (for 3D tensors), height (for 3D and 2D tensors), and width.
  12299. ///
  12300. /// @param aengine Engine to use.
  12301. /// @param aalgorithm Pooling algorithm kind: either
  12302. /// #dnnl::algorithm::pooling_max,
  12303. /// #dnnl::algorithm::pooling_avg_include_padding,
  12304. /// or #dnnl::algorithm::pooling_avg_exclude_padding.
  12305. /// @param diff_src_desc Diff source memory descriptor.
  12306. /// @param diff_dst_desc Diff destination memory descriptor.
  12307. /// @param strides Vector of strides for spatial dimension.
  12308. /// @param kernel Vector of kernel spatial dimensions.
  12309. /// @param dilation Array of dilations for spatial dimension.
  12310. /// @param padding_l Vector of padding values for low indices for each
  12311. /// spatial dimension `([[front,] top,] left)`.
  12312. /// @param padding_r Vector of padding values for high indices for
  12313. /// each spatial dimension `([[back,] bottom,] right)`.
  12314. /// @param hint_fwd_pd Primitive descriptor for a pooling
  12315. /// forward propagation primitive. It is used as a hint for
  12316. /// deciding which memory format to use.
  12317. /// @param attr Primitive attributes to use. Attributes are optional
  12318. /// and default to empty attributes.
  12319. /// @param allow_empty A flag signifying whether construction is
  12320. /// allowed to fail without throwing an exception. In this case an
  12321. /// empty object will be produced. This flag is optional and
  12322. /// defaults to false.
  12323. primitive_desc(const engine &aengine, algorithm aalgorithm,
  12324. const memory::desc &diff_src_desc,
  12325. const memory::desc &diff_dst_desc, const memory::dims &strides,
  12326. const memory::dims &kernel, const memory::dims &dilation,
  12327. const memory::dims &padding_l, const memory::dims &padding_r,
  12328. const pooling_forward::primitive_desc &hint_fwd_pd,
  12329. const primitive_attr &attr = default_attr(),
  12330. bool allow_empty = false) {
  12331. memory::validate_dims(strides, diff_src_desc.get_ndims() - 2);
  12332. memory::validate_dims(kernel, diff_src_desc.get_ndims() - 2);
  12333. memory::validate_dims(padding_l, diff_src_desc.get_ndims() - 2);
  12334. memory::validate_dims(padding_r, diff_src_desc.get_ndims() - 2);
  12335. memory::validate_dims(dilation, diff_src_desc.get_ndims() - 2);
  12336. dnnl_primitive_desc_t pd = nullptr;
  12337. dnnl_status_t status = dnnl_pooling_backward_primitive_desc_create(
  12338. &pd, aengine.get(), convert_to_c(aalgorithm),
  12339. diff_src_desc.get(), diff_dst_desc.get(), &strides[0],
  12340. &kernel[0], &dilation[0], &padding_l[0], &padding_r[0],
  12341. hint_fwd_pd.get(), attr.get());
  12342. if (!allow_empty)
  12343. error::wrap_c_api(status,
  12344. "could not create a descriptor for a pooling backward "
  12345. "propagation primitive");
  12346. reset(pd);
  12347. }
  12348. /// Constructs a primitive descriptor for a pooling backward propagation
  12349. /// primitive from a C API primitive descriptor that must have a
  12350. /// matching kind.
  12351. ///
  12352. /// @param pd C API primitive descriptor for a pooling backward
  12353. /// propagation primitive.
  12354. primitive_desc(dnnl_primitive_desc_t pd)
  12355. : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling,
  12356. dnnl::prop_kind::backward_data) {}
  12357. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  12358. memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
  12359. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  12360. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  12361. /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
  12362. memory::desc workspace_desc() const { return base::workspace_desc(); }
  12363. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  12364. algorithm get_algorithm() const { return base::get_algorithm(); }
  12365. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  12366. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  12367. /// @copydoc dnnl::primitive_desc_base::get_strides()const
  12368. memory::dims get_strides() const { return base::get_strides(); }
  12369. /// @copydoc dnnl::primitive_desc_base::get_kernel()const
  12370. memory::dims get_kernel() const { return base::get_kernel(); }
  12371. /// @copydoc dnnl::primitive_desc_base::get_dilations()const
  12372. memory::dims get_dilations() const { return base::get_dilations(); }
  12373. /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
  12374. memory::dims get_padding_l() const { return base::get_padding_l(); }
  12375. /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
  12376. memory::dims get_padding_r() const { return base::get_padding_r(); }
  12377. };
  12378. /// Default constructor. Produces an empty object.
  12379. pooling_backward() = default;
  12380. /// Constructs a pooling backward propagation primitive.
  12381. ///
  12382. /// @param pd Primitive descriptor for a pooling backward propagation
  12383. /// primitive.
  12384. pooling_backward(const primitive_desc &pd) : primitive(pd) {}
  12385. /// Constructs a pooling backward propagation primitive from a cache blob.
  12386. ///
  12387. /// @param pd Primitive descriptor for a pooling backward propagation
  12388. /// primitive.
  12389. /// @param cache_blob Cache blob.
  12390. pooling_backward(
  12391. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  12392. : primitive(pd, cache_blob) {}
  12393. };
  12394. /// @} dnnl_api_pooling
  12395. /// @addtogroup dnnl_api_prelu PReLU
  12396. ///
  12397. /// PReLU primitive
  12398. /// A primitive to perform PReLU (leaky ReLU with trainable alpha parameter)
  12399. ///
  12400. /// @sa @ref dev_guide_prelu in developer guide
  12401. ///
  12402. /// @{
  12403. /// PReLU forward propagation primitive.
  12404. struct prelu_forward : public primitive {
  12405. /// Primitive descriptor for a PReLU forward propagation primitive.
  12406. struct primitive_desc : public dnnl::primitive_desc {
  12407. /// Default constructor. Produces an empty object.
  12408. primitive_desc() = default;
  12409. /// Constructs a primitive descriptor for a PReLU forward propagation
  12410. /// primitive.
  12411. ///
  12412. /// @param aengine Engine to use.
  12413. /// @param aprop_kind Propagation kind. Possible values are
  12414. /// #dnnl::prop_kind::forward_training, and
  12415. /// #dnnl::prop_kind::forward_inference.
  12416. /// @param src_desc Source memory descriptor.
  12417. /// @param weight_desc Alpha parameters memory descriptor.
  12418. /// @param dst_desc Destination memory descriptor.
  12419. /// @param attr Primitive attributes to use. Attributes are optional
  12420. /// and default to empty attributes.
  12421. /// @param allow_empty A flag signifying whether construction is
  12422. /// allowed to fail without throwing an exception. In this case an
  12423. /// empty object will be produced. This flag is optional and
  12424. /// defaults to false.
  12425. primitive_desc(const engine &aengine, prop_kind aprop_kind,
  12426. const memory::desc &src_desc, const memory::desc &weight_desc,
  12427. const memory::desc &dst_desc,
  12428. const primitive_attr &attr = default_attr(),
  12429. bool allow_empty = false) {
  12430. dnnl_primitive_desc_t pd = nullptr;
  12431. dnnl_status_t status = dnnl_prelu_forward_primitive_desc_create(&pd,
  12432. aengine.get(), dnnl::convert_to_c(aprop_kind),
  12433. src_desc.get(), weight_desc.get(), dst_desc.get(),
  12434. attr.get());
  12435. if (!allow_empty)
  12436. error::wrap_c_api(status,
  12437. "could not create a primitive descriptor for "
  12438. "the prelu forward propagation primitive. Run workload "
  12439. "with environment variable ONEDNN_VERBOSE=all to get "
  12440. "additional diagnostic information.");
  12441. reset(pd);
  12442. }
  12443. /// Constructs a primitive descriptor for a prelu forward
  12444. /// propagation primitive from a C API primitive descriptor that must
  12445. /// have a matching kind.
  12446. ///
  12447. /// @param pd C API primitive descriptor for a prelu forward
  12448. /// propagation primitive.
  12449. primitive_desc(dnnl_primitive_desc_t pd)
  12450. : dnnl::primitive_desc(pd, dnnl::primitive::kind::prelu,
  12451. dnnl::prop_kind::forward_training,
  12452. dnnl::prop_kind::forward_inference) {}
  12453. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  12454. memory::desc src_desc() const { return base::src_desc(0); }
  12455. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  12456. memory::desc dst_desc() const { return base::dst_desc(0); }
  12457. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  12458. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  12459. };
  12460. /// Default constructor. Produces an empty object.
  12461. prelu_forward() = default;
  12462. /// Constructs a prelu forward propagation primitive.
  12463. /// @param pd Primitive descriptor for a prelu forward propagation
  12464. /// primitive.
  12465. prelu_forward(const primitive_desc &pd) : primitive(pd) {}
  12466. /// Constructs a prelu forward propagation primitive from a cache blob.
  12467. /// @param pd Primitive descriptor for a prelu forward propagation
  12468. /// primitive.
  12469. /// @param cache_blob Cache blob.
  12470. prelu_forward(
  12471. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  12472. : primitive(pd, cache_blob) {}
  12473. };
  12474. /// PReLU backward propagation primitive.
  12475. struct prelu_backward : public primitive {
  12476. /// Primitive descriptor for prelu backward propagation.
  12477. struct primitive_desc : public dnnl::primitive_desc {
  12478. /// Default constructor. Produces an empty object.
  12479. primitive_desc() = default;
  12480. /// Constructs a descriptor for a PReLU backward propagation
  12481. /// primitive.
  12482. ///
  12483. /// @param aengine Engine to use.
  12484. /// @param src_desc Source memory descriptor.
  12485. /// @param weight_desc Alpha parameters memory descriptor.
  12486. /// @param diff_src_desc Diff source memory descriptor.
  12487. /// @param diff_weights_desc Diff alpha parameters memory descriptor.
  12488. /// @param diff_dst_desc Diff destination memory descriptor.
  12489. /// @param hint_fwd_pd Primitive descriptor for a PReLU
  12490. /// forward propagation primitive. It is used as a hint for
  12491. /// deciding which memory format to use.
  12492. /// @param attr Primitive attributes to use. Attributes are optional
  12493. /// and default to empty attributes.
  12494. /// @param allow_empty A flag signifying whether construction is
  12495. /// allowed to fail without throwing an exception. In this case an
  12496. /// empty object will be produced. This flag is optional and
  12497. /// defaults to false.
  12498. primitive_desc(const engine &aengine, const memory::desc &src_desc,
  12499. const memory::desc &weight_desc,
  12500. const memory::desc &diff_src_desc,
  12501. const memory::desc &diff_weights_desc,
  12502. const memory::desc &diff_dst_desc,
  12503. const prelu_forward::primitive_desc &hint_fwd_pd,
  12504. const primitive_attr &attr = default_attr(),
  12505. bool allow_empty = false) {
  12506. dnnl_primitive_desc_t pd = nullptr;
  12507. dnnl_status_t status = dnnl_prelu_backward_primitive_desc_create(
  12508. &pd, aengine.get(), src_desc.get(), weight_desc.get(),
  12509. diff_src_desc.get(), diff_weights_desc.get(),
  12510. diff_dst_desc.get(), hint_fwd_pd.get(), attr.get());
  12511. if (!allow_empty)
  12512. error::wrap_c_api(status,
  12513. "could not create a primitive descriptor for "
  12514. "the prelu backward propagation primitive. Run "
  12515. "workload with environment variable ONEDNN_VERBOSE=all "
  12516. "to get additional diagnostic information.");
  12517. reset(pd);
  12518. }
  12519. /// Constructs a primitive descriptor for a prelu backward
  12520. /// propagation primitive from a C API primitive descriptor that must
  12521. /// have a matching kind.
  12522. ///
  12523. /// @param pd C API primitive descriptor for a prelu backward
  12524. /// propagation primitive.
  12525. primitive_desc(dnnl_primitive_desc_t pd)
  12526. : dnnl::primitive_desc(pd, dnnl::primitive::kind::prelu,
  12527. dnnl::prop_kind::backward) {}
  12528. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  12529. memory::desc src_desc() const { return base::src_desc(0); }
  12530. /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
  12531. memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
  12532. /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
  12533. memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
  12534. /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
  12535. prop_kind get_prop_kind() const { return base::get_prop_kind(); }
  12536. };
  12537. /// Default constructor. Produces an empty object.
  12538. prelu_backward() = default;
  12539. /// Constructs a prelu backward propagation primitive.
  12540. /// @param pd Primitive descriptor for a prelu backward propagation
  12541. /// primitive.
  12542. prelu_backward(const primitive_desc &pd) : primitive(pd) {}
  12543. /// Constructs a prelu backward propagation primitive from a cache blob.
  12544. /// @param pd Primitive descriptor for a prelu backward propagation
  12545. /// primitive.
  12546. /// @param cache_blob Cache blob.
  12547. prelu_backward(
  12548. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  12549. : primitive(pd, cache_blob) {}
  12550. };
  12551. /// @} dnnl_api_prelu
  12552. /// @addtogroup dnnl_api_reduction Reduction
  12553. ///
  12554. /// A primitive to compute reduction operation on data tensor
  12555. /// using min, max, mul, sum, mean and norm_lp operations.
  12556. ///
  12557. /// @sa @ref dev_guide_reduction in developer guide
  12558. ///
  12559. /// @{
  12560. /// Reduction.
  12561. struct reduction : public primitive {
  12562. /// Primitive descriptor for a reduction primitive.
  12563. struct primitive_desc : public dnnl::primitive_desc {
  12564. /// Default constructor. Produces an empty object.
  12565. primitive_desc() = default;
  12566. /// Constructs a primitive descriptor for a reduction primitive using
  12567. /// algorithm specific parameters, source and destination memory
  12568. /// descriptors.
  12569. ///
  12570. /// @note
  12571. /// Destination memory descriptor may be initialized with
  12572. /// #dnnl::memory::format_tag::any value of @p format_tag.
  12573. ///
  12574. /// @param aengine Engine to use.
  12575. /// @param aalgorithm reduction algorithm kind. Possible values:
  12576. /// #dnnl_reduction_max, #dnnl_reduction_min, #dnnl_reduction_sum,
  12577. /// #dnnl_reduction_mul, #dnnl_reduction_mean,
  12578. /// #dnnl_reduction_norm_lp_max, #dnnl_reduction_norm_lp_sum,
  12579. /// #dnnl_reduction_norm_lp_power_p_max,
  12580. /// #dnnl_reduction_norm_lp_power_p_sum.
  12581. /// @param p algorithm specific parameter.
  12582. /// @param eps algorithm specific parameter.
  12583. /// @param src_desc Source memory descriptor.
  12584. /// @param dst_desc Destination memory descriptor.
  12585. /// @param attr Primitive attributes to use. Attributes are optional
  12586. /// and default to empty attributes.
  12587. /// @param allow_empty A flag signifying whether construction is
  12588. /// allowed to fail without throwing an exception. In this case an
  12589. /// empty object will be produced. This flag is optional and
  12590. /// defaults to false.
  12591. primitive_desc(const engine &aengine, algorithm aalgorithm,
  12592. const memory::desc &src_desc, const memory::desc &dst_desc,
  12593. float p, float eps, const primitive_attr &attr = default_attr(),
  12594. bool allow_empty = false) {
  12595. dnnl_primitive_desc_t pd = nullptr;
  12596. dnnl_status_t status = dnnl_reduction_primitive_desc_create(&pd,
  12597. aengine.get(), convert_to_c(aalgorithm), src_desc.get(),
  12598. dst_desc.get(), p, eps, attr.get());
  12599. if (!allow_empty)
  12600. error::wrap_c_api(status,
  12601. "could not create a primitive descriptor for "
  12602. "the reduction primitive. Run workload with "
  12603. "environment variable ONEDNN_VERBOSE=all to get "
  12604. "additional diagnostic information.");
  12605. reset(pd);
  12606. }
  12607. /// Constructs a primitive descriptor for a reduction primitive from a C
  12608. /// API primitive descriptor that must have a matching kind.
  12609. ///
  12610. /// @param pd C API primitive descriptor for a reduction primitive.
  12611. primitive_desc(dnnl_primitive_desc_t pd)
  12612. : dnnl::primitive_desc(pd, dnnl::primitive::kind::reduction) {}
  12613. /// @copydoc dnnl::primitive_desc_base::src_desc()const
  12614. memory::desc src_desc() const { return base::src_desc(0); }
  12615. /// @copydoc dnnl::primitive_desc_base::dst_desc()const
  12616. memory::desc dst_desc() const { return base::dst_desc(0); }
  12617. /// @copydoc dnnl::primitive_desc_base::get_p()const
  12618. float get_p() const { return base::get_p(); }
  12619. /// @copydoc dnnl::primitive_desc_base::get_epsilon()const
  12620. float get_epsilon() const { return base::get_epsilon(); }
  12621. /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
  12622. algorithm get_algorithm() const { return base::get_algorithm(); }
  12623. };
  12624. /// Default constructor. Produces an empty object.
  12625. reduction() = default;
  12626. /// Constructs a reduction primitive.
  12627. /// @param pd Primitive descriptor for a reduction primitive.
  12628. reduction(const primitive_desc &pd) : primitive(pd) {}
  12629. /// Constructs a reduction primitive from a cache blob.
  12630. /// @param pd Primitive descriptor for a reduction primitive.
  12631. /// @param cache_blob Cache blob.
  12632. reduction(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  12633. : primitive(pd, cache_blob) {}
  12634. };
  12635. /// @} dnnl_api_reduction
  12636. /// @} dnnl_api_primitives
  12637. /// @addtogroup dnnl_api_service Service
  12638. ///
  12639. /// A set of functions that aid in oneDNN debugging and profiling.
  12640. ///
  12641. /// @{
  12642. /// @copydoc dnnl_version_t
  12643. using version_t = dnnl_version_t;
  12644. /// Status values returned by the library functions.
  12645. enum class status {
  12646. /// @copydoc dnnl_success
  12647. success = dnnl_success,
  12648. /// @copydoc dnnl_out_of_memory
  12649. out_of_memory = dnnl_out_of_memory,
  12650. /// @copydoc dnnl_invalid_arguments
  12651. invalid_arguments = dnnl_invalid_arguments,
  12652. /// @copydoc dnnl_unimplemented
  12653. unimplemented = dnnl_unimplemented,
  12654. /// @copydoc dnnl_last_impl_reached
  12655. last_impl_reached = dnnl_last_impl_reached,
  12656. /// @copydoc dnnl_runtime_error
  12657. runtime_error = dnnl_runtime_error,
  12658. /// @copydoc dnnl_not_required
  12659. not_required = dnnl_not_required,
  12660. };
  12661. /// @copydoc dnnl_set_verbose()
  12662. inline status set_verbose(int level) {
  12663. return static_cast<status>(dnnl_set_verbose(level));
  12664. }
  12665. /// @copydoc dnnl_version()
  12666. inline const version_t *version() {
  12667. return dnnl_version();
  12668. }
  12669. /// Returns the floating-point math mode that will be used by default
  12670. /// for all subsequently created primitives.
  12671. ///
  12672. /// @returns Output FP math mode.
  12673. inline fpmath_mode get_default_fpmath_mode() {
  12674. dnnl_fpmath_mode_t mode;
  12675. error::wrap_c_api(dnnl_get_default_fpmath_mode(&mode),
  12676. "could not get a default fpmath mode");
  12677. return static_cast<fpmath_mode>(mode);
  12678. }
  12679. /// @copydoc dnnl_set_default_fpmath_mode()
  12680. inline status set_default_fpmath_mode(fpmath_mode mode) {
  12681. return static_cast<status>(
  12682. dnnl_set_default_fpmath_mode(convert_to_c(mode)));
  12683. }
  12684. /// @copydoc dnnl_set_jit_dump()
  12685. inline status set_jit_dump(int enable) {
  12686. return static_cast<status>(dnnl_set_jit_dump(enable));
  12687. }
  12688. /// @copydoc dnnl_set_jit_profiling_flags()
  12689. inline status set_jit_profiling_flags(unsigned flags) {
  12690. return static_cast<status>(dnnl_set_jit_profiling_flags(flags));
  12691. }
  12692. /// @copydoc dnnl_set_jit_profiling_jitdumpdir()
  12693. inline status set_jit_profiling_jitdumpdir(const std::string &dir) {
  12694. return static_cast<status>(dnnl_set_jit_profiling_jitdumpdir(dir.c_str()));
  12695. }
  12696. /// @copydoc dnnl_cpu_isa_t
  12697. enum class cpu_isa {
  12698. /// @copydoc dnnl_cpu_isa_default
  12699. isa_default = dnnl_cpu_isa_default,
  12700. /// @copydoc dnnl_cpu_isa_sse41
  12701. sse41 = dnnl_cpu_isa_sse41,
  12702. /// @copydoc dnnl_cpu_isa_avx
  12703. avx = dnnl_cpu_isa_avx,
  12704. /// @copydoc dnnl_cpu_isa_avx2
  12705. avx2 = dnnl_cpu_isa_avx2,
  12706. /// @copydoc dnnl_cpu_isa_avx2_vnni
  12707. avx2_vnni = dnnl_cpu_isa_avx2_vnni,
  12708. /// @copydoc dnnl_cpu_isa_avx2_vnni_2
  12709. avx2_vnni_2 = dnnl_cpu_isa_avx2_vnni_2,
  12710. /// @copydoc dnnl_cpu_isa_avx512_core
  12711. avx512_core = dnnl_cpu_isa_avx512_core,
  12712. /// @copydoc dnnl_cpu_isa_avx512_core_vnni
  12713. avx512_core_vnni = dnnl_cpu_isa_avx512_core_vnni,
  12714. /// @copydoc dnnl_cpu_isa_avx512_core_bf16
  12715. avx512_core_bf16 = dnnl_cpu_isa_avx512_core_bf16,
  12716. /// @copydoc dnnl_cpu_isa_avx10_1_512
  12717. avx10_1_512 = dnnl_cpu_isa_avx10_1_512,
  12718. /// @copydoc dnnl_cpu_isa_avx512_core_fp16
  12719. avx512_core_fp16 = dnnl_cpu_isa_avx512_core_fp16,
  12720. /// @copydoc dnnl_cpu_isa_avx10_1_512_amx
  12721. avx10_1_512_amx = dnnl_cpu_isa_avx10_1_512_amx,
  12722. /// @copydoc dnnl_cpu_isa_avx512_core_amx
  12723. avx512_core_amx = dnnl_cpu_isa_avx512_core_amx,
  12724. /// @copydoc dnnl_cpu_isa_avx10_1_512_amx_fp16
  12725. avx10_1_512_amx_fp16 = dnnl_cpu_isa_avx10_1_512_amx_fp16,
  12726. /// @copydoc dnnl_cpu_isa_avx512_core_amx_fp16
  12727. avx512_core_amx_fp16 = dnnl_cpu_isa_avx512_core_amx_fp16,
  12728. /// @copydoc dnnl_cpu_isa_avx10_2_512
  12729. avx10_2_512 = dnnl_cpu_isa_avx10_2_512,
  12730. /// @copydoc dnnl_cpu_isa_avx10_2_512_amx_2
  12731. avx10_2_512_amx_2 = dnnl_cpu_isa_avx10_2_512_amx_2,
  12732. };
  12733. /// @copydoc dnnl_set_max_cpu_isa()
  12734. inline status set_max_cpu_isa(cpu_isa isa) {
  12735. return static_cast<status>(
  12736. dnnl_set_max_cpu_isa(static_cast<dnnl_cpu_isa_t>(isa)));
  12737. }
  12738. /// @copydoc dnnl_get_effective_cpu_isa()
  12739. inline cpu_isa get_effective_cpu_isa() {
  12740. return static_cast<cpu_isa>(dnnl_get_effective_cpu_isa());
  12741. }
  12742. /// @copydoc dnnl_cpu_isa_hints_t
  12743. enum class cpu_isa_hints {
  12744. /// @copydoc dnnl_cpu_isa_no_hints
  12745. no_hints = dnnl_cpu_isa_no_hints,
  12746. /// @copydoc dnnl_cpu_isa_prefer_ymm
  12747. prefer_ymm = dnnl_cpu_isa_prefer_ymm,
  12748. };
  12749. /// @copydoc dnnl_set_cpu_isa_hints()
  12750. inline status set_cpu_isa_hints(cpu_isa_hints isa_hints) {
  12751. return static_cast<status>(dnnl_set_cpu_isa_hints(
  12752. static_cast<dnnl_cpu_isa_hints_t>(isa_hints)));
  12753. }
  12754. /// @copydoc dnnl_get_cpu_isa_hints()
  12755. inline cpu_isa_hints get_cpu_isa_hints() {
  12756. return static_cast<cpu_isa_hints>(dnnl_get_cpu_isa_hints());
  12757. }
  12758. /// @} dnnl_api_service
  12759. #ifdef DNNL_EXPERIMENTAL_PROFILING
  12760. /// @addtogroup dnnl_api_profiling Profiling
  12761. /// @{
  12762. /// Profiling data kind.
  12763. enum class profiling_data_kind {
  12764. /// Undefined profiling data kind.
  12765. undef = dnnl_profiling_data_kind_undef,
  12766. /// Data kind to query an execution time in nanoseconds.
  12767. time = dnnl_profiling_data_kind_time,
  12768. };
  12769. /// Resets a profiler's state.
  12770. ///
  12771. /// @param stream Stream associated with the profiler.
  12772. inline void reset_profiling(stream &stream) {
  12773. error::wrap_c_api(
  12774. dnnl_reset_profiling(stream.get()), "could not reset profiling");
  12775. }
  12776. /// Returns requested profiling data. The profiling data accumulates for each
  12777. /// primitive execution. The size of the vector will be equal to the number
  12778. /// of executions since the last `dnnl::reset_profiling` call.
  12779. ///
  12780. /// The profiling data can be reset by calling #dnnl::reset_profiling.
  12781. ///
  12782. /// @note
  12783. /// It is required to wait for all submitted primitives to complete
  12784. /// using #dnnl::stream::wait prior to querying profiling data.
  12785. ///
  12786. /// @param stream Stream that was used for executing a primitive that
  12787. /// is being profiled.
  12788. /// @param data_kind Profiling data kind to query.
  12789. ///
  12790. /// @returns A vector with the requested profiling data.
  12791. inline std::vector<uint64_t> get_profiling_data(
  12792. stream &stream, profiling_data_kind data_kind) {
  12793. int num_entries = 0;
  12794. error::wrap_c_api(
  12795. dnnl_query_profiling_data(stream.get(),
  12796. static_cast<dnnl_profiling_data_kind_t>(data_kind),
  12797. &num_entries, nullptr),
  12798. "could not get number of entries for profiling data");
  12799. if (num_entries == 0) return {};
  12800. std::vector<uint64_t> data(num_entries);
  12801. error::wrap_c_api(
  12802. dnnl_query_profiling_data(stream.get(),
  12803. static_cast<dnnl_profiling_data_kind_t>(data_kind),
  12804. &num_entries, data.data()),
  12805. "could not get profiling data");
  12806. return data;
  12807. }
  12808. /// @} dnnl_api_profiling
  12809. #endif
  12810. /// @addtogroup dnnl_api_primitive_cache Primitive Cache
  12811. ///
  12812. /// A set of functions that provide primitive cache control.
  12813. ///
  12814. /// @{
  12815. /// Returns the number of primitives that can be held in the primitive cache
  12816. /// at the same time.
  12817. inline int get_primitive_cache_capacity() {
  12818. int result = 0;
  12819. error::wrap_c_api(dnnl_get_primitive_cache_capacity(&result),
  12820. "could not get primitive cache capacity");
  12821. return result;
  12822. }
  12823. /// @copydoc dnnl_set_primitive_cache_capacity(int capacity)
  12824. inline void set_primitive_cache_capacity(int capacity) {
  12825. error::wrap_c_api(dnnl_set_primitive_cache_capacity(capacity),
  12826. "could not set primitive cache capacity");
  12827. }
  12828. /// @} dnnl_api_primitive_cache
  12829. /// @addtogroup dnnl_api_blas BLAS functions
  12830. ///
  12831. /// A subset of Basic Linear Algebra (BLAS) functions that perform
  12832. /// matrix-matrix multiplication.
  12833. ///
  12834. /// @{
  12835. /// @copydoc dnnl_sgemm()
  12836. inline status sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N,
  12837. dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda,
  12838. const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc) {
  12839. return static_cast<status>(dnnl_sgemm(
  12840. transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc));
  12841. }
  12842. /// @copydoc dnnl_gemm_u8s8s32()
  12843. inline status gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
  12844. dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A,
  12845. dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
  12846. float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co) {
  12847. return static_cast<status>(dnnl_gemm_u8s8s32(transa, transb, offsetc, M, N,
  12848. K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co));
  12849. }
  12850. /// @copydoc dnnl_gemm_s8s8s32()
  12851. inline status gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
  12852. dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A,
  12853. dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
  12854. float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co) {
  12855. return static_cast<status>(dnnl_gemm_s8s8s32(transa, transb, offsetc, M, N,
  12856. K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co));
  12857. }
  12858. /// @} dnnl_api_blas
  12859. // implementation section
  12860. /// @cond DO_NOT_DOCUMENT_THIS
  12861. inline primitive::primitive(const_dnnl_primitive_desc_t c_pd) {
  12862. dnnl_primitive_t result;
  12863. error::wrap_c_api(dnnl_primitive_create(&result, c_pd),
  12864. "could not create a primitive");
  12865. reset(result);
  12866. }
  12867. inline primitive::primitive(const_dnnl_primitive_desc_t c_pd,
  12868. const std::vector<uint8_t> &cache_blob) {
  12869. dnnl_primitive_t result;
  12870. size_t size = cache_blob.size();
  12871. const uint8_t *cache_blob_data = cache_blob.data();
  12872. error::wrap_c_api(dnnl_primitive_create_from_cache_blob(
  12873. &result, c_pd, size, cache_blob_data),
  12874. "could not create a primitive from a cache blob");
  12875. reset(result);
  12876. }
  12877. inline primitive::primitive(const primitive_desc &pd) : primitive(pd.get()) {}
  12878. inline primitive::primitive(
  12879. const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
  12880. : primitive(pd.get(), cache_blob) {}
  12881. inline void primitive::execute(const stream &astream,
  12882. const std::unordered_map<int, memory> &args) const {
  12883. std::vector<dnnl_exec_arg_t> c_args;
  12884. c_args.reserve(args.size());
  12885. for (const auto &a : args)
  12886. c_args.push_back({a.first, a.second.get(true)});
  12887. error::wrap_c_api(dnnl_primitive_execute(get(), astream.get(),
  12888. (int)c_args.size(), c_args.data()),
  12889. "could not execute a primitive");
  12890. }
  12891. /// @endcond
  12892. #undef DNNL_DEFINE_BITMASK_OPS
  12893. } // namespace dnnl
  12894. /// oneAPI namespace
  12895. /// The oneAPI namespace.
  12896. /// Contains the oneapi::dnnl namespace as an alias to the ::dnnl namespace.
  12897. namespace oneapi {
  12898. // Note: without this guard, doxygen warns of potentially recursive namespace
  12899. #ifndef DOXYGEN_SHOULD_SKIP_THIS
  12900. /// oneDNN alias namespace
  12901. namespace dnnl = ::dnnl;
  12902. #endif
  12903. } // namespace oneapi
  12904. /// @} dnnl_api
  12905. // NOLINTEND(readability-identifier-naming)
  12906. #endif /* ONEAPI_DNNL_DNNL_HPP */
  12907. #else
  12908. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  12909. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)