| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270427142724273427442754276427742784279428042814282428342844285428642874288428942904291429242934294429542964297429842994300430143024303430443054306430743084309431043114312431343144315431643174318431943204321432243234324432543264327432843294330433143324333433443354336433743384339434043414342434343444345434643474348434943504351435243534354435543564357435843594360436143624363436443654366436743684369437043714372437343744375437643774378437943804381438243834384438543864387438843894390439143924393439443954396439743984399440044014402440344044405440644074408440944104411441244134414441544164417441844194420442144224423442444254426442744284429443044314432443344344435443644374438443944404441444244434444444544464447444844494450445144524453445444554456445744584459446044614462446344644465446644674468446944704471447244734474447544764477447844794480448144824483448444854486448744884489449044914492449344944495449644974498449945004501450245034504450545064507450845094510451145124513451445154516451745184519452045214522452345244525452645274528452945304531453245334534453545364537453845394540454145424543454445454546454745484549455045514552455345544555455645574558455945604561456245634564456545664567456845694570457145724573457445754576457745784579458045814582458345844585458645874588458945904591459245934594459545964597459845994600460146024603460446054606460746084609461046114612461346144615461646174618461946204621462246234624462546264627462846294630463146324633463446354636463746384639464046414642464346444645464646474648464946504651465246534654465546564657465846594660466146624663466446654666466746684669467046714672467346744675467646774678467946804681468246834684468546864687468846894690469146924693469446954696469746984699470047014702470347044705470647074708470947104711471247134714471547164717471847194720472147224723472447254726472747284729473047314732473347344735473647374738473947404741474247434744474547464747474847494750475147524753475447554756475747584759476047614762476347644765476647674768476947704771477247734774477547764777477847794780478147824783478447854786478747884789479047914792479347944795479647974798479948004801480248034804480548064807480848094810481148124813481448154816481748184819482048214822482348244825482648274828482948304831483248334834483548364837483848394840484148424843484448454846484748484849485048514852485348544855485648574858485948604861486248634864486548664867486848694870487148724873487448754876487748784879488048814882488348844885488648874888488948904891489248934894489548964897489848994900490149024903490449054906490749084909491049114912491349144915491649174918491949204921492249234924492549264927492849294930493149324933493449354936493749384939494049414942494349444945494649474948494949504951495249534954495549564957495849594960496149624963496449654966496749684969497049714972497349744975497649774978497949804981498249834984498549864987498849894990499149924993499449954996499749984999500050015002500350045005500650075008500950105011501250135014501550165017501850195020502150225023502450255026502750285029503050315032503350345035503650375038503950405041504250435044504550465047504850495050505150525053505450555056505750585059506050615062506350645065506650675068506950705071507250735074507550765077507850795080508150825083508450855086508750885089509050915092509350945095509650975098509951005101510251035104510551065107510851095110511151125113511451155116511751185119512051215122512351245125512651275128512951305131513251335134513551365137513851395140514151425143514451455146514751485149515051515152515351545155515651575158515951605161516251635164516551665167516851695170517151725173517451755176517751785179518051815182518351845185518651875188518951905191519251935194519551965197519851995200520152025203520452055206520752085209521052115212521352145215521652175218521952205221522252235224522552265227522852295230523152325233523452355236523752385239524052415242524352445245524652475248524952505251525252535254525552565257525852595260526152625263526452655266526752685269527052715272527352745275527652775278527952805281528252835284528552865287528852895290529152925293529452955296529752985299530053015302530353045305530653075308530953105311531253135314531553165317531853195320532153225323532453255326532753285329533053315332533353345335533653375338533953405341534253435344534553465347534853495350535153525353535453555356535753585359536053615362536353645365536653675368536953705371537253735374537553765377537853795380538153825383538453855386538753885389539053915392539353945395539653975398539954005401540254035404540554065407540854095410541154125413541454155416541754185419542054215422542354245425542654275428542954305431543254335434543554365437543854395440544154425443544454455446544754485449545054515452545354545455545654575458545954605461546254635464546554665467546854695470547154725473547454755476547754785479548054815482548354845485548654875488548954905491549254935494549554965497549854995500550155025503550455055506550755085509551055115512551355145515551655175518551955205521552255235524552555265527552855295530553155325533553455355536553755385539554055415542554355445545554655475548554955505551555255535554555555565557555855595560556155625563556455655566556755685569557055715572557355745575557655775578557955805581558255835584558555865587558855895590559155925593559455955596559755985599560056015602560356045605560656075608560956105611561256135614561556165617561856195620562156225623562456255626562756285629563056315632563356345635563656375638563956405641564256435644564556465647564856495650565156525653565456555656565756585659566056615662566356645665566656675668566956705671567256735674567556765677567856795680568156825683568456855686568756885689569056915692569356945695569656975698569957005701570257035704570557065707570857095710571157125713571457155716571757185719572057215722572357245725572657275728572957305731573257335734573557365737573857395740574157425743574457455746574757485749575057515752575357545755575657575758575957605761576257635764576557665767576857695770577157725773577457755776577757785779578057815782578357845785578657875788578957905791579257935794579557965797579857995800580158025803580458055806580758085809581058115812581358145815581658175818581958205821582258235824582558265827582858295830583158325833583458355836583758385839584058415842584358445845584658475848584958505851585258535854585558565857585858595860586158625863586458655866586758685869587058715872587358745875587658775878587958805881588258835884588558865887588858895890589158925893589458955896589758985899590059015902590359045905590659075908590959105911591259135914591559165917591859195920592159225923592459255926592759285929593059315932593359345935593659375938593959405941594259435944594559465947594859495950595159525953595459555956595759585959596059615962596359645965596659675968596959705971597259735974597559765977597859795980598159825983598459855986598759885989599059915992599359945995599659975998599960006001600260036004600560066007600860096010601160126013601460156016601760186019602060216022602360246025602660276028602960306031603260336034603560366037603860396040604160426043604460456046604760486049605060516052605360546055605660576058605960606061606260636064606560666067606860696070607160726073607460756076607760786079608060816082608360846085608660876088608960906091609260936094609560966097609860996100610161026103610461056106610761086109611061116112611361146115611661176118611961206121612261236124612561266127612861296130613161326133613461356136613761386139614061416142614361446145614661476148614961506151615261536154615561566157615861596160616161626163616461656166616761686169617061716172617361746175617661776178617961806181618261836184618561866187618861896190619161926193619461956196619761986199620062016202620362046205620662076208620962106211621262136214621562166217621862196220622162226223622462256226622762286229623062316232623362346235623662376238623962406241624262436244624562466247624862496250625162526253625462556256625762586259626062616262626362646265626662676268626962706271627262736274627562766277627862796280628162826283628462856286628762886289629062916292629362946295629662976298629963006301630263036304630563066307630863096310631163126313631463156316631763186319632063216322632363246325632663276328632963306331633263336334633563366337633863396340634163426343634463456346634763486349635063516352635363546355635663576358635963606361636263636364636563666367636863696370637163726373637463756376637763786379638063816382638363846385638663876388638963906391639263936394639563966397639863996400640164026403640464056406640764086409641064116412641364146415641664176418641964206421642264236424642564266427642864296430643164326433643464356436643764386439644064416442644364446445644664476448644964506451645264536454645564566457645864596460646164626463646464656466646764686469647064716472647364746475647664776478647964806481648264836484648564866487648864896490649164926493649464956496649764986499650065016502650365046505650665076508650965106511651265136514651565166517651865196520652165226523652465256526652765286529653065316532653365346535653665376538653965406541654265436544654565466547654865496550655165526553655465556556655765586559656065616562656365646565656665676568656965706571657265736574657565766577657865796580658165826583658465856586658765886589659065916592659365946595659665976598659966006601660266036604660566066607660866096610661166126613661466156616661766186619662066216622662366246625662666276628662966306631663266336634663566366637663866396640664166426643664466456646664766486649665066516652665366546655665666576658665966606661666266636664666566666667666866696670667166726673667466756676667766786679668066816682668366846685668666876688668966906691669266936694669566966697669866996700670167026703670467056706670767086709671067116712671367146715671667176718671967206721672267236724672567266727672867296730673167326733673467356736673767386739674067416742674367446745674667476748674967506751675267536754675567566757675867596760676167626763676467656766676767686769677067716772677367746775677667776778677967806781678267836784678567866787678867896790679167926793679467956796679767986799680068016802680368046805680668076808680968106811681268136814681568166817681868196820682168226823682468256826682768286829683068316832683368346835683668376838683968406841684268436844684568466847684868496850685168526853685468556856685768586859686068616862686368646865686668676868686968706871687268736874687568766877687868796880688168826883688468856886688768886889689068916892689368946895689668976898689969006901690269036904690569066907690869096910691169126913691469156916691769186919692069216922692369246925692669276928692969306931693269336934693569366937693869396940694169426943694469456946694769486949695069516952695369546955695669576958695969606961696269636964696569666967696869696970697169726973697469756976697769786979698069816982698369846985698669876988698969906991699269936994699569966997699869997000700170027003700470057006700770087009701070117012701370147015701670177018701970207021702270237024702570267027702870297030703170327033703470357036703770387039704070417042704370447045704670477048704970507051705270537054705570567057705870597060706170627063706470657066706770687069707070717072707370747075707670777078707970807081708270837084708570867087708870897090709170927093709470957096709770987099710071017102710371047105710671077108710971107111711271137114711571167117711871197120712171227123712471257126712771287129713071317132713371347135713671377138713971407141714271437144714571467147714871497150715171527153715471557156715771587159716071617162716371647165716671677168716971707171717271737174717571767177717871797180718171827183718471857186718771887189719071917192719371947195719671977198719972007201720272037204720572067207720872097210721172127213721472157216721772187219722072217222722372247225722672277228722972307231723272337234723572367237723872397240724172427243724472457246724772487249725072517252725372547255725672577258725972607261726272637264726572667267726872697270727172727273727472757276727772787279728072817282728372847285728672877288728972907291729272937294729572967297729872997300730173027303730473057306730773087309731073117312731373147315731673177318731973207321732273237324732573267327732873297330733173327333733473357336733773387339734073417342734373447345734673477348734973507351735273537354735573567357735873597360736173627363736473657366736773687369737073717372737373747375737673777378737973807381738273837384738573867387738873897390739173927393739473957396739773987399740074017402740374047405740674077408740974107411741274137414741574167417741874197420742174227423742474257426742774287429743074317432743374347435743674377438743974407441744274437444744574467447744874497450745174527453745474557456745774587459746074617462746374647465746674677468746974707471747274737474747574767477747874797480748174827483748474857486748774887489749074917492749374947495749674977498749975007501750275037504750575067507750875097510751175127513751475157516751775187519752075217522752375247525752675277528752975307531753275337534753575367537753875397540754175427543754475457546754775487549755075517552755375547555755675577558755975607561756275637564756575667567756875697570757175727573757475757576757775787579758075817582758375847585758675877588758975907591759275937594759575967597759875997600760176027603760476057606760776087609761076117612761376147615761676177618761976207621762276237624762576267627762876297630763176327633763476357636763776387639764076417642764376447645764676477648764976507651765276537654765576567657765876597660766176627663766476657666766776687669767076717672767376747675767676777678767976807681768276837684768576867687768876897690769176927693769476957696769776987699770077017702770377047705770677077708770977107711771277137714771577167717771877197720772177227723772477257726772777287729773077317732773377347735773677377738773977407741774277437744774577467747774877497750775177527753775477557756775777587759776077617762776377647765776677677768776977707771777277737774777577767777777877797780778177827783778477857786778777887789779077917792779377947795779677977798779978007801780278037804780578067807780878097810781178127813781478157816781778187819782078217822782378247825782678277828782978307831783278337834783578367837783878397840784178427843784478457846784778487849785078517852785378547855785678577858785978607861786278637864786578667867786878697870787178727873787478757876787778787879788078817882788378847885788678877888788978907891789278937894789578967897789878997900790179027903790479057906790779087909791079117912791379147915791679177918791979207921792279237924792579267927792879297930793179327933793479357936793779387939794079417942794379447945794679477948794979507951795279537954795579567957795879597960796179627963796479657966796779687969797079717972797379747975797679777978797979807981798279837984798579867987798879897990799179927993799479957996799779987999800080018002800380048005800680078008800980108011801280138014801580168017801880198020802180228023802480258026802780288029803080318032803380348035803680378038803980408041804280438044804580468047804880498050805180528053805480558056805780588059806080618062806380648065806680678068806980708071807280738074807580768077807880798080808180828083808480858086808780888089809080918092809380948095809680978098809981008101810281038104810581068107810881098110811181128113811481158116811781188119812081218122812381248125812681278128812981308131813281338134813581368137813881398140814181428143814481458146814781488149815081518152815381548155815681578158815981608161816281638164816581668167816881698170817181728173817481758176817781788179818081818182818381848185818681878188818981908191819281938194819581968197819881998200820182028203820482058206820782088209821082118212821382148215821682178218821982208221822282238224822582268227822882298230823182328233823482358236823782388239824082418242824382448245824682478248824982508251825282538254825582568257825882598260826182628263826482658266826782688269827082718272827382748275827682778278827982808281828282838284828582868287828882898290829182928293829482958296829782988299830083018302830383048305830683078308830983108311831283138314831583168317831883198320832183228323832483258326832783288329833083318332833383348335833683378338833983408341834283438344834583468347834883498350835183528353835483558356835783588359836083618362836383648365836683678368836983708371837283738374837583768377837883798380838183828383838483858386838783888389839083918392839383948395839683978398839984008401840284038404840584068407840884098410841184128413841484158416841784188419842084218422842384248425842684278428842984308431843284338434843584368437843884398440844184428443844484458446844784488449845084518452845384548455845684578458845984608461846284638464846584668467846884698470847184728473847484758476847784788479848084818482848384848485848684878488848984908491849284938494849584968497849884998500850185028503850485058506850785088509851085118512851385148515851685178518851985208521852285238524852585268527852885298530853185328533853485358536853785388539854085418542854385448545854685478548854985508551855285538554855585568557855885598560856185628563856485658566856785688569857085718572857385748575857685778578857985808581858285838584858585868587858885898590859185928593859485958596859785988599860086018602860386048605860686078608860986108611861286138614861586168617861886198620862186228623862486258626862786288629863086318632863386348635863686378638863986408641864286438644864586468647864886498650865186528653865486558656865786588659866086618662866386648665866686678668866986708671867286738674867586768677867886798680868186828683868486858686868786888689869086918692869386948695869686978698869987008701870287038704870587068707870887098710871187128713871487158716871787188719872087218722872387248725872687278728872987308731873287338734873587368737873887398740874187428743874487458746874787488749875087518752875387548755875687578758875987608761876287638764876587668767876887698770877187728773877487758776877787788779878087818782878387848785878687878788878987908791879287938794879587968797879887998800880188028803880488058806880788088809881088118812881388148815881688178818881988208821882288238824882588268827882888298830883188328833883488358836883788388839884088418842884388448845884688478848884988508851885288538854885588568857885888598860886188628863886488658866886788688869887088718872887388748875887688778878887988808881888288838884888588868887888888898890889188928893889488958896889788988899890089018902890389048905890689078908890989108911891289138914891589168917891889198920892189228923892489258926892789288929893089318932893389348935893689378938893989408941894289438944894589468947894889498950895189528953895489558956895789588959896089618962896389648965896689678968896989708971897289738974897589768977897889798980898189828983898489858986898789888989899089918992899389948995899689978998899990009001900290039004900590069007900890099010901190129013901490159016901790189019902090219022902390249025902690279028902990309031903290339034903590369037903890399040904190429043904490459046904790489049905090519052905390549055905690579058905990609061906290639064906590669067906890699070907190729073907490759076907790789079908090819082908390849085908690879088908990909091909290939094909590969097909890999100910191029103910491059106910791089109911091119112911391149115911691179118911991209121912291239124912591269127912891299130913191329133913491359136913791389139914091419142914391449145914691479148914991509151915291539154915591569157915891599160916191629163916491659166916791689169917091719172917391749175917691779178917991809181918291839184918591869187918891899190919191929193919491959196919791989199920092019202920392049205920692079208920992109211921292139214921592169217921892199220922192229223922492259226922792289229923092319232923392349235923692379238923992409241924292439244924592469247924892499250925192529253925492559256925792589259926092619262926392649265926692679268926992709271927292739274927592769277927892799280928192829283928492859286928792889289929092919292929392949295929692979298929993009301930293039304930593069307930893099310931193129313931493159316931793189319932093219322932393249325932693279328932993309331933293339334933593369337933893399340934193429343934493459346934793489349935093519352935393549355935693579358935993609361936293639364936593669367936893699370937193729373937493759376937793789379938093819382938393849385938693879388938993909391939293939394939593969397939893999400940194029403940494059406940794089409941094119412941394149415941694179418941994209421942294239424942594269427942894299430943194329433943494359436943794389439944094419442944394449445944694479448944994509451945294539454945594569457945894599460946194629463946494659466946794689469947094719472947394749475947694779478947994809481948294839484948594869487948894899490949194929493949494959496949794989499950095019502950395049505950695079508950995109511951295139514951595169517951895199520952195229523952495259526952795289529953095319532953395349535953695379538953995409541954295439544954595469547954895499550955195529553955495559556955795589559956095619562956395649565956695679568956995709571957295739574957595769577957895799580958195829583958495859586958795889589959095919592959395949595959695979598959996009601960296039604960596069607960896099610961196129613961496159616961796189619962096219622962396249625962696279628962996309631963296339634963596369637963896399640964196429643964496459646964796489649965096519652965396549655965696579658965996609661966296639664966596669667966896699670967196729673967496759676967796789679968096819682968396849685968696879688968996909691969296939694969596969697969896999700970197029703970497059706970797089709971097119712971397149715971697179718971997209721972297239724972597269727972897299730973197329733973497359736973797389739974097419742974397449745974697479748974997509751975297539754975597569757975897599760976197629763976497659766976797689769977097719772977397749775977697779778977997809781978297839784978597869787978897899790979197929793979497959796979797989799980098019802980398049805980698079808980998109811981298139814981598169817981898199820982198229823982498259826982798289829983098319832983398349835983698379838983998409841984298439844984598469847984898499850985198529853985498559856985798589859986098619862986398649865986698679868986998709871987298739874987598769877987898799880988198829883988498859886988798889889989098919892989398949895989698979898989999009901990299039904990599069907990899099910991199129913991499159916991799189919992099219922992399249925992699279928992999309931993299339934993599369937993899399940994199429943994499459946994799489949995099519952995399549955995699579958995999609961996299639964996599669967996899699970997199729973997499759976997799789979998099819982998399849985998699879988998999909991999299939994999599969997999899991000010001100021000310004100051000610007100081000910010100111001210013100141001510016100171001810019100201002110022100231002410025100261002710028100291003010031100321003310034100351003610037100381003910040100411004210043100441004510046100471004810049100501005110052100531005410055100561005710058100591006010061100621006310064100651006610067100681006910070100711007210073100741007510076100771007810079100801008110082100831008410085100861008710088100891009010091100921009310094100951009610097100981009910100101011010210103101041010510106101071010810109101101011110112101131011410115101161011710118101191012010121101221012310124101251012610127101281012910130101311013210133101341013510136101371013810139101401014110142101431014410145101461014710148101491015010151101521015310154101551015610157101581015910160101611016210163101641016510166101671016810169101701017110172101731017410175101761017710178101791018010181101821018310184101851018610187101881018910190101911019210193101941019510196101971019810199102001020110202102031020410205102061020710208102091021010211102121021310214102151021610217102181021910220102211022210223102241022510226102271022810229102301023110232102331023410235102361023710238102391024010241102421024310244102451024610247102481024910250102511025210253102541025510256102571025810259102601026110262102631026410265102661026710268102691027010271102721027310274102751027610277102781027910280102811028210283102841028510286102871028810289102901029110292102931029410295102961029710298102991030010301103021030310304103051030610307103081030910310103111031210313103141031510316103171031810319103201032110322103231032410325103261032710328103291033010331103321033310334103351033610337103381033910340103411034210343103441034510346103471034810349103501035110352103531035410355103561035710358103591036010361103621036310364103651036610367103681036910370103711037210373103741037510376103771037810379103801038110382103831038410385103861038710388103891039010391103921039310394103951039610397103981039910400104011040210403104041040510406104071040810409104101041110412104131041410415104161041710418104191042010421104221042310424104251042610427104281042910430104311043210433104341043510436104371043810439104401044110442104431044410445104461044710448104491045010451104521045310454104551045610457104581045910460104611046210463104641046510466104671046810469104701047110472104731047410475104761047710478104791048010481104821048310484104851048610487104881048910490104911049210493104941049510496104971049810499105001050110502105031050410505105061050710508105091051010511105121051310514105151051610517105181051910520105211052210523105241052510526105271052810529105301053110532105331053410535105361053710538105391054010541105421054310544105451054610547105481054910550105511055210553105541055510556105571055810559105601056110562105631056410565105661056710568105691057010571105721057310574105751057610577105781057910580105811058210583105841058510586105871058810589105901059110592105931059410595105961059710598105991060010601106021060310604106051060610607106081060910610106111061210613106141061510616106171061810619106201062110622106231062410625106261062710628106291063010631106321063310634106351063610637106381063910640106411064210643106441064510646106471064810649106501065110652106531065410655106561065710658106591066010661106621066310664106651066610667106681066910670106711067210673106741067510676106771067810679106801068110682106831068410685106861068710688106891069010691106921069310694106951069610697106981069910700107011070210703107041070510706107071070810709107101071110712107131071410715107161071710718107191072010721107221072310724107251072610727107281072910730107311073210733107341073510736107371073810739107401074110742107431074410745107461074710748107491075010751107521075310754107551075610757107581075910760107611076210763107641076510766107671076810769107701077110772107731077410775107761077710778107791078010781107821078310784107851078610787107881078910790107911079210793107941079510796107971079810799108001080110802108031080410805108061080710808108091081010811108121081310814108151081610817108181081910820108211082210823108241082510826108271082810829108301083110832108331083410835108361083710838108391084010841108421084310844108451084610847108481084910850108511085210853108541085510856108571085810859108601086110862108631086410865108661086710868108691087010871108721087310874108751087610877108781087910880108811088210883108841088510886108871088810889108901089110892108931089410895108961089710898108991090010901109021090310904109051090610907109081090910910109111091210913109141091510916109171091810919109201092110922109231092410925109261092710928109291093010931109321093310934109351093610937109381093910940109411094210943109441094510946109471094810949109501095110952109531095410955109561095710958109591096010961109621096310964109651096610967109681096910970109711097210973109741097510976109771097810979109801098110982109831098410985109861098710988109891099010991109921099310994109951099610997109981099911000110011100211003110041100511006110071100811009110101101111012110131101411015110161101711018110191102011021110221102311024110251102611027110281102911030110311103211033110341103511036110371103811039110401104111042110431104411045110461104711048110491105011051110521105311054110551105611057110581105911060110611106211063110641106511066110671106811069110701107111072110731107411075110761107711078110791108011081110821108311084110851108611087110881108911090110911109211093110941109511096110971109811099111001110111102111031110411105111061110711108111091111011111111121111311114111151111611117111181111911120111211112211123111241112511126111271112811129111301113111132111331113411135111361113711138111391114011141111421114311144111451114611147111481114911150111511115211153111541115511156111571115811159111601116111162111631116411165111661116711168111691117011171111721117311174111751117611177111781117911180111811118211183111841118511186111871118811189111901119111192111931119411195111961119711198111991120011201112021120311204112051120611207112081120911210112111121211213112141121511216112171121811219112201122111222112231122411225112261122711228112291123011231112321123311234112351123611237112381123911240112411124211243112441124511246112471124811249112501125111252112531125411255112561125711258112591126011261112621126311264112651126611267112681126911270112711127211273112741127511276112771127811279112801128111282112831128411285112861128711288112891129011291112921129311294112951129611297112981129911300113011130211303113041130511306113071130811309113101131111312113131131411315113161131711318113191132011321113221132311324113251132611327113281132911330113311133211333113341133511336113371133811339113401134111342113431134411345113461134711348113491135011351113521135311354113551135611357113581135911360113611136211363113641136511366113671136811369113701137111372113731137411375113761137711378113791138011381113821138311384113851138611387113881138911390113911139211393113941139511396113971139811399114001140111402114031140411405114061140711408114091141011411114121141311414114151141611417114181141911420114211142211423114241142511426114271142811429114301143111432114331143411435114361143711438114391144011441114421144311444114451144611447114481144911450114511145211453114541145511456114571145811459114601146111462114631146411465114661146711468114691147011471114721147311474114751147611477114781147911480114811148211483114841148511486114871148811489114901149111492114931149411495114961149711498114991150011501115021150311504115051150611507115081150911510115111151211513115141151511516115171151811519115201152111522115231152411525115261152711528115291153011531115321153311534115351153611537115381153911540115411154211543115441154511546115471154811549115501155111552115531155411555115561155711558115591156011561115621156311564115651156611567115681156911570115711157211573115741157511576115771157811579115801158111582115831158411585115861158711588115891159011591115921159311594115951159611597115981159911600116011160211603116041160511606116071160811609116101161111612116131161411615116161161711618116191162011621116221162311624116251162611627116281162911630116311163211633116341163511636116371163811639116401164111642116431164411645116461164711648116491165011651116521165311654116551165611657116581165911660116611166211663116641166511666116671166811669116701167111672116731167411675116761167711678116791168011681116821168311684116851168611687116881168911690116911169211693116941169511696116971169811699117001170111702117031170411705117061170711708117091171011711117121171311714117151171611717117181171911720117211172211723117241172511726117271172811729117301173111732117331173411735117361173711738117391174011741117421174311744117451174611747117481174911750117511175211753117541175511756117571175811759117601176111762117631176411765117661176711768117691177011771117721177311774117751177611777117781177911780117811178211783117841178511786117871178811789117901179111792117931179411795117961179711798117991180011801118021180311804118051180611807118081180911810118111181211813118141181511816118171181811819118201182111822118231182411825118261182711828118291183011831118321183311834118351183611837118381183911840118411184211843118441184511846118471184811849118501185111852118531185411855118561185711858118591186011861118621186311864118651186611867118681186911870118711187211873118741187511876118771187811879118801188111882118831188411885118861188711888118891189011891118921189311894118951189611897118981189911900119011190211903119041190511906119071190811909119101191111912119131191411915119161191711918119191192011921119221192311924119251192611927119281192911930119311193211933119341193511936119371193811939119401194111942119431194411945119461194711948119491195011951119521195311954119551195611957119581195911960119611196211963119641196511966119671196811969119701197111972119731197411975119761197711978119791198011981119821198311984119851198611987119881198911990119911199211993119941199511996119971199811999120001200112002120031200412005120061200712008120091201012011120121201312014120151201612017120181201912020120211202212023120241202512026120271202812029120301203112032120331203412035120361203712038120391204012041120421204312044120451204612047120481204912050120511205212053120541205512056120571205812059120601206112062120631206412065120661206712068120691207012071120721207312074120751207612077120781207912080120811208212083120841208512086120871208812089120901209112092120931209412095120961209712098120991210012101121021210312104121051210612107121081210912110121111211212113121141211512116121171211812119121201212112122121231212412125121261212712128121291213012131121321213312134121351213612137121381213912140121411214212143121441214512146121471214812149121501215112152121531215412155121561215712158121591216012161121621216312164121651216612167121681216912170121711217212173121741217512176121771217812179121801218112182121831218412185121861218712188121891219012191121921219312194121951219612197121981219912200122011220212203122041220512206122071220812209122101221112212122131221412215122161221712218122191222012221122221222312224122251222612227122281222912230122311223212233122341223512236122371223812239122401224112242122431224412245122461224712248122491225012251122521225312254122551225612257122581225912260122611226212263122641226512266122671226812269122701227112272122731227412275122761227712278122791228012281122821228312284122851228612287122881228912290122911229212293122941229512296122971229812299123001230112302123031230412305123061230712308123091231012311123121231312314123151231612317123181231912320123211232212323123241232512326123271232812329123301233112332123331233412335123361233712338123391234012341123421234312344123451234612347123481234912350123511235212353123541235512356123571235812359123601236112362123631236412365123661236712368123691237012371123721237312374123751237612377123781237912380123811238212383123841238512386123871238812389123901239112392123931239412395123961239712398123991240012401124021240312404124051240612407124081240912410124111241212413124141241512416124171241812419124201242112422124231242412425124261242712428124291243012431124321243312434124351243612437124381243912440124411244212443124441244512446124471244812449124501245112452124531245412455124561245712458124591246012461124621246312464124651246612467124681246912470124711247212473124741247512476124771247812479124801248112482124831248412485124861248712488124891249012491124921249312494124951249612497124981249912500125011250212503125041250512506125071250812509125101251112512125131251412515125161251712518125191252012521125221252312524125251252612527125281252912530125311253212533125341253512536125371253812539125401254112542125431254412545125461254712548125491255012551125521255312554125551255612557125581255912560125611256212563125641256512566125671256812569125701257112572125731257412575125761257712578125791258012581125821258312584125851258612587125881258912590125911259212593125941259512596125971259812599126001260112602126031260412605126061260712608126091261012611126121261312614126151261612617126181261912620126211262212623126241262512626126271262812629126301263112632126331263412635126361263712638126391264012641126421264312644126451264612647126481264912650126511265212653126541265512656126571265812659126601266112662126631266412665126661266712668126691267012671126721267312674126751267612677126781267912680126811268212683126841268512686126871268812689126901269112692126931269412695126961269712698126991270012701127021270312704127051270612707127081270912710127111271212713127141271512716127171271812719127201272112722127231272412725127261272712728127291273012731127321273312734127351273612737127381273912740127411274212743127441274512746127471274812749127501275112752127531275412755127561275712758127591276012761127621276312764127651276612767127681276912770127711277212773127741277512776127771277812779127801278112782127831278412785127861278712788127891279012791127921279312794127951279612797127981279912800128011280212803128041280512806128071280812809128101281112812128131281412815128161281712818128191282012821128221282312824128251282612827128281282912830128311283212833128341283512836128371283812839128401284112842128431284412845128461284712848128491285012851128521285312854128551285612857128581285912860128611286212863128641286512866128671286812869128701287112872128731287412875128761287712878128791288012881128821288312884128851288612887128881288912890128911289212893128941289512896128971289812899129001290112902129031290412905129061290712908129091291012911129121291312914129151291612917129181291912920129211292212923129241292512926129271292812929129301293112932129331293412935129361293712938129391294012941129421294312944129451294612947129481294912950129511295212953129541295512956129571295812959129601296112962129631296412965129661296712968129691297012971129721297312974129751297612977129781297912980129811298212983129841298512986129871298812989129901299112992129931299412995129961299712998129991300013001130021300313004130051300613007130081300913010130111301213013130141301513016130171301813019130201302113022130231302413025130261302713028130291303013031130321303313034130351303613037130381303913040130411304213043130441304513046130471304813049130501305113052130531305413055130561305713058130591306013061130621306313064130651306613067130681306913070130711307213073130741307513076130771307813079130801308113082130831308413085130861308713088130891309013091130921309313094130951309613097130981309913100131011310213103131041310513106131071310813109131101311113112131131311413115131161311713118131191312013121131221312313124131251312613127131281312913130131311313213133131341313513136131371313813139131401314113142131431314413145131461314713148131491315013151131521315313154131551315613157131581315913160131611316213163131641316513166131671316813169131701317113172131731317413175131761317713178131791318013181131821318313184131851318613187131881318913190131911319213193131941319513196131971319813199132001320113202132031320413205132061320713208132091321013211132121321313214132151321613217132181321913220132211322213223132241322513226132271322813229132301323113232132331323413235132361323713238132391324013241132421324313244132451324613247132481324913250132511325213253132541325513256132571325813259132601326113262132631326413265132661326713268132691327013271132721327313274132751327613277132781327913280132811328213283132841328513286132871328813289132901329113292132931329413295132961329713298132991330013301133021330313304133051330613307133081330913310133111331213313133141331513316133171331813319133201332113322133231332413325133261332713328133291333013331133321333313334133351333613337133381333913340133411334213343133441334513346133471334813349133501335113352133531335413355133561335713358133591336013361133621336313364133651336613367133681336913370133711337213373133741337513376133771337813379133801338113382133831338413385133861338713388133891339013391133921339313394133951339613397133981339913400134011340213403134041340513406134071340813409134101341113412134131341413415134161341713418134191342013421134221342313424134251342613427134281342913430134311343213433134341343513436134371343813439134401344113442134431344413445134461344713448134491345013451134521345313454134551345613457134581345913460134611346213463134641346513466134671346813469134701347113472134731347413475134761347713478134791348013481134821348313484134851348613487134881348913490134911349213493134941349513496134971349813499135001350113502135031350413505135061350713508135091351013511135121351313514135151351613517135181351913520135211352213523135241352513526135271352813529135301353113532135331353413535135361353713538135391354013541135421354313544135451354613547135481354913550135511355213553135541355513556135571355813559135601356113562135631356413565135661356713568135691357013571135721357313574135751357613577135781357913580135811358213583135841358513586135871358813589135901359113592135931359413595135961359713598135991360013601136021360313604136051360613607136081360913610136111361213613136141361513616136171361813619136201362113622136231362413625136261362713628136291363013631136321363313634136351363613637136381363913640136411364213643136441364513646136471364813649136501365113652136531365413655136561365713658136591366013661136621366313664136651366613667136681366913670136711367213673136741367513676136771367813679136801368113682136831368413685136861368713688136891369013691136921369313694136951369613697136981369913700137011370213703137041370513706137071370813709137101371113712137131371413715137161371713718137191372013721137221372313724137251372613727137281372913730137311373213733137341373513736137371373813739137401374113742137431374413745137461374713748137491375013751137521375313754137551375613757137581375913760137611376213763137641376513766137671376813769137701377113772137731377413775137761377713778137791378013781137821378313784137851378613787137881378913790137911379213793137941379513796137971379813799138001380113802138031380413805138061380713808138091381013811138121381313814138151381613817138181381913820138211382213823138241382513826138271382813829138301383113832138331383413835138361383713838138391384013841138421384313844138451384613847138481384913850138511385213853138541385513856138571385813859138601386113862138631386413865138661386713868138691387013871138721387313874138751387613877138781387913880138811388213883138841388513886138871388813889138901389113892138931389413895138961389713898138991390013901139021390313904139051390613907139081390913910139111391213913139141391513916139171391813919139201392113922139231392413925139261392713928139291393013931139321393313934139351393613937139381393913940139411394213943139441394513946139471394813949139501395113952139531395413955139561395713958139591396013961139621396313964139651396613967139681396913970139711397213973139741397513976139771397813979139801398113982139831398413985139861398713988139891399013991139921399313994139951399613997139981399914000140011400214003140041400514006140071400814009140101401114012140131401414015140161401714018140191402014021140221402314024140251402614027140281402914030140311403214033140341403514036140371403814039140401404114042140431404414045140461404714048140491405014051140521405314054140551405614057140581405914060140611406214063140641406514066140671406814069140701407114072140731407414075140761407714078140791408014081140821408314084140851408614087140881408914090140911409214093140941409514096140971409814099141001410114102141031410414105141061410714108141091411014111141121411314114141151411614117141181411914120141211412214123141241412514126141271412814129141301413114132141331413414135141361413714138141391414014141141421414314144141451414614147141481414914150141511415214153141541415514156141571415814159141601416114162141631416414165141661416714168141691417014171141721417314174141751417614177141781417914180141811418214183141841418514186141871418814189141901419114192141931419414195141961419714198141991420014201142021420314204142051420614207142081420914210142111421214213142141421514216142171421814219 |
- #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
- /*******************************************************************************
- * Copyright 2016-2025 Intel Corporation
- * Copyright 2024-2025 FUJITSU LIMITED
- * Copyright 2025 Arm Ltd. and affiliates
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *******************************************************************************/
- /// @file
- /// C++ API
- #ifndef ONEAPI_DNNL_DNNL_HPP
- #define ONEAPI_DNNL_DNNL_HPP
- // NOLINTBEGIN(readability-identifier-naming)
- #include "oneapi/dnnl/dnnl_config.h"
- /// @cond DO_NOT_DOCUMENT_THIS
- #include <algorithm>
- #include <cstdlib>
- #include <iterator>
- #include <memory>
- #include <string>
- #include <vector>
- #include <unordered_map>
- #include "oneapi/dnnl/dnnl.h"
- #include "oneapi/dnnl/dnnl_common.hpp"
- /// @endcond
- /// @addtogroup dnnl_api oneDNN API
- /// @{
- /// oneDNN namespace
- namespace dnnl {
- /// @addtogroup dnnl_api_utils Utilities
- /// Utility types and definitions.
- /// @{
- /// @cond DO_NOT_DOCUMENT_THIS
- template <typename T>
- void validate_container_size(const T &v, const char *error_message,
- int min_size = 1, int max_size = -1) {
- const int size = (int)v.size();
- if (size < min_size || (max_size >= 0 && size > max_size))
- DNNL_THROW_ERROR(dnnl_invalid_arguments, error_message);
- }
- /// @endcond
- /// @cond DO_NOT_DOCUMENT_THIS
- template <>
- struct handle_traits<dnnl_memory_desc_t> {
- static dnnl_status_t destructor(dnnl_memory_desc_t p) {
- return dnnl_memory_desc_destroy(p);
- }
- };
- template <>
- struct handle_traits<dnnl_memory_t> {
- static dnnl_status_t destructor(dnnl_memory_t p) {
- return dnnl_memory_destroy(p);
- }
- };
- template <>
- struct handle_traits<dnnl_primitive_desc_t> {
- static dnnl_status_t destructor(dnnl_primitive_desc_t p) {
- return dnnl_primitive_desc_destroy(p);
- }
- };
- template <>
- struct handle_traits<dnnl_primitive_t> {
- static dnnl_status_t destructor(dnnl_primitive_t p) {
- return dnnl_primitive_destroy(p);
- }
- };
- /// @endcond
- /// @} dnnl_api_utils
- struct stream;
- struct memory;
- struct primitive_desc;
- /// @addtogroup dnnl_api_primitives Primitives
- /// Compute primitives
- /// @sa @ref dev_guide_basic_concepts
- /// @{
- /// @addtogroup dnnl_api_primitives_common Common
- /// Common operations to create, destroy and inspect primitives
- /// @{
- /// Base class for all computational primitives.
- struct primitive : public handle<dnnl_primitive_t> {
- /// Kinds of primitives supported by the library.
- enum class kind {
- /// Undefined primitive
- undef = dnnl_undefined_primitive,
- /// A reorder primitive.
- reorder = dnnl_reorder,
- /// A shuffle primitive.
- shuffle = dnnl_shuffle,
- /// A (out-of-place) tensor concatenation primitive.
- concat = dnnl_concat,
- /// A summation primitive.
- sum = dnnl_sum,
- /// A convolution primitive.
- convolution = dnnl_convolution,
- /// A deconvolution primitive.
- deconvolution = dnnl_deconvolution,
- /// An element-wise primitive.
- eltwise = dnnl_eltwise,
- /// An LRN primitive.
- lrn = dnnl_lrn,
- /// A batch normalization primitive.
- batch_normalization = dnnl_batch_normalization,
- /// An inner product primitive.
- inner_product = dnnl_inner_product,
- /// An RNN primitive.
- rnn = dnnl_rnn,
- /// A binary primitive.
- binary = dnnl_binary,
- /// A matmul (matrix multiplication) primitive.
- matmul = dnnl_matmul,
- /// A resampling primitive.
- resampling = dnnl_resampling,
- /// A pooling primitive.
- pooling = dnnl_pooling,
- /// A reduction primitive.
- reduction = dnnl_reduction,
- /// A PReLU primitive.
- prelu = dnnl_prelu,
- /// A softmax primitive.
- softmax = dnnl_softmax,
- /// A layer normalization primitive.
- layer_normalization = dnnl_layer_normalization,
- /// A group normalization primitive
- group_normalization = dnnl_group_normalization,
- };
- using handle::handle;
- /// Default constructor. Constructs an empty object.
- primitive() = default;
- /// Constructs a primitive from a C API primitive descriptor.
- ///
- /// @param c_pd C API primitive descriptor.
- primitive(const_dnnl_primitive_desc_t c_pd);
- /// Constructs a primitive from a C API primitive descriptor and a cache blob.
- ///
- /// @param c_pd C API primitive descriptor.
- /// @param cache_blob Cache blob.
- primitive(const_dnnl_primitive_desc_t c_pd,
- const std::vector<uint8_t> &cache_blob);
- /// Constructs a primitive from a primitive descriptor.
- ///
- /// @param pd Primitive descriptor.
- primitive(const primitive_desc &pd);
- /// Constructs a primitive from a primitive descriptor and a cache blob.
- ///
- /// @param pd Primitive descriptor.
- /// @param cache_blob Cache blob.
- primitive(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob);
- /// Returns the C API primitive descriptor of the underlying C API
- /// primitive.
- ///
- /// @returns The underlying C API primitive descriptor.
- inline const_dnnl_primitive_desc_t get_primitive_desc() const;
- /// Returns the kind of the primitive.
- ///
- /// @returns The primitive kind.
- inline kind get_kind() const;
- /// Returns a cache blob for the primitive.
- ///
- /// @returns Vector containing the cache blob.
- ///
- /// @note The cache blob can be empty. It's the user's responsibility to
- /// check whether it's empty prior to passing it to the primitive
- /// constructor.
- inline std::vector<uint8_t> get_cache_blob() const;
- /// Executes computations specified by the primitive in a specified stream.
- ///
- /// Arguments are passed via an arguments map containing <index,
- /// memory object> pairs. The index must be one of the `DNNL_ARG_*` values
- /// such as `DNNL_ARG_SRC`, and the memory must have a memory descriptor
- /// matching the one returned by
- /// primitive_desc::query_md(#query::exec_arg_md, index) unless using
- /// dynamic shapes (see #DNNL_RUNTIME_DIM_VAL).
- ///
- /// @param astream Stream object. The stream must belong to the same engine
- /// as the primitive.
- /// @param args Arguments map.
- void execute(const stream &astream,
- const std::unordered_map<int, memory> &args) const;
- };
- /// Converts primitive kind enum value from C++ API to C API type.
- ///
- /// @param akind C++ API primitive kind enum value.
- /// @returns Corresponding C API primitive kind enum value.
- inline dnnl_primitive_kind_t convert_to_c(primitive::kind akind) {
- return static_cast<dnnl_primitive_kind_t>(akind);
- }
- const_dnnl_primitive_desc_t primitive::get_primitive_desc() const {
- const_dnnl_primitive_desc_t pd;
- error::wrap_c_api(dnnl_primitive_get_primitive_desc(get(), &pd),
- "could not get a primitive descriptor from a primitive");
- return pd;
- }
- dnnl::primitive::kind primitive::get_kind() const {
- const_dnnl_primitive_desc_t pd = get_primitive_desc();
- // TODO (Roma): the code below is only needed because get_primitive_desc
- // returns a C type.
- dnnl_primitive_kind_t kind;
- error::wrap_c_api(dnnl_primitive_desc_query(
- pd, dnnl_query_primitive_kind, 0, (void *)&kind),
- "could not get a primitive kind from a primitive descriptor");
- return static_cast<dnnl::primitive::kind>(kind);
- }
- std::vector<uint8_t> primitive::get_cache_blob() const {
- size_t size;
- error::wrap_c_api(dnnl_primitive_get_cache_blob(get(), &size, nullptr),
- "could not get cache blob size from a primitive");
- std::vector<uint8_t> cache_blob(size);
- error::wrap_c_api(
- dnnl_primitive_get_cache_blob(get(), &size, cache_blob.data()),
- "could not get a cache blob from a primitive");
- return cache_blob;
- }
- /// @} dnnl_api_primitives_common
- /// @addtogroup dnnl_api_attributes
- ///
- /// A container for parameters that extend primitives behavior.
- ///
- /// Attributes can also contain Post-ops, which are computations executed
- /// after the primitive.
- ///
- /// @sa @ref dev_guide_attributes
- /// @sa @ref dev_guide_attributes_post_ops
- ///
- /// @{
- /// Scratchpad mode
- enum class scratchpad_mode {
- /// The library manages the scratchpad allocation according to the policy
- /// specified by the `DNNL_ENABLE_CONCURRENT_EXEC`
- /// [build option](@ref dev_guide_build_options) (default).
- ///
- /// When `DNNL_ENABLE_CONCURRENT_EXEC=OFF` (default), the library
- /// scratchpad is common to all primitives to reduce the memory footprint.
- /// This configuration comes with limited thread-safety properties, namely
- /// primitives can be created and executed in parallel but cannot migrate
- /// between threads (in other words, each primitive should be executed in
- /// the same thread it was created in).
- ///
- /// When `DNNL_ENABLE_CONCURRENT_EXEC=ON`, the library scratchpad is
- /// private to each primitive. The memory footprint is larger than when
- /// using `DNNL_ENABLE_CONCURRENT_EXEC=OFF` but different primitives can be
- /// created and run concurrently (the same primitive cannot be run
- /// concurrently from two different threads though).
- library = dnnl_scratchpad_mode_library,
- /// The user manages the scratchpad allocation by querying and providing
- /// the scratchpad memory to primitives. This mode is thread-safe as long
- /// as the scratchpad buffers are not used concurrently by two primitive
- /// executions.
- user = dnnl_scratchpad_mode_user,
- };
- /// Converts a scratchpad mode enum value from C++ API to C API type.
- ///
- /// @param mode C++ API scratchpad mode enum value.
- /// @returns Corresponding C API scratchpad mode enum value.
- inline dnnl_scratchpad_mode_t convert_to_c(scratchpad_mode mode) {
- return static_cast<dnnl_scratchpad_mode_t>(mode);
- }
- /// Rounding mode
- enum class rounding_mode {
- /// rounding mode dictated by the floating-point environment
- environment = dnnl_rounding_mode_environment,
- /// stochastic rounding mode where a random bias is added to the
- /// trailing mantissa bits before conversion.
- stochastic = dnnl_rounding_mode_stochastic
- };
- /// Converts a rounding mode enum value from C++ API to C API type.
- ///
- /// @param mode C++ API rounding mode enum value.
- /// @returns Corresponding C API rounding mode enum value.
- inline dnnl_rounding_mode_t convert_to_c(rounding_mode mode) {
- return static_cast<dnnl_rounding_mode_t>(mode);
- }
- /// Propagation kind.
- enum class prop_kind {
- /// Undefined propagation kind.
- undef = dnnl_prop_kind_undef,
- /// Forward data propagation (training mode). In this mode, primitives
- /// perform computations necessary for subsequent backward propagation.
- forward_training = dnnl_forward_training,
- /// Forward data propagation (inference mode). In this mode, primitives
- /// perform only computations that are necessary for inference and omit
- /// computations that are necessary only for backward propagation.
- forward_inference = dnnl_forward_inference,
- /// Forward data propagation,
- /// alias for #dnnl::prop_kind::forward_training.
- forward = dnnl_forward,
- /// Backward propagation (with respect to all parameters).
- backward = dnnl_backward,
- /// Backward data propagation.
- backward_data = dnnl_backward_data,
- /// Backward weights propagation.
- backward_weights = dnnl_backward_weights,
- /// Backward bias propagation.
- backward_bias = dnnl_backward_bias
- };
- /// Converts propagation kind enum value from C++ API to C API type.
- ///
- /// @param akind C++ API propagation kind enum value.
- /// @returns Corresponding C API propagation kind enum value.
- inline dnnl_prop_kind_t convert_to_c(prop_kind akind) {
- return static_cast<dnnl_prop_kind_t>(akind);
- }
- /// Kinds of algorithms.
- enum class algorithm {
- /// Undefined algorithm
- undef = dnnl_alg_kind_undef,
- /// Convolution algorithm that is chosen to be either direct or Winograd
- /// automatically
- convolution_auto = dnnl_convolution_auto,
- /// Direct convolution
- convolution_direct = dnnl_convolution_direct,
- /// Winograd convolution
- convolution_winograd = dnnl_convolution_winograd,
- /// Direct deconvolution
- deconvolution_direct = dnnl_deconvolution_direct,
- /// Winograd deconvolution
- deconvolution_winograd = dnnl_deconvolution_winograd,
- /// Elementwise: rectified linear unit (ReLU)
- eltwise_relu = dnnl_eltwise_relu,
- /// Elementwise: hyperbolic tangent non-linearity (tanh)
- eltwise_tanh = dnnl_eltwise_tanh,
- /// Elementwise: exponential linear unit (ELU)
- eltwise_elu = dnnl_eltwise_elu,
- /// Elementwise: square
- eltwise_square = dnnl_eltwise_square,
- /// Elementwise: abs
- eltwise_abs = dnnl_eltwise_abs,
- /// Elementwise: square root
- eltwise_sqrt = dnnl_eltwise_sqrt,
- /// Elementwise: swish (\f$x \cdot sigmoid(a \cdot x)\f$)
- eltwise_swish = dnnl_eltwise_swish,
- /// Elementwise: linear
- eltwise_linear = dnnl_eltwise_linear,
- /// Elementwise: soft_relu
- eltwise_soft_relu = dnnl_eltwise_soft_relu,
- /// Elementwise: mish
- eltwise_mish = dnnl_eltwise_mish,
- /// Elementwise: logistic
- eltwise_logistic = dnnl_eltwise_logistic,
- /// Elementwise: exponent
- eltwise_exp = dnnl_eltwise_exp,
- /// Elementwise: tanh-based gelu
- eltwise_gelu_tanh = dnnl_eltwise_gelu_tanh,
- /// Elementwise: erf-based gelu
- eltwise_gelu_erf = dnnl_eltwise_gelu_erf,
- /// Elementwise: natural logarithm
- eltwise_log = dnnl_eltwise_log,
- /// Elementwise: clip
- eltwise_clip = dnnl_eltwise_clip,
- /// Eltwise: clip version 2
- eltwise_clip_v2 = dnnl_eltwise_clip_v2,
- /// Elementwise: pow
- eltwise_pow = dnnl_eltwise_pow,
- /// Elementwise: round
- eltwise_round = dnnl_eltwise_round,
- /// Elementwise: hardswish
- eltwise_hardswish = dnnl_eltwise_hardswish,
- /// Elementwise: hardsigmoid
- eltwise_hardsigmoid = dnnl_eltwise_hardsigmoid,
- /// Elementwise: rectified linar unit (ReLU) (dst for backward)
- eltwise_relu_use_dst_for_bwd = dnnl_eltwise_relu_use_dst_for_bwd,
- /// Elementwise: hyperbolic tangent non-linearity (tanh) (dst for backward)
- eltwise_tanh_use_dst_for_bwd = dnnl_eltwise_tanh_use_dst_for_bwd,
- /// Elementwise: exponential linear unit (ELU) (dst for backward)
- eltwise_elu_use_dst_for_bwd = dnnl_eltwise_elu_use_dst_for_bwd,
- /// Elementwise: square root (dst for backward)
- eltwise_sqrt_use_dst_for_bwd = dnnl_eltwise_sqrt_use_dst_for_bwd,
- /// Elementwise: logistic (dst for backward)
- eltwise_logistic_use_dst_for_bwd = dnnl_eltwise_logistic_use_dst_for_bwd,
- /// Elementwise: exponent (dst for backward)
- eltwise_exp_use_dst_for_bwd = dnnl_eltwise_exp_use_dst_for_bwd,
- /// Elementwise: clip version 2 (dst for backward)
- eltwise_clip_v2_use_dst_for_bwd = dnnl_eltwise_clip_v2_use_dst_for_bwd,
- /// Local response normalization (LRN) across multiple channels
- lrn_across_channels = dnnl_lrn_across_channels,
- /// LRN within a single channel
- lrn_within_channel = dnnl_lrn_within_channel,
- /// Max pooling
- pooling_max = dnnl_pooling_max,
- /// Average pooling include padding
- pooling_avg_include_padding = dnnl_pooling_avg_include_padding,
- /// Average pooling exclude padding
- pooling_avg_exclude_padding = dnnl_pooling_avg_exclude_padding,
- /// RNN cell
- vanilla_rnn = dnnl_vanilla_rnn,
- /// LSTM cell
- vanilla_lstm = dnnl_vanilla_lstm,
- /// GRU cell
- vanilla_gru = dnnl_vanilla_gru,
- /// GRU cell with linear before reset. Differs from the vanilla GRU
- /// in how the new memory gate is calculated:
- /// \f$c_t = tanh(W_c*x_t + b_{c_x} + r_t*(U_c*h_{t-1}+b_{c_h})) \f$
- /// LRB GRU expects 4 bias tensors on input:
- /// \f$[b_{u}, b_{r}, b_{c_x}, b_{c_h}]\f$
- lbr_gru = dnnl_lbr_gru,
- /// AUGRU cell
- vanilla_augru = dnnl_vanilla_augru,
- /// AUGRU cell with linear before reset
- lbr_augru = dnnl_lbr_augru,
- /// Binary add
- binary_add = dnnl_binary_add,
- /// Binary mul
- binary_mul = dnnl_binary_mul,
- /// Binary max
- binary_max = dnnl_binary_max,
- /// Binary min
- binary_min = dnnl_binary_min,
- /// Binary div
- binary_div = dnnl_binary_div,
- /// Binary sub
- binary_sub = dnnl_binary_sub,
- /// Binary greater than or equal
- binary_ge = dnnl_binary_ge,
- /// Binary greater than
- binary_gt = dnnl_binary_gt,
- /// Binary less than or equal
- binary_le = dnnl_binary_le,
- /// Binary less than
- binary_lt = dnnl_binary_lt,
- /// Binary equal
- binary_eq = dnnl_binary_eq,
- /// Binary not equal
- binary_ne = dnnl_binary_ne,
- /// Binary select
- binary_select = dnnl_binary_select,
- /// Nearest Neighbor resampling method
- resampling_nearest = dnnl_resampling_nearest,
- /// Linear (Bilinear, Trilinear) resampling method
- resampling_linear = dnnl_resampling_linear,
- /// Reduction using max operation
- reduction_max = dnnl_reduction_max,
- /// Reduction using min operation
- reduction_min = dnnl_reduction_min,
- /// Reduction using sum operation
- reduction_sum = dnnl_reduction_sum,
- /// Reduction using mul operation
- reduction_mul = dnnl_reduction_mul,
- /// Reduction using mean operation
- reduction_mean = dnnl_reduction_mean,
- /// Reduction using norm_lp_max operation
- reduction_norm_lp_max = dnnl_reduction_norm_lp_max,
- /// Reduction using norm_lp_sum operation
- reduction_norm_lp_sum = dnnl_reduction_norm_lp_sum,
- /// Reduction using norm_lp_power_p_max operation
- reduction_norm_lp_power_p_max = dnnl_reduction_norm_lp_power_p_max,
- /// Reduction using norm_lp_power_p_sum operation
- reduction_norm_lp_power_p_sum = dnnl_reduction_norm_lp_power_p_sum,
- /// Softmax, numerically stable
- softmax_accurate = dnnl_softmax_accurate,
- /// LogSoftmax, numerically stable
- softmax_log = dnnl_softmax_log,
- };
- /// Converts algorithm kind enum value from C++ API to C API type.
- /// @param aalgorithm C++ API algorithm kind enum value.
- /// @returns Corresponding C API algorithm kind enum value.
- inline dnnl_alg_kind_t convert_to_c(algorithm aalgorithm) {
- return static_cast<dnnl_alg_kind_t>(aalgorithm);
- }
- /// @} dnnl_api_attributes
- /// @addtogroup dnnl_api_primitives_common
- /// @{
- /// Flags for normalization primitives.
- enum class normalization_flags : unsigned {
- /// Use no normalization flags. If specified, the library computes mean and
- /// variance on forward propagation for training and inference, outputs
- /// them on forward propagation for training, and computes the respective
- /// derivatives on backward propagation.
- ///
- /// @note
- /// Backward propagation of type #dnnl::prop_kind::backward_data has
- /// the same behavior as #dnnl::prop_kind::backward.
- none = dnnl_normalization_flags_none,
- /// Use global statistics. If specified, the library uses mean and
- /// variance provided by the user as an input on forward propagation and
- /// does not compute their derivatives on backward propagation. Otherwise,
- /// the library computes mean and variance on forward propagation for
- /// training and inference, outputs them on forward propagation for
- /// training, and computes the respective derivatives on backward
- /// propagation.
- use_global_stats = dnnl_use_global_stats,
- /// Use scale parameter. If specified, the user is expected to pass scale as
- /// input on forward propagation. On backward propagation of type
- /// #dnnl::prop_kind::backward, the library computes its derivative.
- use_scale = dnnl_use_scale,
- /// Use shift parameter. If specified, the user is expected to pass shift as
- /// input on forward propagation. On backward propagation of type
- /// #dnnl::prop_kind::backward, the library computes its derivative.
- use_shift = dnnl_use_shift,
- /// Fuse normalization with ReLU. On training, normalization will require
- /// the workspace to implement backward propagation. On inference, the
- /// workspace is not required and behavior is the same as when normalization
- /// is fused with ReLU using the post-ops API.
- ///
- /// @note
- /// The flag implies negative slope being 0. On training this is the only
- /// configuration supported. For inference, to use non-zero negative slope
- /// consider using @ref dev_guide_attributes_post_ops.
- fuse_norm_relu = dnnl_fuse_norm_relu,
- /// Fuse normalization with an elementwise binary Add operation
- /// followed by ReLU.
- /// During training, normalization will require a workspace to implement
- /// backward propagation. For inference, the workspace is not needed.
- /// On forward propagation, an elementwise binary Add operation is applied
- /// to the normalization results with an additional input tensor, followed
- /// by ReLU with a negative slope of 0.
- /// On backward propagation, the result of the backward ReLU operation
- /// with the input tensor and workspace from the forward pass is saved
- /// to an extra output tensor, and backward normalization is performed.
- fuse_norm_add_relu = dnnl_fuse_norm_add_relu,
- /// Use Root Mean Square (RMS) Normalization. In forward propagation,
- /// the mean is considered zero, and RMS norm is used instead of variance
- /// for scaling. Only the RMS norm is output during forward propagation for
- /// training. In backward propagation, the library calculates the derivative
- /// with respect to the RMS norm only, assuming the mean is zero.
- ///
- /// @note
- /// When used with #dnnl::normalization_flags::use_global_stats,
- /// only RMS norm is required to be provided as input.
- rms_norm = dnnl_rms_norm,
- };
- /// Converts normalization flags enum value from C++ API to C API type.
- /// @param flags C++ API normalization flags enum value.
- /// @returns Corresponding C API normalization flags enum value.
- inline dnnl_normalization_flags_t convert_to_c(normalization_flags flags) {
- return static_cast<dnnl_normalization_flags_t>(flags);
- }
- /// @} dnnl_api_primitives_common
- /// @addtogroup dnnl_api_rnn
- /// @{
- /// RNN cell flags.
- enum class rnn_flags : unsigned {
- /// Undefined RNN flags
- undef = dnnl_rnn_flags_undef,
- /// Do not add weights gradient to existing diff_weights memory
- diff_weights_overwrite = dnnl_rnn_flags_diff_weights_overwrite,
- };
- /// Converts RNN cell flags enum value from C++ API to C API type.
- /// @param flags C++ API RNN cell flags enum value.
- /// @returns Corresponding C API RNN cell flags enum value.
- inline dnnl_rnn_flags_t convert_to_c(rnn_flags flags) {
- return static_cast<dnnl_rnn_flags_t>(flags);
- }
- DNNL_DEFINE_BITMASK_OPS(normalization_flags)
- DNNL_DEFINE_BITMASK_OPS(rnn_flags)
- /// A direction of RNN primitive execution
- enum class rnn_direction {
- /// Undefined RNN direction.
- undef = dnnl_rnn_direction_undef,
- /// Unidirectional execution of RNN primitive from left to right.
- unidirectional_left2right = dnnl_unidirectional_left2right,
- /// Unidirectional execution of RNN primitive from right to left.
- unidirectional_right2left = dnnl_unidirectional_right2left,
- /// Bidirectional execution of RNN primitive with concatenation of the
- /// results.
- bidirectional_concat = dnnl_bidirectional_concat,
- /// Bidirectional execution of RNN primitive with summation of the
- /// results.
- bidirectional_sum = dnnl_bidirectional_sum,
- };
- /// Converts RNN direction enum value from C++ API to C API type.
- /// @param dir C++ API RNN direction enum value.
- /// @returns Corresponding C API RNN direction enum value.
- inline dnnl_rnn_direction_t convert_to_c(rnn_direction dir) {
- return static_cast<dnnl_rnn_direction_t>(dir);
- }
- /// @} dnnl_api_rnn
- /// @addtogroup dnnl_api_primitives_common
- /// @{
- /// Primitive descriptor query specification.
- ///
- /// In general, queries are not used with the C++ API because most queries are
- /// implemented as class members.
- ///
- /// See @ref dnnl_query_t for more information.
- enum class query {
- /// no query
- undef = dnnl_query_undef,
- /// execution engine
- engine = dnnl_query_engine,
- /// primitive kind
- primitive_kind = dnnl_query_primitive_kind,
- /// number of inputs expected
- num_of_inputs_s32 = dnnl_query_num_of_inputs_s32,
- /// number of outputs expected
- num_of_outputs_s32 = dnnl_query_num_of_outputs_s32,
- /// runtime estimation (seconds), unimplemented
- time_estimate_f64 = dnnl_query_time_estimate_f64,
- /// memory required for scratchpad (bytes)
- ///
- /// @sa @ref dev_guide_attributes_scratchpad
- memory_consumption_s64 = dnnl_query_memory_consumption_s64,
- /// scratchpad engine
- ///
- /// engine to be used for creating scratchpad memory
- scratchpad_engine = dnnl_query_scratchpad_engine,
- /// reorder source engine
- reorder_src_engine = dnnl_query_reorder_src_engine,
- /// reorder destination engine
- reorder_dst_engine = dnnl_query_reorder_dst_engine,
- /// implementation name
- impl_info_str = dnnl_query_impl_info_str,
- /// propagation kind
- prop_kind = dnnl_query_prop_kind,
- /// size of cache blob ID in bytes
- cache_blob_id_size_s64 = dnnl_query_cache_blob_id_size_s64,
- /// cache blob ID (pointer to array)
- cache_blob_id = dnnl_query_cache_blob_id,
- /// strides
- strides = dnnl_query_strides,
- /// dilations
- dilations = dnnl_query_dilations,
- /// left padding
- padding_l = dnnl_query_padding_l,
- /// right padding
- padding_r = dnnl_query_padding_r,
- /// epsilon
- epsilon_f32 = dnnl_query_epsilon_f32,
- /// flags
- flags = dnnl_query_flags,
- /// algorithm kind
- alg_kind = dnnl_query_alg_kind,
- /// alpha
- alpha_f32 = dnnl_query_alpha_f32,
- /// beta
- beta_f32 = dnnl_query_beta_f32,
- /// axis
- axis_s32 = dnnl_query_axis_s32,
- /// LRN parameter local size
- local_size_s64 = dnnl_query_local_size_s64,
- /// LRN parameter K
- k_f32 = dnnl_query_k_f32,
- /// Reduction parameter P
- p_f32 = dnnl_query_p_f32,
- /// Resampling parameter factors
- factors = dnnl_query_factors,
- /// RNN parameter cell kind
- cell_kind = dnnl_query_cell_kind,
- /// RNN parameter direction
- direction = dnnl_query_direction,
- /// RNN parameter activation kind
- activation_kind = dnnl_query_activation_kind,
- /// Pooling parameter kernel
- kernel = dnnl_query_kernel,
- /// Shuffle parameter group size
- group_size_s64 = dnnl_query_group_size_s64,
- /// source memory desc
- src_md = dnnl_query_src_md,
- /// source gradient (diff) memory desc
- diff_src_md = dnnl_query_diff_src_md,
- /// weights memory descriptor desc
- weights_md = dnnl_query_weights_md,
- /// weights gradient (diff) memory desc
- diff_weights_md = dnnl_query_diff_weights_md,
- /// destination memory desc
- dst_md = dnnl_query_dst_md,
- /// destination gradient (diff) memory desc
- diff_dst_md = dnnl_query_diff_dst_md,
- /// workspace memory desc
- workspace_md = dnnl_query_workspace_md,
- /// scratchpad memory desc
- scratchpad_md = dnnl_query_scratchpad_md,
- /// memory desc of an execute argument
- exec_arg_md = dnnl_query_exec_arg_md,
- /// number of dimensions
- ndims_s32 = dnnl_query_ndims_s32,
- /// vector of dimensions
- dims = dnnl_query_dims,
- /// data type
- data_type = dnnl_query_data_type,
- /// submemory offset
- submemory_offset_s64 = dnnl_query_submemory_offset_s64,
- /// vector of padded dimensions
- padded_dims = dnnl_query_padded_dims,
- /// vector of padded offsets
- padded_offsets = dnnl_query_padded_offsets,
- /// format kind
- format_kind = dnnl_query_format_kind,
- /// number of innermost blocks
- inner_nblks_s32 = dnnl_query_inner_nblks_s32,
- /// vector of sizes of the innermost blocks
- inner_blks = dnnl_query_inner_blks,
- /// vector of logical indices of the blocks
- inner_idxs = dnnl_query_inner_idxs,
- /// Sparse encoding
- sparse_encoding = dnnl_query_sparse_encoding,
- /// Number of non-zero entries
- nnz_s64 = dnnl_query_nnz_s64,
- /// Number of buffers required for a memory descriptor
- num_handles_s32 = dnnl_query_num_handles_s32,
- };
- /// Converts query enum value from C++ API to C API type.
- /// @param aquery C++ API query enum value.
- /// @returns Corresponding C API query enum value.
- inline dnnl_query_t convert_to_c(query aquery) {
- return static_cast<dnnl_query_t>(aquery);
- }
- /// @} dnnl_api_primitives_common
- /// @} dnnl_api_primitives
- /// @addtogroup dnnl_api_memory Memory
- ///
- /// A container that describes and stores data. Memory objects can contain
- /// data of various types and formats. There are two levels of abstraction:
- ///
- /// 1. **Memory descriptor** -- engine-agnostic logical description of data
- /// (number of dimensions, dimension sizes, and data type), and,
- /// optionally, the information about the physical format of data in
- /// memory. If this information is not known yet, a memory descriptor can
- /// be created with #dnnl::memory::format_tag::any. This allows
- /// compute-intensive primitives to choose the best format for
- /// computation. The user is responsible for reordering the data into the
- /// chosen format when formats do not match.
- ///
- /// A memory descriptor can be initialized either by specifying dimensions
- /// and a memory format tag or strides for each of them, or by
- /// manipulating the dnnl_memory_desc_t structure directly.
- ///
- /// @warning
- /// The latter approach requires understanding how the physical data
- /// representation is mapped to the structure and is discouraged. This
- /// topic is discussed in @ref dev_guide_understanding_memory_formats.
- ///
- /// The user can query the amount of memory required by a memory
- /// descriptor using the #dnnl::memory::desc::get_size() function. The
- /// size of data in general cannot be computed as the product of
- /// dimensions multiplied by the size of the data type. So users are
- /// required to use this function for better code portability.
- ///
- /// Two memory descriptors can be compared using the equality and
- /// inequality operators. The comparison is especially useful when
- /// checking whether it is necessary to reorder data from the user's data
- /// format to a primitive's format.
- ///
- /// 2. **Memory object** -- an engine-specific object that handles the memory
- /// buffer and its description (a memory descriptor). For the CPU engine or
- /// with USM, the memory buffer handle is simply a pointer to @c void. The
- /// memory buffer can be queried using #dnnl::memory::get_data_handle() and
- /// set using #dnnl::memory::set_data_handle(). The underlying SYCL buffer,
- /// when used, can be queried using #dnnl::sycl_interop::get_buffer and set
- /// using #dnnl::sycl_interop::set_buffer. A memory object can also be
- /// queried for the underlying memory descriptor and for its engine using
- /// #dnnl::memory::get_desc() and dnnl::memory::get_engine().
- ///
- /// Along with ordinary memory descriptors with all dimensions being positive,
- /// the library supports *zero-volume* memory descriptors with one or more
- /// dimensions set to zero. This is used to support the NumPy\* convention.
- /// If a zero-volume memory is passed to a primitive, the primitive typically
- /// does not perform any computations with this memory. For example:
- ///
- /// - A concatenation primitive would ignore all memory object with zeroes in
- /// the concat dimension / axis.
- ///
- /// - A forward convolution with a source memory object with zero in the
- /// minibatch dimension would always produce a destination memory object
- /// with a zero in the minibatch dimension and perform no computations.
- ///
- /// - However, a forward convolution with a zero in one of the weights
- /// dimensions is ill-defined and is considered to be an error by the
- /// library because there is no clear definition of what the output values
- /// should be.
- ///
- /// Memory buffer of a zero-volume memory is never accessed.
- ///
- /// @{
- /// Memory object.
- ///
- /// A memory object encapsulates a handle to a memory buffer allocated on a
- /// specific engine, tensor dimensions, data type, and memory format, which is
- /// the way tensor indices map to offsets in linear memory space. Memory
- /// objects are passed to primitives during execution.
- struct memory : public handle<dnnl_memory_t> {
- using handle::handle;
- /// Integer type for representing dimension sizes and indices.
- using dim = dnnl_dim_t;
- /// Vector of dimensions. Implementations are free to force a limit on the
- /// vector's length.
- using dims = std::vector<dim>;
- /// Helper function that validates that an `std::vector` of dimensions can
- /// be safely converted to the C API array ::dnnl_dims_t. Throws if
- /// validation fails.
- ///
- /// @param v Vector of dimensions.
- /// @param min_size Minimum expected size of the vector.
- template <typename T>
- static void validate_dims(const std::vector<T> &v, int min_size = 0) {
- validate_container_size(
- v, "dimensions are invalid", min_size, DNNL_MAX_NDIMS);
- }
- /// Data type specification.
- enum class data_type {
- /// Undefined data type (used for empty memory descriptors).
- undef = dnnl_data_type_undef,
- /// 4-bit float data type with 3-bit exponent and 0 bit mantissa.
- f4_e3m0 = dnnl_f4_e3m0,
- /// [MX-compliant 4-bit float data type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 2-bit exponent and 1 bit mantissa.
- f4_e2m1 = dnnl_f4_e2m1,
- /// [MX-compliant 8-bit compliant scale data type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 8-bit exponent.
- e8m0 = dnnl_e8m0,
- /// [OFP8 standard 8-bit floating-point](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-06-20-pdf)
- /// with a 5-bit exponent and a 2-bit mantissa.
- f8_e5m2 = dnnl_f8_e5m2,
- /// [OFP8 standard 8-bit floating-point](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-06-20-pdf)
- /// with a 4-bit exponent and a 3-bit mantissa.
- f8_e4m3 = dnnl_f8_e4m3,
- /// [16-bit/half-precision floating point](https://en.wikipedia.org/wiki/Half-precision_floating-point_format).
- f16 = dnnl_f16,
- /// non-standard
- /// [16-bit floating point with 7-bit mantissa](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format).
- bf16 = dnnl_bf16,
- /// [32-bit/single-precision floating point](https://en.wikipedia.org/wiki/Single-precision_floating-point_format).
- f32 = dnnl_f32,
- //// [64-bit/double-precision floating point](https://en.wikipedia.org/wiki/Double-precision_floating-point_format).
- f64 = dnnl_f64,
- /// 32-bit signed integer.
- s32 = dnnl_s32,
- /// 8-bit signed integer.
- s8 = dnnl_s8,
- /// 8-bit unsigned integer.
- u8 = dnnl_u8,
- /// 4-bit signed integer.
- s4 = dnnl_s4,
- /// 4-bit unsigned integer.
- u4 = dnnl_u4,
- };
- /// Returns size of data type in bytes.
- /// @returns The number of bytes occupied by data type.
- static size_t data_type_size(data_type adata_type) {
- return dnnl_data_type_size(convert_to_c(adata_type));
- }
- /// Memory format kind
- enum class format_kind {
- /// Undefined memory format kind, used for empty memory descriptors.
- undef = dnnl_format_kind_undef,
- /// A special format kind that indicates that the actual format will be
- /// selected by a primitive automatically.
- any = dnnl_format_kind_any,
- /// A tensor in a generic format described by the stride and blocking
- /// values in each dimension.
- blocked = dnnl_blocked,
- /// Format kind for sparse tensors.
- sparse = dnnl_format_kind_sparse,
- /// Format kind for host scalars.
- host_scalar = dnnl_format_kind_host_scalar,
- /// A special format kind that indicates that tensor format is opaque.
- opaque = dnnl_format_kind_opaque,
- };
- /// Sparse encodings.
- /// @sa @ref dev_guide_sparsity
- enum class sparse_encoding {
- /// Undefined sparse encoding kind, used for empty memory descriptors.
- undef = dnnl_sparse_encoding_undef,
- /// Compressed Sparse Row (CSR) encoding.
- csr = dnnl_csr,
- /// An encoding that is used for an opaque storage schema for
- /// tensors with unstructured sparsity. A memory descriptor with the
- /// packed encoding cannot be used to create a memory object. It can
- /// only be used to create a primitive descriptor to query the
- /// actual memory descriptor (similar to the format tag `any`).
- packed = dnnl_packed,
- /// Coordinate Sparse (COO) encoding.
- coo = dnnl_coo,
- };
- /// Memory format tag specification.
- ///
- /// Memory format tags can be further divided into two categories:
- ///
- /// - Domain-agnostic names, i.e. names that do not depend on the tensor
- /// usage in the specific primitive. These names use letters from `a`
- /// to `f` to denote logical dimensions and form the order in which the
- /// dimensions are laid in memory. For example,
- /// #dnnl::memory::format_tag::ab is used to denote a 2D tensor where the
- /// second logical dimension (denoted as `b`) is the innermost, i.e.
- /// has stride = 1, and the first logical dimension (`a`) is laid out in
- /// memory with stride equal to the size of the second dimension. On the
- /// other hand, #dnnl::memory::format_tag::ba is the transposed version
- /// of the same tensor: the outermost dimension (`a`) becomes the
- /// innermost one.
- ///
- /// - Domain-specific names, i.e. names that make sense only in the
- /// context of a certain domain, such as CNN. These names are
- /// aliases to the corresponding domain-agnostic tags and used mostly
- /// for convenience. For example, #dnnl::memory::format_tag::nc
- /// is used to denote 2D CNN activations tensor memory format, where
- /// the channels dimension is the innermost one and the batch dimension
- /// is the outermost one. Moreover, #dnnl::memory::format_tag::nc is
- /// an alias for #dnnl::memory::format_tag::ab, because for
- /// CNN primitives the logical dimensions of activations tensors come
- /// in order: batch, channels, spatial. In other words, batch
- /// corresponds to the first logical dimension (`a`), and channels
- /// correspond to the second one (`b`).
- ///
- /// The following domain-specific notation applies to memory format tags:
- /// - @c 'n' denotes the mini-batch dimension
- /// - @c 'c' denotes a channels dimension
- /// - When there are multiple channel dimensions (for example,
- /// in convolution weights tensor), @c 'i' and @c 'o' denote dimensions
- /// of input and output channels
- /// - @c 'g' denotes a groups dimension for convolution weights
- /// - @c 'd', @c 'h', and @c 'w' denote spatial depth, height, and width
- /// respectively
- ///
- /// See @ref dnnl_format_tag_t for a detailed description.
- enum class format_tag {
- /// Undefined memory format tag
- undef = dnnl_format_tag_undef,
- /// Placeholder memory format tag. Used to instruct the primitive to
- /// select a format automatically.
- any = dnnl_format_tag_any,
- /// plain 1D tensor
- a = dnnl_a,
- /// plain 2D tensor
- ab = dnnl_ab,
- /// permuted 2D tensor
- ba = dnnl_ba,
- /// plain 3D tensor
- abc = dnnl_abc,
- /// permuted 3D tensor
- acb = dnnl_acb,
- /// permuted 3D tensor
- bac = dnnl_bac,
- /// permuted 3D tensor
- bca = dnnl_bca,
- /// permuted 3D tensor
- cba = dnnl_cba,
- /// plain 4D tensor
- abcd = dnnl_abcd,
- /// permuted 4D tensor
- abdc = dnnl_abdc,
- /// permuted 4D tensor
- acbd = dnnl_acbd,
- /// permuted 4D tensor
- acdb = dnnl_acdb,
- /// permuted 4D tensor
- adbc = dnnl_adbc,
- /// permuted 4D tensor
- bacd = dnnl_bacd,
- /// permuted 4D tensor
- bcda = dnnl_bcda,
- /// permuted 4D tensor
- cdba = dnnl_cdba,
- /// permuted 4D tensor
- dcab = dnnl_dcab,
- /// plain 5D tensor
- abcde = dnnl_abcde,
- /// permuted 5D tensor
- abdec = dnnl_abdec,
- /// permuted 5D tensor
- acbde = dnnl_acbde,
- /// permuted 5D tensor
- acdeb = dnnl_acdeb,
- /// permuted 5D tensor
- bacde = dnnl_bacde,
- /// permuted 5D tensor
- bcdea = dnnl_bcdea,
- /// permuted 5D tensor
- cdeba = dnnl_cdeba,
- /// permuted 5D tensor
- decab = dnnl_decab,
- /// permuted 5D tensor
- abced = dnnl_abced,
- /// plain 6D tensor
- abcdef = dnnl_abcdef,
- /// permuted 6D tensor
- abdfce = dnnl_abdfce,
- /// permuted 6D tensor
- acbdef = dnnl_acbdef,
- /// permuted 6D tensor
- abdefc = dnnl_abdefc,
- /// permuted 6D tensor
- defcab = dnnl_defcab,
- /// permuted 6D tensor
- abcdfe = dnnl_abcdfe,
- /// plain 7D tensor
- abcdefg = dnnl_abcdefg,
- /// permuted 7D tensor
- abcdegf = dnnl_abcdegf,
- /// plain 8D tensor
- abcdefgh = dnnl_abcdefgh,
- /// permuted 8D tensor
- abcdefhg = dnnl_abcdefhg,
- /// plain 9D tensor
- abcdefghi = dnnl_abcdefghi,
- /// permuted 9D tensor
- abcdefgih = dnnl_abcdefgih,
- /// plain 10D tensor
- abcdefghij = dnnl_abcdefghij,
- /// permuted 10D tensor
- abcdefghji = dnnl_abcdefghji,
- /// plain 11D tensor
- abcdefghijk = dnnl_abcdefghijk,
- /// permuted 11D tensor
- abcdefghikj = dnnl_abcdefghikj,
- /// plain 12D tensor
- abcdefghijkl = dnnl_abcdefghijkl,
- /// permuted 12D tensor
- abcdefghijlk = dnnl_abcdefghijlk,
- /// 1D tensor; an alias for #dnnl::memory::format_tag::a
- x = a,
- /// 2D CNN activations tensor; an alias for #dnnl::memory::format_tag::ab
- nc = ab,
- /// 2D CNN activations tensor; an alias for #dnnl::memory::format_tag::ba
- cn = ba,
- /// 2D RNN statistics tensor; an alias for #dnnl::memory::format_tag::ab
- tn = ab,
- /// 2D RNN statistics tensor; an alias for #dnnl::memory::format_tag::ba
- nt = ba,
- /// 3D CNN activations tensor; an alias for #dnnl::memory::format_tag::abc
- ncw = abc,
- /// 3D CNN activations tensor; an alias for #dnnl::memory::format_tag::acb
- nwc = acb,
- /// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::abcd
- nchw = abcd,
- /// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::acdb
- nhwc = acdb,
- /// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::bcda
- chwn = bcda,
- /// 5D CNN activations tensor; an alias for #dnnl::memory::format_tag::abcde
- ncdhw = abcde,
- /// 5D CNN activations tensor; an alias for #dnnl::memory::format_tag::acdeb
- ndhwc = acdeb,
- /// 2D CNN weights tensor; an alias for #dnnl::memory::format_tag::ab
- oi = ab,
- /// 2D CNN weights tensor; an alias for #dnnl::memory::format_tag::ba
- io = ba,
- /// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::abc
- oiw = abc,
- /// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::acb
- owi = acb,
- /// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::cba
- wio = cba,
- /// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::bca
- iwo = bca,
- /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::abcd
- oihw = abcd,
- /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::cdba
- hwio = cdba,
- /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::acdb
- ohwi = acdb,
- /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::bcda
- ihwo = bcda,
- /// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::bacd
- iohw = bacd,
- /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::abcde
- oidhw = abcde,
- /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::cdeba
- dhwio = cdeba,
- /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::acdeb
- odhwi = acdeb,
- /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::bacde
- iodhw = bacde,
- /// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::bcdea
- idhwo = bcdea,
- /// 4D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcd
- goiw = abcd,
- /// 4D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abdc
- gowi = abdc,
- /// 4D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::dcab
- wigo = dcab,
- /// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abdec
- gohwi = abdec,
- /// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcde
- goihw = abcde,
- /// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::decab
- hwigo = decab,
- /// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::acbde
- giohw = acbde,
- /// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcdef
- goidhw = abcdef,
- /// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcdef
- giodhw = acbdef,
- /// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abdefc
- godhwi = abdefc,
- /// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::defcab
- dhwigo = defcab,
- /// 3D RNN data tensor in the format (seq_length, batch, input
- /// channels); an alias for #dnnl::memory::format_tag::abc.
- tnc = abc,
- /// 3D RNN data tensor in the format (batch, seq_length, input
- /// channels); an alias for #dnnl::memory::format_tag::bac.
- ntc = bac,
- /// 4D RNN states tensor in the format (num_layers, num_directions,
- /// batch, state channels); an alias for #dnnl::memory::format_tag::abcd.
- ldnc = abcd,
- /// 5D RNN weights tensor in the format (num_layers, num_directions,
- /// input_channels, num_gates, output_channels);
- /// an alias for #dnnl::memory::format_tag::abcde.
- ///
- /// - For LSTM cells, the gates order is input, forget, candidate
- /// and output gate.
- /// - For GRU cells, the gates order is update, reset and output gate.
- ldigo = abcde,
- /// 5D RNN weights tensor in the format (num_layers, num_directions,
- /// num_gates, output_channels, input_channels);
- /// an alias for #dnnl::memory::format_tag::abdec.
- ///
- /// - For LSTM cells, the gates order is input, forget, candidate
- /// and output gate.
- /// - For GRU cells, the gates order is update, reset and output gate.
- ldgoi = abdec,
- /// 4D LSTM projection tensor in the format (num_layers, num_directions,
- /// num_channels_in_hidden_state, num_channels_in_recurrent_projection);
- /// an alias for #dnnl::memory::format_tag::abcd.
- ldio = abcd,
- /// 4D LSTM projection tensor in the format (num_layers, num_directions,
- /// num_channels_in_recurrent_projection, num_channels_in_hidden_state);
- /// an alias for #dnnl::memory::format_tag::abdc.
- ldoi = abdc,
- /// 4D RNN bias tensor in the format (num_layers, num_directions,
- /// num_gates, output_channels);
- /// an alias for #dnnl::memory::format_tag::abcd.
- ///
- /// - For LSTM cells, the gates order is input, forget, candidate
- /// and output gate.
- /// - For GRU cells, the gates order is update, reset and output gate.
- ldgo = abcd,
- // Opaque blocked formats
- AB16b16a = dnnl_AB16b16a,
- AB16b32a = dnnl_AB16b32a,
- AB16b48a = dnnl_AB16b48a,
- AB16b64a = dnnl_AB16b64a,
- AB8b16a2b = dnnl_AB8b16a2b,
- AB8b32a2b = dnnl_AB8b32a2b,
- AB8b64a2b = dnnl_AB8b64a2b,
- AB4b16a4b = dnnl_AB4b16a4b,
- AB4b32a4b = dnnl_AB4b32a4b,
- AB4b64a4b = dnnl_AB4b64a4b,
- AB16b16a4b = dnnl_AB16b16a4b,
- AB16b32a4b = dnnl_AB16b32a4b,
- AB16b48a4b = dnnl_AB16b48a4b,
- AB16b64a4b = dnnl_AB16b64a4b,
- AB16b16a2b = dnnl_AB16b16a2b,
- AB16b32a2b = dnnl_AB16b32a2b,
- AB16b48a2b = dnnl_AB16b48a2b,
- AB16b64a2b = dnnl_AB16b64a2b,
- Ab4a = dnnl_Ab4a,
- Ab8a = dnnl_Ab8a,
- Ab32a = dnnl_Ab32a,
- Abc16a = dnnl_Abc16a,
- ABc16a16b = dnnl_ABc16a16b,
- ABc4a4b = dnnl_ABc4a4b,
- aBc16b = dnnl_aBc16b,
- aBc32b = dnnl_aBc32b,
- ABc16b16a = dnnl_ABc16b16a,
- AcB16b16a = dnnl_AcB16b16a,
- ABc16b32a = dnnl_ABc16b32a,
- AcB16b32a = dnnl_AcB16b32a,
- ABc16b48a = dnnl_ABc16b48a,
- AcB16b48a = dnnl_AcB16b48a,
- ABc16b64a = dnnl_ABc16b64a,
- AcB16b64a = dnnl_AcB16b64a,
- Abc4a = dnnl_Abc4a,
- aBc4b = dnnl_aBc4b,
- ABc4b16a4b = dnnl_ABc4b16a4b,
- AcB4b16a4b = dnnl_AcB4b16a4b,
- ABc4b32a4b = dnnl_ABc4b32a4b,
- AcB4b32a4b = dnnl_AcB4b32a4b,
- ABc4b64a4b = dnnl_ABc4b64a4b,
- AcB4b64a4b = dnnl_AcB4b64a4b,
- ABc2b8a4b = dnnl_ABc2b8a4b,
- ABc16a16b2a = dnnl_ABc16a16b2a,
- ABc16b16a4b = dnnl_ABc16b16a4b,
- ABc16b32a4b = dnnl_ABc16b32a4b,
- ABc16b48a4b = dnnl_ABc16b48a4b,
- ABc16b64a4b = dnnl_ABc16b64a4b,
- ABc16b16a2b = dnnl_ABc16b16a2b,
- ABc16b32a2b = dnnl_ABc16b32a2b,
- ABc16b48a2b = dnnl_ABc16b48a2b,
- ABc16b64a2b = dnnl_ABc16b64a2b,
- ABc4b4a = dnnl_ABc4b4a,
- ABc8a16b2a = dnnl_ABc8a16b2a,
- ABc8a8b = dnnl_ABc8a8b,
- ABc8a4b = dnnl_ABc8a4b,
- aBc8b = dnnl_aBc8b,
- ABc8b16a2b = dnnl_ABc8b16a2b,
- AcB8b16a2b = dnnl_AcB8b16a2b,
- ABc8b32a2b = dnnl_ABc8b32a2b,
- AcB8b32a2b = dnnl_AcB8b32a2b,
- ABc8b64a2b = dnnl_ABc8b64a2b,
- AcB8b64a2b = dnnl_AcB8b64a2b,
- ABc8b8a = dnnl_ABc8b8a,
- AcB8b8a = dnnl_AcB8b8a,
- Abcd8a = dnnl_Abcd8a,
- Abcd16a = dnnl_Abcd16a,
- Abcd32a = dnnl_Abcd32a,
- ABcd16a16b = dnnl_ABcd16a16b,
- aBcd16b = dnnl_aBcd16b,
- aBcd32b = dnnl_aBcd32b,
- ABcd16b16a = dnnl_ABcd16b16a,
- AcdB16b16a = dnnl_AcdB16b16a,
- ABcd16b32a = dnnl_ABcd16b32a,
- AcdB16b32a = dnnl_AcdB16b32a,
- ABcd16b48a = dnnl_ABcd16b48a,
- AcdB16b48a = dnnl_AcdB16b48a,
- ABcd16b64a = dnnl_ABcd16b64a,
- AcdB16b64a = dnnl_AcdB16b64a,
- aBCd16b16c = dnnl_aBCd16b16c,
- aBCd16c16b = dnnl_aBCd16c16b,
- Abcd4a = dnnl_Abcd4a,
- aBcd4b = dnnl_aBcd4b,
- ABcd4b16a4b = dnnl_ABcd4b16a4b,
- AcdB4b16a4b = dnnl_AcdB4b16a4b,
- ABcd4b32a4b = dnnl_ABcd4b32a4b,
- AcdB4b32a4b = dnnl_AcdB4b32a4b,
- ABcd4b64a4b = dnnl_ABcd4b64a4b,
- AcdB4b64a4b = dnnl_AcdB4b64a4b,
- ABcd2b8a4b = dnnl_ABcd2b8a4b,
- ABcd4b4a = dnnl_ABcd4b4a,
- ABcd4a4b = dnnl_ABcd4a4b,
- aBCd4c16b4c = dnnl_aBCd4c16b4c,
- aBCd2c8b4c = dnnl_aBCd2c8b4c,
- ABcd16a16b2a = dnnl_ABcd16a16b2a,
- ABcd16b16a4b = dnnl_ABcd16b16a4b,
- ABcd16b32a4b = dnnl_ABcd16b32a4b,
- ABcd16b48a4b = dnnl_ABcd16b48a4b,
- ABcd16b64a4b = dnnl_ABcd16b64a4b,
- ABcd16b16a2b = dnnl_ABcd16b16a2b,
- ABcd16b32a2b = dnnl_ABcd16b32a2b,
- ABcd16b48a2b = dnnl_ABcd16b48a2b,
- ABcd16b64a2b = dnnl_ABcd16b64a2b,
- aBCd16b16c2b = dnnl_aBCd16b16c2b,
- aBCd16c16b4c = dnnl_aBCd16c16b4c,
- aBCd16c16b2c = dnnl_aBCd16c16b2c,
- aBCd4c4b = dnnl_aBCd4c4b,
- aBCd4b4c = dnnl_aBCd4b4c,
- ABcd8a16b2a = dnnl_ABcd8a16b2a,
- ABcd8a8b = dnnl_ABcd8a8b,
- ABcd8a4b = dnnl_ABcd8a4b,
- ABcd8a2b = dnnl_ABcd8a2b,
- /// 4D tensor blocked by 2nd dimension with block size 8
- aBcd8b = dnnl_aBcd8b,
- ABcd8b16a2b = dnnl_ABcd8b16a2b,
- AcdB8b16a2b = dnnl_AcdB8b16a2b,
- ABcd8b32a2b = dnnl_ABcd8b32a2b,
- AcdB8b32a2b = dnnl_AcdB8b32a2b,
- ABcd8b64a2b = dnnl_ABcd8b64a2b,
- AcdB8b64a2b = dnnl_AcdB8b64a2b,
- aBCd8b16c2b = dnnl_aBCd8b16c2b,
- /// 4D tensor blocked by 1st and 2nd dimension with block size 8
- ABcd8b8a = dnnl_ABcd8b8a,
- AcdB8b8a = dnnl_AcdB8b8a,
- aBCd8b8c = dnnl_aBCd8b8c,
- aBCd8b4c = dnnl_aBCd8b4c,
- aBCd8c16b2c = dnnl_aBCd8c16b2c,
- aBCd8c8b = dnnl_aBCd8c8b,
- Abcde16a = dnnl_Abcde16a,
- Abcde32a = dnnl_Abcde32a,
- ABcde16a16b = dnnl_ABcde16a16b,
- aBcde16b = dnnl_aBcde16b,
- aBcde32b = dnnl_aBcde32b,
- ABcde16b16a = dnnl_ABcde16b16a,
- AcdeB16b16a = dnnl_AcdeB16b16a,
- ABcde16b32a = dnnl_ABcde16b32a,
- AcdeB16b32a = dnnl_AcdeB16b32a,
- ABcde16b48a = dnnl_ABcde16b48a,
- AcdeB16b48a = dnnl_AcdeB16b48a,
- ABcde16b64a = dnnl_ABcde16b64a,
- AcdeB16b64a = dnnl_AcdeB16b64a,
- aBCde16b16c = dnnl_aBCde16b16c,
- aBCde16c16b = dnnl_aBCde16c16b,
- aBCde2c8b4c = dnnl_aBCde2c8b4c,
- Abcde4a = dnnl_Abcde4a,
- aBcde4b = dnnl_aBcde4b,
- ABcde4b4a = dnnl_ABcde4b4a,
- ABcde4a4b = dnnl_ABcde4a4b,
- aBCde4b4c = dnnl_aBCde4b4c,
- aBCde4c16b4c = dnnl_aBCde4c16b4c,
- aBCde16b16c2b = dnnl_aBCde16b16c2b,
- aBCde16c16b4c = dnnl_aBCde16c16b4c,
- aBCde16c16b2c = dnnl_aBCde16c16b2c,
- aBCdef16c16b2c = dnnl_aBCdef16c16b2c,
- aBCde4c4b = dnnl_aBCde4c4b,
- Abcde8a = dnnl_Abcde8a,
- ABcde8a8b = dnnl_ABcde8a8b,
- ABcde8a4b = dnnl_ABcde8a4b,
- aBcde8b = dnnl_aBcde8b,
- ABcde8b16a2b = dnnl_ABcde8b16a2b,
- AcdeB8b16a2b = dnnl_AcdeB8b16a2b,
- ABcde8b32a2b = dnnl_ABcde8b32a2b,
- AcdeB8b32a2b = dnnl_AcdeB8b32a2b,
- ABcde8b64a2b = dnnl_ABcde8b64a2b,
- AcdeB8b64a2b = dnnl_AcdeB8b64a2b,
- ABcde4b16a4b = dnnl_ABcde4b16a4b,
- AcdeB4b16a4b = dnnl_AcdeB4b16a4b,
- ABcde4b32a4b = dnnl_ABcde4b32a4b,
- AcdeB4b32a4b = dnnl_AcdeB4b32a4b,
- ABcde4b64a4b = dnnl_ABcde4b64a4b,
- AcdeB4b64a4b = dnnl_AcdeB4b64a4b,
- ABcde16b16a4b = dnnl_ABcde16b16a4b,
- ABcde16b32a4b = dnnl_ABcde16b32a4b,
- ABcde16b48a4b = dnnl_ABcde16b48a4b,
- ABcde16b64a4b = dnnl_ABcde16b64a4b,
- ABcde16b16a2b = dnnl_ABcde16b16a2b,
- ABcde16b32a2b = dnnl_ABcde16b32a2b,
- ABcde16b48a2b = dnnl_ABcde16b48a2b,
- ABcde16b64a2b = dnnl_ABcde16b64a2b,
- ABcde2b8a4b = dnnl_ABcde2b8a4b,
- aBCde8b16c2b = dnnl_aBCde8b16c2b,
- ABcde8b8a = dnnl_ABcde8b8a,
- AcdeB8b8a = dnnl_AcdeB8b8a,
- aBCde8b8c = dnnl_aBCde8b8c,
- aBCde8b4c = dnnl_aBCde8b4c,
- ABcd4a8b8a4b = dnnl_ABcd4a8b8a4b,
- ABcd2a8b8a2b = dnnl_ABcd2a8b8a2b,
- aBCde4b8c8b4c = dnnl_aBCde4b8c8b4c,
- aBCde2b8c8b2c = dnnl_aBCde2b8c8b2c,
- aBCde8c16b2c = dnnl_aBCde8c16b2c,
- aBCde8c8b = dnnl_aBCde8c8b,
- aBcdef16b = dnnl_aBcdef16b,
- aBCdef16b16c = dnnl_aBCdef16b16c,
- aBCdef16c16b = dnnl_aBCdef16c16b,
- aBcdef4b = dnnl_aBcdef4b,
- aBCdef2c8b4c = dnnl_aBCdef2c8b4c,
- aBCdef4c4b = dnnl_aBCdef4c4b,
- aBCdef4b4c = dnnl_aBCdef4b4c,
- aBCdef8b8c = dnnl_aBCdef8b8c,
- aBCdef8b4c = dnnl_aBCdef8b4c,
- aBCdef8c16b2c = dnnl_aBCdef8c16b2c,
- aBCdef4c16b4c = dnnl_aBCdef4c16b4c,
- aBCdef8c8b = dnnl_aBCdef8c8b,
- aBdc16b = dnnl_aBdc16b,
- aBdc4b = dnnl_aBdc4b,
- aBdc8b = dnnl_aBdc8b,
- aBdC8b2c = dnnl_aBdC8b2c,
- aBdC8b4c = dnnl_aBdC8b4c,
- aBdec16b = dnnl_aBdec16b,
- aBdec4b = dnnl_aBdec4b,
- aBdec8b = dnnl_aBdec8b,
- aBdeC8b2c = dnnl_aBdeC8b2c,
- aBdeC8b4c = dnnl_aBdeC8b4c,
- aBdefc16b = dnnl_aBdefc16b,
- aCBdef16c16b = dnnl_aCBdef16c16b,
- aCBdef8b8c = dnnl_aCBdef8b8c,
- aCBdef16b16c = dnnl_aCBdef16b16c,
- aBdefc4b = dnnl_aBdefc4b,
- aBdefc8b = dnnl_aBdefc8b,
- aBdefC8b2c = dnnl_aBdefC8b2c,
- aBdefC8b4c = dnnl_aBdefC8b4c,
- Acb16a = dnnl_Acb16a,
- Acb4a = dnnl_Acb4a,
- Acb8a = dnnl_Acb8a,
- AcB8a2b = dnnl_AcB8a2b,
- AcB8a4b = dnnl_AcB8a4b,
- aCBd8b8c = dnnl_aCBd8b8c,
- aCBd16b16c = dnnl_aCBd16b16c,
- aCBd16c16b = dnnl_aCBd16c16b,
- aCBde8b8c = dnnl_aCBde8b8c,
- aCBde16b16c = dnnl_aCBde16b16c,
- aCBde16c16b = dnnl_aCBde16c16b,
- Acdb16a = dnnl_Acdb16a,
- Acdb4a = dnnl_Acdb4a,
- Acdb8a = dnnl_Acdb8a,
- AcdB8a2b = dnnl_AcdB8a2b,
- AcdB8a4b = dnnl_AcdB8a4b,
- Acdeb16a = dnnl_Acdeb16a,
- Acdeb4a = dnnl_Acdeb4a,
- Acdeb8a = dnnl_Acdeb8a,
- AcdeB8a2b = dnnl_AcdeB8a2b,
- AcdeB8a4b = dnnl_AcdeB8a4b,
- BAc8a8b = dnnl_BAc8a8b,
- BAc16a16b = dnnl_BAc16a16b,
- BAc16b16a = dnnl_BAc16b16a,
- BAcd8a8b = dnnl_BAcd8a8b,
- BAcd16a16b = dnnl_BAcd16a16b,
- BAcd16b16a = dnnl_BAcd16b16a,
- ABcd32a32b = dnnl_ABcd32a32b,
- BAcde16b16a = dnnl_BAcde16b16a,
- BAcde8a8b = dnnl_BAcde8a8b,
- BAcde16a16b = dnnl_BAcde16a16b,
- aBdec32b = dnnl_aBdec32b,
- Abcdef16a = dnnl_Abcdef16a,
- Abcdef32a = dnnl_Abcdef32a,
- Acdb32a = dnnl_Acdb32a,
- aBCd2b4c2b = dnnl_aBCd2b4c2b,
- aBCde2b4c2b = dnnl_aBCde2b4c2b,
- aBCdef2b4c2b = dnnl_aBCdef2b4c2b,
- aBCd2c4b2c = dnnl_aBCd2c4b2c,
- aBCde2c4b2c = dnnl_aBCde2c4b2c,
- aBCdef2c4b2c = dnnl_aBCdef2c4b2c,
- aBCd4b8c2b = dnnl_aBCd4b8c2b,
- aBCde4b8c2b = dnnl_aBCde4b8c2b,
- aBCdef4b8c2b = dnnl_aBCdef4b8c2b,
- aBCd4c8b2c = dnnl_aBCd4c8b2c,
- aBCde4c8b2c = dnnl_aBCde4c8b2c,
- aBCdef4c8b2c = dnnl_aBCdef4c8b2c,
- AB32a32b8a4b = dnnl_AB32a32b8a4b,
- AB32a32b8a2b = dnnl_AB32a32b8a2b,
- AB8a4b = dnnl_AB8a4b,
- AB8a2b = dnnl_AB8a2b,
- abDc16d = dnnl_abDc16d,
- abDc32d = dnnl_abDc32d,
- abDC16d4c = dnnl_abDC16d4c,
- abDC32d4c = dnnl_abDC32d4c,
- abCd32c = dnnl_abCd32c,
- abdEc16e = dnnl_abdEc16e,
- abdEc32e = dnnl_abdEc32e,
- abdEC16e4c = dnnl_abdEC16e4c,
- abdEC32e2c = dnnl_abdEC32e2c,
- abdEC32e4c = dnnl_abdEC32e4c,
- abdCe16c = dnnl_abdCe16c,
- abdCe32c = dnnl_abdCe32c,
- abdCE32c2e = dnnl_abdCE32c2e,
- aBCdef16c16b4c = dnnl_aBCdef16c16b4c,
- aBdC16b4c = dnnl_aBdC16b4c,
- aBdeC16b4c = dnnl_aBdeC16b4c,
- AcB16a4b = dnnl_AcB16a4b,
- AcdB16a2b = dnnl_AcdB16a2b,
- aBdefC16b4c = dnnl_aBdefC16b4c,
- AcdeB16a4b = dnnl_AcdeB16a4b,
- Acb32a = dnnl_Acb32a,
- AcB32a2b = dnnl_AcB32a2b,
- AcB32a4b = dnnl_AcB32a4b,
- Acb48a = dnnl_Acb48a,
- AcB48a2b = dnnl_AcB48a2b,
- AcB48a4b = dnnl_AcB48a4b,
- Acb64a = dnnl_Acb64a,
- AcB64a2b = dnnl_AcB64a2b,
- AcB64a4b = dnnl_AcB64a4b,
- cBa2b = dnnl_cBa2b,
- cBa4b = dnnl_cBa4b,
- aBdc32b = dnnl_aBdc32b,
- aBdC32b2c = dnnl_aBdC32b2c,
- aBdC32b4c = dnnl_aBdC32b4c,
- aBdc48b = dnnl_aBdc48b,
- aBdC48b2c = dnnl_aBdC48b2c,
- aBdC48b4c = dnnl_aBdC48b4c,
- aBdc64b = dnnl_aBdc64b,
- aBdC64b2c = dnnl_aBdC64b2c,
- aBdC64b4c = dnnl_aBdC64b4c,
- adcb = dnnl_adcb,
- adCb2c = dnnl_adCb2c,
- adCb4c = dnnl_adCb4c,
- AcdB32a2b = dnnl_AcdB32a2b,
- AcdB32a4b = dnnl_AcdB32a4b,
- Acdb48a = dnnl_Acdb48a,
- AcdB48a2b = dnnl_AcdB48a2b,
- AcdB48a4b = dnnl_AcdB48a4b,
- Acdb64a = dnnl_Acdb64a,
- AcdB64a2b = dnnl_AcdB64a2b,
- AcdB64a4b = dnnl_AcdB64a4b,
- cdBa2b = dnnl_cdBa2b,
- cdBa4b = dnnl_cdBa4b,
- aBdeC32b2c = dnnl_aBdeC32b2c,
- aBdeC32b4c = dnnl_aBdeC32b4c,
- aBdec48b = dnnl_aBdec48b,
- aBdeC48b2c = dnnl_aBdeC48b2c,
- aBdeC48b4c = dnnl_aBdeC48b4c,
- aBdec64b = dnnl_aBdec64b,
- aBdeC64b2c = dnnl_aBdeC64b2c,
- aBdeC64b4c = dnnl_aBdeC64b4c,
- adecb = dnnl_adecb,
- adeCb2c = dnnl_adeCb2c,
- adeCb4c = dnnl_adeCb4c,
- Acdeb32a = dnnl_Acdeb32a,
- AcdeB32a2b = dnnl_AcdeB32a2b,
- AcdeB32a4b = dnnl_AcdeB32a4b,
- Acdeb48a = dnnl_Acdeb48a,
- AcdeB48a2b = dnnl_AcdeB48a2b,
- AcdeB48a4b = dnnl_AcdeB48a4b,
- Acdeb64a = dnnl_Acdeb64a,
- AcdeB64a2b = dnnl_AcdeB64a2b,
- AcdeB64a4b = dnnl_AcdeB64a4b,
- cdeBa2b = dnnl_cdeBa2b,
- cdeBa4b = dnnl_cdeBa4b,
- aBdefc32b = dnnl_aBdefc32b,
- aBdefC32b2c = dnnl_aBdefC32b2c,
- aBdefC32b4c = dnnl_aBdefC32b4c,
- aBdefc48b = dnnl_aBdefc48b,
- aBdefC48b2c = dnnl_aBdefC48b2c,
- aBdefC48b4c = dnnl_aBdefC48b4c,
- aBdefc64b = dnnl_aBdefc64b,
- aBdefC64b2c = dnnl_aBdefC64b2c,
- aBdefC64b4c = dnnl_aBdefC64b4c,
- adefcb = dnnl_adefcb,
- adefCb2c = dnnl_adefCb2c,
- adefCb4c = dnnl_adefCb4c,
- ABc32a32b = dnnl_ABc32a32b,
- BAc8a16b2a = dnnl_BAc8a16b2a,
- BAcd8a16b2a = dnnl_BAcd8a16b2a,
- ABcde8a16b2a = dnnl_ABcde8a16b2a,
- aCBd8b16c2b = dnnl_aCBd8b16c2b,
- BAcde8a16b2a = dnnl_BAcde8a16b2a,
- aCBde8b16c2b = dnnl_aCBde8b16c2b,
- ABcde32a32b = dnnl_ABcde32a32b,
- ABc4a8b8a4b = dnnl_ABc4a8b8a4b,
- ABcde4a8b8a4b = dnnl_ABcde4a8b8a4b,
- BAc4b8a8b4a = dnnl_BAc4b8a8b4a,
- BAcd4b8a8b4a = dnnl_BAcd4b8a8b4a,
- BAcde4b8a8b4a = dnnl_BAcde4b8a8b4a,
- aBCd4b8c8b4c = dnnl_aBCd4b8c8b4c,
- aBCdef4b8c8b4c = dnnl_aBCdef4b8c8b4c,
- aBCdef8b16c2b = dnnl_aBCdef8b16c2b,
- aCBdef8b16c2b = dnnl_aCBdef8b16c2b,
- aBdC16b2c = dnnl_aBdC16b2c,
- aBdeC16b2c = dnnl_aBdeC16b2c,
- aBdefC16b2c = dnnl_aBdefC16b2c,
- aBedc16b = dnnl_aBedc16b,
- AcB16a2b = dnnl_AcB16a2b,
- AcdB16a4b = dnnl_AcdB16a4b,
- AcdeB16a2b = dnnl_AcdeB16a2b,
- Adcb16a = dnnl_Adcb16a,
- aCBd4c8b8c4b = dnnl_aCBd4c8b8c4b,
- aCBde4c8b8c4b = dnnl_aCBde4c8b8c4b,
- aCBdef4c8b8c4b = dnnl_aCBdef4c8b8c4b,
- ABc32a16b = dnnl_ABc32a16b,
- ABcd16a32b = dnnl_ABcd16a32b,
- ABcd32a16b = dnnl_ABcd32a16b,
- ABcde32a16b = dnnl_ABcde32a16b,
- AB48a16b = dnnl_AB48a16b,
- AB48a32b = dnnl_AB48a32b,
- ABc40a16b = dnnl_ABc40a16b,
- ABc40a32b = dnnl_ABc40a32b,
- aBC48b16c = dnnl_aBC48b16c,
- aBC48b32c = dnnl_aBC48b32c,
- ABcd40a16b = dnnl_ABcd40a16b,
- ABcd40a32b = dnnl_ABcd40a32b,
- BA16a16b = dnnl_BA16a16b,
- BA16a32b = dnnl_BA16a32b,
- BA16a48b = dnnl_BA16a48b,
- BA16a64b = dnnl_BA16a64b,
- BA16a16b2a = dnnl_BA16a16b2a,
- BA16a32b2a = dnnl_BA16a32b2a,
- BA16a48b2a = dnnl_BA16a48b2a,
- BA16a64b2a = dnnl_BA16a64b2a,
- BA16a16b4a = dnnl_BA16a16b4a,
- BA16a32b4a = dnnl_BA16a32b4a,
- BA16a48b4a = dnnl_BA16a48b4a,
- BA16a64b4a = dnnl_BA16a64b4a,
- BA24b8a = dnnl_BA24b8a,
- aCB24c8b = dnnl_aCB24c8b,
- abDC24d8c = dnnl_abDC24d8c,
- decbA16a = dnnl_decbA16a,
- decbA8a = dnnl_decbA8a,
- defcbA16a = dnnl_defcbA16a,
- defcbA8a = dnnl_defcbA8a,
- aCB16b16c = dnnl_aCB16b16c,
- aCB16b32c = dnnl_aCB16b32c,
- aCB16b48c = dnnl_aCB16b48c,
- aCB16b64c = dnnl_aCB16b64c,
- aCB16b16c2b = dnnl_aCB16b16c2b,
- aCB16b32c2b = dnnl_aCB16b32c2b,
- aCB16b48c2b = dnnl_aCB16b48c2b,
- aCB16b64c2b = dnnl_aCB16b64c2b,
- aCB16b16c4b = dnnl_aCB16b16c4b,
- aCB16b32c4b = dnnl_aCB16b32c4b,
- aCB16b48c4b = dnnl_aCB16b48c4b,
- aCB16b64c4b = dnnl_aCB16b64c4b,
- Acb24a = dnnl_Acb24a,
- Acdb24a = dnnl_Acdb24a,
- Acdeb24a = dnnl_Acdeb24a,
- aBdc24b = dnnl_aBdc24b,
- aBdec24b = dnnl_aBdec24b,
- aBdefc24b = dnnl_aBdefc24b,
- AcB24a2b = dnnl_AcB24a2b,
- AcdB24a2b = dnnl_AcdB24a2b,
- AcdeB24a2b = dnnl_AcdeB24a2b,
- aBdC24b2c = dnnl_aBdC24b2c,
- aBdeC24b2c = dnnl_aBdeC24b2c,
- aBdefC24b2c = dnnl_aBdefC24b2c,
- AcB24a4b = dnnl_AcB24a4b,
- AcdB24a4b = dnnl_AcdB24a4b,
- AcdeB24a4b = dnnl_AcdeB24a4b,
- aBdC24b4c = dnnl_aBdC24b4c,
- aBdeC24b4c = dnnl_aBdeC24b4c,
- aBdefC24b4c = dnnl_aBdefC24b4c,
- AB8b32a = dnnl_AB8b32a,
- ABc8b32a = dnnl_ABc8b32a,
- AcB8b32a = dnnl_AcB8b32a,
- ABcd8b32a = dnnl_ABcd8b32a,
- AcdB8b32a = dnnl_AcdB8b32a,
- ABcde8b32a = dnnl_ABcde8b32a,
- AcdeB8b32a = dnnl_AcdeB8b32a,
- AB8b24a = dnnl_AB8b24a,
- ABc8b24a = dnnl_ABc8b24a,
- AcB8b24a = dnnl_AcB8b24a,
- ABcd8b24a = dnnl_ABcd8b24a,
- AcdB8b24a = dnnl_AcdB8b24a,
- ABcde8b24a = dnnl_ABcde8b24a,
- AcdeB8b24a = dnnl_AcdeB8b24a,
- AB8b16a = dnnl_AB8b16a,
- ABc8b16a = dnnl_ABc8b16a,
- AcB8b16a = dnnl_AcB8b16a,
- ABcd8b16a = dnnl_ABcd8b16a,
- AcdB8b16a = dnnl_AcdB8b16a,
- ABcde8b16a = dnnl_ABcde8b16a,
- AcdeB8b16a = dnnl_AcdeB8b16a,
- AB8b8a = dnnl_AB8b8a,
- abDC8d8c = dnnl_abDC8d8c,
- abDC16d8c = dnnl_abDC16d8c,
- aCB8c8b = dnnl_aCB8c8b,
- aCB16c8b = dnnl_aCB16c8b,
- BA8b8a = dnnl_BA8b8a,
- BA16b8a = dnnl_BA16b8a,
- AB2a4b = dnnl_AB2a4b,
- format_tag_last = dnnl_format_tag_last,
- nCdhw16c = dnnl_nCdhw16c,
- nCdhw4c = dnnl_nCdhw4c,
- nCdhw8c = dnnl_nCdhw8c,
- nChw16c = dnnl_nChw16c,
- nChw4c = dnnl_nChw4c,
- nChw8c = dnnl_nChw8c,
- nCw16c = dnnl_nCw16c,
- nCw4c = dnnl_nCw4c,
- nCw8c = dnnl_nCw8c,
- NCw16n16c = dnnl_NCw16n16c,
- NChw16n16c = dnnl_NChw16n16c,
- NCdhw16n16c = dnnl_NCdhw16n16c,
- NCdhw32n32c = dnnl_NCdhw32n32c,
- NChw32n32c = dnnl_NChw32n32c,
- IOhw16i16o = dnnl_IOhw16i16o,
- OI16i16o = dnnl_OI16i16o,
- OI16i32o = dnnl_OI16i32o,
- OI16i48o = dnnl_OI16i48o,
- OI16i64o = dnnl_OI16i64o,
- OI8i16o2i = dnnl_OI8i16o2i,
- OI8i32o2i = dnnl_OI8i32o2i,
- OI8i64o2i = dnnl_OI8i64o2i,
- OI4i8o4i = dnnl_OI4i8o4i,
- OI4i16o4i = dnnl_OI4i16o4i,
- OI4i24o4i = dnnl_OI4i24o4i,
- OI4i32o4i = dnnl_OI4i32o4i,
- OI4i64o4i = dnnl_OI4i64o4i,
- Ohwi32o = dnnl_Ohwi32o,
- IOdhw16i16o = dnnl_IOdhw16i16o,
- gIOhw16i16o = dnnl_gIOhw16i16o,
- gOhwi32o = dnnl_gOhwi32o,
- Goidhw16g = dnnl_Goidhw16g,
- IOw8o8i = dnnl_IOw8o8i,
- IOw16o16i = dnnl_IOw16o16i,
- OIw16i16o = dnnl_OIw16i16o,
- OwI16i16o = dnnl_OwI16i16o,
- OIw16i32o = dnnl_OIw16i32o,
- OwI16i32o = dnnl_OwI16i32o,
- OIw16i48o = dnnl_OIw16i48o,
- OwI16i48o = dnnl_OwI16i48o,
- OIw16i64o = dnnl_OIw16i64o,
- OwI16i64o = dnnl_OwI16i64o,
- IOw16i16o = dnnl_IOw16i16o,
- gIOw16i16o = dnnl_gIOw16i16o,
- OIw16o16i = dnnl_OIw16o16i,
- Oiw16o = dnnl_Oiw16o,
- OIw4i8o4i = dnnl_OIw4i8o4i,
- OwI4i8o4i = dnnl_OwI4i8o4i,
- OIw4i16o4i = dnnl_OIw4i16o4i,
- OwI4i16o4i = dnnl_OwI4i16o4i,
- OIw4i24o4i = dnnl_OIw4i24o4i,
- OwI4i24o4i = dnnl_OwI4i24o4i,
- OIw4i32o4i = dnnl_OIw4i32o4i,
- OwI4i32o4i = dnnl_OwI4i32o4i,
- OIw4i64o4i = dnnl_OIw4i64o4i,
- OwI4i64o4i = dnnl_OwI4i64o4i,
- OIw2i8o4i = dnnl_OIw2i8o4i,
- OIw4i4o = dnnl_OIw4i4o,
- OIw4o4i = dnnl_OIw4o4i,
- Oiw4o = dnnl_Oiw4o,
- OIw8i16o2i = dnnl_OIw8i16o2i,
- OwI8i16o2i = dnnl_OwI8i16o2i,
- OIw8i32o2i = dnnl_OIw8i32o2i,
- OwI8i32o2i = dnnl_OwI8i32o2i,
- OIw8i64o2i = dnnl_OIw8i64o2i,
- OwI8i64o2i = dnnl_OwI8i64o2i,
- OIw8i8o = dnnl_OIw8i8o,
- OwI8i8o = dnnl_OwI8i8o,
- OIw8o16i2o = dnnl_OIw8o16i2o,
- OIw8o8i = dnnl_OIw8o8i,
- OIw8o4i = dnnl_OIw8o4i,
- OIw16i16o4i = dnnl_OIw16i16o4i,
- OIw16i32o4i = dnnl_OIw16i32o4i,
- OIw16i48o4i = dnnl_OIw16i48o4i,
- OIw16i64o4i = dnnl_OIw16i64o4i,
- OIw16i16o2i = dnnl_OIw16i16o2i,
- OIw16i32o2i = dnnl_OIw16i32o2i,
- OIw16i48o2i = dnnl_OIw16i48o2i,
- OIw16i64o2i = dnnl_OIw16i64o2i,
- OIw16o16i2o = dnnl_OIw16o16i2o,
- Owi16o = dnnl_Owi16o,
- OwI16o2i = dnnl_OwI16o2i,
- Iwo16i = dnnl_Iwo16i,
- IwO16i2o = dnnl_IwO16i2o,
- IwO16i4o = dnnl_IwO16i4o,
- Owi4o = dnnl_Owi4o,
- Owi8o = dnnl_Owi8o,
- OwI8o2i = dnnl_OwI8o2i,
- OwI8o4i = dnnl_OwI8o4i,
- IOhw8o8i = dnnl_IOhw8o8i,
- IOhw16o16i = dnnl_IOhw16o16i,
- Ohwi16o = dnnl_Ohwi16o,
- OhwI16o2i = dnnl_OhwI16o2i,
- Ihwo16i = dnnl_Ihwo16i,
- IhwO16i2o = dnnl_IhwO16i2o,
- IhwO16i4o = dnnl_IhwO16i4o,
- Ohwi4o = dnnl_Ohwi4o,
- Ohwi8o = dnnl_Ohwi8o,
- OhwI8o2i = dnnl_OhwI8o2i,
- OhwI8o4i = dnnl_OhwI8o4i,
- OIhw16i16o = dnnl_OIhw16i16o,
- OhwI16i16o = dnnl_OhwI16i16o,
- OIhw16i32o = dnnl_OIhw16i32o,
- OhwI16i32o = dnnl_OhwI16i32o,
- OIhw16i48o = dnnl_OIhw16i48o,
- OhwI16i48o = dnnl_OhwI16i48o,
- OIhw16i64o = dnnl_OIhw16i64o,
- OhwI16i64o = dnnl_OhwI16i64o,
- OIhw16o16i = dnnl_OIhw16o16i,
- Oihw16o = dnnl_Oihw16o,
- OIhw4i8o4i = dnnl_OIhw4i8o4i,
- OhwI4i8o4i = dnnl_OhwI4i8o4i,
- OIhw4i16o4i = dnnl_OIhw4i16o4i,
- OhwI4i16o4i = dnnl_OhwI4i16o4i,
- OIhw4i24o4i = dnnl_OIhw4i24o4i,
- OhwI4i24o4i = dnnl_OhwI4i24o4i,
- OIhw4i32o4i = dnnl_OIhw4i32o4i,
- OhwI4i32o4i = dnnl_OhwI4i32o4i,
- OIhw4i64o4i = dnnl_OIhw4i64o4i,
- OhwI4i64o4i = dnnl_OhwI4i64o4i,
- OIhw4i4o = dnnl_OIhw4i4o,
- OIhw4o4i = dnnl_OIhw4o4i,
- Oihw4o = dnnl_Oihw4o,
- OIhw8i16o2i = dnnl_OIhw8i16o2i,
- OhwI8i16o2i = dnnl_OhwI8i16o2i,
- OIhw8i32o2i = dnnl_OIhw8i32o2i,
- OhwI8i32o2i = dnnl_OhwI8i32o2i,
- OIhw8i64o2i = dnnl_OIhw8i64o2i,
- OhwI8i64o2i = dnnl_OhwI8i64o2i,
- OIhw8i8o = dnnl_OIhw8i8o,
- OhwI8i8o = dnnl_OhwI8i8o,
- OIhw8o16i2o = dnnl_OIhw8o16i2o,
- OIhw8o8i = dnnl_OIhw8o8i,
- OIhw8o4i = dnnl_OIhw8o4i,
- OIhw2i8o4i = dnnl_OIhw2i8o4i,
- IOdhw8o8i = dnnl_IOdhw8o8i,
- IOdhw16o16i = dnnl_IOdhw16o16i,
- Odhwi16o = dnnl_Odhwi16o,
- OdhwI16o2i = dnnl_OdhwI16o2i,
- Idhwo16i = dnnl_Idhwo16i,
- IdhwO16i2o = dnnl_IdhwO16i2o,
- IdhwO16i4o = dnnl_IdhwO16i4o,
- Odhwi4o = dnnl_Odhwi4o,
- Odhwi8o = dnnl_Odhwi8o,
- OdhwI8o2i = dnnl_OdhwI8o2i,
- OdhwI8o4i = dnnl_OdhwI8o4i,
- OIdhw16i16o = dnnl_OIdhw16i16o,
- OdhwI16i16o = dnnl_OdhwI16i16o,
- OIdhw16i32o = dnnl_OIdhw16i32o,
- OdhwI16i32o = dnnl_OdhwI16i32o,
- OIdhw16i48o = dnnl_OIdhw16i48o,
- OdhwI16i48o = dnnl_OdhwI16i48o,
- OIdhw16i64o = dnnl_OIdhw16i64o,
- OdhwI16i64o = dnnl_OdhwI16i64o,
- OIdhw16o16i = dnnl_OIdhw16o16i,
- OIdhw16o16i2o = dnnl_OIdhw16o16i2o,
- Oidhw16o = dnnl_Oidhw16o,
- OIdhw4i4o = dnnl_OIdhw4i4o,
- OIdhw4o4i = dnnl_OIdhw4o4i,
- Oidhw4o = dnnl_Oidhw4o,
- OIdhw8i16o2i = dnnl_OIdhw8i16o2i,
- OdhwI8i16o2i = dnnl_OdhwI8i16o2i,
- OIdhw8i32o2i = dnnl_OIdhw8i32o2i,
- OdhwI8i32o2i = dnnl_OdhwI8i32o2i,
- OIdhw8i64o2i = dnnl_OIdhw8i64o2i,
- OdhwI8i64o2i = dnnl_OdhwI8i64o2i,
- OIdhw4i8o4i = dnnl_OIdhw4i8o4i,
- OdhwI4i8o4i = dnnl_OdhwI4i8o4i,
- OIdhw4i16o4i = dnnl_OIdhw4i16o4i,
- OdhwI4i16o4i = dnnl_OdhwI4i16o4i,
- OIdhw16i16o4i = dnnl_OIdhw16i16o4i,
- OIdhw16i32o4i = dnnl_OIdhw16i32o4i,
- OIdhw16i48o4i = dnnl_OIdhw16i48o4i,
- OIdhw16i64o4i = dnnl_OIdhw16i64o4i,
- OIdhw16i16o2i = dnnl_OIdhw16i16o2i,
- OIdhw16i32o2i = dnnl_OIdhw16i32o2i,
- OIdhw16i48o2i = dnnl_OIdhw16i48o2i,
- OIdhw16i64o2i = dnnl_OIdhw16i64o2i,
- OIdhw4i24o4i = dnnl_OIdhw4i24o4i,
- OdhwI4i24o4i = dnnl_OdhwI4i24o4i,
- OIdhw4i32o4i = dnnl_OIdhw4i32o4i,
- OdhwI4i32o4i = dnnl_OdhwI4i32o4i,
- OIdhw4i64o4i = dnnl_OIdhw4i64o4i,
- OdhwI4i64o4i = dnnl_OdhwI4i64o4i,
- OIdhw2i8o4i = dnnl_OIdhw2i8o4i,
- OIdhw8i8o = dnnl_OIdhw8i8o,
- OdhwI8i8o = dnnl_OdhwI8i8o,
- OIdhw8o8i = dnnl_OIdhw8o8i,
- OIdhw8o4i = dnnl_OIdhw8o4i,
- gIOw8o8i = dnnl_gIOw8o8i,
- gIOw16o16i = dnnl_gIOw16o16i,
- gOIw16i16o = dnnl_gOIw16i16o,
- gOIw16o16i = dnnl_gOIw16o16i,
- gOiw16o = dnnl_gOiw16o,
- gOIw4i16o4i = dnnl_gOIw4i16o4i,
- gOIw2i8o4i = dnnl_gOIw2i8o4i,
- gOIw4i4o = dnnl_gOIw4i4o,
- gOIw4o4i = dnnl_gOIw4o4i,
- gOiw4o = dnnl_gOiw4o,
- gOIw8i16o2i = dnnl_gOIw8i16o2i,
- gOIw8i8o = dnnl_gOIw8i8o,
- gOIw8o16i2o = dnnl_gOIw8o16i2o,
- gOIw8o8i = dnnl_gOIw8o8i,
- gOIw8o4i = dnnl_gOIw8o4i,
- gOIw16i16o4i = dnnl_gOIw16i16o4i,
- gOIw16i16o2i = dnnl_gOIw16i16o2i,
- gOIw16o16i2o = dnnl_gOIw16o16i2o,
- gOwi16o = dnnl_gOwi16o,
- gOwI16o2i = dnnl_gOwI16o2i,
- gIwo16i = dnnl_gIwo16i,
- gIwO16i2o = dnnl_gIwO16i2o,
- gIwO16i4o = dnnl_gIwO16i4o,
- gOwi4o = dnnl_gOwi4o,
- gOwi8o = dnnl_gOwi8o,
- gOwI8o2i = dnnl_gOwI8o2i,
- gOwI8o4i = dnnl_gOwI8o4i,
- Goiw8g = dnnl_Goiw8g,
- Goiw16g = dnnl_Goiw16g,
- gIOhw8o8i = dnnl_gIOhw8o8i,
- gIOhw16o16i = dnnl_gIOhw16o16i,
- gOhwi16o = dnnl_gOhwi16o,
- gOhwI16o2i = dnnl_gOhwI16o2i,
- gIhwo16i = dnnl_gIhwo16i,
- gIhwO16i2o = dnnl_gIhwO16i2o,
- gIhwO16i4o = dnnl_gIhwO16i4o,
- gOhwi4o = dnnl_gOhwi4o,
- gOhwi8o = dnnl_gOhwi8o,
- gOhwI8o2i = dnnl_gOhwI8o2i,
- gOhwI8o4i = dnnl_gOhwI8o4i,
- Goihw16g = dnnl_Goihw16g,
- gOIhw16i16o = dnnl_gOIhw16i16o,
- gOIhw16o16i = dnnl_gOIhw16o16i,
- gOihw16o = dnnl_gOihw16o,
- gOIhw4i16o4i = dnnl_gOIhw4i16o4i,
- gOIhw2i8o4i = dnnl_gOIhw2i8o4i,
- gOIhw4i4o = dnnl_gOIhw4i4o,
- gOIhw4o4i = dnnl_gOIhw4o4i,
- gOihw4o = dnnl_gOihw4o,
- Goihw8g = dnnl_Goihw8g,
- gOIhw8i16o2i = dnnl_gOIhw8i16o2i,
- gOIhw8i8o = dnnl_gOIhw8i8o,
- gOIhw8o16i2o = dnnl_gOIhw8o16i2o,
- OIw4o8i8o4i = dnnl_OIw4o8i8o4i,
- OIdhw4o8i8o4i = dnnl_OIdhw4o8i8o4i,
- OIhw4o8i8o4i = dnnl_OIhw4o8i8o4i,
- OIhw2o8i8o2i = dnnl_OIhw2o8i8o2i,
- gOIw4o8i8o4i = dnnl_gOIw4o8i8o4i,
- gOIdhw4o8i8o4i = dnnl_gOIdhw4o8i8o4i,
- gOIhw4o8i8o4i = dnnl_gOIhw4o8i8o4i,
- gOIhw2o8i8o2i = dnnl_gOIhw2o8i8o2i,
- OIhw16i16o4i = dnnl_OIhw16i16o4i,
- OIhw16i32o4i = dnnl_OIhw16i32o4i,
- OIhw16i48o4i = dnnl_OIhw16i48o4i,
- OIhw16i64o4i = dnnl_OIhw16i64o4i,
- OIhw16i16o2i = dnnl_OIhw16i16o2i,
- OIhw16i32o2i = dnnl_OIhw16i32o2i,
- OIhw16i48o2i = dnnl_OIhw16i48o2i,
- OIhw16i64o2i = dnnl_OIhw16i64o2i,
- OIhw16o16i2o = dnnl_OIhw16o16i2o,
- gOIhw16i16o4i = dnnl_gOIhw16i16o4i,
- gOIhw16i16o2i = dnnl_gOIhw16i16o2i,
- gOIhw16o16i2o = dnnl_gOIhw16o16i2o,
- gOIhw8o8i = dnnl_gOIhw8o8i,
- gOIhw8o4i = dnnl_gOIhw8o4i,
- gIOdhw16i16o = dnnl_gIOdhw16i16o,
- gIOdhw8o8i = dnnl_gIOdhw8o8i,
- gIOdhw16o16i = dnnl_gIOdhw16o16i,
- gOdhwi16o = dnnl_gOdhwi16o,
- gOdhwI16o2i = dnnl_gOdhwI16o2i,
- gIdhwo16i = dnnl_gIdhwo16i,
- gIdhwO16i2o = dnnl_gIdhwO16i2o,
- gIdhwO16i4o = dnnl_gIdhwO16i4o,
- gOdhwi4o = dnnl_gOdhwi4o,
- gOdhwi8o = dnnl_gOdhwi8o,
- gOdhwI8o2i = dnnl_gOdhwI8o2i,
- gOdhwI8o4i = dnnl_gOdhwI8o4i,
- gOIdhw16i16o = dnnl_gOIdhw16i16o,
- gOIdhw16o16i = dnnl_gOIdhw16o16i,
- gOIdhw16o16i2o = dnnl_gOIdhw16o16i2o,
- gOidhw16o = dnnl_gOidhw16o,
- gOIdhw4i4o = dnnl_gOIdhw4i4o,
- gOIdhw4o4i = dnnl_gOIdhw4o4i,
- gOidhw4o = dnnl_gOidhw4o,
- gOIdhw8i16o2i = dnnl_gOIdhw8i16o2i,
- gOIdhw4i16o4i = dnnl_gOIdhw4i16o4i,
- gOIdhw16i16o4i = dnnl_gOIdhw16i16o4i,
- gOIdhw16i16o2i = dnnl_gOIdhw16i16o2i,
- gOIdhw2i8o4i = dnnl_gOIdhw2i8o4i,
- gOIdhw8i8o = dnnl_gOIdhw8i8o,
- gOIdhw8o8i = dnnl_gOIdhw8o8i,
- gOIdhw8o4i = dnnl_gOIdhw8o4i,
- gOIw2i4o2i = dnnl_gOIw2i4o2i,
- gOIhw2i4o2i = dnnl_gOIhw2i4o2i,
- gOIdhw2i4o2i = dnnl_gOIdhw2i4o2i,
- gOIw2o4i2o = dnnl_gOIw2o4i2o,
- gOIhw2o4i2o = dnnl_gOIhw2o4i2o,
- gOIdhw2o4i2o = dnnl_gOIdhw2o4i2o,
- gOIw4i8o2i = dnnl_gOIw4i8o2i,
- gOIhw4i8o2i = dnnl_gOIhw4i8o2i,
- gOIdhw4i8o2i = dnnl_gOIdhw4i8o2i,
- gOIw4o8i2o = dnnl_gOIw4o8i2o,
- gOIhw4o8i2o = dnnl_gOIhw4o8i2o,
- gOIdhw4o8i2o = dnnl_gOIdhw4o8i2o,
- ldOi16o = abDc16d,
- ldOi32o = abDc32d,
- ldOI16o4i = abDC16d4c,
- ldOI32o4i = abDC32d4c,
- ldgOi16o = abdEc16e,
- ldgOI16o4i = abdEC16e4c,
- ldgOi32o = abdEc32e,
- ldgOI32o2i = abdEC32e2c,
- ldgOI32o4i = abdEC32e4c,
- OwI16o4i = dnnl_OwI16o4i,
- OhwI16o4i = dnnl_OhwI16o4i,
- gOwI16o4i = dnnl_gOwI16o4i,
- gOhwI16o4i = dnnl_gOhwI16o4i,
- OdhwI16o4i = dnnl_OdhwI16o4i,
- gOdhwI16o4i = dnnl_gOdhwI16o4i,
- Owi32o = dnnl_Owi32o,
- OwI32o2i = dnnl_OwI32o2i,
- OwI32o4i = dnnl_OwI32o4i,
- Owi48o = dnnl_Owi48o,
- OwI48o2i = dnnl_OwI48o2i,
- OwI48o4i = dnnl_OwI48o4i,
- Owi64o = dnnl_Owi64o,
- OwI64o2i = dnnl_OwI64o2i,
- OwI64o4i = dnnl_OwI64o4i,
- Iwo32i = dnnl_Iwo32i,
- IwO32i2o = dnnl_IwO32i2o,
- IwO32i4o = dnnl_IwO32i4o,
- Iwo48i = dnnl_Iwo48i,
- IwO48i2o = dnnl_IwO48i2o,
- IwO48i4o = dnnl_IwO48i4o,
- Iwo64i = dnnl_Iwo64i,
- IwO64i2o = dnnl_IwO64i2o,
- IwO64i4o = dnnl_IwO64i4o,
- wIo2i = dnnl_wIo2i,
- wIo4i = dnnl_wIo4i,
- gOwi32o = dnnl_gOwi32o,
- gOwI32o2i = dnnl_gOwI32o2i,
- gOwI32o4i = dnnl_gOwI32o4i,
- gOwi48o = dnnl_gOwi48o,
- gOwI48o2i = dnnl_gOwI48o2i,
- gOwI48o4i = dnnl_gOwI48o4i,
- gOwi64o = dnnl_gOwi64o,
- gOwI64o2i = dnnl_gOwI64o2i,
- gOwI64o4i = dnnl_gOwI64o4i,
- gIwo32i = dnnl_gIwo32i,
- gIwO32i2o = dnnl_gIwO32i2o,
- gIwO32i4o = dnnl_gIwO32i4o,
- gIwo48i = dnnl_gIwo48i,
- gIwO48i2o = dnnl_gIwO48i2o,
- gIwO48i4o = dnnl_gIwO48i4o,
- gIwo64i = dnnl_gIwo64i,
- gIwO64i2o = dnnl_gIwO64i2o,
- gIwO64i4o = dnnl_gIwO64i4o,
- gwio = dnnl_gwio,
- gwIo2i = dnnl_gwIo2i,
- gwIo4i = dnnl_gwIo4i,
- OhwI32o = dnnl_OhwI32o,
- OhwI32o2i = dnnl_OhwI32o2i,
- OhwI32o4i = dnnl_OhwI32o4i,
- Ohwi48o = dnnl_Ohwi48o,
- OhwI48o2i = dnnl_OhwI48o2i,
- OhwI48o4i = dnnl_OhwI48o4i,
- Ohwi64o = dnnl_Ohwi64o,
- OhwI64o2i = dnnl_OhwI64o2i,
- OhwI64o4i = dnnl_OhwI64o4i,
- Ihwo32i = dnnl_Ihwo32i,
- IhwO32i2o = dnnl_IhwO32i2o,
- IhwO32i4o = dnnl_IhwO32i4o,
- Ihwo48i = dnnl_Ihwo48i,
- IhwO48i2o = dnnl_IhwO48i2o,
- IhwO48i4o = dnnl_IhwO48i4o,
- Ihwo64i = dnnl_Ihwo64i,
- IhwO64i2o = dnnl_IhwO64i2o,
- IhwO64i4o = dnnl_IhwO64i4o,
- hwIo2i = dnnl_hwIo2i,
- hwIo4i = dnnl_hwIo4i,
- gOhwI32o = dnnl_gOhwI32o,
- gOhwI32o2i = dnnl_gOhwI32o2i,
- gOhwI32o4i = dnnl_gOhwI32o4i,
- gOhwi48o = dnnl_gOhwi48o,
- gOhwI48o2i = dnnl_gOhwI48o2i,
- gOhwI48o4i = dnnl_gOhwI48o4i,
- gOhwi64o = dnnl_gOhwi64o,
- gOhwI64o2i = dnnl_gOhwI64o2i,
- gOhwI64o4i = dnnl_gOhwI64o4i,
- gIhwo32i = dnnl_gIhwo32i,
- gIhwO32i2o = dnnl_gIhwO32i2o,
- gIhwO32i4o = dnnl_gIhwO32i4o,
- gIhwo48i = dnnl_gIhwo48i,
- gIhwO48i2o = dnnl_gIhwO48i2o,
- gIhwO48i4o = dnnl_gIhwO48i4o,
- gIhwo64i = dnnl_gIhwo64i,
- gIhwO64i2o = dnnl_gIhwO64i2o,
- gIhwO64i4o = dnnl_gIhwO64i4o,
- ghwio = dnnl_ghwio,
- ghwIo2i = dnnl_ghwIo2i,
- ghwIo4i = dnnl_ghwIo4i,
- Odhwi32o = dnnl_Odhwi32o,
- OdhwI32o2i = dnnl_OdhwI32o2i,
- OdhwI32o4i = dnnl_OdhwI32o4i,
- Odhwi48o = dnnl_Odhwi48o,
- OdhwI48o2i = dnnl_OdhwI48o2i,
- OdhwI48o4i = dnnl_OdhwI48o4i,
- Odhwi64o = dnnl_Odhwi64o,
- OdhwI64o2i = dnnl_OdhwI64o2i,
- OdhwI64o4i = dnnl_OdhwI64o4i,
- Idhwo32i = dnnl_Idhwo32i,
- IdhwO32i2o = dnnl_IdhwO32i2o,
- IdhwO32i4o = dnnl_IdhwO32i4o,
- Idhwo48i = dnnl_Idhwo48i,
- IdhwO48i2o = dnnl_IdhwO48i2o,
- IdhwO48i4o = dnnl_IdhwO48i4o,
- Idhwo64i = dnnl_Idhwo64i,
- IdhwO64i2o = dnnl_IdhwO64i2o,
- IdhwO64i4o = dnnl_IdhwO64i4o,
- dhwIo2i = dnnl_dhwIo2i,
- dhwIo4i = dnnl_dhwIo4i,
- gOdhwi32o = dnnl_gOdhwi32o,
- gOdhwI32o2i = dnnl_gOdhwI32o2i,
- gOdhwI32o4i = dnnl_gOdhwI32o4i,
- gOdhwi48o = dnnl_gOdhwi48o,
- gOdhwI48o2i = dnnl_gOdhwI48o2i,
- gOdhwI48o4i = dnnl_gOdhwI48o4i,
- gOdhwi64o = dnnl_gOdhwi64o,
- gOdhwI64o2i = dnnl_gOdhwI64o2i,
- gOdhwI64o4i = dnnl_gOdhwI64o4i,
- gIdhwo32i = dnnl_gIdhwo32i,
- gIdhwO32i2o = dnnl_gIdhwO32i2o,
- gIdhwO32i4o = dnnl_gIdhwO32i4o,
- gIdhwo48i = dnnl_gIdhwo48i,
- gIdhwO48i2o = dnnl_gIdhwO48i2o,
- gIdhwO48i4o = dnnl_gIdhwO48i4o,
- gIdhwo64i = dnnl_gIdhwo64i,
- gIdhwO64i2o = dnnl_gIdhwO64i2o,
- gIdhwO64i4o = dnnl_gIdhwO64i4o,
- gdhwio = dnnl_gdhwio,
- gdhwIo2i = dnnl_gdhwIo2i,
- gdhwIo4i = dnnl_gdhwIo4i,
- ldIo32i = dnnl_ldIo32i,
- ldgIo16i = dnnl_ldgIo16i,
- ldgIo32i = dnnl_ldgIo32i,
- ldgIO32i2o = dnnl_ldgIO32i2o,
- nCdhw32c = dnnl_nCdhw32c,
- nChw32c = dnnl_nChw32c,
- nCw32c = dnnl_nCw32c,
- NCw32n16c = dnnl_NCw32n16c,
- NChw32n16c = dnnl_NChw32n16c,
- NCdhw32n16c = dnnl_NCdhw32n16c,
- NCw32n32c = dnnl_NCw32n32c,
- OI16i16o4i = dnnl_OI16i16o4i,
- IOw8o16i2o = dnnl_IOw8o16i2o,
- IOhw8o16i2o = dnnl_IOhw8o16i2o,
- Owhi16o = dnnl_Owhi16o,
- OIdhw8o16i2o = dnnl_OIdhw8o16i2o,
- IOdhw8o16i2o = dnnl_IOdhw8o16i2o,
- Goiw4g = dnnl_Goiw4g,
- gIOw8o16i2o = dnnl_gIOw8o16i2o,
- Goiw32g = dnnl_Goiw32g,
- Goihw4g = dnnl_Goihw4g,
- gIOhw8o16i2o = dnnl_gIOhw8o16i2o,
- Goihw32g = dnnl_Goihw32g,
- gOwhi16o = dnnl_gOwhi16o,
- IOw4i8o8i4o = dnnl_IOw4i8o8i4o,
- IOhw4i8o8i4o = dnnl_IOhw4i8o8i4o,
- IOdhw4i8o8i4o = dnnl_IOdhw4i8o8i4o,
- gIOw4i8o8i4o = dnnl_gIOw4i8o8i4o,
- gIOhw4i8o8i4o = dnnl_gIOhw4i8o8i4o,
- gIOdhw4i8o8i4o = dnnl_gIOdhw4i8o8i4o,
- gOIdhw8o16i2o = dnnl_gOIdhw8o16i2o,
- gIOdhw8o16i2o = dnnl_gIOdhw8o16i2o,
- Goidhw32g = dnnl_Goidhw32g,
- OI16i32o4i = dnnl_OI16i32o4i,
- OI16i48o4i = dnnl_OI16i48o4i,
- OI16i64o4i = dnnl_OI16i64o4i,
- OI16i16o2i = dnnl_OI16i16o2i,
- OI16i32o2i = dnnl_OI16i32o2i,
- OI16i48o2i = dnnl_OI16i48o2i,
- OI16i64o2i = dnnl_OI16i64o2i,
- aBdeC16c16b4c = dnnl_aBdeC16c16b4c,
- AcB16b16a2b = dnnl_AcB16b16a2b,
- aBdC16c16b2c = dnnl_aBdC16c16b2c,
- AcB16b16a4b = dnnl_AcB16b16a4b,
- aBdC16c16b4c = dnnl_aBdC16c16b4c,
- AcdB16b16a2b = dnnl_AcdB16b16a2b,
- aBdefC16c16b4c = dnnl_aBdefC16c16b4c,
- AcdeB16b16a4b = dnnl_AcdeB16b16a4b,
- AcB16b32a2b = dnnl_AcB16b32a2b,
- AcB16b32a4b = dnnl_AcB16b32a4b,
- AcB16b48a2b = dnnl_AcB16b48a2b,
- AcB16b48a4b = dnnl_AcB16b48a4b,
- AcB16b64a2b = dnnl_AcB16b64a2b,
- AcB16b64a4b = dnnl_AcB16b64a4b,
- aBdC16c32b2c = dnnl_aBdC16c32b2c,
- aBdC16c32b4c = dnnl_aBdC16c32b4c,
- aBdC16c48b2c = dnnl_aBdC16c48b2c,
- aBdC16c48b4c = dnnl_aBdC16c48b4c,
- aBdC16c64b2c = dnnl_aBdC16c64b2c,
- aBdC16c64b4c = dnnl_aBdC16c64b4c,
- AcdB16b32a2b = dnnl_AcdB16b32a2b,
- AcdB16b32a4b = dnnl_AcdB16b32a4b,
- AcdB16b48a2b = dnnl_AcdB16b48a2b,
- AcdB16b48a4b = dnnl_AcdB16b48a4b,
- AcdB16b64a2b = dnnl_AcdB16b64a2b,
- AcdB16b64a4b = dnnl_AcdB16b64a4b,
- aBdeC16c32b2c = dnnl_aBdeC16c32b2c,
- aBdeC16c32b4c = dnnl_aBdeC16c32b4c,
- aBdeC16c48b2c = dnnl_aBdeC16c48b2c,
- aBdeC16c48b4c = dnnl_aBdeC16c48b4c,
- aBdeC16c64b2c = dnnl_aBdeC16c64b2c,
- aBdeC16c64b4c = dnnl_aBdeC16c64b4c,
- AcdeB16b32a2b = dnnl_AcdeB16b32a2b,
- AcdeB16b32a4b = dnnl_AcdeB16b32a4b,
- AcdeB16b48a2b = dnnl_AcdeB16b48a2b,
- AcdeB16b48a4b = dnnl_AcdeB16b48a4b,
- AcdeB16b64a2b = dnnl_AcdeB16b64a2b,
- AcdeB16b64a4b = dnnl_AcdeB16b64a4b,
- aBdefC16c32b2c = dnnl_aBdefC16c32b2c,
- aBdefC16c32b4c = dnnl_aBdefC16c32b4c,
- aBdefC16c48b2c = dnnl_aBdefC16c48b2c,
- aBdefC16c48b4c = dnnl_aBdefC16c48b4c,
- aBdefC16c64b2c = dnnl_aBdefC16c64b2c,
- aBdefC16c64b4c = dnnl_aBdefC16c64b4c,
- OwI16i16o2i = dnnl_OwI16i16o2i,
- gOwI16i16o2i = dnnl_gOwI16i16o2i,
- OhwI16i16o2i = dnnl_OhwI16i16o2i,
- gOhwI16i16o2i = dnnl_gOhwI16i16o2i,
- OdhwI16i16o2i = dnnl_OdhwI16i16o2i,
- gOdhwI16i16o2i = dnnl_gOdhwI16i16o2i,
- OwI16i16o4i = dnnl_OwI16i16o4i,
- gOwI16i16o4i = dnnl_gOwI16i16o4i,
- OhwI16i16o4i = dnnl_OhwI16i16o4i,
- gOhwI16i16o4i = dnnl_gOhwI16i16o4i,
- OdhwI16i16o4i = dnnl_OdhwI16i16o4i,
- gOdhwI16i16o4i = dnnl_gOdhwI16i16o4i,
- OwI16i32o2i = dnnl_OwI16i32o2i,
- OwI16i32o4i = dnnl_OwI16i32o4i,
- OwI16i48o2i = dnnl_OwI16i48o2i,
- OwI16i48o4i = dnnl_OwI16i48o4i,
- OwI16i64o2i = dnnl_OwI16i64o2i,
- OwI16i64o4i = dnnl_OwI16i64o4i,
- gOwI16i32o2i = dnnl_gOwI16i32o2i,
- gOwI16i32o4i = dnnl_gOwI16i32o4i,
- gOwI16i48o2i = dnnl_gOwI16i48o2i,
- gOwI16i48o4i = dnnl_gOwI16i48o4i,
- gOwI16i64o2i = dnnl_gOwI16i64o2i,
- gOwI16i64o4i = dnnl_gOwI16i64o4i,
- OhwI16i32o2i = dnnl_OhwI16i32o2i,
- OhwI16i32o4i = dnnl_OhwI16i32o4i,
- OhwI16i48o2i = dnnl_OhwI16i48o2i,
- OhwI16i48o4i = dnnl_OhwI16i48o4i,
- OhwI16i64o2i = dnnl_OhwI16i64o2i,
- OhwI16i64o4i = dnnl_OhwI16i64o4i,
- gOhwI16i32o2i = dnnl_gOhwI16i32o2i,
- gOhwI16i32o4i = dnnl_gOhwI16i32o4i,
- gOhwI16i48o2i = dnnl_gOhwI16i48o2i,
- gOhwI16i48o4i = dnnl_gOhwI16i48o4i,
- gOhwI16i64o2i = dnnl_gOhwI16i64o2i,
- gOhwI16i64o4i = dnnl_gOhwI16i64o4i,
- OdhwI16i32o2i = dnnl_OdhwI16i32o2i,
- OdhwI16i32o4i = dnnl_OdhwI16i32o4i,
- OdhwI16i48o2i = dnnl_OdhwI16i48o2i,
- OdhwI16i48o4i = dnnl_OdhwI16i48o4i,
- OdhwI16i64o2i = dnnl_OdhwI16i64o2i,
- OdhwI16i64o4i = dnnl_OdhwI16i64o4i,
- IdhwO16o32i2o = dnnl_IdhwO16o32i2o,
- IdhwO16o32i4o = dnnl_IdhwO16o32i4o,
- IdhwO16o48i2o = dnnl_IdhwO16o48i2o,
- IdhwO16o48i4o = dnnl_IdhwO16o48i4o,
- IdhwO16o64i2o = dnnl_IdhwO16o64i2o,
- IdhwO16o64i4o = dnnl_IdhwO16o64i4o,
- gOdhwI16i32o2i = dnnl_gOdhwI16i32o2i,
- gOdhwI16i32o4i = dnnl_gOdhwI16i32o4i,
- gOdhwI16i48o2i = dnnl_gOdhwI16i48o2i,
- gOdhwI16i48o4i = dnnl_gOdhwI16i48o4i,
- gOdhwI16i64o2i = dnnl_gOdhwI16i64o2i,
- gOdhwI16i64o4i = dnnl_gOdhwI16i64o4i,
- gIdhwO16o32i2o = dnnl_gIdhwO16o32i2o,
- gIdhwO16o32i4o = dnnl_gIdhwO16o32i4o,
- gIdhwO16o48i2o = dnnl_gIdhwO16o48i2o,
- gIdhwO16o48i4o = dnnl_gIdhwO16o48i4o,
- gIdhwO16o64i2o = dnnl_gIdhwO16o64i2o,
- gIdhwO16o64i4o = dnnl_gIdhwO16o64i4o,
- IwO16o16i2o = dnnl_IwO16o16i2o,
- IwO16o16i4o = dnnl_IwO16o16i4o,
- IhwO16o16i2o = dnnl_IhwO16o16i2o,
- IhwO16o16i4o = dnnl_IhwO16o16i4o,
- IdhwO16o16i2o = dnnl_IdhwO16o16i2o,
- IdhwO16o16i4o = dnnl_IdhwO16o16i4o,
- gIwO16o16i2o = dnnl_gIwO16o16i2o,
- gIwO16o16i4o = dnnl_gIwO16o16i4o,
- gIhwO16o16i2o = dnnl_gIhwO16o16i2o,
- gIhwO16o16i4o = dnnl_gIhwO16o16i4o,
- gIdhwO16o16i2o = dnnl_gIdhwO16o16i2o,
- gIdhwO16o16i4o = dnnl_gIdhwO16o16i4o,
- IwO16o32i2o = dnnl_IwO16o32i2o,
- IwO16o32i4o = dnnl_IwO16o32i4o,
- IwO16o48i2o = dnnl_IwO16o48i2o,
- IwO16o48i4o = dnnl_IwO16o48i4o,
- IwO16o64i2o = dnnl_IwO16o64i2o,
- IwO16o64i4o = dnnl_IwO16o64i4o,
- gIwO16o32i2o = dnnl_gIwO16o32i2o,
- gIwO16o32i4o = dnnl_gIwO16o32i4o,
- gIwO16o48i2o = dnnl_gIwO16o48i2o,
- gIwO16o48i4o = dnnl_gIwO16o48i4o,
- gIwO16o64i2o = dnnl_gIwO16o64i2o,
- gIwO16o64i4o = dnnl_gIwO16o64i4o,
- IhwO16o32i2o = dnnl_IhwO16o32i2o,
- IhwO16o32i4o = dnnl_IhwO16o32i4o,
- IhwO16o48i2o = dnnl_IhwO16o48i2o,
- IhwO16o48i4o = dnnl_IhwO16o48i4o,
- IhwO16o64i2o = dnnl_IhwO16o64i2o,
- IhwO16o64i4o = dnnl_IhwO16o64i4o,
- gIhwO16o32i2o = dnnl_gIhwO16o32i2o,
- gIhwO16o32i4o = dnnl_gIhwO16o32i4o,
- gIhwO16o48i2o = dnnl_gIhwO16o48i2o,
- gIhwO16o48i4o = dnnl_gIhwO16o48i4o,
- gIhwO16o64i2o = dnnl_gIhwO16o64i2o,
- gIhwO16o64i4o = dnnl_gIhwO16o64i4o,
- aBdeC16c16b2c = dnnl_aBdeC16c16b2c,
- aBdefC16c16b2c = dnnl_aBdefC16c16b2c,
- AcdB16b16a4b = dnnl_AcdB16b16a4b,
- AcdeB16b16a2b = dnnl_AcdeB16b16a2b,
- hwioG16g = dnnl_hwioG16g,
- hwioG8g = dnnl_hwioG8g,
- dhwioG16g = dnnl_dhwioG16g,
- dhwioG8g = dnnl_dhwioG8g,
- ABc4a2b = dnnl_ABc4a2b,
- ABc8a2b = dnnl_ABc8a2b,
- ABcd4a2b = dnnl_ABcd4a2b,
- ABcde4a2b = dnnl_ABcde4a2b,
- ABcde8a2b = dnnl_ABcde8a2b,
- ABcd4a8b8a2b = dnnl_ABcd4a8b8a2b,
- NCdhw40n32c = dnnl_NCdhw40n32c,
- NChw40n32c = dnnl_NChw40n32c,
- NCw40n32c = dnnl_NCw40n32c,
- OIdhw4o8i8o2i = dnnl_OIdhw4o8i8o2i,
- OIhw4o8i8o2i = dnnl_OIhw4o8i8o2i,
- OIw4o8i8o2i = dnnl_OIw4o8i8o2i,
- gOIdhw4o8i8o2i = dnnl_gOIdhw4o8i8o2i,
- gOIhw4o8i8o2i = dnnl_gOIhw4o8i8o2i,
- gOIw4o8i8o2i = dnnl_gOIw4o8i8o2i,
- IOdhw4i8o8i2o = dnnl_IOdhw4i8o8i2o,
- IOhw4i8o8i2o = dnnl_IOhw4i8o8i2o,
- IOw4i8o8i2o = dnnl_IOw4i8o8i2o,
- gIOdhw4i8o8i2o = dnnl_gIOdhw4i8o8i2o,
- gIOhw4i8o8i2o = dnnl_gIOhw4i8o8i2o,
- gIOw4i8o8i2o = dnnl_gIOw4i8o8i2o,
- aBCd8b2c = dnnl_aBCd8b2c,
- ABcde40a16b = dnnl_ABcde40a16b,
- ABcde40a32b = dnnl_ABcde40a32b,
- aBCde8b2c = dnnl_aBCde8b2c,
- ABcde4a8b8a2b = dnnl_ABcde4a8b8a2b,
- ABc4a8b8a2b = dnnl_ABc4a8b8a2b,
- aBCdef4b8c8b2c = dnnl_aBCdef4b8c8b2c,
- aBCde4b8c8b2c = dnnl_aBCde4b8c8b2c,
- aBCd4b8c8b2c = dnnl_aBCd4b8c8b2c,
- BAcde4b8a8b2a = dnnl_BAcde4b8a8b2a,
- BAcd4b8a8b2a = dnnl_BAcd4b8a8b2a,
- BAc4b8a8b2a = dnnl_BAc4b8a8b2a,
- aCBdef4c8b8c2b = dnnl_aCBdef4c8b8c2b,
- aCBde4c8b8c2b = dnnl_aCBde4c8b8c2b,
- aCBd4c8b8c2b = dnnl_aCBd4c8b8c2b,
- aBCdef8b2c = dnnl_aBCdef8b2c,
- AB32a16b = dnnl_AB32a16b,
- AB32a32b = dnnl_AB32a32b,
- BA4b8a8b2a = dnnl_BA4b8a8b2a,
- BA4b8a8b4a = dnnl_BA4b8a8b4a,
- aBC32b16c = dnnl_aBC32b16c,
- aBC32b32c = dnnl_aBC32b32c,
- aCB4c8b8c2b = dnnl_aCB4c8b8c2b,
- aCB4c8b8c4b = dnnl_aCB4c8b8c4b,
- ABc2b8a16b4a = dnnl_ABc2b8a16b4a,
- ABcd2b8a16b4a = dnnl_ABcd2b8a16b4a,
- ABcde2b8a16b4a = dnnl_ABcde2b8a16b4a,
- ABc2a8b16a4b = dnnl_ABc2a8b16a4b,
- ABc2a8b16a2b = dnnl_ABc2a8b16a2b,
- ABc2b32a8b = dnnl_ABc2b32a8b,
- ABcd2a8b16a4b = dnnl_ABcd2a8b16a4b,
- ABcd2a8b16a2b = dnnl_ABcd2a8b16a2b,
- aCBd2c8b16c2b = dnnl_aCBd2c8b16c2b,
- ABcd2b32a8b = dnnl_ABcd2b32a8b,
- aBCd2c8b16c2b = dnnl_aBCd2c8b16c2b,
- ABcde2a8b16a4b = dnnl_ABcde2a8b16a4b,
- ABcde2a8b16a2b = dnnl_ABcde2a8b16a2b,
- aCBde2c8b16c2b = dnnl_aCBde2c8b16c2b,
- ABcde2b32a8b = dnnl_ABcde2b32a8b,
- aBC2b8c16b2c = dnnl_aBC2b8c16b2c,
- aBCd2b8c16b2c = dnnl_aBCd2b8c16b2c,
- aBCde2b8c16b2c = dnnl_aBCde2b8c16b2c,
- aBCdef2b8c16b2c = dnnl_aBCdef2b8c16b2c,
- BAcde2b8a16b4a = dnnl_BAcde2b8a16b4a,
- BAcd2b8a16b4a = dnnl_BAcd2b8a16b4a,
- BAc2b8a16b4a = dnnl_BAc2b8a16b4a,
- BAcde2b8a16b2a = dnnl_BAcde2b8a16b2a,
- BAcd2b8a16b2a = dnnl_BAcd2b8a16b2a,
- BAc2b8a16b2a = dnnl_BAc2b8a16b2a,
- aBCde2c8b16c2b = dnnl_aBCde2c8b16c2b,
- aBCdef2c8b16c2b = dnnl_aBCdef2c8b16c2b,
- aCBdef2c8b16c2b = dnnl_aCBdef2c8b16c2b,
- aBCd2b8c16b4c = dnnl_aBCd2b8c16b4c,
- aBCde2b8c16b4c = dnnl_aBCde2b8c16b4c,
- NCdhw40n16c = dnnl_NCdhw40n16c,
- NCw40n16c = dnnl_NCw40n16c,
- NChw40n16c = dnnl_NChw40n16c,
- NCw2c32n8c = dnnl_NCw2c32n8c,
- NChw2c32n8c = dnnl_NChw2c32n8c,
- NCdhw2c32n8c = dnnl_NCdhw2c32n8c,
- OIw2i8o16i4o = dnnl_OIw2i8o16i4o,
- OIhw2i8o16i4o = dnnl_OIhw2i8o16i4o,
- OIdhw2i8o16i4o = dnnl_OIdhw2i8o16i4o,
- OIw2o8i16o4i = dnnl_OIw2o8i16o4i,
- OIw2o8i16o2i = dnnl_OIw2o8i16o2i,
- IOw2i8o16i4o = dnnl_IOw2i8o16i4o,
- IOw2i8o16i2o = dnnl_IOw2i8o16i2o,
- OIhw2o8i16o4i = dnnl_OIhw2o8i16o4i,
- OIhw2o8i16o2i = dnnl_OIhw2o8i16o2i,
- IOhw2i8o16i4o = dnnl_IOhw2i8o16i4o,
- IOhw2i8o16i2o = dnnl_IOhw2i8o16i2o,
- OIdhw2o8i16o4i = dnnl_OIdhw2o8i16o4i,
- OIdhw2o8i16o2i = dnnl_OIdhw2o8i16o2i,
- IOdhw2i8o16i4o = dnnl_IOdhw2i8o16i4o,
- IOdhw2i8o16i2o = dnnl_IOdhw2i8o16i2o,
- gOIw2o8i16o2i = dnnl_gOIw2o8i16o2i,
- gIOw2i8o16i2o = dnnl_gIOw2i8o16i2o,
- gIOhw2i8o16i2o = dnnl_gIOhw2i8o16i2o,
- gIOdhw2i8o16i2o = dnnl_gIOdhw2i8o16i2o,
- gOIhw2o8i16o2i = dnnl_gOIhw2o8i16o2i,
- gOIdhw2o8i16o2i = dnnl_gOIdhw2o8i16o2i,
- gOIw2o8i16o4i = dnnl_gOIw2o8i16o4i,
- gOIhw2o8i16o4i = dnnl_gOIhw2o8i16o4i,
- BA4b8a16b2a = dnnl_BA4b8a16b2a,
- BA4b8a16b4a = dnnl_BA4b8a16b4a,
- aCB4c8b16c2b = dnnl_aCB4c8b16c2b,
- aCB4c8b16c4b = dnnl_aCB4c8b16c4b,
- aCB16c2b = dnnl_aCB16c2b,
- aCB16c4b = dnnl_aCB16c4b,
- BA16b2a = dnnl_BA16b2a,
- BA16b4a = dnnl_BA16b4a,
- BA4b4a = dnnl_BA4b4a,
- BA8b4a = dnnl_BA8b4a,
- aBC16b16c = dnnl_aBC16b16c,
- aBC16b32c = dnnl_aBC16b32c,
- AB16a16b = dnnl_AB16a16b,
- AB16a32b = dnnl_AB16a32b,
- ABcde16a16b2a = dnnl_ABcde16a16b2a,
- aBCdef16b16c2b = dnnl_aBCdef16b16c2b,
- Acedb16a = dnnl_Acedb16a,
- aBdfec16b = dnnl_aBdfec16b,
- Odwhi16o = dnnl_Odwhi16o,
- gOdwhi16o = dnnl_gOdwhi16o,
- abdEC64e2c = dnnl_abdEC64e2c,
- abdEC64e4c = dnnl_abdEC64e4c,
- ldgOI64o2i = abdEC64e2c,
- ldgOI64o4i = abdEC64e4c,
- abCd4c = dnnl_abCd4c,
- abCde4c = dnnl_abCde4c,
- abCdef4c = dnnl_abCdef4c,
- abCde32c = dnnl_abCde32c,
- abCdef32c = dnnl_abCdef32c,
- aCdefB16b32c2b = dnnl_aCdefB16b32c2b,
- aCdefB16b32c4b = dnnl_aCdefB16b32c4b,
- aCdefB16b48c2b = dnnl_aCdefB16b48c2b,
- aCdefB16b48c4b = dnnl_aCdefB16b48c4b,
- aCdefB16b64c2b = dnnl_aCdefB16b64c2b,
- aCdefB16b64c4b = dnnl_aCdefB16b64c4b,
- BcdeA16a32b2a = dnnl_BcdeA16a32b2a,
- BcdeA16a32b4a = dnnl_BcdeA16a32b4a,
- BcdeA16a48b2a = dnnl_BcdeA16a48b2a,
- BcdeA16a48b4a = dnnl_BcdeA16a48b4a,
- BcdeA16a64b2a = dnnl_BcdeA16a64b2a,
- BcdeA16a64b4a = dnnl_BcdeA16a64b4a,
- aCdefb32c = dnnl_aCdefb32c,
- aCdefB32c2b = dnnl_aCdefB32c2b,
- aCdefB32c4b = dnnl_aCdefB32c4b,
- aCdefb48c = dnnl_aCdefb48c,
- aCdefB48c2b = dnnl_aCdefB48c2b,
- aCdefB48c4b = dnnl_aCdefB48c4b,
- aCdefb64c = dnnl_aCdefb64c,
- aCdefB64c2b = dnnl_aCdefB64c2b,
- aCdefB64c4b = dnnl_aCdefB64c4b,
- Bcdea32b = dnnl_Bcdea32b,
- BcdeA32b2a = dnnl_BcdeA32b2a,
- BcdeA32b4a = dnnl_BcdeA32b4a,
- Bcdea48b = dnnl_Bcdea48b,
- BcdeA48b2a = dnnl_BcdeA48b2a,
- BcdeA48b4a = dnnl_BcdeA48b4a,
- Bcdea64b = dnnl_Bcdea64b,
- BcdeA64b2a = dnnl_BcdeA64b2a,
- BcdeA64b4a = dnnl_BcdeA64b4a,
- Bca32b = dnnl_Bca32b,
- BcA32b2a = dnnl_BcA32b2a,
- BcA32b4a = dnnl_BcA32b4a,
- Bca48b = dnnl_Bca48b,
- BcA48b2a = dnnl_BcA48b2a,
- BcA48b4a = dnnl_BcA48b4a,
- Bca64b = dnnl_Bca64b,
- BcA64b2a = dnnl_BcA64b2a,
- BcA64b4a = dnnl_BcA64b4a,
- aCdb32c = dnnl_aCdb32c,
- aCdB32c2b = dnnl_aCdB32c2b,
- aCdB32c4b = dnnl_aCdB32c4b,
- aCdb48c = dnnl_aCdb48c,
- aCdB48c2b = dnnl_aCdB48c2b,
- aCdB48c4b = dnnl_aCdB48c4b,
- aCdb64c = dnnl_aCdb64c,
- aCdB64c2b = dnnl_aCdB64c2b,
- aCdB64c4b = dnnl_aCdB64c4b,
- BcA16a16b2a = dnnl_BcA16a16b2a,
- BcA16a16b4a = dnnl_BcA16a16b4a,
- BcdA16a16b2a = dnnl_BcdA16a16b2a,
- BcdA16a16b4a = dnnl_BcdA16a16b4a,
- BcdeA16a16b2a = dnnl_BcdeA16a16b2a,
- BcdeA16a16b4a = dnnl_BcdeA16a16b4a,
- aCdB16b16c2b = dnnl_aCdB16b16c2b,
- aCdB16b16c4b = dnnl_aCdB16b16c4b,
- aCdeB16b16c2b = dnnl_aCdeB16b16c2b,
- aCdeB16b16c4b = dnnl_aCdeB16b16c4b,
- aCdefB16b16c2b = dnnl_aCdefB16b16c2b,
- aCdefB16b16c4b = dnnl_aCdefB16b16c4b,
- BcA16a32b2a = dnnl_BcA16a32b2a,
- BcA16a32b4a = dnnl_BcA16a32b4a,
- BcA16a48b2a = dnnl_BcA16a48b2a,
- BcA16a48b4a = dnnl_BcA16a48b4a,
- BcA16a64b2a = dnnl_BcA16a64b2a,
- BcA16a64b4a = dnnl_BcA16a64b4a,
- aCdB16b32c2b = dnnl_aCdB16b32c2b,
- aCdB16b32c4b = dnnl_aCdB16b32c4b,
- aCdB16b48c2b = dnnl_aCdB16b48c2b,
- aCdB16b48c4b = dnnl_aCdB16b48c4b,
- aCdB16b64c2b = dnnl_aCdB16b64c2b,
- aCdB16b64c4b = dnnl_aCdB16b64c4b,
- BcdA16a32b2a = dnnl_BcdA16a32b2a,
- BcdA16a32b4a = dnnl_BcdA16a32b4a,
- BcdA16a48b2a = dnnl_BcdA16a48b2a,
- BcdA16a48b4a = dnnl_BcdA16a48b4a,
- BcdA16a64b2a = dnnl_BcdA16a64b2a,
- BcdA16a64b4a = dnnl_BcdA16a64b4a,
- aCdeB16b32c2b = dnnl_aCdeB16b32c2b,
- aCdeB16b32c4b = dnnl_aCdeB16b32c4b,
- aCdeB16b48c2b = dnnl_aCdeB16b48c2b,
- aCdeB16b48c4b = dnnl_aCdeB16b48c4b,
- aCdeB16b64c2b = dnnl_aCdeB16b64c2b,
- aCdeB16b64c4b = dnnl_aCdeB16b64c4b,
- Bca16b = dnnl_Bca16b,
- BcA16b2a = dnnl_BcA16b2a,
- BcA16b4a = dnnl_BcA16b4a,
- Bcda16b = dnnl_Bcda16b,
- BcdA16b2a = dnnl_BcdA16b2a,
- BcdA16b4a = dnnl_BcdA16b4a,
- Bcdea16b = dnnl_Bcdea16b,
- BcdeA16b2a = dnnl_BcdeA16b2a,
- BcdeA16b4a = dnnl_BcdeA16b4a,
- aCdb16c = dnnl_aCdb16c,
- aCdB16c2b = dnnl_aCdB16c2b,
- aCdB16c4b = dnnl_aCdB16c4b,
- aCdeb16c = dnnl_aCdeb16c,
- aCdeB16c2b = dnnl_aCdeB16c2b,
- aCdeB16c4b = dnnl_aCdeB16c4b,
- aCdefb16c = dnnl_aCdefb16c,
- aCdefB16c2b = dnnl_aCdefB16c2b,
- aCdefB16c4b = dnnl_aCdefB16c4b,
- Bcda32b = dnnl_Bcda32b,
- BcdA32b2a = dnnl_BcdA32b2a,
- BcdA32b4a = dnnl_BcdA32b4a,
- Bcda48b = dnnl_Bcda48b,
- BcdA48b2a = dnnl_BcdA48b2a,
- BcdA48b4a = dnnl_BcdA48b4a,
- Bcda64b = dnnl_Bcda64b,
- BcdA64b2a = dnnl_BcdA64b2a,
- BcdA64b4a = dnnl_BcdA64b4a,
- aCdeb32c = dnnl_aCdeb32c,
- aCdeB32c2b = dnnl_aCdeB32c2b,
- aCdeB32c4b = dnnl_aCdeB32c4b,
- aCdeb48c = dnnl_aCdeb48c,
- aCdeB48c2b = dnnl_aCdeB48c2b,
- aCdeB48c4b = dnnl_aCdeB48c4b,
- aCdeb64c = dnnl_aCdeb64c,
- aCdeB64c2b = dnnl_aCdeB64c2b,
- aCdeB64c4b = dnnl_aCdeB64c4b,
- NChw16n32c = dnnl_NChw16n32c,
- goIw4i = dnnl_goIw4i,
- goIw32i = dnnl_goIw32i,
- goIhw4i = dnnl_goIhw4i,
- goIhw32i = dnnl_goIhw32i,
- goIdhw4i = dnnl_goIdhw4i,
- goIdhw32i = dnnl_goIdhw32i,
- cab = dnnl_cab,
- cdab = dnnl_cdab,
- cdeab = dnnl_cdeab,
- woi = dnnl_woi,
- hwoi = dnnl_hwoi,
- dhwoi = dnnl_dhwoi,
- Owi24o = dnnl_Owi24o,
- Ohwi24o = dnnl_Ohwi24o,
- Odhwi24o = dnnl_Odhwi24o,
- gOwi24o = dnnl_gOwi24o,
- gOhwi24o = dnnl_gOhwi24o,
- gOdhwi24o = dnnl_gOdhwi24o,
- OwI24o2i = dnnl_OwI24o2i,
- OhwI24o2i = dnnl_OhwI24o2i,
- OdhwI24o2i = dnnl_OdhwI24o2i,
- gOwI24o2i = dnnl_gOwI24o2i,
- gOhwI24o2i = dnnl_gOhwI24o2i,
- gOdhwI24o2i = dnnl_gOdhwI24o2i,
- OwI24o4i = dnnl_OwI24o4i,
- OhwI24o4i = dnnl_OhwI24o4i,
- OdhwI24o4i = dnnl_OdhwI24o4i,
- gOwI24o4i = dnnl_gOwI24o4i,
- gOhwI24o4i = dnnl_gOhwI24o4i,
- gOdhwI24o4i = dnnl_gOdhwI24o4i,
- OI8i32o = dnnl_OI8i32o,
- OIw8i32o = dnnl_OIw8i32o,
- OwI8i32o = dnnl_OwI8i32o,
- OIhw8i32o = dnnl_OIhw8i32o,
- OhwI8i32o = dnnl_OhwI8i32o,
- OIdhw8i32o = dnnl_OIdhw8i32o,
- OdhwI8i32o = dnnl_OdhwI8i32o,
- OI8i24o = dnnl_OI8i24o,
- OIw8i24o = dnnl_OIw8i24o,
- OwI8i24o = dnnl_OwI8i24o,
- OIhw8i24o = dnnl_OIhw8i24o,
- OhwI8i24o = dnnl_OhwI8i24o,
- OIdhw8i24o = dnnl_OIdhw8i24o,
- OdhwI8i24o = dnnl_OdhwI8i24o,
- OI8i16o = dnnl_OI8i16o,
- OIw8i16o = dnnl_OIw8i16o,
- OwI8i16o = dnnl_OwI8i16o,
- OIhw8i16o = dnnl_OIhw8i16o,
- OhwI8i16o = dnnl_OhwI8i16o,
- OIdhw8i16o = dnnl_OIdhw8i16o,
- OdhwI8i16o = dnnl_OdhwI8i16o,
- OI8i8o = dnnl_OI8i8o,
- AB4b8a4b = dnnl_AB4b8a4b,
- AB4b24a4b = dnnl_AB4b24a4b,
- ABc4b8a4b = dnnl_ABc4b8a4b,
- AcB4b8a4b = dnnl_AcB4b8a4b,
- ABc4b24a4b = dnnl_ABc4b24a4b,
- AcB4b24a4b = dnnl_AcB4b24a4b,
- ABcd4b8a4b = dnnl_ABcd4b8a4b,
- AcdB4b8a4b = dnnl_AcdB4b8a4b,
- ABcd4b24a4b = dnnl_ABcd4b24a4b,
- AcdB4b24a4b = dnnl_AcdB4b24a4b,
- ABcde4b8a4b = dnnl_ABcde4b8a4b,
- AcdeB4b8a4b = dnnl_AcdeB4b8a4b,
- ABcde4b24a4b = dnnl_ABcde4b24a4b,
- AcdeB4b24a4b = dnnl_AcdeB4b24a4b,
- Bca8b = dnnl_Bca8b,
- BcA8b2a = dnnl_BcA8b2a,
- Bcda8b = dnnl_Bcda8b,
- BcdA8b2a = dnnl_BcdA8b2a,
- Bcdea8b = dnnl_Bcdea8b,
- BcdeA8b2a = dnnl_BcdeA8b2a,
- aCdb8c = dnnl_aCdb8c,
- aCdB8c2b = dnnl_aCdB8c2b,
- aCdeb8c = dnnl_aCdeb8c,
- aCdeB8c2b = dnnl_aCdeB8c2b,
- aCdefb8c = dnnl_aCdefb8c,
- aCdefB8c2b = dnnl_aCdefB8c2b,
- Bca24b = dnnl_Bca24b,
- BcA24b2a = dnnl_BcA24b2a,
- Bcda24b = dnnl_Bcda24b,
- BcdA24b2a = dnnl_BcdA24b2a,
- Bcdea24b = dnnl_Bcdea24b,
- BcdeA24b2a = dnnl_BcdeA24b2a,
- aCdb24c = dnnl_aCdb24c,
- aCdB24c2b = dnnl_aCdB24c2b,
- aCdeb24c = dnnl_aCdeb24c,
- aCdeB24c2b = dnnl_aCdeB24c2b,
- aCdefb24c = dnnl_aCdefb24c,
- aCdefB24c2b = dnnl_aCdefB24c2b,
- Iwo8i = dnnl_Iwo8i,
- IwO8i2o = dnnl_IwO8i2o,
- Iwo24i = dnnl_Iwo24i,
- IwO24i2o = dnnl_IwO24i2o,
- Ihwo8i = dnnl_Ihwo8i,
- IhwO8i2o = dnnl_IhwO8i2o,
- Ihwo24i = dnnl_Ihwo24i,
- IhwO24i2o = dnnl_IhwO24i2o,
- Idhwo8i = dnnl_Idhwo8i,
- IdhwO8i2o = dnnl_IdhwO8i2o,
- Idhwo24i = dnnl_Idhwo24i,
- IdhwO24i2o = dnnl_IdhwO24i2o,
- gIwo8i = dnnl_gIwo8i,
- gIwO8i2o = dnnl_gIwO8i2o,
- gIwo24i = dnnl_gIwo24i,
- gIwO24i2o = dnnl_gIwO24i2o,
- gIhwo8i = dnnl_gIhwo8i,
- gIhwO8i2o = dnnl_gIhwO8i2o,
- gIhwo24i = dnnl_gIhwo24i,
- gIhwO24i2o = dnnl_gIhwO24i2o,
- gIdhwo8i = dnnl_gIdhwo8i,
- gIdhwO8i2o = dnnl_gIdhwO8i2o,
- gIdhwo24i = dnnl_gIdhwo24i,
- gIdhwO24i2o = dnnl_gIdhwO24i2o,
- OhwI24o = dnnl_OhwI24o,
- gOhwI24o = dnnl_gOhwI24o,
- AB8b24a2b = dnnl_AB8b24a2b,
- ABc8b24a2b = dnnl_ABc8b24a2b,
- AcB8b24a2b = dnnl_AcB8b24a2b,
- ABcd8b24a2b = dnnl_ABcd8b24a2b,
- AcdB8b24a2b = dnnl_AcdB8b24a2b,
- ABcde8b24a2b = dnnl_ABcde8b24a2b,
- AcdeB8b24a2b = dnnl_AcdeB8b24a2b,
- AB8b8a2b = dnnl_AB8b8a2b,
- ABc8b8a2b = dnnl_ABc8b8a2b,
- AcB8b8a2b = dnnl_AcB8b8a2b,
- ABcd8b8a2b = dnnl_ABcd8b8a2b,
- AcdB8b8a2b = dnnl_AcdB8b8a2b,
- ABcde8b8a2b = dnnl_ABcde8b8a2b,
- AcdeB8b8a2b = dnnl_AcdeB8b8a2b,
- OI8i8o2i = dnnl_OI8i8o2i,
- OI8i24o2i = dnnl_OI8i24o2i,
- OIw8i8o2i = dnnl_OIw8i8o2i,
- OwI8i8o2i = dnnl_OwI8i8o2i,
- OIw8i24o2i = dnnl_OIw8i24o2i,
- OwI8i24o2i = dnnl_OwI8i24o2i,
- OIhw8i8o2i = dnnl_OIhw8i8o2i,
- OhwI8i8o2i = dnnl_OhwI8i8o2i,
- OIhw8i24o2i = dnnl_OIhw8i24o2i,
- OhwI8i24o2i = dnnl_OhwI8i24o2i,
- OIdhw8i8o2i = dnnl_OIdhw8i8o2i,
- OdhwI8i8o2i = dnnl_OdhwI8i8o2i,
- OIdhw8i24o2i = dnnl_OIdhw8i24o2i,
- OdhwI8i24o2i = dnnl_OdhwI8i24o2i,
- BcA8b4a = dnnl_BcA8b4a,
- BcdA8b4a = dnnl_BcdA8b4a,
- BcdeA8b4a = dnnl_BcdeA8b4a,
- aCdB8c4b = dnnl_aCdB8c4b,
- aCdeB8c4b = dnnl_aCdeB8c4b,
- aCdefB8c4b = dnnl_aCdefB8c4b,
- BcA24b4a = dnnl_BcA24b4a,
- BcdA24b4a = dnnl_BcdA24b4a,
- BcdeA24b4a = dnnl_BcdeA24b4a,
- aCdB24c4b = dnnl_aCdB24c4b,
- aCdeB24c4b = dnnl_aCdeB24c4b,
- aCdefB24c4b = dnnl_aCdefB24c4b,
- ABc16a4b = dnnl_ABc16a4b,
- ABcd16a4b = dnnl_ABcd16a4b,
- ABcde16a4b = dnnl_ABcde16a4b,
- IwO8i4o = dnnl_IwO8i4o,
- IwO24i4o = dnnl_IwO24i4o,
- IhwO8i4o = dnnl_IhwO8i4o,
- IhwO24i4o = dnnl_IhwO24i4o,
- IdhwO8i4o = dnnl_IdhwO8i4o,
- IdhwO24i4o = dnnl_IdhwO24i4o,
- gIwO8i4o = dnnl_gIwO8i4o,
- gIwO24i4o = dnnl_gIwO24i4o,
- gIhwO8i4o = dnnl_gIhwO8i4o,
- gIhwO24i4o = dnnl_gIhwO24i4o,
- gIdhwO8i4o = dnnl_gIdhwO8i4o,
- gIdhwO24i4o = dnnl_gIdhwO24i4o,
- BA2a24b = dnnl_BA2a24b,
- aCB2b24c = dnnl_aCB2b24c,
- BA2a8b = dnnl_BA2a8b,
- aCB2b8c = dnnl_aCB2b8c,
- BA8a24b = dnnl_BA8a24b,
- aCB8b24c = dnnl_aCB8b24c,
- BA8a16b = dnnl_BA8a16b,
- aCB8b16c = dnnl_aCB8b16c,
- BA8a8b = dnnl_BA8a8b,
- aCB8b8c = dnnl_aCB8b8c,
- bcad = dnnl_bcad,
- cabd = dnnl_cabd,
- dabc = dnnl_dabc,
- decbA4a = dnnl_decbA4a,
- defcbA4a = dnnl_defcbA4a,
- hwioG4g = dnnl_hwioG4g,
- dhwioG4g = dnnl_dhwioG4g,
- aCBd4b4c = dnnl_aCBd4b4c,
- aCBde4b4c = dnnl_aCBde4b4c,
- aCBdef4b4c = dnnl_aCBdef4b4c,
- BAc4a4b = dnnl_BAc4a4b,
- BAcd4a4b = dnnl_BAcd4a4b,
- BAcde4a4b = dnnl_BAcde4a4b,
- IOw4o4i = dnnl_IOw4o4i,
- IOhw4o4i = dnnl_IOhw4o4i,
- IOdhw4o4i = dnnl_IOdhw4o4i,
- gIOw4o4i = dnnl_gIOw4o4i,
- gIOhw4o4i = dnnl_gIOhw4o4i,
- gIOdhw4o4i = dnnl_gIOdhw4o4i,
- };
- /// A memory descriptor.
- struct desc : public handle<dnnl_memory_desc_t> {
- using handle<dnnl_memory_desc_t>::handle;
- friend struct memory;
- /// Constructs a zero (empty) memory descriptor. Such a memory
- /// descriptor can be used to indicate absence of an argument.
- desc() {
- dnnl_memory_desc_t zero_md = nullptr;
- error::wrap_c_api(
- dnnl_memory_desc_create_with_tag(&zero_md, 0, nullptr,
- dnnl_data_type_undef, dnnl_format_tag_undef),
- "could not create a zero memory descriptor");
- reset(zero_md);
- }
- /// Constructs a memory descriptor.
- ///
- /// @note
- /// The logical order of dimensions corresponds to the `abc...`
- /// format tag, and the physical meaning of the dimensions depends
- /// both on the primitive that would operate on this memory and
- /// the operation context.
- ///
- /// @param adims Tensor dimensions.
- /// @param adata_type Data precision/type.
- /// @param aformat_tag Memory format tag.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case a
- /// zero memory descriptor will be constructed. This flag is
- /// optional and defaults to false.
- desc(const dims &adims, data_type adata_type, format_tag aformat_tag,
- bool allow_empty = false) {
- validate_dims(adims);
- dnnl_memory_desc_t md = nullptr;
- dnnl_status_t status = dnnl_memory_desc_create_with_tag(&md,
- (int)adims.size(), adims.data(), convert_to_c(adata_type),
- convert_to_c(aformat_tag));
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not construct a memory descriptor using a "
- "format tag");
- reset(md);
- }
- /// Constructs a memory descriptor by strides.
- ///
- /// @note
- /// The logical order of dimensions corresponds to the `abc...`
- /// format tag, and the physical meaning of the dimensions depends
- /// both on the primitive that would operate on this memory and
- /// the operation context.
- ///
- /// @param adims Tensor dimensions.
- /// @param adata_type Data precision/type.
- /// @param strides Strides for each dimension.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case a
- /// zero memory descriptor will be constructed. This flag is
- /// optional and defaults to false.
- desc(const dims &adims, data_type adata_type, const dims &strides,
- bool allow_empty = false) {
- validate_dims(adims);
- if (!strides.empty()) validate_dims(strides, (int)adims.size());
- dnnl_memory_desc_t md = nullptr;
- dnnl_status_t status = dnnl_memory_desc_create_with_strides(&md,
- (int)adims.size(), adims.data(), convert_to_c(adata_type),
- strides.empty() ? nullptr : &strides[0]);
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not construct a memory descriptor using "
- "strides");
- reset(md);
- }
- /// Function for creating a memory descriptor for CSR sparse encoding.
- ///
- /// The created memory descriptor will describe a memory object that
- /// contains 3 buffers. The buffers have the following meaning and
- /// assigned numbers (index):
- /// - 0: values
- /// - 1: indices
- /// - 2: pointers
- ///
- /// @param adims Tensor dimensions.
- /// @param adata_type Data precision/type.
- /// @param nnz Number of non-zero entries.
- /// @param index_dt Data type of indices.
- /// @param pointer_dt Data type of pointers.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case a
- /// zero memory descriptor will be constructed. This flag is
- /// optional and defaults to false.
- /// @sa @ref dev_guide_sparsity
- static desc csr(const dims &adims, data_type adata_type, dim nnz,
- data_type index_dt, data_type pointer_dt,
- bool allow_empty = false) {
- validate_dims(adims);
- dnnl_memory_desc_t md = nullptr;
- dnnl_status_t status = dnnl_memory_desc_create_with_csr_encoding(
- &md, (int)adims.size(), adims.data(),
- convert_to_c(adata_type), nnz, convert_to_c(index_dt),
- convert_to_c(pointer_dt));
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a memory descriptor for CSR sparse "
- "encoding");
- return desc {md};
- }
- /// Function for creating a memory descriptor for COO sparse encodings.
- ///
- /// The created memory descriptor will describe a memory object that
- /// contains n+1 buffers for an n-dimensional tensor.
- /// The buffers have the following meaning and assigned numbers (index):
- /// - 0: values
- /// - 1: indices for dimension 0
- /// - 2: indices for dimension 1 ...
- /// - n: indices for dimension n-1
- ///
- /// @param adims Tensor dimensions.
- /// @param adata_type Data precision/type.
- /// @param nnz Number of non-zero entries.
- /// @param index_dt Data type of indices.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case a
- /// zero memory descriptor will be constructed. This flag is
- /// optional and defaults to false.
- /// @sa @ref dev_guide_sparsity
- static desc coo(const dims &adims, data_type adata_type, dim nnz,
- data_type index_dt, bool allow_empty = false) {
- validate_dims(adims);
- dnnl_memory_desc_t md = nullptr;
- dnnl_status_t status = dnnl_memory_desc_create_with_coo_encoding(
- &md, (int)adims.size(), adims.data(),
- convert_to_c(adata_type), nnz, convert_to_c(index_dt));
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a memory descriptor for COO sparse "
- "encoding");
- return desc {md};
- }
- /// Function for creating a memory descriptor for packed sparse
- /// encoding.
- ///
- /// The created memory descriptor cannot be used to create a memory
- /// object. It can only be used to create a primitive descriptor to
- /// query the actual memory descriptor (similar to the format tag
- /// `any`).
- ///
- /// @warning
- /// The meaning and content of the handles of the memory object that
- /// is created using the queried memory descriptor are unspecified
- /// therefore using the content is an undefined behavior.
- ///
- /// @param adims Tensor dimensions.
- /// @param adata_type Data precision/type.
- /// @param nnz Number of non-zero entries.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case a
- /// zero memory descriptor will be constructed. This flag is
- /// optional and defaults to false.
- /// @sa @ref dev_guide_sparsity
- static desc packed(const dims &adims, data_type adata_type, dim nnz,
- bool allow_empty = false) {
- validate_dims(adims);
- dnnl_memory_desc_t md = nullptr;
- dnnl_status_t status = dnnl_memory_desc_create_with_packed_encoding(
- &md, (int)adims.size(), adims.data(),
- convert_to_c(adata_type), nnz);
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a memory descriptor for packed "
- "sparse encoding");
- return desc {md};
- }
- /// Creates a memory descriptor for a scalar value that resides on the host.
- ///
- /// @param adata_type Data type of the scalar.
- /// @returns A memory descriptor for host-side scalar input.
- static desc host_scalar(data_type adata_type) {
- dnnl_memory_desc_t md = nullptr;
- error::wrap_c_api(dnnl_memory_desc_create_host_scalar(
- &md, convert_to_c(adata_type)),
- "could not create a memory descriptor describing host side "
- "scalar");
- return desc {md};
- }
- /// Construct a memory descriptor from a C API ::dnnl_memory_desc_t
- /// handle. The resulting handle is not weak and the C handle will be
- /// destroyed during the destruction of the C++ object.
- ///
- /// @param md The C API memory descriptor.
- desc(dnnl_memory_desc_t md) : handle<dnnl_memory_desc_t>(md) {}
- /// Construct a memory descriptor from a binary blob.
- ///
- /// @param blob A binary blob previously queried from a memory descriptor.
- desc(const std::vector<uint8_t> &blob) {
- dnnl_memory_desc_t md = nullptr;
- error::wrap_c_api(
- dnnl_memory_desc_create_with_blob(&md, blob.data()),
- "could not create a memory descriptor from blob");
- reset(md);
- }
- /// Constructs a memory descriptor for a region inside an area
- /// described by this memory descriptor.
- //
- /// @param adims Sizes of the region.
- /// @param offsets Offsets to the region from the encompassing
- /// memory object in each dimension.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case a
- /// zero memory descriptor will be returned. This flag is optional
- /// and defaults to false.
- /// @returns A memory descriptor for the region.
- desc submemory_desc(const dims &adims, const dims &offsets,
- bool allow_empty = false) const {
- validate_dims(adims, get_ndims());
- validate_dims(offsets, get_ndims());
- dnnl_memory_desc_t sub_md = nullptr;
- dnnl_status_t status = dnnl_memory_desc_create_submemory(
- &sub_md, get(), adims.data(), offsets.data());
- if (!allow_empty)
- error::wrap_c_api(status, "could not construct a sub-memory");
- return desc(sub_md);
- }
- /// Constructs a memory descriptor by reshaping an existing one. The
- /// new memory descriptor inherits the data type. This operation is
- /// valid only for memory descriptors that have format_kind set to
- /// #dnnl::memory::format_kind::blocked or
- /// #dnnl::memory::format_kind::any.
- ///
- /// The operation ensures that the transformation of the physical memory
- /// format corresponds to the transformation of the logical dimensions.
- /// If such transformation is impossible, the function either throws an
- /// exception (default) or returns a zero memory descriptor depending on
- /// the `allow_empty` flag.
- ///
- /// The reshape operation can be described as a combination of the
- /// following basic operations:
- /// 1. Add a dimension of size `1`. This is always possible.
- /// 2. Remove a dimension of size `1`. This is possible only if the
- /// dimension has no padding (i.e.
- /// `padded_dims[dim] == dims[dim] && dims[dim] == 1`).
- /// 3. Split a dimension into multiple ones. This is possible only if
- /// the product of all tensor dimensions stays constant and the
- /// dimension being split does not have padding (i.e.
- /// `padded_dims[dim] = dims[dim]`).
- /// 4. Join multiple consecutive dimensions into a single one. As in
- /// the cases above, this requires that the dimensions do not have
- /// padding and that the memory format is such that in physical
- /// memory these dimensions are dense and have the same order as
- /// their logical counterparts. This also assumes that these
- /// dimensions are not blocked.
- /// - Here, 'dense' means:
- /// `stride for dim[i] == (stride for dim[i + 1]) * dim[i + 1]`;
- /// - And 'same order' means:
- /// `i < j` if and only if `stride for dim[j] <= stride for dim[i]`.
- ///
- /// @warning
- /// Some combinations of physical memory layout and/or offsets or
- /// dimensions may result in a failure to make a reshape.
- ///
- /// @param adims New dimensions. The product of dimensions must
- /// remain constant.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case a
- /// zero memory descriptor will be returned. This flag is optional
- /// and defaults to false.
- /// @returns A new memory descriptor with new dimensions.
- desc reshape(const dims &adims, bool allow_empty = false) const {
- if (get_ndims()) validate_dims(adims, 1);
- dnnl_memory_desc_t out_md = nullptr;
- dnnl_status_t status = dnnl_memory_desc_reshape(
- &out_md, get(), (int)adims.size(), adims.data());
- if (!allow_empty)
- error::wrap_c_api(
- status, "could not reshape a memory descriptor");
- return desc(out_md);
- }
- /// Constructs a memory descriptor by permuting axes in an existing
- /// one.
- ///
- /// The physical memory layout representation is adjusted accordingly
- /// to maintain the consistency between the logical and physical parts
- /// of the memory descriptor. The new memory descriptor inherits the
- /// data type.
- ///
- /// The new memory descriptor inherits the data type. This operation is
- /// valid only for memory descriptors that have format_kind set to
- /// #dnnl::memory::format_kind::blocked or
- /// #dnnl::memory::format_kind::any.
- ///
- /// The logical axes will be permuted in the following manner:
- /// @code
- /// for (i = 0; i < get_ndims(); i++)
- /// new_desc.dims()[permutation[i]] = dims()[i];
- /// @endcode
- ///
- /// Example:
- /// @code
- /// std::vector<int> permutation = {1, 0}; // swap the first and
- /// // the second axes
- /// dnnl::memory::desc in_md(
- /// {2, 3}, data_type, memory::format_tag::ab);
- /// dnnl::memory::desc expect_out_md(
- /// {3, 2}, data_type, memory::format_tag::ba);
- ///
- /// assert(in_md.permute_axes(permutation) == expect_out_md);
- /// @endcode
- ///
- /// @param permutation Axes permutation.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case a
- /// zero memory descriptor will be returned. This flag is optional
- /// and defaults to false.
- /// @returns A new memory descriptor with new dimensions.
- desc permute_axes(const std::vector<int> &permutation,
- bool allow_empty = false) const {
- validate_dims(permutation, get_ndims());
- dnnl_memory_desc_t out_md = nullptr;
- dnnl_status_t status = dnnl_memory_desc_permute_axes(
- &out_md, get(), permutation.data());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not permute axes of a memory descriptor");
- return desc(out_md);
- }
- /// Returns a number of dimensions of the memory descriptor.
- ///
- /// @returns A number of dimensions.
- int get_ndims() const { return query_s32(query::ndims_s32); }
- /// Returns padded dimensions of the memory descriptor.
- ///
- /// @returns A copy of the padded dimensions vector.
- memory::dims get_padded_dims() const {
- return query_dims(query::padded_dims);
- }
- /// Returns padded offsets of the memory descriptor.
- ///
- /// @returns A copy of the padded offsets vector.
- memory::dims get_padded_offsets() const {
- return query_dims(query::padded_offsets);
- }
- /// Returns a submemory offset of the memory descriptor.
- ///
- /// @returns A submemory offset.
- memory::dim get_submemory_offset() const {
- dnnl_dim_t submemory_offset;
- dnnl_status_t status = dnnl_memory_desc_query(
- get(), dnnl_query_submemory_offset_s64, &submemory_offset);
- return status == dnnl_success ? submemory_offset : 0;
- }
- /// Returns strides of the memory descriptor.
- ///
- /// @note
- /// This API is only applicable to memory descriptors with format
- /// kind #dnnl_blocked.
- ///
- /// @returns A copy of the strides vector.
- /// @returns An empty #dnnl::memory::dims if the memory descriptor
- /// does not have strides.
- memory::dims get_strides() const { return query_dims(query::strides); }
- /// Returns a number of inner blocks of the memory descriptor.
- ///
- /// @note
- /// This API is only applicable to memory descriptors with format
- /// kind #dnnl_blocked.
- ///
- /// @returns A number of inner blocks.
- int get_inner_nblks() const {
- return query_s32(query::inner_nblks_s32);
- }
- /// Returns inner blocks of the memory descriptor.
- ///
- /// @note
- /// This API is only applicable to memory descriptors with format
- /// kind #dnnl_blocked.
- ///
- /// @returns A copy of the inner blocks vector.
- /// @returns An empty #dnnl::memory::dims if the memory descriptor
- /// does not have inner blocks.
- memory::dims get_inner_blks() const {
- return query_dims(query::inner_blks);
- }
- /// Returns inner indices of the memory descriptor.
- ///
- /// @note
- /// This API is only applicable to memory descriptors with format
- /// kind #dnnl_blocked.
- ///
- /// @returns A copy of the inner indices vector.
- /// @returns An empty #dnnl::memory::dims if the memory descriptor
- /// does not have inner indices.
- memory::dims get_inner_idxs() const {
- return query_dims(query::inner_idxs);
- }
- /// Returns number of handles.
- ///
- /// @returns A number of handles.
- int get_num_handles() const {
- int nhandles;
- dnnl_status_t status = dnnl_memory_desc_query_v2(
- get(), dnnl_query_num_handles_s32, 0, &nhandles);
- return status == dnnl_success ? nhandles : 0;
- }
- /// Returns a number of non-zero entries of the memory descriptor.
- ///
- /// @returns A number non-zero entries.
- dim get_nnz() const {
- dnnl_dim_t nnz;
- dnnl_status_t status = dnnl_memory_desc_query_v2(
- get(), dnnl_query_nnz_s64, 0, &nnz);
- return status == dnnl_success ? nnz : 0;
- }
- /// Returns the sparse encoding of the memory descriptor.
- ///
- /// @returns the sparse encoding kind.
- /// @sa @ref dev_guide_sparsity
- memory::sparse_encoding get_sparse_encoding() const {
- dnnl_sparse_encoding_t sparse_encoding;
- dnnl_status_t status = dnnl_memory_desc_query_v2(
- get(), dnnl_query_sparse_encoding, 0, &sparse_encoding);
- return status == dnnl_success
- ? static_cast<dnnl::memory::sparse_encoding>(
- sparse_encoding)
- : dnnl::memory::sparse_encoding::undef;
- }
- /// Returns the data type of the memory descriptor.
- ///
- /// @returns The data type.
- memory::data_type get_data_type(int index = 0) const {
- return query_data_type(query::data_type, index);
- }
- /// Returns the format kind of the memory descriptor.
- ///
- /// @returns the format kind.
- memory::format_kind get_format_kind() const {
- dnnl_format_kind_t format_kind;
- dnnl_status_t status = dnnl_memory_desc_query(
- get(), dnnl_query_format_kind, &format_kind);
- return status == dnnl_success
- ? static_cast<dnnl::memory::format_kind>(format_kind)
- : dnnl::memory::format_kind::undef;
- }
- /// Returns dimensions of the memory descriptor.
- ///
- /// Potentially expensive due to the data copy involved.
- /// @returns A copy of the dimensions vector.
- memory::dims get_dims() const { return query_dims(query::dims); }
- /// Returns size of the memory descriptor in bytes.
- /// @param index Data index. Defaults to 0.
- /// @returns The number of bytes required to allocate a memory buffer
- /// for data with a particular @p index described by this memory
- /// descriptor including the padding area.
- size_t get_size(int index = 0) const {
- return dnnl_memory_desc_get_size_v2(get(), index);
- }
- /// Returns a binary blob associated with the given memory descriptor
- /// @returns The memory descriptor blob associated with the memory descriptor
- std::vector<uint8_t> get_blob() {
- size_t size;
- dnnl_status_t status
- = dnnl_memory_desc_get_blob(nullptr, &size, get());
- error::wrap_c_api(
- status, "could not get memory descriptor blob size");
- std::vector<uint8_t> out_blob(size);
- status = dnnl_memory_desc_get_blob(out_blob.data(), &size, get());
- error::wrap_c_api(status, "could not get memory descriptor blob");
- return out_blob;
- }
- /// Checks whether the memory descriptor is zero (empty).
- /// @returns @c true if the memory descriptor describes an empty
- /// memory and @c false otherwise.
- bool is_zero() const { return get_ndims() == 0; }
- /// An equality operator.
- /// @param other Another memory descriptor.
- /// @returns Whether this and the other memory descriptors have
- /// the same format tag, dimensions, strides, blocking, etc.
- bool operator==(const desc &other) const {
- return dnnl_memory_desc_equal(get(), other.get()) != 0;
- }
- /// An inequality operator.
- /// @param other Another memory descriptor.
- /// @returns Whether this and the other memory descriptors describe
- /// different memory.
- bool operator!=(const desc &other) const { return !operator==(other); }
- private:
- memory::data_type query_data_type(query what, int index) const {
- dnnl_data_type_t data_type;
- dnnl_status_t status = dnnl_memory_desc_query_v2(
- get(), dnnl::convert_to_c(what), index, &data_type);
- return status == dnnl_success
- ? static_cast<dnnl::memory::data_type>(data_type)
- : dnnl::memory::data_type::undef;
- }
- int query_s32(query what) const {
- int res;
- dnnl_status_t status = dnnl_memory_desc_query(
- get(), dnnl::convert_to_c(what), &res);
- return status == dnnl_success ? res : 0;
- }
- memory::dims query_dims(query what) const {
- dnnl_dims_t *c_dims;
- dnnl_status_t status = dnnl_memory_desc_query(
- get(), dnnl::convert_to_c(what), &c_dims);
- const int ndims
- = (what == query::inner_idxs || what == query::inner_blks)
- ? get_inner_nblks()
- : get_ndims();
- return status == dnnl_success
- ? memory::dims(*c_dims, *c_dims + ndims)
- : memory::dims {};
- }
- };
- /// Default constructor.
- ///
- /// Constructs an empty memory object, which can be used to indicate
- /// absence of a parameter.
- memory() = default;
- /// Constructs a memory object.
- ///
- /// Unless @p handle is equal to #DNNL_MEMORY_NONE, the constructed memory
- /// object will have the underlying buffer set. In this case, the buffer
- /// will be initialized as if #dnnl::memory::set_data_handle() had been
- /// called.
- ///
- /// @sa memory::set_data_handle()
- ///
- /// @param md Memory descriptor.
- /// @param aengine Engine to store the data on.
- /// @param handle Handle of the memory buffer to use.
- /// - A pointer to the user-allocated buffer. In this case the library
- /// doesn't own the buffer.
- /// - The #DNNL_MEMORY_ALLOCATE special value. Instructs the library to
- /// allocate the buffer for the memory object. In this case the
- /// library owns the buffer.
- /// - #DNNL_MEMORY_NONE to create dnnl::memory without an underlying
- /// buffer.
- memory(const desc &md, const engine &aengine, void *handle)
- : memory(md, aengine, std::vector<void *> {handle}) {}
- /// Constructs a memory object with multiple handles.
- ///
- /// Unless @p handle is equal to #DNNL_MEMORY_NONE, the constructed memory
- /// object will have the underlying buffer set. In this case, the buffer
- /// will be initialized as if #dnnl::memory::set_data_handle() had been
- /// called.
- ///
- /// @sa memory::set_data_handle()
- ///
- /// @param md Memory descriptor.
- /// @param aengine Engine to store the data on.
- /// @param handles Handles of the memory buffers to use.
- /// For each element of the @p handles vector the following applies:
- /// - A pointer to the user-allocated buffer. In this case the library
- /// doesn't own the buffer.
- /// - The #DNNL_MEMORY_ALLOCATE special value. Instructs the library to
- /// allocate the buffer for the memory object. In this case the
- /// library owns the buffer.
- /// - #DNNL_MEMORY_NONE Instructs the library to skip allocation of the
- /// memory buffer.
- memory(const desc &md, const engine &aengine, std::vector<void *> handles) {
- dnnl_memory_t result;
- dnnl_status_t status = dnnl_memory_create_v2(&result, md.get(),
- aengine.get(), (int)handles.size(), handles.data());
- error::wrap_c_api(status, "could not create a memory object");
- reset(result);
- }
- /// Constructs a memory object.
- ///
- /// The underlying buffer(s) for the memory will be allocated by the
- /// library.
- /// @param md Memory descriptor.
- /// @param aengine Engine to store the data on.
- memory(const desc &md, const engine &aengine) {
- dnnl_status_t status;
- dnnl_memory_t result;
- const int nhandles = md.get_num_handles();
- std::vector<void *> handles(nhandles, DNNL_MEMORY_ALLOCATE);
- status = dnnl_memory_create_v2(&result, md.get(), aengine.get(),
- (int)handles.size(), handles.data());
- error::wrap_c_api(status, "could not create a memory object");
- reset(result);
- }
- /// Constructs a memory object that wraps a host-side scalar value.
- ///
- /// @note The scalar value is copied into the newly allocated memory storage,
- /// so the user does not need to manage the lifetime of the original scalar data.
- ///
- /// @tparam T Type of the scalar value.
- /// @param md Memory descriptor describing a scalar value residing on the host.
- /// @param value The scalar value to be wrapped by the memory object.
- ///
- /// @throws error if the memory object could not be created.
- template <typename T>
- memory(const desc &md, const T value) {
- dnnl_memory_t result;
- // Check that the data type of T matches the memory descriptor's data type
- // For host-side scalars, md.get_size() is data_type size
- if (sizeof(T) != md.get_size()) {
- DNNL_THROW_ERROR(dnnl_invalid_arguments,
- "scalar type size does not match memory descriptor data "
- "type size");
- } else {
- dnnl_status_t status = dnnl_memory_create_host_scalar(
- &result, md.get(), (void *)&value);
- error::wrap_c_api(status, "could not create a memory object");
- }
- reset(result);
- }
- /// Returns the associated memory descriptor.
- desc get_desc() const {
- const_dnnl_memory_desc_t cdesc;
- error::wrap_c_api(dnnl_memory_get_memory_desc(get(), &cdesc),
- "could not get a memory descriptor from a memory object");
- dnnl_memory_desc_t cloned_md = nullptr;
- error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc),
- "could not clone a memory descriptor");
- return desc(cloned_md);
- }
- /// Returns the associated engine.
- engine get_engine() const {
- dnnl_engine_t c_engine;
- error::wrap_c_api(dnnl_memory_get_engine(get(), &c_engine),
- "could not get an engine from a memory object");
- return engine(c_engine, true);
- }
- /// Returns an underlying memory buffer that corresponds to the given index.
- ///
- /// On the CPU engine, or when using USM, this is a pointer to the
- /// allocated memory.
- void *get_data_handle(int index = 0) const {
- void *handle;
- error::wrap_c_api(dnnl_memory_get_data_handle_v2(get(), &handle, index),
- "could not get a native handle from a memory object");
- return handle;
- }
- /// Sets an underlying memory buffer that corresponds to the given index.
- ///
- /// @param handle Memory buffer to use. On the CPU engine or when USM is
- /// used, the memory buffer is a pointer to the actual data. For OpenCL
- /// it is a cl_mem. It must have at least
- /// #dnnl::memory::desc::get_size() bytes allocated.
- /// @param index Memory index to attach the buffer. Defaults to 0.
- void set_data_handle(void *handle, int index = 0) const {
- error::wrap_c_api(dnnl_memory_set_data_handle_v2(get(), handle, index),
- "could not set native handle of a memory object");
- }
- /// Returns the scalar value stored in the memory object as type T.
- ///
- /// @tparam T Type to cast the scalar value to.
- template <typename T>
- T get_host_scalar_value() const {
- const_dnnl_memory_desc_t cdesc;
- error::wrap_c_api(dnnl_memory_get_memory_desc(get(), &cdesc),
- "could not get memory descriptor");
- if (sizeof(T) != dnnl_memory_desc_get_size_v2(cdesc, 0)) {
- DNNL_THROW_ERROR(dnnl_invalid_arguments,
- "scalar type size does not match memory descriptor data "
- "type size");
- }
- T value;
- error::wrap_c_api(dnnl_memory_get_host_scalar_value(get(), &value),
- "could not get host scalar value from a memory object");
- return value;
- }
- /// Sets the scalar value stored in the memory object.
- ///
- /// @note The scalar value is copied into the memory storage, so the user
- /// does not need to manage the lifetime of the original scalar data.
- ///
- /// @param value Pointer to the scalar value to set.
- template <typename T>
- void set_host_scalar_value(const T value) const {
- const_dnnl_memory_desc_t cdesc;
- error::wrap_c_api(dnnl_memory_get_memory_desc(get(), &cdesc),
- "could not get memory descriptor from a memory object");
- if (sizeof(T) != dnnl_memory_desc_get_size_v2(cdesc, 0)) {
- DNNL_THROW_ERROR(dnnl_invalid_arguments,
- "scalar type size does not match memory descriptor data "
- "type size");
- }
- error::wrap_c_api(dnnl_memory_set_host_scalar_value(get(), &value),
- "could not set host scalar value to a memory object");
- }
- /// Maps a memory object and returns a host-side pointer to a memory
- /// buffer with a copy of its contents. The memory buffer corresponds to
- /// the given index.
- ///
- /// Mapping enables read/write directly from/to the memory contents for
- /// engines that do not support direct memory access.
- ///
- /// Mapping is an exclusive operation - a memory object cannot be used in
- /// other operations until it is unmapped via #dnnl::memory::unmap_data()
- /// call.
- ///
- /// @note
- /// Any primitives working with the memory should be completed before
- /// the memory is mapped. Use #dnnl::stream::wait() to synchronize the
- /// corresponding execution stream.
- ///
- /// @note
- /// The map_data and unmap_data functions are provided mainly for
- /// debug and testing purposes and their performance may be suboptimal.
- ///
- /// @tparam T Data type to return a pointer to.
- /// @param index Index of the buffer. Defaults to 0.
- /// @returns Pointer to the mapped memory.
- template <typename T = void>
- T *map_data(int index = 0) const {
- void *mapped_ptr;
- error::wrap_c_api(dnnl_memory_map_data_v2(get(), &mapped_ptr, index),
- "could not map memory object data");
- return static_cast<T *>(mapped_ptr);
- }
- /// Unmaps a memory object and writes back any changes made to the
- /// previously mapped memory buffer. The memory buffer corresponds to
- /// the given index.
- ///
- /// @note
- /// The map_data and unmap_data functions are provided mainly for
- /// debug and testing purposes and their performance may be
- /// suboptimal.
- ///
- /// @param mapped_ptr A pointer previously returned by
- /// #dnnl::memory::map_data().
- /// @param index Index of the buffer. Defaults to 0.
- void unmap_data(void *mapped_ptr, int index = 0) const {
- error::wrap_c_api(dnnl_memory_unmap_data_v2(get(), mapped_ptr, index),
- "could not unmap memory object data");
- }
- static dnnl_data_type_t convert_to_c(data_type adata_type) {
- return static_cast<dnnl_data_type_t>(adata_type);
- }
- static dnnl_format_tag_t convert_to_c(format_tag format) {
- return static_cast<dnnl_format_tag_t>(format);
- }
- };
- inline bool operator==(dnnl_data_type_t a, memory::data_type b) {
- return a == memory::convert_to_c(b);
- }
- inline bool operator!=(dnnl_data_type_t a, memory::data_type b) {
- return !(a == b);
- }
- inline bool operator==(memory::data_type a, dnnl_data_type_t b) {
- return b == a;
- }
- inline bool operator!=(memory::data_type a, dnnl_data_type_t b) {
- return !(a == b);
- }
- inline bool operator==(dnnl_format_tag_t a, memory::format_tag b) {
- return a == memory::convert_to_c(b);
- }
- inline bool operator!=(dnnl_format_tag_t a, memory::format_tag b) {
- return !(a == b);
- }
- inline bool operator==(memory::format_tag a, dnnl_format_tag_t b) {
- return b == a;
- }
- inline bool operator!=(memory::format_tag a, dnnl_format_tag_t b) {
- return !(a == b);
- }
- /// @} dnnl_api_memory
- /// @addtogroup dnnl_api_primitives
- /// @{
- /// @addtogroup dnnl_api_attributes Attributes
- ///
- /// A container for parameters that extend primitives behavior.
- ///
- /// @{
- /// @cond DO_NOT_DOCUMENT_THIS
- template <>
- struct handle_traits<dnnl_post_ops_t> {
- static dnnl_status_t destructor(dnnl_post_ops_t p) {
- return dnnl_post_ops_destroy(p);
- }
- };
- /// @endcond
- /// Post-ops.
- ///
- /// Post-ops are computations executed after the main primitive computations
- /// and are attached to the primitive via primitive attributes.
- ///
- /// @sa @ref dev_guide_attributes_post_ops
- ///
- struct post_ops : public handle<dnnl_post_ops_t> {
- using handle<dnnl_post_ops_t>::handle;
- /// Constructs an empty sequence of post-ops.
- post_ops() {
- dnnl_post_ops_t result;
- error::wrap_c_api(
- dnnl_post_ops_create(&result), "could not create post-ops");
- reset(result);
- }
- /// Creates post-ops primitive attribute from a C API ::dnnl_post_ops_t
- /// handle. The resulting handle is not weak and the C handle will be
- /// destroyed during the destruction of the C++ object.
- ///
- /// @param post_ops The C API post-ops primitive attribute.
- post_ops(dnnl_post_ops_t post_ops) : handle<dnnl_post_ops_t>(post_ops) {}
- /// Returns the number of post-ops entries.
- int len() const { return dnnl_post_ops_len(get()); }
- /// Returns the primitive kind of post-op at entry with a certain index.
- /// @param index Index of the post-op to return the kind for.
- /// @returns Primitive kind of the post-op at the specified index.
- primitive::kind kind(int index) const {
- error::wrap_c_api(index < len() ? dnnl_success : dnnl_invalid_arguments,
- "post-ops index is out of range");
- return static_cast<primitive::kind>(
- dnnl_post_ops_get_kind(get(), index));
- }
- /// Appends an accumulation (sum) post-op. Prior to accumulating the
- /// result, the previous value will be will be reduced by zero point
- /// @p zero_point and multiplied by a scaling factor @p scale.
- ///
- /// The kind of this post-op is #dnnl::primitive::kind::sum.
- ///
- /// This feature may improve performance for cases like dequantize the
- /// asymmetrically quantized sum's src1 tensor to f32 domain before
- /// performing the sum operation by subtracting @p zero_point before the
- /// scaling.
- ///
- /// In the simplest case when the accumulation is the only post-op,
- /// the computations will be `dst[:] := scale * (dst[:] - zero_point) +
- /// op(...)` instead of `dst[:] := op(...)`.
- ///
- /// If @p data_type is specified, the original dst tensor will be
- /// reinterpreted as a tensor with the provided data type. Because it is a
- /// reinterpretation, data_type and dst data type should have the same size.
- /// As a result, computations will be `dst[:] <- scale *
- /// (as_data_type(dst[:]) - zero_point) + op(...)` instead of
- /// `dst[:] <- op(...)`.
- ///
- /// @note
- /// This post-op executes in-place and does not change the
- /// destination layout.
- ///
- /// @param scale Scaling factor.
- /// @param zero_point Zero point.
- /// @param data_type Data type.
- void append_sum(float scale = 1.f, int32_t zero_point = 0,
- memory::data_type data_type = memory::data_type::undef) {
- error::wrap_c_api(dnnl_post_ops_append_sum(get(), scale, zero_point,
- memory::convert_to_c(data_type)),
- "could not append a sum post-op");
- }
- /// Returns the parameters of an accumulation (sum) post-op.
- ///
- /// @param index Index of the sum post-op.
- /// @param scale Scaling factor of the sum post-op.
- void get_params_sum(int index, float &scale) const {
- error::wrap_c_api(dnnl_post_ops_get_params_sum(
- get(), index, &scale, nullptr, nullptr),
- "could not get parameters of a sum post-op");
- }
- /// Returns the parameters of an accumulation (sum) post-op.
- ///
- /// @param index Index of the sum post-op.
- /// @param scale Scaling factor of the sum post-op.
- /// @param data_type Data type of the sum post-op.
- void get_params_sum(
- int index, float &scale, memory::data_type &data_type) const {
- dnnl_data_type_t c_data_type;
- error::wrap_c_api(dnnl_post_ops_get_params_sum(
- get(), index, &scale, nullptr, &c_data_type),
- "could not get parameters of a sum post-op");
- data_type = static_cast<memory::data_type>(c_data_type);
- }
- /// Returns the parameters of an accumulation (sum) post-op.
- ///
- /// @param index Index of the sum post-op.
- /// @param scale Scaling factor of the sum post-op.
- /// @param zero_point Single scalar int32_t value of zeropoint.
- /// @param data_type Data type of the sum post-op.
- void get_params_sum(int index, float &scale, int32_t &zero_point,
- memory::data_type &data_type) const {
- dnnl_data_type_t c_data_type;
- error::wrap_c_api(dnnl_post_ops_get_params_sum(get(), index, &scale,
- &zero_point, &c_data_type),
- "could not get parameters of a sum post-op");
- data_type = static_cast<memory::data_type>(c_data_type);
- }
- /// Appends an elementwise post-op.
- ///
- /// The kind of this post-op is #dnnl::primitive::kind::eltwise.
- ///
- /// In the simplest case when the elementwise is the only post-op, the
- /// computations would be `dst[:] := eltwise_op (op(...))` instead
- /// of `dst[:] <- op(...)`, where eltwise_op is configured with the given
- /// parameters.
- ///
- /// @param aalgorithm Elementwise algorithm.
- /// @param alpha Alpha parameter for the elementwise algorithm.
- /// @param beta Beta parameter for the elementwise algorithm.
- void append_eltwise(algorithm aalgorithm, float alpha, float beta) {
- error::wrap_c_api(dnnl_post_ops_append_eltwise(
- get(), convert_to_c(aalgorithm), alpha, beta),
- "could not append an elementwise post-op");
- }
- /// Returns parameters of an elementwise post-op.
- ///
- /// @param index Index of the post-op.
- /// @param aalgorithm Output elementwise algorithm kind.
- /// @param alpha Output alpha parameter for the elementwise algorithm.
- /// @param beta Output beta parameter for the elementwise algorithm.
- void get_params_eltwise(
- int index, algorithm &aalgorithm, float &alpha, float &beta) const {
- dnnl_alg_kind_t c_alg;
- error::wrap_c_api(dnnl_post_ops_get_params_eltwise(
- get(), index, &c_alg, &alpha, &beta),
- "could not get parameters of an elementwise post-op");
- aalgorithm = static_cast<dnnl::algorithm>(c_alg);
- }
- /// Appends a depthwise post-op convolution.
- ///
- /// This post-op can only be fused with a 2D 1x1 convolution (convolution
- /// with weights spatial dimension equal to 1 i.e., kh=kw=1).
- ///
- /// The kind of this post-op is #dnnl_convolution.
- ///
- /// The number of outputs for primitive remain same as before. The output
- /// spatial size can be derived as below:
- ///
- /// output_height = ceil(output_height_1x1_convolution, stride)
- /// output_width = ceil(output_width_1x1_convolution, stride)
- ///
- /// See @ref dev_guide_attributes_post_ops_depthwise and
- /// @ref dev_guide_attributes_post_ops_depthwise_fusion for more info.
- ///
- /// @param weights_data_type Weights data type of depthwise post-op
- /// @param bias_data_type Bias data type of depthwise post-op
- /// @param dst_data_type Output data type of depthwise post-op
- /// @param kernel_size Size of kernel of depthwise post-op
- /// @param stride_size Size of stride of depthwise post-op
- /// @param padding_l_size Size of left and top paddings of depthwise post-op
- void append_dw(memory::data_type weights_data_type,
- memory::data_type bias_data_type, memory::data_type dst_data_type,
- memory::dim kernel_size, memory::dim stride_size,
- memory::dim padding_l_size) {
- error::wrap_c_api(dnnl_post_ops_append_dw(get(),
- memory::convert_to_c(weights_data_type),
- memory::convert_to_c(bias_data_type),
- memory::convert_to_c(dst_data_type),
- kernel_size, stride_size, padding_l_size),
- "could not append depthwise post-op");
- }
- /// Returns the parameters of an depthwise post-op.
- ///
- /// @param index Index of the elementwise post-op.
- /// @param weights_data_type Weights data type of depthwise post-op
- /// @param bias_data_type Bias data type of depthwise post-op
- /// @param dst_data_type Output data type of depthwise post-op
- /// @param kernel_size Size of kernel of depthwise post-op
- /// @param stride_size Size of stride of depthwise post-op
- /// @param padding_l_size Size of left and top paddings of depthwise post-op
- void get_params_dw(int index, memory::data_type &weights_data_type,
- memory::data_type &bias_data_type, memory::data_type &dst_data_type,
- memory::dim &kernel_size, memory::dim &stride_size,
- memory::dim &padding_l_size) const {
- dnnl_data_type_t c_weights_data_type;
- dnnl_data_type_t c_bias_data_type;
- dnnl_data_type_t c_dst_data_type;
- dnnl_dim_t c_kernel_size;
- dnnl_dim_t c_stride_size;
- dnnl_dim_t c_padding_l_size;
- error::wrap_c_api(
- dnnl_post_ops_get_params_dw(get(), index, &c_weights_data_type,
- &c_bias_data_type, &c_dst_data_type, &c_kernel_size,
- &c_stride_size, &c_padding_l_size),
- "could not get parameters of depthwise post-op");
- weights_data_type = static_cast<memory::data_type>(c_weights_data_type);
- bias_data_type = static_cast<memory::data_type>(c_bias_data_type);
- dst_data_type = static_cast<memory::data_type>(c_dst_data_type);
- kernel_size = c_kernel_size;
- stride_size = c_stride_size;
- padding_l_size = c_padding_l_size;
- }
- /// Appends a binary post-op.
- ///
- /// This post operation is categorized as #dnnl_binary.
- ///
- /// In the simplest case when the binary is the only post operation, the
- /// computations will be:
- ///
- /// dst[:] <- binary_op (dst[:], another_input[:])
- ///
- /// where binary_op is configured with the given parameters. binary_op
- /// supports broadcast semantics for a second operand.
- ///
- /// @param aalgorithm Binary algorithm for the post-op.
- /// @param src1_desc Memory descriptor of a second operand.
- void append_binary(algorithm aalgorithm, const memory::desc &src1_desc) {
- error::wrap_c_api(dnnl_post_ops_append_binary(get(),
- convert_to_c(aalgorithm), src1_desc.get()),
- "could not append a binary post-op");
- }
- /// Appends a binary post-op with ternary operators.
- ///
- /// This post operation is categorized as #dnnl_binary.
- ///
- /// In the simplest case when this is the only post operation, the
- /// computations will be:
- ///
- /// dst[:] <- binary_op (dst[:], another_input1[:], another_input2[:])
- ///
- /// where binary_op is configured with the given parameters. binary_op
- /// supports broadcast semantics only for the second operand and not for the
- /// third operand.
- ///
- /// @param aalgorithm Binary algorithm for the post-op.
- /// @param src1_desc Memory descriptor of the second operand.
- /// @param src2_desc Memory descriptor of the third operand. If the specified
- /// algorithm is not one that requires a ternary input, src2_desc will be
- /// ignored.
- void append_binary(algorithm aalgorithm, const memory::desc &src1_desc,
- const memory::desc &src2_desc) {
- error::wrap_c_api(
- dnnl_post_ops_append_binary_v2(get(), convert_to_c(aalgorithm),
- src1_desc.get(), src2_desc.get()),
- "could not append a binary post-op with ternary operators");
- }
- /// Returns the parameters of a binary post-op.
- ///
- /// @param index Index of the binary post-op.
- /// @param aalgorithm Output binary algorithm kind.
- /// @param src1_desc Output memory descriptor of a second operand.
- void get_params_binary(
- int index, algorithm &aalgorithm, memory::desc &src1_desc) const {
- dnnl_alg_kind_t c_alg;
- const_dnnl_memory_desc_t cdesc;
- error::wrap_c_api(
- dnnl_post_ops_get_params_binary(get(), index, &c_alg, &cdesc),
- "could not get parameters of a binary post-op");
- aalgorithm = static_cast<dnnl::algorithm>(c_alg);
- dnnl_memory_desc_t cloned_md = nullptr;
- error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc),
- "could not clone a memory descriptor");
- src1_desc = memory::desc(cloned_md);
- }
- /// Returns the parameters of a binary post-op with ternary operators.
- ///
- /// @param index Index of the binary post-op.
- /// @param aalgorithm Output binary algorithm kind.
- /// @param src1_desc Output memory descriptor of the second operand.
- /// @param src2_desc Output memory descriptor of the third operand.
- void get_params_binary(int index, algorithm &aalgorithm,
- memory::desc &src1_desc, memory::desc &src2_desc) const {
- dnnl_alg_kind_t c_alg;
- const_dnnl_memory_desc_t cdesc1, cdesc2;
- error::wrap_c_api(dnnl_post_ops_get_params_binary_v2(
- get(), index, &c_alg, &cdesc1, &cdesc2),
- "could not get parameters of a binary post-op with ternary "
- "operators");
- aalgorithm = static_cast<dnnl::algorithm>(c_alg);
- dnnl_memory_desc_t cloned_md1 = nullptr;
- dnnl_memory_desc_t cloned_md2 = nullptr;
- error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md1, cdesc1),
- "could not clone a memory descriptor");
- src1_desc = memory::desc(cloned_md1);
- error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md2, cdesc2),
- "could not clone a memory descriptor");
- src2_desc = memory::desc(cloned_md2);
- }
- /// Appends a prelu forward post-op.
- ///
- /// The kind of this post-op is #dnnl::primitive::kind::prelu.
- ///
- /// The post-op can be defined as:
- ///
- /// dst[:] <- prelu(dst[:], weights[:])
- /// prelu:
- /// dst[:] <- dst[:] if dst[:] > 0
- /// dst[:] <- dst[:] * weights[:] if dst[:] <= 0
- ///
- ///
- /// Example usage:
- /// @code
- /// int mb = 32, oc = 32,
- /// oh = 14, ow = 14; // convolution output params
- /// // unique weights per output channel
- /// vector<float> weights = { ... };
- /// int oc_dim = 1; // mb_dim = 0, channel_dim = 1, height_dim = 2, ...
- ///
- /// // construct a convolution descriptor
- /// dnnl::convolution::desc conv_d;
- ///
- /// dnnl::primitive_attr attr;
- /// attr.append_prelu(1 << oc_dim);
- ///
- /// dnnl::primitive_desc conv_pd(conv_d, attr, engine);
- /// memory prelu_weights({{1}, dt::f32, {1}}, eng, weights.data());
- ///
- /// std::unordered_map<int, memory> conv_args;
- ///
- /// conv_args.insert(
- /// {DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_WEIGHTS, prelu_weights})
- /// @endcode
- ///
- /// @note
- /// The order of dimensions does not depend on how elements are laid
- /// out in memory. For example:
- /// - for a 2D CNN activations tensor the order is always (n, c)
- /// - for a 4D CNN activations tensor the order is always (n, c, h, w)
- /// - for a 5D CNN weights tensor the order is always
- /// (g, oc, ic, kh, kw)
- ///
- /// Prelu weights tensor is passed in runtime execution phase. Prelu
- /// weights tensor data type is implicitly assumed as f32 using plain
- /// layout (a, ab, acb, acdb, acdeb).
- ///
- /// @param mask Defines the correspondence between the output tensor
- /// dimensions and the prelu weights tensor. The set i-th bit indicates
- /// that a dedicated weights value is used for each index along that
- /// dimension. Set the mask to 0 to use a common weights value
- /// for the whole output tensor.
- void append_prelu(int mask) {
- error::wrap_c_api(dnnl_post_ops_append_prelu(get(), mask),
- "could not append a prelu post-op");
- }
- /// Returns the parameters of a prelu post-op.
- ///
- /// @param index Index of the prelu post-op.
- /// @param mask Weights mask of prelu post-op.
- void get_params_prelu(int index, int &mask) const {
- error::wrap_c_api(dnnl_post_ops_get_params_prelu(get(), index, &mask),
- "could not get parameters of a binary post-op");
- }
- };
- /// @cond DO_NOT_DOCUMENT_THIS
- template <>
- struct handle_traits<dnnl_primitive_attr_t> {
- static dnnl_status_t destructor(dnnl_primitive_attr_t p) {
- return dnnl_primitive_attr_destroy(p);
- }
- };
- /// @endcond
- /// Primitive attributes.
- ///
- /// @sa @ref dev_guide_attributes
- struct primitive_attr : public handle<dnnl_primitive_attr_t> {
- using handle<dnnl_primitive_attr_t>::handle;
- /// Constructs default (empty) primitive attributes.
- primitive_attr() {
- dnnl_primitive_attr_t result;
- error::wrap_c_api(dnnl_primitive_attr_create(&result),
- "could not create primitive attribute");
- reset(result);
- }
- /// Creates primitive attributes from a C API ::dnnl_primitive_attr_t
- /// handle. The resulting handle is not weak and the C handle will be
- /// destroyed during the destruction of the C++ object.
- ///
- /// @param attr The C API primitive attributes.
- primitive_attr(dnnl_primitive_attr_t attr)
- : handle<dnnl_primitive_attr_t>(attr) {}
- /// Returns the parameters of a dropout attribute.
- ///
- /// @param mask_desc Output memory descriptor of a dropout mask.
- void get_dropout(memory::desc &mask_desc) const {
- const_dnnl_memory_desc_t cdesc;
- error::wrap_c_api(dnnl_primitive_attr_get_dropout(get(), &cdesc),
- "could not get parameters of a dropout attribute");
- dnnl_memory_desc_t cloned_md = nullptr;
- error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc),
- "could not clone a memory descriptor");
- mask_desc = memory::desc(cloned_md);
- }
- /// Sets dropout probability.
- ///
- /// @param mask_desc Output memory descriptor of a dropout mask.
- void set_dropout(const memory::desc &mask_desc) {
- error::wrap_c_api(
- dnnl_primitive_attr_set_dropout(get(), mask_desc.get()),
- "could not set dropout primitive attribute");
- }
- /// Returns the fpmath mode
- fpmath_mode get_fpmath_mode() const {
- dnnl_fpmath_mode_t result;
- error::wrap_c_api(dnnl_primitive_attr_get_fpmath_mode(get(), &result),
- "could not get fpmath mode primitive attribute");
- return fpmath_mode(result);
- }
- /// Returns the fpmath mode
- ///
- /// @param mode Specified fpmath mode.
- /// @param apply_to_int Use floating-point arithmetic for integer primitives.
- void get_fpmath_mode(fpmath_mode &mode, bool &apply_to_int) const {
- dnnl_fpmath_mode_t c_mode;
- int c_apply_to_int;
- error::wrap_c_api(dnnl_primitive_attr_get_fpmath_mode_v2(
- get(), &c_mode, &c_apply_to_int),
- "could not get fpmath mode primitive attribute");
- mode = fpmath_mode(c_mode);
- apply_to_int = static_cast<bool>(c_apply_to_int);
- }
- /// Sets fpmath mode.
- ///
- /// @param mode Specified fpmath mode.
- /// @param apply_to_int Boolean. Use of floating-point arithmetic for integer primitives.
- void set_fpmath_mode(fpmath_mode mode, bool apply_to_int = false) {
- error::wrap_c_api(dnnl_primitive_attr_set_fpmath_mode_v2(get(),
- dnnl::convert_to_c(mode), apply_to_int),
- "could not set fpmath mode primitive attribute");
- }
- /// Returns the accumulation mode
- accumulation_mode get_accumulation_mode() const {
- dnnl_accumulation_mode_t result;
- error::wrap_c_api(
- dnnl_primitive_attr_get_accumulation_mode(get(), &result),
- "could not get accumulation mode primitive attribute");
- return accumulation_mode(result);
- }
- /// Sets accumulation mode.
- ///
- /// @param mode Specified accumulation mode.
- void set_accumulation_mode(accumulation_mode mode) {
- error::wrap_c_api(dnnl_primitive_attr_set_accumulation_mode(
- get(), dnnl::convert_to_c(mode)),
- "could not set accumulation mode primitive attribute");
- }
- /// Returns the deterministic attribute value
- bool get_deterministic() const {
- int result;
- error::wrap_c_api(dnnl_primitive_attr_get_deterministic(get(), &result),
- "could not get deterministic primitive attribute");
- return static_cast<bool>(result);
- }
- /// Sets deterministic attribute value
- ///
- /// @param value Specified deterministic mode.
- void set_deterministic(bool value) {
- error::wrap_c_api(dnnl_primitive_attr_set_deterministic(
- get(), static_cast<int>(value)),
- "could not set deterministic primitive attribute");
- }
- /// Returns the rounding mode attribute value
- ///
- /// @param arg Argument for which rounding mode query applies.
- /// @returns The rounding mode applied to the specified argument.
- rounding_mode get_rounding_mode(int arg) const {
- dnnl_rounding_mode_t result;
- error::wrap_c_api(dnnl_primitive_attr_get_rounding(get(), arg, &result),
- "could not get rounding mode primitive attribute");
- return rounding_mode(result);
- }
- /// Sets the rounding mode attribute value for a given argument
- ///
- /// @param arg Argument for which to set rounding mode.
- /// @param mode Rounding mode to apply.
- void set_rounding_mode(int arg, rounding_mode mode) {
- error::wrap_c_api(dnnl_primitive_attr_set_rounding(
- get(), arg, convert_to_c(mode)),
- "could not set rounding mode primitive attribute");
- }
- /// Returns the scratchpad mode.
- scratchpad_mode get_scratchpad_mode() const {
- dnnl_scratchpad_mode_t result;
- error::wrap_c_api(
- dnnl_primitive_attr_get_scratchpad_mode(get(), &result),
- "could not get scratchpad mode primitive attribute");
- return scratchpad_mode(result);
- }
- /// Sets scratchpad mode.
- ///
- /// @param mode Specified scratchpad mode.
- void set_scratchpad_mode(scratchpad_mode mode) {
- error::wrap_c_api(dnnl_primitive_attr_set_scratchpad_mode(
- get(), dnnl::convert_to_c(mode)),
- "could not set scratchpad mode primitive attribute");
- }
- /// Sets scaling factors for primitive operations for a given memory
- /// argument. The scaling factors must be passed at execution time
- /// as an argument with index #DNNL_ARG_ATTR_SCALES | arg.
- ///
- /// @sa dnnl_primitive_attr_set_scales_mask
- ///
- /// @param arg Parameter argument index as passed to the
- /// primitive::execute() call.
- /// @param mask Scaling factors correspondence mask that defines the
- /// correspondence between the tensor dimensions and the @p scales
- /// vector. The set i-th bit indicates that a dedicated scaling factor
- /// is used for each index along that dimension. Set the mask to 0 to
- /// use a common scaling factor for the whole output tensor.
- void set_scales_mask(int arg, int mask) {
- error::wrap_c_api(dnnl_primitive_attr_set_scales_mask(get(), arg, mask),
- "could not set scales primitive attribute");
- }
- /// Sets scaling factors for primitive operations for a given memory
- /// argument. The scaling factors must be passed at execution time
- /// as an argument with index #DNNL_ARG_ATTR_SCALES | arg.
- ///
- /// @note If `is_on_host` is true, sets a single host-side scalar scaling
- /// factor for the specified memory argument. The scaling factor should be
- /// passed as a host scalar memory object at execution time with index
- /// #DNNL_ARG_ATTR_SCALES | arg.
- ///
- /// @sa dnnl_primitive_attr_set_scales_v2
- ///
- /// @param arg Parameter argument index as passed to the
- /// primitive::execute() call.
- /// @param mask Scales correspondence mask that defines the
- /// correspondence between the tensor dimensions and the @p
- /// scales vector. The set i-th bit indicates that a dedicated
- /// scale is used for each index along that dimension. Set the
- /// mask to 0 to use a common scale for the whole output tensor.
- /// @param groups Scaling factors correspondence groups that define the
- /// correspondence between the tensor dimensions and the scales array.
- /// The set i-th dimension indicates a number of groups of scaling
- /// factors used for that logical dimension in a memory indicated by @p arg.
- /// @param data_type Scaling factors data_type.
- /// @param is_on_host Indicates whether the scaling factor is a host-side scalar.
- void set_scales(int arg, int mask, const memory::dims &groups,
- memory::data_type data_type = memory::data_type::f32,
- bool is_on_host = false) {
- error::wrap_c_api(dnnl_primitive_attr_set_scales_v2(get(), arg, mask,
- (int)groups.size(), groups.data(),
- memory::convert_to_c(data_type), is_on_host),
- "could not set scales primitive attribute");
- }
- /// Sets a single host-side scalar scaling
- /// factor for the specified memory argument. The scaling factor should be
- /// passed as a host scalar memory object at execution time with index
- /// #DNNL_ARG_ATTR_SCALES | arg.
- ///
- /// @note Using this API to set the scaling factor implies that the scales
- /// attribute has `mask == 0` and an empty groups vector.
- ///
- /// @sa dnnl_primitive_attr_set_scales_v2
- ///
- /// @param arg Parameter argument index as passed to the
- /// primitive::execute() call.
- /// @param data_type Scaling factors data_type.
- void set_host_scale(
- int arg, memory::data_type data_type = memory::data_type::f32) {
- error::wrap_c_api(dnnl_primitive_attr_set_scales_v2(get(), arg, 0, 0,
- nullptr, memory::convert_to_c(data_type), 1),
- "could not set scales primitive attribute");
- }
- /// Sets zero points for primitive operations for a given memory argument.
- /// The zero points must be passed at execution time as an argument with
- /// index #DNNL_ARG_ATTR_ZERO_POINTS | arg.
- ///
- /// @sa dnnl_primitive_attr_set_zero_points_mask
- ///
- /// @param arg Parameter argument index as passed to the
- /// primitive::execute() call.
- /// @param mask Zero point correspondence mask that defines the
- /// correspondence between the tensor dimensions and the @p
- /// zero_points vector. The set i-th bit indicates that a dedicated
- /// zero point is used for each index along that dimension. Set the
- /// mask to 0 to use a common zero point for the whole output tensor.
- void set_zero_points_mask(int arg, int mask) {
- error::wrap_c_api(
- dnnl_primitive_attr_set_zero_points_mask(get(), arg, mask),
- "could not set zero points primitive attribute");
- }
- /// Sets zero points for primitive operations for a given memory argument.
- /// The zero points must be passed at execution time as an argument with
- /// index #DNNL_ARG_ATTR_ZERO_POINTS | arg.
- ///
- /// @note If `is_on_host` is true, sets a single host-side zero point
- /// for the specified memory argument. The zero point should be
- /// passed as a host scalar memory object at execution time with index
- /// #DNNL_ARG_ATTR_ZERO_POINTS | arg.
- ///
- /// @sa dnnl_primitive_attr_set_zero_points
- ///
- /// @param arg Parameter argument index as passed to the
- /// primitive::execute() call.
- /// @param mask Zero point correspondence mask that defines the
- /// correspondence between the tensor dimensions and the zero points
- /// vector. The set i-th bit indicates that a dedicated zero point is
- /// used for each index along that dimension. Set the mask to 0 to use
- /// a common zero point for the whole output tensor.
- /// @param groups Zero point factors correspondence groups that define the
- /// correspondence between the tensor dimensions and the zero points
- /// array.
- /// The set i-th dimension indicates a number of groups of zero point
- /// factors used for that logical dimension in a memory indicated by
- /// @p arg.
- /// @param data_type Zero point factors data_type.
- /// @param is_on_host Indicates whether the zero point is a host-side scalar.
- void set_zero_points(int arg, int mask, const memory::dims &groups,
- memory::data_type data_type = memory::data_type::s32,
- bool is_on_host = false) {
- error::wrap_c_api(dnnl_primitive_attr_set_zero_points_v2(get(), arg,
- mask, (int)groups.size(), groups.data(),
- memory::convert_to_c(data_type), is_on_host),
- "could not set zero points primitive attribute");
- }
- /// Sets a single host-side zero point for the specified memory argument.
- /// The zero point should be passed as a host scalar memory object at
- /// execution time with index #DNNL_ARG_ATTR_ZERO_POINTS | arg.
- ///
- /// @note Using this API to set the zero point implies that the zero
- /// point attribute has `mask == 0` and an empty groups vector.
- ///
- /// @sa dnnl_primitive_attr_set_zero_points_v2
- ///
- /// @param arg Parameter argument index as passed to the
- /// primitive::execute() call.
- /// @param data_type Zero point data type.
- void set_host_zero_point(
- int arg, memory::data_type data_type = memory::data_type::s32) {
- error::wrap_c_api(
- dnnl_primitive_attr_set_zero_points_v2(get(), arg, 0, 0,
- nullptr, memory::convert_to_c(data_type), 1),
- "could not set zero points primitive attribute");
- }
- /// Sets precomputed reductions for primitive operations for a given memory
- /// argument. The precomputed reductions must be passed at execution time as
- /// an argument with index #DNNL_ARG_ATTR_PRECOMPUTED_REDUCTIONS | arg.
- ///
- /// @sa dnnl_primitive_attr_set_precomputed_reductions
- ///
- /// @param arg Parameter argument index as passed to the
- /// primitive::execute() call.
- /// @param mask Precomputed reductions correspondence mask that defines the
- /// correspondence between the tensor dimensions and the precomputed
- /// reductions vector. The set i-th bit indicates that a dedicated
- /// precomputed reduction point is used for each index along that
- /// dimension.
- /// @param groups Precomputed reduction factors correspondence groups that
- /// define the correspondence between the tensor dimensions and the
- /// precomputed reductions array.
- /// The set i-th dimension indicates a number of groups of precomputed
- /// reduction factors used for that logical dimension in a memory
- /// indicated by @p arg.
- /// @param data_type Precomputed reduction factors data_type.
- void set_precomputed_reductions(int arg, int mask,
- const memory::dims &groups,
- memory::data_type data_type = memory::data_type::s32) {
- error::wrap_c_api(dnnl_primitive_attr_set_precomputed_reductions(get(),
- arg, mask, (int)groups.size(), groups.data(),
- memory::convert_to_c(data_type)),
- "could not set precomputed reductions primitive attribute");
- }
- /// Returns post-ops previously set via set_post_ops().
- ///
- /// @returns Post-ops.
- post_ops get_post_ops() const {
- const_dnnl_post_ops_t const_c_post_ops;
- error::wrap_c_api(
- dnnl_primitive_attr_get_post_ops(get(), &const_c_post_ops),
- "could not get post-ops primitive attribute");
- dnnl_post_ops_t c_post_ops;
- error::wrap_c_api(dnnl_post_ops_clone(&c_post_ops, const_c_post_ops),
- "could not clone post-ops primitive attribute");
- return post_ops(c_post_ops);
- }
- /// Sets post-ops.
- ///
- /// @note
- /// There is no way to check whether the post-ops would be supported
- /// by the target primitive. Any error will be reported
- /// by the respective primitive descriptor constructor.
- ///
- /// @param ops Post-ops object to copy post-ops from.
- void set_post_ops(const post_ops &ops) {
- error::wrap_c_api(dnnl_primitive_attr_set_post_ops(get(), ops.get()),
- "could not set post-ops primitive attribute");
- }
- /// Sets quantization scale and shift parameters for RNN data tensors.
- ///
- /// For performance reasons, the low-precision configuration of the RNN
- /// primitives expect input activations to have the unsigned 8-bit integer
- /// data type. The scale and shift parameters are used to quantize
- /// floating-point data to unsigned integer and must be passed to the RNN
- /// primitive using attributes.
- ///
- /// The quantization formula is `scale * data + shift`.
- ///
- /// Example usage:
- /// @code
- /// // RNN parameters
- /// int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32;
- /// // Activations quantization parameters
- /// float scale = 63.f, shift = 64.f;
- ///
- /// primitive_attr attr;
- ///
- /// // Set scale and shift for int8 quantization of activation
- /// attr.set_rnn_data_qparams(scale, shift);
- ///
- /// // Create an RNN primitive descriptor.
- /// vanilla_rnn_forward::primitive_desc rnn_d(
- /// engine, /* arguments */, attr);
- /// @endcode
- ///
- /// @note
- /// Quantization scale and shift are common for src_layer, src_iter,
- /// dst_iter, and dst_layer.
- ///
- /// @param scale The value to scale the data by.
- /// @param shift The value to shift the data by.
- void set_rnn_data_qparams(float scale, float shift) {
- error::wrap_c_api(
- dnnl_primitive_attr_set_rnn_data_qparams(get(), scale, shift),
- "could not set RNN data quantization parameters primitive "
- "attribute");
- }
- /// Returns the quantization scale and shift parameters for RNN data
- /// tensors.
- ///
- /// @note
- /// Quantization scale and shift are common for src_layer, src_iter,
- /// dst_iter, and dst_layer.
- ///
- /// @param scale The value to scale the data by.
- /// @param shift The value to shift the data by.
- void get_rnn_data_qparams(float &scale, float &shift) {
- float c_scale, c_shift;
- error::wrap_c_api(dnnl_primitive_attr_get_rnn_data_qparams(
- get(), &c_scale, &c_shift),
- "could not set RNN data quantization parameters primitive "
- "attribute");
- scale = c_scale;
- shift = c_shift;
- }
- /// Sets quantization scaling factors for RNN weights tensors. The
- /// low-precision configuration of the RNN primitives expect input weights
- /// to use the signed 8-bit integer data type. The scaling factors are
- /// used to quantize floating-point data to signed integer and must be
- /// passed to RNN primitives using attributes.
- ///
- /// @note
- /// The dimension order is always native and does not depend on the
- /// actual layout used. For example, five-dimensional weights always
- /// have (l, d, i, g, o) logical dimension ordering.
- ///
- /// @note
- /// Quantization scales are common for weights_layer and
- /// weights_iteration
- ///
- /// @param mask Scaling factors correspondence mask that defines the
- /// correspondence between the output tensor dimensions and the @p
- /// scales vector. The set i-th bit indicates that a dedicated scaling
- /// factor should be used each index along that dimension. Set the
- /// mask to 0 to use a common scaling factor for the whole output
- /// tensor.
- /// @param scales Constant vector of output scaling factors. The following
- /// equality must hold:
- /// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$
- /// Violations can only be detected when the attributes are used to
- /// create a primitive descriptor.
- void set_rnn_weights_qparams(int mask, const std::vector<float> &scales) {
- error::wrap_c_api(dnnl_primitive_attr_set_rnn_weights_qparams(get(),
- (int)scales.size(), mask, scales.data()),
- "could not set RNN weights quantization parameters primitive "
- "attribute");
- }
- /// Returns the quantization scaling factors for RNN projection weights
- /// tensors.
- ///
- /// @note
- /// The dimension order is always native and does not depend on the
- /// actual layout used. For example, five-dimensional weights always
- /// have (l, d, i, g, o) logical dimension ordering.
- ///
- /// @param mask Scaling factors correspondence mask that defines the
- /// correspondence between the output tensor dimensions and the @p
- /// scales vector. The set i-th bit indicates that a dedicated scaling
- /// factor should be used each index along that dimension. Set the
- /// mask to 0 to use a common scaling factor for the whole output
- /// tensor.
- /// @param scales Constant vector of output scaling factors. The following
- /// equality must hold:
- /// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$
- /// Violations can only be detected when the attributes are used to
- /// create a primitive descriptor.
- void get_rnn_weights_qparams(int &mask, std::vector<float> &scales) {
- dnnl_dim_t count;
- int c_mask;
- const float *c_scales;
- error::wrap_c_api(dnnl_primitive_attr_get_rnn_weights_qparams(
- get(), &count, &c_mask, &c_scales),
- "could not get primitive RNN weights quantization "
- "parameters attributes");
- scales.resize(count);
- mask = c_mask;
- for (dnnl_dim_t c = 0; c < count; c++)
- scales[c] = c_scales[c];
- }
- /// Sets quantization scaling factors for RNN projection weights tensors.
- // The low-precision configuration of the RNN primitives expect input
- // weights to use the signed 8-bit integer data type. The scaling factors
- // are used to quantize floating-point data to signed integer and must be
- /// passed to RNN primitives using attributes.
- ///
- /// @note
- /// The dimension order is always native and does not depend on the
- /// actual layout used. For example, five-dimensional weights always
- /// have (l, d, i, g, o) logical dimension ordering.
- ///
- /// @note
- /// Quantization scales are common for weights_layer and
- /// weights_iteration
- ///
- /// @param mask Scaling factors correspondence mask that defines the
- /// correspondence between the output tensor dimensions and the @p
- /// scales vector. The set i-th bit indicates that a dedicated scaling
- /// factor should be used each index along that dimension. Set the
- /// mask to 0 to use a common scaling factor for the whole output
- /// tensor.
- /// @param scales Constant vector of output scaling factors. The following
- /// equality must hold:
- /// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$
- /// Violations can only be detected when the attributes are used to
- /// create a primitive descriptor.
- void set_rnn_weights_projection_qparams(
- int mask, const std::vector<float> &scales) {
- error::wrap_c_api(
- dnnl_primitive_attr_set_rnn_weights_projection_qparams(
- get(), (int)scales.size(), mask, scales.data()),
- "could not set primitive RNN weights projection quantization "
- "parameters attributes");
- }
- /// Returns the quantization scaling factors for RNN projection weights
- /// tensors.
- ///
- /// @note
- /// The dimension order is always native and does not depend on the
- /// actual layout used. For example, five-dimensional weights always
- /// have (l, d, i, g, o) logical dimension ordering.
- ///
- /// @param mask Scaling factors correspondence mask that defines the
- /// correspondence between the output tensor dimensions and the @p
- /// scales vector. The set i-th bit indicates that a dedicated scaling
- /// factor should be used each index along that dimension. Set the
- /// mask to 0 to use a common scaling factor for the whole output
- /// tensor.
- /// @param scales Constant vector of output scaling factors. The following
- /// equality must hold:
- /// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$
- /// Violations can only be detected when the attributes are used to
- /// create a primitive descriptor.
- void get_rnn_weights_projection_qparams(
- int &mask, std::vector<float> &scales) {
- dnnl_dim_t count;
- int c_mask;
- const float *c_scales;
- error::wrap_c_api(
- dnnl_primitive_attr_get_rnn_weights_projection_qparams(
- get(), &count, &c_mask, &c_scales),
- "could not get primitive RNN weights projection quantization "
- "parameters attributes");
- scales.resize(count);
- mask = c_mask;
- for (dnnl_dim_t c = 0; c < count; c++)
- scales[c] = c_scales[c];
- }
- };
- /// @} dnnl_api_attributes
- /// @addtogroup dnnl_api_primitives_common
- /// @{
- /// Base class for all primitive descriptors.
- struct primitive_desc_base : public handle<dnnl_primitive_desc_t> {
- using handle<dnnl_primitive_desc_t>::handle;
- /// Default constructor. Produces an empty object.
- primitive_desc_base() = default;
- /// Returns the engine of the primitive descriptor.
- /// @returns The engine of the primitive descriptor.
- engine get_engine() const { return query_engine(query::engine); }
- /// Returns implementation name.
- /// @returns The implementation name.
- const char *impl_info_str() const {
- const char *res;
- error::wrap_c_api(dnnl_primitive_desc_query(
- get(), dnnl_query_impl_info_str, 0, &res),
- "could not retrieve implementation info string from a "
- "primitive descriptor");
- return res;
- }
- /// Returns a memory::dim value (same as int64_t).
- /// @param what The value to query.
- /// @returns The result of the query.
- memory::dim query_s64(query what) const {
- memory::dim res;
- dnnl_status_t status = dnnl_primitive_desc_query(
- get(), dnnl::convert_to_c(what), 0, &res);
- return status == dnnl_success ? res : 0;
- }
- /// Returns strides.
- /// @returns Strides.
- /// @returns An empty #dnnl::memory::dims if the primitive does not have
- /// a strides parameter.
- memory::dims get_strides() const { return query_dims(query::strides); }
- /// Returns dilations.
- /// @returns Dilations.
- /// @returns An empty #dnnl::memory::dims if the primitive does not have
- /// a dilations parameter.
- memory::dims get_dilations() const { return query_dims(query::dilations); }
- /// Returns a left padding.
- /// @returns A left padding.
- /// @returns An empty #dnnl::memory::dims if the primitive does not have
- /// a left padding parameter.
- memory::dims get_padding_l() const { return query_dims(query::padding_l); }
- /// Returns a right padding.
- /// @returns A right padding.
- /// @returns An empty #dnnl::memory::dims if the primitive does not have
- /// a right padding parameter.
- memory::dims get_padding_r() const { return query_dims(query::padding_r); }
- /// Returns an epsilon.
- /// @returns An epsilon.
- /// @returns Zero if the primitive does not have an epsilon parameter.
- float get_epsilon() const { return query_f32(query::epsilon_f32); }
- /// Returns flags.
- /// @tparam T Flags enumeration type.
- /// @returns Flags.
- /// @returns Zero if the primitive does not have a flags parameter.
- template <typename T = unsigned>
- T get_flags() const {
- unsigned res;
- dnnl_status_t status
- = dnnl_primitive_desc_query(get(), dnnl_query_flags, 0, &res);
- return static_cast<T>(status == dnnl_success ? res : 0x0U);
- }
- /// Returns an algorithm kind.
- /// @returns An algorithm kind.
- /// @returns #dnnl::algorithm::undef if the primitive does not have an
- /// algorithm parameter.
- dnnl::algorithm get_algorithm() const { return query_alg(query::alg_kind); }
- /// Returns an alpha.
- /// @returns An alpha.
- /// @returns Zero if the primitive does not have an alpha parameter.
- float get_alpha() const { return query_f32(query::alpha_f32); }
- /// Returns a beta.
- /// @returns A beta.
- /// @returns Zero if the primitive does not have a beta parameter.
- float get_beta() const { return query_f32(query::beta_f32); }
- /// Returns an axis.
- /// @returns An axis.
- /// @returns A negative number if the primitive does not have an axis
- /// parameter.
- int get_axis() const {
- int res;
- dnnl_status_t status = dnnl_primitive_desc_query(
- get(), dnnl_query_axis_s32, 0, &res);
- return status == dnnl_success ? res : -1;
- }
- /// Returns an LRN local size parameter.
- /// @returns An LRN local size parameter.
- /// @returns Zero if the primitive does not have an LRN local size
- /// parameter.
- memory::dim get_local_size() const {
- return query_s64(query::local_size_s64);
- }
- /// Returns an LRN K parameter.
- /// @returns An LRN K parameter.
- /// @returns Zero if the primitive does not have an LRN K parameter.
- float get_k() const { return query_f32(query::k_f32); }
- /// Returns a reduction P parameter.
- /// @returns A reduction P parameter.
- /// @returns Zero if the primitive does not have a reduction P parameter.
- float get_p() const { return query_f32(query::p_f32); }
- /// Returns a resampling factors parameters.
- /// @returns A vector of factors.
- /// @returns An empty vector if the primitive does not have a resampling
- /// factors parameter.
- std::vector<float> get_factors() const {
- float *factors;
- dnnl_status_t status = dnnl_primitive_desc_query(
- get(), dnnl_query_factors, 0, &factors);
- const bool is_backward = get_prop_kind() != prop_kind::forward_training
- && get_prop_kind() != prop_kind::forward_inference;
- const_dnnl_memory_desc_t md = dnnl_primitive_desc_query_md(get(),
- is_backward ? dnnl_query_diff_dst_md : dnnl_query_dst_md, 0);
- int ndims;
- error::wrap_c_api(
- dnnl_memory_desc_query(md, dnnl_query_ndims_s32, &ndims),
- "could not query ndims from a memory descriptor");
- return status == dnnl_success
- ? std::vector<float>(factors, factors + (ndims - 2))
- : std::vector<float> {};
- }
- /// Returns an RNN cell kind parameter.
- /// @returns An RNN cell kind parameter.
- /// @returns #dnnl::algorithm::undef if the primitive does not have an
- /// RNN cell kind parameter.
- dnnl::algorithm get_cell_kind() const {
- return query_alg(query::cell_kind);
- }
- /// Returns an RNN direction parameter.
- /// @returns An RNN direction parameter.
- /// @returns #dnnl::rnn_direction::undef if the primitive does not have
- /// an RNN direction parameter.
- dnnl::rnn_direction get_direction() const {
- dnnl_rnn_direction_t direction;
- dnnl_status_t status = dnnl_primitive_desc_query(
- get(), dnnl_query_direction, 0, &direction);
- return status == dnnl_success
- ? static_cast<dnnl::rnn_direction>(direction)
- : dnnl::rnn_direction::undef;
- }
- /// Returns an RNN activation kind parameter.
- /// @returns An RNN activation kind parameter.
- /// @returns #dnnl::algorithm::undef if the primitive does not have an
- /// RNN activation kind parameter.
- dnnl::algorithm get_activation_kind() const {
- return query_alg(query::activation_kind);
- }
- /// Returns a pooling kernel parameter.
- /// @returns A pooling kernel parameter.
- /// @returns An empty #dnnl::memory::dims if the primitive does not have
- /// a pooling kernel parameter.
- memory::dims get_kernel() const { return query_dims(query::kernel); }
- /// Returns a group size parameter.
- /// @returns A group size parameter.
- /// @returns Zero if the primitive does not have a group size
- /// parameter.
- memory::dim get_group_size() const {
- return query_s64(query::group_size_s64);
- }
- /// Returns a propagation kind.
- /// @returns A propagation kind.
- /// @returns #dnnl::prop_kind::undef if the primitive does not have
- /// a propagation parameter.
- dnnl::prop_kind get_prop_kind() const {
- dnnl_prop_kind_t prop_kind;
- dnnl_status_t status = dnnl_primitive_desc_query(
- get(), dnnl_query_prop_kind, 0, &prop_kind);
- return status == dnnl_success ? static_cast<dnnl::prop_kind>(prop_kind)
- : dnnl::prop_kind::undef;
- }
- /// Returns a memory descriptor.
- ///
- /// @note
- /// There are also convenience methods
- /// #dnnl::primitive_desc_base::src_desc(),
- /// #dnnl::primitive_desc_base::dst_desc(), and others.
- ///
- /// @param what The kind of parameter to query; can be
- /// #dnnl::query::src_md, #dnnl::query::dst_md, etc.
- /// @param idx Index of the parameter. For example, convolution bias can
- /// be queried with what = #dnnl::query::weights_md and idx = 1.
- /// @returns The requested memory descriptor.
- /// @returns A zero memory descriptor if the primitive does not have a
- /// parameter of the specified kind or index.
- memory::desc query_md(query what, int idx = 0) const {
- std::vector<query> valid_q {query::src_md, query::diff_src_md,
- query::weights_md, query::diff_weights_md, query::dst_md,
- query::diff_dst_md, query::workspace_md, query::scratchpad_md,
- query::exec_arg_md};
- if (!std::any_of(valid_q.cbegin(), valid_q.cend(),
- [=](query q) { return what == q; }))
- DNNL_THROW_ERROR(dnnl_invalid_arguments,
- "memory descriptor query is invalid");
- const_dnnl_memory_desc_t cdesc = dnnl_primitive_desc_query_md(
- get(), dnnl::convert_to_c(what), idx);
- if (!cdesc) return memory::desc();
- dnnl_memory_desc_t cloned_md = nullptr;
- error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc),
- "could not clone a memory descriptor");
- return memory::desc(cloned_md);
- }
- /// Returns a source memory descriptor.
- /// @param idx Source index.
- /// @returns Source memory descriptor.
- /// @returns A zero memory descriptor if the primitive does not have a
- /// source parameter with index @p idx.
- memory::desc src_desc(int idx) const {
- return query_md(query::src_md, idx);
- }
- /// Returns a destination memory descriptor.
- /// @param idx Destination index.
- /// @returns Destination memory descriptor.
- /// @returns A zero memory descriptor if the primitive does not have a
- /// destination parameter with index @p idx.
- memory::desc dst_desc(int idx) const {
- return query_md(query::dst_md, idx);
- }
- /// Returns a weights memory descriptor.
- /// @param idx Weights index.
- /// @returns Weights memory descriptor.
- /// @returns A zero memory descriptor if the primitive does not have a
- /// weights parameter with index @p idx.
- memory::desc weights_desc(int idx) const {
- return query_md(query::weights_md, idx);
- }
- /// Returns a diff source memory descriptor.
- /// @param idx Diff source index.
- /// @returns Diff source memory descriptor.
- /// @returns A zero memory descriptor if the primitive does not have a
- /// diff source parameter with index @p idx.
- memory::desc diff_src_desc(int idx) const {
- return query_md(query::diff_src_md, idx);
- }
- /// Returns a diff destination memory descriptor.
- /// @param idx Diff destination index.
- /// @returns Diff destination memory descriptor.
- /// @returns A zero memory descriptor if the primitive does not have a
- /// diff destination parameter with index @p idx.
- memory::desc diff_dst_desc(int idx) const {
- return query_md(query::diff_dst_md, idx);
- }
- /// Returns a diff weights memory descriptor.
- /// @param idx Diff weights index.
- /// @returns Diff weights memory descriptor.
- /// @returns A zero memory descriptor if the primitive does not have a
- /// diff weights parameter with index @p idx.
- memory::desc diff_weights_desc(int idx) const {
- return query_md(query::diff_weights_md, idx);
- }
- // Separate versions without the index argument for documentation
- // purposes.
- /// Returns a source memory descriptor.
- /// @returns Source memory descriptor.
- /// @returns A zero memory descriptor if the primitive does not have a
- /// source parameter.
- memory::desc src_desc() const { return src_desc(0); }
- /// Returns a destination memory descriptor.
- /// @returns Destination memory descriptor.
- /// @returns A zero memory descriptor if the primitive does not have a
- /// destination parameter.
- memory::desc dst_desc() const { return dst_desc(0); }
- /// Returns a weights memory descriptor.
- /// @returns Weights memory descriptor.
- /// @returns A zero memory descriptor if the primitive does not have a
- /// weights parameter.
- memory::desc weights_desc() const { return weights_desc(0); }
- /// Returns a diff source memory descriptor.
- /// @returns Diff source memory descriptor.
- /// @returns A zero memory descriptor if the primitive does not have a
- /// diff source memory with.
- memory::desc diff_src_desc() const { return diff_src_desc(0); }
- /// Returns a diff destination memory descriptor.
- /// @returns Diff destination memory descriptor.
- /// @returns A zero memory descriptor if the primitive does not have a
- /// diff destination parameter.
- memory::desc diff_dst_desc() const { return diff_dst_desc(0); }
- /// Returns a diff weights memory descriptor.
- /// @returns Diff weights memory descriptor.
- /// @returns A zero memory descriptor if the primitive does not have a
- /// diff weights parameter.
- memory::desc diff_weights_desc() const { return diff_weights_desc(0); }
- /// Returns the workspace memory descriptor.
- /// @returns Workspace memory descriptor.
- /// @returns A zero memory descriptor if the primitive does not require
- /// workspace parameter.
- memory::desc workspace_desc() const {
- return query_md(query::workspace_md, 0);
- }
- /// Returns the scratchpad memory descriptor.
- /// @returns scratchpad memory descriptor.
- /// @returns A zero memory descriptor if the primitive does not require
- /// scratchpad parameter.
- /// @sa @ref dev_guide_attributes_scratchpad
- memory::desc scratchpad_desc() const {
- return query_md(query::scratchpad_md, 0);
- }
- /// Returns the engine on which the scratchpad memory is located.
- /// @returns The engine on which the scratchpad memory is located.
- engine scratchpad_engine() const {
- dnnl_engine_t c_engine;
- error::wrap_c_api(dnnl_primitive_desc_query(get(),
- dnnl::convert_to_c(query::scratchpad_engine),
- 0, &c_engine),
- "could not retrieve scratchpad engine from a primitive "
- "descriptor");
- return engine(c_engine, true);
- }
- /// Returns the primitive attributes.
- /// @returns The primitive attributes.
- primitive_attr get_primitive_attr() const {
- const_dnnl_primitive_attr_t const_c_attr;
- error::wrap_c_api(dnnl_primitive_desc_get_attr(get(), &const_c_attr),
- "could not get attributes from a primitive descriptor");
- dnnl_primitive_attr_t c_attr;
- error::wrap_c_api(dnnl_primitive_attr_clone(&c_attr, const_c_attr),
- "could not clone primitive attributes");
- return primitive_attr(c_attr);
- }
- /// Returns the kind of the primitive descriptor.
- /// @returns The kind of the primitive descriptor.
- dnnl::primitive::kind get_kind() const {
- dnnl_primitive_kind_t kind;
- error::wrap_c_api(dnnl_primitive_desc_query(get(),
- dnnl_query_primitive_kind, 0, (void *)&kind),
- "could not get primitive kind from a primitive descriptor");
- return static_cast<dnnl::primitive::kind>(kind);
- }
- /// Returns the cache blob ID of the primitive descriptor.
- /// @returns The cache blob ID of the primitive descriptor.
- std::vector<uint8_t> get_cache_blob_id() const {
- dnnl_dim_t count;
- const uint8_t *c_id;
- error::wrap_c_api(
- dnnl_primitive_desc_query(get(),
- dnnl::convert_to_c(query::cache_blob_id_size_s64), 0,
- (void *)&count),
- "could not get size of cache blob ID from a primitive "
- "descriptor");
- error::wrap_c_api(dnnl_primitive_desc_query(get(),
- dnnl::convert_to_c(query::cache_blob_id), 0,
- (void **)&c_id),
- "could not get cache blob ID from a primitive descriptor");
- std::vector<uint8_t> id(c_id, c_id + count);
- return id;
- }
- protected:
- /// Returns a float value.
- /// @param what The value to query.
- /// @returns The result of the query.
- /// @returns Zero if the primitive doesn't support the query.
- float query_f32(query what) const {
- float res;
- dnnl_status_t status = dnnl_primitive_desc_query(
- get(), dnnl::convert_to_c(what), 0, &res);
- return status == dnnl_success ? res : 0.0f;
- }
- /// Returns an #dnnl::algorithm value.
- /// @param what The value to query.
- /// @returns The result of the query.
- /// @returns #dnnl::algorithm::undef if the primitive doesn't support
- /// the query.
- algorithm query_alg(query what) const {
- dnnl_alg_kind_t res;
- dnnl_status_t status = dnnl_primitive_desc_query(
- get(), dnnl::convert_to_c(what), 0, &res);
- return status == dnnl_success ? static_cast<dnnl::algorithm>(res)
- : algorithm::undef;
- }
- /// Returns a memory::dims value.
- /// @param what The value to query.
- /// @returns The result of the query.
- /// @returns An empty #dnnl::memory::dims if the primitive doesn't support
- /// the query.
- memory::dims query_dims(query what) const {
- const bool is_backward = get_prop_kind() != prop_kind::forward_training
- && get_prop_kind() != prop_kind::forward_inference;
- const_dnnl_memory_desc_t md = dnnl_primitive_desc_query_md(get(),
- is_backward ? dnnl_query_diff_dst_md : dnnl_query_dst_md, 0);
- int nspatial_dims = 0;
- if (md) {
- int ndims;
- error::wrap_c_api(
- dnnl_memory_desc_query(md, dnnl_query_ndims_s32, &ndims),
- "could not query ndims from a memory descriptor");
- nspatial_dims = ndims - 2;
- }
- dnnl_dims_t *c_dims;
- dnnl_status_t status = dnnl_primitive_desc_query(
- get(), dnnl::convert_to_c(what), 0, &c_dims);
- return status == dnnl_success
- ? memory::dims(*c_dims, *c_dims + nspatial_dims)
- : memory::dims {};
- }
- /// Returns an #dnnl::engine value.
- /// @param what The value to query.
- /// @returns The result of the query.
- /// @returns A weak handle to the engine that the primitive descriptor was
- /// created with.
- engine query_engine(query what) const {
- dnnl_engine_t c_engine;
- error::wrap_c_api(dnnl_primitive_desc_query(get(),
- dnnl::convert_to_c(what), 0, &c_engine),
- "could not get an engine from a primitive_desc");
- return engine(c_engine, true);
- }
- /// Resets the value of the handle to a clone of a C API primitive
- /// descriptor.
- /// @param pd A C API primitive descriptor to clone.
- void reset_with_clone(const_dnnl_primitive_desc_t pd) {
- dnnl_primitive_desc_t new_pd;
- error::wrap_c_api(dnnl_primitive_desc_clone(&new_pd, pd),
- "could not clone a primitive descriptor");
- reset(new_pd);
- }
- /// Constructs a primitive descriptor base object from a clone of a C API
- /// primitive descriptor after verifying that it is what the caller
- /// expects.
- ///
- /// @note
- /// The @p prim_kind should map to a primitive that does not have
- /// different values of propagation kind (e.g. #dnnl::binary).
- /// @note
- /// Primitive descriptor base constructed this way does not support
- /// next_impl() (will throw).
- ///
- /// @param pd C API primitive descriptor to clone.
- /// @param prim_kind Expected primitive kind.
- primitive_desc_base(
- dnnl_primitive_desc_t pd, dnnl::primitive::kind prim_kind)
- : primitive_desc_base(pd, prim_kind, dnnl::prop_kind::undef) {}
- /// Constructs a primitive descriptor base object from a clone of a C API
- /// primitive descriptor after verifying that it is what the caller
- /// expects.
- ///
- /// @note
- /// Primitive descriptor base constructed this way does not support
- /// next_impl() (will throw).
- ///
- /// @param pd C API primitive descriptor to clone.
- /// @param prim_kind Expected primitive kind.
- /// @param aprop_kind Expected propagation kind.
- primitive_desc_base(dnnl_primitive_desc_t pd,
- dnnl::primitive::kind prim_kind, dnnl::prop_kind aprop_kind)
- : primitive_desc_base(pd, prim_kind, aprop_kind, aprop_kind) {}
- /// Constructs a primitive descriptor base object from a clone of a C API
- /// primitive descriptor after verifying that it is what the caller
- /// expects.
- ///
- /// @note
- /// Primitive descriptor base constructed this way does not support
- /// next_impl() (will throw).
- ///
- /// @param pd C API primitive descriptor to clone.
- /// @param prim_kind Expected primitive kind.
- /// @param prop_kind1 Expected propagation kind (option 1).
- /// @param prop_kind2 Expected propagation kind (option 2). This value is
- /// checked if the check with @p prop_kind1 fails.
- primitive_desc_base(dnnl_primitive_desc_t pd,
- dnnl::primitive::kind prim_kind, dnnl::prop_kind prop_kind1,
- dnnl::prop_kind prop_kind2) {
- // It is OK to pass an empty primitive descriptor
- if (pd == nullptr) return;
- dnnl_status_t rc;
- dnnl_primitive_kind_t c_prim_kind = convert_to_c(prim_kind);
- dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1);
- dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2);
- // Check that primitive kind matches
- dnnl_primitive_kind_t pd_kind;
- rc = dnnl_primitive_desc_query(
- pd, dnnl_query_primitive_kind, 0, (void *)&pd_kind);
- error::wrap_c_api(
- rc, "could not get primitive kind from a primitive descriptor");
- if (pd_kind != c_prim_kind)
- DNNL_THROW_ERROR(dnnl_invalid_arguments,
- "primitive descriptor operation kind mismatch");
- // Check that propagation kind matches
- dnnl_prop_kind_t pd_prop_kind;
- rc = dnnl_primitive_desc_query(
- pd, dnnl_query_prop_kind, 0, (void *)&pd_prop_kind);
- // Something went wrong
- if (rc != dnnl_success && rc != dnnl_unimplemented)
- DNNL_THROW_ERROR(dnnl_invalid_arguments,
- "could not get propagation kind from the primitive "
- "descriptor");
- // Everything is fine
- if ((rc == dnnl_unimplemented && c_prop_kind1 == dnnl_prop_kind_undef)
- || (rc == dnnl_success
- && (pd_prop_kind == c_prop_kind1
- || pd_prop_kind == c_prop_kind2))) {
- reset_with_clone(pd);
- return;
- }
- // We could get the propagation kind but there is a mismatch
- DNNL_THROW_ERROR(dnnl_invalid_arguments,
- "primitive descriptor propagation kind mismatch");
- }
- /// Returns a constant reference to a static instance of default constructed
- /// primitive attributes
- static const primitive_attr &default_attr() {
- static const primitive_attr attr;
- return attr;
- }
- const_dnnl_memory_desc_t optional_arg(const memory::desc *md) {
- return md ? md->get() : nullptr;
- }
- const dnnl_dim_t *optional_arg(const memory::dims *dims) {
- return dims ? dims->data() : nullptr;
- }
- const float *optional_arg(const std::vector<float> *arg) {
- return arg ? arg->data() : nullptr;
- }
- using base = primitive_desc_base;
- };
- /// @} dnnl_api_primitives_common
- /// @addtogroup dnnl_api_reorder Reorder
- ///
- /// A primitive to copy data between two memory objects. This primitive is
- /// typically used to change the way the data is laid out in memory.
- ///
- /// @sa @ref dev_guide_reorder in developer guide
- ///
- /// @{
- /// Reorder primitive.
- struct reorder : public primitive {
- /// Primitive descriptor for a reorder primitive.
- struct primitive_desc : public primitive_desc_base {
- using primitive_desc_base::primitive_desc_base;
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for reorder primitive.
- ///
- /// @note
- /// If @p allow_empty is true, the constructor does not throw if a
- /// primitive descriptor cannot be created.
- ///
- /// @param src_engine Engine on which the source memory object will be
- /// located.
- /// @param src_md Source memory descriptor.
- /// @param dst_engine Engine on which the destination memory object
- /// will be located.
- /// @param dst_md Destination memory descriptor.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is allowed
- /// to fail without throwing an exception. In this case an empty
- /// object will be produced. This flag is optional and defaults to
- /// false.
- primitive_desc(const engine &src_engine, const memory::desc &src_md,
- const engine &dst_engine, const memory::desc &dst_md,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false) {
- dnnl_primitive_desc_t result;
- dnnl_status_t status = dnnl_reorder_primitive_desc_create(&result,
- src_md.get(), src_engine.get(), dst_md.get(),
- dst_engine.get(), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the reorder primitive. Run workload with "
- "environment variable ONEDNN_VERBOSE=all to get "
- "additional diagnostic information.");
- reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
- }
- /// Constructs a primitive descriptor for reorder primitive.
- ///
- /// @param src Source memory object. It is used to obtain the source
- /// memory descriptor and engine.
- /// @param dst Destination memory object. It is used to obtain the
- /// destination memory descriptor and engine.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is allowed
- /// to fail without throwing an exception. In this case an empty
- /// object will be produced. This flag is optional and defaults to
- /// false.
- primitive_desc(const memory &src, const memory &dst,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false) {
- dnnl_primitive_desc_t result;
- auto src_md = src.get_desc();
- auto dst_md = dst.get_desc();
- dnnl_status_t status = dnnl_reorder_primitive_desc_create(&result,
- src_md.get(), src.get_engine().get(), dst_md.get(),
- dst.get_engine().get(), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the reorder primitive. Run workload with "
- "environment variable ONEDNN_VERBOSE=all to get "
- "additional diagnostic information.");
- reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
- }
- /// Constructs a primitive descriptor for reorder primitive from a C
- /// API primitive descriptor which must have a matching kind.
- ///
- /// @param pd C API primitive descriptor for reorder primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : primitive_desc_base(pd, dnnl::primitive::kind::reorder) {}
- /// Returns the engine on which the source memory is allocated.
- /// @returns The engine on which the source memory is allocated.
- engine get_src_engine() const {
- return query_engine(dnnl::query::reorder_src_engine);
- }
- /// Returns the engine on which the destination memory is allocated.
- /// @returns The engine on which the destination memory is allocated.
- engine get_dst_engine() const {
- return query_engine(dnnl::query::reorder_dst_engine);
- }
- /// @copydoc dnnl::primitive_desc_base::src_desc()const
- memory::desc src_desc() const { return base::src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::dst_desc()const
- memory::desc dst_desc() const { return base::dst_desc(0); }
- };
- /// Default constructor. Produces an empty object.
- reorder() = default;
- /// Constructs a reorder primitive.
- /// @param pd Primitive descriptor for reorder primitive.
- reorder(const primitive_desc &pd) : primitive(pd.get()) {}
- /// Constructs a reorder primitive from a cache blob.
- /// @param pd Primitive descriptor for reorder primitive.
- /// @param cache_blob Cache blob.
- reorder(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd.get(), cache_blob) {}
- /// Constructs a reorder primitive that would reorder data between memory
- /// objects having the same memory descriptors as memory objects @p src and
- /// @p dst.
- ///
- /// @param src Source memory object.
- /// @param dst Destination memory object.
- /// @param attr Primitive attributes to use (optional).
- reorder(const memory &src, const memory &dst,
- const primitive_attr &attr = primitive_attr())
- : primitive(primitive_desc(src, dst, attr).get()) {}
- using primitive::execute;
- /// Executes the reorder primitive.
- ///
- /// @param astream Stream object. The stream must belong to the same engine
- /// as the primitive.
- /// @param src Source memory object.
- /// @param dst Destination memory object.
- void execute(const stream &astream, memory &src, memory &dst) const {
- primitive::execute(astream, {{DNNL_ARG_FROM, src}, {DNNL_ARG_TO, dst}});
- }
- };
- /// @} dnnl_api_reorder
- /// @addtogroup dnnl_api_concat Concat
- ///
- /// A primitive to concatenate data by arbitrary dimension.
- ///
- /// @sa @ref dev_guide_concat in developer guide
- ///
- /// @{
- /// @cond DO_NOT_DOCUMENT_THIS
- inline std::vector<const_dnnl_memory_desc_t> convert_to_c(
- const std::vector<memory::desc> &mds) {
- std::vector<const_dnnl_memory_desc_t> c_mds;
- c_mds.reserve(mds.size());
- for (const auto &md : mds)
- c_mds.push_back(md.get());
- return c_mds;
- }
- /// @endcond
- /// Tensor concatenation (concat) primitive.
- struct concat : public primitive {
- /// Primitive descriptor for a concat primitive.
- struct primitive_desc : public primitive_desc_base {
- using primitive_desc_base::primitive_desc_base;
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for an out-of-place concatenation
- /// primitive.
- ///
- /// @param aengine Engine to perform the operation on.
- /// @param dst Destination memory descriptor.
- /// @param concat_dimension Source tensors will be concatenated over
- /// dimension with this index. Note that order of dimensions does
- /// not depend on memory format.
- /// @param srcs Vector of source memory descriptors.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, const memory::desc &dst,
- int concat_dimension, const std::vector<memory::desc> &srcs,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false) {
- auto c_srcs = convert_to_c(srcs);
- dnnl_primitive_desc_t result;
- dnnl_status_t status = dnnl_concat_primitive_desc_create(&result,
- aengine.get(), dst.get(), (int)c_srcs.size(),
- concat_dimension, c_srcs.data(), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the concat primitive. Run workload with "
- "environment variable ONEDNN_VERBOSE=all to get "
- "additional diagnostic information.");
- reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
- }
- /// Constructs a primitive descriptor for an out-of-place concatenation
- /// primitive.
- ///
- /// This version derives the destination memory descriptor
- /// automatically.
- ///
- /// @param aengine Engine to perform the operation on.
- /// @param concat_dimension Source tensors will be concatenated over
- /// dimension with this index. Note that order of dimensions does
- /// not depend on memory format.
- /// @param srcs Vector of source memory descriptors.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, int concat_dimension,
- const std::vector<memory::desc> &srcs,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false) {
- auto c_api_srcs = convert_to_c(srcs);
- dnnl_primitive_desc_t result;
- dnnl_status_t status = dnnl_concat_primitive_desc_create(&result,
- aengine.get(), nullptr, (int)c_api_srcs.size(),
- concat_dimension, c_api_srcs.data(), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the concat primitive. Run workload with "
- "environment variable ONEDNN_VERBOSE=all to get "
- "additional diagnostic information.");
- reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
- }
- /// Constructs a primitive descriptor for concat primitive from a C
- /// API primitive descriptor which must have a matching kind.
- ///
- /// @param pd C API primitive descriptor for concat primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : primitive_desc_base(pd, dnnl::primitive::kind::concat) {}
- /// @copydoc dnnl::primitive_desc_base::src_desc(int)const
- memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
- /// @copydoc dnnl::primitive_desc_base::dst_desc()const
- memory::desc dst_desc() const { return base::dst_desc(0); }
- };
- /// Default constructor. Produces an empty object.
- concat() = default;
- /// Constructs a concatenation primitive.
- /// @param pd Primitive descriptor for concatenation primitive.
- concat(const primitive_desc &pd) : primitive(pd.get()) {}
- /// Constructs a concatenation primitive from a cache blob.
- /// @param pd Primitive descriptor for concatenation primitive.
- /// @param cache_blob Cache blob.
- concat(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd.get(), cache_blob) {}
- };
- /// @} dnnl_api_concat
- /// @addtogroup dnnl_api_sum Sum
- ///
- /// A primitive to sum multiple tensors.
- ///
- /// @sa @ref dev_guide_sum in developer guide
- ///
- /// @{
- /// Out-of-place summation (sum) primitive.
- struct sum : public primitive {
- /// Primitive descriptor for a sum primitive.
- struct primitive_desc : public primitive_desc_base {
- using primitive_desc_base::primitive_desc_base;
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a sum primitive.
- ///
- /// @param aengine Engine to perform the operation on.
- /// @param dst Destination memory descriptor.
- /// @param scales Vector of scales to multiply data in each source
- /// memory by.
- /// @param srcs Vector of source memory descriptors.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, const memory::desc &dst,
- const std::vector<float> &scales,
- const std::vector<memory::desc> &srcs,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false) {
- validate_container_size(scales,
- "counts of scales and sources are not equal",
- (int)srcs.size(), (int)srcs.size());
- auto c_api_srcs = convert_to_c(srcs);
- dnnl_primitive_desc_t result;
- dnnl_status_t status = dnnl_sum_primitive_desc_create(&result,
- aengine.get(), dst.get(), (int)c_api_srcs.size(),
- scales.data(), c_api_srcs.data(), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the sum primitive. Run workload with "
- "environment variable ONEDNN_VERBOSE=all to get "
- "additional diagnostic information.");
- reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
- }
- /// Constructs a primitive descriptor for a sum primitive.
- ///
- /// This version derives the destination memory descriptor
- /// automatically.
- ///
- /// @param aengine Engine on which to perform the operation.
- /// @param scales Vector of scales by which to multiply data in each
- /// source memory object.
- /// @param srcs Vector of source memory descriptors.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, const std::vector<float> &scales,
- const std::vector<memory::desc> &srcs,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false) {
- validate_container_size(scales,
- "counts of scales and sources are not equal",
- (int)srcs.size(), (int)srcs.size());
- auto c_api_srcs = convert_to_c(srcs);
- dnnl_primitive_desc_t result;
- dnnl_status_t status = dnnl_sum_primitive_desc_create(&result,
- aengine.get(), nullptr, (int)c_api_srcs.size(),
- scales.data(), c_api_srcs.data(), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the sum primitive. Run workload with "
- "environment variable ONEDNN_VERBOSE=all to get "
- "additional diagnostic information.");
- reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
- }
- /// Constructs a primitive descriptor for sum primitive from a C API
- /// primitive descriptor which must have a matching kind.
- ///
- /// @param pd C API primitive descriptor for sum primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : primitive_desc_base(pd, dnnl::primitive::kind::sum) {}
- /// @copydoc dnnl::primitive_desc_base::src_desc(int)const
- memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
- /// @copydoc dnnl::primitive_desc_base::dst_desc()const
- memory::desc dst_desc() const { return base::dst_desc(0); }
- };
- /// Default constructor. Produces an empty object.
- sum() = default;
- /// Constructs a sum primitive.
- /// @param pd Primitive descriptor for sum primitive.
- sum(const primitive_desc &pd) : primitive(pd.get()) {}
- /// Constructs a sum primitive from a cache blob.
- /// @param pd Primitive descriptor for sum primitive.
- /// @param cache_blob Cache blob.
- sum(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd.get(), cache_blob) {}
- };
- /// @} dnnl_api_sum
- /// @addtogroup dnnl_api_primitives_common
- /// @{
- /// A base class for descriptors of all primitives that support iteration
- /// over multiple implementations.
- struct primitive_desc : public primitive_desc_base {
- using primitive_desc_base::primitive_desc_base;
- primitive_desc() = default;
- /// Changes the primitive descriptor to point to the next available
- /// implementation.
- ///
- /// @returns @c true on success and @c false if the last available
- /// implementation has already been reached. In the latter case, the
- /// primitive descriptor itself is kept unchanged.
- bool next_impl() {
- dnnl_status_t status = dnnl_primitive_desc_next_impl(get());
- if (status == dnnl_last_impl_reached) return false;
- error::wrap_c_api(status, "last available implementation is reached");
- return true;
- }
- };
- /// @} dnnl_api_primitives_common
- /// @addtogroup dnnl_api_convolution Convolution
- ///
- /// A primitive to perform 1D, 2D or 3D convolution. Supported variants are
- /// forward propagation, backward propagation, and weights gradient with or
- /// without bias.
- ///
- /// @sa @ref dev_guide_convolution in developer guide
- ///
- /// @{
- /// Convolution forward propagation primitive.
- struct convolution_forward : public primitive {
- /// Primitive descriptor for a convolution forward propagation primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a convolution forward
- /// propagation primitive with bias.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// Arrays @p strides, @p padding_l, and @p padding_r contain values
- /// for spatial dimensions only and hence must have the same number of
- /// elements as there are spatial dimensions. The order of values is
- /// the same as in the tensor: depth (for 3D tensors), height (for 3D
- /// and 2D tensors), and width.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param aalgorithm Convolution algorithm. Possible values are
- /// #dnnl::algorithm::convolution_direct,
- /// #dnnl::algorithm::convolution_winograd, and
- /// #dnnl::algorithm::convolution_auto.
- /// @param src_desc Source memory descriptor.
- /// @param weights_desc Weights memory descriptor.
- /// @param bias_desc Bias memory descriptor. Passing zero memory
- /// descriptor disables the bias term.
- /// @param dst_desc Destination memory descriptor.
- /// @param strides Strides for each spatial dimension.
- /// @param padding_l Vector of padding values for low indices for each
- /// spatial dimension `([[front,] top,] left)`.
- /// @param padding_r Vector of padding values for high indices for
- /// each spatial dimension `([[back,] bottom,] right)`.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- algorithm aalgorithm, const memory::desc &src_desc,
- const memory::desc &weights_desc, const memory::desc &bias_desc,
- const memory::desc &dst_desc, const memory::dims &strides,
- const memory::dims &padding_l, const memory::dims &padding_r,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
- weights_desc, &bias_desc, dst_desc, strides, nullptr,
- padding_l, padding_r, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a convolution forward
- /// propagation primitive without bias.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// Arrays @p strides, @p padding_l, and @p padding_r contain values
- /// for spatial dimensions only and hence must have the same number of
- /// elements as there are spatial dimensions. The order of values is
- /// the same as in the tensor: depth (for 3D tensors), height (for 3D
- /// and 2D tensors), and width.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param aalgorithm Convolution algorithm. Possible values are
- /// #dnnl::algorithm::convolution_direct,
- /// #dnnl::algorithm::convolution_winograd, and
- /// #dnnl::algorithm::convolution_auto.
- /// @param src_desc Source memory descriptor.
- /// @param weights_desc Weights memory descriptor.
- /// @param dst_desc Destination memory descriptor.
- /// @param strides Strides for each spatial dimension.
- /// @param padding_l Vector of padding values for low indices for each
- /// spatial dimension `([[front,] top,] left)`.
- /// @param padding_r Vector of padding values for high indices for
- /// each spatial dimension `([[back,] bottom,] right)`.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- algorithm aalgorithm, const memory::desc &src_desc,
- const memory::desc &weights_desc, const memory::desc &dst_desc,
- const memory::dims &strides, const memory::dims &padding_l,
- const memory::dims &padding_r,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
- weights_desc, nullptr, dst_desc, strides, nullptr,
- padding_l, padding_r, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a convolution forward
- /// propagation primitive with bias.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
- /// contain values for spatial dimensions only and hence must have the
- /// same number of elements as there are spatial dimensions. The order
- /// of values is the same as in the tensor: depth (for 3D tensors),
- /// height (for 3D and 2D tensors), and width.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param aalgorithm Convolution algorithm. Possible values are
- /// #dnnl::algorithm::convolution_direct,
- /// #dnnl::algorithm::convolution_winograd, and
- /// #dnnl::algorithm::convolution_auto.
- /// @param src_desc Source memory descriptor.
- /// @param weights_desc Weights memory descriptor.
- /// @param bias_desc Bias memory descriptor. Passing zero memory
- /// descriptor disables the bias term.
- /// @param dst_desc Destination memory descriptor.
- /// @param strides Strides for each spatial dimension.
- /// @param dilates Dilations for each spatial dimension. A zero value
- /// means no dilation in the corresponding dimension.
- /// @param padding_l Vector of padding values for low indices for each
- /// spatial dimension `([[front,] top,] left)`.
- /// @param padding_r Vector of padding values for high indices for
- /// each spatial dimension `([[back,] bottom,] right)`.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- algorithm aalgorithm, const memory::desc &src_desc,
- const memory::desc &weights_desc, const memory::desc &bias_desc,
- const memory::desc &dst_desc, const memory::dims &strides,
- const memory::dims &dilates, const memory::dims &padding_l,
- const memory::dims &padding_r,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
- weights_desc, &bias_desc, dst_desc, strides, &dilates,
- padding_l, padding_r, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a convolution forward
- /// propagation primitive without bias.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
- /// contain values for spatial dimensions only and hence must have the
- /// same number of elements as there are spatial dimensions. The order
- /// of values is the same as in the tensor: depth (for 3D tensors),
- /// height (for 3D and 2D tensors), and width.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param aalgorithm Convolution algorithm. Possible values are
- /// #dnnl::algorithm::convolution_direct,
- /// #dnnl::algorithm::convolution_winograd, and
- /// #dnnl::algorithm::convolution_auto.
- /// @param src_desc Source memory descriptor.
- /// @param weights_desc Weights memory descriptor.
- /// @param dst_desc Destination memory descriptor.
- /// @param strides Strides for each spatial dimension.
- /// @param dilates Dilations for each spatial dimension. A zero value
- /// means no dilation in the corresponding dimension.
- /// @param padding_l Vector of padding values for low indices for each
- /// spatial dimension `([[front,] top,] left)`.
- /// @param padding_r Vector of padding values for high indices for
- /// each spatial dimension `([[back,] bottom,] right)`.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- algorithm aalgorithm, const memory::desc &src_desc,
- const memory::desc &weights_desc, const memory::desc &dst_desc,
- const memory::dims &strides, const memory::dims &dilates,
- const memory::dims &padding_l, const memory::dims &padding_r,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
- weights_desc, nullptr, dst_desc, strides, &dilates,
- padding_l, padding_r, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a convolution forward
- /// propagation primitive from a C API primitive descriptor that must
- /// have a matching kind.
- ///
- /// @param pd C API primitive descriptor for a convolution forward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
- dnnl::prop_kind::forward_training,
- dnnl::prop_kind::forward_inference) {}
- /// @copydoc dnnl::primitive_desc_base::src_desc()const
- memory::desc src_desc() const { return base::src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::weights_desc()const
- memory::desc weights_desc() const { return base::weights_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::dst_desc()const
- memory::desc dst_desc() const { return base::dst_desc(0); }
- /// Returns the bias memory descriptor.
- /// @returns The bias memory descriptor.
- /// @returns A zero memory descriptor of the primitive does not have a
- /// bias parameter.
- memory::desc bias_desc() const { return base::weights_desc(1); }
- /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
- algorithm get_algorithm() const { return base::get_algorithm(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_strides()const
- memory::dims get_strides() const { return base::get_strides(); }
- /// @copydoc dnnl::primitive_desc_base::get_dilations()const
- memory::dims get_dilations() const { return base::get_dilations(); }
- /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
- memory::dims get_padding_l() const { return base::get_padding_l(); }
- /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
- memory::dims get_padding_r() const { return base::get_padding_r(); }
- private:
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- algorithm aalgorithm, const memory::desc &src_desc,
- const memory::desc &weights_desc, const memory::desc *bias_desc,
- const memory::desc &dst_desc, const memory::dims &strides,
- const memory::dims *dilates, const memory::dims &padding_l,
- const memory::dims &padding_r, const primitive_attr &attr,
- bool allow_empty) {
- memory::validate_dims(strides, src_desc.get_ndims() - 2);
- memory::validate_dims(padding_l, src_desc.get_ndims() - 2);
- memory::validate_dims(padding_r, src_desc.get_ndims() - 2);
- if (dilates)
- memory::validate_dims(*dilates, src_desc.get_ndims() - 2);
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status
- = dnnl_convolution_forward_primitive_desc_create(&pd,
- aengine.get(), dnnl::convert_to_c(aprop_kind),
- convert_to_c(aalgorithm), src_desc.get(),
- weights_desc.get(), optional_arg(bias_desc),
- dst_desc.get(), &strides[0], optional_arg(dilates),
- &padding_l[0], &padding_r[0], attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the convolution forward propagation primitive. Run "
- "workload with environment variable ONEDNN_VERBOSE=all "
- "to get additional diagnostic information.");
- reset(pd);
- }
- };
- /// Default constructor. Produces an empty object.
- convolution_forward() = default;
- /// Constructs a convolution forward propagation primitive.
- /// @param pd Primitive descriptor for a convolution forward propagation
- /// primitive.
- convolution_forward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a convolution forward propagation primitive from a cache
- /// blob.
- /// @param pd Primitive descriptor for a convolution forward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- convolution_forward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// Convolution backward propagation primitive.
- struct convolution_backward_data : public primitive {
- /// Primitive descriptor for a convolution backward propagation primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a convolution backward
- /// propagation primitive.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// Arrays @p strides, @p padding_l, and @p padding_r contain values
- /// for spatial dimensions only and hence must have the same number of
- /// elements as there are spatial dimensions. The order of values is
- /// the same as in the tensor: depth (for 3D tensors), height (for 3D
- /// and 2D tensors), and width.
- ///
- /// @param aengine Engine to use.
- /// @param aalgorithm Convolution algorithm. Possible values are
- /// #dnnl::algorithm::convolution_direct,
- /// #dnnl::algorithm::convolution_winograd, and
- /// #dnnl::algorithm::convolution_auto.
- /// @param diff_src_desc Diff source memory descriptor.
- /// @param weights_desc Weights memory descriptor.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param strides Strides for each spatial dimension.
- /// @param padding_l Vector of padding values for low indices for each
- /// spatial dimension `([[front,] top,] left)`.
- /// @param padding_r Vector of padding values for high indices for
- /// each spatial dimension `([[back,] bottom,] right)`.
- /// @param hint_fwd_pd Primitive descriptor for a convolution
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &diff_src_desc,
- const memory::desc &weights_desc,
- const memory::desc &diff_dst_desc, const memory::dims &strides,
- const memory::dims &padding_l, const memory::dims &padding_r,
- const convolution_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc,
- diff_dst_desc, strides, nullptr, padding_l, padding_r,
- hint_fwd_pd, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a convolution backward
- /// propagation primitive.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
- /// contain values for spatial dimensions only and hence must have the
- /// same number of elements as there are spatial dimensions. The order
- /// of values is the same as in the tensor: depth (for 3D tensors),
- /// height (for 3D and 2D tensors), and width.
- ///
- /// @param aengine Engine to use.
- /// @param aalgorithm Convolution algorithm. Possible values are
- /// #dnnl::algorithm::convolution_direct,
- /// #dnnl::algorithm::convolution_winograd, and
- /// #dnnl::algorithm::convolution_auto.
- /// @param diff_src_desc Diff source memory descriptor.
- /// @param weights_desc Weights memory descriptor.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param strides Strides for each spatial dimension.
- /// @param dilates Dilations for each spatial dimension. A zero value
- /// means no dilation in the corresponding dimension.
- /// @param padding_l Vector of padding values for low indices for each
- /// spatial dimension `([[front,] top,] left)`.
- /// @param padding_r Vector of padding values for high indices for
- /// each spatial dimension `([[back,] bottom,] right)`.
- /// @param hint_fwd_pd Primitive descriptor for a convolution
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &diff_src_desc,
- const memory::desc &weights_desc,
- const memory::desc &diff_dst_desc, const memory::dims &strides,
- const memory::dims &dilates, const memory::dims &padding_l,
- const memory::dims &padding_r,
- const convolution_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc,
- diff_dst_desc, strides, &dilates, padding_l, padding_r,
- hint_fwd_pd, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a convolution backward
- /// propagation primitive from a C API primitive descriptor that must
- /// have a matching kind.
- ///
- /// @param pd C API primitive descriptor for a convolution backward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
- dnnl::prop_kind::backward_data) {}
- /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
- memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::weights_desc()const
- memory::desc weights_desc() const { return base::weights_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
- memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
- algorithm get_algorithm() const { return base::get_algorithm(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_strides()const
- memory::dims get_strides() const { return base::get_strides(); }
- /// @copydoc dnnl::primitive_desc_base::get_dilations()const
- memory::dims get_dilations() const { return base::get_dilations(); }
- /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
- memory::dims get_padding_l() const { return base::get_padding_l(); }
- /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
- memory::dims get_padding_r() const { return base::get_padding_r(); }
- private:
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &diff_src_desc,
- const memory::desc &weights_desc,
- const memory::desc &diff_dst_desc, const memory::dims &strides,
- const memory::dims *dilates, const memory::dims &padding_l,
- const memory::dims &padding_r,
- const convolution_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr, bool allow_empty) {
- memory::validate_dims(strides, diff_src_desc.get_ndims() - 2);
- memory::validate_dims(padding_l, diff_src_desc.get_ndims() - 2);
- memory::validate_dims(padding_r, diff_src_desc.get_ndims() - 2);
- if (dilates)
- memory::validate_dims(*dilates, diff_src_desc.get_ndims() - 2);
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status
- = dnnl_convolution_backward_data_primitive_desc_create(&pd,
- aengine.get(), convert_to_c(aalgorithm),
- diff_src_desc.get(), weights_desc.get(),
- diff_dst_desc.get(), &strides[0],
- optional_arg(dilates), &padding_l[0], &padding_r[0],
- hint_fwd_pd.get(), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the convolution backward propagation primitive. Run "
- "workload with environment variable ONEDNN_VERBOSE=all "
- "to get additional diagnostic information.");
- reset(pd);
- }
- };
- /// Default constructor. Produces an empty object.
- convolution_backward_data() = default;
- /// Constructs a convolution backward propagation primitive.
- /// @param pd Primitive descriptor for a convolution backward propagation
- /// primitive.
- convolution_backward_data(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a convolution backward propagation primitive from a cache
- /// blob.
- /// @param pd Primitive descriptor for a convolution backward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- convolution_backward_data(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// Convolution weights gradient primitive.
- struct convolution_backward_weights : public primitive {
- /// Primitive descriptor for a convolution weights gradient primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a convolution weights gradient
- /// primitive with bias.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// Arrays @p strides, @p padding_l, and @p padding_r contain values
- /// for spatial dimensions only and hence must have the same number of
- /// elements as there are spatial dimensions. The order of values is
- /// the same as in the tensor: depth (for 3D tensors), height (for 3D
- /// and 2D tensors), and width.
- ///
- /// @param aengine Engine to use.
- /// @param aalgorithm Convolution algorithm. Possible values are
- /// #dnnl::algorithm::convolution_direct,
- /// #dnnl::algorithm::convolution_winograd, and
- /// #dnnl::algorithm::convolution_auto.
- /// @param src_desc Source memory descriptor.
- /// @param diff_weights_desc Diff weights memory descriptor.
- /// @param diff_bias_desc Diff bias memory descriptor. Passing zero
- /// memory descriptor disables the bias term.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param strides Strides for each spatial dimension.
- /// @param padding_l Vector of padding values for low indices for each
- /// spatial dimension `([[front,] top,] left)`.
- /// @param padding_r Vector of padding values for high indices for
- /// each spatial dimension `([[back,] bottom,] right)`.
- /// @param hint_fwd_pd Primitive descriptor for a convolution
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &diff_weights_desc,
- const memory::desc &diff_bias_desc,
- const memory::desc &diff_dst_desc, const memory::dims &strides,
- const memory::dims &padding_l, const memory::dims &padding_r,
- const convolution_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
- &diff_bias_desc, diff_dst_desc, strides, nullptr, padding_l,
- padding_r, hint_fwd_pd, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a convolution weights gradient
- /// primitive without bias.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// Arrays @p strides, @p padding_l, and @p padding_r contain values
- /// for spatial dimensions only and hence must have the same number of
- /// elements as there are spatial dimensions. The order of values is
- /// the same as in the tensor: depth (for 3D tensors), height (for 3D
- /// and 2D tensors), and width.
- ///
- /// @param aengine Engine to use.
- /// @param aalgorithm Convolution algorithm. Possible values are
- /// #dnnl::algorithm::convolution_direct,
- /// #dnnl::algorithm::convolution_winograd, and
- /// #dnnl::algorithm::convolution_auto.
- /// @param src_desc Source memory descriptor.
- /// @param diff_weights_desc Diff weights memory descriptor.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param strides Strides for each spatial dimension.
- /// @param padding_l Vector of padding values for low indices for each
- /// spatial dimension `([[front,] top,] left)`.
- /// @param padding_r Vector of padding values for high indices for
- /// each spatial dimension `([[back,] bottom,] right)`.
- /// @param hint_fwd_pd Primitive descriptor for a convolution
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &diff_weights_desc,
- const memory::desc &diff_dst_desc, const memory::dims &strides,
- const memory::dims &padding_l, const memory::dims &padding_r,
- const convolution_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
- nullptr, diff_dst_desc, strides, nullptr, padding_l,
- padding_r, hint_fwd_pd, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a convolution weights
- /// gradient primitive with bias.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
- /// contain values for spatial dimensions only and hence must have the
- /// same number of elements as there are spatial dimensions. The order
- /// of values is the same as in the tensor: depth (for 3D tensors),
- /// height (for 3D and 2D tensors), and width.
- ///
- /// @param aengine Engine to use.
- /// @param aalgorithm Convolution algorithm. Possible values are
- /// #dnnl::algorithm::convolution_direct,
- /// #dnnl::algorithm::convolution_winograd, and
- /// #dnnl::algorithm::convolution_auto.
- /// @param src_desc Source memory descriptor.
- /// @param diff_weights_desc Diff weights memory descriptor.
- /// @param diff_bias_desc Diff bias memory descriptor. Passing zero
- /// memory descriptor disables the bias term.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param strides Strides for each spatial dimension.
- /// @param dilates Dilations for each spatial dimension. A zero value
- /// means no dilation in the corresponding dimension.
- /// @param padding_l Vector of padding values for low indices for each
- /// spatial dimension `([[front,] top,] left)`.
- /// @param padding_r Vector of padding values for high indices for
- /// each spatial dimension `([[back,] bottom,] right)`.
- /// @param hint_fwd_pd Primitive descriptor for a convolution
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &diff_weights_desc,
- const memory::desc &diff_bias_desc,
- const memory::desc &diff_dst_desc, const memory::dims &strides,
- const memory::dims &dilates, const memory::dims &padding_l,
- const memory::dims &padding_r,
- const convolution_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
- &diff_bias_desc, diff_dst_desc, strides, &dilates,
- padding_l, padding_r, hint_fwd_pd, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a convolution weights
- /// gradient primitive without bias.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
- /// contain values for spatial dimensions only and hence must have the
- /// same number of elements as there are spatial dimensions. The order
- /// of values is the same as in the tensor: depth (for 3D tensors),
- /// height (for 3D and 2D tensors), and width.
- ///
- /// @param aengine Engine to use.
- /// @param aalgorithm Convolution algorithm. Possible values are
- /// #dnnl::algorithm::convolution_direct,
- /// #dnnl::algorithm::convolution_winograd, and
- /// #dnnl::algorithm::convolution_auto.
- /// @param src_desc Source memory descriptor.
- /// @param diff_weights_desc Diff weights memory descriptor.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param strides Strides for each spatial dimension.
- /// @param dilates Dilations for each spatial dimension. A zero value
- /// means no dilation in the corresponding dimension.
- /// @param padding_l Vector of padding values for low indices for each
- /// spatial dimension `([[front,] top,] left)`.
- /// @param padding_r Vector of padding values for high indices for
- /// each spatial dimension `([[back,] bottom,] right)`.
- /// @param hint_fwd_pd Primitive descriptor for a convolution
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &diff_weights_desc,
- const memory::desc &diff_dst_desc, const memory::dims &strides,
- const memory::dims &dilates, const memory::dims &padding_l,
- const memory::dims &padding_r,
- const convolution_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
- nullptr, diff_dst_desc, strides, &dilates, padding_l,
- padding_r, hint_fwd_pd, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a convolution weights gradient
- /// primitive from a C API primitive descriptor that must have a
- /// matching kind.
- ///
- /// @param pd C API primitive descriptor for a convolution weights
- /// gradient primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
- dnnl::prop_kind::backward_weights) {}
- /// @copydoc dnnl::primitive_desc_base::src_desc()const
- memory::desc src_desc() const { return base::src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
- memory::desc diff_weights_desc() const {
- return base::diff_weights_desc(0);
- }
- /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
- memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
- /// Returns the diff bias memory descriptor.
- /// @returns The diff bias memory descriptor.
- /// @returns A zero memory descriptor of the primitive does not have a
- /// diff bias parameter.
- memory::desc diff_bias_desc() const {
- return base::diff_weights_desc(1);
- }
- /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
- algorithm get_algorithm() const { return base::get_algorithm(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_strides()const
- memory::dims get_strides() const { return base::get_strides(); }
- /// @copydoc dnnl::primitive_desc_base::get_dilations()const
- memory::dims get_dilations() const { return base::get_dilations(); }
- /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
- memory::dims get_padding_l() const { return base::get_padding_l(); }
- /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
- memory::dims get_padding_r() const { return base::get_padding_r(); }
- private:
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &diff_weights_desc,
- const memory::desc *diff_bias_desc,
- const memory::desc &diff_dst_desc, const memory::dims &strides,
- const memory::dims *dilates, const memory::dims &padding_l,
- const memory::dims &padding_r,
- const convolution_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr, bool allow_empty) {
- memory::validate_dims(strides, src_desc.get_ndims() - 2);
- memory::validate_dims(padding_l, src_desc.get_ndims() - 2);
- memory::validate_dims(padding_r, src_desc.get_ndims() - 2);
- if (dilates)
- memory::validate_dims(*dilates, src_desc.get_ndims() - 2);
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status
- = dnnl_convolution_backward_weights_primitive_desc_create(
- &pd, aengine.get(), convert_to_c(aalgorithm),
- src_desc.get(), diff_weights_desc.get(),
- optional_arg(diff_bias_desc), diff_dst_desc.get(),
- &strides[0], optional_arg(dilates), &padding_l[0],
- &padding_r[0], hint_fwd_pd.get(), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the convolution weights update primitive. Run "
- "workload with environment variable ONEDNN_VERBOSE=all "
- "to get additional diagnostic information.");
- reset(pd);
- }
- };
- /// Default constructor. Produces an empty object.
- convolution_backward_weights() = default;
- /// Constructs a convolution weights gradient primitive.
- /// @param pd Primitive descriptor for a convolution weights gradient
- /// primitive.
- convolution_backward_weights(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a convolution weights gradient primitive from a cache blob.
- /// @param pd Primitive descriptor for a convolution weights gradient
- /// primitive.
- /// @param cache_blob Cache blob.
- convolution_backward_weights(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// @} dnnl_api_convolution
- //
- /// @addtogroup dnnl_api_deconvolution Deconvolution
- ///
- /// A primitive to perform 1D, 2D or 3D deconvolution. Supported variants are
- /// forward propagation, backward propagation, and weights gradient with or
- /// without bias.
- ///
- /// @{
- /// Deconvolution forward propagation primitive.
- struct deconvolution_forward : public primitive {
- /// Primitive descriptor for a deconvolution forward propagation primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a deconvolution forward
- /// propagation primitive with bias.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// Arrays @p strides, @p padding_l, and @p padding_r contain values
- /// for spatial dimensions only and hence must have the same number of
- /// elements as there are spatial dimensions. The order of values is
- /// the same as in the tensor: depth (for 3D tensors), height (for 3D
- /// and 2D tensors), and width.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param aalgorithm Deconvolution algorithm:
- /// #dnnl::algorithm::deconvolution_direct, and
- /// #dnnl::algorithm::deconvolution_winograd.
- /// @param src_desc Source memory descriptor.
- /// @param weights_desc Weights memory descriptor.
- /// @param bias_desc Bias memory descriptor. Passing zero memory
- /// descriptor disables the bias term.
- /// @param dst_desc Destination memory descriptor.
- /// @param strides Vector of strides for spatial dimension.
- /// @param padding_l Vector of padding values for low indices for each
- /// spatial dimension `([[front,] top,] left)`.
- /// @param padding_r Vector of padding values for high indices for
- /// each spatial dimension `([[back,] bottom,] right)`.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- algorithm aalgorithm, const memory::desc &src_desc,
- const memory::desc &weights_desc, const memory::desc &bias_desc,
- const memory::desc &dst_desc, const memory::dims &strides,
- const memory::dims &padding_l, const memory::dims &padding_r,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
- weights_desc, &bias_desc, dst_desc, strides, nullptr,
- padding_l, padding_r, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a deconvolution forward
- /// propagation primitive without bias.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// Arrays @p strides, @p padding_l, and @p padding_r contain values
- /// for spatial dimensions only and hence must have the same number of
- /// elements as there are spatial dimensions. The order of values is
- /// the same as in the tensor: depth (for 3D tensors), height (for 3D
- /// and 2D tensors), and width.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param aalgorithm Deconvolution algorithm:
- /// #dnnl::algorithm::deconvolution_direct, and
- /// #dnnl::algorithm::deconvolution_winograd.
- /// @param src_desc Source memory descriptor.
- /// @param weights_desc Weights memory descriptor.
- /// @param dst_desc Destination memory descriptor.
- /// @param strides Vector of strides for spatial dimension.
- /// @param padding_l Vector of padding values for low indices for each
- /// spatial dimension `([[front,] top,] left)`.
- /// @param padding_r Vector of padding values for high indices for
- /// each spatial dimension `([[back,] bottom,] right)`.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- algorithm aalgorithm, const memory::desc &src_desc,
- const memory::desc &weights_desc, const memory::desc &dst_desc,
- const memory::dims &strides, const memory::dims &padding_l,
- const memory::dims &padding_r,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
- weights_desc, nullptr, dst_desc, strides, nullptr,
- padding_l, padding_r, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a deconvolution forward
- /// propagation primitive with bias.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
- /// contain values for spatial dimensions only and hence must have the
- /// same number of elements as there are spatial dimensions. The order
- /// of values is the same as in the tensor: depth (for 3D tensors),
- /// height (for 3D and 2D tensors), and width.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param aalgorithm Deconvolution algorithm:
- /// #dnnl::algorithm::deconvolution_direct, and
- /// #dnnl::algorithm::deconvolution_winograd.
- /// @param src_desc Source memory descriptor.
- /// @param weights_desc Weights memory descriptor.
- /// @param bias_desc Bias memory descriptor. Passing zero memory
- /// descriptor disables the bias term.
- /// @param dst_desc Destination memory descriptor.
- /// @param strides Vector of strides for spatial dimension.
- /// @param dilates Dilations for each spatial dimension. A zero value
- /// means no dilation in the corresponding dimension.
- /// @param padding_l Vector of padding values for low indices for each
- /// spatial dimension `([[front,] top,] left)`.
- /// @param padding_r Vector of padding values for high indices for
- /// each spatial dimension `([[back,] bottom,] right)`.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- algorithm aalgorithm, const memory::desc &src_desc,
- const memory::desc &weights_desc, const memory::desc &bias_desc,
- const memory::desc &dst_desc, const memory::dims &strides,
- const memory::dims &dilates, const memory::dims &padding_l,
- const memory::dims &padding_r,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
- weights_desc, &bias_desc, dst_desc, strides, &dilates,
- padding_l, padding_r, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a deconvolution forward
- /// propagation primitive without bias.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
- /// contain values for spatial dimensions only and hence must have the
- /// same number of elements as there are spatial dimensions. The order
- /// of values is the same as in the tensor: depth (for 3D tensors),
- /// height (for 3D and 2D tensors), and width.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param aalgorithm Deconvolution algorithm:
- /// #dnnl::algorithm::deconvolution_direct, and
- /// #dnnl::algorithm::deconvolution_winograd.
- /// @param src_desc Source memory descriptor.
- /// @param weights_desc Weights memory descriptor.
- /// @param dst_desc Destination memory descriptor.
- /// @param strides Vector of strides for spatial dimension.
- /// @param dilates Dilations for each spatial dimension. A zero value
- /// means no dilation in the corresponding dimension.
- /// @param padding_l Vector of padding values for low indices for each
- /// spatial dimension `([[front,] top,] left)`.
- /// @param padding_r Vector of padding values for high indices for
- /// each spatial dimension `([[back,] bottom,] right)`.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- algorithm aalgorithm, const memory::desc &src_desc,
- const memory::desc &weights_desc, const memory::desc &dst_desc,
- const memory::dims &strides, const memory::dims &dilates,
- const memory::dims &padding_l, const memory::dims &padding_r,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
- weights_desc, nullptr, dst_desc, strides, &dilates,
- padding_l, padding_r, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a deconvolution forward
- /// propagation primitive from a C API primitive descriptor that must
- /// have a matching kind.
- ///
- /// @param pd C API primitive descriptor for a deconvolution forward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
- dnnl::prop_kind::forward_training,
- dnnl::prop_kind::forward_inference) {}
- /// @copydoc dnnl::primitive_desc_base::src_desc()const
- memory::desc src_desc() const { return base::src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::weights_desc()const
- memory::desc weights_desc() const { return base::weights_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::dst_desc()const
- memory::desc dst_desc() const { return base::dst_desc(0); }
- /// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const
- memory::desc bias_desc() const { return base::weights_desc(1); }
- /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
- algorithm get_algorithm() const { return base::get_algorithm(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_strides()const
- memory::dims get_strides() const { return base::get_strides(); }
- /// @copydoc dnnl::primitive_desc_base::get_dilations()const
- memory::dims get_dilations() const { return base::get_dilations(); }
- /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
- memory::dims get_padding_l() const { return base::get_padding_l(); }
- /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
- memory::dims get_padding_r() const { return base::get_padding_r(); }
- private:
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- algorithm aalgorithm, const memory::desc &src_desc,
- const memory::desc &weights_desc, const memory::desc *bias_desc,
- const memory::desc &dst_desc, const memory::dims &strides,
- const memory::dims *dilates, const memory::dims &padding_l,
- const memory::dims &padding_r, const primitive_attr &attr,
- bool allow_empty) {
- memory::validate_dims(strides, src_desc.get_ndims() - 2);
- memory::validate_dims(padding_l, src_desc.get_ndims() - 2);
- memory::validate_dims(padding_r, src_desc.get_ndims() - 2);
- if (dilates)
- memory::validate_dims(*dilates, src_desc.get_ndims() - 2);
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status
- = dnnl_deconvolution_forward_primitive_desc_create(&pd,
- aengine.get(), dnnl::convert_to_c(aprop_kind),
- convert_to_c(aalgorithm), src_desc.get(),
- weights_desc.get(), optional_arg(bias_desc),
- dst_desc.get(), &strides[0], optional_arg(dilates),
- &padding_l[0], &padding_r[0], attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the deconvolution forward propagation primitive. Run "
- "workload with environment variable ONEDNN_VERBOSE=all "
- "to get additional diagnostic information.");
- reset(pd);
- }
- };
- /// Default constructor. Produces an empty object.
- deconvolution_forward() = default;
- /// Constructs a deconvolution forward propagation primitive.
- /// @param pd Primitive descriptor for a deconvolution forward propagation
- /// primitive.
- deconvolution_forward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a deconvolution forward propagation primitive from a cache
- /// blob.
- /// @param pd Primitive descriptor for a deconvolution forward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- deconvolution_forward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// Deconvolution backward propagation primitive.
- struct deconvolution_backward_data : public primitive {
- /// Primitive descriptor for a deconvolution backward propagation primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a deconvolution backward
- /// propagation primitive.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// Arrays @p strides, @p padding_l, and @p padding_r contain values
- /// for spatial dimensions only and hence must have the same number of
- /// elements as there are spatial dimensions. The order of values is
- /// the same as in the tensor: depth (for 3D tensors), height (for 3D
- /// and 2D tensors), and width.
- ///
- /// @param aengine Engine to use.
- /// @param aalgorithm Deconvolution algorithm
- /// (#dnnl::algorithm::convolution_direct,
- /// #dnnl::algorithm::convolution_winograd).
- /// @param diff_src_desc Diff source memory descriptor.
- /// @param weights_desc Weights memory descriptor.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param strides Strides for each spatial dimension.
- /// @param padding_l Vector of padding values for low indices for each
- /// spatial dimension `([[front,] top,] left)`.
- /// @param padding_r Vector of padding values for high indices for
- /// each spatial dimension `([[back,] bottom,] right)`.
- /// @param hint_fwd_pd Primitive descriptor for a deconvolution
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &diff_src_desc,
- const memory::desc &weights_desc,
- const memory::desc &diff_dst_desc, const memory::dims &strides,
- const memory::dims &padding_l, const memory::dims &padding_r,
- const deconvolution_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc,
- diff_dst_desc, strides, nullptr, padding_l, padding_r,
- hint_fwd_pd, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a deconvolution backward
- /// propagation primitive.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
- /// contain values for spatial dimensions only and hence must have the
- /// same number of elements as there are spatial dimensions. The order
- /// of values is the same as in the tensor: depth (for 3D tensors),
- /// height (for 3D and 2D tensors), and width.
- ///
- /// @param aengine Engine to use.
- /// @param aalgorithm Deconvolution algorithm
- /// (#dnnl::algorithm::convolution_direct,
- /// #dnnl::algorithm::convolution_winograd).
- /// @param diff_src_desc Diff source memory descriptor.
- /// @param weights_desc Weights memory descriptor.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param strides Strides for each spatial dimension.
- /// @param dilates Dilations for each spatial dimension. A zero value
- /// means no dilation in the corresponding dimension.
- /// @param padding_l Vector of padding values for low indices for each
- /// spatial dimension `([[front,] top,] left)`.
- /// @param padding_r Vector of padding values for high indices for
- /// each spatial dimension `([[back,] bottom,] right)`.
- /// @param hint_fwd_pd Primitive descriptor for a deconvolution
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &diff_src_desc,
- const memory::desc &weights_desc,
- const memory::desc &diff_dst_desc, const memory::dims &strides,
- const memory::dims &dilates, const memory::dims &padding_l,
- const memory::dims &padding_r,
- const deconvolution_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc,
- diff_dst_desc, strides, &dilates, padding_l, padding_r,
- hint_fwd_pd, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a deconvolution backward
- /// propagation primitive from a C API primitive descriptor that must
- /// have a matching kind.
- ///
- /// @param pd C API primitive descriptor for a deconvolution backward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
- dnnl::prop_kind::backward_data) {}
- /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
- memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::weights_desc()const
- memory::desc weights_desc() const { return base::weights_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
- memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
- algorithm get_algorithm() const { return base::get_algorithm(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_strides()const
- memory::dims get_strides() const { return base::get_strides(); }
- /// @copydoc dnnl::primitive_desc_base::get_dilations()const
- memory::dims get_dilations() const { return base::get_dilations(); }
- /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
- memory::dims get_padding_l() const { return base::get_padding_l(); }
- /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
- memory::dims get_padding_r() const { return base::get_padding_r(); }
- private:
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &diff_src_desc,
- const memory::desc &weights_desc,
- const memory::desc &diff_dst_desc, const memory::dims &strides,
- const memory::dims *dilates, const memory::dims &padding_l,
- const memory::dims &padding_r,
- const deconvolution_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr, bool allow_empty) {
- memory::validate_dims(strides, diff_src_desc.get_ndims() - 2);
- memory::validate_dims(padding_l, diff_src_desc.get_ndims() - 2);
- memory::validate_dims(padding_r, diff_src_desc.get_ndims() - 2);
- if (dilates)
- memory::validate_dims(*dilates, diff_src_desc.get_ndims() - 2);
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status
- = dnnl_deconvolution_backward_data_primitive_desc_create(
- &pd, aengine.get(), convert_to_c(aalgorithm),
- diff_src_desc.get(), weights_desc.get(),
- diff_dst_desc.get(), &strides[0],
- optional_arg(dilates), &padding_l[0], &padding_r[0],
- hint_fwd_pd.get(), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the deconvolution backward propagation primitive. Run "
- "workload with environment variable ONEDNN_VERBOSE=all "
- "to get additional diagnostic information.");
- reset(pd);
- }
- };
- /// Default constructor. Produces an empty object.
- deconvolution_backward_data() = default;
- /// Constructs a deconvolution backward propagation primitive.
- /// @param pd Primitive descriptor for a deconvolution backward propagation
- /// primitive.
- deconvolution_backward_data(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a deconvolution backward propagation primitive from a cache
- /// blob.
- /// @param pd Primitive descriptor for a deconvolution backward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- deconvolution_backward_data(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// Deconvolution weights gradient primitive.
- struct deconvolution_backward_weights : public primitive {
- /// Primitive descriptor for a deconvolution weights gradient primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a deconvolution weights
- /// gradient primitive with bias.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// Arrays @p strides, @p padding_l, and @p padding_r contain values
- /// for spatial dimensions only and hence must have the same number of
- /// elements as there are spatial dimensions. The order of values is
- /// the same as in the tensor: depth (for 3D tensors), height (for 3D
- /// and 2D tensors), and width.
- ///
- /// @param aengine Engine to use.
- /// @param aalgorithm Deconvolution algorithm. Possible values are
- /// #dnnl::algorithm::deconvolution_direct, and
- /// #dnnl::algorithm::deconvolution_winograd.
- /// @param src_desc Source memory descriptor.
- /// @param diff_weights_desc Diff weights memory descriptor.
- /// @param diff_bias_desc Diff bias memory descriptor. Passing zero
- /// memory descriptor disables the bias term.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param strides Strides for each spatial dimension.
- /// @param padding_l Vector of padding values for low indices for each
- /// spatial dimension `([[front,] top,] left)`.
- /// @param padding_r Vector of padding values for high indices for
- /// each spatial dimension `([[back,] bottom,] right)`.
- /// @param hint_fwd_pd Primitive descriptor for a deconvolution
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &diff_weights_desc,
- const memory::desc &diff_bias_desc,
- const memory::desc &diff_dst_desc, const memory::dims &strides,
- const memory::dims &padding_l, const memory::dims &padding_r,
- const deconvolution_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
- &diff_bias_desc, diff_dst_desc, strides, nullptr, padding_l,
- padding_r, hint_fwd_pd, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a deconvolution weights
- /// gradient primitive without bias.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// Arrays @p strides, @p padding_l, and @p padding_r contain values
- /// for spatial dimensions only and hence must have the same number of
- /// elements as there are spatial dimensions. The order of values is
- /// the same as in the tensor: depth (for 3D tensors), height (for 3D
- /// and 2D tensors), and width.
- ///
- /// @param aengine Engine to use.
- /// @param aalgorithm Deconvolution algorithm. Possible values are
- /// #dnnl::algorithm::deconvolution_direct, and
- /// #dnnl::algorithm::deconvolution_winograd.
- /// @param src_desc Source memory descriptor.
- /// @param diff_weights_desc Diff weights memory descriptor.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param strides Strides for each spatial dimension.
- /// @param padding_l Vector of padding values for low indices for each
- /// spatial dimension `([[front,] top,] left)`.
- /// @param padding_r Vector of padding values for high indices for
- /// each spatial dimension `([[back,] bottom,] right)`.
- /// @param hint_fwd_pd Primitive descriptor for a deconvolution
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &diff_weights_desc,
- const memory::desc &diff_dst_desc, const memory::dims &strides,
- const memory::dims &padding_l, const memory::dims &padding_r,
- const deconvolution_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
- nullptr, diff_dst_desc, strides, nullptr, padding_l,
- padding_r, hint_fwd_pd, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a deconvolution weights
- /// gradient primitive with bias.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
- /// contain values for spatial dimensions only and hence must have the
- /// same number of elements as there are spatial dimensions. The order
- /// of values is the same as in the tensor: depth (for 3D tensors),
- /// height (for 3D and 2D tensors), and width.
- ///
- /// @param aengine Engine to use.
- /// @param aalgorithm Deconvolution algorithm. Possible values are
- /// #dnnl::algorithm::deconvolution_direct, and
- /// #dnnl::algorithm::deconvolution_winograd.
- /// @param src_desc Source memory descriptor.
- /// @param diff_weights_desc Diff weights memory descriptor.
- /// @param diff_bias_desc Diff bias memory descriptor. Passing zero
- /// memory descriptor disables the bias term.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param strides Strides for each spatial dimension.
- /// @param dilates Dilations for each spatial dimension. A zero value
- /// means no dilation in the corresponding dimension.
- /// @param padding_l Vector of padding values for low indices for each
- /// spatial dimension `([[front,] top,] left)`.
- /// @param padding_r Vector of padding values for high indices for
- /// each spatial dimension `([[back,] bottom,] right)`.
- /// @param hint_fwd_pd Primitive descriptor for a deconvolution
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &diff_weights_desc,
- const memory::desc &diff_bias_desc,
- const memory::desc &diff_dst_desc, const memory::dims &strides,
- const memory::dims &dilates, const memory::dims &padding_l,
- const memory::dims &padding_r,
- const deconvolution_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
- &diff_bias_desc, diff_dst_desc, strides, &dilates,
- padding_l, padding_r, hint_fwd_pd, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a deconvolution weights
- /// gradient primitive without bias.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
- /// contain values for spatial dimensions only and hence must have the
- /// same number of elements as there are spatial dimensions. The order
- /// of values is the same as in the tensor: depth (for 3D tensors),
- /// height (for 3D and 2D tensors), and width.
- ///
- /// @param aengine Engine to use.
- /// @param aalgorithm Deconvolution algorithm. Possible values are
- /// #dnnl::algorithm::deconvolution_direct, and
- /// #dnnl::algorithm::deconvolution_winograd.
- /// @param src_desc Source memory descriptor.
- /// @param diff_weights_desc Diff weights memory descriptor.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param strides Strides for each spatial dimension.
- /// @param dilates Dilations for each spatial dimension. A zero value
- /// means no dilation in the corresponding dimension.
- /// @param padding_l Vector of padding values for low indices for each
- /// spatial dimension `([[front,] top,] left)`.
- /// @param padding_r Vector of padding values for high indices for
- /// each spatial dimension `([[back,] bottom,] right)`.
- /// @param hint_fwd_pd Primitive descriptor for a deconvolution
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &diff_weights_desc,
- const memory::desc &diff_dst_desc, const memory::dims &strides,
- const memory::dims &dilates, const memory::dims &padding_l,
- const memory::dims &padding_r,
- const deconvolution_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
- nullptr, diff_dst_desc, strides, &dilates, padding_l,
- padding_r, hint_fwd_pd, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a deconvolution weights
- /// gradient primitive from a C API primitive descriptor that must
- /// have a matching kind.
- ///
- /// @param pd C API primitive descriptor for a deconvolution weights
- /// gradient primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
- dnnl::prop_kind::backward_weights) {}
- /// @copydoc dnnl::primitive_desc_base::src_desc()const
- memory::desc src_desc() const { return base::src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
- memory::desc diff_weights_desc() const {
- return base::diff_weights_desc(0);
- }
- /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
- memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
- /// @copydoc dnnl::convolution_backward_weights::primitive_desc::diff_bias_desc()const
- memory::desc diff_bias_desc() const {
- return base::diff_weights_desc(1);
- }
- /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
- algorithm get_algorithm() const { return base::get_algorithm(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_strides()const
- memory::dims get_strides() const { return base::get_strides(); }
- /// @copydoc dnnl::primitive_desc_base::get_dilations()const
- memory::dims get_dilations() const { return base::get_dilations(); }
- /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
- memory::dims get_padding_l() const { return base::get_padding_l(); }
- /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
- memory::dims get_padding_r() const { return base::get_padding_r(); }
- private:
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &src_desc,
- const memory::desc &diff_weights_desc,
- const memory::desc *diff_bias_desc,
- const memory::desc &diff_dst_desc, const memory::dims &strides,
- const memory::dims *dilates, const memory::dims &padding_l,
- const memory::dims &padding_r,
- const deconvolution_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr, bool allow_empty) {
- memory::validate_dims(strides, src_desc.get_ndims() - 2);
- memory::validate_dims(padding_l, src_desc.get_ndims() - 2);
- memory::validate_dims(padding_r, src_desc.get_ndims() - 2);
- if (dilates)
- memory::validate_dims(*dilates, src_desc.get_ndims() - 2);
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status
- = dnnl_deconvolution_backward_weights_primitive_desc_create(
- &pd, aengine.get(), convert_to_c(aalgorithm),
- src_desc.get(), diff_weights_desc.get(),
- optional_arg(diff_bias_desc), diff_dst_desc.get(),
- &strides[0], optional_arg(dilates), &padding_l[0],
- &padding_r[0], hint_fwd_pd.get(), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the deconvolution weights update primitive. Run "
- "workload with environment variable ONEDNN_VERBOSE=all "
- "to get additional diagnostic information.");
- reset(pd);
- }
- };
- /// Default constructor. Produces an empty object.
- deconvolution_backward_weights() = default;
- /// Constructs a deconvolution weights gradient primitive.
- /// @param pd Primitive descriptor for a deconvolution weights gradient
- /// primitive.
- deconvolution_backward_weights(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a deconvolution weights gradient primitive from a cache
- /// blob.
- /// @param pd Primitive descriptor for a deconvolution weights gradient
- /// primitive.
- /// @param cache_blob Cache blob.
- deconvolution_backward_weights(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// @} dnnl_api_deconvolution
- /// @addtogroup dnnl_api_lrn LRN
- ///
- /// A primitive to perform local response normalization (LRN) across or within
- /// channels.
- ///
- /// @sa @ref dev_guide_lrn in developer guide
- ///
- /// @{
- /// Local response normalization (LRN) forward propagation primitive.
- struct lrn_forward : public primitive {
- /// Primitive descriptor for an LRN forward propagation primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for an LRN forward propagation
- /// primitive.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param aalgorithm LRN algorithm kind: either
- /// #dnnl::algorithm::lrn_across_channels, or
- /// #dnnl::algorithm::lrn_within_channel.
- /// @param src_desc Source memory descriptor.
- /// @param dst_desc Destination memory descriptor.
- /// @param local_size Regularization local size.
- /// @param alpha The alpha regularization parameter.
- /// @param beta The beta regularization parameter.
- /// @param k The k regularization parameter.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- algorithm aalgorithm, const memory::desc &src_desc,
- const memory::desc &dst_desc, memory::dim local_size,
- float alpha, float beta, float k,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false) {
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status = dnnl_lrn_forward_primitive_desc_create(&pd,
- aengine.get(), dnnl::convert_to_c(aprop_kind),
- convert_to_c(aalgorithm), src_desc.get(), dst_desc.get(),
- local_size, alpha, beta, k, attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the lrn forward propagation primitive. Run workload "
- "with environment variable ONEDNN_VERBOSE=all to get "
- "additional diagnostic information.");
- reset(pd);
- }
- /// Constructs a primitive descriptor for an LRN forward propagation
- /// primitive from a C API primitive descriptor that must have a
- /// matching kind.
- ///
- /// @param pd C API primitive descriptor for an LRN forward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn,
- dnnl::prop_kind::forward_training,
- dnnl::prop_kind::forward_inference) {}
- /// @copydoc dnnl::primitive_desc_base::src_desc()const
- memory::desc src_desc() const { return base::src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::dst_desc()const
- memory::desc dst_desc() const { return base::dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
- memory::desc workspace_desc() const { return base::workspace_desc(); }
- /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
- algorithm get_algorithm() const { return base::get_algorithm(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_alpha()const
- float get_alpha() const { return base::get_alpha(); }
- /// @copydoc dnnl::primitive_desc_base::get_beta()const
- float get_beta() const { return base::get_beta(); }
- /// @copydoc dnnl::primitive_desc_base::get_local_size()const
- memory::dim get_local_size() const { return base::get_local_size(); }
- /// @copydoc dnnl::primitive_desc_base::get_k()const
- float get_k() const { return base::get_k(); }
- };
- /// Default constructor. Produces an empty object.
- lrn_forward() = default;
- /// Constructs an LRN forward propagation primitive.
- /// @param pd Primitive descriptor for an LRN forward propagation
- /// primitive.
- lrn_forward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs an LRN forward propagation primitive from a cache blob.
- /// @param pd Primitive descriptor for an LRN forward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- lrn_forward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// Local response normalization (LRN) backward propagation primitive.
- struct lrn_backward : public primitive {
- /// Primitive descriptor for an LRN backward propagation primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for an LRN backward propagation
- /// primitive.
- ///
- /// @param aengine Engine to use.
- /// @param aalgorithm LRN algorithm kind: either
- /// #dnnl::algorithm::lrn_across_channels, or
- /// #dnnl::algorithm::lrn_within_channel.
- /// @param diff_src_desc Diff source memory descriptor.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param src_desc Source memory descriptor.
- /// @param local_size Regularization local size.
- /// @param alpha The alpha regularization parameter.
- /// @param beta The beta regularization parameter.
- /// @param k The k regularization parameter.
- /// @param hint_fwd_pd Primitive descriptor for an LRN forward
- /// propagation primitive. It is used as a hint for deciding which
- /// memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &diff_src_desc,
- const memory::desc &diff_dst_desc, const memory::desc &src_desc,
- memory::dim local_size, float alpha, float beta, float k,
- const lrn_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false) {
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status = dnnl_lrn_backward_primitive_desc_create(&pd,
- aengine.get(), convert_to_c(aalgorithm),
- diff_src_desc.get(), diff_dst_desc.get(), src_desc.get(),
- local_size, alpha, beta, k, hint_fwd_pd.get(), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the lrn backward propagation primitive. Run workload "
- "with environment variable ONEDNN_VERBOSE=all to get "
- "additional diagnostic information.");
- reset(pd);
- }
- /// Constructs a primitive descriptor for an LRN backward propagation
- /// primitive from a C API primitive descriptor that must have a
- /// matching kind.
- ///
- /// @param pd C API primitive descriptor for an LRN backward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn,
- dnnl::prop_kind::backward_data) {}
- /// @copydoc dnnl::primitive_desc_base::src_desc()const
- memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
- memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
- memory::desc workspace_desc() const { return base::workspace_desc(); }
- /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
- algorithm get_algorithm() const { return base::get_algorithm(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_alpha()const
- float get_alpha() const { return base::get_alpha(); }
- /// @copydoc dnnl::primitive_desc_base::get_beta()const
- float get_beta() const { return base::get_beta(); }
- /// @copydoc dnnl::primitive_desc_base::get_local_size()const
- memory::dim get_local_size() const { return base::get_local_size(); }
- /// @copydoc dnnl::primitive_desc_base::get_k()const
- float get_k() const { return base::get_k(); }
- };
- /// Default constructor. Produces an empty object.
- lrn_backward() = default;
- /// Constructs an LRN backward propagation primitive.
- /// @param pd Primitive descriptor for an LRN backward propagation
- /// primitive.
- lrn_backward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs an LRN backward propagation primitive from a cache blob.
- /// @param pd Primitive descriptor for an LRN backward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- lrn_backward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// @} dnnl_api_lrn
- /// @addtogroup dnnl_api_eltwise Eltwise
- ///
- /// A primitive to perform elementwise operations such as the
- /// rectifier linear unit (ReLU).
- ///
- /// Both forward and backward propagation primitives support in-place
- /// operation; that is, src and dst can refer to the same memory for forward
- /// propagation, and diff_dst and diff_src can refer to the same memory for
- /// backward propagation.
- ///
- /// @warning
- /// Because the original source data is required for backward propagation,
- /// in-place forward propagation is not generally supported in the
- /// training mode. However, for algorithms supporting destination as input
- /// memory, dst can be used for the backward propagation, which makes it
- /// possible to get performance benefit even in the training mode.
- ///
- /// @sa @ref dev_guide_eltwise in developer guide
- ///
- /// @{
- /// Elementwise unary operation forward propagation primitive.
- struct eltwise_forward : public primitive {
- /// Primitive descriptor for an elementwise forward propagation primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for an elementwise forward
- /// propagation primitive.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param aalgorithm Elementwise algorithm kind.
- /// @param src_desc Source memory descriptor.
- /// @param dst_desc Destination memory descriptor.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- algorithm aalgorithm, const memory::desc &src_desc,
- const memory::desc &dst_desc,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
- dst_desc, nullptr, nullptr, attr, allow_empty) {}
- /// Constructs a primitive descriptor for an elementwise forward
- /// propagation primitive with an alpha parameter.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param aalgorithm Elementwise algorithm kind.
- /// @param src_desc Source memory descriptor.
- /// @param dst_desc Destination memory descriptor.
- /// @param alpha The alpha parameter for the elementwise operation.
- /// Specific meaning depends on the algorithm.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- algorithm aalgorithm, const memory::desc &src_desc,
- const memory::desc &dst_desc, float alpha,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
- dst_desc, &alpha, nullptr, attr, allow_empty) {}
- /// Constructs a primitive descriptor for an elementwise forward
- /// propagation primitive with an alpha and beta parameters.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param aalgorithm Elementwise algorithm kind.
- /// @param src_desc Source memory descriptor.
- /// @param dst_desc Destination memory descriptor.
- /// @param alpha The alpha parameter for the elementwise operation.
- /// Specific meaning depends on the algorithm.
- /// @param beta The beta parameter for the elementwise operation.
- /// Specific meaning depends on the algorithm.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- algorithm aalgorithm, const memory::desc &src_desc,
- const memory::desc &dst_desc, float alpha, float beta,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
- dst_desc, &alpha, &beta, attr, allow_empty) {}
- /// Constructs a primitive descriptor for an eltwise forward
- /// propagation primitive from a C API primitive descriptor that must
- /// have a matching kind.
- ///
- /// @param pd C API primitive descriptor for an eltwise forward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise,
- dnnl::prop_kind::forward_training,
- dnnl::prop_kind::forward_inference) {}
- /// @copydoc dnnl::primitive_desc_base::src_desc()const
- memory::desc src_desc() const { return base::src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::dst_desc()const
- memory::desc dst_desc() const { return base::dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
- dnnl::algorithm get_algorithm() const { return base::get_algorithm(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_alpha()const
- float get_alpha() const { return base::get_alpha(); }
- /// @copydoc dnnl::primitive_desc_base::get_beta()const
- float get_beta() const { return base::get_beta(); }
- private:
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- algorithm aalgorithm, const memory::desc &src_desc,
- const memory::desc &dst_desc, const float *alpha,
- const float *beta, const primitive_attr &attr,
- bool allow_empty) {
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status = dnnl_eltwise_forward_primitive_desc_create(
- &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
- dnnl::convert_to_c(aalgorithm), src_desc.get(),
- dst_desc.get(), alpha ? *alpha : 0.0f, beta ? *beta : 0.0f,
- attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the eltwise forward propagation primitive. Run "
- "workload with environment variable ONEDNN_VERBOSE=all "
- "to get additional diagnostic information.");
- reset(pd);
- }
- };
- /// Default constructor. Produces an empty object.
- eltwise_forward() = default;
- /// Constructs an eltwise forward propagation primitive.
- /// @param pd Primitive descriptor for an eltwise forward propagation
- /// primitive.
- eltwise_forward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs an eltwise forward propagation primitive from a cache blob.
- /// @param pd Primitive descriptor for an eltwise forward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- eltwise_forward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// Elementwise unary operation backward propagation primitive.
- struct eltwise_backward : public primitive {
- /// Primitive descriptor for eltwise backward propagation.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for an elementwise backward
- /// propagation primitive with an alpha parameter.
- ///
- /// @param aengine Engine to use.
- /// @param aalgorithm Elementwise algorithm kind.
- /// @param diff_src_desc Diff source memory descriptor.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param data_desc Destination memory descriptor if one of the
- /// "use_dst_for_bwd" algorithms are used (such as
- /// #dnnl_eltwise_relu_use_dst_for_bwd), source memory descriptor
- /// otherwise.
- /// @param hint_fwd_pd Primitive descriptor for an elementwise
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &diff_src_desc,
- const memory::desc &diff_dst_desc,
- const memory::desc &data_desc,
- const eltwise_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aalgorithm, diff_src_desc, diff_dst_desc,
- data_desc, nullptr, nullptr, hint_fwd_pd, attr,
- allow_empty) {}
- /// Constructs a primitive descriptor for an elementwise backward
- /// propagation primitive with an alpha parameter.
- ///
- /// @param aengine Engine to use.
- /// @param aalgorithm Elementwise algorithm kind.
- /// @param diff_src_desc Diff source memory descriptor.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param data_desc Destination memory descriptor if one of the
- /// "use_dst_for_bwd" algorithms are used (such as
- /// #dnnl_eltwise_relu_use_dst_for_bwd), source memory descriptor
- /// otherwise.
- /// @param alpha The alpha parameter for the elementwise operation.
- /// Specific meaning depends on the algorithm.
- /// @param hint_fwd_pd Primitive descriptor for an elementwise
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &diff_src_desc,
- const memory::desc &diff_dst_desc,
- const memory::desc &data_desc, float alpha,
- const eltwise_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aalgorithm, diff_src_desc, diff_dst_desc,
- data_desc, &alpha, nullptr, hint_fwd_pd, attr,
- allow_empty) {}
- /// Constructs a primitive descriptor for an elementwise backward
- /// propagation primitive with an alpha and beta parameters.
- ///
- /// @param aengine Engine to use.
- /// @param aalgorithm Elementwise algorithm kind.
- /// @param diff_src_desc Diff source memory descriptor.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param data_desc Destination memory descriptor if one of the
- /// "use_dst_for_bwd" algorithms are used (such as
- /// #dnnl_eltwise_relu_use_dst_for_bwd), source memory descriptor
- /// otherwise.
- /// @param alpha The alpha parameter for the elementwise operation.
- /// Specific meaning depends on the algorithm.
- /// @param beta The beta parameter for the elementwise operation.
- /// Specific meaning depends on the algorithm.
- /// @param hint_fwd_pd Primitive descriptor for an elementwise
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &diff_src_desc,
- const memory::desc &diff_dst_desc,
- const memory::desc &data_desc, float alpha, float beta,
- const eltwise_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aalgorithm, diff_src_desc, diff_dst_desc,
- data_desc, &alpha, &beta, hint_fwd_pd, attr, allow_empty) {}
- /// Constructs a primitive descriptor for an eltwise backward
- /// propagation primitive from a C API primitive descriptor that must
- /// have a matching kind.
- ///
- /// @param pd C API primitive descriptor for an eltwise backward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise,
- dnnl::prop_kind::backward_data) {}
- /// @copydoc dnnl::primitive_desc_base::src_desc()const
- memory::desc src_desc() const { return base::src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
- memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
- memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
- dnnl::algorithm get_algorithm() const { return base::get_algorithm(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_alpha()const
- float get_alpha() const { return base::get_alpha(); }
- /// @copydoc dnnl::primitive_desc_base::get_beta()const
- float get_beta() const { return base::get_beta(); }
- private:
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &diff_src_desc,
- const memory::desc &diff_dst_desc,
- const memory::desc &data_desc, const float *alpha,
- const float *beta,
- const eltwise_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr, bool allow_empty) {
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status = dnnl_eltwise_backward_primitive_desc_create(
- &pd, aengine.get(), dnnl::convert_to_c(aalgorithm),
- diff_src_desc.get(), diff_dst_desc.get(), data_desc.get(),
- alpha ? *alpha : 0.0f, beta ? *beta : 0.0f,
- hint_fwd_pd.get(), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the eltwise backward propagation primitive. Run "
- "workload with environment variable ONEDNN_VERBOSE=all "
- "to get additional diagnostic information.");
- reset(pd);
- }
- };
- /// Default constructor. Produces an empty object.
- eltwise_backward() = default;
- /// Constructs an eltwise backward propagation primitive.
- /// @param pd Primitive descriptor for an eltwise backward propagation
- /// primitive.
- eltwise_backward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs an eltwise backward propagation primitive from a cache blob.
- /// @param pd Primitive descriptor for an eltwise backward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- eltwise_backward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// @} dnnl_api_eltwise
- /// @addtogroup dnnl_api_softmax Softmax
- ///
- /// A primitive to perform softmax.
- ///
- /// @sa @ref dev_guide_softmax in developer guide
- ///
- /// @{
- /// Softmax forward propagation primitive.
- struct softmax_forward : public primitive {
- /// Primitive descriptor for a softmax forward propagation primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a softmax forward propagation
- /// primitive.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param aalgorithm Softmax algorithm kind: either
- /// #dnnl::algorithm::softmax_accurate,
- /// or #dnnl::algorithm::softmax_log.
- /// @param src_desc Source memory descriptor.
- /// @param dst_desc Destination memory descriptor.
- /// @param axis Axis over which softmax is computed.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- algorithm aalgorithm, const memory::desc &src_desc,
- const memory::desc &dst_desc, int axis,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false) {
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status = dnnl_softmax_forward_primitive_desc_create(
- &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
- dnnl::convert_to_c(aalgorithm), src_desc.get(),
- dst_desc.get(), axis, attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the softmax forward propagation primitive. Run "
- "workload with environment variable ONEDNN_VERBOSE=all "
- "to get additional diagnostic information.");
- reset(pd);
- }
- /// Constructs a primitive descriptor for a softmax forward
- /// propagation primitive from a C API primitive descriptor that must
- /// have a matching kind.
- ///
- /// @param pd C API primitive descriptor for a softmax forward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax,
- dnnl::prop_kind::forward_training,
- dnnl::prop_kind::forward_inference) {}
- /// @copydoc dnnl::primitive_desc_base::src_desc()const
- memory::desc src_desc() const { return base::src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::dst_desc()const
- memory::desc dst_desc() const { return base::dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
- dnnl::algorithm get_algorithm() const { return base::get_algorithm(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_axis()const
- int get_axis() const { return base::get_axis(); }
- };
- /// Default constructor. Produces an empty object.
- softmax_forward() = default;
- /// Constructs a softmax forward propagation primitive.
- /// @param pd Primitive descriptor for a softmax forward propagation
- /// primitive.
- softmax_forward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a softmax forward propagation primitive from a cache blob.
- /// @param pd Primitive descriptor for a softmax forward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- softmax_forward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// Softmax backward propagation primitive.
- struct softmax_backward : public primitive {
- /// Primitive descriptor for a softmax backward propagation primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a softmax backward propagation
- /// primitive.
- ///
- /// @param aengine Engine to use.
- /// @param aalgorithm Softmax algorithm kind: either
- /// #dnnl::algorithm::softmax_accurate,
- /// or #dnnl::algorithm::softmax_log.
- /// @param diff_src_desc Diff source memory descriptor.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param dst_desc Destination memory descriptor.
- /// @param axis Axis over which softmax is computed.
- /// @param hint_fwd_pd Primitive descriptor for a softmax
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &diff_src_desc,
- const memory::desc &diff_dst_desc, const memory::desc &dst_desc,
- int axis, const softmax_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false) {
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status = dnnl_softmax_backward_primitive_desc_create(
- &pd, aengine.get(), dnnl::convert_to_c(aalgorithm),
- diff_src_desc.get(), diff_dst_desc.get(), dst_desc.get(),
- axis, hint_fwd_pd.get(), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the softmax backward propagation primitive. Run "
- "workload with environment variable ONEDNN_VERBOSE=all "
- "to get additional diagnostic information.");
- reset(pd);
- }
- /// Constructs a primitive descriptor for a softmax backward
- /// propagation primitive from a C API primitive descriptor that must
- /// have a matching kind.
- ///
- /// @param pd C API primitive descriptor for a softmax backward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax,
- dnnl::prop_kind::backward_data) {}
- /// @copydoc dnnl::primitive_desc_base::dst_desc()const
- memory::desc dst_desc() const { return base::dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
- memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::dst_desc()const
- memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
- dnnl::algorithm get_algorithm() const { return base::get_algorithm(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_axis()const
- int get_axis() const { return base::get_axis(); }
- };
- /// Default constructor. Produces an empty object.
- softmax_backward() = default;
- /// Constructs a softmax backward propagation primitive.
- /// @param pd Primitive descriptor for a softmax backward propagation
- /// primitive.
- softmax_backward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a softmax backward propagation primitive from a cache blob.
- /// @param pd Primitive descriptor for a softmax backward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- softmax_backward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// @} dnnl_api_softmax
- /// @addtogroup dnnl_api_batch_normalization Batch Normalization
- ///
- /// A primitive to perform batch normalization.
- ///
- /// Both forward and backward propagation primitives support in-place
- /// operation; that is, src and dst can refer to the same memory for forward
- /// propagation, and diff_dst and diff_src can refer to the same memory for
- /// backward propagation.
- ///
- /// The batch normalization primitives computations can be controlled by
- /// specifying different @ref dnnl::normalization_flags values. For example,
- /// batch normalization forward propagation can be configured to either
- /// compute the mean and variance or take them as arguments. It can either
- /// perform scaling and shifting using gamma and beta parameters or not.
- /// Optionally, it can also perform a fused ReLU, which in case of training
- /// would also require a workspace.
- ///
- /// @sa @ref dev_guide_batch_normalization in developer guide
- ///
- /// @{
- /// Batch normalization forward propagation primitive.
- struct batch_normalization_forward : public primitive {
- /// Primitive descriptor for a batch normalization forward propagation
- /// primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a batch normalization forward
- /// propagation primitive.
- ///
- /// @note
- /// In-place operation is supported: the dst can refer to the same
- /// memory as the src.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training and
- /// #dnnl::prop_kind::forward_inference.
- /// @param src_desc Source memory descriptor.
- /// @param dst_desc Destination memory descriptor.
- /// @param epsilon Batch normalization epsilon parameter.
- /// @param flags Batch normalization flags (@ref
- /// dnnl::normalization_flags).
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- const memory::desc &src_desc, const memory::desc &dst_desc,
- float epsilon, normalization_flags flags,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false) {
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status
- = dnnl_batch_normalization_forward_primitive_desc_create(
- &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
- src_desc.get(), dst_desc.get(), epsilon,
- convert_to_c(flags), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the batch normalization forward propagation "
- "primitive. Run workload with environment variable "
- "ONEDNN_VERBOSE=all to get additional diagnostic "
- "information.");
- reset(pd);
- }
- /// Constructs a primitive descriptor for a batch normalization
- /// forward propagation primitive from a C API primitive descriptor
- /// that must have a matching kind.
- ///
- /// @param pd C API primitive descriptor for a batch normalization
- /// forward propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd,
- dnnl::primitive::kind::batch_normalization,
- dnnl::prop_kind::forward_training,
- dnnl::prop_kind::forward_inference) {}
- /// @copydoc dnnl::primitive_desc_base::src_desc()const
- memory::desc src_desc() const { return base::src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::dst_desc()const
- memory::desc dst_desc() const { return base::dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::weights_desc()const
- memory::desc weights_desc() const { return base::weights_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
- memory::desc workspace_desc() const { return base::workspace_desc(); }
- /// Returns memory descriptor for mean.
- /// @returns Memory descriptor for mean.
- memory::desc mean_desc() const { return stat_desc(mean); }
- /// Returns memory descriptor for variance.
- /// @returns Memory descriptor for variance.
- memory::desc variance_desc() const { return stat_desc(var); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_epsilon()const
- float get_epsilon() const { return base::get_epsilon(); }
- /// Returns normalization flags.
- /// @return Normalization flags.
- normalization_flags get_flags() const {
- return base::get_flags<normalization_flags>();
- }
- private:
- enum {
- mean = 1,
- var = 2,
- };
- memory::desc stat_desc(int kind) const {
- const bool use_global_stats
- = (get_flags() & normalization_flags::use_global_stats)
- != normalization_flags::none;
- return query_md(
- use_global_stats ? query::src_md : query::dst_md, kind);
- }
- };
- /// Default constructor. Produces an empty object.
- batch_normalization_forward() = default;
- /// Constructs a batch normalization forward propagation primitive.
- /// @param pd Primitive descriptor for a batch normalization forward
- /// propagation primitive.
- batch_normalization_forward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a batch normalization forward propagation primitive from
- /// a cache blob.
- /// @param pd Primitive descriptor for a batch normalization forward
- /// propagation primitive.
- /// @param cache_blob Cache blob.
- batch_normalization_forward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// Batch normalization backward propagation primitive.
- struct batch_normalization_backward : public primitive {
- /// Primitive descriptor for a batch normalization backward propagation
- /// primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a batch normalization backward
- /// propagation primitive.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
- /// (diffs for all parameters are computed in this case).
- /// @param diff_src_desc Diff source memory descriptor.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param src_desc Source memory descriptor.
- /// @param epsilon Batch normalization epsilon parameter.
- /// @param flags Batch normalization flags (@ref
- /// dnnl::normalization_flags).
- /// @param hint_fwd_pd Primitive descriptor for a batch normalization
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- const memory::desc &diff_src_desc,
- const memory::desc &diff_dst_desc, const memory::desc &src_desc,
- float epsilon, normalization_flags flags,
- const batch_normalization_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false) {
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status
- = dnnl_batch_normalization_backward_primitive_desc_create(
- &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
- diff_src_desc.get(), diff_dst_desc.get(),
- src_desc.get(), epsilon, convert_to_c(flags),
- hint_fwd_pd.get(), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the batch normalization backward propagation "
- "primitive. Run workload with environment variable "
- "ONEDNN_VERBOSE=all to get additional diagnostic "
- "information.");
- reset(pd);
- }
- /// Constructs a primitive descriptor for a batch normalization
- /// backward propagation primitive from a C API primitive descriptor
- /// that must have a matching kind.
- ///
- /// @param pd C API primitive descriptor for a batch normalization
- /// backward propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd,
- dnnl::primitive::kind::batch_normalization,
- dnnl::prop_kind::backward, dnnl::prop_kind::backward_data) {
- }
- /// @copydoc dnnl::primitive_desc_base::src_desc()const
- memory::desc src_desc() const { return base::src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::weights_desc()const
- memory::desc weights_desc() const { return base::weights_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::dst_desc()const
- memory::desc dst_desc() const { return base::dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
- memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
- memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
- memory::desc diff_weights_desc() const {
- return base::diff_weights_desc(0);
- }
- /// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const
- memory::desc mean_desc() const { return query_md(query::src_md, 1); }
- /// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const
- memory::desc variance_desc() const {
- return query_md(query::src_md, 2);
- }
- /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
- memory::desc workspace_desc() const { return base::workspace_desc(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_epsilon()const
- float get_epsilon() const { return base::get_epsilon(); }
- /// Returns normalization flags.
- /// @return Normalization flags.
- normalization_flags get_flags() const {
- return base::get_flags<normalization_flags>();
- }
- };
- /// Default constructor. Produces an empty object.
- batch_normalization_backward() = default;
- /// Constructs a batch normalization backward propagation primitive.
- /// @param pd Primitive descriptor for a batch normalization backward
- /// propagation primitive.
- batch_normalization_backward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a batch normalization backward propagation primitive from
- /// a cache blob.
- /// @param pd Primitive descriptor for a batch normalization backward
- /// propagation primitive.
- /// @param cache_blob Cache blob.
- batch_normalization_backward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// @} dnnl_api_batch_normalization
- /// @addtogroup dnnl_api_group_normalization Group Normalization
- ///
- /// A primitive to perform group normalization.
- ///
- /// Both forward and backward propagation primitives support in-place
- /// operation; that is, src and dst can refer to the same memory for forward
- /// propagation, and diff_dst and diff_src can refer to the same memory for
- /// backward propagation.
- ///
- /// The group normalization primitives computations can be controlled by
- /// specifying different @ref dnnl::normalization_flags values. For example,
- /// group normalization forward propagation can be configured to either
- /// compute the mean and variance or take them as arguments. It can either
- /// perform scaling and shifting using gamma and beta parameters or not.
- ///
- /// @sa @ref dev_guide_group_normalization in developer guide
- ///
- /// @{
- /// Group normalization forward propagation primitive.
- struct group_normalization_forward : public primitive {
- /// Primitive descriptor for a group normalization forward propagation
- /// primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a group normalization forward
- /// propagation primitive.
- ///
- /// @note
- /// In-place operation is supported: the dst can refer to the same
- /// memory as the src.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training and
- /// #dnnl::prop_kind::forward_inference.
- /// @param src_desc Source memory descriptor.
- /// @param dst_desc Destination memory descriptor.
- /// @param groups Group normalization groups parameter.
- /// @param epsilon Group normalization epsilon parameter.
- /// @param flags Group normalization flags (@ref
- /// dnnl::normalization_flags).
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- const memory::desc &src_desc, const memory::desc &dst_desc,
- memory::dim groups, float epsilon, normalization_flags flags,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false) {
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status
- = dnnl_group_normalization_forward_primitive_desc_create(
- &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
- src_desc.get(), dst_desc.get(), groups, epsilon,
- convert_to_c(flags), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the group normalization forward propagation "
- "primitive. Run workload with environment variable "
- "ONEDNN_VERBOSE=all to get additional diagnostic "
- "information.");
- reset(pd);
- }
- /// Constructs a primitive descriptor for a group normalization
- /// forward propagation primitive from a C API primitive descriptor
- /// that must have a matching kind.
- ///
- /// @param pd C API primitive descriptor for a group normalization
- /// forward propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd,
- dnnl::primitive::kind::group_normalization,
- dnnl::prop_kind::forward_training,
- dnnl::prop_kind::forward_inference) {}
- /// @copydoc dnnl::primitive_desc_base::src_desc()const
- memory::desc src_desc() const { return base::src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::dst_desc()const
- memory::desc dst_desc() const { return base::dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::weights_desc()const
- memory::desc weights_desc() const { return base::weights_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
- memory::desc workspace_desc() const { return base::workspace_desc(); }
- /// Returns memory descriptor for mean.
- /// @returns Memory descriptor for mean.
- memory::desc mean_desc() const { return stat_desc(mean); }
- /// Returns memory descriptor for variance.
- /// @returns Memory descriptor for variance.
- memory::desc variance_desc() const { return stat_desc(var); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_group_size()const
- memory::dim get_group_size() const { return base::get_group_size(); }
- /// @copydoc dnnl::primitive_desc_base::get_epsilon()const
- float get_epsilon() const { return base::get_epsilon(); }
- /// Returns normalization flags.
- /// @return Normalization flags.
- normalization_flags get_flags() const {
- return base::get_flags<normalization_flags>();
- }
- private:
- enum {
- mean = 1,
- var = 2,
- };
- memory::desc stat_desc(int kind) const {
- const bool use_global_stats
- = (get_flags() & normalization_flags::use_global_stats)
- != normalization_flags::none;
- return query_md(
- use_global_stats ? query::src_md : query::dst_md, kind);
- }
- };
- /// Default constructor. Produces an empty object.
- group_normalization_forward() = default;
- /// Constructs a group normalization forward propagation primitive.
- /// @param pd Primitive descriptor for a group normalization forward
- /// propagation primitive.
- group_normalization_forward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a group normalization forward propagation primitive from
- /// a cache blob.
- /// @param pd Primitive descriptor for a group normalization forward
- /// propagation primitive.
- /// @param cache_blob Cache blob.
- group_normalization_forward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// Group normalization backward propagation primitive.
- struct group_normalization_backward : public primitive {
- /// Primitive descriptor for a group normalization backward propagation
- /// primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a group normalization backward
- /// propagation primitive.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
- /// (diffs for all parameters are computed in this case).
- /// @param diff_src_desc Diff source memory descriptor.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param src_desc Source memory descriptor.
- /// @param groups Group normalization groups parameter.
- /// @param epsilon Group normalization epsilon parameter.
- /// @param flags Group normalization flags (@ref
- /// dnnl::normalization_flags).
- /// @param hint_fwd_pd Primitive descriptor for a group normalization
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- const memory::desc &diff_src_desc,
- const memory::desc &diff_dst_desc, const memory::desc &src_desc,
- memory::dim groups, float epsilon, normalization_flags flags,
- const group_normalization_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false) {
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status
- = dnnl_group_normalization_backward_primitive_desc_create(
- &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
- diff_src_desc.get(), diff_dst_desc.get(),
- src_desc.get(), groups, epsilon,
- convert_to_c(flags), hint_fwd_pd.get(), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the group normalization backward propagation "
- "primitive. Run workload with environment variable "
- "ONEDNN_VERBOSE=all to get additional diagnostic "
- "information.");
- reset(pd);
- }
- /// Constructs a primitive descriptor for a group normalization
- /// backward propagation primitive from a C API primitive descriptor
- /// that must have a matching kind.
- ///
- /// @param pd C API primitive descriptor for a group normalization
- /// backward propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd,
- dnnl::primitive::kind::group_normalization,
- dnnl::prop_kind::backward, dnnl::prop_kind::backward_data) {
- }
- /// @copydoc dnnl::primitive_desc_base::src_desc()const
- memory::desc src_desc() const { return base::src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::weights_desc()const
- memory::desc weights_desc() const { return base::weights_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::dst_desc()const
- memory::desc dst_desc() const { return base::dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
- memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
- memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
- memory::desc diff_weights_desc() const {
- return base::diff_weights_desc(0);
- }
- /// @copydoc dnnl::group_normalization_forward::primitive_desc::mean_desc()const
- memory::desc mean_desc() const { return query_md(query::src_md, 1); }
- /// @copydoc dnnl::group_normalization_forward::primitive_desc::variance_desc()const
- memory::desc variance_desc() const {
- return query_md(query::src_md, 2);
- }
- /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
- memory::desc workspace_desc() const { return base::workspace_desc(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_group_size()const
- memory::dim get_group_size() const { return base::get_group_size(); }
- /// @copydoc dnnl::primitive_desc_base::get_epsilon()const
- float get_epsilon() const { return base::get_epsilon(); }
- /// Returns normalization flags.
- /// @return Normalization flags.
- normalization_flags get_flags() const {
- return base::get_flags<normalization_flags>();
- }
- };
- /// Default constructor. Produces an empty object.
- group_normalization_backward() = default;
- /// Constructs a group normalization backward propagation primitive.
- /// @param pd Primitive descriptor for a group normalization backward
- /// propagation primitive.
- group_normalization_backward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a group normalization backward propagation primitive from
- /// a cache blob.
- /// @param pd Primitive descriptor for a group normalization backward
- /// propagation primitive.
- /// @param cache_blob Cache blob.
- group_normalization_backward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// @} dnnl_api_group_normalization
- /// @addtogroup dnnl_api_layer_normalization Layer Normalization
- ///
- /// A primitive to perform layer normalization. Normalization is performed
- /// within the last logical dimension of data tensor.
- ///
- /// Both forward and backward propagation primitives support in-place
- /// operation; that is, src and dst can refer to the same memory for forward
- /// propagation, and diff_dst and diff_src can refer to the same memory for
- /// backward propagation.
- ///
- /// The layer normalization primitives computations can be controlled by
- /// specifying different @ref dnnl::normalization_flags values. For example,
- /// layer normalization forward propagation can be configured to either
- /// compute the mean and variance or take them as arguments. It can either
- /// perform scaling and shifting using gamma and beta parameters or not.
- ///
- /// @sa @ref dev_guide_layer_normalization in developer guide
- ///
- /// @{
- /// Layer normalization forward propagation primitive.
- struct layer_normalization_forward : public primitive {
- /// Primitive descriptor for a layer normalization forward propagation
- /// primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a layer normalization forward
- /// propagation primitive.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param src_desc Source memory descriptor.
- /// @param dst_desc Destination memory descriptor.
- /// @param stat_desc Statistics memory descriptors.
- /// @param epsilon Layer normalization epsilon parameter.
- /// @param flags Layer normalization flags (@ref
- /// dnnl::normalization_flags).
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- const memory::desc &src_desc, const memory::desc &dst_desc,
- const memory::desc &stat_desc, float epsilon,
- normalization_flags flags,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aprop_kind, src_desc, dst_desc,
- &stat_desc, memory::data_type::f32, epsilon, flags, attr,
- allow_empty) {}
- /// Constructs a primitive descriptor for a layer normalization forward
- /// propagation primitive.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param src_desc Source memory descriptor.
- /// @param dst_desc Destination memory descriptor.
- /// @param epsilon Layer normalization epsilon parameter.
- /// @param flags Layer normalization flags (@ref
- /// dnnl::normalization_flags).
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- const memory::desc &src_desc, const memory::desc &dst_desc,
- float epsilon, normalization_flags flags,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aprop_kind, src_desc, dst_desc, nullptr,
- memory::data_type::f32, epsilon, flags, attr, allow_empty) {
- }
- /// Constructs a primitive descriptor for a layer normalization forward
- /// propagation primitive with a user-provided data type for the scale
- /// and shift memory objects.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param src_desc Source memory descriptor.
- /// @param dst_desc Destination memory descriptor.
- /// @param stat_desc Statistics memory descriptors.
- /// @param scale_shift_data_type Data type of scale and shift memory.
- /// If neither scale nor shift flag are specified the parameter
- /// is ignored.
- /// @param epsilon Layer normalization epsilon parameter.
- /// @param flags Layer normalization flags (@ref
- /// dnnl::normalization_flags).
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- const memory::desc &src_desc, const memory::desc &dst_desc,
- const memory::desc &stat_desc,
- memory::data_type scale_shift_data_type, float epsilon,
- normalization_flags flags,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aprop_kind, src_desc, dst_desc,
- &stat_desc, scale_shift_data_type, epsilon, flags, attr,
- allow_empty) {}
- /// Constructs a primitive descriptor for a layer normalization forward
- /// propagation primitive with a user-provided data type for the scale
- /// and shift memory objects.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param src_desc Source memory descriptor.
- /// @param dst_desc Destination memory descriptor.
- /// @param scale_shift_data_type Data type of scale and shift memory.
- /// If neither scale nor shift flag are specified the parameter
- /// is ignored.
- /// @param epsilon Layer normalization epsilon parameter.
- /// @param flags Layer normalization flags (@ref
- /// dnnl::normalization_flags).
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- const memory::desc &src_desc, const memory::desc &dst_desc,
- memory::data_type scale_shift_data_type, float epsilon,
- normalization_flags flags,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aprop_kind, src_desc, dst_desc, nullptr,
- scale_shift_data_type, epsilon, flags, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a layer normalization
- /// forward propagation primitive from a C API primitive descriptor
- /// that must have a matching kind.
- ///
- /// @param pd C API primitive descriptor for a layer normalization
- /// forward propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd,
- dnnl::primitive::kind::layer_normalization,
- dnnl::prop_kind::forward_training,
- dnnl::prop_kind::forward_inference) {}
- /// @copydoc dnnl::primitive_desc_base::src_desc()const
- memory::desc src_desc() const { return base::src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::dst_desc()const
- memory::desc dst_desc() const { return base::dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::weights_desc()const
- memory::desc weights_desc() const { return base::weights_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
- memory::desc workspace_desc() const { return base::workspace_desc(); }
- /// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const
- memory::desc mean_desc() const { return stat_desc(mean); }
- /// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const
- memory::desc variance_desc() const { return stat_desc(var); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_epsilon()const
- float get_epsilon() const { return base::get_epsilon(); }
- /// Returns normalization flags.
- /// @return Normalization flags.
- normalization_flags get_flags() const {
- return base::get_flags<normalization_flags>();
- }
- private:
- enum {
- mean = 1,
- var = 2,
- };
- memory::desc stat_desc(int kind) const {
- const bool use_global_stats
- = (get_flags() & normalization_flags::use_global_stats)
- != normalization_flags::none;
- return query_md(
- use_global_stats ? query::src_md : query::dst_md, kind);
- }
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- const memory::desc &src_desc, const memory::desc &dst_desc,
- const memory::desc *stat_desc,
- memory::data_type scale_shift_data_type, float epsilon,
- normalization_flags flags, const primitive_attr &attr,
- bool allow_empty) {
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status
- = dnnl_layer_normalization_forward_primitive_desc_create_v2(
- &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
- src_desc.get(), dst_desc.get(),
- optional_arg(stat_desc),
- memory::convert_to_c(scale_shift_data_type),
- epsilon, convert_to_c(flags), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the layer normalization forward propagation "
- "primitive. Run workload with environment variable "
- "ONEDNN_VERBOSE=all to get additional diagnostic "
- "information.");
- reset(pd);
- }
- };
- /// Default constructor. Produces an empty object.
- layer_normalization_forward() = default;
- /// Constructs a layer normalization forward propagation primitive.
- /// @param pd Primitive descriptor for a layer normalization forward
- /// propagation primitive.
- layer_normalization_forward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a layer normalization forward propagation primitive from
- /// a cache blob.
- /// @param pd Primitive descriptor for a layer normalization forward
- /// propagation primitive.
- /// @param cache_blob Cache blob.
- layer_normalization_forward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// Layer normalization backward propagation primitive.
- struct layer_normalization_backward : public primitive {
- /// Primitive descriptor for a layer normalization backward propagation
- /// primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a layer normalization backward
- /// propagation primitive.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
- /// (diffs for all parameters are computed in this case).
- /// @param diff_src_desc Diff source memory descriptor.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param src_desc Source memory descriptor.
- /// @param stat_desc Statistics memory descriptors.
- /// @param epsilon Layer normalization epsilon parameter.
- /// @param flags Layer normalization flags (@ref
- /// dnnl::normalization_flags).
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param hint_fwd_pd Primitive descriptor for a layer normalization
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- const memory::desc &diff_src_desc,
- const memory::desc &diff_dst_desc, const memory::desc &src_desc,
- const memory::desc &stat_desc, float epsilon,
- normalization_flags flags,
- const layer_normalization_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aprop_kind, diff_src_desc, diff_dst_desc,
- src_desc, &stat_desc, memory::data_type::f32,
- memory::data_type::f32, epsilon, flags, hint_fwd_pd, attr,
- allow_empty) {}
- /// Constructs a primitive descriptor for a layer normalization backward
- /// propagation primitive.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
- /// (diffs for all parameters are computed in this case).
- /// @param diff_src_desc Diff source memory descriptor.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param src_desc Source memory descriptor.
- /// @param epsilon Layer normalization epsilon parameter.
- /// @param flags Layer normalization flags (@ref
- /// dnnl::normalization_flags).
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param hint_fwd_pd Primitive descriptor for a layer normalization
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- const memory::desc &diff_src_desc,
- const memory::desc &diff_dst_desc, const memory::desc &src_desc,
- float epsilon, normalization_flags flags,
- const layer_normalization_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aprop_kind, diff_src_desc, diff_dst_desc,
- src_desc, nullptr, memory::data_type::f32,
- memory::data_type::f32, epsilon, flags, hint_fwd_pd, attr,
- allow_empty) {}
- /// Constructs a primitive descriptor for a layer normalization backward
- /// propagation primitive with a user-provided data type for the scale
- /// and shift memory objects.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
- /// (diffs for all parameters are computed in this case).
- /// @param diff_src_desc Diff source memory descriptor.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param src_desc Source memory descriptor.
- /// @param stat_desc Statistics memory descriptors.
- /// @param diff_scale_shift_data_type Data type of diff scale and shift
- /// memory. If neither scale nor shift flag are specified the
- /// parameter is ignored.
- /// @param scale_shift_data_type Data type of scale and shift memory.
- /// If neither scale nor shift flag are specified the parameter
- /// is ignored.
- /// @param epsilon Layer normalization epsilon parameter.
- /// @param flags Layer normalization flags (@ref
- /// dnnl::normalization_flags).
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param hint_fwd_pd Primitive descriptor for a layer normalization
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- const memory::desc &diff_src_desc,
- const memory::desc &diff_dst_desc, const memory::desc &src_desc,
- const memory::desc &stat_desc,
- memory::data_type diff_scale_shift_data_type,
- memory::data_type scale_shift_data_type, float epsilon,
- normalization_flags flags,
- const layer_normalization_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aprop_kind, diff_src_desc, diff_dst_desc,
- src_desc, &stat_desc, diff_scale_shift_data_type,
- scale_shift_data_type, epsilon, flags, hint_fwd_pd, attr,
- allow_empty) {}
- /// Constructs a primitive descriptor for a layer normalization backward
- /// propagation primitive with a user-provided data type for the scale
- /// and shift memory objects.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
- /// (diffs for all parameters are computed in this case).
- /// @param diff_src_desc Diff source memory descriptor.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param src_desc Source memory descriptor.
- /// @param diff_scale_shift_data_type Data type of diff scale and shift
- /// memory. If neither scale nor shift flag are specified the
- /// parameter is ignored.
- /// @param scale_shift_data_type Data type of scale and shift memory.
- /// If neither scale nor shift flag are specified the parameter
- /// is ignored.
- /// @param epsilon Layer normalization epsilon parameter.
- /// @param flags Layer normalization flags (@ref
- /// dnnl::normalization_flags).
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param hint_fwd_pd Primitive descriptor for a layer normalization
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- const memory::desc &diff_src_desc,
- const memory::desc &diff_dst_desc, const memory::desc &src_desc,
- memory::data_type diff_scale_shift_data_type,
- memory::data_type scale_shift_data_type, float epsilon,
- normalization_flags flags,
- const layer_normalization_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aprop_kind, diff_src_desc, diff_dst_desc,
- src_desc, nullptr, diff_scale_shift_data_type,
- scale_shift_data_type, epsilon, flags, hint_fwd_pd, attr,
- allow_empty) {}
- /// Constructs a primitive descriptor for a layer normalization
- /// backward propagation primitive from a C API primitive descriptor
- /// that must have a matching kind.
- ///
- /// @param pd C API primitive descriptor for a layer normalization
- /// backward propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd,
- dnnl::primitive::kind::layer_normalization,
- dnnl::prop_kind::backward, dnnl::prop_kind::backward_data) {
- }
- /// @copydoc dnnl::primitive_desc_base::src_desc()const
- memory::desc src_desc() const { return base::src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::weights_desc()const
- memory::desc weights_desc() const { return base::weights_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::dst_desc()const
- memory::desc dst_desc() const { return base::dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
- memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
- memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
- memory::desc diff_weights_desc() const {
- return base::diff_weights_desc(0);
- }
- /// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const
- memory::desc mean_desc() const { return query_md(query::src_md, 1); }
- /// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const
- memory::desc variance_desc() const {
- return query_md(query::src_md, 2);
- }
- /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
- memory::desc workspace_desc() const { return base::workspace_desc(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_epsilon()const
- float get_epsilon() const { return base::get_epsilon(); }
- /// Returns normalization flags.
- /// @return Normalization flags.
- normalization_flags get_flags() const {
- return base::get_flags<normalization_flags>();
- }
- private:
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- const memory::desc &diff_src_desc,
- const memory::desc &diff_dst_desc, const memory::desc &src_desc,
- const memory::desc *stat_desc,
- memory::data_type diff_scale_shift_data_type,
- memory::data_type scale_shift_data_type, float epsilon,
- normalization_flags flags,
- const layer_normalization_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr, bool allow_empty) {
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status
- = dnnl_layer_normalization_backward_primitive_desc_create_v2(
- &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
- diff_src_desc.get(), diff_dst_desc.get(),
- src_desc.get(), optional_arg(stat_desc),
- memory::convert_to_c(diff_scale_shift_data_type),
- memory::convert_to_c(scale_shift_data_type),
- epsilon, convert_to_c(flags), hint_fwd_pd.get(),
- attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the layer normalization backward propagation "
- "primitive. Run workload with environment variable "
- "ONEDNN_VERBOSE=all to get additional diagnostic "
- "information.");
- reset(pd);
- }
- };
- /// Default constructor. Produces an empty object.
- layer_normalization_backward() = default;
- /// Constructs a layer normalization backward propagation primitive.
- /// @param pd Primitive descriptor for a layer normalization backward
- /// propagation primitive.
- layer_normalization_backward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a layer normalization backward propagation primitive from
- /// a cache blob.
- /// @param pd Primitive descriptor for a layer normalization backward
- /// propagation primitive.
- /// @param cache_blob Cache blob.
- layer_normalization_backward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// @} dnnl_api_layer_normalization
- /// @addtogroup dnnl_api_inner_product Inner Product
- ///
- /// A primitive to compute an inner product.
- ///
- /// @sa @ref dev_guide_inner_product in developer guide
- ///
- /// @{
- /// Inner product forward propagation primitive.
- struct inner_product_forward : public primitive {
- /// Primitive descriptor for an inner product forward propagation primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for an inner product forward
- /// propagation primitive with bias.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param src_desc Memory descriptor for src.
- /// @param weights_desc Memory descriptor for weights.
- /// @param bias_desc Memory descriptor for bias.
- /// @param dst_desc Memory descriptor for dst.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- const memory::desc &src_desc, const memory::desc &weights_desc,
- const memory::desc &bias_desc, const memory::desc &dst_desc,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aprop_kind, src_desc, weights_desc,
- &bias_desc, dst_desc, attr, allow_empty) {}
- /// Constructs a primitive descriptor for an inner product forward
- /// propagation primitive.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param src_desc Memory descriptor for src.
- /// @param weights_desc Memory descriptor for weights.
- /// @param dst_desc Memory descriptor for dst.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- const memory::desc &src_desc, const memory::desc &weights_desc,
- const memory::desc &dst_desc,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aprop_kind, src_desc, weights_desc,
- nullptr, dst_desc, attr, allow_empty) {}
- /// Constructs a primitive descriptor for an inner product forward
- /// propagation primitive from a C API primitive descriptor that must
- /// have a matching kind.
- ///
- /// @param pd C API primitive descriptor for an inner product forward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
- dnnl::prop_kind::forward_training,
- dnnl::prop_kind::forward_inference) {}
- /// @copydoc dnnl::primitive_desc_base::src_desc()const
- memory::desc src_desc() const { return base::src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::weights_desc()const
- memory::desc weights_desc() const { return base::weights_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::dst_desc()const
- memory::desc dst_desc() const { return base::dst_desc(0); }
- /// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const
- memory::desc bias_desc() const { return base::weights_desc(1); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- private:
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- const memory::desc &src_desc, const memory::desc &weights_desc,
- const memory::desc *bias_desc, const memory::desc &dst_desc,
- const primitive_attr &attr, bool allow_empty) {
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status
- = dnnl_inner_product_forward_primitive_desc_create(&pd,
- aengine.get(), dnnl::convert_to_c(aprop_kind),
- src_desc.get(), weights_desc.get(),
- optional_arg(bias_desc), dst_desc.get(),
- attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the inner product forward propagation primitive. Run "
- "workload with environment variable ONEDNN_VERBOSE=all "
- "to get additional diagnostic information.");
- reset(pd);
- }
- };
- /// Default constructor. Produces an empty object.
- inner_product_forward() = default;
- /// Constructs an inner product forward propagation primitive.
- /// @param pd Primitive descriptor for an inner product forward
- /// propagation primitive.
- inner_product_forward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs an inner product forward propagation primitive from
- /// a cache blob.
- /// @param pd Primitive descriptor for an inner product forward
- /// propagation primitive.
- /// @param cache_blob Cache blob.
- inner_product_forward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// Inner product backward propagation primitive.
- struct inner_product_backward_data : public primitive {
- /// Primitive descriptor for an inner product backward propagation
- /// primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for an inner product backward
- /// propagation primitive.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// @param aengine Engine to use.
- /// @param diff_src_desc Memory descriptor for diff src.
- /// @param weights_desc Memory descriptor for weights.
- /// @param diff_dst_desc Memory descriptor for diff dst.
- /// @param hint_fwd_pd Primitive descriptor for an inner product
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, const memory::desc &diff_src_desc,
- const memory::desc &weights_desc,
- const memory::desc &diff_dst_desc,
- const inner_product_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false) {
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status
- = dnnl_inner_product_backward_data_primitive_desc_create(
- &pd, aengine.get(), diff_src_desc.get(),
- weights_desc.get(), diff_dst_desc.get(),
- hint_fwd_pd.get(), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the inner product backward propagation primitive. Run "
- "workload with environment variable ONEDNN_VERBOSE=all "
- "to get additional diagnostic information.");
- reset(pd);
- }
- /// Constructs a primitive descriptor for an inner product backward
- /// propagation primitive from a C API primitive descriptor that must
- /// have a matching kind.
- ///
- /// @param pd C API primitive descriptor for an inner product backward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
- dnnl::prop_kind::backward_data) {}
- /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
- memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::weights_desc()const
- memory::desc weights_desc() const { return base::weights_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
- memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- };
- /// Default constructor. Produces an empty object.
- inner_product_backward_data() = default;
- /// Constructs an inner product backward propagation primitive.
- /// @param pd Primitive descriptor for an inner product backward
- /// propagation primitive.
- inner_product_backward_data(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs an inner product backward propagation primitive from
- /// a cache blob.
- /// @param pd Primitive descriptor for an inner product backward
- /// propagation primitive.
- /// @param cache_blob Cache blob.
- inner_product_backward_data(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// Inner product weights gradient primitive.
- struct inner_product_backward_weights : public primitive {
- /// Primitive descriptor for an inner product weights gradient primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for an inner product weights
- /// update primitive with bias.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// @param aengine Engine to use.
- /// @param src_desc Memory descriptor for src.
- /// @param diff_weights_desc Memory descriptor for diff weights.
- /// @param diff_bias_desc Memory descriptor for diff bias.
- /// @param diff_dst_desc Memory descriptor for diff dst.
- /// @param hint_fwd_pd Primitive descriptor for an inner product
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, const memory::desc &src_desc,
- const memory::desc &diff_weights_desc,
- const memory::desc &diff_bias_desc,
- const memory::desc &diff_dst_desc,
- const inner_product_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, src_desc, diff_weights_desc,
- &diff_bias_desc, diff_dst_desc, hint_fwd_pd, attr,
- allow_empty) {}
- /// Constructs a primitive descriptor for an inner product weights
- /// update primitive.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// @param aengine Engine to use.
- /// @param src_desc Memory descriptor for src.
- /// @param diff_weights_desc Memory descriptor for diff weights.
- /// @param diff_dst_desc Memory descriptor for diff dst.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param hint_fwd_pd Primitive descriptor for an inner product
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, const memory::desc &src_desc,
- const memory::desc &diff_weights_desc,
- const memory::desc &diff_dst_desc,
- const inner_product_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, src_desc, diff_weights_desc, nullptr,
- diff_dst_desc, hint_fwd_pd, attr, allow_empty) {}
- /// Constructs a primitive descriptor for an inner product weights
- /// update primitive from a C API primitive descriptor that must
- /// have a matching kind.
- ///
- /// @param pd C API primitive descriptor for an inner product weights
- /// gradient primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
- dnnl::prop_kind::backward_weights) {}
- /// @copydoc dnnl::primitive_desc_base::src_desc()const
- memory::desc src_desc() const { return base::src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
- memory::desc diff_weights_desc() const {
- return base::diff_weights_desc(0);
- }
- /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
- memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
- /// @copydoc dnnl::convolution_backward_weights::primitive_desc::diff_bias_desc()const
- memory::desc diff_bias_desc() const {
- return base::diff_weights_desc(1);
- }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- private:
- primitive_desc(const engine &aengine, const memory::desc &src_desc,
- const memory::desc &diff_weights_desc,
- const memory::desc *diff_bias_desc,
- const memory::desc &diff_dst_desc,
- const inner_product_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr, bool allow_empty) {
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status
- = dnnl_inner_product_backward_weights_primitive_desc_create(
- &pd, aengine.get(), src_desc.get(),
- diff_weights_desc.get(),
- optional_arg(diff_bias_desc), diff_dst_desc.get(),
- hint_fwd_pd.get(), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the inner product weights gradient primitive. Run "
- "workload with environment variable ONEDNN_VERBOSE=all "
- "to get additional diagnostic information.");
- reset(pd);
- }
- };
- /// Default constructor. Produces an empty object.
- inner_product_backward_weights() = default;
- /// Constructs an inner product weights gradient primitive.
- /// @param pd Primitive descriptor for an inner product weights gradient
- /// primitive.
- inner_product_backward_weights(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs an inner product weights gradient primitive from a cache
- /// blob.
- /// @param pd Primitive descriptor for an inner product weights gradient
- /// primitive.
- /// @param cache_blob Cache blob.
- inner_product_backward_weights(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// @} dnnl_api_inner_product
- /// @addtogroup dnnl_api_rnn RNN
- ///
- /// A primitive to compute recurrent neural network layers.
- ///
- /// @sa @ref dev_guide_rnn in developer guide
- ///
- /// @{
- /// Base class for primitive descriptors for RNN primitives.
- struct rnn_primitive_desc_base : public primitive_desc {
- using primitive_desc::primitive_desc;
- /// Default constructor. Produces an empty object.
- rnn_primitive_desc_base() = default;
- /// Constructs an RNN primitive descriptor base from a C API primitive
- /// descriptor while checking that it actually describes the expected
- /// primitive by comparing propagation and primitive kinds.
- ///
- /// @param pd C API primitive descriptor.
- /// @param aprop_kind Expected propagation kind.
- /// @param cell_kind Expected cell kind.
- rnn_primitive_desc_base(dnnl_primitive_desc_t pd,
- dnnl::prop_kind aprop_kind, dnnl::algorithm cell_kind)
- : rnn_primitive_desc_base(pd, aprop_kind, aprop_kind, cell_kind) {}
- /// Returns source layer memory descriptor.
- /// @returns Source layer memory descriptor.
- memory::desc src_layer_desc() const {
- return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_LAYER);
- }
- /// Returns AUGRU attention memory descriptor.
- /// @returns AUGRU attention memory descriptor.
- memory::desc augru_attention_desc() const {
- return base::query_md(query::exec_arg_md, DNNL_ARG_AUGRU_ATTENTION);
- }
- /// Returns source iteration memory descriptor.
- /// @returns Source iteration memory descriptor.
- /// @returns A zero memory descriptor if the primitive does not have a
- /// source iteration parameter.
- memory::desc src_iter_desc() const {
- return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_ITER);
- }
- /// Returns source recurrent cell state memory descriptor.
- /// @returns Source recurrent cell state memory descriptor.
- memory::desc src_iter_c_desc() const {
- return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_ITER_C);
- }
- /// Returns weights layer memory descriptor.
- /// @returns Weights layer memory descriptor.
- memory::desc weights_layer_desc() const {
- return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_LAYER);
- }
- /// Returns weights iteration memory descriptor.
- /// @returns Weights iteration memory descriptor.
- memory::desc weights_iter_desc() const {
- return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_ITER);
- }
- /// Returns weights peephole memory descriptor.
- /// @returns Weights peephole memory descriptor.
- memory::desc weights_peephole_desc() const {
- return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_PEEPHOLE);
- }
- /// Returns weights projection memory descriptor.
- /// @returns Weights projection memory descriptor.
- memory::desc weights_projection_desc() const {
- return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_PROJECTION);
- }
- /// Returns bias memory descriptor.
- /// @returns Bias memory descriptor.
- /// @returns A zero memory descriptor if the primitive does not have a
- /// bias parameter.
- memory::desc bias_desc() const {
- return base::query_md(query::exec_arg_md, DNNL_ARG_BIAS);
- }
- /// Returns destination layer memory descriptor.
- /// @returns Destination layer memory descriptor.
- memory::desc dst_layer_desc() const {
- return base::query_md(query::exec_arg_md, DNNL_ARG_DST_LAYER);
- }
- /// Returns destination iteration memory descriptor.
- /// @returns Destination iteration memory descriptor.
- /// @returns A zero memory descriptor if the primitive does not have a
- /// destination iteration parameter.
- memory::desc dst_iter_desc() const {
- return base::query_md(query::exec_arg_md, DNNL_ARG_DST_ITER);
- }
- /// Returns destination recurrent cell state memory descriptor.
- /// @returns Destination recurrent cell state memory descriptor.
- memory::desc dst_iter_c_desc() const {
- return base::query_md(query::exec_arg_md, DNNL_ARG_DST_ITER_C);
- }
- /// Returns diff source layer memory descriptor.
- /// @returns Diff source layer memory descriptor.
- memory::desc diff_src_layer_desc() const {
- return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_LAYER);
- }
- /// Returns diff AUGRU attention memory descriptor.
- /// @returns Diff AUGRU attention memory descriptor.
- memory::desc diff_augru_attention_desc() const {
- return base::query_md(
- query::exec_arg_md, DNNL_ARG_DIFF_AUGRU_ATTENTION);
- }
- /// Returns diff source iteration memory descriptor.
- /// @returns Diff source iteration memory descriptor.
- /// @returns A zero memory descriptor if the primitive does not have a
- /// diff source iteration parameter.
- memory::desc diff_src_iter_desc() const {
- return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_ITER);
- }
- /// Returns diff source recurrent cell state memory descriptor.
- /// @returns Diff source recurrent cell state memory descriptor.
- memory::desc diff_src_iter_c_desc() const {
- return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_ITER_C);
- }
- /// Returns diff weights layer memory descriptor.
- /// @returns Diff weights layer memory descriptor.
- memory::desc diff_weights_layer_desc() const {
- return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_LAYER);
- }
- /// Returns diff weights iteration memory descriptor.
- /// @returns Diff weights iteration memory descriptor.
- memory::desc diff_weights_iter_desc() const {
- return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_ITER);
- }
- /// Returns diff weights peephole memory descriptor.
- /// @returns Diff weights peephole memory descriptor.
- memory::desc diff_weights_peephole_desc() const {
- return base::query_md(
- query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE);
- }
- /// Returns diff weights projection memory descriptor.
- /// @returns Diff weights projection memory descriptor.
- memory::desc diff_weights_projection_desc() const {
- return base::query_md(
- query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_PROJECTION);
- }
- /// Returns diff bias memory descriptor.
- /// @returns Diff bias memory descriptor.
- /// @returns A zero memory descriptor if the primitive does not have a
- /// diff bias parameter.
- memory::desc diff_bias_desc() const {
- return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_BIAS);
- }
- /// Returns diff destination layer memory descriptor.
- /// @returns Diff destination layer memory descriptor.
- memory::desc diff_dst_layer_desc() const {
- return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_LAYER);
- }
- /// Returns diff destination iteration memory descriptor.
- /// @returns Diff destination iteration memory descriptor.
- /// @returns A zero memory descriptor if the primitive does not have a
- /// diff destination iteration parameter.
- memory::desc diff_dst_iter_desc() const {
- return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_ITER);
- }
- /// Returns diff destination recurrent cell state memory descriptor.
- /// @returns Diff destination recurrent cell state memory descriptor.
- memory::desc diff_dst_iter_c_desc() const {
- return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_ITER_C);
- }
- protected:
- using rnn_base = rnn_primitive_desc_base;
- // (Deliberately not using doxygen comments)
- //
- // Constructs an RNN primitive descriptor base from a C API primitive
- // descriptor while checking that it actually describes the expected
- // primitive by comparing propagation and primitive kinds. Caller can
- // pass two options propagation kinds. This is typically used to check
- // that propagation kind is inference or training forward propagation.
- //
- // @param pd C API primitive descriptor.
- // @param prop_kind1 Expected propagation kind.
- // @param prop_kind2 Expected propagation kind.
- // @param cell_kind Expected cell kind.
- rnn_primitive_desc_base(dnnl_primitive_desc_t pd,
- dnnl::prop_kind prop_kind1, dnnl::prop_kind prop_kind2,
- dnnl::algorithm cell_kind) {
- dnnl_status_t rc;
- dnnl_primitive_kind_t q_primitive_kind;
- rc = dnnl_primitive_desc_query(
- pd, dnnl_query_primitive_kind, 0, &q_primitive_kind);
- error::wrap_c_api(rc,
- "could not retrieve a primitive kind from a primitive "
- "descriptor for an RNN primitive");
- dnnl_prop_kind_t q_prop_kind;
- rc = dnnl_primitive_desc_query(
- pd, dnnl_query_prop_kind, 0, &q_prop_kind);
- error::wrap_c_api(rc,
- "could not retrieve a propagation kind from a primitive "
- "descriptor for an RNN primitive");
- dnnl_alg_kind_t q_cell_kind;
- rc = dnnl_primitive_desc_query(
- pd, dnnl_query_cell_kind, 0, &q_cell_kind);
- error::wrap_c_api(rc,
- "could not retrieve a cell kind from a primitive descriptor "
- "for an RNN primitive");
- dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1);
- dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2);
- dnnl_alg_kind_t c_cell_kind = convert_to_c(cell_kind);
- bool ok = q_primitive_kind == dnnl_rnn
- && (q_prop_kind == c_prop_kind1 || q_prop_kind == c_prop_kind2)
- && q_cell_kind == c_cell_kind;
- if (!ok)
- DNNL_THROW_ERROR(dnnl_invalid_arguments,
- "mismatch between expected and provided descriptors for an "
- "RNN primitive");
- reset_with_clone(pd);
- }
- // Constructs an RNN forward propagation primitive descriptor base for
- // any cell kind.
- rnn_primitive_desc_base(const engine &aengine, algorithm cell_kind,
- prop_kind aprop_kind, algorithm activation, rnn_direction direction,
- const memory::desc &src_layer_desc,
- const memory::desc &src_iter_desc,
- const memory::desc *src_iter_c_desc,
- const memory::desc *attention_desc,
- const memory::desc &weights_layer_desc,
- const memory::desc &weights_iter_desc,
- const memory::desc *weights_peephole_desc,
- const memory::desc *weights_projection_desc,
- const memory::desc &bias_desc, const memory::desc &dst_layer_desc,
- const memory::desc &dst_iter_desc,
- const memory::desc *dst_iter_c_desc, rnn_flags flags, float alpha,
- float beta, const primitive_attr &attr, bool allow_empty) {
- dnnl_status_t status = dnnl_success;
- const char *msg
- = "could not create a primitive descriptor for a requested "
- "cell kind";
- dnnl_primitive_desc_t pd = nullptr;
- switch (cell_kind) {
- case algorithm::vanilla_rnn:
- status = dnnl_vanilla_rnn_forward_primitive_desc_create(&pd,
- aengine.get(), dnnl::convert_to_c(aprop_kind),
- dnnl::convert_to_c(activation),
- dnnl::convert_to_c(direction), src_layer_desc.get(),
- src_iter_desc.get(), weights_layer_desc.get(),
- weights_iter_desc.get(), bias_desc.get(),
- dst_layer_desc.get(), dst_iter_desc.get(),
- convert_to_c(flags), alpha, beta, attr.get());
- msg = "could not create a primitive descriptor for "
- "the vanilla RNN forward propagation primitive. Run "
- "workload with environment variable ONEDNN_VERBOSE=all "
- "to get additional diagnostic information.";
- break;
- case algorithm::vanilla_lstm:
- status = dnnl_lstm_forward_primitive_desc_create(&pd,
- aengine.get(), dnnl::convert_to_c(aprop_kind),
- dnnl::convert_to_c(direction), src_layer_desc.get(),
- src_iter_desc.get(), optional_arg(src_iter_c_desc),
- weights_layer_desc.get(), weights_iter_desc.get(),
- optional_arg(weights_peephole_desc),
- optional_arg(weights_projection_desc), bias_desc.get(),
- dst_layer_desc.get(), dst_iter_desc.get(),
- optional_arg(dst_iter_c_desc), convert_to_c(flags),
- attr.get());
- msg = "could not create a primitive descriptor for "
- "the LSTM forward propagation primitive. Run workload "
- "with environment variable ONEDNN_VERBOSE=all to get "
- "additional diagnostic information.";
- break;
- case algorithm::vanilla_gru:
- status = dnnl_gru_forward_primitive_desc_create(&pd,
- aengine.get(), dnnl::convert_to_c(aprop_kind),
- dnnl::convert_to_c(direction), src_layer_desc.get(),
- src_iter_desc.get(), weights_layer_desc.get(),
- weights_iter_desc.get(), bias_desc.get(),
- dst_layer_desc.get(), dst_iter_desc.get(),
- convert_to_c(flags), attr.get());
- msg = "could not create a primitive descriptor for "
- "the GRU forward propagation primitive. Run workload "
- "with environment variable ONEDNN_VERBOSE=all to get "
- "additional diagnostic information.";
- break;
- case algorithm::lbr_gru:
- status = dnnl_lbr_gru_forward_primitive_desc_create(&pd,
- aengine.get(), dnnl::convert_to_c(aprop_kind),
- dnnl::convert_to_c(direction), src_layer_desc.get(),
- src_iter_desc.get(), weights_layer_desc.get(),
- weights_iter_desc.get(), bias_desc.get(),
- dst_layer_desc.get(), dst_iter_desc.get(),
- convert_to_c(flags), attr.get());
- msg = "could not create a primitive descriptor for "
- "the LBR GRU forward propagation primitive. Run workload "
- "with environment variable ONEDNN_VERBOSE=all to get "
- "additional diagnostic information.";
- break;
- case algorithm::vanilla_augru:
- status = dnnl_augru_forward_primitive_desc_create(&pd,
- aengine.get(), dnnl::convert_to_c(aprop_kind),
- dnnl::convert_to_c(direction), src_layer_desc.get(),
- src_iter_desc.get(), optional_arg(attention_desc),
- weights_layer_desc.get(), weights_iter_desc.get(),
- bias_desc.get(), dst_layer_desc.get(),
- dst_iter_desc.get(), convert_to_c(flags), attr.get());
- msg = "could not create a primitive descriptor for "
- "the AUGRU forward propagation primitive. Run workload "
- "with environment variable ONEDNN_VERBOSE=all to get "
- "additional diagnostic information.";
- break;
- case algorithm::lbr_augru:
- status = dnnl_lbr_augru_forward_primitive_desc_create(&pd,
- aengine.get(), dnnl::convert_to_c(aprop_kind),
- dnnl::convert_to_c(direction), src_layer_desc.get(),
- src_iter_desc.get(), optional_arg(attention_desc),
- weights_layer_desc.get(), weights_iter_desc.get(),
- bias_desc.get(), dst_layer_desc.get(),
- dst_iter_desc.get(), convert_to_c(flags), attr.get());
- msg = "could not create a primitive descriptor for "
- "the LBR AUGRU forward propagation primitive. Run "
- "workload with environment variable ONEDNN_VERBOSE=all "
- "to get additional diagnostic information.";
- break;
- default: status = dnnl_unimplemented;
- }
- if (!allow_empty) error::wrap_c_api(status, msg);
- reset(pd);
- }
- // Constructs an RNN backward propagation primitive descriptor base for
- // any cell kind.
- rnn_primitive_desc_base(const engine &aengine, algorithm cell_kind,
- prop_kind aprop_kind, algorithm activation, rnn_direction direction,
- const memory::desc &src_layer_desc,
- const memory::desc &src_iter_desc,
- const memory::desc *src_iter_c_desc,
- const memory::desc *attention_desc,
- const memory::desc &weights_layer_desc,
- const memory::desc &weights_iter_desc,
- const memory::desc *weights_peephole_desc,
- const memory::desc *weights_projection_desc,
- const memory::desc &bias_desc, const memory::desc &dst_layer_desc,
- const memory::desc &dst_iter_desc,
- const memory::desc *dst_iter_c_desc,
- const memory::desc &diff_src_layer_desc,
- const memory::desc &diff_src_iter_desc,
- const memory::desc *diff_src_iter_c_desc,
- const memory::desc *diff_attention_desc,
- const memory::desc &diff_weights_layer_desc,
- const memory::desc &diff_weights_iter_desc,
- const memory::desc *diff_weights_peephole_desc,
- const memory::desc *diff_weights_projection_desc,
- const memory::desc &diff_bias_desc,
- const memory::desc &diff_dst_layer_desc,
- const memory::desc &diff_dst_iter_desc,
- const memory::desc *diff_dst_iter_c_desc, rnn_flags flags,
- float alpha, float beta, const rnn_primitive_desc_base &hint_fwd_pd,
- const primitive_attr &attr, bool allow_empty) {
- dnnl_status_t status = dnnl_success;
- const char *msg = "";
- dnnl_primitive_desc_t pd = nullptr;
- switch (cell_kind) {
- case algorithm::vanilla_rnn:
- status = dnnl_vanilla_rnn_backward_primitive_desc_create(&pd,
- aengine.get(), dnnl::convert_to_c(aprop_kind),
- dnnl::convert_to_c(activation),
- dnnl::convert_to_c(direction), src_layer_desc.get(),
- src_iter_desc.get(), weights_layer_desc.get(),
- weights_iter_desc.get(), bias_desc.get(),
- dst_layer_desc.get(), dst_iter_desc.get(),
- diff_src_layer_desc.get(), diff_src_iter_desc.get(),
- diff_weights_layer_desc.get(),
- diff_weights_iter_desc.get(), diff_bias_desc.get(),
- diff_dst_layer_desc.get(), diff_dst_iter_desc.get(),
- convert_to_c(flags), alpha, beta, hint_fwd_pd.get(),
- attr.get());
- msg = "could not create a primitive descriptor for "
- "the vanilla RNN backward propagation primitive. Run "
- "workload with environment variable ONEDNN_VERBOSE=all "
- "to get additional diagnostic information.";
- break;
- case algorithm::vanilla_lstm:
- status = dnnl_lstm_backward_primitive_desc_create(&pd,
- aengine.get(), dnnl::convert_to_c(aprop_kind),
- dnnl::convert_to_c(direction), src_layer_desc.get(),
- src_iter_desc.get(), optional_arg(src_iter_c_desc),
- weights_layer_desc.get(), weights_iter_desc.get(),
- optional_arg(weights_peephole_desc),
- optional_arg(weights_projection_desc), bias_desc.get(),
- dst_layer_desc.get(), dst_iter_desc.get(),
- optional_arg(dst_iter_c_desc),
- diff_src_layer_desc.get(), diff_src_iter_desc.get(),
- optional_arg(diff_src_iter_c_desc),
- diff_weights_layer_desc.get(),
- diff_weights_iter_desc.get(),
- optional_arg(diff_weights_peephole_desc),
- optional_arg(diff_weights_projection_desc),
- diff_bias_desc.get(), diff_dst_layer_desc.get(),
- diff_dst_iter_desc.get(),
- optional_arg(diff_dst_iter_c_desc), convert_to_c(flags),
- hint_fwd_pd.get(), attr.get());
- msg = "could not create a primitive descriptor for "
- "the LSTM backward propagation primitive. Run workload "
- "with environment variable ONEDNN_VERBOSE=all to get "
- "additional diagnostic information.";
- break;
- case algorithm::vanilla_gru:
- status = dnnl_gru_backward_primitive_desc_create(&pd,
- aengine.get(), dnnl::convert_to_c(aprop_kind),
- dnnl::convert_to_c(direction), src_layer_desc.get(),
- src_iter_desc.get(), weights_layer_desc.get(),
- weights_iter_desc.get(), bias_desc.get(),
- dst_layer_desc.get(), dst_iter_desc.get(),
- diff_src_layer_desc.get(), diff_src_iter_desc.get(),
- diff_weights_layer_desc.get(),
- diff_weights_iter_desc.get(), diff_bias_desc.get(),
- diff_dst_layer_desc.get(), diff_dst_iter_desc.get(),
- convert_to_c(flags), hint_fwd_pd.get(), attr.get());
- msg = "could not create a primitive descriptor for "
- "the GRU backward propagation primitive. Run workload "
- "with environment variable ONEDNN_VERBOSE=all to get "
- "additional diagnostic information.";
- break;
- case algorithm::lbr_gru:
- status = dnnl_lbr_gru_backward_primitive_desc_create(&pd,
- aengine.get(), dnnl::convert_to_c(aprop_kind),
- dnnl::convert_to_c(direction), src_layer_desc.get(),
- src_iter_desc.get(), weights_layer_desc.get(),
- weights_iter_desc.get(), bias_desc.get(),
- dst_layer_desc.get(), dst_iter_desc.get(),
- diff_src_layer_desc.get(), diff_src_iter_desc.get(),
- diff_weights_layer_desc.get(),
- diff_weights_iter_desc.get(), diff_bias_desc.get(),
- diff_dst_layer_desc.get(), diff_dst_iter_desc.get(),
- convert_to_c(flags), hint_fwd_pd.get(), attr.get());
- msg = "could not create a primitive descriptor for "
- "the LBR GRU backward propagation primitive. Run "
- "workload with environment variable ONEDNN_VERBOSE=all "
- "to get additional diagnostic information.";
- break;
- case algorithm::vanilla_augru:
- status = dnnl_augru_backward_primitive_desc_create(&pd,
- aengine.get(), dnnl::convert_to_c(aprop_kind),
- dnnl::convert_to_c(direction), src_layer_desc.get(),
- src_iter_desc.get(), optional_arg(attention_desc),
- weights_layer_desc.get(), weights_iter_desc.get(),
- bias_desc.get(), dst_layer_desc.get(),
- dst_iter_desc.get(), diff_src_layer_desc.get(),
- diff_src_iter_desc.get(),
- optional_arg(diff_attention_desc),
- diff_weights_layer_desc.get(),
- diff_weights_iter_desc.get(), diff_bias_desc.get(),
- diff_dst_layer_desc.get(), diff_dst_iter_desc.get(),
- convert_to_c(flags), hint_fwd_pd.get(), attr.get());
- msg = "could not create a primitive descriptor for "
- "the AUGRU backward propagation primitive. Run workload "
- "with environment variable ONEDNN_VERBOSE=all to get "
- "additional diagnostic information.";
- break;
- case algorithm::lbr_augru:
- status = dnnl_lbr_augru_backward_primitive_desc_create(&pd,
- aengine.get(), dnnl::convert_to_c(aprop_kind),
- dnnl::convert_to_c(direction), src_layer_desc.get(),
- src_iter_desc.get(), optional_arg(attention_desc),
- weights_layer_desc.get(), weights_iter_desc.get(),
- bias_desc.get(), dst_layer_desc.get(),
- dst_iter_desc.get(), diff_src_layer_desc.get(),
- diff_src_iter_desc.get(),
- optional_arg(diff_attention_desc),
- diff_weights_layer_desc.get(),
- diff_weights_iter_desc.get(), diff_bias_desc.get(),
- diff_dst_layer_desc.get(), diff_dst_iter_desc.get(),
- convert_to_c(flags), hint_fwd_pd.get(), attr.get());
- msg = "could not create a primitive descriptor for "
- "the LBR AUGRU backward propagation primitive. Run "
- "workload with environment variable ONEDNN_VERBOSE=all "
- "to get additional diagnostic information.";
- break;
- default: status = dnnl_unimplemented;
- }
- if (!allow_empty) error::wrap_c_api(status, msg);
- reset(pd);
- }
- };
- /// Vanilla RNN forward propagation primitive.
- struct vanilla_rnn_forward : public primitive {
- /// Primitive descriptor for a vanilla RNN forward propagation primitive.
- struct primitive_desc : public rnn_primitive_desc_base {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a vanilla RNN forward
- /// propagation primitive.
- ///
- /// The following arguments may point to a zero memory descriptor:
- /// - @p src_iter_desc,
- /// - @p bias_desc,
- /// - @p dst_iter_desc.
- ///
- /// This would then indicate that the RNN forward propagation primitive
- /// should not use them and should default to zero values instead.
- ///
- /// @note
- /// All memory descriptors except @p src_iter_desc can be
- /// initialized with an #dnnl::memory::format_tag::any value of @p
- /// format_tag.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param activation Activation kind. Possible values are
- /// #dnnl::algorithm::eltwise_relu,
- /// #dnnl::algorithm::eltwise_tanh, or
- /// #dnnl::algorithm::eltwise_logistic.
- /// @param direction RNN direction. See @ref dnnl::rnn_direction for
- /// more info.
- /// @param src_layer_desc Memory descriptor for the input vector.
- /// @param src_iter_desc Memory descriptor for the input recurrent
- /// hidden state vector.
- /// @param weights_layer_desc Memory descriptor for the weights
- /// applied to the layer input.
- /// @param weights_iter_desc Memory descriptor for the weights applied
- /// to the recurrent input.
- /// @param bias_desc Bias memory descriptor.
- /// @param dst_layer_desc Memory descriptor for the output vector.
- /// @param dst_iter_desc Memory descriptor for the output recurrent
- /// hidden state vector.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- algorithm activation, rnn_direction direction,
- const memory::desc &src_layer_desc,
- const memory::desc &src_iter_desc,
- const memory::desc &weights_layer_desc,
- const memory::desc &weights_iter_desc,
- const memory::desc &bias_desc,
- const memory::desc &dst_layer_desc,
- const memory::desc &dst_iter_desc,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn,
- aprop_kind, activation, direction, src_layer_desc,
- src_iter_desc, nullptr, nullptr, weights_layer_desc,
- weights_iter_desc, nullptr, nullptr, bias_desc,
- dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef,
- 0.0f, 0.0f, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a vanilla RNN forward
- /// propagation primitive with alpha parameter.
- ///
- /// The following arguments may point to a zero memory descriptor:
- /// - @p src_iter_desc,
- /// - @p bias_desc,
- /// - @p dst_iter_desc.
- ///
- /// This would then indicate that the RNN forward propagation primitive
- /// should not use them and should default to zero values instead.
- ///
- /// @note
- /// All memory descriptors except @p src_iter_desc can be
- /// initialized with an #dnnl::memory::format_tag::any value of @p
- /// format_tag.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param activation Activation kind. Possible values are
- /// #dnnl::algorithm::eltwise_relu,
- /// #dnnl::algorithm::eltwise_tanh, or
- /// #dnnl::algorithm::eltwise_logistic.
- /// @param direction RNN direction. See @ref dnnl::rnn_direction for
- /// more info.
- /// @param src_layer_desc Memory descriptor for the input vector.
- /// @param src_iter_desc Memory descriptor for the input recurrent
- /// hidden state vector.
- /// @param weights_layer_desc Memory descriptor for the weights
- /// applied to the layer input.
- /// @param weights_iter_desc Memory descriptor for the weights applied
- /// to the recurrent input.
- /// @param bias_desc Bias memory descriptor.
- /// @param dst_layer_desc Memory descriptor for the output vector.
- /// @param dst_iter_desc Memory descriptor for the output recurrent
- /// hidden state vector.
- /// @param alpha Negative slope if activation is
- /// #dnnl::algorithm::eltwise_relu.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- algorithm activation, rnn_direction direction,
- const memory::desc &src_layer_desc,
- const memory::desc &src_iter_desc,
- const memory::desc &weights_layer_desc,
- const memory::desc &weights_iter_desc,
- const memory::desc &bias_desc,
- const memory::desc &dst_layer_desc,
- const memory::desc &dst_iter_desc, float alpha,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn,
- aprop_kind, activation, direction, src_layer_desc,
- src_iter_desc, nullptr, nullptr, weights_layer_desc,
- weights_iter_desc, nullptr, nullptr, bias_desc,
- dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef,
- alpha, 0.0f, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a vanilla RNN forward
- /// propagation primitive from a C API primitive descriptor that must
- /// have a matching kind.
- ///
- /// @param pd C API primitive descriptor for a vanilla RNN forward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
- dnnl::prop_kind::forward_inference,
- dnnl::algorithm::vanilla_rnn) {}
- /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
- memory::desc src_layer_desc() const {
- return rnn_base::src_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
- memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
- memory::desc weights_layer_desc() const {
- return rnn_base::weights_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
- memory::desc weights_iter_desc() const {
- return rnn_base::weights_iter_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
- memory::desc bias_desc() const { return rnn_base::bias_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
- memory::desc dst_layer_desc() const {
- return rnn_base::dst_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
- memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
- memory::desc workspace_desc() const {
- return rnn_base::workspace_desc();
- }
- /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
- algorithm get_cell_kind() const { return base::get_cell_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_activation_kind()const
- algorithm get_activation_kind() const {
- return base::get_activation_kind();
- }
- /// @copydoc dnnl::primitive_desc_base::get_direction()const
- rnn_direction get_direction() const { return base::get_direction(); }
- /// @copydoc dnnl::primitive_desc_base::get_alpha()const
- float get_alpha() const { return base::get_alpha(); }
- /// @copydoc dnnl::primitive_desc_base::get_beta()const
- float get_beta() const { return base::get_beta(); }
- };
- /// Default constructor. Produces an empty object.
- vanilla_rnn_forward() = default;
- /// Constructs a vanilla RNN forward propagation primitive.
- /// @param pd Primitive descriptor for a vanilla RNN forward
- /// propagation primitive.
- vanilla_rnn_forward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a vanilla RNN forward propagation primitive from
- /// a cache blob.
- /// @param pd Primitive descriptor for a vanilla RNN forward
- /// propagation primitive.
- /// @param cache_blob Cache blob.
- vanilla_rnn_forward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// Vanilla RNN backward propagation primitive.
- struct vanilla_rnn_backward : public primitive {
- /// Primitive descriptor for an RNN backward propagation primitive.
- struct primitive_desc : public rnn_primitive_desc_base {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a vanilla RNN backward
- /// propagation primitive.
- ///
- /// The following arguments may point to a zero memory descriptor:
- /// - @p src_iter_desc together with @p diff_src_iter_desc,
- /// - @p bias_desc together with @p diff_bias_desc,
- /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
- ///
- /// This would then indicate that the RNN backward propagation
- /// primitive should not use the respective data and should use zero
- /// values instead.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Must be
- /// #dnnl::prop_kind::backward.
- /// @param activation Activation kind. Possible values are
- /// #dnnl::algorithm::eltwise_relu,
- /// #dnnl::algorithm::eltwise_tanh, or
- /// #dnnl::algorithm::eltwise_logistic.
- /// @param direction RNN direction. See @ref dnnl::rnn_direction for
- /// more info.
- /// @param src_layer_desc Memory descriptor for the input vector.
- /// @param src_iter_desc Memory descriptor for the input recurrent
- /// hidden state vector.
- /// @param weights_layer_desc Memory descriptor for the weights
- /// applied to the layer input.
- /// @param weights_iter_desc Memory descriptor for the weights applied
- /// to the recurrent input.
- /// @param bias_desc Bias memory descriptor.
- /// @param dst_layer_desc Memory descriptor for the output vector.
- /// @param dst_iter_desc Memory descriptor for the output recurrent
- /// hidden state vector.
- /// @param diff_src_layer_desc Memory descriptor for the diff of input
- /// vector.
- /// @param diff_src_iter_desc Memory descriptor for the diff of input
- /// recurrent hidden state vector.
- /// @param diff_weights_layer_desc Memory descriptor for the diff of
- /// weights applied to the layer input.
- /// @param diff_weights_iter_desc Memory descriptor for the diff of
- /// weights applied to the recurrent input.
- /// @param diff_bias_desc Diff bias memory descriptor.
- /// @param diff_dst_layer_desc Memory descriptor for the diff of
- /// output vector.
- /// @param diff_dst_iter_desc Memory descriptor for the diff of output
- /// recurrent hidden state vector.
- /// @param hint_fwd_pd Primitive descriptor for a vanilla RNN
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- algorithm activation, rnn_direction direction,
- const memory::desc &src_layer_desc,
- const memory::desc &src_iter_desc,
- const memory::desc &weights_layer_desc,
- const memory::desc &weights_iter_desc,
- const memory::desc &bias_desc,
- const memory::desc &dst_layer_desc,
- const memory::desc &dst_iter_desc,
- const memory::desc &diff_src_layer_desc,
- const memory::desc &diff_src_iter_desc,
- const memory::desc &diff_weights_layer_desc,
- const memory::desc &diff_weights_iter_desc,
- const memory::desc &diff_bias_desc,
- const memory::desc &diff_dst_layer_desc,
- const memory::desc &diff_dst_iter_desc,
- const vanilla_rnn_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn,
- aprop_kind, activation, direction, src_layer_desc,
- src_iter_desc, nullptr, nullptr, weights_layer_desc,
- weights_iter_desc, nullptr, nullptr, bias_desc,
- dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
- diff_src_iter_desc, nullptr, nullptr,
- diff_weights_layer_desc, diff_weights_iter_desc, nullptr,
- nullptr, diff_bias_desc, diff_dst_layer_desc,
- diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f,
- hint_fwd_pd, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a vanilla RNN backward
- /// propagation primitive with an alpha parameter.
- ///
- /// The following arguments may point to a zero memory descriptor:
- /// - @p src_iter_desc together with @p diff_src_iter_desc,
- /// - @p bias_desc together with @p diff_bias_desc,
- /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
- ///
- /// This would then indicate that the RNN backward propagation
- /// primitive should not use the respective data and should use zero
- /// values instead.
- ///
- /// @note
- /// All the memory descriptors may be initialized with the
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Must be
- /// #dnnl::prop_kind::backward.
- /// @param activation Activation kind. Possible values are
- /// #dnnl::algorithm::eltwise_relu,
- /// #dnnl::algorithm::eltwise_tanh, or
- /// #dnnl::algorithm::eltwise_logistic.
- /// @param direction RNN direction. See @ref dnnl::rnn_direction for
- /// more info.
- /// @param src_layer_desc Memory descriptor for the input vector.
- /// @param src_iter_desc Memory descriptor for the input recurrent
- /// hidden state vector.
- /// @param weights_layer_desc Memory descriptor for the weights
- /// applied to the layer input.
- /// @param weights_iter_desc Memory descriptor for the weights applied
- /// to the recurrent input.
- /// @param bias_desc Bias memory descriptor.
- /// @param dst_layer_desc Memory descriptor for the output vector.
- /// @param dst_iter_desc Memory descriptor for the output recurrent
- /// hidden state vector.
- /// @param diff_src_layer_desc Memory descriptor for the diff of input
- /// vector.
- /// @param diff_src_iter_desc Memory descriptor for the diff of input
- /// recurrent hidden state vector.
- /// @param diff_weights_layer_desc Memory descriptor for the diff of
- /// weights applied to the layer input.
- /// @param diff_weights_iter_desc Memory descriptor for the diff of
- /// weights applied to the recurrent input.
- /// @param diff_bias_desc Diff bias memory descriptor.
- /// @param diff_dst_layer_desc Memory descriptor for the diff of
- /// output vector.
- /// @param diff_dst_iter_desc Memory descriptor for the diff of output
- /// recurrent hidden state vector.
- /// @param alpha Negative slope if activation is
- /// #dnnl::algorithm::eltwise_relu.
- /// @param hint_fwd_pd Primitive descriptor for a vanilla RNN
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- algorithm activation, rnn_direction direction,
- const memory::desc &src_layer_desc,
- const memory::desc &src_iter_desc,
- const memory::desc &weights_layer_desc,
- const memory::desc &weights_iter_desc,
- const memory::desc &bias_desc,
- const memory::desc &dst_layer_desc,
- const memory::desc &dst_iter_desc,
- const memory::desc &diff_src_layer_desc,
- const memory::desc &diff_src_iter_desc,
- const memory::desc &diff_weights_layer_desc,
- const memory::desc &diff_weights_iter_desc,
- const memory::desc &diff_bias_desc,
- const memory::desc &diff_dst_layer_desc,
- const memory::desc &diff_dst_iter_desc, float alpha,
- const vanilla_rnn_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn,
- aprop_kind, activation, direction, src_layer_desc,
- src_iter_desc, nullptr, nullptr, weights_layer_desc,
- weights_iter_desc, nullptr, nullptr, bias_desc,
- dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
- diff_src_iter_desc, nullptr, nullptr,
- diff_weights_layer_desc, diff_weights_iter_desc, nullptr,
- nullptr, diff_bias_desc, diff_dst_layer_desc,
- diff_dst_iter_desc, nullptr, rnn_flags::undef, alpha, 0.0f,
- hint_fwd_pd, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a vanilla RNN backward
- /// propagation primitive from a C API primitive descriptor that must
- /// have a matching kind.
- ///
- /// @param pd C API primitive descriptor for a vanilla RNN backward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
- dnnl::algorithm::vanilla_rnn) {}
- /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
- memory::desc src_layer_desc() const {
- return rnn_base::src_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
- memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
- memory::desc weights_layer_desc() const {
- return rnn_base::weights_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
- memory::desc weights_iter_desc() const {
- return rnn_base::weights_iter_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
- memory::desc bias_desc() const { return rnn_base::bias_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
- memory::desc dst_layer_desc() const {
- return rnn_base::dst_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
- memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
- memory::desc workspace_desc() const {
- return rnn_base::workspace_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
- memory::desc diff_src_layer_desc() const {
- return rnn_base::diff_src_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
- memory::desc diff_src_iter_desc() const {
- return rnn_base::diff_src_iter_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
- memory::desc diff_weights_layer_desc() const {
- return rnn_base::diff_weights_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
- memory::desc diff_weights_iter_desc() const {
- return rnn_base::diff_weights_iter_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
- memory::desc diff_bias_desc() const {
- return rnn_base::diff_bias_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
- memory::desc diff_dst_layer_desc() const {
- return rnn_base::diff_dst_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
- memory::desc diff_dst_iter_desc() const {
- return rnn_base::diff_dst_iter_desc();
- }
- /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
- algorithm get_cell_kind() const { return base::get_cell_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_activation_kind()const
- algorithm get_activation_kind() const {
- return base::get_activation_kind();
- }
- /// @copydoc dnnl::primitive_desc_base::get_direction()const
- rnn_direction get_direction() const { return base::get_direction(); }
- /// @copydoc dnnl::primitive_desc_base::get_alpha()const
- float get_alpha() const { return base::get_alpha(); }
- /// @copydoc dnnl::primitive_desc_base::get_beta()const
- float get_beta() const { return base::get_beta(); }
- };
- /// Default constructor. Produces an empty object.
- vanilla_rnn_backward() = default;
- /// Constructs a vanilla RNN backward propagation primitive.
- /// @param pd Primitive descriptor for a vanilla RNN backward
- /// propagation primitive.
- vanilla_rnn_backward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a vanilla RNN backward propagation primitive from
- /// a cache blob.
- /// @param pd Primitive descriptor for a vanilla RNN backward
- /// propagation primitive.
- /// @param cache_blob Cache blob.
- vanilla_rnn_backward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// LSTM forward propagation primitive.
- struct lstm_forward : public primitive {
- /// Primitive descriptor for an LSTM forward propagation primitive.
- struct primitive_desc : public rnn_primitive_desc_base {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for an LSTM (with or without
- /// peephole and with or without projection) forward propagation
- /// primitive.
- ///
- /// The following arguments may point to a zero memory descriptor:
- /// - @p src_iter_desc together with @p src_iter_c_desc,
- /// - @p weights_peephole_desc,
- /// - @p bias_desc,
- /// - @p dst_iter_desc together with @p dst_iter_c_desc.
- ///
- /// This would then indicate that the LSTM forward propagation
- /// primitive should not use them and should default to zero values
- /// instead.
- ///
- /// The @p weights_projection_desc may point to a zero memory
- /// descriptor. This would then indicate that the LSTM doesn't have
- /// recurrent projection layer.
- ///
- /// @note
- /// All memory descriptors can be initialized with an
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param direction RNN direction. See @ref dnnl::rnn_direction for
- /// more info.
- /// @param src_layer_desc Memory descriptor for the input vector.
- /// @param src_iter_desc Memory descriptor for the input recurrent
- /// hidden state vector.
- /// @param src_iter_c_desc Memory descriptor for the input recurrent
- /// cell state vector.
- /// @param weights_layer_desc Memory descriptor for the weights
- /// applied to the layer input.
- /// @param weights_iter_desc Memory descriptor for the weights applied
- /// to the recurrent input.
- /// @param weights_peephole_desc Memory descriptor for the weights
- /// applied to the cell states (according to the Peephole LSTM
- /// formula).
- /// @param weights_projection_desc Memory descriptor for the weights
- /// applied to the hidden states to get the recurrent projection
- /// (according to the Projection LSTM formula).
- /// @param bias_desc Bias memory descriptor.
- /// @param dst_layer_desc Memory descriptor for the output vector.
- /// @param dst_iter_desc Memory descriptor for the output recurrent
- /// hidden state vector.
- /// @param dst_iter_c_desc Memory descriptor for the output recurrent
- /// cell state vector.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- rnn_direction direction, const memory::desc &src_layer_desc,
- const memory::desc &src_iter_desc,
- const memory::desc &src_iter_c_desc,
- const memory::desc &weights_layer_desc,
- const memory::desc &weights_iter_desc,
- const memory::desc &weights_peephole_desc,
- const memory::desc &weights_projection_desc,
- const memory::desc &bias_desc,
- const memory::desc &dst_layer_desc,
- const memory::desc &dst_iter_desc,
- const memory::desc &dst_iter_c_desc,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
- aprop_kind, algorithm::undef, direction, src_layer_desc,
- src_iter_desc, &src_iter_c_desc, nullptr,
- weights_layer_desc, weights_iter_desc,
- &weights_peephole_desc, &weights_projection_desc, bias_desc,
- dst_layer_desc, dst_iter_desc, &dst_iter_c_desc,
- rnn_flags::undef, 0.0f, 0.0f, attr, allow_empty) {}
- /// Constructs a primitive descriptor for an LSTM (with or without
- /// peephole) forward propagation primitive.
- ///
- /// The following arguments may point to a zero memory descriptor:
- /// - @p src_iter_desc together with @p src_iter_c_desc,
- /// - @p weights_peephole_desc,
- /// - @p bias_desc,
- /// - @p dst_iter_desc together with @p dst_iter_c_desc.
- ///
- /// This would then indicate that the LSTM forward propagation
- /// primitive should not use them and should default to zero values
- /// instead.
- ///
- /// @note
- /// All memory descriptors can be initialized with an
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param direction RNN direction. See @ref dnnl::rnn_direction for
- /// more info.
- /// @param src_layer_desc Memory descriptor for the input vector.
- /// @param src_iter_desc Memory descriptor for the input recurrent
- /// hidden state vector.
- /// @param src_iter_c_desc Memory descriptor for the input recurrent
- /// cell state vector.
- /// @param weights_layer_desc Memory descriptor for the weights
- /// applied to the layer input.
- /// @param weights_iter_desc Memory descriptor for the weights applied
- /// to the recurrent input.
- /// @param weights_peephole_desc Memory descriptor for the weights
- /// applied to the cell states (according to the Peephole LSTM
- /// formula).
- /// @param bias_desc Bias memory descriptor.
- /// @param dst_layer_desc Memory descriptor for the output vector.
- /// @param dst_iter_desc Memory descriptor for the output recurrent
- /// hidden state vector.
- /// @param dst_iter_c_desc Memory descriptor for the output recurrent
- /// cell state vector.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- rnn_direction direction, const memory::desc &src_layer_desc,
- const memory::desc &src_iter_desc,
- const memory::desc &src_iter_c_desc,
- const memory::desc &weights_layer_desc,
- const memory::desc &weights_iter_desc,
- const memory::desc &weights_peephole_desc,
- const memory::desc &bias_desc,
- const memory::desc &dst_layer_desc,
- const memory::desc &dst_iter_desc,
- const memory::desc &dst_iter_c_desc,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
- aprop_kind, algorithm::undef, direction, src_layer_desc,
- src_iter_desc, &src_iter_c_desc, nullptr,
- weights_layer_desc, weights_iter_desc,
- &weights_peephole_desc, nullptr, bias_desc, dst_layer_desc,
- dst_iter_desc, &dst_iter_c_desc, rnn_flags::undef, 0.0f,
- 0.0f, attr, allow_empty) {}
- /// Constructs a primitive descriptor for an LSTM forward propagation
- /// primitive.
- ///
- /// The following arguments may point to a zero memory descriptor:
- /// - @p src_iter_desc together with @p src_iter_c_desc,
- /// - @p bias_desc,
- /// - @p dst_iter_desc together with @p dst_iter_c_desc.
- ///
- /// This would then indicate that the LSTM forward propagation
- /// primitive should not use them and should default to zero values
- /// instead.
- ///
- /// @note
- /// All memory descriptors can be initialized with an
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param direction RNN direction. See @ref dnnl::rnn_direction for
- /// more info.
- /// @param src_layer_desc Memory descriptor for the input vector.
- /// @param src_iter_desc Memory descriptor for the input recurrent
- /// hidden state vector.
- /// @param src_iter_c_desc Memory descriptor for the input recurrent
- /// cell state vector.
- /// @param weights_layer_desc Memory descriptor for the weights
- /// applied to the layer input.
- /// @param weights_iter_desc Memory descriptor for the weights applied
- /// to the recurrent input.
- /// @param bias_desc Bias memory descriptor.
- /// @param dst_layer_desc Memory descriptor for the output vector.
- /// @param dst_iter_desc Memory descriptor for the output recurrent
- /// hidden state vector.
- /// @param dst_iter_c_desc Memory descriptor for the output recurrent
- /// cell state vector.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- rnn_direction direction, const memory::desc &src_layer_desc,
- const memory::desc &src_iter_desc,
- const memory::desc &src_iter_c_desc,
- const memory::desc &weights_layer_desc,
- const memory::desc &weights_iter_desc,
- const memory::desc &bias_desc,
- const memory::desc &dst_layer_desc,
- const memory::desc &dst_iter_desc,
- const memory::desc &dst_iter_c_desc,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
- aprop_kind, algorithm::undef, direction, src_layer_desc,
- src_iter_desc, &src_iter_c_desc, nullptr,
- weights_layer_desc, weights_iter_desc, nullptr, nullptr,
- bias_desc, dst_layer_desc, dst_iter_desc, &dst_iter_c_desc,
- rnn_flags::undef, 0.0f, 0.0f, attr, allow_empty) {}
- /// Constructs a primitive descriptor for an LSTM forward propagation
- /// primitive from a C API primitive descriptor that must have a
- /// matching kind.
- ///
- /// @param pd C API primitive descriptor for an LSTM forward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
- dnnl::prop_kind::forward_inference,
- dnnl::algorithm::vanilla_lstm) {}
- /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
- memory::desc src_layer_desc() const {
- return rnn_base::src_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
- memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
- memory::desc src_iter_c_desc() const {
- return rnn_base::src_iter_c_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
- memory::desc weights_layer_desc() const {
- return rnn_base::weights_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
- memory::desc weights_iter_desc() const {
- return rnn_base::weights_iter_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_peephole_desc()const
- memory::desc weights_peephole_desc() const {
- return rnn_base::weights_peephole_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_projection_desc()const
- memory::desc weights_projection_desc() const {
- return rnn_base::weights_projection_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
- memory::desc bias_desc() const { return rnn_base::bias_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
- memory::desc dst_layer_desc() const {
- return rnn_base::dst_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
- memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
- memory::desc dst_iter_c_desc() const {
- return rnn_base::dst_iter_c_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
- memory::desc workspace_desc() const {
- return rnn_base::workspace_desc();
- }
- /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
- algorithm get_cell_kind() const { return base::get_cell_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_direction()const
- rnn_direction get_direction() const { return base::get_direction(); }
- };
- /// Default constructor. Produces an empty object.
- lstm_forward() = default;
- /// Constructs an LSTM forward propagation primitive.
- /// @param pd Primitive descriptor for an LSTM forward propagation
- /// primitive.
- lstm_forward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs an LSTM forward propagation primitive from a cache blob.
- /// @param pd Primitive descriptor for an LSTM forward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- lstm_forward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// LSTM backward propagation primitive.
- struct lstm_backward : public primitive {
- /// Primitive descriptor for an LSTM backward propagation primitive.
- struct primitive_desc : public rnn_primitive_desc_base {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs an LSTM (with or without peephole and with or without
- /// projection) primitive descriptor for backward propagation
- /// using @p prop_kind, @p direction, and memory descriptors.
- ///
- /// The following arguments may point to a zero memory descriptor:
- /// - @p src_iter_desc together with @p src_iter_c_desc,
- /// @p diff_src_iter_desc, and @p diff_src_iter_c_desc,
- /// - @p weights_peephole_desc together with
- /// @p diff_weights_peephole_desc
- /// - @p bias_desc together with @p diff_bias_desc,
- /// - @p dst_iter_desc together with @p dst_iter_c_desc,
- /// @p diff_dst_iter_desc, and @p diff_dst_iter_c_desc.
- ///
- /// This would then indicate that the LSTM backward propagation
- /// primitive should not use them and should default to zero values
- /// instead.
- ///
- /// The @p weights_projection_desc together with @p
- /// diff_weights_projection_desc may point to a zero memory descriptor.
- /// This would then indicate that the LSTM doesn't have recurrent
- /// projection layer.
- ///
- /// @note
- /// All memory descriptors can be initialized with
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Must be
- /// #dnnl::prop_kind::backward.
- /// @param direction RNN direction. See @ref dnnl::rnn_direction for
- /// more info.
- /// @param src_layer_desc Memory descriptor for the input vector.
- /// @param src_iter_desc Memory descriptor for the input recurrent
- /// hidden state vector.
- /// @param src_iter_c_desc Memory descriptor for the input recurrent
- /// cell state vector.
- /// @param weights_layer_desc Memory descriptor for the weights
- /// applied to the layer input.
- /// @param weights_iter_desc Memory descriptor for the weights applied
- /// to the recurrent input.
- /// @param weights_peephole_desc Memory descriptor for the weights
- /// applied to the cell states (according to the Peephole LSTM
- /// formula).
- /// @param weights_projection_desc Memory descriptor for the weights
- /// applied to the hidden states to get the recurrent projection
- /// (according to the Projection LSTM formula).
- /// @param bias_desc Bias memory descriptor.
- /// @param dst_layer_desc Memory descriptor for the output vector.
- /// @param dst_iter_desc Memory descriptor for the output recurrent
- /// hidden state vector.
- /// @param dst_iter_c_desc Memory descriptor for the output recurrent
- /// cell state vector.
- /// @param diff_src_layer_desc Memory descriptor for the diff of input
- /// vector.
- /// @param diff_src_iter_desc Memory descriptor for the diff of input
- /// recurrent hidden state vector.
- /// @param diff_src_iter_c_desc Memory descriptor for the diff of
- /// input recurrent cell state vector.
- /// @param diff_weights_layer_desc Memory descriptor for the diff of
- /// weights applied to the layer input.
- /// @param diff_weights_iter_desc Memory descriptor for the diff of
- /// weights applied to the recurrent input.
- /// @param diff_weights_peephole_desc Memory descriptor for the diff of
- /// weights applied to the cell states (according to the Peephole
- /// LSTM formula).
- /// @param diff_weights_projection_desc Memory descriptor for the diff
- /// of weights applied to the hidden states to get the recurrent
- /// projection (according to the Projection LSTM formula).
- /// @param diff_bias_desc Diff bias memory descriptor.
- /// @param diff_dst_layer_desc Memory descriptor for the diff of
- /// output vector.
- /// @param diff_dst_iter_desc Memory descriptor for the diff of output
- /// recurrent hidden state vector.
- /// @param diff_dst_iter_c_desc Memory descriptor for the diff of
- /// output recurrent cell state vector.
- /// @param hint_fwd_pd Primitive descriptor for an LSTM
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- rnn_direction direction, const memory::desc &src_layer_desc,
- const memory::desc &src_iter_desc,
- const memory::desc &src_iter_c_desc,
- const memory::desc &weights_layer_desc,
- const memory::desc &weights_iter_desc,
- const memory::desc &weights_peephole_desc,
- const memory::desc &weights_projection_desc,
- const memory::desc &bias_desc,
- const memory::desc &dst_layer_desc,
- const memory::desc &dst_iter_desc,
- const memory::desc &dst_iter_c_desc,
- const memory::desc &diff_src_layer_desc,
- const memory::desc &diff_src_iter_desc,
- const memory::desc &diff_src_iter_c_desc,
- const memory::desc &diff_weights_layer_desc,
- const memory::desc &diff_weights_iter_desc,
- const memory::desc &diff_weights_peephole_desc,
- const memory::desc &diff_weights_projection_desc,
- const memory::desc &diff_bias_desc,
- const memory::desc &diff_dst_layer_desc,
- const memory::desc &diff_dst_iter_desc,
- const memory::desc &diff_dst_iter_c_desc,
- const lstm_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
- aprop_kind, algorithm::undef, direction, src_layer_desc,
- src_iter_desc, &src_iter_c_desc, nullptr,
- weights_layer_desc, weights_iter_desc,
- &weights_peephole_desc, &weights_projection_desc, bias_desc,
- dst_layer_desc, dst_iter_desc, &dst_iter_c_desc,
- diff_src_layer_desc, diff_src_iter_desc,
- &diff_src_iter_c_desc, nullptr, diff_weights_layer_desc,
- diff_weights_iter_desc, &diff_weights_peephole_desc,
- &diff_weights_projection_desc, diff_bias_desc,
- diff_dst_layer_desc, diff_dst_iter_desc,
- &diff_dst_iter_c_desc, rnn_flags::undef, 0.0f, 0.0f,
- hint_fwd_pd, attr, allow_empty) {}
- /// Constructs an LSTM (with or without peephole) primitive descriptor
- /// for backward propagation using @p prop_kind, @p direction,
- /// and memory descriptors.
- ///
- /// The following arguments may point to a zero memory descriptor:
- /// - @p src_iter_desc together with @p src_iter_c_desc,
- /// @p diff_src_iter_desc, and @p diff_src_iter_c_desc,
- /// - @p weights_peephole_desc together with
- /// @p diff_weights_peephole_desc
- /// - @p bias_desc together with @p diff_bias_desc,
- /// - @p dst_iter_desc together with @p dst_iter_c_desc,
- /// @p diff_dst_iter_desc, and @p diff_dst_iter_c_desc.
- ///
- /// This would then indicate that the LSTM backward propagation
- /// primitive should not use them and should default to zero values
- /// instead.
- ///
- /// @note
- /// All memory descriptors may be initialized with
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Must be
- /// #dnnl::prop_kind::backward.
- /// @param direction RNN direction. See @ref dnnl::rnn_direction for
- /// more info.
- /// @param src_layer_desc Memory descriptor for the input vector.
- /// @param src_iter_desc Memory descriptor for the input recurrent
- /// hidden state vector.
- /// @param src_iter_c_desc Memory descriptor for the input recurrent
- /// cell state vector.
- /// @param weights_layer_desc Memory descriptor for the weights
- /// applied to the layer input.
- /// @param weights_iter_desc Memory descriptor for the weights applied
- /// to the recurrent input.
- /// @param weights_peephole_desc Memory descriptor for the weights
- /// applied to the cell states (according to the Peephole LSTM
- /// formula).
- /// @param bias_desc Bias memory descriptor.
- /// @param dst_layer_desc Memory descriptor for the output vector.
- /// @param dst_iter_desc Memory descriptor for the output recurrent
- /// hidden state vector.
- /// @param dst_iter_c_desc Memory descriptor for the output recurrent
- /// cell state vector.
- /// @param diff_src_layer_desc Memory descriptor for the diff of input
- /// vector.
- /// @param diff_src_iter_desc Memory descriptor for the diff of input
- /// recurrent hidden state vector.
- /// @param diff_src_iter_c_desc Memory descriptor for the diff of
- /// input recurrent cell state vector.
- /// @param diff_weights_layer_desc Memory descriptor for the diff of
- /// weights applied to the layer input.
- /// @param diff_weights_iter_desc Memory descriptor for the diff of
- /// weights applied to the recurrent input.
- /// @param diff_weights_peephole_desc Memory descriptor for the diff of
- /// weights applied to the cell states (according to the Peephole
- /// LSTM formula).
- /// @param diff_bias_desc Diff bias memory descriptor.
- /// @param diff_dst_layer_desc Memory descriptor for the diff of
- /// output vector.
- /// @param diff_dst_iter_desc Memory descriptor for the diff of output
- /// recurrent hidden state vector.
- /// @param diff_dst_iter_c_desc Memory descriptor for the diff of
- /// output recurrent cell state vector.
- /// @param hint_fwd_pd Primitive descriptor for an LSTM
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- rnn_direction direction, const memory::desc &src_layer_desc,
- const memory::desc &src_iter_desc,
- const memory::desc &src_iter_c_desc,
- const memory::desc &weights_layer_desc,
- const memory::desc &weights_iter_desc,
- const memory::desc &weights_peephole_desc,
- const memory::desc &bias_desc,
- const memory::desc &dst_layer_desc,
- const memory::desc &dst_iter_desc,
- const memory::desc &dst_iter_c_desc,
- const memory::desc &diff_src_layer_desc,
- const memory::desc &diff_src_iter_desc,
- const memory::desc &diff_src_iter_c_desc,
- const memory::desc &diff_weights_layer_desc,
- const memory::desc &diff_weights_iter_desc,
- const memory::desc &diff_weights_peephole_desc,
- const memory::desc &diff_bias_desc,
- const memory::desc &diff_dst_layer_desc,
- const memory::desc &diff_dst_iter_desc,
- const memory::desc &diff_dst_iter_c_desc,
- const lstm_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
- aprop_kind, algorithm::undef, direction, src_layer_desc,
- src_iter_desc, &src_iter_c_desc, nullptr,
- weights_layer_desc, weights_iter_desc,
- &weights_peephole_desc, nullptr, bias_desc, dst_layer_desc,
- dst_iter_desc, &dst_iter_c_desc, diff_src_layer_desc,
- diff_src_iter_desc, &diff_src_iter_c_desc, nullptr,
- diff_weights_layer_desc, diff_weights_iter_desc,
- &diff_weights_peephole_desc, nullptr, diff_bias_desc,
- diff_dst_layer_desc, diff_dst_iter_desc,
- &diff_dst_iter_c_desc, rnn_flags::undef, 0.0f, 0.0f,
- hint_fwd_pd, attr, allow_empty) {}
- /// Constructs an LSTM primitive descriptor for backward propagation
- /// using @p prop_kind, @p direction, and memory descriptors.
- ///
- /// The following arguments may point to a zero memory descriptor:
- /// - @p src_iter_desc together with @p src_iter_c_desc,
- /// @p diff_src_iter_desc, and @p diff_src_iter_c_desc,
- /// - @p bias_desc together with @p diff_bias_desc,
- /// - @p dst_iter_desc together with @p dst_iter_c_desc,
- /// @p diff_dst_iter_desc, and @p diff_dst_iter_c_desc.
- ///
- /// This would then indicate that the LSTM backward propagation
- /// primitive should not use them and should default to zero values
- /// instead.
- ///
- /// @note
- /// All memory descriptors may be initialized with
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Must be
- /// #dnnl::prop_kind::backward.
- /// @param direction RNN direction. See @ref dnnl::rnn_direction for
- /// more info.
- /// @param src_layer_desc Memory descriptor for the input vector.
- /// @param src_iter_desc Memory descriptor for the input recurrent
- /// hidden state vector.
- /// @param src_iter_c_desc Memory descriptor for the input recurrent
- /// cell state vector.
- /// @param weights_layer_desc Memory descriptor for the weights
- /// applied to the layer input.
- /// @param weights_iter_desc Memory descriptor for the weights applied
- /// to the recurrent input.
- /// @param bias_desc Bias memory descriptor.
- /// @param dst_layer_desc Memory descriptor for the output vector.
- /// @param dst_iter_desc Memory descriptor for the output recurrent
- /// hidden state vector.
- /// @param dst_iter_c_desc Memory descriptor for the output recurrent
- /// cell state vector.
- /// @param diff_src_layer_desc Memory descriptor for the diff of input
- /// vector.
- /// @param diff_src_iter_desc Memory descriptor for the diff of input
- /// recurrent hidden state vector.
- /// @param diff_src_iter_c_desc Memory descriptor for the diff of
- /// input recurrent cell state vector.
- /// @param diff_weights_layer_desc Memory descriptor for the diff of
- /// weights applied to the layer input.
- /// @param diff_weights_iter_desc Memory descriptor for the diff of
- /// weights applied to the recurrent input.
- /// @param diff_bias_desc Diff bias memory descriptor.
- /// @param diff_dst_layer_desc Memory descriptor for the diff of
- /// output vector.
- /// @param diff_dst_iter_desc Memory descriptor for the diff of output
- /// recurrent hidden state vector.
- /// @param diff_dst_iter_c_desc Memory descriptor for the diff of
- /// output recurrent cell state vector.
- /// @param hint_fwd_pd Primitive descriptor for a convolution
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- rnn_direction direction, const memory::desc &src_layer_desc,
- const memory::desc &src_iter_desc,
- const memory::desc &src_iter_c_desc,
- const memory::desc &weights_layer_desc,
- const memory::desc &weights_iter_desc,
- const memory::desc &bias_desc,
- const memory::desc &dst_layer_desc,
- const memory::desc &dst_iter_desc,
- const memory::desc &dst_iter_c_desc,
- const memory::desc &diff_src_layer_desc,
- const memory::desc &diff_src_iter_desc,
- const memory::desc &diff_src_iter_c_desc,
- const memory::desc &diff_weights_layer_desc,
- const memory::desc &diff_weights_iter_desc,
- const memory::desc &diff_bias_desc,
- const memory::desc &diff_dst_layer_desc,
- const memory::desc &diff_dst_iter_desc,
- const memory::desc &diff_dst_iter_c_desc,
- const lstm_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
- aprop_kind, algorithm::undef, direction, src_layer_desc,
- src_iter_desc, &src_iter_c_desc, nullptr,
- weights_layer_desc, weights_iter_desc, nullptr, nullptr,
- bias_desc, dst_layer_desc, dst_iter_desc, &dst_iter_c_desc,
- diff_src_layer_desc, diff_src_iter_desc,
- &diff_src_iter_c_desc, nullptr, diff_weights_layer_desc,
- diff_weights_iter_desc, nullptr, nullptr, diff_bias_desc,
- diff_dst_layer_desc, diff_dst_iter_desc,
- &diff_dst_iter_c_desc, rnn_flags::undef, 0.0f, 0.0f,
- hint_fwd_pd, attr, allow_empty) {}
- /// Constructs a primitive descriptor for an LSTM backward propagation
- /// primitive from a C API primitive descriptor that must have a
- /// matching kind.
- ///
- /// @param pd C API primitive descriptor for an LSTM backward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
- dnnl::algorithm::vanilla_lstm) {}
- /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
- memory::desc src_layer_desc() const {
- return rnn_base::src_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
- memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
- memory::desc src_iter_c_desc() const {
- return rnn_base::src_iter_c_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
- memory::desc weights_layer_desc() const {
- return rnn_base::weights_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
- memory::desc weights_iter_desc() const {
- return rnn_base::weights_iter_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_peephole_desc()const
- memory::desc weights_peephole_desc() const {
- return rnn_base::weights_peephole_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_projection_desc()const
- memory::desc weights_projection_desc() const {
- return rnn_base::weights_projection_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
- memory::desc bias_desc() const { return rnn_base::bias_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
- memory::desc dst_layer_desc() const {
- return rnn_base::dst_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
- memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
- memory::desc dst_iter_c_desc() const {
- return rnn_base::dst_iter_c_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
- memory::desc workspace_desc() const {
- return rnn_base::workspace_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
- memory::desc diff_src_layer_desc() const {
- return rnn_base::diff_src_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
- memory::desc diff_src_iter_desc() const {
- return rnn_base::diff_src_iter_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_c_desc()const
- memory::desc diff_src_iter_c_desc() const {
- return rnn_base::diff_src_iter_c_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
- memory::desc diff_weights_layer_desc() const {
- return rnn_base::diff_weights_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
- memory::desc diff_weights_iter_desc() const {
- return rnn_base::diff_weights_iter_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_peephole_desc()const
- memory::desc diff_weights_peephole_desc() const {
- return rnn_base::diff_weights_peephole_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_projection_desc()const
- memory::desc diff_weights_projection_desc() const {
- return rnn_base::diff_weights_projection_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
- memory::desc diff_bias_desc() const {
- return rnn_base::diff_bias_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
- memory::desc diff_dst_layer_desc() const {
- return rnn_base::diff_dst_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
- memory::desc diff_dst_iter_desc() const {
- return rnn_base::diff_dst_iter_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_c_desc()const
- memory::desc diff_dst_iter_c_desc() const {
- return rnn_base::diff_dst_iter_c_desc();
- }
- /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
- algorithm get_cell_kind() const { return base::get_cell_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_direction()const
- rnn_direction get_direction() const { return base::get_direction(); }
- };
- /// Default constructor. Produces an empty object.
- lstm_backward() = default;
- /// Constructs an LSTM backward propagation primitive.
- /// @param pd Primitive descriptor for an LSTM backward propagation
- /// primitive.
- lstm_backward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs an LSTM backward propagation primitive from a cache blob.
- /// @param pd Primitive descriptor for an LSTM backward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- lstm_backward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// GRU forward propagation primitive.
- struct gru_forward : public primitive {
- /// Primitive descriptor for a GRU forward propagation primitive.
- struct primitive_desc : public rnn_primitive_desc_base {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a GRU forward propagation
- /// primitive.
- ///
- /// The following arguments may point to a zero memory descriptor:
- /// - @p src_iter_desc,
- /// - @p bias_desc,
- /// - @p dst_iter_desc.
- ///
- /// This would then indicate that the GRU forward propagation primitive
- /// should not use them and should default to zero values instead.
- ///
- /// @note
- /// All memory descriptors except @p src_iter_desc may be
- /// initialized with an #dnnl::memory::format_tag::any value of @p
- /// format_tag.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param direction RNN direction. See @ref dnnl::rnn_direction for
- /// more info.
- /// @param src_layer_desc Memory descriptor for the input vector.
- /// @param src_iter_desc Memory descriptor for the input recurrent
- /// hidden state vector.
- /// @param weights_layer_desc Memory descriptor for the weights
- /// applied to the layer input.
- /// @param weights_iter_desc Memory descriptor for the weights applied
- /// to the recurrent input.
- /// @param bias_desc Bias memory descriptor.
- /// @param dst_layer_desc Memory descriptor for the output vector.
- /// @param dst_iter_desc Memory descriptor for the output recurrent
- /// hidden state vector.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- rnn_direction direction, const memory::desc &src_layer_desc,
- const memory::desc &src_iter_desc,
- const memory::desc &weights_layer_desc,
- const memory::desc &weights_iter_desc,
- const memory::desc &bias_desc,
- const memory::desc &dst_layer_desc,
- const memory::desc &dst_iter_desc,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : rnn_primitive_desc_base(aengine, algorithm::vanilla_gru,
- aprop_kind, algorithm::undef, direction, src_layer_desc,
- src_iter_desc, nullptr, nullptr, weights_layer_desc,
- weights_iter_desc, nullptr, nullptr, bias_desc,
- dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef,
- 0.0f, 0.0f, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a GRU forward propagation
- /// primitive from a C API primitive descriptor that must have a
- /// matching kind.
- ///
- /// @param pd C API primitive descriptor for a GRU forward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
- dnnl::prop_kind::forward_inference,
- dnnl::algorithm::vanilla_gru) {}
- /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
- memory::desc src_layer_desc() const {
- return rnn_base::src_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
- memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
- memory::desc weights_layer_desc() const {
- return rnn_base::weights_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
- memory::desc weights_iter_desc() const {
- return rnn_base::weights_iter_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
- memory::desc bias_desc() const { return rnn_base::bias_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
- memory::desc dst_layer_desc() const {
- return rnn_base::dst_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
- memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
- memory::desc workspace_desc() const {
- return rnn_base::workspace_desc();
- }
- /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
- algorithm get_cell_kind() const { return base::get_cell_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_direction()const
- rnn_direction get_direction() const { return base::get_direction(); }
- };
- /// Default constructor. Produces an empty object.
- gru_forward() = default;
- /// Constructs a GRU forward propagation primitive.
- /// @param pd Primitive descriptor for a GRU forward propagation
- /// primitive.
- gru_forward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a GRU forward propagation primitive from a cache blob.
- /// @param pd Primitive descriptor for a GRU forward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- gru_forward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// GRU backward propagation primitive.
- struct gru_backward : public primitive {
- /// Primitive descriptor for a GRU backward propagation primitive.
- struct primitive_desc : public rnn_primitive_desc_base {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a GRU backward propagation
- /// primitive.
- ///
- /// The following arguments may point to a zero memory descriptor:
- /// - @p src_iter_desc together with @p diff_src_iter_desc,
- /// - @p bias_desc together with @p diff_bias_desc,
- /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
- ///
- /// This would then indicate that the GRU backward propagation
- /// primitive should not use them and should default to zero values
- /// instead.
- ///
- /// @note
- /// All memory descriptors may be initialized with
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Must be
- /// #dnnl::prop_kind::backward.
- /// @param direction RNN direction. See @ref dnnl::rnn_direction for
- /// more info.
- /// @param src_layer_desc Memory descriptor for the input vector.
- /// @param src_iter_desc Memory descriptor for the input recurrent
- /// hidden state vector.
- /// @param weights_layer_desc Memory descriptor for the weights
- /// applied to the layer input.
- /// @param weights_iter_desc Memory descriptor for the weights applied
- /// to the recurrent input.
- /// @param bias_desc Bias memory descriptor.
- /// @param dst_layer_desc Memory descriptor for the output vector.
- /// @param dst_iter_desc Memory descriptor for the output recurrent
- /// hidden state vector.
- /// @param diff_src_layer_desc Memory descriptor for the diff of input
- /// vector.
- /// @param diff_src_iter_desc Memory descriptor for the diff of input
- /// recurrent hidden state vector.
- /// @param diff_weights_layer_desc Memory descriptor for the diff of
- /// weights applied to the layer input.
- /// @param diff_weights_iter_desc Memory descriptor for the diff of
- /// weights applied to the recurrent input.
- /// @param diff_bias_desc Diff bias memory descriptor.
- /// @param diff_dst_layer_desc Memory descriptor for the diff of
- /// output vector.
- /// @param diff_dst_iter_desc Memory descriptor for the diff of output
- /// recurrent hidden state vector.
- /// @param hint_fwd_pd Primitive descriptor for a GRU
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- rnn_direction direction, const memory::desc &src_layer_desc,
- const memory::desc &src_iter_desc,
- const memory::desc &weights_layer_desc,
- const memory::desc &weights_iter_desc,
- const memory::desc &bias_desc,
- const memory::desc &dst_layer_desc,
- const memory::desc &dst_iter_desc,
- const memory::desc &diff_src_layer_desc,
- const memory::desc &diff_src_iter_desc,
- const memory::desc &diff_weights_layer_desc,
- const memory::desc &diff_weights_iter_desc,
- const memory::desc &diff_bias_desc,
- const memory::desc &diff_dst_layer_desc,
- const memory::desc &diff_dst_iter_desc,
- const gru_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : rnn_primitive_desc_base(aengine, algorithm::vanilla_gru,
- aprop_kind, algorithm::undef, direction, src_layer_desc,
- src_iter_desc, nullptr, nullptr, weights_layer_desc,
- weights_iter_desc, nullptr, nullptr, bias_desc,
- dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
- diff_src_iter_desc, nullptr, nullptr,
- diff_weights_layer_desc, diff_weights_iter_desc, nullptr,
- nullptr, diff_bias_desc, diff_dst_layer_desc,
- diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f,
- hint_fwd_pd, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a GRU backward propagation
- /// primitive from a C API primitive descriptor that must have a
- /// matching kind.
- ///
- /// @param pd C API primitive descriptor for a GRU backward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
- dnnl::algorithm::vanilla_gru) {}
- /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
- memory::desc src_layer_desc() const {
- return rnn_base::src_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
- memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
- memory::desc weights_layer_desc() const {
- return rnn_base::weights_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
- memory::desc weights_iter_desc() const {
- return rnn_base::weights_iter_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
- memory::desc bias_desc() const { return rnn_base::bias_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
- memory::desc dst_layer_desc() const {
- return rnn_base::dst_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
- memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
- memory::desc workspace_desc() const {
- return rnn_base::workspace_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
- memory::desc diff_src_layer_desc() const {
- return rnn_base::diff_src_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
- memory::desc diff_src_iter_desc() const {
- return rnn_base::diff_src_iter_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
- memory::desc diff_weights_layer_desc() const {
- return rnn_base::diff_weights_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
- memory::desc diff_weights_iter_desc() const {
- return rnn_base::diff_weights_iter_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
- memory::desc diff_bias_desc() const {
- return rnn_base::diff_bias_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
- memory::desc diff_dst_layer_desc() const {
- return rnn_base::diff_dst_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
- memory::desc diff_dst_iter_desc() const {
- return rnn_base::diff_dst_iter_desc();
- }
- /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
- algorithm get_cell_kind() const { return base::get_cell_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_direction()const
- rnn_direction get_direction() const { return base::get_direction(); }
- };
- /// Default constructor. Produces an empty object.
- gru_backward() = default;
- /// Constructs a GRU backward propagation primitive.
- /// @param pd Primitive descriptor for a GRU backward propagation
- /// primitive.
- gru_backward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a GRU backward propagation primitive from a cache blob.
- /// @param pd Primitive descriptor for a GRU backward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- gru_backward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// LBR GRU forward propagation primitive.
- struct lbr_gru_forward : public primitive {
- /// Primitive descriptor for an LBR GRU forward propagation primitive.
- struct primitive_desc : public rnn_primitive_desc_base {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for LBR GRU forward propagation
- /// primitive.
- ///
- /// The following arguments may point to a zero memory descriptor:
- /// - @p src_iter_desc,
- /// - @p bias_desc,
- /// - @p dst_iter_desc.
- ///
- /// This would then indicate that the LBR GRU forward propagation
- /// primitive should not use them and should default to zero values
- /// instead.
- ///
- /// @note
- /// All memory descriptors except @p src_iter_desc may be
- /// initialized with an #dnnl::memory::format_tag::any value of @p
- /// format_tag.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param direction RNN direction. See @ref dnnl::rnn_direction for
- /// more info.
- /// @param src_layer_desc Memory descriptor for the input vector.
- /// @param src_iter_desc Memory descriptor for the input recurrent
- /// hidden state vector.
- /// @param weights_layer_desc Memory descriptor for the weights
- /// applied to the layer input.
- /// @param weights_iter_desc Memory descriptor for the weights applied
- /// to the recurrent input.
- /// @param bias_desc Bias memory descriptor.
- /// @param dst_layer_desc Memory descriptor for the output vector.
- /// @param dst_iter_desc Memory descriptor for the output recurrent
- /// hidden state vector.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- rnn_direction direction, const memory::desc &src_layer_desc,
- const memory::desc &src_iter_desc,
- const memory::desc &weights_layer_desc,
- const memory::desc &weights_iter_desc,
- const memory::desc &bias_desc,
- const memory::desc &dst_layer_desc,
- const memory::desc &dst_iter_desc,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : rnn_primitive_desc_base(aengine, algorithm::lbr_gru, aprop_kind,
- algorithm::undef, direction, src_layer_desc, src_iter_desc,
- nullptr, nullptr, weights_layer_desc, weights_iter_desc,
- nullptr, nullptr, bias_desc, dst_layer_desc, dst_iter_desc,
- nullptr, rnn_flags::undef, 0.0f, 0.0f, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a LBR GRU forward propagation
- /// primitive from a C API primitive descriptor that must have a
- /// matching kind.
- ///
- /// @param pd C API primitive descriptor for a LBR GRU forward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
- dnnl::prop_kind::forward_inference,
- dnnl::algorithm::lbr_gru) {}
- /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
- memory::desc src_layer_desc() const {
- return rnn_base::src_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
- memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
- memory::desc weights_layer_desc() const {
- return rnn_base::weights_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
- memory::desc weights_iter_desc() const {
- return rnn_base::weights_iter_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
- memory::desc bias_desc() const { return rnn_base::bias_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
- memory::desc dst_layer_desc() const {
- return rnn_base::dst_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
- memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
- memory::desc workspace_desc() const {
- return rnn_base::workspace_desc();
- }
- /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
- algorithm get_cell_kind() const { return base::get_cell_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_direction()const
- rnn_direction get_direction() const { return base::get_direction(); }
- };
- /// Default constructor. Produces an empty object.
- lbr_gru_forward() = default;
- /// Constructs an LBR GRU forward propagation primitive.
- /// @param pd Primitive descriptor for an LBR GRU forward propagation
- /// primitive.
- lbr_gru_forward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs an LBR GRU forward propagation primitive from a cache blob.
- /// @param pd Primitive descriptor for an LBR GRU forward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- lbr_gru_forward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// LBR GRU backward propagation primitive.
- struct lbr_gru_backward : public primitive {
- /// Primitive descriptor for an LBR GRU backward propagation primitive.
- struct primitive_desc : public rnn_primitive_desc_base {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for LBR GRU backward propagation
- /// primitive.
- ///
- /// The following arguments may point to a zero memory descriptor:
- /// - @p src_iter_desc together with @p diff_src_iter_desc,
- /// - @p bias_desc together with @p diff_bias_desc,
- /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
- ///
- /// This would then indicate that the LBR GRU backward propagation
- /// primitive should not use them and should default to zero values
- /// instead.
- ///
- /// @note
- /// All memory descriptors may be initialized with
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Must be
- /// #dnnl::prop_kind::backward.
- /// @param direction RNN direction. See @ref dnnl::rnn_direction for
- /// more info.
- /// @param src_layer_desc Memory descriptor for the input vector.
- /// @param src_iter_desc Memory descriptor for the input recurrent
- /// hidden state vector.
- /// @param weights_layer_desc Memory descriptor for the weights
- /// applied to the layer input.
- /// @param weights_iter_desc Memory descriptor for the weights applied
- /// to the recurrent input.
- /// @param bias_desc Bias memory descriptor.
- /// @param dst_layer_desc Memory descriptor for the output vector.
- /// @param dst_iter_desc Memory descriptor for the output recurrent
- /// hidden state vector.
- /// @param diff_src_layer_desc Memory descriptor for the diff of input
- /// vector.
- /// @param diff_src_iter_desc Memory descriptor for the diff of input
- /// recurrent hidden state vector.
- /// @param diff_weights_layer_desc Memory descriptor for the diff of
- /// weights applied to the layer input.
- /// @param diff_weights_iter_desc Memory descriptor for the diff of
- /// weights applied to the recurrent input.
- /// @param diff_bias_desc Diff bias memory descriptor.
- /// @param diff_dst_layer_desc Memory descriptor for the diff of
- /// output vector.
- /// @param diff_dst_iter_desc Memory descriptor for the diff of output
- /// recurrent hidden state vector.
- /// @param hint_fwd_pd Primitive descriptor for an LBR GRU
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- rnn_direction direction, const memory::desc &src_layer_desc,
- const memory::desc &src_iter_desc,
- const memory::desc &weights_layer_desc,
- const memory::desc &weights_iter_desc,
- const memory::desc &bias_desc,
- const memory::desc &dst_layer_desc,
- const memory::desc &dst_iter_desc,
- const memory::desc &diff_src_layer_desc,
- const memory::desc &diff_src_iter_desc,
- const memory::desc &diff_weights_layer_desc,
- const memory::desc &diff_weights_iter_desc,
- const memory::desc &diff_bias_desc,
- const memory::desc &diff_dst_layer_desc,
- const memory::desc &diff_dst_iter_desc,
- const lbr_gru_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : rnn_primitive_desc_base(aengine, algorithm::lbr_gru, aprop_kind,
- algorithm::undef, direction, src_layer_desc, src_iter_desc,
- nullptr, nullptr, weights_layer_desc, weights_iter_desc,
- nullptr, nullptr, bias_desc, dst_layer_desc, dst_iter_desc,
- nullptr, diff_src_layer_desc, diff_src_iter_desc, nullptr,
- nullptr, diff_weights_layer_desc, diff_weights_iter_desc,
- nullptr, nullptr, diff_bias_desc, diff_dst_layer_desc,
- diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f,
- hint_fwd_pd, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a LBR GRU backward propagation
- /// primitive from a C API primitive descriptor that must have a
- /// matching kind.
- ///
- /// @param pd C API primitive descriptor for a LBR GRU backward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : rnn_primitive_desc_base(
- pd, dnnl::prop_kind::backward, dnnl::algorithm::lbr_gru) {}
- /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
- memory::desc src_layer_desc() const {
- return rnn_base::src_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
- memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
- memory::desc weights_layer_desc() const {
- return rnn_base::weights_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
- memory::desc weights_iter_desc() const {
- return rnn_base::weights_iter_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
- memory::desc bias_desc() const { return rnn_base::bias_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
- memory::desc dst_layer_desc() const {
- return rnn_base::dst_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
- memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
- memory::desc workspace_desc() const {
- return rnn_base::workspace_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
- memory::desc diff_src_layer_desc() const {
- return rnn_base::diff_src_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
- memory::desc diff_src_iter_desc() const {
- return rnn_base::diff_src_iter_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
- memory::desc diff_weights_layer_desc() const {
- return rnn_base::diff_weights_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
- memory::desc diff_weights_iter_desc() const {
- return rnn_base::diff_weights_iter_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
- memory::desc diff_bias_desc() const {
- return rnn_base::diff_bias_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
- memory::desc diff_dst_layer_desc() const {
- return rnn_base::diff_dst_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
- memory::desc diff_dst_iter_desc() const {
- return rnn_base::diff_dst_iter_desc();
- }
- /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
- algorithm get_cell_kind() const { return base::get_cell_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_direction()const
- rnn_direction get_direction() const { return base::get_direction(); }
- };
- /// Default constructor. Produces an empty object.
- lbr_gru_backward() = default;
- /// Constructs an LBR GRU backward propagation primitive.
- /// @param pd Primitive descriptor for an LBR GRU backward propagation
- /// primitive.
- lbr_gru_backward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs an LBR GRU backward propagation primitive from a cache blob.
- /// @param pd Primitive descriptor for an LBR GRU backward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- lbr_gru_backward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// AUGRU forward propagation primitive.
- struct augru_forward : public primitive {
- /// Primitive descriptor for an AUGRU forward propagation primitive.
- struct primitive_desc : public rnn_primitive_desc_base {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for an AUGRU forward propagation
- /// primitive.
- ///
- /// The following arguments may point to a zero memory descriptor:
- /// - @p src_iter_desc,
- /// - @p bias_desc,
- /// - @p dst_iter_desc.
- ///
- /// This would then indicate that the AUGRU forward propagation
- /// primitive should not use them and should default to zero values
- /// instead.
- ///
- /// @note
- /// All memory descriptors except @p src_iter_desc may be
- /// initialized with an #dnnl::memory::format_tag::any value of @p
- /// format_tag.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param direction RNN direction. See @ref dnnl::rnn_direction for
- /// more info.
- /// @param src_layer_desc Memory descriptor for the input vector.
- /// @param src_iter_desc Memory descriptor for the input recurrent
- /// hidden state vector.
- /// @param attention_desc Memory descriptor for the attention vector.
- /// @param weights_layer_desc Memory descriptor for the weights
- /// applied to the layer input.
- /// @param weights_iter_desc Memory descriptor for the weights applied
- /// to the recurrent input.
- /// @param bias_desc Bias memory descriptor.
- /// @param dst_layer_desc Memory descriptor for the output vector.
- /// @param dst_iter_desc Memory descriptor for the output recurrent
- /// hidden state vector.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- rnn_direction direction, const memory::desc &src_layer_desc,
- const memory::desc &src_iter_desc,
- const memory::desc &attention_desc,
- const memory::desc &weights_layer_desc,
- const memory::desc &weights_iter_desc,
- const memory::desc &bias_desc,
- const memory::desc &dst_layer_desc,
- const memory::desc &dst_iter_desc,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : rnn_primitive_desc_base(aengine, algorithm::vanilla_augru,
- aprop_kind, algorithm::undef, direction, src_layer_desc,
- src_iter_desc, nullptr, &attention_desc, weights_layer_desc,
- weights_iter_desc, nullptr, nullptr, bias_desc,
- dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef,
- 0.0f, 0.0f, attr, allow_empty) {}
- /// Constructs a primitive descriptor for an AUGRU forward propagation
- /// primitive from a C API primitive descriptor that must have a
- /// matching kind.
- ///
- /// @param pd C API primitive descriptor for an AUGRU forward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
- dnnl::prop_kind::forward_inference,
- dnnl::algorithm::vanilla_augru) {}
- /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
- memory::desc src_layer_desc() const {
- return rnn_base::src_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
- memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const
- memory::desc attention_desc() const {
- return rnn_base::augru_attention_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
- memory::desc weights_layer_desc() const {
- return rnn_base::weights_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
- memory::desc weights_iter_desc() const {
- return rnn_base::weights_iter_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
- memory::desc bias_desc() const { return rnn_base::bias_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
- memory::desc dst_layer_desc() const {
- return rnn_base::dst_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
- memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
- memory::desc workspace_desc() const {
- return rnn_base::workspace_desc();
- }
- /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
- algorithm get_cell_kind() const { return base::get_cell_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_direction()const
- rnn_direction get_direction() const { return base::get_direction(); }
- };
- /// Default constructor. Produces an empty object.
- augru_forward() = default;
- /// Constructs an AUGRU forward propagation primitive.
- /// @param pd Primitive descriptor for an AUGRU forward propagation
- /// primitive.
- augru_forward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs an AUGRU forward propagation primitive from a cache blob.
- /// @param pd Primitive descriptor for an AUGRU forward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- augru_forward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// AUGRU backward propagation primitive.
- struct augru_backward : public primitive {
- /// Descriptor for an AUGRU backward propagation primitive.
- /// Primitive descriptor for an AUGRU backward propagation primitive.
- struct primitive_desc : public rnn_primitive_desc_base {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for an AUGRU backward propagation
- /// primitive.
- ///
- /// The following arguments may point to a zero memory descriptor:
- /// - @p src_iter_desc together with @p diff_src_iter_desc,
- /// - @p bias_desc together with @p diff_bias_desc,
- /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
- ///
- /// This would then indicate that the AUGRU backward propagation
- /// primitive should not use them and should default to zero values
- /// instead.
- ///
- /// @note
- /// All memory descriptors may be initialized with
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Must be
- /// #dnnl::prop_kind::backward.
- /// @param direction RNN direction. See @ref dnnl::rnn_direction for
- /// more info.
- /// @param src_layer_desc Memory descriptor for the input vector.
- /// @param src_iter_desc Memory descriptor for the input recurrent
- /// hidden state vector.
- /// @param attention_desc Memory descriptor for the attention vector.
- /// @param weights_layer_desc Memory descriptor for the weights
- /// applied to the layer input.
- /// @param weights_iter_desc Memory descriptor for the weights applied
- /// to the recurrent input.
- /// @param bias_desc Bias memory descriptor.
- /// @param dst_layer_desc Memory descriptor for the output vector.
- /// @param dst_iter_desc Memory descriptor for the output recurrent
- /// hidden state vector.
- /// @param diff_src_layer_desc Memory descriptor for the diff of input
- /// vector.
- /// @param diff_src_iter_desc Memory descriptor for the diff of input
- /// recurrent hidden state vector.
- /// @param diff_attention_desc Memory descriptor for the diff of
- /// attention vector.
- /// @param diff_weights_layer_desc Memory descriptor for the diff of
- /// weights applied to the layer input.
- /// @param diff_weights_iter_desc Memory descriptor for the diff of
- /// weights applied to the recurrent input.
- /// @param diff_bias_desc Diff bias memory descriptor.
- /// @param diff_dst_layer_desc Memory descriptor for the diff of
- /// output vector.
- /// @param diff_dst_iter_desc Memory descriptor for the diff of output
- /// recurrent hidden state vector.
- /// @param hint_fwd_pd Primitive descriptor for an AUGRU
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- rnn_direction direction, const memory::desc &src_layer_desc,
- const memory::desc &src_iter_desc,
- const memory::desc &attention_desc,
- const memory::desc &weights_layer_desc,
- const memory::desc &weights_iter_desc,
- const memory::desc &bias_desc,
- const memory::desc &dst_layer_desc,
- const memory::desc &dst_iter_desc,
- const memory::desc &diff_src_layer_desc,
- const memory::desc &diff_src_iter_desc,
- const memory::desc &diff_attention_desc,
- const memory::desc &diff_weights_layer_desc,
- const memory::desc &diff_weights_iter_desc,
- const memory::desc &diff_bias_desc,
- const memory::desc &diff_dst_layer_desc,
- const memory::desc &diff_dst_iter_desc,
- const augru_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : rnn_primitive_desc_base(aengine, algorithm::vanilla_augru,
- aprop_kind, algorithm::undef, direction, src_layer_desc,
- src_iter_desc, nullptr, &attention_desc, weights_layer_desc,
- weights_iter_desc, nullptr, nullptr, bias_desc,
- dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
- diff_src_iter_desc, nullptr, &diff_attention_desc,
- diff_weights_layer_desc, diff_weights_iter_desc, nullptr,
- nullptr, diff_bias_desc, diff_dst_layer_desc,
- diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f,
- hint_fwd_pd, attr, allow_empty) {}
- /// Constructs a primitive descriptor for an AUGRU backward propagation
- /// primitive from a C API primitive descriptor that must have a
- /// matching kind.
- ///
- /// @param pd C API primitive descriptor for an AUGRU backward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
- dnnl::algorithm::vanilla_augru) {}
- /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
- memory::desc src_layer_desc() const {
- return rnn_base::src_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
- memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const
- memory::desc attention_desc() const {
- return rnn_base::augru_attention_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
- memory::desc weights_layer_desc() const {
- return rnn_base::weights_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
- memory::desc weights_iter_desc() const {
- return rnn_base::weights_iter_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
- memory::desc bias_desc() const { return rnn_base::bias_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
- memory::desc dst_layer_desc() const {
- return rnn_base::dst_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
- memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
- memory::desc workspace_desc() const {
- return rnn_base::workspace_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
- memory::desc diff_src_layer_desc() const {
- return rnn_base::diff_src_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
- memory::desc diff_src_iter_desc() const {
- return rnn_base::diff_src_iter_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_augru_attention_desc()const
- memory::desc diff_attention_desc() const {
- return rnn_base::diff_augru_attention_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
- memory::desc diff_weights_layer_desc() const {
- return rnn_base::diff_weights_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
- memory::desc diff_weights_iter_desc() const {
- return rnn_base::diff_weights_iter_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
- memory::desc diff_bias_desc() const {
- return rnn_base::diff_bias_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
- memory::desc diff_dst_layer_desc() const {
- return rnn_base::diff_dst_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
- memory::desc diff_dst_iter_desc() const {
- return rnn_base::diff_dst_iter_desc();
- }
- /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
- algorithm get_cell_kind() const { return base::get_cell_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_direction()const
- rnn_direction get_direction() const { return base::get_direction(); }
- };
- /// Default constructor. Produces an empty object.
- augru_backward() = default;
- /// Constructs an AUGRU backward propagation primitive.
- /// @param pd Primitive descriptor for an AUGRU backward propagation
- /// primitive.
- augru_backward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs an AUGRU backward propagation primitive from a cache blob.
- /// @param pd Primitive descriptor for an AUGRU backward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- augru_backward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// LBR AUGRU forward propagation primitive.
- struct lbr_augru_forward : public primitive {
- /// Descriptor for an LBR AUGRU forward propagation primitive.
- /// Primitive descriptor for an LBR AUGRU forward propagation primitive.
- struct primitive_desc : public rnn_primitive_desc_base {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for LBR AUGRU forward propagation
- /// primitive.
- ///
- /// The following arguments may point to a zero memory descriptor:
- /// - @p src_iter_desc,
- /// - @p bias_desc,
- /// - @p dst_iter_desc.
- ///
- /// This would then indicate that the LBR AUGRU forward propagation
- /// primitive should not use them and should default to zero values
- /// instead.
- ///
- /// @note
- /// All memory descriptors except @p src_iter_desc may be
- /// initialized with an #dnnl::memory::format_tag::any value of @p
- /// format_tag.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param direction RNN direction. See @ref dnnl::rnn_direction for
- /// more info.
- /// @param src_layer_desc Memory descriptor for the input vector.
- /// @param src_iter_desc Memory descriptor for the input recurrent
- /// hidden state vector.
- /// @param attention_desc Memory descriptor for the attention vector.
- /// @param weights_layer_desc Memory descriptor for the weights
- /// applied to the layer input.
- /// @param weights_iter_desc Memory descriptor for the weights applied
- /// to the recurrent input.
- /// @param bias_desc Bias memory descriptor.
- /// @param dst_layer_desc Memory descriptor for the output vector.
- /// @param dst_iter_desc Memory descriptor for the output recurrent
- /// hidden state vector.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- rnn_direction direction, const memory::desc &src_layer_desc,
- const memory::desc &src_iter_desc,
- const memory::desc &attention_desc,
- const memory::desc &weights_layer_desc,
- const memory::desc &weights_iter_desc,
- const memory::desc &bias_desc,
- const memory::desc &dst_layer_desc,
- const memory::desc &dst_iter_desc,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : rnn_primitive_desc_base(aengine, algorithm::lbr_augru, aprop_kind,
- algorithm::undef, direction, src_layer_desc, src_iter_desc,
- nullptr, &attention_desc, weights_layer_desc,
- weights_iter_desc, nullptr, nullptr, bias_desc,
- dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef,
- 0.0f, 0.0f, attr, allow_empty) {}
- /// Constructs a primitive descriptor for an LBR AUGRU forward propagation
- /// primitive from a C API primitive descriptor that must have a
- /// matching kind.
- ///
- /// @param pd C API primitive descriptor for an LBR AUGRU forward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
- dnnl::prop_kind::forward_inference,
- dnnl::algorithm::lbr_augru) {}
- /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
- memory::desc src_layer_desc() const {
- return rnn_base::src_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
- memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const
- memory::desc attention_desc() const {
- return rnn_base::augru_attention_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
- memory::desc weights_layer_desc() const {
- return rnn_base::weights_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
- memory::desc weights_iter_desc() const {
- return rnn_base::weights_iter_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
- memory::desc bias_desc() const { return rnn_base::bias_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
- memory::desc dst_layer_desc() const {
- return rnn_base::dst_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
- memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
- memory::desc workspace_desc() const {
- return rnn_base::workspace_desc();
- }
- /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
- algorithm get_cell_kind() const { return base::get_cell_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_direction()const
- rnn_direction get_direction() const { return base::get_direction(); }
- };
- /// Default constructor. Produces an empty object.
- lbr_augru_forward() = default;
- /// Constructs an LBR AUGRU forward propagation primitive.
- /// @param pd Primitive descriptor for an LBR AUGRU forward propagation
- /// primitive.
- lbr_augru_forward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs an LBR AUGRU forward propagation primitive from a cache blob.
- /// @param pd Primitive descriptor for an LBR AUGRU forward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- lbr_augru_forward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// LBR AUGRU backward propagation primitive.
- struct lbr_augru_backward : public primitive {
- /// Primitive descriptor for an LBR AUGRU backward propagation primitive.
- struct primitive_desc : public rnn_primitive_desc_base {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for LBR AUGRU backward propagation
- /// primitive.
- ///
- /// The following arguments may point to a zero memory descriptor:
- /// - @p src_iter_desc together with @p diff_src_iter_desc,
- /// - @p bias_desc together with @p diff_bias_desc,
- /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
- ///
- /// This would then indicate that the LBR AUGRU backward propagation
- /// primitive should not use them and should default to zero values
- /// instead.
- ///
- /// @note
- /// All memory descriptors may be initialized with
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Must be
- /// #dnnl::prop_kind::backward.
- /// @param direction RNN direction. See @ref dnnl::rnn_direction for
- /// more info.
- /// @param src_layer_desc Memory descriptor for the input vector.
- /// @param src_iter_desc Memory descriptor for the input recurrent
- /// hidden state vector.
- /// @param attention_desc Memory descriptor for the attention vector.
- /// @param weights_layer_desc Memory descriptor for the weights
- /// applied to the layer input.
- /// @param weights_iter_desc Memory descriptor for the weights applied
- /// to the recurrent input.
- /// @param bias_desc Bias memory descriptor.
- /// @param dst_layer_desc Memory descriptor for the output vector.
- /// @param dst_iter_desc Memory descriptor for the output recurrent
- /// hidden state vector.
- /// @param diff_src_layer_desc Memory descriptor for the diff of input
- /// vector.
- /// @param diff_src_iter_desc Memory descriptor for the diff of input
- /// recurrent hidden state vector.
- /// @param diff_attention_desc Memory descriptor for the diff of
- /// attention vector.
- /// @param diff_weights_layer_desc Memory descriptor for the diff of
- /// weights applied to the layer input.
- /// @param diff_weights_iter_desc Memory descriptor for the diff of
- /// weights applied to the recurrent input.
- /// @param diff_bias_desc Diff bias memory descriptor.
- /// @param diff_dst_layer_desc Memory descriptor for the diff of
- /// output vector.
- /// @param diff_dst_iter_desc Memory descriptor for the diff of output
- /// recurrent hidden state vector.
- /// @param hint_fwd_pd Primitive descriptor for an LBR AUGRU
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- rnn_direction direction, const memory::desc &src_layer_desc,
- const memory::desc &src_iter_desc,
- const memory::desc &attention_desc,
- const memory::desc &weights_layer_desc,
- const memory::desc &weights_iter_desc,
- const memory::desc &bias_desc,
- const memory::desc &dst_layer_desc,
- const memory::desc &dst_iter_desc,
- const memory::desc &diff_src_layer_desc,
- const memory::desc &diff_src_iter_desc,
- const memory::desc &diff_attention_desc,
- const memory::desc &diff_weights_layer_desc,
- const memory::desc &diff_weights_iter_desc,
- const memory::desc &diff_bias_desc,
- const memory::desc &diff_dst_layer_desc,
- const memory::desc &diff_dst_iter_desc,
- const lbr_augru_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : rnn_primitive_desc_base(aengine, algorithm::lbr_augru, aprop_kind,
- algorithm::undef, direction, src_layer_desc, src_iter_desc,
- nullptr, &attention_desc, weights_layer_desc,
- weights_iter_desc, nullptr, nullptr, bias_desc,
- dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
- diff_src_iter_desc, nullptr, &diff_attention_desc,
- diff_weights_layer_desc, diff_weights_iter_desc, nullptr,
- nullptr, diff_bias_desc, diff_dst_layer_desc,
- diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f,
- hint_fwd_pd, attr, allow_empty) {}
- /// Constructs a primitive descriptor for an LBR AUGRU backward
- /// propagation primitive from a C API primitive descriptor that must
- /// have a matching kind.
- ///
- /// @param pd C API primitive descriptor for an LBR AUGRU backward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
- dnnl::algorithm::lbr_augru) {}
- /// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
- memory::desc src_layer_desc() const {
- return rnn_base::src_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
- memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const
- memory::desc attention_desc() const {
- return rnn_base::augru_attention_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
- memory::desc weights_layer_desc() const {
- return rnn_base::weights_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
- memory::desc weights_iter_desc() const {
- return rnn_base::weights_iter_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
- memory::desc bias_desc() const { return rnn_base::bias_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
- memory::desc dst_layer_desc() const {
- return rnn_base::dst_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
- memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
- /// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
- memory::desc workspace_desc() const {
- return rnn_base::workspace_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
- memory::desc diff_src_layer_desc() const {
- return rnn_base::diff_src_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
- memory::desc diff_src_iter_desc() const {
- return rnn_base::diff_src_iter_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_augru_attention_desc()const
- memory::desc diff_attention_desc() const {
- return rnn_base::diff_augru_attention_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
- memory::desc diff_weights_layer_desc() const {
- return rnn_base::diff_weights_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
- memory::desc diff_weights_iter_desc() const {
- return rnn_base::diff_weights_iter_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
- memory::desc diff_bias_desc() const {
- return rnn_base::diff_bias_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
- memory::desc diff_dst_layer_desc() const {
- return rnn_base::diff_dst_layer_desc();
- }
- /// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
- memory::desc diff_dst_iter_desc() const {
- return rnn_base::diff_dst_iter_desc();
- }
- /// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
- algorithm get_cell_kind() const { return base::get_cell_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_direction()const
- rnn_direction get_direction() const { return base::get_direction(); }
- };
- /// Default constructor. Produces an empty object.
- lbr_augru_backward() = default;
- /// Constructs an LBR AUGRU backward propagation primitive.
- /// @param pd Primitive descriptor for an LBR AUGRU backward propagation
- /// primitive.
- lbr_augru_backward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs an LBR AUGRU backward propagation primitive from a cache blob.
- /// @param pd Primitive descriptor for an LBR AUGRU backward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- lbr_augru_backward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// @} dnnl_api_rnn
- /// @addtogroup dnnl_api_shuffle Shuffle
- ///
- /// A primitive to shuffle tensor data along an axis.
- ///
- /// @sa @ref dev_guide_shuffle in developer guide
- ///
- /// @{
- /// Shuffle forward propagation primitive.
- struct shuffle_forward : public primitive {
- /// Primitive descriptor for a shuffle forward propagation primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a shuffle forward propagation
- /// primitive.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param src_desc Source memory descriptor.
- /// @param dst_desc Destination memory descriptor.
- /// @param axis The axis along which the data is shuffled.
- /// @param group_size Shuffle group size.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- const memory::desc &src_desc, const memory::desc &dst_desc,
- int axis, int group_size,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false) {
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status = dnnl_shuffle_forward_primitive_desc_create(
- &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
- src_desc.get(), dst_desc.get(), axis, group_size,
- attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the shuffle forward propagation primitive. Run "
- "workload with environment variable ONEDNN_VERBOSE=all "
- "to get additional diagnostic information.");
- reset(pd);
- }
- /// Constructs a primitive descriptor for a shuffle forward propagation
- /// primitive from a C API primitive descriptor that must have a
- /// matching kind.
- ///
- /// @param pd C API primitive descriptor for a shuffle forward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle,
- dnnl::prop_kind::forward_training,
- dnnl::prop_kind::forward_inference) {}
- /// @copydoc dnnl::primitive_desc_base::src_desc()const
- memory::desc src_desc() const { return base::src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::dst_desc()const
- memory::desc dst_desc() const { return base::dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_axis()const
- int get_axis() const { return base::get_axis(); }
- /// @copydoc dnnl::primitive_desc_base::get_group_size()const
- memory::dim get_group_size() const { return base::get_group_size(); }
- };
- /// Default constructor. Produces an empty object.
- shuffle_forward() = default;
- /// Constructs a shuffle forward propagation primitive.
- /// @param pd Primitive descriptor for a shuffle forward propagation
- /// primitive.
- shuffle_forward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a shuffle forward propagation primitive from a cache blob.
- /// @param pd Primitive descriptor for a shuffle forward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- shuffle_forward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// Shuffle backward propagation primitive.
- struct shuffle_backward : public primitive {
- /// Primitive descriptor for a shuffle backward propagation primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a shuffle backward propagation
- /// primitive.
- ///
- /// @param aengine Engine to use.
- /// @param diff_src_desc Diff source memory descriptor.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param axis The axis along which the data is shuffled.
- /// @param group_size Shuffle group size.
- /// @param hint_fwd_pd Primitive descriptor for a shuffle forward
- /// propagation primitive. It is used as a hint for deciding which
- /// memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, const memory::desc &diff_src_desc,
- const memory::desc &diff_dst_desc, int axis, int group_size,
- const shuffle_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false) {
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status = dnnl_shuffle_backward_primitive_desc_create(
- &pd, aengine.get(), diff_src_desc.get(),
- diff_dst_desc.get(), axis, group_size, hint_fwd_pd.get(),
- attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the shuffle backward propagation primitive. Run "
- "workload with environment variable ONEDNN_VERBOSE=all "
- "to get additional diagnostic information.");
- reset(pd);
- }
- /// Constructs a primitive descriptor for a shuffle backward
- /// propagation primitive from a C API primitive descriptor that must
- /// have a matching kind.
- ///
- /// @param pd C API primitive descriptor for a shuffle backward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle,
- dnnl::prop_kind::backward_data) {}
- /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
- memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
- memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_axis()const
- int get_axis() const { return base::get_axis(); }
- /// @copydoc dnnl::primitive_desc_base::get_group_size()const
- memory::dim get_group_size() const { return base::get_group_size(); }
- };
- /// Default constructor. Produces an empty object.
- shuffle_backward() = default;
- /// Constructs a shuffle backward propagation primitive.
- /// @param pd Primitive descriptor for a shuffle backward propagation
- /// primitive.
- shuffle_backward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a shuffle backward propagation primitive from a cache blob.
- /// @param pd Primitive descriptor for a shuffle backward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- shuffle_backward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// @} dnnl_api_shuffle
- /// @addtogroup dnnl_api_binary Binary
- ///
- /// A primitive to perform tensor operations over two tensors.
- ///
- /// @sa @ref dev_guide_binary in developer guide
- ///
- /// @{
- /// Elementwise binary operator primitive.
- struct binary : public primitive {
- /// Primitive descriptor for an elementwise binary operator primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for an elementwise binary operator
- /// primitive.
- ///
- /// @param aengine Engine to use.
- /// @param aalgorithm Elementwise binary algorithm.
- /// @param src0 Memory descriptor for source tensor #0.
- /// @param src1 Memory descriptor for source tensor #1.
- /// @param dst Memory descriptor for destination tensor.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &src0, const memory::desc &src1,
- const memory::desc &dst,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false) {
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status = dnnl_binary_primitive_desc_create(&pd,
- aengine.get(), dnnl::convert_to_c(aalgorithm), src0.get(),
- src1.get(), dst.get(), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the binary operation primitive. Run workload with "
- "environment variable ONEDNN_VERBOSE=all to get "
- "additional diagnostic information.");
- reset(pd);
- }
- /// Constructs a primitive descriptor for an elementwise binary operator
- /// primitive with support of ternary operators.
- ///
- /// @param aengine Engine to use.
- /// @param aalgorithm Elementwise binary algorithm.
- /// @param src0 Memory descriptor for source tensor #0.
- /// @param src1 Memory descriptor for source tensor #1.
- /// @param src2 Memory descriptor for source tensor #2 for ternary
- /// operations. Might be empty.
- /// @param dst Memory descriptor for destination tensor.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &src0, const memory::desc &src1,
- const memory::desc &src2, const memory::desc &dst,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false) {
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status = dnnl_binary_primitive_desc_create_v2(&pd,
- aengine.get(), dnnl::convert_to_c(aalgorithm), src0.get(),
- src1.get(), src2.get(), dst.get(), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the binary v2 operation primitive. Run workload with "
- "environment variable ONEDNN_VERBOSE=all to get "
- "additional diagnostic information.");
- reset(pd);
- }
- /// Constructs a primitive descriptor for a binary primitive from a C
- /// API primitive descriptor that must have a matching kind.
- ///
- /// @param pd C API primitive descriptor for a binary primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd, dnnl::primitive::kind::binary) {}
- /// @copydoc dnnl::primitive_desc_base::src_desc(int)const
- memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
- /// Returns the memory descriptor for source #0.
- memory::desc src0_desc() const { return base::src_desc(0); }
- /// Returns the memory descriptor for source #1.
- memory::desc src1_desc() const { return base::src_desc(1); }
- /// Returns the memory descriptor for source #2.
- memory::desc src2_desc() const { return base::src_desc(2); }
- /// @copydoc dnnl::primitive_desc_base::dst_desc()const
- memory::desc dst_desc() const { return base::dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
- algorithm get_algorithm() const { return base::get_algorithm(); }
- };
- /// Default constructor. Produces an empty object.
- binary() = default;
- /// Constructs an elementwise binary operation primitive.
- /// @param pd Primitive descriptor for an elementwise binary operation
- /// primitive.
- binary(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs an elementwise binary operation primitive from a cache blob.
- /// @param pd Primitive descriptor for an elementwise binary operation
- /// primitive.
- /// @param cache_blob Cache blob.
- binary(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// @} dnnl_api_binary
- /// @addtogroup dnnl_api_matmul Matrix Multiplication
- ///
- /// A primitive to perform matrix-matrix multiplication. The batched mode
- /// is supported with 3D tensors.
- ///
- /// @sa @ref dev_guide_matmul in developer guide
- ///
- ///
- /// @{
- /// Matrix multiplication (matmul) primitive.
- struct matmul : public primitive {
- /// Primitive descriptor for a matmul primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a matmul primitive
- /// without bias.
- ///
- /// @param aengine Engine to use.
- /// @param src_desc Memory descriptor for source (matrix A).
- /// @param weights_desc Memory descriptor for weights (matrix B).
- /// @param dst_desc Memory descriptor for destination (matrix C).
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, const memory::desc &src_desc,
- const memory::desc &weights_desc, const memory::desc &dst_desc,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, src_desc, weights_desc, nullptr, dst_desc,
- attr, allow_empty) {}
- /// Constructs a primitive descriptor for a matmul primitive with bias.
- ///
- /// @param aengine Engine to use.
- /// @param src_desc Memory descriptor for source (matrix A).
- /// @param weights_desc Memory descriptor for weights (matrix B).
- /// @param dst_desc Memory descriptor for destination (matrix C).
- /// @param bias_desc Memory descriptor for bias.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, const memory::desc &src_desc,
- const memory::desc &weights_desc, const memory::desc &bias_desc,
- const memory::desc &dst_desc,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, src_desc, weights_desc, &bias_desc,
- dst_desc, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a matmul primitive from a C
- /// API primitive descriptor that must have a matching kind.
- ///
- /// @param pd C API primitive descriptor for a matmul primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd, dnnl::primitive::kind::matmul) {}
- /// @copydoc dnnl::primitive_desc_base::src_desc()const
- memory::desc src_desc() const { return query_md(query::src_md, 0); }
- /// @copydoc dnnl::primitive_desc_base::weights_desc()const
- memory::desc weights_desc() const {
- return query_md(query::weights_md, 0);
- }
- /// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const
- memory::desc bias_desc() const {
- return query_md(query::weights_md, 1);
- }
- /// @copydoc dnnl::primitive_desc_base::dst_desc()const
- memory::desc dst_desc() const { return query_md(query::dst_md, 0); }
- private:
- primitive_desc(const engine &aengine, const memory::desc &src_desc,
- const memory::desc &weights_desc, const memory::desc *bias_desc,
- const memory::desc &dst_desc, const primitive_attr &attr,
- bool allow_empty) {
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status = dnnl_matmul_primitive_desc_create(&pd,
- aengine.get(), src_desc.get(), weights_desc.get(),
- optional_arg(bias_desc), dst_desc.get(), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the matmul primitive. Run workload with "
- "environment variable ONEDNN_VERBOSE=all to get "
- "additional diagnostic information.");
- reset(pd);
- }
- };
- /// Default constructor. Produces an empty object.
- matmul() = default;
- /// Constructs a matmul primitive.
- /// @param pd Primitive descriptor for a matmul primitive.
- matmul(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a matmul primitive from a cache blob.
- /// @param pd Primitive descriptor for a matmul primitive.
- /// @param cache_blob Cache blob.
- matmul(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// @} dnnl_api_matmul
- /// @addtogroup dnnl_api_resampling Resampling
- ///
- /// A primitive to compute resampling operation on 1D, 2D or 3D data tensor
- /// using Nearest Neighbor, or Linear (Bilinear, Trilinear) interpolation
- /// method.
- ///
- /// @sa @ref dev_guide_resampling in developer guide
- ///
- /// @{
- /// Resampling forward propagation.
- struct resampling_forward : public primitive {
- /// Primitive descriptor for a resampling forward propagation primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a resampling forward
- /// propagation primitive using source and destination memory
- /// descriptors.
- ///
- /// @note
- /// Destination memory descriptor may be initialized with
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param aalgorithm resampling algorithm kind: either
- /// #dnnl::algorithm::resampling_nearest, or
- /// #dnnl::algorithm::resampling_linear
- /// @param src_desc Source memory descriptor.
- /// @param dst_desc Destination memory descriptor.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- algorithm aalgorithm, const memory::desc &src_desc,
- const memory::desc &dst_desc,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aprop_kind, aalgorithm, nullptr, src_desc,
- &dst_desc, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a resampling forward
- /// propagation primitive using source memory descriptor and
- /// factors.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param aalgorithm resampling algorithm kind: either
- /// #dnnl::algorithm::resampling_nearest, or
- /// #dnnl::algorithm::resampling_linear
- /// @param factors Vector of scaling factors for spatial dimension.
- /// @param src_desc Source memory descriptor.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- algorithm aalgorithm, const std::vector<float> &factors,
- const memory::desc &src_desc,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aprop_kind, aalgorithm, &factors,
- src_desc, nullptr, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a resampling forward
- /// propagation primitive.
- ///
- /// @note
- /// The destination memory descriptor may be initialized with
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param aalgorithm resampling algorithm kind: either
- /// #dnnl::algorithm::resampling_nearest, or
- /// #dnnl::algorithm::resampling_linear
- /// @param factors Vector of scaling factors for spatial dimension.
- /// @param src_desc Source memory descriptor.
- /// @param dst_desc Destination memory descriptor.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- algorithm aalgorithm, const std::vector<float> &factors,
- const memory::desc &src_desc, const memory::desc &dst_desc,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aprop_kind, aalgorithm, &factors,
- src_desc, &dst_desc, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a resampling forward
- /// propagation primitive from a C API primitive descriptor that must
- /// have a matching kind.
- ///
- /// @param pd C API primitive descriptor for a resampling forward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd, dnnl::primitive::kind::resampling,
- dnnl::prop_kind::forward_training,
- dnnl::prop_kind::forward_inference) {}
- /// @copydoc dnnl::primitive_desc_base::src_desc()const
- memory::desc src_desc() const { return base::src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::dst_desc()const
- memory::desc dst_desc() const { return base::dst_desc(0); }
- private:
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- algorithm aalgorithm, const std::vector<float> *factors,
- const memory::desc &src_desc, const memory::desc *dst_desc,
- const primitive_attr &attr, bool allow_empty) {
- if (factors)
- memory::validate_dims(*factors, src_desc.get_ndims() - 2);
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status
- = dnnl_resampling_forward_primitive_desc_create(&pd,
- aengine.get(), dnnl::convert_to_c(aprop_kind),
- convert_to_c(aalgorithm), optional_arg(factors),
- src_desc.get(), optional_arg(dst_desc), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the resampling forward propagation primitive. Run "
- "workload with environment variable ONEDNN_VERBOSE=all "
- "to get additional diagnostic information.");
- reset(pd);
- }
- };
- /// Default constructor. Produces an empty object.
- resampling_forward() = default;
- /// Constructs a resampling forward propagation primitive.
- /// @param pd Primitive descriptor for a resampling forward propagation
- /// primitive.
- resampling_forward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a resampling forward propagation primitive from a cache
- /// blob.
- /// @param pd Primitive descriptor for a resampling forward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- resampling_forward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// Resampling backward propagation primitive.
- struct resampling_backward : public primitive {
- /// Primitive descriptor for resampling backward propagation primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a resampling backward
- /// propagation primitive using source and destination memory
- /// descriptors.
- ///
- /// @param aengine Engine to use.
- /// @param aalgorithm resampling algorithm kind: either
- /// #dnnl::algorithm::resampling_nearest, or
- /// #dnnl::algorithm::resampling_linear
- /// @param diff_src_desc Diff source memory descriptor.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param hint_fwd_pd Primitive descriptor for a resampling
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &diff_src_desc,
- const memory::desc &diff_dst_desc,
- const resampling_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aalgorithm, nullptr, diff_src_desc,
- diff_dst_desc, hint_fwd_pd, attr, allow_empty) {}
- /// Constructs a primitive descriptor for resampling backward
- /// propagation primitive.
- ///
- /// @param aengine Engine to use.
- /// @param aalgorithm resampling algorithm kind: either
- /// #dnnl::algorithm::resampling_nearest, or
- /// #dnnl::algorithm::resampling_linear
- /// @param factors Vector of scaling factors for spatial dimension.
- /// @param diff_src_desc Diff source memory descriptor.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param hint_fwd_pd Primitive descriptor for a resampling
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const std::vector<float> &factors,
- const memory::desc &diff_src_desc,
- const memory::desc &diff_dst_desc,
- const resampling_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false)
- : primitive_desc(aengine, aalgorithm, &factors, diff_src_desc,
- diff_dst_desc, hint_fwd_pd, attr, allow_empty) {}
- /// Constructs a primitive descriptor for a resampling backward
- /// propagation primitive from a C API primitive descriptor that must
- /// have a matching kind.
- ///
- /// @param pd C API primitive descriptor for a resampling backward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd, dnnl::primitive::kind::resampling,
- dnnl::prop_kind::backward_data) {}
- /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
- memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
- memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
- private:
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const std::vector<float> *factors,
- const memory::desc &diff_src_desc,
- const memory::desc &diff_dst_desc,
- const resampling_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr, bool allow_empty) {
- if (factors)
- memory::validate_dims(*factors, diff_src_desc.get_ndims() - 2);
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status
- = dnnl_resampling_backward_primitive_desc_create(&pd,
- aengine.get(), convert_to_c(aalgorithm),
- optional_arg(factors), diff_src_desc.get(),
- diff_dst_desc.get(), hint_fwd_pd.get(), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the resampling backward propagation primitive. Run "
- "workload with environment variable ONEDNN_VERBOSE=all "
- "to get additional diagnostic information.");
- reset(pd);
- }
- };
- /// Default constructor. Produces an empty object.
- resampling_backward() = default;
- /// Constructs a resampling backward propagation primitive.
- /// @param pd Primitive descriptor for a resampling backward propagation
- /// primitive.
- resampling_backward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a resampling backward propagation primitive from a cache
- /// blob.
- /// @param pd Primitive descriptor for a resampling backward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- resampling_backward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// @} dnnl_api_resampling
- /// @addtogroup dnnl_api_pooling Pooling
- ///
- /// A primitive to perform max or average pooling with dilation.
- ///
- /// @sa @ref dev_guide_pooling in developer guide
- ///
- /// @{
- /// Pooling forward propagation primitive.
- struct pooling_forward : public primitive {
- /// Primitive descriptor for a pooling forward propagation primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for pooling forward propagation
- /// primitive.
- ///
- /// Arrays @p strides, @p kernel, @p dilation, @p padding_l
- /// and @p padding_r contain values for spatial dimensions only and
- /// hence must have the same number of elements as there are spatial
- /// dimensions. The order of values is the same as in the tensor:
- /// depth (for 3D tensors), height (for 3D and 2D tensors), and width.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param aalgorithm Pooling algorithm kind: either
- /// #dnnl::algorithm::pooling_max,
- /// #dnnl::algorithm::pooling_avg_include_padding,
- /// or #dnnl::algorithm::pooling_avg_exclude_padding.
- /// @param src_desc Source memory descriptor.
- /// @param dst_desc Destination memory descriptor.
- /// @param strides Vector of strides for spatial dimension.
- /// @param kernel Vector of kernel spatial dimensions.
- /// @param dilation Array of dilations for spatial dimension.
- /// @param padding_l Vector of padding values for low indices for each
- /// spatial dimension `([[front,] top,] left)`.
- /// @param padding_r Vector of padding values for high indices for
- /// each spatial dimension `([[back,] bottom,] right)`.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- algorithm aalgorithm, const memory::desc &src_desc,
- const memory::desc &dst_desc, const memory::dims &strides,
- const memory::dims &kernel, const memory::dims &dilation,
- const memory::dims &padding_l, const memory::dims &padding_r,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false) {
- memory::validate_dims(strides, src_desc.get_ndims() - 2);
- memory::validate_dims(kernel, src_desc.get_ndims() - 2);
- memory::validate_dims(padding_l, src_desc.get_ndims() - 2);
- memory::validate_dims(padding_r, src_desc.get_ndims() - 2);
- memory::validate_dims(dilation, src_desc.get_ndims() - 2);
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status = dnnl_pooling_forward_primitive_desc_create(
- &pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
- convert_to_c(aalgorithm), src_desc.get(), dst_desc.get(),
- &strides[0], &kernel[0], &dilation[0], &padding_l[0],
- &padding_r[0], attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a descriptor for a pooling forward "
- "propagation primitive");
- reset(pd);
- }
- /// Constructs a primitive descriptor for a pooling forward propagation
- /// primitive from a C API primitive descriptor that must have a
- /// matching kind.
- ///
- /// @param pd C API primitive descriptor for a pooling forward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling,
- dnnl::prop_kind::forward_training,
- dnnl::prop_kind::forward_inference) {}
- /// @copydoc dnnl::primitive_desc_base::src_desc()const
- memory::desc src_desc() const { return base::src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::dst_desc()const
- memory::desc dst_desc() const { return base::dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
- memory::desc workspace_desc() const { return base::workspace_desc(); }
- /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
- algorithm get_algorithm() const { return base::get_algorithm(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_strides()const
- memory::dims get_strides() const { return base::get_strides(); }
- /// @copydoc dnnl::primitive_desc_base::get_kernel()const
- memory::dims get_kernel() const { return base::get_kernel(); }
- /// @copydoc dnnl::primitive_desc_base::get_dilations()const
- memory::dims get_dilations() const { return base::get_dilations(); }
- /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
- memory::dims get_padding_l() const { return base::get_padding_l(); }
- /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
- memory::dims get_padding_r() const { return base::get_padding_r(); }
- };
- /// Default constructor. Produces an empty object.
- pooling_forward() = default;
- /// Constructs a pooling forward propagation primitive.
- ///
- /// @param pd Primitive descriptor for a pooling forward propagation
- /// primitive.
- pooling_forward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a pooling forward propagation primitive from a cache blob.
- ///
- /// @param pd Primitive descriptor for a pooling forward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- pooling_forward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// Pooling backward propagation primitive.
- struct pooling_backward : public primitive {
- /// Primitive descriptor for a pooling backward propagation primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a pooling backward propagation
- /// primitive.
- ///
- /// Arrays @p strides, @p kernel, @p dilation, @p padding_l
- /// and @p padding_r contain values for spatial dimensions only and
- /// hence must have the same number of elements as there are spatial
- /// dimensions. The order of values is the same as in the tensor:
- /// depth (for 3D tensors), height (for 3D and 2D tensors), and width.
- ///
- /// @param aengine Engine to use.
- /// @param aalgorithm Pooling algorithm kind: either
- /// #dnnl::algorithm::pooling_max,
- /// #dnnl::algorithm::pooling_avg_include_padding,
- /// or #dnnl::algorithm::pooling_avg_exclude_padding.
- /// @param diff_src_desc Diff source memory descriptor.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param strides Vector of strides for spatial dimension.
- /// @param kernel Vector of kernel spatial dimensions.
- /// @param dilation Array of dilations for spatial dimension.
- /// @param padding_l Vector of padding values for low indices for each
- /// spatial dimension `([[front,] top,] left)`.
- /// @param padding_r Vector of padding values for high indices for
- /// each spatial dimension `([[back,] bottom,] right)`.
- /// @param hint_fwd_pd Primitive descriptor for a pooling
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &diff_src_desc,
- const memory::desc &diff_dst_desc, const memory::dims &strides,
- const memory::dims &kernel, const memory::dims &dilation,
- const memory::dims &padding_l, const memory::dims &padding_r,
- const pooling_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false) {
- memory::validate_dims(strides, diff_src_desc.get_ndims() - 2);
- memory::validate_dims(kernel, diff_src_desc.get_ndims() - 2);
- memory::validate_dims(padding_l, diff_src_desc.get_ndims() - 2);
- memory::validate_dims(padding_r, diff_src_desc.get_ndims() - 2);
- memory::validate_dims(dilation, diff_src_desc.get_ndims() - 2);
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status = dnnl_pooling_backward_primitive_desc_create(
- &pd, aengine.get(), convert_to_c(aalgorithm),
- diff_src_desc.get(), diff_dst_desc.get(), &strides[0],
- &kernel[0], &dilation[0], &padding_l[0], &padding_r[0],
- hint_fwd_pd.get(), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a descriptor for a pooling backward "
- "propagation primitive");
- reset(pd);
- }
- /// Constructs a primitive descriptor for a pooling backward propagation
- /// primitive from a C API primitive descriptor that must have a
- /// matching kind.
- ///
- /// @param pd C API primitive descriptor for a pooling backward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling,
- dnnl::prop_kind::backward_data) {}
- /// @copydoc dnnl::primitive_desc_base::src_desc()const
- memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
- memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::workspace_desc()const
- memory::desc workspace_desc() const { return base::workspace_desc(); }
- /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
- algorithm get_algorithm() const { return base::get_algorithm(); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- /// @copydoc dnnl::primitive_desc_base::get_strides()const
- memory::dims get_strides() const { return base::get_strides(); }
- /// @copydoc dnnl::primitive_desc_base::get_kernel()const
- memory::dims get_kernel() const { return base::get_kernel(); }
- /// @copydoc dnnl::primitive_desc_base::get_dilations()const
- memory::dims get_dilations() const { return base::get_dilations(); }
- /// @copydoc dnnl::primitive_desc_base::get_padding_l()const
- memory::dims get_padding_l() const { return base::get_padding_l(); }
- /// @copydoc dnnl::primitive_desc_base::get_padding_r()const
- memory::dims get_padding_r() const { return base::get_padding_r(); }
- };
- /// Default constructor. Produces an empty object.
- pooling_backward() = default;
- /// Constructs a pooling backward propagation primitive.
- ///
- /// @param pd Primitive descriptor for a pooling backward propagation
- /// primitive.
- pooling_backward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a pooling backward propagation primitive from a cache blob.
- ///
- /// @param pd Primitive descriptor for a pooling backward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- pooling_backward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// @} dnnl_api_pooling
- /// @addtogroup dnnl_api_prelu PReLU
- ///
- /// PReLU primitive
- /// A primitive to perform PReLU (leaky ReLU with trainable alpha parameter)
- ///
- /// @sa @ref dev_guide_prelu in developer guide
- ///
- /// @{
- /// PReLU forward propagation primitive.
- struct prelu_forward : public primitive {
- /// Primitive descriptor for a PReLU forward propagation primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a PReLU forward propagation
- /// primitive.
- ///
- /// @param aengine Engine to use.
- /// @param aprop_kind Propagation kind. Possible values are
- /// #dnnl::prop_kind::forward_training, and
- /// #dnnl::prop_kind::forward_inference.
- /// @param src_desc Source memory descriptor.
- /// @param weight_desc Alpha parameters memory descriptor.
- /// @param dst_desc Destination memory descriptor.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, prop_kind aprop_kind,
- const memory::desc &src_desc, const memory::desc &weight_desc,
- const memory::desc &dst_desc,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false) {
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status = dnnl_prelu_forward_primitive_desc_create(&pd,
- aengine.get(), dnnl::convert_to_c(aprop_kind),
- src_desc.get(), weight_desc.get(), dst_desc.get(),
- attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the prelu forward propagation primitive. Run workload "
- "with environment variable ONEDNN_VERBOSE=all to get "
- "additional diagnostic information.");
- reset(pd);
- }
- /// Constructs a primitive descriptor for a prelu forward
- /// propagation primitive from a C API primitive descriptor that must
- /// have a matching kind.
- ///
- /// @param pd C API primitive descriptor for a prelu forward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd, dnnl::primitive::kind::prelu,
- dnnl::prop_kind::forward_training,
- dnnl::prop_kind::forward_inference) {}
- /// @copydoc dnnl::primitive_desc_base::src_desc()const
- memory::desc src_desc() const { return base::src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::dst_desc()const
- memory::desc dst_desc() const { return base::dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- };
- /// Default constructor. Produces an empty object.
- prelu_forward() = default;
- /// Constructs a prelu forward propagation primitive.
- /// @param pd Primitive descriptor for a prelu forward propagation
- /// primitive.
- prelu_forward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a prelu forward propagation primitive from a cache blob.
- /// @param pd Primitive descriptor for a prelu forward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- prelu_forward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// PReLU backward propagation primitive.
- struct prelu_backward : public primitive {
- /// Primitive descriptor for prelu backward propagation.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a descriptor for a PReLU backward propagation
- /// primitive.
- ///
- /// @param aengine Engine to use.
- /// @param src_desc Source memory descriptor.
- /// @param weight_desc Alpha parameters memory descriptor.
- /// @param diff_src_desc Diff source memory descriptor.
- /// @param diff_weights_desc Diff alpha parameters memory descriptor.
- /// @param diff_dst_desc Diff destination memory descriptor.
- /// @param hint_fwd_pd Primitive descriptor for a PReLU
- /// forward propagation primitive. It is used as a hint for
- /// deciding which memory format to use.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, const memory::desc &src_desc,
- const memory::desc &weight_desc,
- const memory::desc &diff_src_desc,
- const memory::desc &diff_weights_desc,
- const memory::desc &diff_dst_desc,
- const prelu_forward::primitive_desc &hint_fwd_pd,
- const primitive_attr &attr = default_attr(),
- bool allow_empty = false) {
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status = dnnl_prelu_backward_primitive_desc_create(
- &pd, aengine.get(), src_desc.get(), weight_desc.get(),
- diff_src_desc.get(), diff_weights_desc.get(),
- diff_dst_desc.get(), hint_fwd_pd.get(), attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the prelu backward propagation primitive. Run "
- "workload with environment variable ONEDNN_VERBOSE=all "
- "to get additional diagnostic information.");
- reset(pd);
- }
- /// Constructs a primitive descriptor for a prelu backward
- /// propagation primitive from a C API primitive descriptor that must
- /// have a matching kind.
- ///
- /// @param pd C API primitive descriptor for a prelu backward
- /// propagation primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd, dnnl::primitive::kind::prelu,
- dnnl::prop_kind::backward) {}
- /// @copydoc dnnl::primitive_desc_base::src_desc()const
- memory::desc src_desc() const { return base::src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
- memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
- memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
- prop_kind get_prop_kind() const { return base::get_prop_kind(); }
- };
- /// Default constructor. Produces an empty object.
- prelu_backward() = default;
- /// Constructs a prelu backward propagation primitive.
- /// @param pd Primitive descriptor for a prelu backward propagation
- /// primitive.
- prelu_backward(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a prelu backward propagation primitive from a cache blob.
- /// @param pd Primitive descriptor for a prelu backward propagation
- /// primitive.
- /// @param cache_blob Cache blob.
- prelu_backward(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// @} dnnl_api_prelu
- /// @addtogroup dnnl_api_reduction Reduction
- ///
- /// A primitive to compute reduction operation on data tensor
- /// using min, max, mul, sum, mean and norm_lp operations.
- ///
- /// @sa @ref dev_guide_reduction in developer guide
- ///
- /// @{
- /// Reduction.
- struct reduction : public primitive {
- /// Primitive descriptor for a reduction primitive.
- struct primitive_desc : public dnnl::primitive_desc {
- /// Default constructor. Produces an empty object.
- primitive_desc() = default;
- /// Constructs a primitive descriptor for a reduction primitive using
- /// algorithm specific parameters, source and destination memory
- /// descriptors.
- ///
- /// @note
- /// Destination memory descriptor may be initialized with
- /// #dnnl::memory::format_tag::any value of @p format_tag.
- ///
- /// @param aengine Engine to use.
- /// @param aalgorithm reduction algorithm kind. Possible values:
- /// #dnnl_reduction_max, #dnnl_reduction_min, #dnnl_reduction_sum,
- /// #dnnl_reduction_mul, #dnnl_reduction_mean,
- /// #dnnl_reduction_norm_lp_max, #dnnl_reduction_norm_lp_sum,
- /// #dnnl_reduction_norm_lp_power_p_max,
- /// #dnnl_reduction_norm_lp_power_p_sum.
- /// @param p algorithm specific parameter.
- /// @param eps algorithm specific parameter.
- /// @param src_desc Source memory descriptor.
- /// @param dst_desc Destination memory descriptor.
- /// @param attr Primitive attributes to use. Attributes are optional
- /// and default to empty attributes.
- /// @param allow_empty A flag signifying whether construction is
- /// allowed to fail without throwing an exception. In this case an
- /// empty object will be produced. This flag is optional and
- /// defaults to false.
- primitive_desc(const engine &aengine, algorithm aalgorithm,
- const memory::desc &src_desc, const memory::desc &dst_desc,
- float p, float eps, const primitive_attr &attr = default_attr(),
- bool allow_empty = false) {
- dnnl_primitive_desc_t pd = nullptr;
- dnnl_status_t status = dnnl_reduction_primitive_desc_create(&pd,
- aengine.get(), convert_to_c(aalgorithm), src_desc.get(),
- dst_desc.get(), p, eps, attr.get());
- if (!allow_empty)
- error::wrap_c_api(status,
- "could not create a primitive descriptor for "
- "the reduction primitive. Run workload with "
- "environment variable ONEDNN_VERBOSE=all to get "
- "additional diagnostic information.");
- reset(pd);
- }
- /// Constructs a primitive descriptor for a reduction primitive from a C
- /// API primitive descriptor that must have a matching kind.
- ///
- /// @param pd C API primitive descriptor for a reduction primitive.
- primitive_desc(dnnl_primitive_desc_t pd)
- : dnnl::primitive_desc(pd, dnnl::primitive::kind::reduction) {}
- /// @copydoc dnnl::primitive_desc_base::src_desc()const
- memory::desc src_desc() const { return base::src_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::dst_desc()const
- memory::desc dst_desc() const { return base::dst_desc(0); }
- /// @copydoc dnnl::primitive_desc_base::get_p()const
- float get_p() const { return base::get_p(); }
- /// @copydoc dnnl::primitive_desc_base::get_epsilon()const
- float get_epsilon() const { return base::get_epsilon(); }
- /// @copydoc dnnl::primitive_desc_base::get_algorithm()const
- algorithm get_algorithm() const { return base::get_algorithm(); }
- };
- /// Default constructor. Produces an empty object.
- reduction() = default;
- /// Constructs a reduction primitive.
- /// @param pd Primitive descriptor for a reduction primitive.
- reduction(const primitive_desc &pd) : primitive(pd) {}
- /// Constructs a reduction primitive from a cache blob.
- /// @param pd Primitive descriptor for a reduction primitive.
- /// @param cache_blob Cache blob.
- reduction(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd, cache_blob) {}
- };
- /// @} dnnl_api_reduction
- /// @} dnnl_api_primitives
- /// @addtogroup dnnl_api_service Service
- ///
- /// A set of functions that aid in oneDNN debugging and profiling.
- ///
- /// @{
- /// @copydoc dnnl_version_t
- using version_t = dnnl_version_t;
- /// Status values returned by the library functions.
- enum class status {
- /// @copydoc dnnl_success
- success = dnnl_success,
- /// @copydoc dnnl_out_of_memory
- out_of_memory = dnnl_out_of_memory,
- /// @copydoc dnnl_invalid_arguments
- invalid_arguments = dnnl_invalid_arguments,
- /// @copydoc dnnl_unimplemented
- unimplemented = dnnl_unimplemented,
- /// @copydoc dnnl_last_impl_reached
- last_impl_reached = dnnl_last_impl_reached,
- /// @copydoc dnnl_runtime_error
- runtime_error = dnnl_runtime_error,
- /// @copydoc dnnl_not_required
- not_required = dnnl_not_required,
- };
- /// @copydoc dnnl_set_verbose()
- inline status set_verbose(int level) {
- return static_cast<status>(dnnl_set_verbose(level));
- }
- /// @copydoc dnnl_version()
- inline const version_t *version() {
- return dnnl_version();
- }
- /// Returns the floating-point math mode that will be used by default
- /// for all subsequently created primitives.
- ///
- /// @returns Output FP math mode.
- inline fpmath_mode get_default_fpmath_mode() {
- dnnl_fpmath_mode_t mode;
- error::wrap_c_api(dnnl_get_default_fpmath_mode(&mode),
- "could not get a default fpmath mode");
- return static_cast<fpmath_mode>(mode);
- }
- /// @copydoc dnnl_set_default_fpmath_mode()
- inline status set_default_fpmath_mode(fpmath_mode mode) {
- return static_cast<status>(
- dnnl_set_default_fpmath_mode(convert_to_c(mode)));
- }
- /// @copydoc dnnl_set_jit_dump()
- inline status set_jit_dump(int enable) {
- return static_cast<status>(dnnl_set_jit_dump(enable));
- }
- /// @copydoc dnnl_set_jit_profiling_flags()
- inline status set_jit_profiling_flags(unsigned flags) {
- return static_cast<status>(dnnl_set_jit_profiling_flags(flags));
- }
- /// @copydoc dnnl_set_jit_profiling_jitdumpdir()
- inline status set_jit_profiling_jitdumpdir(const std::string &dir) {
- return static_cast<status>(dnnl_set_jit_profiling_jitdumpdir(dir.c_str()));
- }
- /// @copydoc dnnl_cpu_isa_t
- enum class cpu_isa {
- /// @copydoc dnnl_cpu_isa_default
- isa_default = dnnl_cpu_isa_default,
- /// @copydoc dnnl_cpu_isa_sse41
- sse41 = dnnl_cpu_isa_sse41,
- /// @copydoc dnnl_cpu_isa_avx
- avx = dnnl_cpu_isa_avx,
- /// @copydoc dnnl_cpu_isa_avx2
- avx2 = dnnl_cpu_isa_avx2,
- /// @copydoc dnnl_cpu_isa_avx2_vnni
- avx2_vnni = dnnl_cpu_isa_avx2_vnni,
- /// @copydoc dnnl_cpu_isa_avx2_vnni_2
- avx2_vnni_2 = dnnl_cpu_isa_avx2_vnni_2,
- /// @copydoc dnnl_cpu_isa_avx512_core
- avx512_core = dnnl_cpu_isa_avx512_core,
- /// @copydoc dnnl_cpu_isa_avx512_core_vnni
- avx512_core_vnni = dnnl_cpu_isa_avx512_core_vnni,
- /// @copydoc dnnl_cpu_isa_avx512_core_bf16
- avx512_core_bf16 = dnnl_cpu_isa_avx512_core_bf16,
- /// @copydoc dnnl_cpu_isa_avx10_1_512
- avx10_1_512 = dnnl_cpu_isa_avx10_1_512,
- /// @copydoc dnnl_cpu_isa_avx512_core_fp16
- avx512_core_fp16 = dnnl_cpu_isa_avx512_core_fp16,
- /// @copydoc dnnl_cpu_isa_avx10_1_512_amx
- avx10_1_512_amx = dnnl_cpu_isa_avx10_1_512_amx,
- /// @copydoc dnnl_cpu_isa_avx512_core_amx
- avx512_core_amx = dnnl_cpu_isa_avx512_core_amx,
- /// @copydoc dnnl_cpu_isa_avx10_1_512_amx_fp16
- avx10_1_512_amx_fp16 = dnnl_cpu_isa_avx10_1_512_amx_fp16,
- /// @copydoc dnnl_cpu_isa_avx512_core_amx_fp16
- avx512_core_amx_fp16 = dnnl_cpu_isa_avx512_core_amx_fp16,
- /// @copydoc dnnl_cpu_isa_avx10_2_512
- avx10_2_512 = dnnl_cpu_isa_avx10_2_512,
- /// @copydoc dnnl_cpu_isa_avx10_2_512_amx_2
- avx10_2_512_amx_2 = dnnl_cpu_isa_avx10_2_512_amx_2,
- };
- /// @copydoc dnnl_set_max_cpu_isa()
- inline status set_max_cpu_isa(cpu_isa isa) {
- return static_cast<status>(
- dnnl_set_max_cpu_isa(static_cast<dnnl_cpu_isa_t>(isa)));
- }
- /// @copydoc dnnl_get_effective_cpu_isa()
- inline cpu_isa get_effective_cpu_isa() {
- return static_cast<cpu_isa>(dnnl_get_effective_cpu_isa());
- }
- /// @copydoc dnnl_cpu_isa_hints_t
- enum class cpu_isa_hints {
- /// @copydoc dnnl_cpu_isa_no_hints
- no_hints = dnnl_cpu_isa_no_hints,
- /// @copydoc dnnl_cpu_isa_prefer_ymm
- prefer_ymm = dnnl_cpu_isa_prefer_ymm,
- };
- /// @copydoc dnnl_set_cpu_isa_hints()
- inline status set_cpu_isa_hints(cpu_isa_hints isa_hints) {
- return static_cast<status>(dnnl_set_cpu_isa_hints(
- static_cast<dnnl_cpu_isa_hints_t>(isa_hints)));
- }
- /// @copydoc dnnl_get_cpu_isa_hints()
- inline cpu_isa_hints get_cpu_isa_hints() {
- return static_cast<cpu_isa_hints>(dnnl_get_cpu_isa_hints());
- }
- /// @} dnnl_api_service
- #ifdef DNNL_EXPERIMENTAL_PROFILING
- /// @addtogroup dnnl_api_profiling Profiling
- /// @{
- /// Profiling data kind.
- enum class profiling_data_kind {
- /// Undefined profiling data kind.
- undef = dnnl_profiling_data_kind_undef,
- /// Data kind to query an execution time in nanoseconds.
- time = dnnl_profiling_data_kind_time,
- };
- /// Resets a profiler's state.
- ///
- /// @param stream Stream associated with the profiler.
- inline void reset_profiling(stream &stream) {
- error::wrap_c_api(
- dnnl_reset_profiling(stream.get()), "could not reset profiling");
- }
- /// Returns requested profiling data. The profiling data accumulates for each
- /// primitive execution. The size of the vector will be equal to the number
- /// of executions since the last `dnnl::reset_profiling` call.
- ///
- /// The profiling data can be reset by calling #dnnl::reset_profiling.
- ///
- /// @note
- /// It is required to wait for all submitted primitives to complete
- /// using #dnnl::stream::wait prior to querying profiling data.
- ///
- /// @param stream Stream that was used for executing a primitive that
- /// is being profiled.
- /// @param data_kind Profiling data kind to query.
- ///
- /// @returns A vector with the requested profiling data.
- inline std::vector<uint64_t> get_profiling_data(
- stream &stream, profiling_data_kind data_kind) {
- int num_entries = 0;
- error::wrap_c_api(
- dnnl_query_profiling_data(stream.get(),
- static_cast<dnnl_profiling_data_kind_t>(data_kind),
- &num_entries, nullptr),
- "could not get number of entries for profiling data");
- if (num_entries == 0) return {};
- std::vector<uint64_t> data(num_entries);
- error::wrap_c_api(
- dnnl_query_profiling_data(stream.get(),
- static_cast<dnnl_profiling_data_kind_t>(data_kind),
- &num_entries, data.data()),
- "could not get profiling data");
- return data;
- }
- /// @} dnnl_api_profiling
- #endif
- /// @addtogroup dnnl_api_primitive_cache Primitive Cache
- ///
- /// A set of functions that provide primitive cache control.
- ///
- /// @{
- /// Returns the number of primitives that can be held in the primitive cache
- /// at the same time.
- inline int get_primitive_cache_capacity() {
- int result = 0;
- error::wrap_c_api(dnnl_get_primitive_cache_capacity(&result),
- "could not get primitive cache capacity");
- return result;
- }
- /// @copydoc dnnl_set_primitive_cache_capacity(int capacity)
- inline void set_primitive_cache_capacity(int capacity) {
- error::wrap_c_api(dnnl_set_primitive_cache_capacity(capacity),
- "could not set primitive cache capacity");
- }
- /// @} dnnl_api_primitive_cache
- /// @addtogroup dnnl_api_blas BLAS functions
- ///
- /// A subset of Basic Linear Algebra (BLAS) functions that perform
- /// matrix-matrix multiplication.
- ///
- /// @{
- /// @copydoc dnnl_sgemm()
- inline status sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N,
- dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda,
- const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc) {
- return static_cast<status>(dnnl_sgemm(
- transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc));
- }
- /// @copydoc dnnl_gemm_u8s8s32()
- inline status gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
- dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A,
- dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
- float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co) {
- return static_cast<status>(dnnl_gemm_u8s8s32(transa, transb, offsetc, M, N,
- K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co));
- }
- /// @copydoc dnnl_gemm_s8s8s32()
- inline status gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
- dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A,
- dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
- float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co) {
- return static_cast<status>(dnnl_gemm_s8s8s32(transa, transb, offsetc, M, N,
- K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co));
- }
- /// @} dnnl_api_blas
- // implementation section
- /// @cond DO_NOT_DOCUMENT_THIS
- inline primitive::primitive(const_dnnl_primitive_desc_t c_pd) {
- dnnl_primitive_t result;
- error::wrap_c_api(dnnl_primitive_create(&result, c_pd),
- "could not create a primitive");
- reset(result);
- }
- inline primitive::primitive(const_dnnl_primitive_desc_t c_pd,
- const std::vector<uint8_t> &cache_blob) {
- dnnl_primitive_t result;
- size_t size = cache_blob.size();
- const uint8_t *cache_blob_data = cache_blob.data();
- error::wrap_c_api(dnnl_primitive_create_from_cache_blob(
- &result, c_pd, size, cache_blob_data),
- "could not create a primitive from a cache blob");
- reset(result);
- }
- inline primitive::primitive(const primitive_desc &pd) : primitive(pd.get()) {}
- inline primitive::primitive(
- const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
- : primitive(pd.get(), cache_blob) {}
- inline void primitive::execute(const stream &astream,
- const std::unordered_map<int, memory> &args) const {
- std::vector<dnnl_exec_arg_t> c_args;
- c_args.reserve(args.size());
- for (const auto &a : args)
- c_args.push_back({a.first, a.second.get(true)});
- error::wrap_c_api(dnnl_primitive_execute(get(), astream.get(),
- (int)c_args.size(), c_args.data()),
- "could not execute a primitive");
- }
- /// @endcond
- #undef DNNL_DEFINE_BITMASK_OPS
- } // namespace dnnl
- /// oneAPI namespace
- /// The oneAPI namespace.
- /// Contains the oneapi::dnnl namespace as an alias to the ::dnnl namespace.
- namespace oneapi {
- // Note: without this guard, doxygen warns of potentially recursive namespace
- #ifndef DOXYGEN_SHOULD_SKIP_THIS
- /// oneDNN alias namespace
- namespace dnnl = ::dnnl;
- #endif
- } // namespace oneapi
- /// @} dnnl_api
- // NOLINTEND(readability-identifier-naming)
- #endif /* ONEAPI_DNNL_DNNL_HPP */
- #else
- #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
- #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|