dnnl.h 189 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. /*******************************************************************************
  3. * Copyright 2016-2025 Intel Corporation
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. *******************************************************************************/
  17. /// @file
  18. /// C API
  19. #ifndef ONEAPI_DNNL_DNNL_H
  20. #define ONEAPI_DNNL_DNNL_H
  21. #include "oneapi/dnnl/dnnl_common.h"
  22. #include "oneapi/dnnl/dnnl_config.h"
  23. #include "oneapi/dnnl/dnnl_types.h"
  24. #include "oneapi/dnnl/dnnl_version.h"
  25. #ifdef __cplusplus
  26. extern "C" {
  27. #endif
  28. /// @addtogroup dnnl_api
  29. /// @{
  30. /// @addtogroup dnnl_api_primitives
  31. /// @{
  32. /// @addtogroup dnnl_api_primitives_common
  33. /// @{
  34. /// Changes the primitive descriptor to point to the next available
  35. /// implementation.
  36. ///
  37. /// @param primitive_desc A primitive descriptor to change.
  38. /// @returns #dnnl_success on success and a status describing the error
  39. /// otherwise.
  40. /// @returns #dnnl_last_impl_reached if no more implementations available,
  41. /// in which case the primitive descriptor itself is kept unchanged.
  42. dnnl_status_t DNNL_API dnnl_primitive_desc_next_impl(
  43. dnnl_primitive_desc_t primitive_desc);
  44. /// Clones a primitive descriptor. The resulting primitive descriptor must be
  45. /// destroyed separately.
  46. ///
  47. /// @param primitive_desc Output primitive descriptor.
  48. /// @param existing_primitive_desc Primitive descriptor to clone.
  49. /// @returns #dnnl_success on success and a status describing the error
  50. /// otherwise.
  51. dnnl_status_t DNNL_API dnnl_primitive_desc_clone(
  52. dnnl_primitive_desc_t *primitive_desc,
  53. const_dnnl_primitive_desc_t existing_primitive_desc);
  54. /// Returns a constant reference to the attributes of a primitive descriptor.
  55. ///
  56. /// @warning
  57. /// It is an error to destroy the resulting @p attr.
  58. ///
  59. /// @warning
  60. /// The lifetime of an @p attr is the same as that of a @p
  61. /// primitive_desc, so it is an error to use the @p attr once the @p
  62. /// primitive_desc has been destroyed.
  63. ///
  64. /// @param primitive_desc Primitive descriptor.
  65. /// @param attr Output primitive attributes.
  66. /// @returns #dnnl_success on success and a status describing the error
  67. /// otherwise.
  68. dnnl_status_t DNNL_API dnnl_primitive_desc_get_attr(
  69. const_dnnl_primitive_desc_t primitive_desc,
  70. const_dnnl_primitive_attr_t *attr);
  71. /// Destroys a primitive descriptor.
  72. ///
  73. /// @param primitive_desc Primitive descriptor to destroy.
  74. /// @returns #dnnl_success on success and a status describing the error
  75. /// otherwise.
  76. dnnl_status_t DNNL_API dnnl_primitive_desc_destroy(
  77. dnnl_primitive_desc_t primitive_desc);
  78. /// Queries a primitive descriptor for various pieces of information.
  79. ///
  80. /// The most common use case is to query a primitive descriptor, created with
  81. /// source, weights, and destination memory descriptors with format tags set
  82. /// to #dnnl_format_tag_any, for the corresponding memory descriptors (in this
  83. /// case the @p what is set to #dnnl_query_src_md, #dnnl_query_weights_md, and
  84. /// #dnnl_query_dst_md respectively) so that it is possible to create memory
  85. /// objects and reorder primitives if necessary.
  86. ///
  87. /// Another typical use case is to query a primitive descriptor for workspace
  88. /// memory descriptor (with @p what set to #dnnl_query_workspace_md). If this
  89. /// query returns #dnnl_not_required status, then workspace memory is not
  90. /// required.
  91. ///
  92. /// @note
  93. /// When querying for a memory descriptor for a scratchpad, a workspace,
  94. /// or an optional parameter, the query will return a pointer to a zero
  95. /// memory descriptor if the parameter is not needed.
  96. ///
  97. /// A few other use cases:
  98. /// - query a primitive descriptor for the implementation information string
  99. /// (#dnnl_query_impl_info_str)
  100. /// - query a primitive descriptor for the number of inputs and outputs
  101. /// (#dnnl_query_num_of_inputs_s32 and #dnnl_query_num_of_outputs_s32
  102. /// respectively)
  103. ///
  104. /// @sa dnnl_query_t for more options
  105. ///
  106. /// @param primitive_desc Primitive descriptor.
  107. /// @param what Parameter to query.
  108. /// @param index Index of the parameter to query for.
  109. /// @param result Output result. The type depends on the query. For example,
  110. /// it must be a @c dnnl_memory_desc_t* if querying for a memory
  111. /// descriptor.
  112. /// @returns #dnnl_success on success and a status describing the error
  113. /// otherwise.
  114. dnnl_status_t DNNL_API dnnl_primitive_desc_query(
  115. const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what,
  116. int index, void *result);
  117. /// Queries primitive descriptor for a memory descriptor.
  118. ///
  119. /// @note
  120. /// This function is a convenience version of
  121. /// #dnnl_primitive_desc_query().
  122. ///
  123. /// @param primitive_desc Primitive descriptor.
  124. /// @param what Kind of memory descriptor parameter to query for.
  125. /// @param index Index of the parameter to query.
  126. /// @returns A pointer to the requested memory descriptor.
  127. /// @returns A pointer to a zero memory descriptor if the parameter is not
  128. /// needed.
  129. /// @returns NULL in case of any error.
  130. ///
  131. const_dnnl_memory_desc_t DNNL_API dnnl_primitive_desc_query_md(
  132. const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what,
  133. int index);
  134. /// Queries primitive descriptor for a signed 32bit int.
  135. ///
  136. /// @note
  137. /// This function is a convenience version of
  138. /// #dnnl_primitive_desc_query().
  139. ///
  140. /// @param primitive_desc Primitive descriptor.
  141. /// @param what Kind of the value to query for.
  142. /// @param index Index of the parameter to query.
  143. /// @returns The requested value.
  144. /// @returns 0 in case of any error (in particular if the queried entity is
  145. /// not of type int32_t). Note that 0 may also be the actual returned
  146. /// value.
  147. int DNNL_API dnnl_primitive_desc_query_s32(
  148. const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what,
  149. int index);
  150. /// Creates a primitive.
  151. ///
  152. /// @param primitive Output primitive.
  153. /// @param primitive_desc Primitive descriptor used to create the primitive.
  154. /// @returns #dnnl_success on success and a status describing the error
  155. /// otherwise.
  156. dnnl_status_t DNNL_API dnnl_primitive_create(dnnl_primitive_t *primitive,
  157. const_dnnl_primitive_desc_t primitive_desc);
  158. /// Creates a primitive from a cache blob.
  159. ///
  160. /// @param primitive Output primitive.
  161. /// @param primitive_desc Primitive descriptor used to create the primitive.
  162. /// @param size Size of the cache blob in bytes.
  163. /// @param cache_blob Cache blob of size @p size.
  164. /// @returns #dnnl_success on success and a status describing the error
  165. /// otherwise.
  166. dnnl_status_t DNNL_API dnnl_primitive_create_from_cache_blob(
  167. dnnl_primitive_t *primitive, const_dnnl_primitive_desc_t primitive_desc,
  168. size_t size, const uint8_t *cache_blob);
  169. /// Executes a primitive.
  170. ///
  171. /// @param primitive Primitive to execute.
  172. /// @param stream Stream to use.
  173. /// @param nargs Number of arguments.
  174. /// @param args Array of arguments. Each argument is an
  175. /// <index, #dnnl_memory_t> pair. The index is one of the `DNNL_ARG_*`
  176. /// values such as `DNNL_ARG_SRC`. Unless runtime shapes are used (see
  177. /// #DNNL_RUNTIME_DIM_VAL), the memory object must have the same memory
  178. /// descriptor as that returned by
  179. /// #dnnl_primitive_desc_query_md(#dnnl_query_exec_arg_md, index).
  180. /// @returns #dnnl_success on success and a status describing the error
  181. /// otherwise.
  182. /// @note If any argument in @p args is padded (padded_dims >
  183. /// dims), the primitive execution will assume properly zero-padded
  184. /// input arguments, and produce zero-padded output arguments.
  185. dnnl_status_t DNNL_API dnnl_primitive_execute(const_dnnl_primitive_t primitive,
  186. dnnl_stream_t stream, int nargs, const dnnl_exec_arg_t *args);
  187. /// Retrieves a constant reference to the primitive descriptor of a given
  188. /// primitive.
  189. ///
  190. /// @warning
  191. /// It is an error to destroy the returned object. It is owned by the
  192. /// primitive. The @c const qualifier of the returned object prevents
  193. /// such attempts.
  194. ///
  195. /// @param primitive Primitive to query for the primitive descriptor.
  196. /// @param primitive_desc Output primitive descriptor.
  197. /// @returns #dnnl_success on success and a status describing the error
  198. /// otherwise.
  199. dnnl_status_t DNNL_API dnnl_primitive_get_primitive_desc(
  200. const_dnnl_primitive_t primitive,
  201. const_dnnl_primitive_desc_t *primitive_desc);
  202. /// Retrieves a cache blob associated with the given primitive.
  203. ///
  204. /// @param primitive Primitive to query for the cache blob.
  205. /// @param size Size of the cache blob in bytes.
  206. /// @param cache_blob Cache blob of size @p size. If the @p cache_blob is
  207. /// nullptr then the size of the cache blob is returned in @p size.
  208. /// @returns #dnnl_success on success and a status describing the error
  209. /// otherwise.
  210. ///
  211. /// @note The cache blob can be empty. It's the user's responsibility to check
  212. /// whether it's empty prior to passing it to
  213. /// #dnnl_primitive_create_from_cache_blob().
  214. dnnl_status_t DNNL_API dnnl_primitive_get_cache_blob(
  215. const_dnnl_primitive_t primitive, size_t *size, uint8_t *cache_blob);
  216. /// Destroys a primitive.
  217. ///
  218. /// @param primitive The primitive to destroy.
  219. /// @returns #dnnl_success on success and a status describing the error
  220. /// otherwise.
  221. dnnl_status_t DNNL_API dnnl_primitive_destroy(dnnl_primitive_t primitive);
  222. /// @} dnnl_api_primitives_common
  223. /// @addtogroup dnnl_api_attributes
  224. /// @{
  225. /// Creates an empty (default) primitive attributes with all the parameters
  226. /// set to their default values.
  227. ///
  228. /// Empty attributes are implied whenever the respective argument is NULL.
  229. ///
  230. /// @param attr Output primitive attributes.
  231. /// @returns #dnnl_success on success and a status describing the error
  232. /// otherwise.
  233. dnnl_status_t DNNL_API dnnl_primitive_attr_create(dnnl_primitive_attr_t *attr);
  234. /// Clones primitive attributes.
  235. ///
  236. /// @param attr Output primitive attributes.
  237. /// @param existing_attr Primitive attributes to clone.
  238. /// @returns #dnnl_success on success and a status describing the error
  239. /// otherwise.
  240. dnnl_status_t DNNL_API dnnl_primitive_attr_clone(
  241. dnnl_primitive_attr_t *attr, const_dnnl_primitive_attr_t existing_attr);
  242. /// Destroys primitive attributes.
  243. ///
  244. /// @param attr Primitive attributes to destroy.
  245. /// @returns #dnnl_success on success and a status describing the error
  246. /// otherwise.
  247. dnnl_status_t DNNL_API dnnl_primitive_attr_destroy(dnnl_primitive_attr_t attr);
  248. /// Returns probability for output dropout primitive attribute.
  249. ///
  250. /// @param attr Primitive attributes.
  251. /// @param dropout_desc Output dropout memory descriptor
  252. /// @returns #dnnl_success on success and a status describing the error
  253. /// otherwise.
  254. dnnl_status_t DNNL_API dnnl_primitive_attr_get_dropout(
  255. const_dnnl_primitive_attr_t attr,
  256. const_dnnl_memory_desc_t *dropout_desc);
  257. /// Sets probability for output dropout primitive attribute.
  258. ///
  259. /// @param attr Primitive attributes.
  260. /// @param dropout_desc Output dropout memory descriptor
  261. /// @returns #dnnl_success on success and a status describing the error
  262. /// otherwise.
  263. dnnl_status_t DNNL_API dnnl_primitive_attr_set_dropout(
  264. dnnl_primitive_attr_t attr, const_dnnl_memory_desc_t dropout_desc);
  265. /// Returns the floating-point math mode primitive attribute.
  266. ///
  267. /// @param attr Primitive attributes.
  268. /// @param mode Output FP math mode.
  269. /// @returns #dnnl_success on success and a status describing the error
  270. /// otherwise.
  271. dnnl_status_t DNNL_API dnnl_primitive_attr_get_fpmath_mode(
  272. const_dnnl_primitive_attr_t attr, dnnl_fpmath_mode_t *mode);
  273. /// Sets the floating-point math mode primitive attributes.
  274. ///
  275. /// @param attr Primitive attributes.
  276. /// @param mode FP math mode. The possible values are:
  277. /// #dnnl_fpmath_mode_strict (default),
  278. /// #dnnl_fpmath_mode_bf16,
  279. /// #dnnl_fpmath_mode_f16,
  280. /// #dnnl_fpmath_mode_tf32,
  281. /// #dnnl_fpmath_mode_any.
  282. /// @returns #dnnl_success on success and a status describing the error
  283. /// otherwise.
  284. dnnl_status_t DNNL_API dnnl_primitive_attr_set_fpmath_mode(
  285. dnnl_primitive_attr_t attr, dnnl_fpmath_mode_t mode);
  286. /// Returns the floating-point math mode primitive attribute.
  287. ///
  288. /// @param attr Primitive attributes.
  289. /// @param mode Output FP math mode.
  290. /// @param apply_to_int Output use floating-point arithmetic for integer primitives.
  291. /// @returns #dnnl_success on success and a status describing the error
  292. /// otherwise.
  293. dnnl_status_t DNNL_API dnnl_primitive_attr_get_fpmath_mode_v2(
  294. const_dnnl_primitive_attr_t attr, dnnl_fpmath_mode_t *mode,
  295. int *apply_to_int);
  296. /// Sets the floating-point math mode primitive attributes.
  297. ///
  298. /// @param attr Primitive attributes.
  299. /// @param mode FP math mode. The possible values are:
  300. /// #dnnl_fpmath_mode_strict (default),
  301. /// #dnnl_fpmath_mode_bf16,
  302. /// #dnnl_fpmath_mode_f16,
  303. /// #dnnl_fpmath_mode_tf32,
  304. /// #dnnl_fpmath_mode_any.
  305. /// @param apply_to_int Boolean. Use of floating-point arithmetic for integer primitives.
  306. /// @returns #dnnl_success on success and a status describing the error
  307. /// otherwise.
  308. dnnl_status_t DNNL_API dnnl_primitive_attr_set_fpmath_mode_v2(
  309. dnnl_primitive_attr_t attr, dnnl_fpmath_mode_t mode, int apply_to_int);
  310. /// Returns the deterministic primitive attribute value.
  311. ///
  312. /// @param attr Primitive attributes.
  313. /// @param value Output deterministic attribute value
  314. /// @returns #dnnl_success on success and a status describing the error
  315. /// otherwise.
  316. dnnl_status_t DNNL_API dnnl_primitive_attr_get_deterministic(
  317. const_dnnl_primitive_attr_t attr, int *value);
  318. /// Sets the deterministic primitive attribute value.
  319. ///
  320. /// @param attr Primitive attributes.
  321. /// @param value Boolean value to set deterministic attribute.
  322. /// @returns #dnnl_success on success and a status describing the error
  323. /// otherwise.
  324. dnnl_status_t DNNL_API dnnl_primitive_attr_set_deterministic(
  325. dnnl_primitive_attr_t attr, int value);
  326. /// Returns the accumulation mode primitive attribute.
  327. ///
  328. /// @param attr Primitive attributes.
  329. /// @param mode Output accumulation mode.
  330. /// @returns #dnnl_success on success and a status describing the error
  331. /// otherwise.
  332. dnnl_status_t DNNL_API dnnl_primitive_attr_get_accumulation_mode(
  333. const_dnnl_primitive_attr_t attr, dnnl_accumulation_mode_t *mode);
  334. /// Sets the accumulation mode primitive attribute.
  335. ///
  336. /// @param attr Primitive attributes.
  337. /// @param mode Accumulation mode. The possible values are:
  338. /// #dnnl_accumulation_mode_strict (default), which is s32 for quantized primitives, f32/f64 otherwise
  339. /// #dnnl_accumulation_mode_relaxed, which is same as strict but allows intermediate accumulators to be in src/dst datatype
  340. /// #dnnl_accumulation_mode_any, which allows accumulators to be src/dst datatype or any wider type.
  341. /// #dnnl_accumulation_mode_f32,
  342. /// #dnnl_accumulation_mode_s32,
  343. /// #dnnl_accumulation_mode_f16.
  344. /// @returns #dnnl_success on success and a status describing the error
  345. /// otherwise.
  346. dnnl_status_t DNNL_API dnnl_primitive_attr_set_accumulation_mode(
  347. dnnl_primitive_attr_t attr, dnnl_accumulation_mode_t mode);
  348. /// Returns the primitive attributes scratchpad mode.
  349. ///
  350. /// @param attr Primitive attributes.
  351. /// @param mode Output scratchpad mode.
  352. /// @returns #dnnl_success on success and a status describing the error
  353. /// otherwise.
  354. dnnl_status_t DNNL_API dnnl_primitive_attr_get_scratchpad_mode(
  355. const_dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t *mode);
  356. /// Sets primitive attributes scratchpad mode.
  357. ///
  358. /// @param attr Primitive attributes.
  359. /// @param mode Scratchpad mode. The possible values are:
  360. /// #dnnl_scratchpad_mode_library (default) and
  361. /// #dnnl_scratchpad_mode_user.
  362. /// @returns #dnnl_success on success and a status describing the error
  363. /// otherwise.
  364. dnnl_status_t DNNL_API dnnl_primitive_attr_set_scratchpad_mode(
  365. dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t mode);
  366. /// Sets primitive attributes scaling factors for primitive operations for a
  367. /// given memory argument. The scaling factors must be passed at execution time
  368. /// as an argument with index #DNNL_ARG_ATTR_SCALES | arg.
  369. ///
  370. /// @sa dnnl_primitive_attr_set_scales_mask
  371. ///
  372. ///
  373. /// @param attr Primitive attributes.
  374. /// @param arg Parameter argument index as passed to the
  375. /// dnnl_primitive_execute() call.
  376. /// @param mask Scaling factors correspondence mask that defines the
  377. /// correspondence between the tensor dimensions and the @p scales array.
  378. /// The set i-th bit indicates that a dedicated scaling factor is used for
  379. /// each index along that dimension. Set the mask to 0 to use a common
  380. /// scaling factor for the whole output tensor.
  381. /// @returns #dnnl_success on success and a status describing the error
  382. /// otherwise.
  383. dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_mask(
  384. dnnl_primitive_attr_t attr, int arg, int mask);
  385. /// Sets primitive attributes scaling factors for primitive operations for a
  386. /// given memory argument. The scaling factors must be passed at execution time
  387. /// as an argument with index #DNNL_ARG_ATTR_SCALES | arg.
  388. ///
  389. /// @sa dnnl_primitive_attr_set_scales
  390. ///
  391. ///
  392. /// @param attr Primitive attributes.
  393. /// @param arg Parameter argument index as passed to the
  394. /// dnnl_primitive_execute() call.
  395. /// @param mask Scaling factors correspondence mask that defines the
  396. /// correspondence between the tensor dimensions and the @p scales array.
  397. /// The set i-th bit indicates that a dedicated scaling factor is used for
  398. /// each index along that dimension. Set the mask to 0 to use a common
  399. /// scaling factor for the whole output tensor.
  400. /// @param group_ndims Number of group dimensions.
  401. /// @param group_dims Scaling factors correspondence groups that define the
  402. /// correspondence between the tensor dimensions and the scales array.
  403. /// The group dimensions should only be provided for each logical dimension
  404. /// that has correspondence mask @p mask set.
  405. /// @param data_type Scaling factors data_type.
  406. /// @returns #dnnl_success on success and a status describing the error
  407. /// otherwise.
  408. dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales(
  409. dnnl_primitive_attr_t attr, int arg, int mask, int group_ndims,
  410. const dnnl_dims_t group_dims, dnnl_data_type_t data_type);
  411. /// Sets primitive attributes scaling factors for primitive operations for a
  412. /// given memory argument. The scaling factors must be passed at execution time
  413. /// as an argument with index #DNNL_ARG_ATTR_SCALES | arg.
  414. /// If `is_on_host` is true, sets a single host-side scalar scaling factor
  415. /// for the specified memory argument. In this case, the scaling factor must
  416. /// be provided as a host scalar memory object at execution time with index
  417. /// #DNNL_ARG_ATTR_SCALES | arg.
  418. ///
  419. /// @sa dnnl_primitive_attr_set_scales
  420. ///
  421. ///
  422. /// @param attr Primitive attributes.
  423. /// @param arg Parameter argument index as passed to the
  424. /// dnnl_primitive_execute() call.
  425. /// @param mask Scaling factors correspondence mask that defines the
  426. /// correspondence between the tensor dimensions and the @p scales array.
  427. /// The set i-th bit indicates that a dedicated scaling factor is used for
  428. /// each index along that dimension. Set the mask to 0 to use a common
  429. /// scaling factor for the whole output tensor.
  430. /// @param ndims Number of group dimensions.
  431. /// @param group_dims Scaling factors correspondence groups that define the
  432. /// correspondence between the tensor dimensions and the scales array.
  433. /// The group dimensions should only be provided for each logical dimension
  434. /// that has correspondence mask @p mask set.
  435. /// @param data_type Scaling factors data_type.
  436. /// @param is_on_host Indicates whether the zero point is a host-side scalar.
  437. /// @returns #dnnl_success on success and a status describing the error
  438. /// otherwise.
  439. dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_v2(
  440. dnnl_primitive_attr_t attr, int arg, int mask, int ndims,
  441. const dnnl_dims_t group_dims, dnnl_data_type_t data_type,
  442. int is_on_host);
  443. /// Sets primitive attributes zero points for primitive operations for a given
  444. /// memory argument. The zero points must be passed at execution time
  445. /// as an argument with index #DNNL_ARG_ATTR_ZERO_POINTS | arg.
  446. ///
  447. /// @param attr Primitive attributes.
  448. /// @param arg Parameter argument index as passed to the
  449. /// dnnl_primitive_execute() call.
  450. /// @param mask Zero point correspondence mask that defines the
  451. /// correspondence between the tensor dimensions and the @p
  452. /// zero_points array. The set i-th bit indicates that a dedicated
  453. /// zero point is used for each index along that dimension. Set the
  454. /// mask to 0 to use a common zero point for the whole output tensor.
  455. /// @returns #dnnl_success on success and a status describing the error
  456. /// otherwise.
  457. dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points_mask(
  458. dnnl_primitive_attr_t attr, int arg, int mask);
  459. /// Sets primitive attributes zero points for primitive operations for a given
  460. /// memory argument. The zero points must be passed at execution time
  461. /// as an argument with index #DNNL_ARG_ATTR_ZERO_POINTS | arg.
  462. ///
  463. /// @sa dnnl_primitive_attr_set_zero_points
  464. ///
  465. ///
  466. /// @param attr Primitive attributes.
  467. /// @param arg Parameter argument index as passed to the
  468. /// dnnl_primitive_execute() call.
  469. /// @param mask Zero point correspondence mask that defines the
  470. /// correspondence between the tensor dimensions and the
  471. /// zero points array. The set i-th bit indicates that a dedicated
  472. /// zero point is used for each index along that dimension. Set the
  473. /// mask to 0 to use a common zero point for the whole output tensor.
  474. /// @param group_ndims Number of group dimensions.
  475. /// @param group_dims Zero point factors correspondence groups that define the
  476. /// correspondence between the tensor dimensions and the zero points array.
  477. /// The group dimensions should be only provided for each logical dimension
  478. /// that has the bit set correspondence mask @p mask set.
  479. /// @param data_type Zero points factors data_type.
  480. /// @returns #dnnl_success on success and a status describing the error
  481. /// otherwise.
  482. dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points(
  483. dnnl_primitive_attr_t attr, int arg, int mask, int group_ndims,
  484. const dnnl_dims_t group_dims, dnnl_data_type_t data_type);
  485. /// Sets primitive attributes precomputed reductions for primitive operations
  486. /// for a given memory argument. The precomputed reductions must be passed at
  487. /// execution time as an argument with index
  488. /// #DNNL_ARG_ATTR_PRECOMPUTED_REDUCTIONS | arg.
  489. ///
  490. /// @sa dnnl_primitive_attr_set_precomputed_reductions
  491. ///
  492. ///
  493. /// @param attr Primitive attributes.
  494. /// @param arg Parameter argument index as passed to the
  495. /// dnnl_primitive_execute() call.
  496. /// @param mask Precomputed reductions correspondence mask that defines the
  497. /// correspondence between the tensor dimensions and the precomputed
  498. /// reductions array. The set i-th bit indicates that a dedicated
  499. /// precomputed reductions is used for each index along that dimension.
  500. /// @param group_ndims Number of group dimensions.
  501. /// @param group_dims Precomputed reduction factors correspondence groups that
  502. /// define the correspondence between the tensor dimensions and the
  503. /// precomputed reductions array.
  504. /// The group dimensions should be only provided for each logical dimension
  505. /// that has the bit set correspondence mask @p mask set.
  506. /// @param data_type Precomputed reduction factors data_type.
  507. /// @returns #dnnl_success on success and a status describing the error
  508. /// otherwise.
  509. dnnl_status_t DNNL_API dnnl_primitive_attr_set_precomputed_reductions(
  510. dnnl_primitive_attr_t attr, int arg, int mask, int group_ndims,
  511. const dnnl_dims_t group_dims, dnnl_data_type_t data_type);
  512. /// Sets primitive attributes zero points for primitive operations for a given
  513. /// memory argument. The zero points must be passed at execution time
  514. /// as an argument with index #DNNL_ARG_ATTR_ZERO_POINTS | arg.
  515. /// If `is_on_host` is true, sets a single host-side scalar zero point
  516. /// for the specified memory argument. In this case, the zero point must
  517. /// be provided as a host scalar memory object at execution time with index
  518. /// #DNNL_ARG_ATTR_ZERO_POINTS | arg.
  519. ///
  520. /// @sa dnnl_primitive_attr_set_zero_points
  521. ///
  522. /// @param attr Primitive attributes.
  523. /// @param arg Parameter argument index as passed to the
  524. /// dnnl_primitive_execute() call.
  525. /// @param mask Zero point correspondence mask that defines the
  526. /// correspondence between the tensor dimensions and the @p
  527. /// zero_points array. The set i-th bit indicates that a dedicated
  528. /// zero point is used for each index along that dimension. Set the
  529. /// mask to 0 to use a common zero point for the whole output tensor.
  530. /// @param ndims Number of group dimensions.
  531. /// @param group_dims Zero point factors correspondence groups that define the
  532. /// correspondence between the tensor dimensions and the zero_points array.
  533. /// The group dimensions should be only provided for each logical dimension
  534. /// that has the bit set correspondence mask @p mask set.
  535. /// @param data_type Zero points factors data_type.
  536. /// @param is_on_host Indicates whether the zero point is a host-side scalar.
  537. /// @returns #dnnl_success on success and a status describing the error
  538. /// otherwise.
  539. dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points_v2(
  540. dnnl_primitive_attr_t attr, int arg, int mask, int ndims,
  541. const dnnl_dims_t group_dims, dnnl_data_type_t data_type,
  542. int is_on_host);
  543. /// Sets the rounding mode attribute value for a given argument
  544. ///
  545. /// @param attr Primitive attributes.
  546. /// @param arg Argument for which rounding mode should be set.
  547. /// @param mode Rounding mode to apply to the argument.
  548. /// @returns #dnnl_success on success and a status describing the error
  549. /// otherwise.
  550. dnnl_status_t DNNL_API dnnl_primitive_attr_set_rounding(
  551. dnnl_primitive_attr_t attr, int arg, dnnl_rounding_mode_t mode);
  552. /// Returns the rounding mode attribute value for a given argument
  553. ///
  554. /// @param attr Primitive attributes.
  555. /// @param arg Argument for which rounding mode query applies.
  556. /// @param mode Output rounding mode.
  557. /// @returns #dnnl_success on success and a status describing the error
  558. /// otherwise.
  559. dnnl_status_t DNNL_API dnnl_primitive_attr_get_rounding(
  560. dnnl_primitive_attr_t attr, int arg, dnnl_rounding_mode_t *mode);
  561. /// Returns primitive attributes post-ops.
  562. ///
  563. /// @warning
  564. /// The output @p post_ops points to the internal @p attr field, so it is
  565. /// an error to modify or destroy them. The lifetime of @p post_ops is
  566. /// the same as that of the @p attr it belongs to, so it is an error to
  567. /// use @p post_ops after @p attr has been destroyed.
  568. ///
  569. /// @param attr Primitive attributes.
  570. /// @param post_ops Output post-ops.
  571. /// @returns #dnnl_success on success and a status describing the error
  572. /// otherwise.
  573. dnnl_status_t DNNL_API dnnl_primitive_attr_get_post_ops(
  574. const_dnnl_primitive_attr_t attr, const_dnnl_post_ops_t *post_ops);
  575. /// Sets primitive attributes post-ops.
  576. ///
  577. /// @note
  578. /// There is no way to check whether the post-ops would be supported by
  579. /// the target primitive. Any error will be reported by the
  580. /// dnnl_<primitive name>_[propagation kind]_primitive_desc_create() function call.
  581. ///
  582. /// @param attr Primitive attributes.
  583. /// @param post_ops Post-ops to set.
  584. /// @returns #dnnl_success on success and a status describing the error
  585. /// otherwise.
  586. dnnl_status_t DNNL_API dnnl_primitive_attr_set_post_ops(
  587. dnnl_primitive_attr_t attr, const_dnnl_post_ops_t post_ops);
  588. /// Creates empty post-ops sequence.
  589. ///
  590. /// @param post_ops Output post-ops.
  591. /// @returns #dnnl_success on success and a status describing the error
  592. /// otherwise.
  593. dnnl_status_t DNNL_API dnnl_post_ops_create(dnnl_post_ops_t *post_ops);
  594. /// Clones post-ops primitive attribute.
  595. ///
  596. /// @param post_ops Output post-ops primitive attribute.
  597. /// @param existing_post_ops Post-ops primitive attribute to clone.
  598. /// @returns #dnnl_success on success and a status describing the error
  599. /// otherwise.
  600. dnnl_status_t DNNL_API dnnl_post_ops_clone(
  601. dnnl_post_ops_t *post_ops, const_dnnl_post_ops_t existing_post_ops);
  602. /// Destroys post-ops.
  603. ///
  604. /// @param post_ops Post-ops to destroy.
  605. /// @returns #dnnl_success on success and a status describing the error
  606. /// otherwise.
  607. dnnl_status_t DNNL_API dnnl_post_ops_destroy(dnnl_post_ops_t post_ops);
  608. /// Returns the length of post-ops.
  609. ///
  610. /// @param post_ops Post-ops.
  611. /// @returns The number of post-ops entries.
  612. int DNNL_API dnnl_post_ops_len(const_dnnl_post_ops_t post_ops);
  613. /// Returns the kind of a post-op entry.
  614. ///
  615. /// @param post_ops Post-ops.
  616. /// @param index Post-op entry index.
  617. /// @returns The kind of the post-op with the specified index.
  618. /// @returns #dnnl_undefined_primitive if there is no post-op at the specified
  619. /// index.
  620. dnnl_primitive_kind_t DNNL_API dnnl_post_ops_get_kind(
  621. const_dnnl_post_ops_t post_ops, int index);
  622. /// Appends an accumulation v3 (sum) to post-ops. Prior to accumulating the
  623. /// result, a zero point is subtracted from the previous value and is
  624. /// multiplied by the scale.
  625. ///
  626. /// The kind of this post-op is #dnnl_sum.
  627. ///
  628. /// This feature may improve performance for cases like dequantize the
  629. /// asymmetrically quantized sum's src1 tensor to f32 domain before performing
  630. /// the sum operation by subtracting the @p zero_point before the scaling.
  631. ///
  632. /// In the simplest case where accumulation is the only post-op, the
  633. /// computations will be:
  634. ///
  635. /// dst[:] <- scale * (dst[:] - zero_point) + op(...)
  636. /// // instead of dst[:] <- op(...)
  637. ///
  638. /// If @p data_type is specified, original dst tensor will be reinterpreted
  639. /// as a tensor with provided data type. Since it is reinterpretation,
  640. /// data_type and dst data type should have the same size.
  641. /// As a result, computations will be:
  642. ///
  643. /// dst[:] <- scale * (as_data_type(dst[:]) - zero_point) + op(...)
  644. /// // instead of dst[:] <- op(...)
  645. /// @note
  646. /// This post-op executes in-place and does not change the
  647. /// destination layout.
  648. ///
  649. /// @param post_ops Post-ops.
  650. /// @param scale Accumulation scaling factor.
  651. /// @param zero_point Single scalar int32_t value of zero point.
  652. /// @param data_type Accumulation data_type.
  653. /// @returns #dnnl_success on success and a status describing the error
  654. /// otherwise.
  655. dnnl_status_t DNNL_API dnnl_post_ops_append_sum(dnnl_post_ops_t post_ops,
  656. float scale, int32_t zero_point, dnnl_data_type_t data_type);
  657. /// Returns the parameters of an accumulation (sum) post-op with
  658. /// zero point and data type parameter.
  659. ///
  660. /// @param post_ops Post-ops.
  661. /// @param index Index of the sum post-op.
  662. /// @param scale Output accumulation scaling factor.
  663. /// @param zero_point Zero point.
  664. /// @param data_type Data type for accumulation.
  665. /// @returns #dnnl_success on success and a status describing the error
  666. /// otherwise.
  667. dnnl_status_t DNNL_API dnnl_post_ops_get_params_sum(
  668. const_dnnl_post_ops_t post_ops, int index, float *scale,
  669. int32_t *zero_point, dnnl_data_type_t *data_type);
  670. /// Appends an elementwise post-op.
  671. ///
  672. /// The kind of this post operation is #dnnl_eltwise.
  673. ///
  674. /// In the simplest case when the elementwise is the only post operation, the
  675. /// computations would be:
  676. ///
  677. /// dst[:] <- eltwise_op (op(...)) // instead of dst[:] <- op(...)
  678. ///
  679. /// where eltwise_op is configured with the given parameters.
  680. ///
  681. /// @param post_ops Post-ops.
  682. /// @param alg_kind Elementwise algorithm for the post-op.
  683. /// @param alpha Alpha parameter for the elementwise algorithm.
  684. /// @param beta Beta parameter for the elementwise algorithm.
  685. /// @returns #dnnl_success on success and a status describing the error
  686. /// otherwise.
  687. dnnl_status_t DNNL_API dnnl_post_ops_append_eltwise(dnnl_post_ops_t post_ops,
  688. dnnl_alg_kind_t alg_kind, float alpha, float beta);
  689. /// Returns the parameters of an elementwise post-op.
  690. ///
  691. /// @param post_ops Post-ops.
  692. /// @param index Index of the elementwise post-op.
  693. /// @param alg_kind Output elementwise algorithm kind.
  694. /// @param alpha Output alpha parameter for the elementwise algorithm.
  695. /// @param beta Output beta parameter for the elementwise algorithm.
  696. /// @returns #dnnl_success on success and a status describing the error
  697. /// otherwise.
  698. /// @returns #dnnl_invalid_arguments if @p index does not refer to an
  699. /// elementwise post-op.
  700. dnnl_status_t DNNL_API dnnl_post_ops_get_params_eltwise(
  701. const_dnnl_post_ops_t post_ops, int index, dnnl_alg_kind_t *alg_kind,
  702. float *alpha, float *beta);
  703. /// Appends a depthwise post-op convolution.
  704. ///
  705. /// This post-op can only be fused with a 2D 1x1 convolution (convolution with
  706. /// weights spatial dimensions equal to 1 i.e., kh=kw=1).
  707. ///
  708. /// The kind of this post-op is #dnnl_convolution.
  709. ///
  710. /// The number of outputs for primitive with fusion is one. The output spatial
  711. /// size can be derived as below:
  712. ///
  713. /// output_height = ceil(output_height_1x1_convolution, stride)
  714. /// output_width = ceil(output_width_1x1_convolution, stride)
  715. ///
  716. /// See @ref dev_guide_attributes_post_ops_depthwise and
  717. /// @ref dev_guide_attributes_post_ops_depthwise_fusion for more info.
  718. ///
  719. /// @param post_ops Post-ops.
  720. /// @param weights_data_type Weights data type of depthwise post-op
  721. /// @param bias_data_type Bias data type of depthwise post-op
  722. /// @param dst_data_type Output data type of depthwise post-op
  723. /// @param kernel_size Size of kernel of depthwise post-op
  724. /// @param stride_size Size of stride of depthwise post-op
  725. /// @param padding_l_size Size of left and top paddings of depthwise post-op
  726. /// @returns #dnnl_success on success and a status describing the error
  727. /// otherwise
  728. dnnl_status_t DNNL_API dnnl_post_ops_append_dw(dnnl_post_ops_t post_ops,
  729. dnnl_data_type_t weights_data_type, dnnl_data_type_t bias_data_type,
  730. dnnl_data_type_t dst_data_type, dnnl_dim_t kernel_size,
  731. dnnl_dim_t stride_size, dnnl_dim_t padding_l_size);
  732. /// Returns the parameters of an depthwise post-op.
  733. ///
  734. /// @param post_ops Post-ops.
  735. /// @param index Index of the elementwise post-op.
  736. /// @param weights_data_type Weights data type of depthwise post-op
  737. /// @param bias_data_type Bias data type of depthwise post-op
  738. /// @param dst_data_type Output data type of depthwise post-op
  739. /// @param kernel_size Size of kernel of depthwise post-op
  740. /// @param stride_size Size of stride of depthwise post-op
  741. /// @param padding_l_size Size of left and top paddings of depthwise post-op
  742. /// @returns #dnnl_success on success and a status describing the error
  743. /// otherwise
  744. dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw(
  745. const_dnnl_post_ops_t post_ops, int index,
  746. dnnl_data_type_t *weights_data_type, dnnl_data_type_t *bias_data_type,
  747. dnnl_data_type_t *dst_data_type, dnnl_dim_t *kernel_size,
  748. dnnl_dim_t *stride_size, dnnl_dim_t *padding_l_size);
  749. /// Appends a binary post-op.
  750. ///
  751. /// This post operation is categorized as #dnnl_binary.
  752. ///
  753. /// In the simplest case when the binary is the only post operation, the
  754. /// computations would be:
  755. ///
  756. /// dst[:] <- binary_op (dst[:], another_input[:])
  757. ///
  758. /// where binary_op is configured with the given parameters. binary_op supports
  759. /// broadcast semantics for a second operand.
  760. ///
  761. /// @param post_ops Post-ops.
  762. /// @param alg_kind Binary algorithm for the post-op.
  763. /// @param src1_desc Memory descriptor of a second operand.
  764. /// @returns #dnnl_success on success and a status describing the error
  765. /// otherwise.
  766. dnnl_status_t DNNL_API dnnl_post_ops_append_binary(dnnl_post_ops_t post_ops,
  767. dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src1_desc);
  768. /// Appends a binary post-op with ternary operators.
  769. ///
  770. /// This post operation is categorized as #dnnl_binary.
  771. ///
  772. /// In the simplest case when the binary is the only post operation, the
  773. /// computations will be:
  774. ///
  775. /// dst[:] <- binary_op (dst[:], another_input1[:], another_input2[:])
  776. ///
  777. /// where binary_op is configured with the given parameters. binary_op supports
  778. /// broadcast semantics only for the second operand and not for the third
  779. /// operand.
  780. ///
  781. /// @param post_ops Post-ops.
  782. /// @param alg_kind Binary algorithm for the post-op.
  783. /// @param src1_desc Memory descriptor of a second operand.
  784. /// @param src2_desc Memory descriptor of a third operand. If the specificed
  785. /// algorithm is not one that requires a ternary input, src2_desc will be
  786. /// ignored.
  787. /// @returns #dnnl_success on success and a status describing the error
  788. /// otherwise.
  789. dnnl_status_t DNNL_API dnnl_post_ops_append_binary_v2(dnnl_post_ops_t post_ops,
  790. dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src1_desc,
  791. const_dnnl_memory_desc_t src2_desc);
  792. /// Returns the parameters of a binary post-op.
  793. ///
  794. /// @param post_ops Post-ops.
  795. /// @param index Index of the binary post-op.
  796. /// @param alg_kind Output binary algorithm kind.
  797. /// @param src1_desc Output memory descriptor of a second operand.
  798. /// @returns #dnnl_success on success and a status describing the error
  799. /// otherwise.
  800. /// @returns #dnnl_invalid_arguments if @p index does not refer to a binary
  801. /// post-op.
  802. dnnl_status_t DNNL_API dnnl_post_ops_get_params_binary(
  803. const_dnnl_post_ops_t post_ops, int index, dnnl_alg_kind_t *alg_kind,
  804. const_dnnl_memory_desc_t *src1_desc);
  805. /// Returns the parameters of a binary post-op with ternary operators.
  806. ///
  807. /// @param post_ops Post-ops.
  808. /// @param index Index of the binary post-op.
  809. /// @param alg_kind Output binary algorithm kind.
  810. /// @param src1_desc Output memory descriptor of a second operand.
  811. /// @param src2_desc Output memory descriptor of a third operand. If the
  812. /// specified algorithm is not one that requires a ternary input, src2_desc
  813. /// will be ignored.
  814. /// @returns #dnnl_success on success and a status describing the error
  815. /// otherwise.
  816. /// @returns #dnnl_invalid_arguments if @p index does not refer to a binary
  817. /// post-op.
  818. dnnl_status_t DNNL_API dnnl_post_ops_get_params_binary_v2(
  819. const_dnnl_post_ops_t post_ops, int index, dnnl_alg_kind_t *alg_kind,
  820. const_dnnl_memory_desc_t *src1_desc,
  821. const_dnnl_memory_desc_t *src2_desc);
  822. /// Appends a prelu forward post-op.
  823. ///
  824. /// The kind of this post-op is #dnnl::primitive::kind::prelu.
  825. ///
  826. /// The post-op can be defined as:
  827. ///
  828. /// dst[:] <- prelu(dst[:], weights[:])
  829. /// prelu:
  830. /// dst[:] <- dst[:] if dst[:] > 0
  831. /// dst[:] <- dst[:] * weights[:] if dst[:] <= 0
  832. ///
  833. ///
  834. /// @note
  835. /// The order of dimensions does not depend on how elements are laid
  836. /// out in memory. For example:
  837. /// - for a 2D CNN activations tensor the order is always (n, c)
  838. /// - for a 4D CNN activations tensor the order is always (n, c, h, w)
  839. /// - for a 5D CNN weights tensor the order is always
  840. /// (g, oc, ic, kh, kw)
  841. ///
  842. /// Prelu weights tensor is passed in runtime execution phase. Prelu
  843. /// weights tensor data type is implicitly assumed as f32 using plain
  844. /// layout (a, ab, acb, acdb, acdeb)
  845. ///
  846. /// @param post_ops Post-ops.
  847. /// @param mask Defines the correspondence between the output tensor
  848. /// dimensions and the prelu weights tensor. The set i-th bit indicates
  849. /// that a dedicated weights value is used for each index along that
  850. /// dimension. Set the mask to 0 to use a common weights value
  851. /// for the whole output tensor.
  852. /// @returns #dnnl_success on success and a status describing the error
  853. /// otherwise.
  854. dnnl_status_t DNNL_API dnnl_post_ops_append_prelu(
  855. dnnl_post_ops_t post_ops, int mask);
  856. /// Returns the parameters of a prelu post-op.
  857. ///
  858. /// @param post_ops Post-ops.
  859. /// @param index Index of the prelu post-op.
  860. /// @param mask Mask of the prelu post-op.
  861. /// @returns #dnnl_success on success and a status describing the error
  862. /// otherwise.
  863. dnnl_status_t DNNL_API dnnl_post_ops_get_params_prelu(
  864. const_dnnl_post_ops_t post_ops, int index, int *mask);
  865. /// @} dnnl_api_attributes
  866. /// @} dnnl_api_primitives
  867. /// @addtogroup dnnl_api_memory
  868. /// @{
  869. /// Destroys a memory descriptor.
  870. ///
  871. /// @param memory_desc Memory descriptor to destroy.
  872. /// @returns #dnnl_success on success and a status describing the error
  873. /// otherwise.
  874. dnnl_status_t DNNL_API dnnl_memory_desc_destroy(dnnl_memory_desc_t memory_desc);
  875. /// Clones a memory descriptor. The resulting memory descriptor must be
  876. /// destroyed separately.
  877. ///
  878. /// @param memory_desc Output memory descriptor.
  879. /// @param existing_memory_desc Memory descriptor to clone.
  880. /// @returns #dnnl_success on success and a status describing the error
  881. /// otherwise.
  882. dnnl_status_t DNNL_API dnnl_memory_desc_clone(dnnl_memory_desc_t *memory_desc,
  883. const_dnnl_memory_desc_t existing_memory_desc);
  884. /// Retrieves a binary blob associated with the given memory descriptor
  885. ///
  886. /// @param blob Output pointer to binary blob.
  887. /// If not nullptr, size bytes of the memory descriptor blob are written.
  888. /// @param size Output pointer to the size of the binary blob in bytes.
  889. /// Size is written if blob is nullptr.
  890. /// @param memory_desc input memory descriptor to serialize
  891. /// @returns #dnnl_success on success and a status describing the error
  892. /// otherwise.
  893. dnnl_status_t DNNL_API dnnl_memory_desc_get_blob(
  894. uint8_t *blob, size_t *size, const_dnnl_memory_desc_t memory_desc);
  895. /// Creates a memory descriptor from a memory descriptor binary blob.
  896. ///
  897. /// @param memory_desc Output pointer to a newly allocated memory descriptor.
  898. /// @param blob Pointer to a memory descriptor binary blob.
  899. /// @returns #dnnl_success on success and a status describing the error
  900. /// otherwise.
  901. dnnl_status_t DNNL_API dnnl_memory_desc_create_with_blob(
  902. dnnl_memory_desc_t *memory_desc, const uint8_t *blob);
  903. /// Creates a memory descriptor using dimensions and strides.
  904. ///
  905. /// @note
  906. /// As always, the logical order of dimensions corresponds to the `abc...`
  907. /// format tag, and the physical meaning of the dimensions depends on both
  908. /// the primitive that consumes the memory and the context of that
  909. /// consumption.
  910. ///
  911. /// @param memory_desc Output memory descriptor.
  912. /// @param ndims Number of dimensions
  913. /// @param dims Array of dimensions.
  914. /// @param data_type Elements data type.
  915. /// @param strides Strides in each dimension.
  916. /// @returns #dnnl_success on success and a status describing the error
  917. /// otherwise.
  918. dnnl_status_t DNNL_API dnnl_memory_desc_create_with_strides(
  919. dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims,
  920. dnnl_data_type_t data_type, const dnnl_dims_t strides);
  921. /// Creates a memory descriptor using dimensions and memory format tag.
  922. ///
  923. /// @note
  924. /// As always, the logical order of dimensions corresponds to the `abc...`
  925. /// format tag, and the physical meaning of the dimensions depends on both
  926. /// the primitive that consumes the memory and the context of that
  927. /// consumption.
  928. ///
  929. /// @param memory_desc Output memory descriptor.
  930. /// @param ndims Number of dimensions
  931. /// @param dims Array of dimensions.
  932. /// @param data_type Elements data type.
  933. /// @param tag Memory format tag. Can be #dnnl_format_tag_any which would
  934. /// allow a primitive to chose the final memory format. In this case the
  935. /// format_kind field of the memory descriptor would be set to
  936. /// #dnnl_format_kind_any.
  937. /// @returns #dnnl_success on success and a status describing the error
  938. /// otherwise.
  939. dnnl_status_t DNNL_API dnnl_memory_desc_create_with_tag(
  940. dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims,
  941. dnnl_data_type_t data_type, dnnl_format_tag_t tag);
  942. /// Creates a memory descriptor for CSR encoding.
  943. ///
  944. /// @param memory_desc Output memory descriptor.
  945. /// @param ndims Number of dimensions
  946. /// @param dims Array of dimensions.
  947. /// @param data_type Elements data type.
  948. /// @param nnz Number of non-zero entries.
  949. /// @param indices_dt Data type of indices.
  950. /// @param pointers_dt Data type of pointers.
  951. /// @returns #dnnl_success on success and a status describing the error
  952. /// otherwise.
  953. dnnl_status_t DNNL_API dnnl_memory_desc_create_with_csr_encoding(
  954. dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims,
  955. dnnl_data_type_t data_type, dnnl_dim_t nnz, dnnl_data_type_t indices_dt,
  956. dnnl_data_type_t pointers_dt);
  957. /// Creates a memory descriptor for COO encoding.
  958. ///
  959. /// The created memory descriptor will describe a memory object that
  960. /// contains n+1 buffers for an n-dimensional tensor.
  961. /// The buffers have the following meaning and assigned numbers (index):
  962. /// - 0: values
  963. /// - 1: indices for dimension 0
  964. /// - 2: indices for dimension 1 ...
  965. /// - n: indices for dimension n-1
  966. ///
  967. /// @param memory_desc Output memory descriptor.
  968. /// @param ndims Number of dimensions.
  969. /// @param dims Array of dimensions.
  970. /// @param data_type Elements data type.
  971. /// @param nnz Number of non-zero entries.
  972. /// @param indices_dt Data type of indices.
  973. /// @returns #dnnl_success on success and a status describing the error
  974. /// otherwise.
  975. dnnl_status_t DNNL_API dnnl_memory_desc_create_with_coo_encoding(
  976. dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims,
  977. dnnl_data_type_t data_type, dnnl_dim_t nnz,
  978. dnnl_data_type_t indices_dt);
  979. /// Creates a memory descriptor for packed sparse encoding.
  980. ///
  981. /// The created memory descriptor cannot be used to create a memory
  982. /// object. It can only be used to create a primitive descriptor to
  983. /// query the actual memory descriptor (similar to the format tag
  984. /// `any`).
  985. ///
  986. /// @warning
  987. /// The meaning and content of the handles of the memory object that
  988. /// is created using the queried memory descriptor are unspecified
  989. /// therefore using the content is an undefined behavior.
  990. ///
  991. /// @param memory_desc Output memory descriptor.
  992. /// @param ndims Number of dimensions
  993. /// @param dims Array of dimensions.
  994. /// @param data_type Elements data type.
  995. /// @param nnz Number of non-zero entries.
  996. /// @returns #dnnl_success on success and a status describing the error
  997. /// otherwise.
  998. /// @sa @ref dev_guide_sparsity
  999. dnnl_status_t DNNL_API dnnl_memory_desc_create_with_packed_encoding(
  1000. dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims,
  1001. dnnl_data_type_t data_type, dnnl_dim_t nnz);
  1002. /// Creates a memory descriptor for a scalar value that resides on the host.
  1003. ///
  1004. /// @param memory_desc Output memory descriptor.
  1005. /// @param data_type Elements data type.
  1006. /// @returns #dnnl_success on success and a status describing the error
  1007. /// otherwise.
  1008. dnnl_status_t DNNL_API dnnl_memory_desc_create_host_scalar(
  1009. dnnl_memory_desc_t *memory_desc, dnnl_data_type_t data_type);
  1010. /// Creates a memory descriptor for a region inside an area
  1011. /// described by an existing memory descriptor.
  1012. ///
  1013. /// @warning
  1014. /// Some combinations of physical memory layout and/or offsets or dims may
  1015. /// result in a failure to create a submemory.
  1016. //
  1017. /// @param memory_desc Output memory descriptor.
  1018. /// @param parent_memory_desc An existing memory descriptor.
  1019. /// @param dims Sizes of the region.
  1020. /// @param offsets Offsets to the region from the encompassing
  1021. /// memory object in each dimension
  1022. /// @returns #dnnl_success on success and a status describing the error
  1023. /// otherwise.
  1024. dnnl_status_t DNNL_API dnnl_memory_desc_create_submemory(
  1025. dnnl_memory_desc_t *memory_desc,
  1026. const_dnnl_memory_desc_t parent_memory_desc, const dnnl_dims_t dims,
  1027. const dnnl_dims_t offsets);
  1028. /// Creates a memory descriptor by reshaping an existing one. The new
  1029. /// memory descriptor inherits the data type. This operation is valid only for
  1030. /// memory descriptors that have format_kind #dnnl_blocked or
  1031. /// #dnnl_format_kind_any.
  1032. ///
  1033. /// The resulting memory descriptor must be destroyed separately.
  1034. ///
  1035. /// The operation ensures the transformation of the physical memory format
  1036. /// corresponds to the transformation of the logical dimensions. If such
  1037. /// transformation is impossible, the function returns #dnnl_invalid_arguments.
  1038. ///
  1039. /// The reshape operation can be described as a combination of the following
  1040. /// basic operations:
  1041. /// 1. Add a dimension of size `1`. This is always possible.
  1042. /// 2. Remove a dimension of size `1`. This is possible only if the dimension
  1043. /// has no padding (i.e. `padded_dims[dim] == dims[dim] && dims[dim] == 1`).
  1044. /// 3. Split a dimension into multiple ones. This is possible only if the size
  1045. /// of the dimension is exactly equal to the product of the split ones and
  1046. /// the dimension does not have padding (i.e.
  1047. /// `padded_dims[dim] = dims[dim]`).
  1048. /// 4. Joining multiple consecutive dimensions into a single one. As in the
  1049. /// cases above, this requires that the dimensions do not have padding and
  1050. /// that the memory format is such that in physical memory these dimensions
  1051. /// are dense and have the same order as their logical counterparts. This
  1052. /// also assumes that these dimensions are not blocked.
  1053. /// - Here, dense means:
  1054. /// `stride for dim[i] == (stride for dim[i + 1]) * dim[i + 1]`;
  1055. /// - And same order means:
  1056. /// `i < j` if and only if `stride for dim[j] <= stride for dim[i]`.
  1057. ///
  1058. /// @warning
  1059. /// Some combinations of physical memory layout and/or offsets or
  1060. /// dimensions may result in a failure to make a reshape.
  1061. ///
  1062. /// @param out_memory_desc Output memory descriptor.
  1063. /// @param in_memory_desc An existing memory descriptor. Must have format_kind
  1064. /// set to #dnnl_blocked or #dnnl_format_kind_any.
  1065. /// @param ndims Number of dimensions for the output memory descriptor.
  1066. /// @param dims Dimensions for the output memory descriptor.
  1067. /// @returns #dnnl_success on success and a status describing the error
  1068. /// otherwise.
  1069. dnnl_status_t DNNL_API dnnl_memory_desc_reshape(
  1070. dnnl_memory_desc_t *out_memory_desc,
  1071. const_dnnl_memory_desc_t in_memory_desc, int ndims,
  1072. const dnnl_dims_t dims);
  1073. /// Creates a memory descriptor by permuting axes in an existing one.
  1074. ///
  1075. /// The physical memory layout representation is adjusted accordingly to
  1076. /// maintain the consistency between the logical and physical parts of the
  1077. /// memory descriptor.
  1078. ///
  1079. /// The resulting memory descriptor must be destroyed separately.
  1080. ///
  1081. /// The new memory descriptor inherits the data type. This operation is valid
  1082. /// only for memory descriptors that have format_kind set to #dnnl_blocked or
  1083. /// #dnnl_format_kind_any.
  1084. ///
  1085. /// The logical axes will be permuted in the following manner:
  1086. /// ```
  1087. /// for (i: 0 .. in_memory_desc->ndims)
  1088. /// out_memory_desc->dims[permutation[i]] = in_memory_desc->dims[i];
  1089. /// ```
  1090. ///
  1091. /// Example:
  1092. /// @code
  1093. /// dnnl_memory_desc_t in_md, out_md, expect_out_md;
  1094. ///
  1095. /// const int permutation[] = {1, 0}; // swap the first and the second axes
  1096. ///
  1097. /// dnnl_dims_t in_dims = {2, 3}, out_dims = {3, 2};
  1098. /// dnnl_format_tag_t in_tag = dnnl_ab, out_tag = dnnl_ba;
  1099. ///
  1100. /// dnnl_memory_desc_create_with_tag(
  1101. /// &in_md, 2, in_dims, data_type, in_tag);
  1102. /// dnnl_memory_desc_create_with_tag(
  1103. /// &expect_out_md, 2, out_dims, data_type, out_tag);
  1104. ///
  1105. /// dnnl_memory_desc_permute_axes(&out_md, in_md, permutation);
  1106. /// assert(dnnl_memory_desc_equal(out_md, expect_out_md));
  1107. ///
  1108. /// dnnl_memory_desc_destroy(in_md);
  1109. /// dnnl_memory_desc_destroy(out_md);
  1110. /// dnnl_memory_desc_destroy(expect_out_md);
  1111. /// @endcode
  1112. ///
  1113. /// @param out_memory_desc Output memory descriptor.
  1114. /// @param in_memory_desc An existing memory descriptor. Must have format_kind
  1115. /// set to #dnnl_blocked or #dnnl_format_kind_any.
  1116. /// @param permutation Axes permutation (of size `in_memory_desc->ndims`).
  1117. /// @returns #dnnl_success on success and a status describing the error
  1118. /// otherwise.
  1119. dnnl_status_t DNNL_API dnnl_memory_desc_permute_axes(
  1120. dnnl_memory_desc_t *out_memory_desc,
  1121. const_dnnl_memory_desc_t in_memory_desc, const int *permutation);
  1122. /// Queries a memory descriptor for various pieces of information.
  1123. ///
  1124. /// The following information can be queried:
  1125. /// - Number of dimensions (#dnnl_query_ndims_s32)
  1126. /// - Dimensions (#dnnl_query_dims) in the following order:
  1127. /// - CNN data tensors: mini-batch, channel, spatial
  1128. /// (<code>{N, C, [[D,] H,] W}</code>)
  1129. /// - CNN weight tensors: group (optional), output channel, input channel,
  1130. /// spatial (<code>{[G,] O, I, [[D,] H,] W}</code>)
  1131. /// - RNN data tensors: time, mini-batch, channels (<code>{T, N, C}</code>)
  1132. /// or layers, directions, states, mini-batch, channels
  1133. /// (<code>{L, D, S, N, C}</code>)
  1134. /// - RNN weight tensor: layers, directions, input channel, gates, output
  1135. /// channels (<code>{L, D, I, G, O}</code>)
  1136. /// - Data type of the tensor elements (#dnnl_query_data_type)
  1137. /// - Padded dimensions (#dnnl_query_padded_dims) - size of the data including
  1138. /// padding in each dimension
  1139. /// - Padded offsets (#dnnl_query_padded_offsets) - per-dimension offset from
  1140. /// the padding to actual data, the top-level tensor with offsets applied
  1141. /// must lie within the padding area.
  1142. /// - Submemory offset (#dnnl_query_submemory_offset_s64) - offset from memory
  1143. /// origin to the current block, non-zero only in a description of a memory
  1144. /// sub-block.
  1145. /// - Format kind (#dnnl_query_format_kind) - memory format kind
  1146. ///
  1147. /// @note
  1148. /// The order of dimensions does not depend on the memory format, so
  1149. /// whether the data is laid out in #dnnl_nchw or #dnnl_nhwc
  1150. /// the dims for 4D CN data tensor would be <code>{N, C, H, W}</code>.
  1151. ///
  1152. /// The following queries are applicable only to format kind #dnnl_blocked.
  1153. /// - Strides (#dnnl_query_strides) between the outermost blocks or in case
  1154. /// of plain (non-blocked) formats the strides between dimensions
  1155. /// - Number of innermost blocks (#dnnl_query_inner_nblks_s32), e.g.
  1156. /// `{4, 16, 4}` in case of `OIhw_4i16o4i`
  1157. /// - Size of the innermost blocks (#dnnl_query_inner_blks), e.g. 3 in case
  1158. /// of `OIhw_4i16o4i_`
  1159. /// - Logical indices of the blocks (#dnnl_query_inner_idxs), e.g. `{1, 0, 1}`
  1160. /// in case of `4i16o4i`, because `i` is the 1st dim and `o` is the 0st dim
  1161. ///
  1162. /// @param memory_desc Memory descriptor.
  1163. /// @param what Parameter to query.
  1164. /// @param result Output result. The type depends on the query. For example,
  1165. /// it must be a @c dnnl_dims_t** if querying for a strides.
  1166. /// @returns #dnnl_success on success and a status describing the error
  1167. /// otherwise.
  1168. dnnl_status_t DNNL_API dnnl_memory_desc_query(
  1169. const_dnnl_memory_desc_t memory_desc, dnnl_query_t what, void *result);
  1170. /// Queries a memory descriptor for various pieces of information. This version
  1171. /// support additional queries #dnnl_query_sparse_encoding, #dnnl_query_nnz_s64
  1172. /// #dnnl_query_num_handles_s32 and #dnnl_query_data_type for a particular
  1173. /// buffer.
  1174. ///
  1175. /// The following information can be queried:
  1176. /// - Number of dimensions (#dnnl_query_ndims_s32)
  1177. /// - Dimensions (#dnnl_query_dims) in the following order:
  1178. /// - CNN data tensors: mini-batch, channel, spatial
  1179. /// (<code>{N, C, [[D,] H,] W}</code>)
  1180. /// - CNN weight tensors: group (optional), output channel, input channel,
  1181. /// spatial (<code>{[G,] O, I, [[D,] H,] W}</code>)
  1182. /// - RNN data tensors: time, mini-batch, channels (<code>{T, N, C}</code>)
  1183. /// or layers, directions, states, mini-batch, channels
  1184. /// (<code>{L, D, S, N, C}</code>)
  1185. /// - RNN weight tensor: layers, directions, input channel, gates, output
  1186. /// channels (<code>{L, D, I, G, O}</code>)
  1187. /// - Data type of the tensor elements (#dnnl_query_data_type)
  1188. /// - Padded dimensions (#dnnl_query_padded_dims) - size of the data including
  1189. /// padding in each dimension
  1190. /// - Padded offsets (#dnnl_query_padded_offsets) - per-dimension offset from
  1191. /// the padding to actual data, the top-level tensor with offsets applied
  1192. /// must lie within the padding area.
  1193. /// - Submemory offset (#dnnl_query_submemory_offset_s64) - offset from memory
  1194. /// origin to the current block, non-zero only in a description of a memory
  1195. /// sub-block.
  1196. /// - Format kind (#dnnl_query_format_kind) - memory format kind
  1197. ///
  1198. /// @note
  1199. /// The order of dimensions does not depend on the memory format, so
  1200. /// whether the data is laid out in #dnnl_nchw or #dnnl_nhwc
  1201. /// the dims for 4D CN data tensor would be <code>{N, C, H, W}</code>.
  1202. ///
  1203. /// The following queries are applicable only to format kind #dnnl_blocked.
  1204. /// - Strides (#dnnl_query_strides) between the outermost blocks or in case
  1205. /// of plain (non-blocked) formats the strides between dimensions
  1206. /// - Number of innermost blocks (#dnnl_query_inner_nblks_s32), e.g.
  1207. /// `{4, 16, 4}` in case of `OIhw_4i16o4i`
  1208. /// - Size of the innermost blocks (#dnnl_query_inner_blks), e.g. 3 in case
  1209. /// of `OIhw_4i16o4i_`
  1210. /// - Logical indices of the blocks (#dnnl_query_inner_idxs), e.g. `{1, 0, 1}`
  1211. /// in case of `4i16o4i`, because `i` is the 1st dim and `o` is the 0st dim
  1212. ///
  1213. /// @param memory_desc Memory descriptor.
  1214. /// @param what Parameter to query.
  1215. /// @param index Index of the parameter to query for. It is mostly used with
  1216. /// #dnnl_query_data_type to specify which data type is being queried.
  1217. /// The main data type (data type of values) has always index 0. For other
  1218. /// indices please refer to the API for creating a memory descriptor for
  1219. /// sparse encoding.
  1220. /// @param result Output result. The type depends on the query. For example,
  1221. /// it must be a @c dnnl_dims_t** if querying for a strides.
  1222. /// @returns #dnnl_success on success and a status describing the error
  1223. /// otherwise.
  1224. /// @sa @ref dev_guide_sparsity
  1225. dnnl_status_t DNNL_API dnnl_memory_desc_query_v2(
  1226. const_dnnl_memory_desc_t memory_desc, dnnl_query_t what, int index,
  1227. void *result);
  1228. /// Compares two memory descriptors.
  1229. ///
  1230. /// Use this function to identify whether a reorder is required between the
  1231. /// two memories
  1232. ///
  1233. /// @param lhs Left-hand side of the comparison.
  1234. /// @param rhs Right-hand side of the comparison.
  1235. /// @returns 1 if the descriptors are the same.
  1236. /// @returns 0 if the descriptors are different.
  1237. int DNNL_API dnnl_memory_desc_equal(
  1238. const_dnnl_memory_desc_t lhs, const_dnnl_memory_desc_t rhs);
  1239. /// Returns the size of a memory descriptor.
  1240. ///
  1241. /// @param memory_desc Memory descriptor.
  1242. /// @returns The number of bytes required for memory described by a memory
  1243. /// descriptor.
  1244. size_t DNNL_API dnnl_memory_desc_get_size(const_dnnl_memory_desc_t memory_desc);
  1245. /// Returns the size of the data that corresponds to the given index.
  1246. ///
  1247. /// @param memory_desc Memory descriptor.
  1248. /// @param index Index of the buffer.
  1249. ///
  1250. /// @returns The number of bytes required for the requested data.
  1251. size_t DNNL_API dnnl_memory_desc_get_size_v2(
  1252. const_dnnl_memory_desc_t memory_desc, int index);
  1253. /// Returns the size of data type.
  1254. ///
  1255. /// @param data_type Data type.
  1256. /// @returns The number of bytes occupied by data type.
  1257. size_t DNNL_API dnnl_data_type_size(dnnl_data_type_t data_type);
  1258. /// Creates a memory object.
  1259. ///
  1260. /// Unless @p handle is equal to DNNL_MEMORY_NONE, the constructed memory
  1261. /// object will have the underlying buffer set. In this case, the buffer will
  1262. /// be initialized as if dnnl_memory_set_data_handle() had been called.
  1263. ///
  1264. /// @sa dnnl_memory_set_data_handle()
  1265. ///
  1266. /// @param memory Output memory object.
  1267. /// @param memory_desc Memory descriptor.
  1268. /// @param engine Engine to use.
  1269. /// @param handle Handle of the memory buffer to use as an underlying storage.
  1270. /// - A pointer to the user-allocated buffer. In this case the library
  1271. /// doesn't own the buffer.
  1272. /// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to
  1273. /// allocate the buffer for the memory object. In this case the library
  1274. /// owns the buffer.
  1275. /// - DNNL_MEMORY_NONE to create dnnl_memory without an underlying buffer.
  1276. /// @returns #dnnl_success on success and a status describing the error
  1277. /// otherwise.
  1278. dnnl_status_t DNNL_API dnnl_memory_create(dnnl_memory_t *memory,
  1279. const_dnnl_memory_desc_t memory_desc, dnnl_engine_t engine,
  1280. void *handle);
  1281. /// Creates a memory object with multiple handles.
  1282. ///
  1283. /// @param memory Output memory object.
  1284. /// @param memory_desc Memory descriptor.
  1285. /// @param engine Engine to use.
  1286. /// @param nhandles Number of handles.
  1287. /// @param handles Handles of the memory buffers to use as underlying storages.
  1288. /// For each element of the @p handles array the following applies:
  1289. /// - A pointer to the user-allocated buffer. In this case the library
  1290. /// doesn't own the buffer.
  1291. /// - The DNNL_MEMORY_ALLOCATE special value. Instructs the library to
  1292. /// allocate the buffer for the memory object. In this case the library
  1293. /// owns the buffer.
  1294. /// - DNNL_MEMORY_NONE Instructs the library to skip allocation of the
  1295. /// memory buffer.
  1296. /// @returns #dnnl_success on success and a status describing the error
  1297. /// otherwise.
  1298. dnnl_status_t DNNL_API dnnl_memory_create_v2(dnnl_memory_t *memory,
  1299. const_dnnl_memory_desc_t memory_desc, dnnl_engine_t engine,
  1300. int nhandles, void **handles);
  1301. /// Creates a memory object for a scalar value located on the host.
  1302. ///
  1303. /// @note The scalar value is copied from the provided pointer into the newly
  1304. /// allocated memory storage, so the user does not need to manage the
  1305. /// lifetime of the original scalar data.
  1306. ///
  1307. /// @param memory Output host-side scalar memory object.
  1308. /// @param memory_desc Memory descriptor describing a scalar value residing on the host.
  1309. /// @param scalar_ptr Pointer to the scalar value to be copied into the memory
  1310. /// object. This should be a host pointer to the scalar data.
  1311. /// @returns #dnnl_success on success; otherwise, returns a status code
  1312. /// describing the error.
  1313. dnnl_status_t DNNL_API dnnl_memory_create_host_scalar(dnnl_memory_t *memory,
  1314. const_dnnl_memory_desc_t memory_desc, void *scalar_ptr);
  1315. /// Returns the memory descriptor for a memory object.
  1316. ///
  1317. /// @param memory Memory object.
  1318. /// @param memory_desc Output memory descriptor (a copy).
  1319. /// @returns #dnnl_success on success and a status describing the error
  1320. /// otherwise.
  1321. dnnl_status_t DNNL_API dnnl_memory_get_memory_desc(
  1322. const_dnnl_memory_t memory, const_dnnl_memory_desc_t *memory_desc);
  1323. /// Returns the engine of a memory object.
  1324. ///
  1325. /// @param memory Memory object.
  1326. /// @param engine Output engine on which the memory is located.
  1327. /// @returns #dnnl_success on success and a status describing the error
  1328. /// otherwise.
  1329. dnnl_status_t DNNL_API dnnl_memory_get_engine(
  1330. const_dnnl_memory_t memory, dnnl_engine_t *engine);
  1331. /// Maps a memory object and returns a host-side pointer to a memory buffer
  1332. /// with a copy of its contents.
  1333. ///
  1334. /// Mapping enables explicit direct access to memory contents for the engines
  1335. /// that do not support it implicitly.
  1336. ///
  1337. /// Mapping is an exclusive operation - a memory object cannot be used in
  1338. /// other operations until this memory object is unmapped.
  1339. ///
  1340. /// @note
  1341. /// Any primitives working with @p memory should be completed before
  1342. /// the memory is mapped. Use dnnl_stream_wait to synchronize the
  1343. /// corresponding execution stream.
  1344. ///
  1345. /// @note
  1346. /// The dnnl_memory_map_data() and dnnl_memory_unmap_data() functions are
  1347. /// mainly provided for debug and testing purposes, and their performance
  1348. /// may be suboptimal.
  1349. ///
  1350. /// @param memory Memory object.
  1351. /// @param mapped_ptr Output pointer to the mapped buffer.
  1352. /// @returns #dnnl_success on success and a status describing the error
  1353. /// otherwise.
  1354. dnnl_status_t DNNL_API dnnl_memory_map_data(
  1355. const_dnnl_memory_t memory, void **mapped_ptr);
  1356. /// Maps a memory object and returns a host-side pointer to a memory buffer
  1357. /// with a copy of its contents. The memory buffer corresponds to the given
  1358. /// index.
  1359. ///
  1360. /// Mapping enables explicit direct access to memory contents for the engines
  1361. /// that do not support it implicitly.
  1362. ///
  1363. /// Mapping is an exclusive operation - a memory object cannot be used in
  1364. /// other operations until this memory object is unmapped.
  1365. ///
  1366. /// @note
  1367. /// Any primitives working with @p memory should be completed before
  1368. /// the memory is mapped. Use dnnl_stream_wait to synchronize the
  1369. /// corresponding execution stream.
  1370. ///
  1371. /// @note
  1372. /// The dnnl_memory_map_data() and dnnl_memory_unmap_data() functions are
  1373. /// mainly provided for debug and testing purposes, and their performance
  1374. /// may be suboptimal.
  1375. ///
  1376. /// @param memory Memory object.
  1377. /// @param mapped_ptr Output pointer to the mapped buffer.
  1378. /// @param index Index of the buffer.
  1379. /// @returns #dnnl_success on success and a status describing the error
  1380. /// otherwise.
  1381. dnnl_status_t DNNL_API dnnl_memory_map_data_v2(
  1382. const_dnnl_memory_t memory, void **mapped_ptr, int index);
  1383. /// Unmaps a memory object and writes back any changes made to the previously
  1384. /// mapped memory buffer. The pointer to the mapped buffer must be obtained
  1385. /// via the dnnl_memory_map_data() call.
  1386. ///
  1387. /// @note
  1388. /// The dnnl_memory_map_data() and dnnl_memory_unmap_data() functions are
  1389. /// mainly provided for debug and testing purposes, and their performance
  1390. /// may be suboptimal.
  1391. ///
  1392. /// @param memory Memory object.
  1393. /// @param mapped_ptr Pointer to the mapped buffer that must have been
  1394. /// obtained using the dnnl_memory_map_data() function.
  1395. /// @returns #dnnl_success on success and a status describing the error
  1396. /// otherwise.
  1397. dnnl_status_t DNNL_API dnnl_memory_unmap_data(
  1398. const_dnnl_memory_t memory, void *mapped_ptr);
  1399. /// Unmaps a memory object and writes back any changes made to the previously
  1400. /// mapped memory buffer. The pointer to the mapped buffer must be obtained
  1401. /// via the dnnl_memory_map_data() call. The buffer corresponds to the given
  1402. /// index.
  1403. ///
  1404. /// @note
  1405. /// The dnnl_memory_map_data() and dnnl_memory_unmap_data() functions are
  1406. /// mainly provided for debug and testing purposes, and their performance
  1407. /// may be suboptimal.
  1408. ///
  1409. /// @param memory Memory object.
  1410. /// @param mapped_ptr Pointer to the mapped buffer that must have been
  1411. /// obtained using the dnnl_memory_map_data() function.
  1412. /// @param index Index of the buffer.
  1413. /// @returns #dnnl_success on success and a status describing the error
  1414. /// otherwise.
  1415. dnnl_status_t DNNL_API dnnl_memory_unmap_data_v2(
  1416. const_dnnl_memory_t memory, void *mapped_ptr, int index);
  1417. /// Returns memory object's data handle.
  1418. ///
  1419. /// @param memory Memory object.
  1420. /// @param handle Output data handle. For the CPU engine, the data handle is a
  1421. /// pointer to the actual data. For OpenCL it is a cl_mem.
  1422. /// @returns #dnnl_success on success and a status describing the error
  1423. /// otherwise.
  1424. dnnl_status_t DNNL_API dnnl_memory_get_data_handle(
  1425. const_dnnl_memory_t memory, void **handle);
  1426. /// Sets the underlying memory buffer.
  1427. ///
  1428. /// @param memory Memory object.
  1429. /// @param handle Data handle. For the CPU engine or when USM is used, the
  1430. /// memory buffer is a pointer to the actual data. For OpenCL it is a
  1431. /// `cl_mem`.
  1432. /// @returns #dnnl_success on success and a status describing the error
  1433. /// otherwise.
  1434. dnnl_status_t DNNL_API dnnl_memory_set_data_handle(
  1435. dnnl_memory_t memory, void *handle);
  1436. /// Returns an underlying memory buffer that corresponds to the given index.
  1437. ///
  1438. /// @param memory Memory object.
  1439. /// @param handle Data handle. For the CPU engine or when USM is used, the
  1440. /// memory buffer is a pointer to the actual data. For OpenCL it is a
  1441. /// `cl_mem`.
  1442. /// @param index Index of the buffer.
  1443. /// @returns #dnnl_success on success and a status describing the error
  1444. /// otherwise.
  1445. dnnl_status_t DNNL_API dnnl_memory_get_data_handle_v2(
  1446. const_dnnl_memory_t memory, void **handle, int index);
  1447. /// Sets an underlying memory buffer that corresponds to the given index.
  1448. ///
  1449. /// @param memory Memory object.
  1450. /// @param handle Data handle. For the CPU engine or when USM is used, the
  1451. /// memory buffer is a pointer to the actual data. For OpenCL it is a
  1452. /// `cl_mem`.
  1453. /// @param index Index of the buffer.
  1454. /// @returns #dnnl_success on success and a status describing the error
  1455. /// otherwise.
  1456. dnnl_status_t DNNL_API dnnl_memory_set_data_handle_v2(
  1457. dnnl_memory_t memory, void *handle, int index);
  1458. /// Returns the value stored in a scalar memory object as a host pointer.
  1459. ///
  1460. /// @param memory Host-side scalar memory object.
  1461. /// @param value Output pointer to the scalar value. The type of the value
  1462. /// depends on the data type of the memory object.
  1463. /// @returns #dnnl_success on success and a status describing the error
  1464. /// otherwise.
  1465. dnnl_status_t DNNL_API dnnl_memory_get_host_scalar_value(
  1466. const_dnnl_memory_t memory, void *value);
  1467. /// Sets the value of a scalar memory object from a host pointer.
  1468. ///
  1469. /// @note The value would be copied from the provided pointer into the
  1470. /// memory object, so the user does not need to manage the lifetime of the
  1471. /// original scalar data.
  1472. ///
  1473. /// @param memory Host-side scalar memory object.
  1474. /// @param value Pointer to the scalar value to be copied into the
  1475. /// memory object. The type of the value must match the data type of the
  1476. /// memory object.
  1477. /// @returns #dnnl_success on success and a status describing the error
  1478. /// otherwise.
  1479. dnnl_status_t DNNL_API dnnl_memory_set_host_scalar_value(
  1480. dnnl_memory_t memory, const void *value);
  1481. /// Destroys a memory object.
  1482. ///
  1483. /// @param memory Memory object to destroy.
  1484. /// @returns #dnnl_success on success and a status describing the error
  1485. /// otherwise.
  1486. dnnl_status_t DNNL_API dnnl_memory_destroy(dnnl_memory_t memory);
  1487. /// @} dnnl_api_memory
  1488. /// @addtogroup dnnl_api_primitives
  1489. /// @{
  1490. /// @addtogroup dnnl_api_reorder
  1491. /// @{
  1492. /// Creates a primitive descriptor for a reorder primitive.
  1493. ///
  1494. /// @param reorder_primitive_desc Output primitive descriptor.
  1495. /// @param src_desc Source memory descriptor.
  1496. /// @param src_engine Engine on which the source memory object will be
  1497. /// located.
  1498. /// @param dst_desc Destination memory descriptor.
  1499. /// @param dst_engine Engine on which the destination memory object
  1500. /// will be located.
  1501. /// @param attr Primitive attributes to use (can be NULL).
  1502. /// @returns #dnnl_success on success and a status describing the error
  1503. /// otherwise.
  1504. dnnl_status_t DNNL_API dnnl_reorder_primitive_desc_create(
  1505. dnnl_primitive_desc_t *reorder_primitive_desc,
  1506. const_dnnl_memory_desc_t src_desc, dnnl_engine_t src_engine,
  1507. const_dnnl_memory_desc_t dst_desc, dnnl_engine_t dst_engine,
  1508. const_dnnl_primitive_attr_t attr);
  1509. /// @} dnnl_api_reorder
  1510. /// @addtogroup dnnl_api_concat
  1511. /// @{
  1512. /// Creates a primitive descriptor for an out-of-place concatenation
  1513. /// primitive.
  1514. ///
  1515. /// @param concat_primitive_desc Output primitive descriptor.
  1516. /// @param dst_desc Destination memory descriptor.
  1517. /// @param n Number of source parameters.
  1518. /// @param concat_dimension Source tensors will be concatenated over
  1519. /// dimension with this index. Note that order of dimensions does
  1520. /// not depend on memory format.
  1521. /// @param src_descs Array of source memory descriptors with @p n elements.
  1522. /// @param attr Primitive attributes to use (can be NULL).
  1523. /// @param engine Engine to use.
  1524. /// @returns #dnnl_success on success and a status describing the error
  1525. /// otherwise.
  1526. dnnl_status_t DNNL_API dnnl_concat_primitive_desc_create(
  1527. dnnl_primitive_desc_t *concat_primitive_desc, dnnl_engine_t engine,
  1528. const_dnnl_memory_desc_t dst_desc, int n, int concat_dimension,
  1529. const_dnnl_memory_desc_t const *src_descs,
  1530. const_dnnl_primitive_attr_t attr);
  1531. /// @} dnnl_api_concat
  1532. /// @addtogroup dnnl_api_sum
  1533. /// @{
  1534. /// Creates a primitive descriptor for an (out-of-place) sum primitive.
  1535. ///
  1536. /// @param sum_primitive_desc Output primitive descriptor.
  1537. /// @param dst_desc Destination memory descriptor.
  1538. /// @param n Number of source parameters.
  1539. /// @param scales Vector of scales to multiply data in each source
  1540. /// memory by.
  1541. /// @param src_descs Array of source memory descriptors having @p n elements.
  1542. /// @param attr Primitive attributes to use (can be NULL).
  1543. /// @param engine Engine to use.
  1544. /// @returns #dnnl_success on success and a status describing the error
  1545. /// otherwise.
  1546. dnnl_status_t DNNL_API dnnl_sum_primitive_desc_create(
  1547. dnnl_primitive_desc_t *sum_primitive_desc, dnnl_engine_t engine,
  1548. const_dnnl_memory_desc_t dst_desc, int n, const float *scales,
  1549. const_dnnl_memory_desc_t const *src_descs,
  1550. const_dnnl_primitive_attr_t attr);
  1551. /// @} dnnl_api_sum
  1552. /// @addtogroup dnnl_api_binary
  1553. /// @{
  1554. /// Creates a primitive descriptor for a binary primitive.
  1555. ///
  1556. /// @note
  1557. /// Memory descriptors @p src1_desc and @p dst_desc are allowed to be
  1558. /// initialized with #dnnl_format_tag_any or with format_kind set to
  1559. /// #dnnl_format_kind_any.
  1560. ///
  1561. /// @note
  1562. /// Both memory descriptors must have the same number of dimensions.
  1563. /// Element broadcasting is supported for memory descriptor @p src1_desc
  1564. /// and are applied to @p src1_desc dimensions that have size equal to 1.
  1565. ///
  1566. /// @param primitive_desc Output primitive descriptor.
  1567. /// @param engine Engine to use.
  1568. /// @param alg_kind Algorithm kind. Valid values are #dnnl_binary_add,
  1569. /// #dnnl_binary_mul, #dnnl_binary_max, #dnnl_binary_min, #dnnl_binary_div,
  1570. /// #dnnl_binary_sub, #dnnl_binary_ge, #dnnl_binary_gt, #dnnl_binary_le,
  1571. /// #dnnl_binary_lt, #dnnl_binary_eq and #dnnl_binary_ne.
  1572. /// @param src0_desc Source 0 memory descriptor.
  1573. /// @param src1_desc Source 1 memory descriptor.
  1574. /// @param dst_desc Destination memory descriptor.
  1575. /// @param attr Primitive attributes (can be NULL).
  1576. /// @returns #dnnl_success on success and a status describing the error
  1577. /// otherwise.
  1578. dnnl_status_t DNNL_API dnnl_binary_primitive_desc_create(
  1579. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1580. dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src0_desc,
  1581. const_dnnl_memory_desc_t src1_desc, const_dnnl_memory_desc_t dst_desc,
  1582. const_dnnl_primitive_attr_t attr);
  1583. /// Creates a primitive descriptor for a binary primitive with support of
  1584. /// ternary operators.
  1585. ///
  1586. /// @note
  1587. /// Memory descriptors @p src1_desc, @p src2_desc and @p dst_desc are
  1588. /// allowed to be initialized with #dnnl_format_tag_any or with format_kind
  1589. /// set to #dnnl_format_kind_any.
  1590. ///
  1591. /// @note
  1592. /// All memory descriptors must have the same number of dimensions.
  1593. /// Element broadcasting is supported for memory descriptor @p src1_desc
  1594. /// and is applied to @p src1_desc dimensions that have a size equal to 1.
  1595. /// There is no broadcasting support for @p src2_desc.
  1596. ///
  1597. /// @param primitive_desc Output primitive descriptor.
  1598. /// @param engine Engine to use.
  1599. /// @param alg_kind Algorithm kind.
  1600. /// @param src0_desc Source 0 memory descriptor.
  1601. /// @param src1_desc Source 1 memory descriptor.
  1602. /// @param src2_desc Source memory descriptor for ternary operations. Might
  1603. /// be empty.
  1604. /// @param dst_desc Destination memory descriptor.
  1605. /// @param attr Primitive attributes.
  1606. /// @returns #dnnl_success on success and a status describing the error
  1607. /// otherwise.
  1608. dnnl_status_t DNNL_API dnnl_binary_primitive_desc_create_v2(
  1609. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1610. dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src0_desc,
  1611. const_dnnl_memory_desc_t src1_desc, const_dnnl_memory_desc_t src2_desc,
  1612. const_dnnl_memory_desc_t dst_desc, const_dnnl_primitive_attr_t attr);
  1613. /// @} dnnl_api_binary
  1614. /// @addtogroup dnnl_api_convolution
  1615. /// @{
  1616. /// Creates a primitive descriptor for a convolution forward propagation
  1617. /// primitive.
  1618. ///
  1619. /// @note
  1620. /// Memory descriptors can be initialized with
  1621. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  1622. ///
  1623. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
  1624. /// values for spatial dimensions only and hence must have the same number of
  1625. /// elements as there are spatial dimensions. The order of values is the same
  1626. /// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
  1627. /// and width.
  1628. ///
  1629. /// @param primitive_desc Output primitive descriptor.
  1630. /// @param engine Engine to use.
  1631. /// @param prop_kind Propagation kind. Possible values are
  1632. /// #dnnl_forward_training and #dnnl_forward_inference.
  1633. /// @param alg_kind Convolution algorithm. Possible values are
  1634. /// #dnnl_convolution_direct, #dnnl_convolution_winograd,
  1635. /// #dnnl_convolution_auto.
  1636. /// @param src_desc Source memory descriptor.
  1637. /// @param weights_desc Weights memory descriptor.
  1638. /// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
  1639. /// descriptor, or a memory descriptor with format_kind set to
  1640. /// #dnnl_format_kind_undef disables the bias term.
  1641. /// @param dst_desc Destination memory descriptor.
  1642. /// @param strides Array of strides for spatial dimension.
  1643. /// @param dilates Array of dilations for spatial dimension. A zero value
  1644. /// means no dilation in the corresponding dimension.
  1645. /// @param padding_l Array of padding values for low indices for each spatial
  1646. /// dimension `([[front,] top,] left)`.
  1647. /// @param padding_r Array of padding values for high indices for each spatial
  1648. /// dimension `([[back,] bottom,] right)`. Can be NULL in which case
  1649. /// padding is considered to be symmetrical.
  1650. /// @param attr Primitive attributes (can be NULL).
  1651. /// @returns #dnnl_success on success and a status describing the error
  1652. /// otherwise.
  1653. dnnl_status_t DNNL_API dnnl_convolution_forward_primitive_desc_create(
  1654. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1655. dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
  1656. const_dnnl_memory_desc_t src_desc,
  1657. const_dnnl_memory_desc_t weights_desc,
  1658. const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_desc,
  1659. const dnnl_dims_t strides, const dnnl_dims_t dilates,
  1660. const dnnl_dims_t padding_l, const dnnl_dims_t padding_r,
  1661. const_dnnl_primitive_attr_t attr);
  1662. /// Creates a primitive descriptor for a convolution backward propagation
  1663. /// primitive.
  1664. ///
  1665. /// @note
  1666. /// Memory descriptors can be initialized with
  1667. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  1668. ///
  1669. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
  1670. /// values for spatial dimensions only and hence must have the same number of
  1671. /// elements as there are spatial dimensions. The order of values is the same
  1672. /// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
  1673. /// and width.
  1674. ///
  1675. /// @param primitive_desc Output primitive descriptor.
  1676. /// @param engine Engine to use.
  1677. /// @param alg_kind Convolution algorithm. Possible values are
  1678. /// #dnnl_convolution_direct, #dnnl_convolution_winograd,
  1679. /// #dnnl_convolution_auto.
  1680. /// @param diff_src_desc Diff source memory descriptor.
  1681. /// @param weights_desc Weights memory descriptor.
  1682. /// @param diff_dst_desc Diff destination memory descriptor.
  1683. /// @param strides Array of strides for spatial dimension.
  1684. /// @param dilates Array of dilations for spatial dimension. A zero value
  1685. /// means no dilation in the corresponding dimension.
  1686. /// @param padding_l Array of padding values for low indices for each spatial
  1687. /// dimension `([[front,] top,] left)`.
  1688. /// @param padding_r Array of padding values for high indices for each spatial
  1689. /// dimension `([[back,] bottom,] right)`. Can be NULL in which case
  1690. /// padding is considered to be symmetrical.
  1691. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  1692. /// primitive.
  1693. /// @param attr Primitive attributes (can be NULL).
  1694. /// @returns #dnnl_success on success and a status describing the error
  1695. /// otherwise.
  1696. dnnl_status_t DNNL_API dnnl_convolution_backward_data_primitive_desc_create(
  1697. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1698. dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
  1699. const_dnnl_memory_desc_t weights_desc,
  1700. const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides,
  1701. const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
  1702. const dnnl_dims_t padding_r, const_dnnl_primitive_desc_t hint_fwd_pd,
  1703. const_dnnl_primitive_attr_t attr);
  1704. /// Creates a primitive descriptor for a convolution weights gradient primitive.
  1705. ///
  1706. /// @note
  1707. /// Memory descriptors can be initialized with
  1708. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  1709. ///
  1710. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
  1711. /// values for spatial dimensions only and hence must have the same number of
  1712. /// elements as there are spatial dimensions. The order of values is the same
  1713. /// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
  1714. /// and width.
  1715. ///
  1716. /// @param primitive_desc Output primitive descriptor.
  1717. /// @param engine Engine to use.
  1718. /// @param alg_kind Convolution algorithm. Possible values are
  1719. /// #dnnl_convolution_direct, #dnnl_convolution_winograd,
  1720. /// #dnnl_convolution_auto.
  1721. /// @param src_desc Source memory descriptor.
  1722. /// @param diff_weights_desc Diff weights memory descriptor.
  1723. /// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero
  1724. /// memory descriptor, or a memory descriptor with format_kind set to
  1725. /// #dnnl_format_kind_undef disables the bias term.
  1726. /// @param diff_dst_desc Diff destination memory descriptor.
  1727. /// @param strides Array of strides for spatial dimension.
  1728. /// @param dilates Array of dilations for spatial dimension. A zero value
  1729. /// means no dilation in the corresponding dimension.
  1730. /// @param padding_l Array of padding values for low indices for each spatial
  1731. /// dimension `([[front,] top,] left)`.
  1732. /// @param padding_r Array of padding values for high indices for each spatial
  1733. /// dimension `([[back,] bottom,] right)`. Can be NULL in which case
  1734. /// padding is considered to be symmetrical.
  1735. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  1736. /// primitive.
  1737. /// @param attr Primitive attributes (can be NULL).
  1738. /// @returns #dnnl_success on success and a status describing the error
  1739. /// otherwise.
  1740. dnnl_status_t DNNL_API dnnl_convolution_backward_weights_primitive_desc_create(
  1741. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1742. dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src_desc,
  1743. const_dnnl_memory_desc_t diff_weights_desc,
  1744. const_dnnl_memory_desc_t diff_bias_desc,
  1745. const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides,
  1746. const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
  1747. const dnnl_dims_t padding_r, const_dnnl_primitive_desc_t hint_fwd_pd,
  1748. const_dnnl_primitive_attr_t attr);
  1749. /// @} dnnl_api_convolution
  1750. /// @addtogroup dnnl_api_deconvolution
  1751. /// @{
  1752. /// Creates a primitive descriptor for a deconvolution forward propagation
  1753. /// primitive.
  1754. ///
  1755. /// @note
  1756. /// Memory descriptors can be initialized with
  1757. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  1758. ///
  1759. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
  1760. /// values for spatial dimensions only and hence must have the same number of
  1761. /// elements as there are spatial dimensions. The order of values is the same
  1762. /// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
  1763. /// and width.
  1764. ///
  1765. /// @param primitive_desc Output primitive descriptor.
  1766. /// @param engine Engine to use.
  1767. /// @param prop_kind Propagation kind. Possible values are
  1768. /// #dnnl_forward_training and #dnnl_forward_inference.
  1769. /// @param alg_kind Deconvolution algorithm. Possible values are
  1770. /// #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd.
  1771. /// @param src_desc Source memory descriptor.
  1772. /// @param weights_desc Weights memory descriptor.
  1773. /// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
  1774. /// descriptor, or a memory descriptor with format_kind set to
  1775. /// #dnnl_format_kind_undef disables the bias term.
  1776. /// @param dst_desc Destination memory descriptor.
  1777. /// @param strides Array of strides for spatial dimension.
  1778. /// @param dilates Array of dilations for spatial dimension. A zero value
  1779. /// means no dilation in the corresponding dimension.
  1780. /// @param padding_l Array of padding values for low indices for each spatial
  1781. /// dimension `([[front,] top,] left)`.
  1782. /// @param padding_r Array of padding values for high indices for each spatial
  1783. /// dimension `([[back,] bottom,] right)`. Can be NULL in which case
  1784. /// padding is considered to be symmetrical.
  1785. /// @param attr Primitive attributes (can be NULL).
  1786. /// @returns #dnnl_success on success and a status describing the error
  1787. /// otherwise.
  1788. dnnl_status_t DNNL_API dnnl_deconvolution_forward_primitive_desc_create(
  1789. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1790. dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
  1791. const_dnnl_memory_desc_t src_desc,
  1792. const_dnnl_memory_desc_t weights_desc,
  1793. const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_desc,
  1794. const dnnl_dims_t strides, const dnnl_dims_t dilates,
  1795. const dnnl_dims_t padding_l, const dnnl_dims_t padding_r,
  1796. const_dnnl_primitive_attr_t attr);
  1797. /// Creates a primitive descriptor for a deconvolution backward propagation
  1798. /// primitive.
  1799. ///
  1800. /// @note
  1801. /// Memory descriptors can be initialized with
  1802. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  1803. ///
  1804. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
  1805. /// values for spatial dimensions only and hence must have the same number of
  1806. /// elements as there are spatial dimensions. The order of values is the same
  1807. /// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
  1808. /// and width.
  1809. ///
  1810. /// @param primitive_desc Output primitive descriptor.
  1811. /// @param engine Engine to use.
  1812. /// @param alg_kind Deconvolution algorithm. Possible values are
  1813. /// #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd.
  1814. /// @param diff_src_desc Diff source memory descriptor.
  1815. /// @param weights_desc Weights memory descriptor.
  1816. /// @param diff_dst_desc Diff destination memory descriptor.
  1817. /// @param strides Array of strides for spatial dimension.
  1818. /// @param dilates Array of dilations for spatial dimension. A zero value
  1819. /// means no dilation in the corresponding dimension.
  1820. /// @param padding_l Array of padding values for low indices for each spatial
  1821. /// dimension `([[front,] top,] left)`.
  1822. /// @param padding_r Array of padding values for high indices for each spatial
  1823. /// dimension `([[back,] bottom,] right)`. Can be NULL in which case
  1824. /// padding is considered to be symmetrical.
  1825. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  1826. /// primitive.
  1827. /// @param attr Primitive attributes (can be NULL).
  1828. /// @returns #dnnl_success on success and a status describing the error
  1829. /// otherwise.
  1830. dnnl_status_t DNNL_API dnnl_deconvolution_backward_data_primitive_desc_create(
  1831. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1832. dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
  1833. const_dnnl_memory_desc_t weights_desc,
  1834. const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides,
  1835. const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
  1836. const dnnl_dims_t padding_r, const_dnnl_primitive_desc_t hint_fwd_pd,
  1837. const_dnnl_primitive_attr_t attr);
  1838. /// Creates a primitive descriptor for a deconvolution weights gradient
  1839. /// primitive.
  1840. ///
  1841. /// @note
  1842. /// Memory descriptors can be initialized with
  1843. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  1844. ///
  1845. /// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r contain
  1846. /// values for spatial dimensions only and hence must have the same number of
  1847. /// elements as there are spatial dimensions. The order of values is the same
  1848. /// as in the tensor: depth (for 3D tensors), height (for 3D and 2D tensors),
  1849. /// and width.
  1850. ///
  1851. /// @param primitive_desc Output primitive descriptor.
  1852. /// @param engine Engine to use.
  1853. /// @param alg_kind Deconvolution algorithm. Possible values are
  1854. /// #dnnl_deconvolution_direct, #dnnl_deconvolution_winograd.
  1855. /// @param src_desc Source memory descriptor.
  1856. /// @param diff_weights_desc Diff weights memory descriptor.
  1857. /// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero
  1858. /// memory descriptor, or a memory descriptor with format_kind set to
  1859. /// #dnnl_format_kind_undef disables the bias term.
  1860. /// @param diff_dst_desc Diff destination memory descriptor.
  1861. /// @param strides Array of strides for spatial dimension.
  1862. /// @param dilates Array of dilations for spatial dimension. A zero value
  1863. /// means no dilation in the corresponding dimension.
  1864. /// @param padding_l Array of padding values for low indices for each spatial
  1865. /// dimension `([[front,] top,] left)`.
  1866. /// @param padding_r Array of padding values for high indices for each spatial
  1867. /// dimension `([[back,] bottom,] right)`. Can be NULL in which case
  1868. /// padding is considered to be symmetrical.
  1869. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  1870. /// primitive.
  1871. /// @param attr Primitive attributes (can be NULL).
  1872. /// @returns #dnnl_success on success and a status describing the error
  1873. /// otherwise.
  1874. dnnl_status_t DNNL_API
  1875. dnnl_deconvolution_backward_weights_primitive_desc_create(
  1876. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1877. dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src_desc,
  1878. const_dnnl_memory_desc_t diff_weights_desc,
  1879. const_dnnl_memory_desc_t diff_bias_desc,
  1880. const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides,
  1881. const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
  1882. const dnnl_dims_t padding_r, const_dnnl_primitive_desc_t hint_fwd_pd,
  1883. const_dnnl_primitive_attr_t attr);
  1884. /// @} dnnl_api_deconvolution
  1885. /// @addtogroup dnnl_api_shuffle
  1886. /// @{
  1887. /// Creates a primitive descriptor for a shuffle forward propagation primitive
  1888. ///
  1889. /// @param primitive_desc Output primitive descriptor.
  1890. /// @param engine Engine to use.
  1891. /// @param prop_kind Propagation kind. Possible values are
  1892. /// #dnnl_forward_training and #dnnl_forward_inference.
  1893. /// @param src_desc Source memory descriptor.
  1894. /// @param dst_desc Destination memory descriptor.
  1895. /// @param axis The axis along which the data is shuffled.
  1896. /// @param group_size Shuffle group size.
  1897. /// @param attr Primitive attributes (can be NULL).
  1898. /// @returns #dnnl_success on success and a status describing the error
  1899. /// otherwise.
  1900. dnnl_status_t DNNL_API dnnl_shuffle_forward_primitive_desc_create(
  1901. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1902. dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
  1903. const_dnnl_memory_desc_t dst_desc, int axis, dnnl_dim_t group_size,
  1904. const_dnnl_primitive_attr_t attr);
  1905. /// Creates a primitive descriptor for a shuffle backward propagation primitive
  1906. ///
  1907. /// @param primitive_desc Output primitive descriptor.
  1908. /// @param engine Engine to use.
  1909. /// @param diff_src_desc Diff source memory descriptor.
  1910. /// @param diff_dst_desc Diff destination memory descriptor.
  1911. /// @param axis The axis along which the data is shuffled.
  1912. /// @param group_size Shuffle group size.
  1913. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  1914. /// primitive.
  1915. /// @param attr Primitive attributes (can be NULL).
  1916. /// @returns #dnnl_success on success and a status describing the error
  1917. /// otherwise.
  1918. dnnl_status_t DNNL_API dnnl_shuffle_backward_primitive_desc_create(
  1919. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1920. const_dnnl_memory_desc_t diff_src_desc,
  1921. const_dnnl_memory_desc_t diff_dst_desc, int axis, dnnl_dim_t group_size,
  1922. const_dnnl_primitive_desc_t hint_fwd_pd,
  1923. const_dnnl_primitive_attr_t attr);
  1924. /// @} dnnl_api_shuffle
  1925. /// @addtogroup dnnl_api_eltwise
  1926. /// @{
  1927. /// Creates a primitive descriptor for an eltwise forward propagation primitive.
  1928. ///
  1929. /// @param primitive_desc Output primitive descriptor.
  1930. /// @param engine Engine to use.
  1931. /// @param prop_kind Propagation kind. Possible values are
  1932. /// #dnnl_forward_training and #dnnl_forward_inference.
  1933. /// @param alg_kind Elementwise algorithm kind.
  1934. /// @param src_desc Source memory descriptor.
  1935. /// @param dst_desc Destination memory descriptor.
  1936. /// @param alpha The alpha parameter for the elementwise operation. Specific
  1937. /// meaning depends on the algorithm.
  1938. /// @param beta The beta parameter for the elementwise operation. Specific
  1939. /// meaning depends on the algorithm.
  1940. /// @param attr Primitive attributes (can be NULL).
  1941. /// @returns #dnnl_success on success and a status describing the error
  1942. /// otherwise.
  1943. dnnl_status_t DNNL_API dnnl_eltwise_forward_primitive_desc_create(
  1944. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1945. dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
  1946. const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t dst_desc,
  1947. float alpha, float beta, const_dnnl_primitive_attr_t attr);
  1948. /// Creates a primitive descriptor for an eltwise backward propagation
  1949. /// primitive.
  1950. ///
  1951. /// @param primitive_desc Output primitive descriptor.
  1952. /// @param engine Engine to use.
  1953. /// @param alg_kind Elementwise algorithm kind.
  1954. /// @param diff_src_desc Diff source memory descriptor.
  1955. /// @param diff_dst_desc Diff destination memory descriptor.
  1956. /// @param data_desc Destination memory descriptor if one of the
  1957. /// "use_dst_for_bwd" algorithms are used (such as
  1958. /// #dnnl_eltwise_relu_use_dst_for_bwd), source memory descriptor otherwise.
  1959. /// @param alpha The alpha parameter for the elementwise operation. Specific
  1960. /// meaning depends on the algorithm.
  1961. /// @param beta The beta parameter for the elementwise operation. Specific
  1962. /// meaning depends on the algorithm.
  1963. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  1964. /// primitive.
  1965. /// @param attr Primitive attributes (can be NULL).
  1966. /// @returns #dnnl_success on success and a status describing the error
  1967. /// otherwise.
  1968. dnnl_status_t DNNL_API dnnl_eltwise_backward_primitive_desc_create(
  1969. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1970. dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
  1971. const_dnnl_memory_desc_t diff_dst_desc,
  1972. const_dnnl_memory_desc_t data_desc, float alpha, float beta,
  1973. const_dnnl_primitive_desc_t hint_fwd_pd,
  1974. const_dnnl_primitive_attr_t attr);
  1975. /// @} dnnl_api_eltwise
  1976. /// @addtogroup dnnl_api_softmax
  1977. /// @{
  1978. /// Creates a primitive descriptor for a softmax forward propagation primitive.
  1979. ///
  1980. /// @param primitive_desc Output primitive descriptor.
  1981. /// @param engine Engine to use.
  1982. /// @param prop_kind Propagation kind. Possible values are
  1983. /// #dnnl_forward_training and #dnnl_forward_inference.
  1984. /// @param alg_kind Softmax algorithm kind: either #dnnl_softmax_accurate, or
  1985. /// #dnnl_softmax_log.
  1986. /// @param src_desc Source memory descriptor.
  1987. /// @param dst_desc Destination memory descriptor.
  1988. /// @param softmax_axis Axis over which softmax is computed.
  1989. /// @param attr Primitive attributes (can be NULL).
  1990. /// @returns #dnnl_success on success and a status describing the error
  1991. /// otherwise.
  1992. dnnl_status_t DNNL_API dnnl_softmax_forward_primitive_desc_create(
  1993. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  1994. dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
  1995. const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t dst_desc,
  1996. int softmax_axis, const_dnnl_primitive_attr_t attr);
  1997. /// Creates a primitive descriptor for a softmax backward propagation primitive.
  1998. ///
  1999. /// @param primitive_desc Output primitive descriptor.
  2000. /// @param engine Engine to use.
  2001. /// @param alg_kind Softmax algorithm kind: either #dnnl_softmax_accurate, or
  2002. /// #dnnl_softmax_log.
  2003. /// @param diff_src_desc Diff source memory descriptor.
  2004. /// @param diff_dst_desc Diff destination memory descriptor.
  2005. /// @param dst_desc Destination memory descriptor.
  2006. /// @param softmax_axis Axis over which softmax is computed.
  2007. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  2008. /// primitive.
  2009. /// @param attr Primitive attributes (can be NULL).
  2010. /// @returns #dnnl_success on success and a status describing the error
  2011. /// otherwise.
  2012. dnnl_status_t DNNL_API dnnl_softmax_backward_primitive_desc_create(
  2013. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2014. dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
  2015. const_dnnl_memory_desc_t diff_dst_desc,
  2016. const_dnnl_memory_desc_t dst_desc, int softmax_axis,
  2017. const_dnnl_primitive_desc_t hint_fwd_pd,
  2018. const_dnnl_primitive_attr_t attr);
  2019. /// @} dnnl_api_softmax
  2020. /// @addtogroup dnnl_api_pooling
  2021. /// @{
  2022. /// Creates a primitive descriptor for a pooling forward propagation
  2023. /// primitive.
  2024. ///
  2025. /// Arrays @p strides, @p kernel, @p dilation, @p padding_l and @p padding_r
  2026. /// contain values for spatial dimensions only and hence must have the same
  2027. /// number of elements as there are spatial dimensions. The order of values
  2028. /// is the same as in the tensor: depth (for 3D tensors),
  2029. /// height (for 3D and 2D tensors), and width.
  2030. ///
  2031. /// @param primitive_desc Output primitive descriptor.
  2032. /// @param engine Engine to use.
  2033. /// @param prop_kind Propagation kind. Possible values are
  2034. /// #dnnl_forward_training and #dnnl_forward_inference.
  2035. /// @param alg_kind Pooling algorithm kind: either #dnnl_pooling_max,
  2036. /// #dnnl_pooling_avg_include_padding, or #dnnl_pooling_avg_exclude_padding.
  2037. /// @param src_desc Source memory descriptor.
  2038. /// @param dst_desc Destination memory descriptor.
  2039. /// @param strides Array of strides for spatial dimension.
  2040. /// @param kernel Array of kernel spatial dimensions.
  2041. /// @param dilation Array of dilations for spatial dimension.
  2042. /// @param padding_l Array of padding values for low indices for each spatial
  2043. /// dimension `([[front,] top,] left)`.
  2044. /// @param padding_r Array of padding values for high indices for each spatial
  2045. /// dimension `([[back,] bottom,] right)`. Can be NULL in which case
  2046. /// padding is considered to be symmetrical.
  2047. /// @param attr Primitive attributes (can be NULL).
  2048. /// @returns #dnnl_success on success and a status describing the error
  2049. /// otherwise.
  2050. dnnl_status_t DNNL_API dnnl_pooling_forward_primitive_desc_create(
  2051. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2052. dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
  2053. const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t dst_desc,
  2054. const dnnl_dims_t strides, const dnnl_dims_t kernel,
  2055. const dnnl_dims_t dilation, const dnnl_dims_t padding_l,
  2056. const dnnl_dims_t padding_r, const_dnnl_primitive_attr_t attr);
  2057. /// Creates a primitive descriptor for a pooling backward propagation
  2058. /// primitive.
  2059. ///
  2060. /// Arrays @p strides, @p kernel, @p dilation, @p padding_l and @p padding_r
  2061. /// contain values for spatial dimensions only and hence must have the same
  2062. /// number of elements as there are spatial dimensions. The order of values
  2063. /// is the same as in the tensor: depth (for 3D tensors),
  2064. /// height (for 3D and 2D tensors), and width.
  2065. ///
  2066. /// @param primitive_desc Output primitive descriptor.
  2067. /// @param engine Engine to use.
  2068. /// @param alg_kind Pooling algorithm kind: either #dnnl_pooling_max,
  2069. /// #dnnl_pooling_avg_include_padding, or #dnnl_pooling_avg_exclude_padding.
  2070. /// @param diff_src_desc Diff source memory descriptor.
  2071. /// @param diff_dst_desc Diff destination memory descriptor.
  2072. /// @param strides Array of strides for spatial dimension.
  2073. /// @param kernel Array of kernel spatial dimensions.
  2074. /// @param dilation Array of dilations for spatial dimension.
  2075. /// @param padding_l Array of padding values for low indices for each spatial
  2076. /// dimension `([[front,] top,] left)`.
  2077. /// @param padding_r Array of padding values for high indices for each spatial
  2078. /// dimension `([[back,] bottom,] right)`. Can be NULL in which case
  2079. /// padding is considered to be symmetrical.
  2080. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  2081. /// primitive.
  2082. /// @param attr Primitive attributes (can be NULL).
  2083. /// @returns #dnnl_success on success and a status describing the error
  2084. /// otherwise.
  2085. dnnl_status_t DNNL_API dnnl_pooling_backward_primitive_desc_create(
  2086. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2087. dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
  2088. const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides,
  2089. const dnnl_dims_t kernel, const dnnl_dims_t dilation,
  2090. const dnnl_dims_t padding_l, const dnnl_dims_t padding_r,
  2091. const_dnnl_primitive_desc_t hint_fwd_pd,
  2092. const_dnnl_primitive_attr_t attr);
  2093. /// @} dnnl_api_pooling
  2094. /// @addtogroup dnnl_api_prelu
  2095. /// @{
  2096. /// Creates a primitive descriptor for a PReLU (leaky ReLU with trainable
  2097. /// alpha parameter) forward propagation primitive.
  2098. ///
  2099. /// @note
  2100. /// weights descriptor is allowed to be initialized with
  2101. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  2102. ///
  2103. /// @param primitive_desc Output primitive descriptor.
  2104. /// @param engine Engine to use.
  2105. /// @param prop_kind Propagation kind. Possible values are
  2106. /// #dnnl_forward_training and #dnnl_forward_inference.
  2107. /// @param src_desc Source memory descriptor.
  2108. /// @param weights_desc Alpha parameters memory descriptor.
  2109. /// @param dst_desc Destination memory descriptor.
  2110. /// @param attr Primitive attributes (can be NULL).
  2111. /// @returns #dnnl_success on success and a status describing the error
  2112. /// otherwise.
  2113. dnnl_status_t DNNL_API dnnl_prelu_forward_primitive_desc_create(
  2114. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2115. dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
  2116. const_dnnl_memory_desc_t weights_desc,
  2117. const_dnnl_memory_desc_t dst_desc, const_dnnl_primitive_attr_t attr);
  2118. /// Creates a primitive descriptor for a PReLU (leaky ReLU with trainable
  2119. /// alpha parameter) backward propagation primitive.
  2120. ///
  2121. /// @note
  2122. /// weights descriptor and diff_weights descriptor are allowed
  2123. /// to be initialized with #dnnl_format_tag_any or with format_kind
  2124. /// set to #dnnl_format_kind_any.
  2125. ///
  2126. /// @param primitive_desc Output primitive descriptor.
  2127. /// @param engine Engine to use.
  2128. /// @param src_desc Source memory descriptor.
  2129. /// @param weights_desc Alpha parameters memory descriptor.
  2130. /// @param diff_src_desc Diff source memory descriptor.
  2131. /// @param diff_weights_desc Diff alpha parameters memory descriptor.
  2132. /// @param diff_dst_desc Diff destination memory descriptor.
  2133. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  2134. /// primitive.
  2135. /// @param attr Primitive attributes (can be NULL).
  2136. /// @returns #dnnl_success on success and a status describing the error
  2137. /// otherwise.
  2138. dnnl_status_t DNNL_API dnnl_prelu_backward_primitive_desc_create(
  2139. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2140. const_dnnl_memory_desc_t src_desc,
  2141. const_dnnl_memory_desc_t weights_desc,
  2142. const_dnnl_memory_desc_t diff_src_desc,
  2143. const_dnnl_memory_desc_t diff_weights_desc,
  2144. const_dnnl_memory_desc_t diff_dst_desc,
  2145. const_dnnl_primitive_desc_t hint_fwd_pd,
  2146. const_dnnl_primitive_attr_t attr);
  2147. /// @} dnnl_api_prelu
  2148. /// @addtogroup dnnl_api_lrn
  2149. /// @{
  2150. /// Creates a primitive descriptor for an LRN forward propagation primitive.
  2151. ///
  2152. /// @param primitive_desc Output primitive_descriptor.
  2153. /// @param engine Engine to use.
  2154. /// @param prop_kind Propagation kind. Possible values are
  2155. /// #dnnl_forward_training and #dnnl_forward_inference.
  2156. /// @param alg_kind LRN algorithm kind: either #dnnl_lrn_across_channels or
  2157. /// #dnnl_lrn_within_channel.
  2158. /// @param src_desc Source memory descriptor.
  2159. /// @param dst_desc Destination memory descriptor.
  2160. /// @param local_size Regularization local size.
  2161. /// @param alpha The alpha regularization parameter.
  2162. /// @param beta The beta regularization parameter.
  2163. /// @param k The k regularization parameter.
  2164. /// @param attr Primitive attributes (can be NULL).
  2165. /// @returns #dnnl_success on success and a status describing the error
  2166. /// otherwise.
  2167. dnnl_status_t DNNL_API dnnl_lrn_forward_primitive_desc_create(
  2168. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2169. dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
  2170. const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t dst_desc,
  2171. dnnl_dim_t local_size, float alpha, float beta, float k,
  2172. const_dnnl_primitive_attr_t attr);
  2173. /// Creates a primitive descriptor for an LRN backward propagation primitive.
  2174. ///
  2175. /// @param primitive_desc Output primitive_descriptor.
  2176. /// @param engine Engine to use.
  2177. /// @param alg_kind LRN algorithm kind: either #dnnl_lrn_across_channels or
  2178. /// #dnnl_lrn_within_channel.
  2179. /// @param diff_src_desc Diff source memory descriptor.
  2180. /// @param diff_dst_desc Diff destination memory descriptor.
  2181. /// @param src_desc Source memory descriptor.
  2182. /// @param local_size Regularization local size.
  2183. /// @param alpha The alpha regularization parameter.
  2184. /// @param beta The beta regularization parameter.
  2185. /// @param k The k regularization parameter.
  2186. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  2187. /// primitive.
  2188. /// @param attr Primitive attributes (can be NULL).
  2189. /// @returns #dnnl_success on success and a status describing the error
  2190. /// otherwise.
  2191. dnnl_status_t DNNL_API dnnl_lrn_backward_primitive_desc_create(
  2192. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2193. dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
  2194. const_dnnl_memory_desc_t diff_dst_desc,
  2195. const_dnnl_memory_desc_t src_desc, dnnl_dim_t local_size, float alpha,
  2196. float beta, float k, const_dnnl_primitive_desc_t hint_fwd_pd,
  2197. const_dnnl_primitive_attr_t attr);
  2198. /// @} dnnl_api_lrn
  2199. /// @addtogroup dnnl_api_batch_normalization
  2200. /// @{
  2201. /// Creates a primitive descriptor for a batch normalization forward propagation
  2202. /// primitive.
  2203. ///
  2204. /// @note
  2205. /// In-place operation is supported: the dst can refer to the same memory
  2206. /// as the src.
  2207. ///
  2208. /// @param primitive_desc Output primitive_descriptor.
  2209. /// @param engine Engine to use.
  2210. /// @param prop_kind Propagation kind. Possible values are
  2211. /// #dnnl_forward_training and #dnnl_forward_inference.
  2212. /// @param src_desc Source memory descriptor.
  2213. /// @param dst_desc Destination memory descriptor.
  2214. /// @param epsilon Batch normalization epsilon parameter.
  2215. /// @param flags Batch normalization flags (@ref dnnl_normalization_flags_t).
  2216. /// @param attr Primitive attributes (can be NULL).
  2217. /// @returns #dnnl_success on success and a status describing the error
  2218. /// otherwise.
  2219. dnnl_status_t DNNL_API dnnl_batch_normalization_forward_primitive_desc_create(
  2220. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2221. dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
  2222. const_dnnl_memory_desc_t dst_desc, float epsilon, unsigned flags,
  2223. const_dnnl_primitive_attr_t attr);
  2224. /// Creates a primitive descriptor for a batch normalization backward
  2225. /// propagation primitive.
  2226. ///
  2227. /// @note
  2228. /// In-place operation is supported: the diff_dst can refer to the same
  2229. /// memory as the diff_src.
  2230. ///
  2231. /// @param primitive_desc Output primitive_descriptor.
  2232. /// @param engine Engine to use.
  2233. /// @param prop_kind Propagation kind. Possible values are
  2234. /// #dnnl_backward_data and #dnnl_backward (diffs for all parameters are
  2235. /// computed in this case).
  2236. /// @param diff_src_desc Diff source memory descriptor.
  2237. /// @param diff_dst_desc Diff destination memory descriptor.
  2238. /// @param src_desc Source memory descriptor.
  2239. /// @param epsilon Batch normalization epsilon parameter.
  2240. /// @param flags Batch normalization flags (@ref dnnl_normalization_flags_t).
  2241. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  2242. /// primitive.
  2243. /// @param attr Primitive attributes (can be NULL).
  2244. /// @returns #dnnl_success on success and a status describing the error
  2245. /// otherwise.
  2246. dnnl_status_t DNNL_API dnnl_batch_normalization_backward_primitive_desc_create(
  2247. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2248. dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t diff_src_desc,
  2249. const_dnnl_memory_desc_t diff_dst_desc,
  2250. const_dnnl_memory_desc_t src_desc, float epsilon, unsigned flags,
  2251. const_dnnl_primitive_desc_t hint_fwd_pd,
  2252. const_dnnl_primitive_attr_t attr);
  2253. /// @} dnnl_api_batch_normalization
  2254. /// @addtogroup dnnl_api_group_normalization
  2255. /// @{
  2256. /// Creates a primitive descriptor for a group normalization forward propagation
  2257. /// primitive.
  2258. ///
  2259. /// @note
  2260. /// In-place operation is supported: the dst can refer to the same memory
  2261. /// as the src.
  2262. ///
  2263. /// @param primitive_desc Output primitive_descriptor.
  2264. /// @param engine Engine to use.
  2265. /// @param prop_kind Propagation kind. Possible values are
  2266. /// #dnnl_forward_training and #dnnl_forward_inference.
  2267. /// @param src_desc Source memory descriptor.
  2268. /// @param dst_desc Destination memory descriptor.
  2269. /// @param groups Group normalization groups parameter.
  2270. /// @param epsilon Group normalization epsilon parameter.
  2271. /// @param flags Group normalization flags (@ref dnnl_normalization_flags_t).
  2272. /// @param attr Primitive attributes (can be NULL).
  2273. /// @returns #dnnl_success on success and a status describing the error
  2274. /// otherwise.
  2275. dnnl_status_t DNNL_API dnnl_group_normalization_forward_primitive_desc_create(
  2276. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2277. dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
  2278. const_dnnl_memory_desc_t dst_desc, dnnl_dim_t groups, float epsilon,
  2279. unsigned flags, const_dnnl_primitive_attr_t attr);
  2280. /// Creates a primitive descriptor for a group normalization backward
  2281. /// propagation primitive.
  2282. ///
  2283. /// @note
  2284. /// In-place operation is supported: the diff_dst can refer to the same
  2285. /// memory as the diff_src.
  2286. ///
  2287. /// @param primitive_desc Output primitive_descriptor.
  2288. /// @param engine Engine to use.
  2289. /// @param prop_kind Propagation kind. Possible values are
  2290. /// #dnnl_backward_data and #dnnl_backward (diffs for all parameters are
  2291. /// computed in this case).
  2292. /// @param diff_src_desc Diff source memory descriptor.
  2293. /// @param diff_dst_desc Diff destination memory descriptor.
  2294. /// @param src_desc Source memory descriptor.
  2295. /// @param groups Group normalization groups parameter.
  2296. /// @param epsilon Group normalization epsilon parameter.
  2297. /// @param flags Group normalization flags (@ref dnnl_normalization_flags_t).
  2298. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  2299. /// primitive.
  2300. /// @param attr Primitive attributes (can be NULL).
  2301. /// @returns #dnnl_success on success and a status describing the error
  2302. /// otherwise.
  2303. dnnl_status_t DNNL_API dnnl_group_normalization_backward_primitive_desc_create(
  2304. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2305. dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t diff_src_desc,
  2306. const_dnnl_memory_desc_t diff_dst_desc,
  2307. const_dnnl_memory_desc_t src_desc, dnnl_dim_t groups, float epsilon,
  2308. unsigned flags, const_dnnl_primitive_desc_t hint_fwd_pd,
  2309. const_dnnl_primitive_attr_t attr);
  2310. /// @} dnnl_api_group_normalization
  2311. /// @addtogroup dnnl_api_layer_normalization
  2312. /// @{
  2313. /// Creates a primitive descriptor for a layer normalization forward propagation
  2314. /// primitive.
  2315. ///
  2316. /// @note
  2317. /// In-place operation is supported: the dst can refer to the same memory
  2318. /// as the src.
  2319. ///
  2320. /// @param primitive_desc Output primitive_descriptor.
  2321. /// @param engine Engine to use.
  2322. /// @param prop_kind Propagation kind. Possible values are
  2323. /// #dnnl_forward_training and #dnnl_forward_inference.
  2324. /// @param src_desc Source memory descriptor.
  2325. /// @param dst_desc Destination memory descriptor.
  2326. /// @param stat_desc Memory descriptor for mean and variance. If this
  2327. /// parameter is NULL, a zero memory descriptor, or a memory descriptor
  2328. /// with format_kind set to #dnnl_format_kind_undef, then the memory
  2329. /// descriptor for stats is derived from @p src_desc by removing the last
  2330. /// dimension.
  2331. /// @param epsilon Layer normalization epsilon parameter.
  2332. /// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t).
  2333. /// @param attr Primitive attributes (can be NULL).
  2334. /// @returns #dnnl_success on success and a status describing the error
  2335. /// otherwise.
  2336. dnnl_status_t DNNL_API dnnl_layer_normalization_forward_primitive_desc_create(
  2337. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2338. dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
  2339. const_dnnl_memory_desc_t dst_desc, const_dnnl_memory_desc_t stat_desc,
  2340. float epsilon, unsigned flags, const_dnnl_primitive_attr_t attr);
  2341. /// Creates a primitive descriptor for a layer normalization backward
  2342. /// propagation primitive.
  2343. ///
  2344. /// @note
  2345. /// In-place operation is supported: the diff_dst can refer to the same
  2346. /// memory as the diff_src.
  2347. ///
  2348. /// @param primitive_desc Output primitive_descriptor.
  2349. /// @param engine Engine to use.
  2350. /// @param prop_kind Propagation kind. Possible values are
  2351. /// #dnnl_backward_data and #dnnl_backward (diffs for all parameters are
  2352. /// computed in this case).
  2353. /// @param diff_src_desc Diff source memory descriptor.
  2354. /// @param diff_dst_desc Diff destination memory descriptor.
  2355. /// @param src_desc Source memory descriptor.
  2356. /// @param stat_desc Memory descriptor for mean and variance. If this
  2357. /// parameter is NULL, a zero memory descriptor, or a memory descriptor
  2358. /// with format_kind set to #dnnl_format_kind_undef, then the memory
  2359. /// descriptor for stats is derived from @p src_desc by removing the last
  2360. /// dimension.
  2361. /// @param epsilon Layer normalization epsilon parameter.
  2362. /// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t).
  2363. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  2364. /// primitive.
  2365. /// @param attr Primitive attributes (can be NULL).
  2366. /// @returns #dnnl_success on success and a status describing the error
  2367. /// otherwise.
  2368. dnnl_status_t DNNL_API dnnl_layer_normalization_backward_primitive_desc_create(
  2369. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2370. dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t diff_src_desc,
  2371. const_dnnl_memory_desc_t diff_dst_desc,
  2372. const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t stat_desc,
  2373. float epsilon, unsigned flags, const_dnnl_primitive_desc_t hint_fwd_pd,
  2374. const_dnnl_primitive_attr_t attr);
  2375. /// Creates a primitive descriptor for a layer normalization forward propagation
  2376. /// primitive with a user-provided data type for the scale and shift
  2377. /// memory objects.
  2378. ///
  2379. /// @note
  2380. /// In-place operation is supported: the dst can refer to the same memory
  2381. /// as the src.
  2382. ///
  2383. /// @param primitive_desc Output primitive_descriptor.
  2384. /// @param engine Engine to use.
  2385. /// @param prop_kind Propagation kind. Possible values are
  2386. /// #dnnl_forward_training and #dnnl_forward_inference.
  2387. /// @param src_desc Source memory descriptor.
  2388. /// @param dst_desc Destination memory descriptor.
  2389. /// @param stat_desc Memory descriptor for mean and variance. If this
  2390. /// parameter is NULL, a zero memory descriptor, or a memory descriptor
  2391. /// with format_kind set to #dnnl_format_kind_undef, then the memory
  2392. /// descriptor for stats is derived from @p src_desc by removing the last
  2393. /// dimension.
  2394. /// @param scale_shift_data_type Data type of scale and shift memory. If neither scale
  2395. /// nor shift flag are specified the parameter is ignored.
  2396. /// @param epsilon Layer normalization epsilon parameter.
  2397. /// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t).
  2398. /// @param attr Primitive attributes (can be NULL).
  2399. /// @returns #dnnl_success on success and a status describing the error
  2400. /// otherwise.
  2401. dnnl_status_t DNNL_API
  2402. dnnl_layer_normalization_forward_primitive_desc_create_v2(
  2403. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2404. dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
  2405. const_dnnl_memory_desc_t dst_desc, const_dnnl_memory_desc_t stat_desc,
  2406. dnnl_data_type_t scale_shift_data_type, float epsilon, unsigned flags,
  2407. const_dnnl_primitive_attr_t attr);
  2408. /// Creates a primitive descriptor for a layer normalization backward
  2409. /// propagation primitive with a user-provided data type for the
  2410. /// scale and shift memory objects.
  2411. ///
  2412. /// @note
  2413. /// In-place operation is supported: the diff_dst can refer to the same
  2414. /// memory as the diff_src.
  2415. ///
  2416. /// @param primitive_desc Output primitive_descriptor.
  2417. /// @param engine Engine to use.
  2418. /// @param prop_kind Propagation kind. Possible values are
  2419. /// #dnnl_backward_data and #dnnl_backward (diffs for all parameters are
  2420. /// computed in this case).
  2421. /// @param diff_src_desc Diff source memory descriptor.
  2422. /// @param diff_dst_desc Diff destination memory descriptor.
  2423. /// @param src_desc Source memory descriptor.
  2424. /// @param stat_desc Memory descriptor for mean and variance. If this
  2425. /// parameter is NULL, a zero memory descriptor, or a memory descriptor
  2426. /// with format_kind set to #dnnl_format_kind_undef, then the memory
  2427. /// descriptor for stats is derived from @p src_desc by removing the last
  2428. /// dimension.
  2429. /// @param diff_scale_shift_data_type Data type of diff scale and shift memory. If neither scale
  2430. /// nor shift flag are specified the parameter is ignored.
  2431. /// @param scale_shift_data_type Data type of scale and shift memory. If neither scale
  2432. /// nor shift flag are specified the parameter is ignored.
  2433. /// @param epsilon Layer normalization epsilon parameter.
  2434. /// @param flags Layer normalization flags (@ref dnnl_normalization_flags_t).
  2435. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  2436. /// primitive.
  2437. /// @param attr Primitive attributes (can be NULL).
  2438. /// @returns #dnnl_success on success and a status describing the error
  2439. /// otherwise.
  2440. dnnl_status_t DNNL_API
  2441. dnnl_layer_normalization_backward_primitive_desc_create_v2(
  2442. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2443. dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t diff_src_desc,
  2444. const_dnnl_memory_desc_t diff_dst_desc,
  2445. const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t stat_desc,
  2446. dnnl_data_type_t diff_scale_shift_data_type,
  2447. dnnl_data_type_t scale_shift_data_type, float epsilon, unsigned flags,
  2448. const_dnnl_primitive_desc_t hint_fwd_pd,
  2449. const_dnnl_primitive_attr_t attr);
  2450. /// @} dnnl_api_layer_normalization
  2451. /// @addtogroup dnnl_api_inner_product
  2452. /// @{
  2453. /// Creates a primitive descriptor for an inner product forward propagation
  2454. /// primitive.
  2455. ///
  2456. /// @note
  2457. /// Memory descriptors can be initialized with
  2458. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  2459. ///
  2460. /// @param primitive_desc Output primitive_descriptor.
  2461. /// @param engine Engine to use.
  2462. /// @param prop_kind Propagation kind. Possible values are
  2463. /// #dnnl_forward_training and #dnnl_forward_inference.
  2464. /// @param src_desc Source memory descriptor.
  2465. /// @param weights_desc Weights memory descriptor.
  2466. /// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
  2467. /// descriptor, or a memory descriptor with format_kind set to
  2468. /// #dnnl_format_kind_undef disables the bias term.
  2469. /// @param dst_desc Destination memory descriptor.
  2470. /// @param attr Primitive attributes (can be NULL).
  2471. /// @returns #dnnl_success on success and a status describing the error
  2472. /// otherwise.
  2473. dnnl_status_t DNNL_API dnnl_inner_product_forward_primitive_desc_create(
  2474. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2475. dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
  2476. const_dnnl_memory_desc_t weights_desc,
  2477. const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_desc,
  2478. const_dnnl_primitive_attr_t attr);
  2479. /// Creates a primitive descriptor for an inner product backward propagation
  2480. /// primitive.
  2481. ///
  2482. /// @note
  2483. /// Memory descriptors can be initialized with
  2484. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  2485. ///
  2486. /// @param primitive_desc Output primitive_descriptor.
  2487. /// @param engine Engine to use.
  2488. /// @param diff_src_desc Diff source memory descriptor.
  2489. /// @param weights_desc Weights memory descriptor.
  2490. /// @param diff_dst_desc Diff destination memory descriptor.
  2491. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  2492. /// primitive.
  2493. /// @param attr Primitive attributes (can be NULL).
  2494. /// @returns #dnnl_success on success and a status describing the error
  2495. /// otherwise.
  2496. dnnl_status_t DNNL_API dnnl_inner_product_backward_data_primitive_desc_create(
  2497. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2498. const_dnnl_memory_desc_t diff_src_desc,
  2499. const_dnnl_memory_desc_t weights_desc,
  2500. const_dnnl_memory_desc_t diff_dst_desc,
  2501. const_dnnl_primitive_desc_t hint_fwd_pd,
  2502. const_dnnl_primitive_attr_t attr);
  2503. /// Creates a primitive descriptor for an inner product weights gradient
  2504. /// primitive.
  2505. ///
  2506. /// @note
  2507. /// Memory descriptors can be initialized with
  2508. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  2509. ///
  2510. /// @param primitive_desc Output primitive_descriptor.
  2511. /// @param engine Engine to use.
  2512. /// @param src_desc Source memory descriptor.
  2513. /// @param diff_weights_desc Diff weights memory descriptor.
  2514. /// @param diff_bias_desc Diff bias memory descriptor. Passing NULL, a zero
  2515. /// memory descriptor, or a memory descriptor with format_kind set to
  2516. /// #dnnl_format_kind_undef disables the bias term.
  2517. /// @param diff_dst_desc Diff destination memory descriptor.
  2518. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  2519. /// primitive.
  2520. /// @param attr Primitive attributes (can be NULL).
  2521. /// @returns #dnnl_success on success and a status describing the error
  2522. /// otherwise.
  2523. dnnl_status_t DNNL_API
  2524. dnnl_inner_product_backward_weights_primitive_desc_create(
  2525. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2526. const_dnnl_memory_desc_t src_desc,
  2527. const_dnnl_memory_desc_t diff_weights_desc,
  2528. const_dnnl_memory_desc_t diff_bias_desc,
  2529. const_dnnl_memory_desc_t diff_dst_desc,
  2530. const_dnnl_primitive_desc_t hint_fwd_pd,
  2531. const_dnnl_primitive_attr_t attr);
  2532. /// @} dnnl_api_inner_product
  2533. /// @addtogroup dnnl_api_attributes
  2534. /// @{
  2535. /// Set quantization scale and shift parameters for RNN data tensors.
  2536. ///
  2537. /// For performance reasons, the low-precision configuration of the RNN
  2538. /// primitives expects input activations to have the unsigned 8-bit integer
  2539. /// data type. The scale and shift parameters are used to quantize
  2540. /// floating-point data to unsigned integer and must be passed to the RNN
  2541. /// primitive using attributes.
  2542. ///
  2543. /// The quantization formula is `scale * data + shift`.
  2544. ///
  2545. /// @note
  2546. /// Quantization scale and shift are common for src_layer, src_iter,
  2547. /// dst_iter, and dst_layer.
  2548. ///
  2549. /// Example usage:
  2550. /// @code
  2551. /// // RNN parameters
  2552. /// int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32;
  2553. /// // Activations quantization parameters
  2554. /// float scale = 63.f, shift = 64.f;
  2555. ///
  2556. /// dnnl_primitive_attr_t rnn_attr;
  2557. /// // Create default attributes
  2558. /// dnnl_primitive_attr_create(&rnn_attr);
  2559. ///
  2560. /// // Set scale and shift for int8 quantization of activation
  2561. /// dnnl_primitive_attr_set_rnn_data_qparams(rnn_attr, scale, shift);
  2562. ///
  2563. /// // Create an RNN primitive descriptor.
  2564. /// dnnl_primitive_desc_t rnn_pd;
  2565. /// dnnl_vanilla_rnn_forward_primitive_desc_create(&rnn_pd,
  2566. /// engine, /* arguments */, attr);
  2567. /// @endcode
  2568. ///
  2569. /// @param attr Primitive attributes.
  2570. /// @param scale The value to scale the data by.
  2571. /// @param shift The value to shift the data by.
  2572. /// @returns #dnnl_success on success and a status describing the error
  2573. /// otherwise.
  2574. dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_data_qparams(
  2575. dnnl_primitive_attr_t attr, const float scale, const float shift);
  2576. /// Returns the quantization scale and shift parameters for RNN data tensors.
  2577. ///
  2578. /// @note
  2579. /// Quantization scale and shift are common for src_layer, src_iter,
  2580. /// dst_iter, and dst_layer.
  2581. ///
  2582. /// @param attr Primitive attributes.
  2583. /// @param scale The value to scale the data by.
  2584. /// @param shift The value to shift the data by.
  2585. /// @returns #dnnl_success on success and a status describing the error
  2586. /// otherwise.
  2587. dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_data_qparams(
  2588. const_dnnl_primitive_attr_t attr, float *scale, float *shift);
  2589. /// Sets quantization scaling factors for RNN weights tensors. The
  2590. /// low-precision configuration of the RNN primitives expects input weights to
  2591. /// use the signed 8-bit integer data type. The scaling factors are used to
  2592. /// quantize floating-point data to signed integer and must be passed to RNN
  2593. /// primitives using attributes.
  2594. ///
  2595. /// @note
  2596. /// The dimension order is always native and does not depend on the actual
  2597. /// layout used. For example, five-dimensional weights always have (l, d,
  2598. /// i, g, o) logical dimension ordering.
  2599. ///
  2600. /// @note
  2601. /// Quantization scales are common for weights_layer and weights_iteration
  2602. ///
  2603. /// @param attr Primitive attributes.
  2604. /// @param count Number of elements in the @p scales array.
  2605. /// @param mask Scaling factors correspondence mask that defines the
  2606. /// correspondence between the output tensor dimensions and the @p
  2607. /// scales vector. The set i-th bit indicates that a dedicated scaling
  2608. /// factor should be used for each index along that dimension. Set the
  2609. /// mask to 0 to use a common scaling factor for the whole output
  2610. /// tensor.
  2611. /// @param scales Array of output scaling factors that must contain @p count
  2612. /// values and the following equality must hold:
  2613. /// \f[count = \prod\limits_{d \in mask} weights.dims[d].\f]
  2614. /// Violations can only be detected when the attributes are used to create
  2615. /// a primitive descriptor.
  2616. /// @returns #dnnl_success on success and a status describing the error
  2617. /// otherwise.
  2618. dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_qparams(
  2619. dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask,
  2620. const float *scales);
  2621. /// Returns the quantization scaling factors for RNN weights tensors.
  2622. ///
  2623. /// @param attr Primitive attributes.
  2624. /// @param count Number of elements in the @p scales array.
  2625. /// @param mask Scaling factors correspondence mask that defines the
  2626. /// correspondence between the output tensor dimensions and the @p
  2627. /// scales vector. The set i-th bit indicates that a dedicated scaling
  2628. /// factor should be used for each index along that dimension. Set the
  2629. /// mask to 0 to use a common scaling factor for the whole output
  2630. /// tensor.
  2631. /// @param scales Array of output scaling factors that contain @p count
  2632. /// values and the following equality must hold:
  2633. /// \f[count = \prod\limits_{d \in mask} weights.dims[d].\f]
  2634. /// @returns #dnnl_success on success and a status describing the error
  2635. /// otherwise.
  2636. dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_qparams(
  2637. const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask,
  2638. const float **scales);
  2639. /// Sets quantization scaling factors for RNN projection weights tensors. The
  2640. /// low-precision configuration of the RNN primitives expects input weights to
  2641. /// use the signed 8-bit integer data type. The scaling factors are used to
  2642. /// quantize floating-point data to signed integer and must be passed to RNN
  2643. /// primitives using attributes.
  2644. ///
  2645. /// @note
  2646. /// The dimension order is always native and does not depend on the actual
  2647. /// layout used. For example, five-dimensional weights always have (l, d,
  2648. /// i, g, o) logical dimension ordering.
  2649. ///
  2650. /// @param attr Primitive attributes.
  2651. /// @param count Number of elements in the @p scales array.
  2652. /// @param mask Scaling factors correspondence mask that defines the
  2653. /// correspondence between the output tensor dimensions and the @p
  2654. /// scales vector. The set i-th bit indicates that a dedicated scaling
  2655. /// factor should be used for each index along that dimension. Set the
  2656. /// mask to 0 to use a common scaling factor for the whole output
  2657. /// tensor.
  2658. /// @param scales Array of output scaling factors that must contain @p count
  2659. /// values and the following equality must hold:
  2660. /// \f[count = \prod\limits_{d \in mask} weights.dims[d].\f]
  2661. /// Violations can only be detected when the attributes are used to create
  2662. /// a primitive descriptor.
  2663. /// @returns #dnnl_success on success and a status describing the error
  2664. /// otherwise.
  2665. dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_projection_qparams(
  2666. dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask,
  2667. const float *scales);
  2668. /// Returns the quantization scaling factors for RNN projection weights tensors.
  2669. ///
  2670. /// @param attr Primitive attributes.
  2671. /// @param count Number of elements in the @p scales array.
  2672. /// @param mask Scaling factors correspondence mask that defines the
  2673. /// correspondence between the output tensor dimensions and the @p
  2674. /// scales vector. The set i-th bit indicates that a dedicated scaling
  2675. /// factor should be used for each index along that dimension. Set the
  2676. /// mask to 0 to use a common scaling factor for the whole output
  2677. /// tensor.
  2678. /// @param scales Array of output scaling factors that contain @p count
  2679. /// values and the following equality must hold:
  2680. /// \f[count = \prod\limits_{d \in mask} weights.dims[d].\f]
  2681. /// @returns #dnnl_success on success and a status describing the error
  2682. /// otherwise.
  2683. dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_projection_qparams(
  2684. const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask,
  2685. const float **scales);
  2686. /// @} dnnl_api_attributes
  2687. /// @addtogroup dnnl_api_rnn
  2688. /// @{
  2689. /// Creates a primitive descriptor for vanilla RNN forward propagation
  2690. /// primitive.
  2691. ///
  2692. /// The following arguments may either be @c NULL or point to a zero memory
  2693. /// descriptor:
  2694. /// - @p src_iter_desc,
  2695. /// - @p bias_desc,
  2696. /// - @p dst_iter_desc.
  2697. ///
  2698. /// This would then indicate that the RNN forward propagation primitive should
  2699. /// not use them and should default to zero values instead.
  2700. ///
  2701. /// @note
  2702. /// All memory descriptors can be initialized with
  2703. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  2704. ///
  2705. /// @param primitive_desc Output primitive descriptor.
  2706. /// @param engine Engine to use.
  2707. /// @param prop_kind Propagation kind. Possible values are
  2708. /// #dnnl_forward_training and #dnnl_forward_inference.
  2709. /// @param activation Activation kind. Possible values are #dnnl_eltwise_relu,
  2710. /// #dnnl_eltwise_tanh or #dnnl_eltwise_logistic.
  2711. /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
  2712. /// info.
  2713. /// @param src_layer_desc Memory descriptor for the input vector.
  2714. /// @param src_iter_desc Memory descriptor for the input recurrent hidden
  2715. /// state vector.
  2716. /// @param weights_layer_desc Memory descriptor for the weights applied to the
  2717. /// layer input.
  2718. /// @param weights_iter_desc Memory descriptor for the weights applied to the
  2719. /// recurrent input.
  2720. /// @param bias_desc Bias memory descriptor.
  2721. /// @param dst_layer_desc Memory descriptor for the output vector.
  2722. /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
  2723. /// state vector.
  2724. /// @param flags Unused.
  2725. /// @param alpha Negative slope if activation is #dnnl_eltwise_relu.
  2726. /// @param beta Unused.
  2727. /// @param attr Primitive attributes (can be NULL).
  2728. /// @returns #dnnl_success on success and a status describing the error
  2729. /// otherwise.
  2730. dnnl_status_t DNNL_API dnnl_vanilla_rnn_forward_primitive_desc_create(
  2731. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2732. dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation,
  2733. const dnnl_rnn_direction_t direction,
  2734. const_dnnl_memory_desc_t src_layer_desc,
  2735. const_dnnl_memory_desc_t src_iter_desc,
  2736. const_dnnl_memory_desc_t weights_layer_desc,
  2737. const_dnnl_memory_desc_t weights_iter_desc,
  2738. const_dnnl_memory_desc_t bias_desc,
  2739. const_dnnl_memory_desc_t dst_layer_desc,
  2740. const_dnnl_memory_desc_t dst_iter_desc, unsigned flags, float alpha,
  2741. float beta, const_dnnl_primitive_attr_t attr);
  2742. /// Creates a primitive descriptor for vanilla RNN backward propagation
  2743. /// primitive.
  2744. ///
  2745. /// The following arguments may either be @c NULL or point to a zero memory
  2746. /// descriptor:
  2747. /// - @p src_iter_desc together with @p diff_src_iter_desc,
  2748. /// - @p bias_desc together with @p diff_bias_desc,
  2749. /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
  2750. ///
  2751. /// This would then indicate that the RNN backward propagation primitive should
  2752. /// not use the respective data and should use zero values instead.
  2753. ///
  2754. /// @note
  2755. /// All memory descriptors can be initialized with
  2756. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  2757. ///
  2758. /// @param primitive_desc Output primitive descriptor.
  2759. /// @param engine Engine to use.
  2760. /// @param prop_kind Propagation kind. Must be #dnnl_backward.
  2761. /// @param activation Activation kind. Possible values are #dnnl_eltwise_relu,
  2762. /// #dnnl_eltwise_tanh or #dnnl_eltwise_logistic.
  2763. /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
  2764. /// info.
  2765. /// @param src_layer_desc Memory descriptor for the input vector.
  2766. /// @param src_iter_desc Memory descriptor for the input recurrent hidden
  2767. /// state vector.
  2768. /// @param weights_layer_desc Memory descriptor for the weights applied to the
  2769. /// layer input.
  2770. /// @param weights_iter_desc Memory descriptor for the weights applied to the
  2771. /// recurrent input.
  2772. /// @param bias_desc Bias memory descriptor.
  2773. /// @param dst_layer_desc Memory descriptor for the output vector.
  2774. /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
  2775. /// state vector.
  2776. /// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
  2777. /// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
  2778. /// hidden state vector.
  2779. /// @param diff_weights_layer_desc Memory descriptor for the diff of weights
  2780. /// applied to the layer input.
  2781. /// @param diff_weights_iter_desc Memory descriptor for the diff of weights
  2782. /// applied to the recurrent input.
  2783. /// @param diff_bias_desc Diff bias memory descriptor.
  2784. /// @param diff_dst_layer_desc Memory descriptor for the diff of output
  2785. /// vector.
  2786. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  2787. /// recurrent hidden state vector.
  2788. /// @param flags Unused.
  2789. /// @param alpha Negative slope if activation is #dnnl_eltwise_relu.
  2790. /// @param beta Unused.
  2791. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  2792. /// primitive.
  2793. /// @param attr Primitive attributes (can be NULL).
  2794. /// @returns #dnnl_success on success and a status describing the error
  2795. /// otherwise.
  2796. dnnl_status_t DNNL_API dnnl_vanilla_rnn_backward_primitive_desc_create(
  2797. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2798. dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation,
  2799. const dnnl_rnn_direction_t direction,
  2800. const_dnnl_memory_desc_t src_layer_desc,
  2801. const_dnnl_memory_desc_t src_iter_desc,
  2802. const_dnnl_memory_desc_t weights_layer_desc,
  2803. const_dnnl_memory_desc_t weights_iter_desc,
  2804. const_dnnl_memory_desc_t bias_desc,
  2805. const_dnnl_memory_desc_t dst_layer_desc,
  2806. const_dnnl_memory_desc_t dst_iter_desc,
  2807. const_dnnl_memory_desc_t diff_src_layer_desc,
  2808. const_dnnl_memory_desc_t diff_src_iter_desc,
  2809. const_dnnl_memory_desc_t diff_weights_layer_desc,
  2810. const_dnnl_memory_desc_t diff_weights_iter_desc,
  2811. const_dnnl_memory_desc_t diff_bias_desc,
  2812. const_dnnl_memory_desc_t diff_dst_layer_desc,
  2813. const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags,
  2814. float alpha, float beta, const_dnnl_primitive_desc_t hint_fwd_pd,
  2815. const_dnnl_primitive_attr_t attr);
  2816. /// Creates a primitive descriptor for an LSTM forward propagation primitive.
  2817. ///
  2818. /// The following arguments may either be @c NULL or point to a zero memory
  2819. /// descriptor:
  2820. /// - @p src_iter_desc together with @p src_iter_c_desc,
  2821. /// - @p weights_peephole_desc,
  2822. /// - @p bias_desc,
  2823. /// - @p dst_iter_desc together with @p dst_iter_c_desc.
  2824. ///
  2825. /// This would then indicate that the LSTM forward propagation primitive should
  2826. /// not use them and should default to zero values instead.
  2827. ///
  2828. /// The @p weights_projection_desc could either be @c NULL or point to a zero
  2829. /// memory descriptor. This would then indicate that the LSTM doesn't have
  2830. /// recurrent projection layer.
  2831. ///
  2832. /// @note
  2833. /// All memory descriptors can be initialized with #dnnl_format_tag_any or
  2834. /// with format_kind set to #dnnl_format_kind_any.
  2835. ///
  2836. /// @param primitive_desc Output primitive descriptor.
  2837. /// @param engine Engine to use.
  2838. /// @param prop_kind Propagation kind. Possible values are
  2839. /// #dnnl_forward_training and #dnnl_forward_inference.
  2840. /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
  2841. /// info.
  2842. /// @param src_layer_desc Memory descriptor for the input vector.
  2843. /// @param src_iter_desc Memory descriptor for the input recurrent hidden
  2844. /// state vector.
  2845. /// @param src_iter_c_desc Memory descriptor for the input recurrent cell
  2846. /// state vector.
  2847. /// @param weights_layer_desc Memory descriptor for the weights applied to the
  2848. /// layer input.
  2849. /// @param weights_iter_desc Memory descriptor for the weights applied to the
  2850. /// recurrent input.
  2851. /// @param weights_peephole_desc Memory descriptor for the weights applied to
  2852. /// the cell states (according to the Peephole LSTM formula).
  2853. /// @param weights_projection_desc Memory descriptor for the weights applied to
  2854. /// the hidden states to get the recurrent projection (according to the
  2855. /// Projection LSTM formula).
  2856. /// @param bias_desc Bias memory descriptor.
  2857. /// @param dst_layer_desc Memory descriptor for the output vector.
  2858. /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
  2859. /// state vector.
  2860. /// @param dst_iter_c_desc Memory descriptor for the output recurrent cell
  2861. /// state vector.
  2862. /// @param flags Unused.
  2863. /// @param attr Primitive attributes (can be NULL).
  2864. /// @returns #dnnl_success on success and a status describing the error
  2865. /// otherwise.
  2866. dnnl_status_t DNNL_API dnnl_lstm_forward_primitive_desc_create(
  2867. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2868. dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
  2869. const_dnnl_memory_desc_t src_layer_desc,
  2870. const_dnnl_memory_desc_t src_iter_desc,
  2871. const_dnnl_memory_desc_t src_iter_c_desc,
  2872. const_dnnl_memory_desc_t weights_layer_desc,
  2873. const_dnnl_memory_desc_t weights_iter_desc,
  2874. const_dnnl_memory_desc_t weights_peephole_desc,
  2875. const_dnnl_memory_desc_t weights_projection_desc,
  2876. const_dnnl_memory_desc_t bias_desc,
  2877. const_dnnl_memory_desc_t dst_layer_desc,
  2878. const_dnnl_memory_desc_t dst_iter_desc,
  2879. const_dnnl_memory_desc_t dst_iter_c_desc, unsigned flags,
  2880. const_dnnl_primitive_attr_t attr);
  2881. /// Creates a primitive descriptor for an LSTM backward propagation primitive.
  2882. ///
  2883. /// The following arguments may either be @c NULL or point to a zero memory
  2884. /// descriptor:
  2885. /// - @p src_iter_desc together with @p src_iter_c_desc, @p diff_src_iter_desc,
  2886. /// and @p diff_src_iter_c_desc,
  2887. /// - @p weights_peephole_desc together with @p diff_weights_peephole_desc,
  2888. /// - @p bias_desc together with @p diff_bias_desc,
  2889. /// - @p dst_iter_desc together with @p dst_iter_c_desc, @p diff_dst_iter_desc,
  2890. /// and @p diff_dst_iter_c_desc.
  2891. ///
  2892. /// This would then indicate that the LSTM backward propagation primitive
  2893. /// should not use them and should default to zero values instead.
  2894. ///
  2895. /// The @p weights_projection_desc together with @p
  2896. /// diff_weights_projection_desc could either be @c NULL or point to a zero
  2897. /// memory descriptor. This would then indicate that the LSTM doesn't have
  2898. /// recurrent projection layer.
  2899. ///
  2900. /// @note
  2901. /// All memory descriptors can be initialized with #dnnl_format_tag_any or
  2902. /// with format_kind set to #dnnl_format_kind_any.
  2903. ///
  2904. /// @param primitive_desc Output primitive descriptor.
  2905. /// @param engine Engine to use.
  2906. /// @param prop_kind Propagation kind. Must be #dnnl_backward.
  2907. /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
  2908. /// info.
  2909. /// @param src_layer_desc Memory descriptor for the input vector.
  2910. /// @param src_iter_desc Memory descriptor for the input recurrent hidden
  2911. /// state vector.
  2912. /// @param src_iter_c_desc Memory descriptor for the input recurrent cell
  2913. /// state vector.
  2914. /// @param weights_layer_desc Memory descriptor for the weights applied to the
  2915. /// layer input.
  2916. /// @param weights_iter_desc Memory descriptor for the weights applied to the
  2917. /// recurrent input.
  2918. /// @param weights_peephole_desc Memory descriptor for the weights applied to
  2919. /// the cell states (according to the Peephole LSTM formula).
  2920. /// @param weights_projection_desc Memory descriptor for the weights applied to
  2921. /// the hidden states to get the recurrent projection (according to the
  2922. /// Projection LSTM formula).
  2923. /// @param bias_desc Bias memory descriptor.
  2924. /// @param dst_layer_desc Memory descriptor for the output vector.
  2925. /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
  2926. /// state vector.
  2927. /// @param dst_iter_c_desc Memory descriptor for the output recurrent cell
  2928. /// state vector.
  2929. /// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
  2930. /// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
  2931. /// hidden state vector.
  2932. /// @param diff_src_iter_c_desc Memory descriptor for the diff of input
  2933. /// recurrent cell state vector.
  2934. /// @param diff_weights_layer_desc Memory descriptor for the diff of weights
  2935. /// applied to the layer input.
  2936. /// @param diff_weights_iter_desc Memory descriptor for the diff of weights
  2937. /// applied to the recurrent input.
  2938. /// @param diff_weights_peephole_desc Memory descriptor for the diff of weights
  2939. /// applied to the cell states (according to the Peephole LSTM formula).
  2940. /// @param diff_weights_projection_desc Memory descriptor for the diff of
  2941. /// weights applied to the hidden states to get the recurrent projection
  2942. /// (according to the Projection LSTM formula).
  2943. /// @param diff_bias_desc Diff bias memory descriptor.
  2944. /// @param diff_dst_layer_desc Memory descriptor for the diff of output
  2945. /// vector.
  2946. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  2947. /// recurrent hidden state vector.
  2948. /// @param diff_dst_iter_c_desc Memory descriptor for the diff of output
  2949. /// recurrent cell state vector.
  2950. /// @param flags Unused.
  2951. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  2952. /// primitive.
  2953. /// @param attr Primitive attributes (can be NULL).
  2954. /// @returns #dnnl_success on success and a status describing the error
  2955. /// otherwise.
  2956. dnnl_status_t DNNL_API dnnl_lstm_backward_primitive_desc_create(
  2957. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  2958. dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
  2959. const_dnnl_memory_desc_t src_layer_desc,
  2960. const_dnnl_memory_desc_t src_iter_desc,
  2961. const_dnnl_memory_desc_t src_iter_c_desc,
  2962. const_dnnl_memory_desc_t weights_layer_desc,
  2963. const_dnnl_memory_desc_t weights_iter_desc,
  2964. const_dnnl_memory_desc_t weights_peephole_desc,
  2965. const_dnnl_memory_desc_t weights_projection_desc,
  2966. const_dnnl_memory_desc_t bias_desc,
  2967. const_dnnl_memory_desc_t dst_layer_desc,
  2968. const_dnnl_memory_desc_t dst_iter_desc,
  2969. const_dnnl_memory_desc_t dst_iter_c_desc,
  2970. const_dnnl_memory_desc_t diff_src_layer_desc,
  2971. const_dnnl_memory_desc_t diff_src_iter_desc,
  2972. const_dnnl_memory_desc_t diff_src_iter_c_desc,
  2973. const_dnnl_memory_desc_t diff_weights_layer_desc,
  2974. const_dnnl_memory_desc_t diff_weights_iter_desc,
  2975. const_dnnl_memory_desc_t diff_weights_peephole_desc,
  2976. const_dnnl_memory_desc_t diff_weights_projection_desc,
  2977. const_dnnl_memory_desc_t diff_bias_desc,
  2978. const_dnnl_memory_desc_t diff_dst_layer_desc,
  2979. const_dnnl_memory_desc_t diff_dst_iter_desc,
  2980. const_dnnl_memory_desc_t diff_dst_iter_c_desc, unsigned flags,
  2981. const_dnnl_primitive_desc_t hint_fwd_pd,
  2982. const_dnnl_primitive_attr_t attr);
  2983. /// Creates a primitive descriptor for GRU forward propagation primitive.
  2984. ///
  2985. /// The following arguments may either be @c NULL or point to a zero memory
  2986. /// descriptor:
  2987. /// - @p src_iter_desc,
  2988. /// - @p bias_desc,
  2989. /// - @p dst_iter_desc.
  2990. ///
  2991. /// This would then indicate that the GRU forward propagation primitive should
  2992. /// not use them and should default to zero values instead.
  2993. ///
  2994. /// @note
  2995. /// All memory descriptors can be initialized with
  2996. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  2997. ///
  2998. /// @param primitive_desc Output primitive descriptor.
  2999. /// @param engine Engine to use.
  3000. /// @param prop_kind Propagation kind. Possible values are
  3001. /// #dnnl_forward_training and #dnnl_forward_inference.
  3002. /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
  3003. /// info.
  3004. /// @param src_layer_desc Memory descriptor for the input vector.
  3005. /// @param src_iter_desc Memory descriptor for the input recurrent hidden
  3006. /// state vector.
  3007. /// @param weights_layer_desc Memory descriptor for the weights applied to the
  3008. /// layer input.
  3009. /// @param weights_iter_desc Memory descriptor for the weights applied to the
  3010. /// recurrent input.
  3011. /// @param bias_desc Bias memory descriptor.
  3012. /// @param dst_layer_desc Memory descriptor for the output vector.
  3013. /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
  3014. /// state vector.
  3015. /// @param flags Unused.
  3016. /// @param attr Primitive attributes (can be NULL).
  3017. /// @returns #dnnl_success on success and a status describing the error
  3018. /// otherwise.
  3019. dnnl_status_t DNNL_API dnnl_gru_forward_primitive_desc_create(
  3020. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  3021. dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
  3022. const_dnnl_memory_desc_t src_layer_desc,
  3023. const_dnnl_memory_desc_t src_iter_desc,
  3024. const_dnnl_memory_desc_t weights_layer_desc,
  3025. const_dnnl_memory_desc_t weights_iter_desc,
  3026. const_dnnl_memory_desc_t bias_desc,
  3027. const_dnnl_memory_desc_t dst_layer_desc,
  3028. const_dnnl_memory_desc_t dst_iter_desc, unsigned flags,
  3029. const_dnnl_primitive_attr_t attr);
  3030. /// Creates a primitive descriptor for GRU backward propagation primitive.
  3031. ///
  3032. /// The following arguments may either be @c NULL or point to a zero memory
  3033. /// descriptor:
  3034. /// - @p src_iter_desc together with @p diff_src_iter_desc,
  3035. /// - @p bias_desc together with @p diff_bias_desc,
  3036. /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
  3037. ///
  3038. /// This would then indicate that the GRU backward propagation primitive
  3039. /// should not use them and should default to zero values instead.
  3040. ///
  3041. /// @note
  3042. /// All memory descriptors can be initialized with
  3043. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  3044. ///
  3045. /// @param primitive_desc Output primitive descriptor.
  3046. /// @param engine Engine to use.
  3047. /// @param prop_kind Propagation kind. Must be #dnnl_backward.
  3048. /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
  3049. /// info.
  3050. /// @param src_layer_desc Memory descriptor for the input vector.
  3051. /// @param src_iter_desc Memory descriptor for the input recurrent hidden
  3052. /// state vector.
  3053. /// @param weights_layer_desc Memory descriptor for the weights applied to the
  3054. /// layer input.
  3055. /// @param weights_iter_desc Memory descriptor for the weights applied to the
  3056. /// recurrent input.
  3057. /// @param bias_desc Bias memory descriptor.
  3058. /// @param dst_layer_desc Memory descriptor for the output vector.
  3059. /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
  3060. /// state vector.
  3061. /// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
  3062. /// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
  3063. /// hidden state vector.
  3064. /// @param diff_weights_layer_desc Memory descriptor for the diff of weights
  3065. /// applied to the layer input.
  3066. /// @param diff_weights_iter_desc Memory descriptor for the diff of weights
  3067. /// applied to the recurrent input.
  3068. /// @param diff_bias_desc Diff bias memory descriptor.
  3069. /// @param diff_dst_layer_desc Memory descriptor for the diff of output
  3070. /// vector.
  3071. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  3072. /// recurrent hidden state vector.
  3073. /// @param flags Unused.
  3074. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  3075. /// primitive.
  3076. /// @param attr Primitive attributes (can be NULL).
  3077. /// @returns #dnnl_success on success and a status describing the error
  3078. /// otherwise.
  3079. dnnl_status_t DNNL_API dnnl_gru_backward_primitive_desc_create(
  3080. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  3081. dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
  3082. const_dnnl_memory_desc_t src_layer_desc,
  3083. const_dnnl_memory_desc_t src_iter_desc,
  3084. const_dnnl_memory_desc_t weights_layer_desc,
  3085. const_dnnl_memory_desc_t weights_iter_desc,
  3086. const_dnnl_memory_desc_t bias_desc,
  3087. const_dnnl_memory_desc_t dst_layer_desc,
  3088. const_dnnl_memory_desc_t dst_iter_desc,
  3089. const_dnnl_memory_desc_t diff_src_layer_desc,
  3090. const_dnnl_memory_desc_t diff_src_iter_desc,
  3091. const_dnnl_memory_desc_t diff_weights_layer_desc,
  3092. const_dnnl_memory_desc_t diff_weights_iter_desc,
  3093. const_dnnl_memory_desc_t diff_bias_desc,
  3094. const_dnnl_memory_desc_t diff_dst_layer_desc,
  3095. const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags,
  3096. const_dnnl_primitive_desc_t hint_fwd_pd,
  3097. const_dnnl_primitive_attr_t attr);
  3098. /// Creates a descriptor for LBR GRU forward propagation primitive.
  3099. ///
  3100. /// The following arguments may either be @c NULL or point to a zero memory
  3101. /// descriptor:
  3102. /// - @p src_iter_desc,
  3103. /// - @p bias_desc,
  3104. /// - @p dst_iter_desc.
  3105. ///
  3106. /// This would then indicate that the LBR GRU forward propagation primitive
  3107. /// should not use them and should default to zero values instead.
  3108. ///
  3109. /// @param primitive_desc Output primitive descriptor.
  3110. /// @param engine Engine to use.
  3111. /// @param prop_kind Propagation kind. Possible values are
  3112. /// #dnnl_forward_training and #dnnl_forward_inference.
  3113. /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
  3114. /// info.
  3115. /// @param src_layer_desc Memory descriptor for the input vector.
  3116. /// @param src_iter_desc Memory descriptor for the input recurrent hidden
  3117. /// state vector.
  3118. /// @param weights_layer_desc Memory descriptor for the weights applied to the
  3119. /// layer input.
  3120. /// @param weights_iter_desc Memory descriptor for the weights applied to the
  3121. /// recurrent input.
  3122. /// @param bias_desc Bias memory descriptor.
  3123. /// @param dst_layer_desc Memory descriptor for the output vector.
  3124. /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
  3125. /// state vector.
  3126. /// @param flags Unused.
  3127. /// @param attr Primitive attributes (can be NULL).
  3128. /// @returns #dnnl_success on success and a status describing the error
  3129. /// otherwise.
  3130. dnnl_status_t DNNL_API dnnl_lbr_gru_forward_primitive_desc_create(
  3131. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  3132. dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
  3133. const_dnnl_memory_desc_t src_layer_desc,
  3134. const_dnnl_memory_desc_t src_iter_desc,
  3135. const_dnnl_memory_desc_t weights_layer_desc,
  3136. const_dnnl_memory_desc_t weights_iter_desc,
  3137. const_dnnl_memory_desc_t bias_desc,
  3138. const_dnnl_memory_desc_t dst_layer_desc,
  3139. const_dnnl_memory_desc_t dst_iter_desc, unsigned flags,
  3140. const_dnnl_primitive_attr_t attr);
  3141. /// Creates a primitive descriptor for LBR GRU backward propagation primitive.
  3142. ///
  3143. /// The following arguments may either be @c NULL or point to a zero memory
  3144. /// descriptor:
  3145. /// - @p src_iter_desc together with @p diff_src_iter_desc,
  3146. /// - @p bias_desc together with @p diff_bias_desc,
  3147. /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
  3148. ///
  3149. /// This would then indicate that the LBR GRU backward propagation primitive
  3150. /// should not use them and should default to zero values instead.
  3151. ///
  3152. /// @note
  3153. /// All memory descriptors can be initialized with
  3154. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  3155. ///
  3156. /// @param primitive_desc Output primitive descriptor.
  3157. /// @param engine Engine to use.
  3158. /// @param prop_kind Propagation kind. Must be #dnnl_backward.
  3159. /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
  3160. /// info.
  3161. /// @param src_layer_desc Memory descriptor for the input vector.
  3162. /// @param src_iter_desc Memory descriptor for the input recurrent hidden
  3163. /// state vector.
  3164. /// @param weights_layer_desc Memory descriptor for the weights applied to the
  3165. /// layer input.
  3166. /// @param weights_iter_desc Memory descriptor for the weights applied to the
  3167. /// recurrent input.
  3168. /// @param bias_desc Bias memory descriptor.
  3169. /// @param dst_layer_desc Memory descriptor for the output vector.
  3170. /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
  3171. /// state vector.
  3172. /// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
  3173. /// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
  3174. /// hidden state vector.
  3175. /// @param diff_weights_layer_desc Memory descriptor for the diff of weights
  3176. /// applied to the layer input.
  3177. /// @param diff_weights_iter_desc Memory descriptor for the diff of weights
  3178. /// applied to the recurrent input.
  3179. /// @param diff_bias_desc Diff bias memory descriptor.
  3180. /// @param diff_dst_layer_desc Memory descriptor for the diff of output
  3181. /// vector.
  3182. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  3183. /// recurrent hidden state vector.
  3184. /// @param flags Unused.
  3185. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  3186. /// primitive.
  3187. /// @param attr Primitive attributes (can be NULL).
  3188. /// @returns #dnnl_success on success and a status describing the error
  3189. /// otherwise.
  3190. dnnl_status_t DNNL_API dnnl_lbr_gru_backward_primitive_desc_create(
  3191. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  3192. dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
  3193. const_dnnl_memory_desc_t src_layer_desc,
  3194. const_dnnl_memory_desc_t src_iter_desc,
  3195. const_dnnl_memory_desc_t weights_layer_desc,
  3196. const_dnnl_memory_desc_t weights_iter_desc,
  3197. const_dnnl_memory_desc_t bias_desc,
  3198. const_dnnl_memory_desc_t dst_layer_desc,
  3199. const_dnnl_memory_desc_t dst_iter_desc,
  3200. const_dnnl_memory_desc_t diff_src_layer_desc,
  3201. const_dnnl_memory_desc_t diff_src_iter_desc,
  3202. const_dnnl_memory_desc_t diff_weights_layer_desc,
  3203. const_dnnl_memory_desc_t diff_weights_iter_desc,
  3204. const_dnnl_memory_desc_t diff_bias_desc,
  3205. const_dnnl_memory_desc_t diff_dst_layer_desc,
  3206. const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags,
  3207. const_dnnl_primitive_desc_t hint_fwd_pd,
  3208. const_dnnl_primitive_attr_t attr);
  3209. /// Creates a primitive descriptor for AUGRU forward propagation primitive.
  3210. ///
  3211. /// The following arguments may either be @c NULL or point to a zero memory
  3212. /// descriptor:
  3213. /// - @p src_iter_desc,
  3214. /// - @p bias_desc,
  3215. /// - @p dst_iter_desc.
  3216. ///
  3217. /// This would then indicate that the AUGRU forward propagation primitive should
  3218. /// not use them and should default to zero values instead.
  3219. ///
  3220. /// @note
  3221. /// All memory descriptors can be initialized with
  3222. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  3223. ///
  3224. /// @param primitive_desc Output primitive descriptor.
  3225. /// @param engine Engine to use.
  3226. /// @param prop_kind Propagation kind. Possible values are
  3227. /// #dnnl_forward_training and #dnnl_forward_inference.
  3228. /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
  3229. /// info.
  3230. /// @param src_layer_desc Memory descriptor for the input vector.
  3231. /// @param src_iter_desc Memory descriptor for the input recurrent hidden
  3232. /// state vector.
  3233. /// @param attention_desc Memory descriptor for the attention vector.
  3234. /// @param weights_layer_desc Memory descriptor for the weights applied to the
  3235. /// layer input.
  3236. /// @param weights_iter_desc Memory descriptor for the weights applied to the
  3237. /// recurrent input.
  3238. /// @param bias_desc Bias memory descriptor.
  3239. /// @param dst_layer_desc Memory descriptor for the output vector.
  3240. /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
  3241. /// state vector.
  3242. /// @param flags Unused.
  3243. /// @param attr Primitive attributes (can be NULL).
  3244. /// @returns #dnnl_success on success and a status describing the error
  3245. /// otherwise.
  3246. dnnl_status_t DNNL_API dnnl_augru_forward_primitive_desc_create(
  3247. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  3248. dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
  3249. const_dnnl_memory_desc_t src_layer_desc,
  3250. const_dnnl_memory_desc_t src_iter_desc,
  3251. const_dnnl_memory_desc_t attention_desc,
  3252. const_dnnl_memory_desc_t weights_layer_desc,
  3253. const_dnnl_memory_desc_t weights_iter_desc,
  3254. const_dnnl_memory_desc_t bias_desc,
  3255. const_dnnl_memory_desc_t dst_layer_desc,
  3256. const_dnnl_memory_desc_t dst_iter_desc, unsigned flags,
  3257. const_dnnl_primitive_attr_t attr);
  3258. /// Creates a primitive descriptor for AUGRU backward propagation primitive.
  3259. ///
  3260. /// The following arguments may either be @c NULL or point to a zero memory
  3261. /// descriptor:
  3262. /// - @p src_iter_desc together with @p diff_src_iter_desc,
  3263. /// - @p bias_desc together with @p diff_bias_desc,
  3264. /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
  3265. ///
  3266. /// This would then indicate that the AUGRU backward propagation primitive
  3267. /// should not use them and should default to zero values instead.
  3268. ///
  3269. /// @note
  3270. /// All memory descriptors can be initialized with
  3271. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  3272. ///
  3273. /// @param primitive_desc Output primitive descriptor.
  3274. /// @param engine Engine to use.
  3275. /// @param prop_kind Propagation kind. Must be #dnnl_backward.
  3276. /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
  3277. /// info.
  3278. /// @param src_layer_desc Memory descriptor for the input vector.
  3279. /// @param src_iter_desc Memory descriptor for the input recurrent hidden
  3280. /// state vector.
  3281. /// @param attention_desc Memory descriptor for the attention vector.
  3282. /// @param weights_layer_desc Memory descriptor for the weights applied to the
  3283. /// layer input.
  3284. /// @param weights_iter_desc Memory descriptor for the weights applied to the
  3285. /// recurrent input.
  3286. /// @param bias_desc Bias memory descriptor.
  3287. /// @param dst_layer_desc Memory descriptor for the output vector.
  3288. /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
  3289. /// state vector.
  3290. /// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
  3291. /// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
  3292. /// hidden state vector.
  3293. /// @param diff_attention_desc Memory descriptor for the diff of attention vector.
  3294. /// @param diff_weights_layer_desc Memory descriptor for the diff of weights
  3295. /// applied to the layer input.
  3296. /// @param diff_weights_iter_desc Memory descriptor for the diff of weights
  3297. /// applied to the recurrent input.
  3298. /// @param diff_bias_desc Diff bias memory descriptor.
  3299. /// @param diff_dst_layer_desc Memory descriptor for the diff of output
  3300. /// vector.
  3301. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  3302. /// recurrent hidden state vector.
  3303. /// @param flags Unused.
  3304. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  3305. /// primitive.
  3306. /// @param attr Primitive attributes (can be NULL).
  3307. /// @returns #dnnl_success on success and a status describing the error
  3308. /// otherwise.
  3309. dnnl_status_t DNNL_API dnnl_augru_backward_primitive_desc_create(
  3310. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  3311. dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
  3312. const_dnnl_memory_desc_t src_layer_desc,
  3313. const_dnnl_memory_desc_t src_iter_desc,
  3314. const_dnnl_memory_desc_t attention_desc,
  3315. const_dnnl_memory_desc_t weights_layer_desc,
  3316. const_dnnl_memory_desc_t weights_iter_desc,
  3317. const_dnnl_memory_desc_t bias_desc,
  3318. const_dnnl_memory_desc_t dst_layer_desc,
  3319. const_dnnl_memory_desc_t dst_iter_desc,
  3320. const_dnnl_memory_desc_t diff_src_layer_desc,
  3321. const_dnnl_memory_desc_t diff_src_iter_desc,
  3322. const_dnnl_memory_desc_t diff_attention_desc,
  3323. const_dnnl_memory_desc_t diff_weights_layer_desc,
  3324. const_dnnl_memory_desc_t diff_weights_iter_desc,
  3325. const_dnnl_memory_desc_t diff_bias_desc,
  3326. const_dnnl_memory_desc_t diff_dst_layer_desc,
  3327. const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags,
  3328. const_dnnl_primitive_desc_t hint_fwd_pd,
  3329. const_dnnl_primitive_attr_t attr);
  3330. /// Creates a primitive descriptor for LBR AUGRU forward propagation primitive.
  3331. ///
  3332. /// The following arguments may either be @c NULL or point to a zero memory
  3333. /// descriptor:
  3334. /// - @p src_iter_desc,
  3335. /// - @p bias_desc,
  3336. /// - @p dst_iter_desc.
  3337. ///
  3338. /// This would then indicate that the LBR AUGRU forward propagation primitive
  3339. /// should not use them and should default to zero values instead.
  3340. ///
  3341. /// @param primitive_desc Output primitive descriptor.
  3342. /// @param engine Engine to use.
  3343. /// @param prop_kind Propagation kind. Possible values are
  3344. /// #dnnl_forward_training and #dnnl_forward_inference.
  3345. /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
  3346. /// info.
  3347. /// @param src_layer_desc Memory descriptor for the input vector.
  3348. /// @param src_iter_desc Memory descriptor for the input recurrent hidden
  3349. /// state vector.
  3350. /// @param attention_desc Memory descriptor for the attention vector.
  3351. /// @param weights_layer_desc Memory descriptor for the weights applied to the
  3352. /// layer input.
  3353. /// @param weights_iter_desc Memory descriptor for the weights applied to the
  3354. /// recurrent input.
  3355. /// @param bias_desc Bias memory descriptor.
  3356. /// @param dst_layer_desc Memory descriptor for the output vector.
  3357. /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
  3358. /// state vector.
  3359. /// @param flags Unused.
  3360. /// @param attr Primitive attributes (can be NULL).
  3361. /// @returns #dnnl_success on success and a status describing the error
  3362. /// otherwise.
  3363. dnnl_status_t DNNL_API dnnl_lbr_augru_forward_primitive_desc_create(
  3364. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  3365. dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
  3366. const_dnnl_memory_desc_t src_layer_desc,
  3367. const_dnnl_memory_desc_t src_iter_desc,
  3368. const_dnnl_memory_desc_t attention_desc,
  3369. const_dnnl_memory_desc_t weights_layer_desc,
  3370. const_dnnl_memory_desc_t weights_iter_desc,
  3371. const_dnnl_memory_desc_t bias_desc,
  3372. const_dnnl_memory_desc_t dst_layer_desc,
  3373. const_dnnl_memory_desc_t dst_iter_desc, unsigned flags,
  3374. const_dnnl_primitive_attr_t attr);
  3375. /// Creates a primitive descriptor for LBR AUGRU backward propagation primitive.
  3376. ///
  3377. /// The following arguments may either be @c NULL or point to a zero memory
  3378. /// descriptor:
  3379. /// - @p src_iter_desc together with @p diff_src_iter_desc,
  3380. /// - @p bias_desc together with @p diff_bias_desc,
  3381. /// - @p dst_iter_desc together with @p diff_dst_iter_desc.
  3382. ///
  3383. /// This would then indicate that the LBR AUGRU backward propagation primitive
  3384. /// should not use them and should default to zero values instead.
  3385. ///
  3386. /// @note
  3387. /// All memory descriptors can be initialized with
  3388. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  3389. ///
  3390. /// @param primitive_desc Output primitive descriptor.
  3391. /// @param engine Engine to use.
  3392. /// @param prop_kind Propagation kind. Must be #dnnl_backward.
  3393. /// @param direction RNN direction. See @ref dnnl_rnn_direction_t for more
  3394. /// info.
  3395. /// @param src_layer_desc Memory descriptor for the input vector.
  3396. /// @param src_iter_desc Memory descriptor for the input recurrent hidden
  3397. /// state vector.
  3398. /// @param attention_desc Memory descriptor for the attention vector.
  3399. /// @param weights_layer_desc Memory descriptor for the weights applied to the
  3400. /// layer input.
  3401. /// @param weights_iter_desc Memory descriptor for the weights applied to the
  3402. /// recurrent input.
  3403. /// @param bias_desc Bias memory descriptor.
  3404. /// @param dst_layer_desc Memory descriptor for the output vector.
  3405. /// @param dst_iter_desc Memory descriptor for the output recurrent hidden
  3406. /// state vector.
  3407. /// @param diff_src_layer_desc Memory descriptor for the diff of input vector.
  3408. /// @param diff_src_iter_desc Memory descriptor for the diff of input recurrent
  3409. /// hidden state vector.
  3410. /// @param diff_attention_desc Memory descriptor for the diff of attention vector.
  3411. /// @param diff_weights_layer_desc Memory descriptor for the diff of weights
  3412. /// applied to the layer input.
  3413. /// @param diff_weights_iter_desc Memory descriptor for the diff of weights
  3414. /// applied to the recurrent input.
  3415. /// @param diff_bias_desc Diff bias memory descriptor.
  3416. /// @param diff_dst_layer_desc Memory descriptor for the diff of output
  3417. /// vector.
  3418. /// @param diff_dst_iter_desc Memory descriptor for the diff of output
  3419. /// recurrent hidden state vector.
  3420. /// @param flags Unused.
  3421. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  3422. /// primitive.
  3423. /// @param attr Primitive attributes (can be NULL).
  3424. /// @returns #dnnl_success on success and a status describing the error
  3425. /// otherwise.
  3426. dnnl_status_t DNNL_API dnnl_lbr_augru_backward_primitive_desc_create(
  3427. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  3428. dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
  3429. const_dnnl_memory_desc_t src_layer_desc,
  3430. const_dnnl_memory_desc_t src_iter_desc,
  3431. const_dnnl_memory_desc_t attention_desc,
  3432. const_dnnl_memory_desc_t weights_layer_desc,
  3433. const_dnnl_memory_desc_t weights_iter_desc,
  3434. const_dnnl_memory_desc_t bias_desc,
  3435. const_dnnl_memory_desc_t dst_layer_desc,
  3436. const_dnnl_memory_desc_t dst_iter_desc,
  3437. const_dnnl_memory_desc_t diff_src_layer_desc,
  3438. const_dnnl_memory_desc_t diff_src_iter_desc,
  3439. const_dnnl_memory_desc_t diff_attention_desc,
  3440. const_dnnl_memory_desc_t diff_weights_layer_desc,
  3441. const_dnnl_memory_desc_t diff_weights_iter_desc,
  3442. const_dnnl_memory_desc_t diff_bias_desc,
  3443. const_dnnl_memory_desc_t diff_dst_layer_desc,
  3444. const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags,
  3445. const_dnnl_primitive_desc_t hint_fwd_pd,
  3446. const_dnnl_primitive_attr_t attr);
  3447. /// @} dnnl_api_rnn
  3448. /// @addtogroup dnnl_api_matmul
  3449. /// @{
  3450. /// Creates a primitive descriptor for a matrix multiplication primitive.
  3451. ///
  3452. /// @param primitive_desc Output primitive descriptor.
  3453. /// @param engine Engine to use.
  3454. /// @param src_desc Source memory descriptor (matrix A)
  3455. /// @param weights_desc Weights memory descriptor (matrix B)
  3456. /// @param bias_desc Bias memory descriptor. Passing NULL, a zero memory
  3457. /// descriptor, or a memory descriptor with format_kind set to
  3458. /// #dnnl_format_kind_undef disables the bias term.
  3459. /// @param dst_desc Destination memory descriptor (matrix C).
  3460. /// @param attr Primitive attributes (can be NULL).
  3461. /// @returns #dnnl_success on success and a status describing the error
  3462. /// otherwise.
  3463. dnnl_status_t DNNL_API dnnl_matmul_primitive_desc_create(
  3464. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  3465. const_dnnl_memory_desc_t src_desc,
  3466. const_dnnl_memory_desc_t weights_desc,
  3467. const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_desc,
  3468. const_dnnl_primitive_attr_t attr);
  3469. /// @} dnnl_api_matmul
  3470. /// @addtogroup dnnl_api_resampling Resampling
  3471. /// @{
  3472. /// Creates a primitive descriptor for a resampling forward propagation
  3473. /// primitive.
  3474. ///
  3475. /// @note
  3476. /// Destination memory descriptor is allowed to be initialized with
  3477. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  3478. ///
  3479. /// @param primitive_desc Output primitive descriptor.
  3480. /// @param engine Engine to use.
  3481. /// @param prop_kind Propagation kind. Possible values are
  3482. /// #dnnl_forward_training and #dnnl_forward_inference.
  3483. /// @param alg_kind resampling algorithm kind: either #dnnl_resampling_nearest,
  3484. /// or #dnnl_resampling_linear.
  3485. /// @param factors Array of scaling factors for spatial dimension.
  3486. /// @param src_desc Source memory descriptor.
  3487. /// @param dst_desc Destination memory descriptor.
  3488. /// @param attr Primitive attributes (can be NULL).
  3489. /// @returns #dnnl_success on success and a status describing the error
  3490. /// otherwise.
  3491. dnnl_status_t DNNL_API dnnl_resampling_forward_primitive_desc_create(
  3492. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  3493. dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
  3494. const float *factors, const_dnnl_memory_desc_t src_desc,
  3495. const_dnnl_memory_desc_t dst_desc, const_dnnl_primitive_attr_t attr);
  3496. /// Creates a primitive descriptor for a resampling backward propagation
  3497. /// primitive.
  3498. ///
  3499. /// @param primitive_desc Output primitive descriptor.
  3500. /// @param engine Engine to use.
  3501. /// @param alg_kind resamplinging algorithm kind: either
  3502. /// #dnnl_resampling_nearest, or #dnnl_resampling_linear.
  3503. /// @param diff_src_desc Diff source memory descriptor.
  3504. /// @param diff_dst_desc Diff destination memory descriptor.
  3505. /// @param factors Array of scaling factors for spatial dimension.
  3506. /// @param hint_fwd_pd Primitive descriptor for a respective forward propagation
  3507. /// primitive.
  3508. /// @param attr Primitive attributes (can be NULL).
  3509. /// @returns #dnnl_success on success and a status describing the error
  3510. /// otherwise.
  3511. ///
  3512. dnnl_status_t DNNL_API dnnl_resampling_backward_primitive_desc_create(
  3513. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  3514. dnnl_alg_kind_t alg_kind, const float *factors,
  3515. const_dnnl_memory_desc_t diff_src_desc,
  3516. const_dnnl_memory_desc_t diff_dst_desc,
  3517. const_dnnl_primitive_desc_t hint_fwd_pd,
  3518. const_dnnl_primitive_attr_t attr);
  3519. /// @} dnnl_api_resampling
  3520. /// @addtogroup dnnl_api_reduction Reduction
  3521. /// @{
  3522. /// Creates a primitive descriptor for a reduction primitive.
  3523. ///
  3524. /// @note
  3525. /// Destination memory descriptor is allowed to be initialized with
  3526. /// #dnnl_format_tag_any or with format_kind set to #dnnl_format_kind_any.
  3527. ///
  3528. /// @param primitive_desc Output primitive descriptor.
  3529. /// @param engine Engine to use.
  3530. /// @param alg_kind reduction algorithm kind. Possible values:
  3531. /// #dnnl_reduction_max, #dnnl_reduction_min, #dnnl_reduction_sum,
  3532. /// #dnnl_reduction_mul, #dnnl_reduction_mean, #dnnl_reduction_norm_lp_max,
  3533. /// #dnnl_reduction_norm_lp_sum, #dnnl_reduction_norm_lp_power_p_max,
  3534. /// #dnnl_reduction_norm_lp_power_p_sum.
  3535. /// @param p Algorithm specific parameter.
  3536. /// @param eps Algorithm specific parameter.
  3537. /// @param src_desc Source memory descriptor.
  3538. /// @param dst_desc Destination memory descriptor.
  3539. /// @param attr Primitive attributes (can be NULL).
  3540. /// @returns #dnnl_success on success and a status describing the error
  3541. /// otherwise.
  3542. dnnl_status_t DNNL_API dnnl_reduction_primitive_desc_create(
  3543. dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
  3544. dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src_desc,
  3545. const_dnnl_memory_desc_t dst_desc, float p, float eps,
  3546. const_dnnl_primitive_attr_t attr);
  3547. /// @} dnnl_api_reduction
  3548. /// @} dnnl_api_primitives
  3549. /// @addtogroup dnnl_api_primitive_cache
  3550. /// @{
  3551. /// Returns the number of primitives that can be held in the primitive cache
  3552. /// at the same time.
  3553. ///
  3554. /// @param capacity Primitive cache capacity to query. Concurrently
  3555. /// accessing @p capacity is safe.
  3556. /// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the
  3557. /// @p capacity value is invalid, and #dnnl_success/#dnnl::status::success on
  3558. /// success.
  3559. dnnl_status_t DNNL_API dnnl_get_primitive_cache_capacity(int *capacity);
  3560. /// Sets a number of primitives that can be held in the primitive cache
  3561. /// at a time.
  3562. ///
  3563. /// @param capacity Primitive cache capacity to set. If a new @p capacity is
  3564. /// less than a number of primitives that the primitive cache already has
  3565. /// then the excess entries will be evicted. Setting the @p capacity to 0
  3566. /// clears the primitive cache and disables it. Concurrently modifying
  3567. /// @p capacity is safe.
  3568. /// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the
  3569. /// @p capacity value is invalid, and #dnnl_success/#dnnl::status::success on
  3570. /// success.
  3571. dnnl_status_t DNNL_API dnnl_set_primitive_cache_capacity(int capacity);
  3572. /// @} dnnl_api_primitive_cache
  3573. /// @addtogroup dnnl_api_service
  3574. /// @{
  3575. /// Configures dumping of JIT-generated code.
  3576. ///
  3577. /// @note
  3578. /// This setting overrides the DNNL_JIT_DUMP environment variable.
  3579. ///
  3580. /// @param enable Flag value. Set to 0 to disable and set to 1 to enable.
  3581. /// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the
  3582. /// @p flag value is invalid, and #dnnl_success/#dnnl::status::success on
  3583. /// success.
  3584. dnnl_status_t DNNL_API dnnl_set_jit_dump(int enable);
  3585. /// Sets library profiling flags. The flags define which profilers are
  3586. /// supported.
  3587. ///
  3588. /// @note
  3589. /// This setting overrides DNNL_JIT_PROFILE environment variable.
  3590. ///
  3591. /// @sa @ref dev_guide_profilers
  3592. ///
  3593. /// @param flags Profiling flags that can contain the following bits:
  3594. /// - @ref DNNL_JIT_PROFILE_VTUNE -- integration with VTune Profiler
  3595. /// (on by default)
  3596. /// - @ref DNNL_JIT_PROFILE_LINUX_JITDUMP -- produce Linux-specific
  3597. /// jit-pid.dump output (off by default). The location of the output
  3598. /// is controlled via JITDUMPDIR environment variable or via
  3599. /// dnnl_set_jit_profiling_jitdumpdir() function.
  3600. /// - @ref DNNL_JIT_PROFILE_LINUX_PERFMAP -- produce Linux-specific
  3601. /// perf-pid.map output (off by default). The output is always placed
  3602. /// into /tmp.
  3603. ///
  3604. /// Passing @ref DNNL_JIT_PROFILE_NONE disables profiling completely.
  3605. ///
  3606. /// @returns #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the
  3607. /// @p flags value is invalid, and #dnnl_success/#dnnl::status::success on
  3608. /// success.
  3609. dnnl_status_t DNNL_API dnnl_set_jit_profiling_flags(unsigned flags);
  3610. /// Sets JIT dump output path. Only applicable to Linux and is only
  3611. /// used when profiling flags have DNNL_JIT_PROFILE_LINUX_PERF bit set.
  3612. ///
  3613. /// After the first JIT kernel is generated, the jitdump output will be placed
  3614. /// into temporary directory created using the mkdtemp template
  3615. /// 'dir/.debug/jit/dnnl.XXXXXX'.
  3616. ///
  3617. /// @sa @ref dev_guide_profilers
  3618. ///
  3619. /// @note
  3620. /// This setting overrides JITDUMPDIR environment variable. If
  3621. /// JITDUMPDIR is not set, and this function is never called, the path
  3622. /// defaults to HOME. Passing NULL reverts the value to default.
  3623. ///
  3624. /// @note
  3625. /// The directory is accessed only when the first JIT kernel is being
  3626. /// created. JIT profiling will be disabled in case of any errors
  3627. /// accessing or creating this directory.
  3628. ///
  3629. /// @param dir JIT dump output path.
  3630. /// @returns #dnnl_success/#dnnl::status::success if the
  3631. /// output directory was set correctly and an error status otherwise.
  3632. /// @returns #dnnl_unimplemented/#dnnl::status::unimplemented on Windows.
  3633. dnnl_status_t DNNL_API dnnl_set_jit_profiling_jitdumpdir(const char *dir);
  3634. /// Sets the maximal ISA the library can dispatch to on the CPU. See
  3635. /// #dnnl_cpu_isa_t and #dnnl::cpu_isa for the list of the values accepted by
  3636. /// the C and C++ API functions respectively.
  3637. ///
  3638. /// This function has effect only once, and returns an error on subsequent
  3639. /// calls. It should also be invoked before any other oneDNN API call, otherwise
  3640. /// it may return an error.
  3641. ///
  3642. /// This function overrides the DNNL_MAX_CPU_ISA environment variable. The
  3643. /// environment variable can be set to the desired maximal ISA name in upper
  3644. /// case and with dnnl_cpu_isa prefix removed. For example:
  3645. /// `DNNL_MAX_CPU_ISA=AVX2`.
  3646. ///
  3647. /// @note
  3648. /// The ISAs are only partially ordered:
  3649. /// - SSE41 < AVX < AVX2 < AVX2_VNNI < AVX2_VNNI_2,
  3650. /// - AVX2 < AVX512_CORE < AVX512_CORE_VNNI < AVX512_CORE_BF16
  3651. /// < AVX10_1_512 < AVX10_2_512,
  3652. /// - AVX10_1_512 < AVX10_1_512_AMX < AVX10_1_512_AMX_FP16
  3653. /// < AVX10_2_512_AMX_2,
  3654. /// - AVX2_VNNI < AVX10_1_512,
  3655. /// - AVX10_2_512 < AVX10_2_512_AMX_2
  3656. ///
  3657. /// Aliases:
  3658. /// - AVX512_CORE_FP16 = AVX10_1_512
  3659. /// - AVX512_CORE_AMX = AVX10_1_512_AMX
  3660. /// - AVX512_CORE_AMX_FP16 = AVX10_1_512_AMX_FP16
  3661. ///
  3662. /// @sa @ref dev_guide_cpu_dispatcher_control for more details
  3663. ///
  3664. /// @param isa Maximal ISA the library should dispatch to. Pass
  3665. /// #dnnl_cpu_isa_default/#dnnl::cpu_isa::isa_default to remove ISA restrictions
  3666. /// (except for ISAs with initial support in the library).
  3667. /// @returns #dnnl_success/#dnnl::status::success on success and a
  3668. /// #dnnl_invalid_arguments/#dnnl::status::invalid_arguments if the @p isa
  3669. /// parameter is invalid or the ISA cannot be changed at this time.
  3670. /// @returns #dnnl_unimplemented/#dnnl::status::unimplemented if the feature
  3671. /// was disabled at build time (see @ref dev_guide_build_options for more
  3672. /// details).
  3673. dnnl_status_t DNNL_API dnnl_set_max_cpu_isa(dnnl_cpu_isa_t isa);
  3674. /// Gets the maximal ISA the library can dispatch to on the CPU. See
  3675. /// #dnnl_cpu_isa_t and #dnnl::cpu_isa for the list of the values returned by
  3676. /// the C and C++ API functions respectively.
  3677. ///
  3678. /// @sa @ref dev_guide_cpu_dispatcher_control for more details
  3679. ///
  3680. /// @returns #dnnl_cpu_isa_t value reflecting the maximal ISA the library may
  3681. /// dispatch to.
  3682. dnnl_cpu_isa_t DNNL_API dnnl_get_effective_cpu_isa(void);
  3683. /// Sets the hints flag for the CPU ISA. See #dnnl_cpu_isa_hints_t and
  3684. /// #dnnl::cpu_isa_hints for the list of the values accepted by the C and C++
  3685. /// API functions respectively.
  3686. ///
  3687. /// This function has effect only once, and returns an error on subsequent
  3688. /// calls. It should also be invoked before any other oneDNN API call, otherwise
  3689. /// it may return an error.
  3690. ///
  3691. /// This function overrides the DNNL_CPU_ISA_HINTS environment variable.
  3692. /// @sa @ref dev_guide_cpu_isa_hints for more details
  3693. ///
  3694. /// @param isa_hints CPU ISA hints to be passed over to the implementation.
  3695. /// Pass #dnnl_cpu_isa_no_hints/#dnnl::cpu_isa_hints::no_hints to use
  3696. /// default features i.e. no hints.
  3697. /// @returns #dnnl_success/#dnnl::status::success on success and a
  3698. /// #dnnl_runtime_error/#dnnl::status::runtime_error if the ISA hints cannot
  3699. /// be specified at the current time.
  3700. /// @returns #dnnl_unimplemented/#dnnl::status::unimplemented if the feature
  3701. /// was disabled at build time (see @ref dev_guide_build_options for more
  3702. /// details).
  3703. dnnl_status_t DNNL_API dnnl_set_cpu_isa_hints(dnnl_cpu_isa_hints_t isa_hints);
  3704. /// Gets the ISA specific hints that library can follow. See
  3705. /// #dnnl_cpu_isa_hints_t and #dnnl::cpu_isa_hints for the list of the values
  3706. /// returned by the C and C++ API functions respectively.
  3707. ///
  3708. /// @sa @ref dev_guide_cpu_isa_hints for more details
  3709. ///
  3710. /// @returns #dnnl_cpu_isa_hints_t value reflecting the ISA specific hints the
  3711. /// library can follow.
  3712. dnnl_cpu_isa_hints_t DNNL_API dnnl_get_cpu_isa_hints(void);
  3713. /// @} dnnl_api_service
  3714. #ifdef DNNL_EXPERIMENTAL_PROFILING
  3715. /// @addtogroup dnnl_api_profiling Profiling
  3716. /// @{
  3717. /// Resets a profiler's state.
  3718. ///
  3719. /// @param stream Stream associated with the profiler.
  3720. ///
  3721. /// @returns #dnnl_success on success and a status describing the error
  3722. /// otherwise.
  3723. dnnl_status_t DNNL_API dnnl_reset_profiling(dnnl_stream_t stream);
  3724. /// Queries profiling data. The profiling data accumulates for each primitive
  3725. /// execution. The @p num_entries will be equal to the number of executions
  3726. /// since the last `dnnl_reset_profiling` call. In order to query the
  3727. /// @p num_entries the @p data parameter should be NULL. When @p data is NULL
  3728. /// then the @p data_kind parameter is ignored.
  3729. ///
  3730. /// The profiling data can be reset by calling #dnnl_reset_profiling.
  3731. ///
  3732. /// @note
  3733. /// It is required to wait for all submitted primitives to complete
  3734. /// using #dnnl_stream_wait prior to querying profiling data.
  3735. ///
  3736. /// @param stream Stream that was used for executing a primitive that
  3737. /// is being profiled.
  3738. /// @param data_kind Profiling data kind to query.
  3739. /// @param num_entries Number of profiling data entries.
  3740. /// @param data Profiling data.
  3741. ///
  3742. /// @returns #dnnl_success on success and a status describing the error
  3743. /// otherwise.
  3744. dnnl_status_t DNNL_API dnnl_query_profiling_data(dnnl_stream_t stream,
  3745. dnnl_profiling_data_kind_t data_kind, int *num_entries, uint64_t *data);
  3746. /// @} dnnl_api_profiling
  3747. #endif
  3748. /// @addtogroup dnnl_api_blas
  3749. /// @{
  3750. /// Performs single-precision matrix-matrix multiply.
  3751. ///
  3752. /// The operation is defined as:
  3753. ///
  3754. /// `C := alpha * op( A ) * op( B ) + beta * C`
  3755. ///
  3756. /// where
  3757. /// - `op( X ) = X` or `op( X ) = X**T`,
  3758. /// - `alpha` and `beta` are scalars, and
  3759. /// - `A`, `B`, and `C` are matrices:
  3760. /// - `op( A )` is an `MxK` matrix,
  3761. /// - `op( B )` is an `KxN` matrix,
  3762. /// - `C` is an `MxN` matrix.
  3763. ///
  3764. /// The matrices are assumed to be stored in row-major order (the elements in
  3765. /// each of the matrix rows are contiguous in memory).
  3766. ///
  3767. /// @note
  3768. /// This API does not support XERBLA. Instead, unlike the standard BLAS
  3769. /// functions, this one returns a dnnl_status_t value to allow error
  3770. /// handling.
  3771. ///
  3772. /// @param transa Transposition flag for matrix A: 'N' or 'n' means A is not
  3773. /// transposed, and 'T' or 't' means that A is transposed.
  3774. /// @param transb Transposition flag for matrix B: 'N' or 'n' means B is not
  3775. /// transposed, and 'T' or 't' means that B is transposed.
  3776. /// @param M The M dimension.
  3777. /// @param N The N dimension.
  3778. /// @param K The K dimension.
  3779. /// @param alpha The alpha parameter that is used to scale the product of
  3780. /// matrices A and B.
  3781. /// @param A A pointer to the A matrix data.
  3782. /// @param lda The leading dimension for the matrix A.
  3783. /// @param B A pointer to the B matrix data.
  3784. /// @param ldb The leading dimension for the matrix B.
  3785. /// @param beta The beta parameter that is used to scale the matrix C.
  3786. /// @param C A pointer to the C matrix data.
  3787. /// @param ldc The leading dimension for the matrix C.
  3788. /// @returns #dnnl_success/#dnnl::status::success on success and a status
  3789. /// describing the error otherwise.
  3790. dnnl_status_t DNNL_API dnnl_sgemm(char transa, char transb, dnnl_dim_t M,
  3791. dnnl_dim_t N, dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda,
  3792. const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc);
  3793. /// Performs integer matrix-matrix multiply on 8-bit unsigned matrix A, 8-bit
  3794. /// signed matrix B, and 32-bit signed resulting matrix C.
  3795. ///
  3796. /// The operation is defined as:
  3797. ///
  3798. /// `C := alpha * (op(A) - A_offset) * (op(B) - B_offset) + beta * C + C_offset`
  3799. ///
  3800. /// where
  3801. /// - `op( X ) = X` or `op( X ) = X**T`,
  3802. /// - `alpha` and `beta` are scalars, and
  3803. /// - `A`, `B`, and `C` are matrices:
  3804. /// - `op( A )` is an `MxK` matrix,
  3805. /// - `op( B )` is an `KxN` matrix,
  3806. /// - `C` is an `MxN` matrix.
  3807. /// - `A_offset` is an `MxK` matrix with every element equal the `ao` value,
  3808. /// - `B_offset` is an `KxN` matrix with every element equal the `bo` value,
  3809. /// - `C_offset` is an `MxN` matrix which is defined by the `co` array of size `len`:
  3810. /// - if `offsetc = F`: the `len` must be at least `1`,
  3811. /// - if `offsetc = C`: the `len` must be at least `max(1, m)`,
  3812. /// - if `offsetc = R`: the `len` must be at least `max(1, n)`,
  3813. ///
  3814. /// The matrices are assumed to be stored in row-major order (the elements in
  3815. /// each of the matrix rows are contiguous in memory).
  3816. ///
  3817. /// @note
  3818. /// This API does not support XERBLA. Instead, unlike the standard BLAS
  3819. /// functions, this one returns a dnnl_status_t value to allow error
  3820. /// handling.
  3821. ///
  3822. /// @warning
  3823. /// On some architectures saturation may happen during intermediate
  3824. /// computations, which would lead to unexpected results. For more
  3825. /// details, refer to @ref dev_guide_int8_computations.
  3826. ///
  3827. /// @param transa Transposition flag for matrix A: 'N' or 'n' means A is not
  3828. /// transposed, and 'T' or 't' means that A is transposed.
  3829. /// @param transb Transposition flag for matrix B: 'N' or 'n' means B is not
  3830. /// transposed, and 'T' or 't' means that B is transposed.
  3831. /// @param offsetc Flag specifying how offsets should be applied to matrix C:
  3832. /// - 'F' means that the same offset will be applied to each element of
  3833. /// the matrix C,
  3834. /// - 'C' means that individual offset will be applied to each element
  3835. /// within each column,
  3836. /// - 'R' means that individual offset will be applied to each element
  3837. /// within each row.
  3838. /// @param M The M dimension.
  3839. /// @param N The N dimension.
  3840. /// @param K The K dimension.
  3841. /// @param alpha The alpha parameter that is used to scale the product of
  3842. /// matrices A and B.
  3843. /// @param A A pointer to the A matrix data.
  3844. /// @param lda The leading dimension for the matrix A.
  3845. /// @param ao The offset value for the matrix A.
  3846. /// @param B A pointer to the B matrix data.
  3847. /// @param ldb The leading dimension for the matrix B.
  3848. /// @param bo The offset value for the matrix B.
  3849. /// @param beta The beta parameter that is used to scale the matrix C.
  3850. /// @param C A pointer to the C matrix data.
  3851. /// @param ldc The leading dimension for the matrix C.
  3852. /// @param co An array of offset values for the matrix C. The number of
  3853. /// elements in the array depends on the value of @p offsetc.
  3854. /// @returns #dnnl_success/#dnnl::status::success on success and a status
  3855. /// describing the error otherwise.
  3856. dnnl_status_t DNNL_API dnnl_gemm_u8s8s32(char transa, char transb, char offsetc,
  3857. dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A,
  3858. dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
  3859. float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co);
  3860. /// Performs integer matrix-matrix multiply on 8-bit signed matrix A, 8-bit
  3861. /// signed matrix B, and 32-bit signed resulting matrix C.
  3862. ///
  3863. /// The operation is defined as:
  3864. ///
  3865. /// `C := alpha * (op(A) - A_offset) * (op(B) - B_offset) + beta * C + C_offset`
  3866. ///
  3867. /// where
  3868. /// - `op( X ) = X` or `op( X ) = X**T`,
  3869. /// - `alpha` and `beta` are scalars, and
  3870. /// - `A`, `B`, and `C` are matrices:
  3871. /// - `op( A )` is an `MxK` matrix,
  3872. /// - `op( B )` is an `KxN` matrix,
  3873. /// - `C` is an `MxN` matrix.
  3874. /// - `A_offset` is an `MxK` matrix with every element equal the `ao` value,
  3875. /// - `B_offset` is an `KxN` matrix with every element equal the `bo` value,
  3876. /// - `C_offset` is an `MxN` matrix which is defined by the `co` array of size `len`:
  3877. /// - if `offsetc = F`: the `len` must be at least `1`,
  3878. /// - if `offsetc = C`: the `len` must be at least `max(1, m)`,
  3879. /// - if `offsetc = R`: the `len` must be at least `max(1, n)`,
  3880. ///
  3881. /// The matrices are assumed to be stored in row-major order (the elements in
  3882. /// each of the matrix rows are contiguous in memory).
  3883. ///
  3884. /// @note
  3885. /// This API does not support XERBLA. Instead, unlike the standard BLAS
  3886. /// functions, this one returns a dnnl_status_t value to allow error
  3887. /// handling.
  3888. ///
  3889. /// @warning
  3890. /// On some architectures saturation may happen during intermediate
  3891. /// computations, which would lead to unexpected results. For more
  3892. /// details, refer to @ref dev_guide_int8_computations.
  3893. ///
  3894. /// @param transa Transposition flag for matrix A: 'N' or 'n' means A is not
  3895. /// transposed, and 'T' or 't' means that A is transposed.
  3896. /// @param transb Transposition flag for matrix B: 'N' or 'n' means B is not
  3897. /// transposed, and 'T' or 't' means that B is transposed.
  3898. /// @param offsetc Flag specifying how offsets should be applied to matrix C:
  3899. /// - 'F' means that the same offset will be applied to each element of
  3900. /// the matrix C,
  3901. /// - 'C' means that individual offset will be applied to each element
  3902. /// within each column,
  3903. /// - 'R' means that individual offset will be applied to each element
  3904. /// within each row.
  3905. /// @param M The M dimension.
  3906. /// @param N The N dimension.
  3907. /// @param K The K dimension.
  3908. /// @param alpha The alpha parameter that is used to scale the product of
  3909. /// matrices A and B.
  3910. /// @param A A pointer to the A matrix data.
  3911. /// @param lda The leading dimension for the matrix A.
  3912. /// @param ao The offset value for the matrix A.
  3913. /// @param B A pointer to the B matrix data.
  3914. /// @param ldb The leading dimension for the matrix B.
  3915. /// @param bo The offset value for the matrix B.
  3916. /// @param beta The beta parameter that is used to scale the matrix C.
  3917. /// @param C A pointer to the C matrix data.
  3918. /// @param ldc The leading dimension for the matrix C.
  3919. /// @param co An array of offset values for the matrix C. The number of
  3920. /// elements in the array depends on the value of @p offsetc.
  3921. /// @returns #dnnl_success/#dnnl::status::success on success and a status
  3922. /// describing the error otherwise.
  3923. dnnl_status_t DNNL_API dnnl_gemm_s8s8s32(char transa, char transb, char offsetc,
  3924. dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A,
  3925. dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
  3926. float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co);
  3927. /// @} dnnl_api_blas
  3928. /// @} dnnl_api
  3929. #ifdef __cplusplus
  3930. }
  3931. #endif
  3932. #endif /* ONEAPI_DNNL_DNNL_H */
  3933. #else
  3934. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  3935. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)