xnnpack.h 197 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. // Copyright (c) Facebook, Inc. and its affiliates.
  3. // All rights reserved.
  4. //
  5. // Copyright 2019 Google LLC
  6. //
  7. // This source code is licensed under the BSD-style license found in the
  8. // LICENSE file in the root directory of this source tree.
  9. #pragma once
  10. #include <stdbool.h>
  11. #include <stddef.h>
  12. #include <stdint.h>
  13. #include "pthreadpool.h"
  14. #ifdef __cplusplus
  15. extern "C" {
  16. #endif
  17. /// The number of bytes XNNPACK may read beyond array bounds.
  18. /// The caller must allocate at least this many extra bytes after the tensor data passed to XNNPACK.
  19. ///
  20. /// Note: XNNPACK reads, but never writes beyond array bounds.
  21. #if XNN_ARCH_HEXAGON
  22. #define XNN_EXTRA_BYTES 128
  23. #else
  24. #define XNN_EXTRA_BYTES 16
  25. #endif // XNN_ARCH_HEXAGON
  26. /// Maximum number of dimensions in tensor shape.
  27. #define XNN_MAX_TENSOR_DIMS 6
  28. /// A value ID that cannot be valid.
  29. #define XNN_INVALID_VALUE_ID UINT32_MAX
  30. /// Allow sparse inference in a Runtime.
  31. ///
  32. /// Note: this flag is a hint to XNNPACK that it should consider sparse inference, but does not guarantee it.
  33. #define XNN_FLAG_HINT_SPARSE_INFERENCE 0x00000001
  34. /// Allow IEEE FP16 inference in a Runtime.
  35. ///
  36. /// Note: this flag hints XNNPACK to consider IEEE FP16 inference, but does not guarantee it.
  37. #define XNN_FLAG_HINT_FP16_INFERENCE 0x00000002
  38. /// Force IEEE FP16 inference in a Runtime, and fail if FP16 inference is not possible.
  39. ///
  40. /// Note: this flag guarantees that XNNPACK will use IEEE FP16 inference, or fail to create the Runtime object.
  41. /// Warning: on x86 systems FP16 computations will be emulated at a substantial performance cost.
  42. #define XNN_FLAG_FORCE_FP16_INFERENCE 0x00000004
  43. /// Enable timing of each operator's runtime.
  44. #define XNN_FLAG_BASIC_PROFILING 0x00000008
  45. /// Enable the just-in-time compiler.
  46. #define XNN_FLAG_JIT 0x00000010
  47. /// The convolution operator represents a depthwise convolution, and use HWGo layout for filters.
  48. #define XNN_FLAG_DEPTHWISE_CONVOLUTION 0x00000001
  49. /// Assume transposed weights in a fully connected operator.
  50. #define XNN_FLAG_TRANSPOSE_WEIGHTS 0x00000001
  51. /// The operator assumes NHWC layout for the input, regardless of the output layout.
  52. #define XNN_FLAG_INPUT_NHWC 0x00000002
  53. /// Match "SAME" padding in TensorFlow. Exact padding values are computed dynamically depending on input size.
  54. #define XNN_FLAG_TENSORFLOW_SAME_PADDING 0x00000004
  55. /// Assume transposed weights in a batch matrix multiply operator.
  56. #define XNN_FLAG_TRANSPOSE_B XNN_FLAG_TRANSPOSE_WEIGHTS
  57. /// Assume transposed input in a batch matrix multiply operator.
  58. #define XNN_FLAG_TRANSPOSE_A 0x00000002
  59. /// Implicitly flatten and reshape input of a Fully Connected operator into a 2D tensor.
  60. #define XNN_FLAG_TENSORFLOW_RESHAPE_2D 0x00000004
  61. /// Match behaviour of TensorFlow 1.x.
  62. #define XNN_FLAG_TENSORFLOW_LEGACY_MODE 0x00000004
  63. /// Static weights of the FP16 operator are in FP32 format.
  64. #define XNN_FLAG_FP32_STATIC_WEIGHTS 0x00000008
  65. /// Static biases of the FP16 operator are in FP32 format.
  66. #define XNN_FLAG_FP32_STATIC_BIASES 0x00000080
  67. /// Align corners of input and output images in resize operations.
  68. #define XNN_FLAG_ALIGN_CORNERS 0x00000008
  69. /// Yield worker threads of the thread pool to the system scheduler after the inference.
  70. #define XNN_FLAG_YIELD_WORKERS 0x00000010
  71. /// Use transient indirection buffer to reduce memory footprint
  72. #define XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER 0x00000020
  73. /// Retain reduced dimensions with length 1.
  74. #define XNN_FLAG_KEEP_DIMS 0x00000040
  75. // Next unused flag value: 0x00000100.
  76. /// The number of entries in an array of xnn_quantization_params that XNNPACK may read beyond array bounds.
  77. /// The caller must allocate at least this many extra xnn_quantization_params before passing the array to XNNPACK.
  78. ///
  79. /// Note: XNNPACK reads, but never writes beyond array bounds.
  80. #define XNN_EXTRA_QUANTIZATION_PARAMS 15
  81. /// The minimum blocksize for blockwise quantized operators.
  82. #define XNN_MIN_BLOCKSIZE 32
  83. #ifdef __GNUC__
  84. #define XNN_DEPRECATED __attribute__((deprecated))
  85. #else
  86. #define XNN_DEPRECATED
  87. #endif
  88. struct xnn_quantization_params {
  89. int32_t zero_point;
  90. float scale;
  91. };
  92. /// Status code for any XNNPACK function call.
  93. enum xnn_status {
  94. /// The call succeeded, and all output arguments now contain valid data.
  95. xnn_status_success = 0,
  96. xnn_status_uninitialized = 1,
  97. xnn_status_invalid_parameter = 2,
  98. xnn_status_invalid_state = 3,
  99. xnn_status_unsupported_parameter = 4,
  100. xnn_status_unsupported_hardware = 5,
  101. xnn_status_out_of_memory = 6,
  102. xnn_status_reallocation_required = 7,
  103. xnn_status_deprecated = 8,
  104. };
  105. struct xnn_allocator {
  106. /// User-specified pointer that will be passed as-is to all functions in this structure.
  107. void* context;
  108. /// Pointer to a function to be called for general memory allocation.
  109. ///
  110. /// @param context - The user-specified pointer from xnn_allocator structure.
  111. /// @param size - The size of the memory block to allocate, in bytes.
  112. ///
  113. /// @returns Pointer to the allocated memory block of at least @ref size bytes.
  114. /// If allocation fails, the function must return NULL.
  115. void* (*allocate)(void* context, size_t size);
  116. /// Pointer to a function to be called for general memory re-allocation, i.e. to increase or shrink a previously
  117. /// allocated memory block. The content of the old memory block is copied to the new memory block.
  118. ///
  119. /// @param context - The user-specified pointer from xnn_allocator structure.
  120. /// @param pointer - Pointer to a memory block allocated by @ref allocate or @ref reallocate functions. Can be NULL.
  121. /// If the pointer is NULL, the @ref reallocate call is equivalent to an @ref allocate call.
  122. /// @param size - The new size of the memory block to allocate, in bytes.
  123. ///
  124. /// @returns Pointer to the newly allocated memory block of at least @ref size bytes with the content of the previous
  125. /// memory block.
  126. /// If allocation fails, the function must return NULL, but must not release the previous memory block.
  127. void* (*reallocate)(void* context, void* pointer, size_t size);
  128. /// Pointer to a function to be called for general memory de-allocation.
  129. ///
  130. /// @param context - The user-specified pointer from xnn_allocator structure.
  131. /// @param pointer - Pointer to a memory block allocated by @ref allocate or @ref reallocate functions. Can be NULL.
  132. /// If the pointer is NULL, the @ref deallocate call is a no-op.
  133. void (*deallocate)(void* context, void* pointer);
  134. /// Pointer to a function to be called for aligned memory allocation.
  135. ///
  136. /// @param context - The user-specified pointer from xnn_allocator structure.
  137. /// @param alignment - The alignment of the memory block to allocate, in bytes. Alignment is always a power-of-2.
  138. /// @param size - The size of the memory block to allocate, in bytes.
  139. ///
  140. /// @returns Pointer to the allocated memory block of at least @ref size bytes.
  141. /// If allocation fails, the function must return NULL.
  142. void* (*aligned_allocate)(void* context, size_t alignment, size_t size);
  143. /// Pointer to a function to be called for aligned memory deallocation.
  144. ///
  145. /// @param context - The user-specified pointer from xnn_allocator structure.
  146. /// @param pointer - Pointer to a memory block allocated by @ref aligned_allocate function. Can be NULL.
  147. /// If the pointer is NULL, the @ref aligned_deallocate call is a no-op.
  148. void (*aligned_deallocate)(void* context, void* pointer);
  149. };
  150. /// Initialize XNNPACK library.
  151. ///
  152. /// XNNPACK must be successfully initialized before use. During initialization, XNNPACK populates internal structures
  153. /// depending on the host processor. Initialization can be time-consuming.
  154. ///
  155. /// @param[in] allocator - structure with function pointers to be use for memory allocation and de-allocation.
  156. /// If this argument is NULL, system-provided memory management functions (e.g. malloc/free)
  157. /// will be used.
  158. ///
  159. /// @retval xnn_status_success - XNNPACK is successfully initialized and ready to use.
  160. /// @retval xnn_status_out_of_memory - initialization failed due to out-of-memory condition.
  161. /// @retval xnn_status_unsupported_hardware - initialization failed because the host processor does not satisfy the
  162. /// minimum hardware requirements for XNNPACK. E.g. this may happen on x86
  163. /// processors without SSE2 extension, or on 32-bit ARM processors without
  164. /// the NEON SIMD extension.
  165. enum xnn_status xnn_initialize(const struct xnn_allocator* allocator);
  166. /// Deinitialize XNNPACK library.
  167. ///
  168. /// To avoid memory and resource leaks, users must call xnn_deinitialize once for each successful xnn_initialize call.
  169. ///
  170. /// @retval xnn_status_success - deinitialization call succeeded.
  171. enum xnn_status xnn_deinitialize(void);
  172. /// Get the microkernel implementation build identifier's data.
  173. ///
  174. /// That identifier will be unique for the current set of microkernels implementations.
  175. ///
  176. /// @returns A pointer to the current identifier's data.
  177. const void* xnn_experimental_get_build_identifier_data();
  178. /// Get the microkernel implementation build identifier's data size.
  179. ///
  180. /// @returns The size in bytes of the identifier's data.
  181. size_t xnn_experimental_get_build_identifier_size();
  182. /// Check whether the given data matches this version's identifier.
  183. ///
  184. /// @returns The size in bytes of the identifier's data.
  185. bool xnn_experimental_check_build_identifier(const void* data, size_t size);
  186. /// Subgraph is an abstract representation of a neural network model.
  187. /// Subgraph objects are used to define Values (tensors) and Nodes (operators) comprising the model.
  188. typedef struct xnn_subgraph* xnn_subgraph_t;
  189. /// Create a empty Subgraph object.
  190. ///
  191. /// @param external_value_ids - number of Value IDs to reserve for communication with external graph representation.
  192. /// The Subgraph object would avoid creating internal Value IDs in the
  193. /// [0, reserved_value_ids-1] range.
  194. /// @param flags - binary features of the subgraph. No supported flags are currently defined.
  195. /// @param subgraph_out - pointer to the variable that will be initialized with a handle to the Subgraph object upon
  196. /// successful return.
  197. enum xnn_status xnn_create_subgraph(
  198. uint32_t external_value_ids,
  199. uint32_t flags,
  200. xnn_subgraph_t* subgraph_out);
  201. /// Destroy a Subgraph object, as well as Values, and Nodes associated with the subgraph.
  202. ///
  203. /// @param subgraph - the Subgraph object to destroy.
  204. enum xnn_status xnn_delete_subgraph(
  205. xnn_subgraph_t subgraph);
  206. #define XNN_VALUE_FLAG_EXTERNAL_INPUT 0x00000001
  207. #define XNN_VALUE_FLAG_EXTERNAL_OUTPUT 0x00000002
  208. #define XNN_VALUE_FLAG_PERSISTENT 0x00000004
  209. #define XNN_INVALID_VALUE_ID UINT32_MAX
  210. /// Type of elements in a Value object.
  211. enum xnn_datatype {
  212. /// Invalid data type. Valid Values never have this datatype.
  213. xnn_datatype_invalid = 0,
  214. /// IEEE754 single-precision floating-point.
  215. xnn_datatype_fp32 = 1,
  216. /// IEEE754 half-precision floating-point.
  217. xnn_datatype_fp16 = 2,
  218. /// Quantized 8-bit signed integer with shared per-Value quantization
  219. /// parameters.
  220. xnn_datatype_qint8 = 3,
  221. /// Quantized 8-bit unsigned integer with shared per-Value quantization
  222. /// parameters.
  223. xnn_datatype_quint8 = 4,
  224. /// Quantized 32-bit signed integer with shared per-Value quantization
  225. /// parameters.
  226. xnn_datatype_qint32 = 5,
  227. /// Quantized 8-bit signed integer with shared per-channel quantization
  228. /// parameters.
  229. xnn_datatype_qcint8 = 6,
  230. /// Quantized 32-bit signed integer with shared per-channel quantization
  231. /// parameters.
  232. xnn_datatype_qcint32 = 7,
  233. /// Quantized 4-bit signed integer with shared per-channel quantization
  234. /// parameters.
  235. xnn_datatype_qcint4 = 8,
  236. /// Dynamically quantized 8-bit signed integer with per-batch quantization
  237. /// parameters.
  238. xnn_datatype_qdint8 = 9,
  239. /// Dynamically quantized 8-bit signed integers packed with their per-row
  240. /// quantization parameters.
  241. xnn_datatype_qpint8 = 10,
  242. /// 32-bit signed integers.
  243. xnn_datatype_int32 = 11,
  244. /// Quantized 4-bit signed integer with shared per-channel-block quantization
  245. /// parameters.
  246. xnn_datatype_qbint4 = 12,
  247. /// IEEE754 single-precision packed floating-point.
  248. xnn_datatype_pfp32 = 13,
  249. /// BFloat16, i.e. the upper 16 bits of a float32.
  250. xnn_datatype_bf16 = 14,
  251. /// Dynamically quantized 8-bit unsigned integer with per-batch quantization
  252. /// parameters.
  253. xnn_datatype_qduint8 = 15,
  254. };
  255. /// Define a tensor-type Value and add it to a Subgraph.
  256. ///
  257. /// @param subgraph - a Subgraph object that will own the created Value.
  258. /// @param datatype - type of the tensor elements.
  259. /// @param num_dims - number of dimensions in the shape.
  260. /// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL.
  261. /// XNNPACK does not keep any pointers to this array after the function returns.
  262. /// @param data - pointer to static data used for tensor initialization. If the tensor is not statically initialized,
  263. /// this pointer must be is NULL. If non-NULL, the life-time of the static data must exceed the life-time
  264. /// of the Subgraph object, and of any Runtime objects created from the Subgraph.
  265. /// @param external_id - external ID for the Value. The ID must be within the range of reversed Value IDs specified on
  266. /// the Subgraph creation. If the external ID is XNN_INVALID_VALUE_ID, an internal ID will be
  267. /// created for the Value.
  268. /// @param flags - binary features of the Value. Supported values are any combination of XNN_VALUE_FLAG_EXTERNAL_INPUT
  269. /// and XNN_VALUE_FLAG_EXTERNAL_OUTPUT.
  270. /// @param id_out - pointer to the variable that will be initialized with the Value ID upon successful return. If a
  271. /// valid @a external_id was provided, the variable will be initialized with the @a external_id value.
  272. enum xnn_status xnn_define_tensor_value(
  273. xnn_subgraph_t subgraph,
  274. enum xnn_datatype datatype,
  275. size_t num_dims,
  276. const size_t* dims,
  277. const void* data,
  278. uint32_t external_id,
  279. uint32_t flags,
  280. uint32_t* id_out);
  281. /// Define a quantized tensor-type Value and add it to a Subgraph.
  282. ///
  283. /// @param subgraph - a Subgraph object that will own the created Value.
  284. /// @param datatype - type of the tensor elements.
  285. /// @param zero_point - offset from zero to subtract from the quantized elements in the Value.
  286. /// @param scale - multiplication factor to convert quantized elements to real representation.
  287. /// @param num_dims - number of dimensions in the shape.
  288. /// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL.
  289. /// XNNPACK does not keep any pointers to this array after the function returns.
  290. /// @param data - pointer to static data used for tensor initialization. If the tensor is not statically initialized,
  291. /// this pointer must be is NULL. If non-NULL, the life-time of the static data must exceed the life-time
  292. /// of the Subgraph object, and of any Runtime objects created from the Subgraph.
  293. /// @param external_id - external ID for the Value. The ID must be within the range of reversed Value IDs specified on
  294. /// the Subgraph creation. If the external ID is XNN_INVALID_VALUE_ID, an internal ID will be
  295. /// created for the Value.
  296. /// @param flags - binary features of the Value. Supported values are any combination of XNN_VALUE_FLAG_EXTERNAL_INPUT
  297. /// and XNN_VALUE_FLAG_EXTERNAL_OUTPUT.
  298. /// @param id_out - pointer to the variable that will be initialized with the Value ID upon successful return. If a
  299. /// valid @a external_id was provided, the variable will be initialized with the @a external_id value.
  300. enum xnn_status xnn_define_quantized_tensor_value(
  301. xnn_subgraph_t subgraph,
  302. enum xnn_datatype datatype,
  303. int32_t zero_point,
  304. float scale,
  305. size_t num_dims,
  306. const size_t* dims,
  307. const void* data,
  308. uint32_t external_id,
  309. uint32_t flags,
  310. uint32_t* id_out);
  311. enum xnn_status xnn_define_channelwise_quantized_tensor_value(
  312. xnn_subgraph_t subgraph,
  313. enum xnn_datatype datatype,
  314. const float* scale,
  315. size_t num_dims,
  316. size_t channel_dim,
  317. const size_t* dims,
  318. const void* data,
  319. uint32_t external_id,
  320. uint32_t flags,
  321. uint32_t* id_out);
  322. /// Validate the dimensions, channel_dim, zero point, datatype, and scale of a quantized tensor-type.
  323. ///
  324. /// @param datatype - type of the tensor elements.
  325. /// @param zero_point - offset from zero to subtract from the quantized elements in the Value.
  326. /// @param scale - multiplication factor to convert quantized elements to real representation.
  327. /// @param num_dims - number of dimensions in the shape.
  328. /// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL.
  329. /// XNNPACK does not keep any pointers to this array after the function returns.
  330. enum xnn_status xnn_validate_quantized_tensor(
  331. enum xnn_datatype datatype,
  332. int32_t zero_point,
  333. float scale,
  334. size_t num_dims,
  335. const size_t* dims);
  336. /// Validate the dimensions, channel_dim, zero point, datatype, and scales of a channelwise quantized tensor-type.
  337. ///
  338. /// @param datatype - type of the tensor elements.
  339. /// @param zero_point - offset from zero to subtract from the quantized elements in the Value.
  340. /// @param scale - per-channel multiplication factors to convert quantized elements to real representation.
  341. /// @param num_dims - number of dimensions in the shape.
  342. /// @param channel_dim - index of the channel dimension in the tensor with per-channel quantization parameters.
  343. /// Typically this is the first dimension (dimension #0) of the filter tensors in the Convolution,
  344. /// Deconvolution, and Fully Connected operators and the last dimension of the filter tensors in
  345. /// the Depthwise Convolution operators.
  346. /// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL.
  347. /// XNNPACK does not keep any pointers to this array after the function returns.
  348. enum xnn_status xnn_validate_channelwise_quantized_tensor(
  349. enum xnn_datatype datatype,
  350. int32_t zero_point,
  351. const float* scale,
  352. size_t num_dims,
  353. size_t channel_dim,
  354. const size_t* dims);
  355. /// Define a channelwise quantized tensor-type Value and add it to a Subgraph.
  356. ///
  357. /// @param subgraph - a Subgraph object that will own the created Value.
  358. /// @param datatype - type of the tensor elements.
  359. /// @param zero_point - offset from zero to subtract from the quantized elements in the Value.
  360. /// @param scale - per-channel multiplication factors to convert quantized elements to real representation.
  361. /// @param num_dims - number of dimensions in the shape.
  362. /// @param channel_dim - index of the channel dimension in the tensor with per-channel quantization parameters.
  363. /// Typically this is the first dimension (dimension #0) of the filter tensors in the Convolution,
  364. /// Deconvolution, and Fully Connected operators and the last dimension of the filter tensors in
  365. /// the Depthwise Convolution operators.
  366. /// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL.
  367. /// XNNPACK does not keep any pointers to this array after the function returns.
  368. /// @param data - pointer to static data used for tensor initialization. If the tensor is not statically initialized,
  369. /// this pointer must be is NULL. If non-NULL, the life-time of the static data must exceed the life-time
  370. /// of the Subgraph object, and of any Runtime objects created from the Subgraph.
  371. /// @param external_id - external ID for the Value. The ID must be within the range of reversed Value IDs specified on
  372. /// the Subgraph creation. If the external ID is XNN_INVALID_VALUE_ID, an internal ID will be
  373. /// created for the Value.
  374. /// @param flags - binary features of the Value. Supported values are any combination of XNN_VALUE_FLAG_EXTERNAL_INPUT
  375. /// and XNN_VALUE_FLAG_EXTERNAL_OUTPUT.
  376. /// @param id_out - pointer to the variable that will be initialized with the Value ID upon successful return. If a
  377. /// valid @a external_id was provided, the variable will be initialized with the @a external_id value.
  378. enum xnn_status xnn_define_channelwise_quantized_tensor_value_v2(
  379. xnn_subgraph_t subgraph,
  380. enum xnn_datatype datatype,
  381. int32_t zero_point,
  382. const float* scale,
  383. size_t num_dims,
  384. size_t channel_dim,
  385. const size_t* dims,
  386. const void* data,
  387. uint32_t external_id,
  388. uint32_t flags,
  389. uint32_t* id_out);
  390. /// Define a blockwise quantized tensor-type Value and add it to a Subgraph.
  391. /// @param block_size - size of a block in the tensor with blockwise quantization parameters. Block is defined as
  392. /// number of input channel element per output channel.
  393. /// For Fully connected operators with 2d filters of size [output_channels, input_channels],
  394. /// expecting number of scale values to be = output_channels * (input_channels / block_size).
  395. enum xnn_status xnn_define_blockwise_quantized_tensor_value(
  396. xnn_subgraph_t subgraph,
  397. enum xnn_datatype datatype,
  398. int32_t zero_point,
  399. const uint16_t* scale,
  400. size_t num_dims,
  401. size_t channel_dim,
  402. size_t block_size,
  403. const size_t* dims,
  404. const void* data,
  405. uint32_t external_id,
  406. uint32_t flags,
  407. uint32_t* id_out);
  408. /// Define a dynamically quantized tensor-type Value and add it to a Subgraph.
  409. ///
  410. /// @param subgraph - a Subgraph object that will own the created Value.
  411. /// @param datatype - type of the tensor elements.
  412. /// @param num_dims - number of dimensions in the shape.
  413. /// @param num_non_batch_dims - number of non-batch dimensions in the shape. The leading (num_dims - num_non_batch_dims)
  414. /// dimensions will be flattened and treated as batch size. A set of quantization parameters
  415. /// will be calculated for each batch element.
  416. /// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL.
  417. /// XNNPACK does not keep any pointers to this array after the function returns.
  418. /// @param external_id - external ID for the Value. The ID must be within the range of reversed Value IDs specified on
  419. /// the Subgraph creation. If the external ID is XNN_INVALID_VALUE_ID, an internal ID will be
  420. /// created for the Value.
  421. /// @param flags - binary features of the Value. No supported flags are currently defined.
  422. /// @param id_out - pointer to the variable that will be initialized with the Value ID upon successful return. If a
  423. /// valid @a external_id was provided, the variable will be initialized with the @a external_id value.
  424. enum xnn_status xnn_define_dynamically_quantized_tensor_value(
  425. xnn_subgraph_t subgraph,
  426. enum xnn_datatype datatype,
  427. size_t num_dims,
  428. size_t num_nonbatch_dims,
  429. const size_t* dims,
  430. uint32_t external_id,
  431. uint32_t flags,
  432. uint32_t* id_out);
  433. /// Type of unary operation
  434. enum xnn_unary_operator {
  435. xnn_unary_invalid = -1,
  436. xnn_unary_convert,
  437. xnn_unary_clamp,
  438. xnn_unary_abs,
  439. xnn_unary_bankers_rounding,
  440. xnn_unary_ceiling,
  441. xnn_unary_elu,
  442. xnn_unary_exp,
  443. xnn_unary_floor,
  444. xnn_unary_gelu,
  445. xnn_unary_hardswish,
  446. xnn_unary_leaky_relu,
  447. xnn_unary_log,
  448. xnn_unary_negate,
  449. xnn_unary_sigmoid,
  450. xnn_unary_square,
  451. xnn_unary_square_root,
  452. xnn_unary_reciprocal_square_root,
  453. xnn_unary_tanh,
  454. // The following operators are experimental and may be removed.
  455. xnn_unary_cube_root,
  456. xnn_unary_cosine,
  457. xnn_unary_sine,
  458. xnn_unary_count_leading_zeros,
  459. xnn_unary_bitwise_not,
  460. xnn_unary_popcount,
  461. xnn_unary_sign,
  462. };
  463. /// Parameters for xnn_define_unary
  464. union xnn_unary_params {
  465. struct {
  466. /// lower bound for clipping output values.
  467. float min;
  468. /// upper bound for clipping output values.
  469. float max;
  470. } clamp;
  471. struct {
  472. /// scale factor for negative input elements.
  473. float alpha;
  474. } elu;
  475. struct {
  476. /// scale factor for negative input elements.
  477. float negative_slope;
  478. } leaky_relu;
  479. };
  480. /// Define a unary operator Node and add it to a Subgraph.
  481. ///
  482. /// @param subgraph - a Subgraph object that will own the created Node.
  483. /// @param operator - type of unary operator to define.
  484. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  485. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  486. /// shape must match the shape of the input tensor.
  487. /// @param params - parameters to be interpreted by the specific operator type.
  488. /// @param flags - binary features of the Node. No supported flags are currently defined.
  489. enum xnn_status xnn_define_unary(
  490. xnn_subgraph_t subgraph,
  491. enum xnn_unary_operator type,
  492. const union xnn_unary_params* params,
  493. uint32_t input_id,
  494. uint32_t output_id,
  495. uint32_t flags);
  496. /// Define a Convert Node and add it to a Subgraph.
  497. ///
  498. /// @param subgraph - a Subgraph object that will own the created Node.
  499. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  500. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  501. /// shape must match the shape of the input tensor.
  502. /// @param flags - binary features of the Convert Node. No supported flags are currently defined.
  503. XNN_DEPRECATED enum xnn_status xnn_define_convert(
  504. xnn_subgraph_t subgraph,
  505. uint32_t input_id,
  506. uint32_t output_id,
  507. uint32_t flags);
  508. /// Define a 2D Convolution Node and add it to a Subgraph.
  509. ///
  510. /// @param subgraph - a Subgraph object that will own the created Node.
  511. /// @param input_padding_top - implicit zero-padding above 2D input data. Must be 0 if XNN_FLAG_TENSORFLOW_SAME_PADDING
  512. /// flag is specified.
  513. /// @param input_padding_right - implicit zero-padding to the right of 2D input data. Must be 0 if
  514. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  515. /// @param input_padding_bottom - implicit zero-padding below 2D input data. Must be 0 if
  516. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  517. /// @param input_padding_left - implicit zero-padding to the left of 2D input data. Must be 0 if
  518. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  519. /// @param kernel_height - kernel (filter) height.
  520. /// @param kernel_width - kernel (filter) width.
  521. /// @param subsampling_height - height of subsampling region for convolution output (convolution height stride).
  522. /// @param subsampling_width - width of subsampling region for convolution output (convolution width stride).
  523. /// @param dilation_height - dilation of kernel elements along the height dimension.
  524. /// @param dilation_width - dilation of kernel elements along the width dimension.
  525. /// @param groups - number of convolution groups.
  526. /// @param group_input_channels - number of input channels per group.
  527. /// @param group_output_channels - number of output channels per group.
  528. /// @param output_min - lower bound for clipping output values.
  529. /// @param output_max - upper bound for clipping output values.
  530. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  531. /// with [N, IH, IW, groups * group_input_channels] dimensions
  532. /// @param filter_id - Value ID for the filter tensor. The filter tensor must ge a 4D tensor defined in the @a subgraph
  533. /// with [groups * group_output_channels, kernel_height, kernel_width, group_input_channels]
  534. /// dimensions.
  535. /// @param bias_id - Value ID for the bias tensor, or XNN_INVALID_VALUE_ID for a 2D Convolution Node without a bias. If
  536. /// present, the bias tensor must be a 1D tensor defined in the @a subgraph with [groups *
  537. /// group_output_channels] dimensions.
  538. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  539. /// with [N, OH, OW, groups * group_output_channels] dimensions.
  540. /// @param flags - binary features of the 2D Convolution Node. The only currently supported values is
  541. /// XNN_FLAG_TENSORFLOW_SAME_PADDING.
  542. enum xnn_status xnn_define_convolution_2d(
  543. xnn_subgraph_t subgraph,
  544. uint32_t input_padding_top,
  545. uint32_t input_padding_right,
  546. uint32_t input_padding_bottom,
  547. uint32_t input_padding_left,
  548. uint32_t kernel_height,
  549. uint32_t kernel_width,
  550. uint32_t subsampling_height,
  551. uint32_t subsampling_width,
  552. uint32_t dilation_height,
  553. uint32_t dilation_width,
  554. uint32_t groups,
  555. size_t group_input_channels,
  556. size_t group_output_channels,
  557. float output_min,
  558. float output_max,
  559. uint32_t input_id,
  560. uint32_t filter_id,
  561. uint32_t bias_id,
  562. uint32_t output_id,
  563. uint32_t flags);
  564. /// Define a 2D Deconvolution (Transposed Convolution) Node and add it to a Subgraph.
  565. ///
  566. /// @param subgraph - a Subgraph object that will own the created Node.
  567. /// @param padding_top - implicit padding above 2D output data.
  568. /// @param padding_right - implicit padding to the right of 2D output data.
  569. /// @param padding_bottom - implicit padding below 2D output data.
  570. /// @param padding_left - implicit padding to the left of 2D output data.
  571. /// @param adjustment_height - additional elements in the bottom of the 2D output data.
  572. /// @param adjustment_width - additional elements to the right of the 2D output data.
  573. /// @param kernel_height - kernel (filter) height.
  574. /// @param kernel_width - kernel (filter) width.
  575. /// @param upsampling_height - height of upsampling region for deconvolution input (deconvolution height stride).
  576. /// @param upsampling_width - width of upsampling region for deconvolution input (deconvolution width stride).
  577. /// @param dilation_height - dilation of kernel elements along the height dimension.
  578. /// @param dilation_width - dilation of kernel elements along the width dimension.
  579. /// @param groups - number of convolution groups.
  580. /// @param group_input_channels - number of input channels per group.
  581. /// @param group_output_channels - number of output channels per group.
  582. /// @param output_min - lower bound for clipping output values.
  583. /// @param output_max - upper bound for clipping output values.
  584. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  585. /// with [N, IH, IW, groups * group_input_channels] dimensions
  586. /// @param filter_id - Value ID for the filter tensor. The filter tensor must ge a 4D tensor defined in the @a subgraph
  587. /// with [groups * group_output_channels, kernel_height, kernel_width, group_input_channels]
  588. /// dimensions.
  589. /// @param bias_id - Value ID for the bias tensor, or XNN_INVALID_VALUE_ID for a 2D Convolution Node without a bias. If
  590. /// present, the bias tensor must be a 1D tensor defined in the @a subgraph with
  591. /// [groups * group_output_channels] dimensions.
  592. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  593. /// with [N, OH, OW, groups * group_output_channels] dimensions.
  594. /// @param flags - binary features of the 2D Deconvolution Node. No supported flags are currently defined.
  595. enum xnn_status xnn_define_deconvolution_2d(
  596. xnn_subgraph_t subgraph,
  597. uint32_t padding_top,
  598. uint32_t padding_right,
  599. uint32_t padding_bottom,
  600. uint32_t padding_left,
  601. uint32_t adjustment_height,
  602. uint32_t adjustment_width,
  603. uint32_t kernel_height,
  604. uint32_t kernel_width,
  605. uint32_t upsampling_height,
  606. uint32_t upsampling_width,
  607. uint32_t dilation_height,
  608. uint32_t dilation_width,
  609. uint32_t groups,
  610. size_t group_input_channels,
  611. size_t group_output_channels,
  612. float output_min,
  613. float output_max,
  614. uint32_t input_id,
  615. uint32_t filter_id,
  616. uint32_t bias_id,
  617. uint32_t output_id,
  618. uint32_t flags);
  619. /// Define a 2D Depthwise Convolution Node and add it to a Subgraph.
  620. ///
  621. /// @param subgraph - a Subgraph object that will own the created Node.
  622. /// @param input_padding_top - implicit zero-padding above 2D input data. Must be 0 if XNN_FLAG_TENSORFLOW_SAME_PADDING
  623. /// flag is specified.
  624. /// @param input_padding_right - implicit zero-padding to the right of 2D input data. Must be 0 if
  625. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  626. /// @param input_padding_bottom - implicit zero-padding below 2D input data. Must be 0 if
  627. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  628. /// @param input_padding_left - implicit zero-padding to the left of 2D input data. Must be 0 if
  629. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  630. /// @param kernel_height - kernel (filter) height.
  631. /// @param kernel_width - kernel (filter) width.
  632. /// @param subsampling_height - height of subsampling region for convolution output (convolution height stride).
  633. /// @param subsampling_width - width of subsampling region for convolution output (convolution width stride).
  634. /// @param dilation_height - dilation of kernel elements along the height dimension.
  635. /// @param dilation_width - dilation of kernel elements along the width dimension.
  636. /// @param depth_multiplier - ratio of output channels to input channels.
  637. /// @param input_channels - number of input channels.
  638. /// @param output_min - lower bound for clipping output values.
  639. /// @param output_max - upper bound for clipping output values.
  640. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  641. /// with [N, IH, IW, input_channels] dimensions
  642. /// @param filter_id - Value ID for the filter tensor. The filter tensor must ge a 4D tensor defined in the @a subgraph
  643. /// with [1, kernel_height, kernel_width, input_channels * depth_multiplier] dimensions.
  644. /// @param bias_id - Value ID for the bias tensor, or XNN_INVALID_VALUE_ID for a 2D Depthwise Convolution Node without
  645. /// a bias. If present, the bias tensor must be a 1D tensor defined in the @a subgraph with
  646. /// [input_channels * depth_multiplier] dimensions.
  647. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  648. /// with [N, OH, OW, input_channels * depth_multiplier] dimensions.
  649. /// @param flags - binary features of the 2D Depthwise Convolution Node. The only currently supported values is
  650. /// XNN_FLAG_TENSORFLOW_SAME_PADDING.
  651. enum xnn_status xnn_define_depthwise_convolution_2d(
  652. xnn_subgraph_t subgraph,
  653. uint32_t input_padding_top,
  654. uint32_t input_padding_right,
  655. uint32_t input_padding_bottom,
  656. uint32_t input_padding_left,
  657. uint32_t kernel_height,
  658. uint32_t kernel_width,
  659. uint32_t subsampling_height,
  660. uint32_t subsampling_width,
  661. uint32_t dilation_height,
  662. uint32_t dilation_width,
  663. uint32_t depth_multiplier,
  664. size_t input_channels,
  665. float output_min,
  666. float output_max,
  667. uint32_t input_id,
  668. uint32_t filter_id,
  669. uint32_t bias_id,
  670. uint32_t output_id,
  671. uint32_t flags);
  672. /// Define a Depth To Space Node 2D and add it to a Subgraph.
  673. ///
  674. /// The Depth To Space 2D Node rearranges data from depth into blocks of spatial data (a reverse transform to
  675. /// Space To Depth). For a given input pixel, an output square of pixels with side @a block_size is formed from values
  676. /// in the corresponding number of its channels. The output depth is therefore @a block_size x @a block_size times
  677. /// smaller than that of the input.
  678. ///
  679. /// @param subgraph - a Subgraph object that will own the created Node.
  680. /// @param block_size - the size of the spatial block.
  681. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  682. /// with [N, IH, IW, OC * block_size * block_size] dimensions.
  683. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  684. /// with [N, IH * block_size, IW * block_size, OC] dimensions.
  685. /// @param flags - binary features of the input_channels Node. No supported flags are currently defined.
  686. enum xnn_status xnn_define_depth_to_space_2d(
  687. xnn_subgraph_t subgraph,
  688. uint32_t block_size,
  689. uint32_t input_id,
  690. uint32_t output_id,
  691. uint32_t flags);
  692. enum xnn_status xnn_define_depth_to_space(
  693. xnn_subgraph_t subgraph,
  694. uint32_t input_id,
  695. uint32_t output_id,
  696. uint32_t block_size,
  697. uint32_t flags);
  698. /// Define a 1D Global Average Pooling Node and add it to a Subgraph.
  699. ///
  700. /// @param subgraph - a Subgraph object that will own the created Node.
  701. /// @param output_min - lower bound for clipping output values.
  702. /// @param output_max - upper bound for clipping output values.
  703. /// @param input_id - Value ID for the input tensor. The input tensor must be a dense tensor with 2 or more dimensions
  704. /// defined in the @a subgraph. Averaging is performed across the second-innermost dimension.
  705. /// @param output_id - Value ID for the output tensor. The output tensor must be a dense tensor with 2 or more
  706. /// dimensions defined in the @a subgraph.
  707. /// @param flags - binary features of the 1D Global Average Pooling Node. The only currently supported value is
  708. /// XNN_FLAG_KEEP_DIMS.
  709. XNN_DEPRECATED enum xnn_status xnn_define_global_average_pooling_1d(
  710. xnn_subgraph_t subgraph,
  711. float output_min,
  712. float output_max,
  713. uint32_t input_id,
  714. uint32_t output_id,
  715. uint32_t flags);
  716. /// Define a 2D Global Average Pooling Node and add it to a Subgraph.
  717. ///
  718. /// @param subgraph - a Subgraph object that will own the created Node.
  719. /// @param output_min - lower bound for clipping output values.
  720. /// @param output_max - upper bound for clipping output values.
  721. /// @param input_id - Value ID for the input tensor. The input tensor must be a dense tensor with 3 or more dimensions
  722. /// defined in the @a subgraph. Averaging is performed across the second- and third-innermost
  723. /// dimensions.
  724. /// @param output_id - Value ID for the output tensor. The output tensor must be a dense tensor with 3 or more
  725. /// dimensions defined in the @a subgraph.
  726. /// @param flags - binary features of the 2D Global Average Pooling Node. The only currently supported value is
  727. /// XNN_FLAG_KEEP_DIMS.
  728. XNN_DEPRECATED enum xnn_status xnn_define_global_average_pooling_2d(
  729. xnn_subgraph_t subgraph,
  730. float output_min,
  731. float output_max,
  732. uint32_t input_id,
  733. uint32_t output_id,
  734. uint32_t flags);
  735. /// Define a 1D Global Sum Pooling Node and add it to a Subgraph.
  736. ///
  737. /// @param subgraph - a Subgraph object that will own the created Node.
  738. /// @param output_min - lower bound for clipping output values.
  739. /// @param output_max - upper bound for clipping output values.
  740. /// @param input_id - Value ID for the input tensor. The input tensor must be a dense tensor with 2 or more dimensions
  741. /// defined in the @a subgraph. Averaging is performed across the second-innermost dimension.
  742. /// @param output_id - Value ID for the output tensor. The output tensor must be a dense tensor with 2 or more
  743. /// dimensions defined in the @a subgraph.
  744. /// @param flags - binary features of the 1D Global Sum Pooling Node. The only currently supported value is
  745. /// XNN_FLAG_KEEP_DIMS.
  746. XNN_DEPRECATED enum xnn_status xnn_define_global_sum_pooling_1d(
  747. xnn_subgraph_t subgraph,
  748. float output_min,
  749. float output_max,
  750. uint32_t input_id,
  751. uint32_t output_id,
  752. uint32_t flags);
  753. /// Define a 2D Global Sum Pooling Node and add it to a Subgraph.
  754. ///
  755. /// @param subgraph - a Subgraph object that will own the created Node.
  756. /// @param output_min - lower bound for clipping output values.
  757. /// @param output_max - upper bound for clipping output values.
  758. /// @param input_id - Value ID for the input tensor. The input tensor must be a dense tensor with 3 or more dimensions
  759. /// defined in the @a subgraph. Averaging is performed across the second- and third-innermost
  760. /// dimensions.
  761. /// @param output_id - Value ID for the output tensor. The output tensor must be a dense tensor with 3 or more
  762. /// dimensions defined in the @a subgraph.
  763. /// @param flags - binary features of the 2D Global Sum Pooling Node. The only currently supported value is
  764. /// XNN_FLAG_KEEP_DIMS.
  765. XNN_DEPRECATED enum xnn_status xnn_define_global_sum_pooling_2d(
  766. xnn_subgraph_t subgraph,
  767. float output_min,
  768. float output_max,
  769. uint32_t input_id,
  770. uint32_t output_id,
  771. uint32_t flags);
  772. /// Define a 2D Average Pooling Node and add it to a Subgraph.
  773. ///
  774. /// @param subgraph - a Subgraph object that will own the created Node.
  775. /// @param input_padding_top - implicit zero-padding above 2D input data. Must be 0 if XNN_FLAG_TENSORFLOW_SAME_PADDING
  776. /// flag is specified.
  777. /// @param input_padding_right - implicit zero-padding to the right of 2D input data. Must be 0 if
  778. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  779. /// @param input_padding_bottom - implicit zero-padding below 2D input data. Must be 0 if
  780. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  781. /// @param input_padding_left - implicit zero-padding to the left of 2D input data. Must be 0 if
  782. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  783. /// @param pooling_height - pooling (kernel) height.
  784. /// @param pooling_width - pooling (kernel) width.
  785. /// @param stride_height - displacing of the pooling window in the vertical dimension of the input pixels corresponding
  786. /// to vertically adjacent output pixels.
  787. /// @param stride_width - displacing of the pooling window in the horizontal dimension of the input pixels corresponding
  788. /// to horizontally adjacent output pixels.
  789. /// @param output_min - lower bound for clipping output values.
  790. /// @param output_max - upper bound for clipping output values.
  791. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  792. /// with [N, IH, IW, channels] dimensions
  793. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  794. /// with [N, OH, OW, channels] dimensions.
  795. /// @param flags - binary features of the 2D Average Pooling Node. The only currently supported values is
  796. /// XNN_FLAG_TENSORFLOW_SAME_PADDING.
  797. enum xnn_status xnn_define_average_pooling_2d(
  798. xnn_subgraph_t subgraph,
  799. uint32_t input_padding_top,
  800. uint32_t input_padding_right,
  801. uint32_t input_padding_bottom,
  802. uint32_t input_padding_left,
  803. uint32_t pooling_height,
  804. uint32_t pooling_width,
  805. uint32_t stride_height,
  806. uint32_t stride_width,
  807. float output_min,
  808. float output_max,
  809. uint32_t input_id,
  810. uint32_t output_id,
  811. uint32_t flags);
  812. /// Define a Fully Connected Node and add it to a Subgraph.
  813. ///
  814. /// @param subgraph - a Subgraph object that will own the created Node.
  815. /// @param output_min - lower bound for clipping output values.
  816. /// @param output_max - upper bound for clipping output values.
  817. /// @param input_id - Value ID for the input tensor. The input tensor must be an N-dimensional tensor defined in the
  818. /// @a subgraph. If XNN_FLAG_TENSORFLOW_RESHAPE_2D is not specified, the input tensor must be at least
  819. /// 1D and its last dimension must match the last dimension of the filter tensor. In particular, if
  820. /// input is a 2D tensor, it must have [batch_size, input_channels] dimensions.
  821. /// If XNN_FLAG_TENSORFLOW_RESHAPE_2D is specified, the number of elements in the input tensor must be
  822. /// divisible by the input_channels. The tensor will be first flattened into a 1D tensor of
  823. /// [num_input_elements] dimensions, then reshaped into a 2D tensor of
  824. /// [num_input_elements / input_channels, input_channels] dimensions where num_input_elements is the
  825. /// total number of elements in the input tensor.
  826. /// @param filter_id - Value ID for the filter tensor. The filter tensor must a 2D tensor defined in the @a subgraph.
  827. /// If the XNN_FLAG_TRANSPOSE_WEIGHTS flag is not specified, the filter tensor must have
  828. /// [output_channels, input_channels] dimensions. If the XNN_FLAG_TRANSPOSE_WEIGHTS flag is
  829. /// specified, the filter tensor must have [input_channels, output_channels] dimensions.
  830. /// @param bias_id - Value ID for the bias tensor, or XNN_INVALID_VALUE_ID for a Fully Connected Node without a bias.
  831. /// If present, the bias tensor must be a 1D tensor defined in the @a subgraph with [output_channels]
  832. /// dimensions.
  833. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph.
  834. /// If XNN_FLAG_TENSORFLOW_RESHAPE_2D is not specified, the output tensor must have the same
  835. /// dimensionality as the input tensor, all its dimensions but the last one must match the
  836. /// corresponding dimensions of the input tensor, and the last dimensions of the output tensor must
  837. /// match the first dimension of the filter tensor. In particular, if input is a 2D tensor, output
  838. /// must be a 2D tensor of [batch_size, output_channels] dimensions.
  839. /// If XNN_FLAG_TENSORFLOW_RESHAPE_2D is specified, output must be a 2D tensor of
  840. /// [num_input_elements / input_channels, output_channels] dimensions where num_input_elements is the
  841. /// total number of elements in the input tensor.
  842. /// @param flags - binary features of the Fully Connected Node. The only currently supported values are
  843. /// XNN_FLAG_TENSORFLOW_RESHAPE_2D and XNN_FLAG_TRANSPOSE_WEIGHTS.
  844. enum xnn_status xnn_define_fully_connected(
  845. xnn_subgraph_t subgraph,
  846. float output_min,
  847. float output_max,
  848. uint32_t input_id,
  849. uint32_t filter_id,
  850. uint32_t bias_id,
  851. uint32_t output_id,
  852. uint32_t flags);
  853. /// Define a Sparse Fully Connected Node and add it to a Subgraph.
  854. ///
  855. /// This operator is experimental, and will be removed in the future.
  856. ///
  857. /// @param subgraph - a Subgraph object that will own the created Node.
  858. /// @param output_min - lower bound for clipping output values.
  859. /// @param output_max - upper bound for clipping output values.
  860. /// @param input_id - Value ID for the input tensor. The input tensor must be an N-dimensional tensor defined in the
  861. /// @a subgraph. If XNN_FLAG_TENSORFLOW_RESHAPE_2D is not specified, the input tensor must be at least
  862. /// 1D and its last dimension must match the last dimension of the filter tensor. In particular, if
  863. /// input is a 2D tensor, it must have [batch_size, input_channels] dimensions.
  864. /// If XNN_FLAG_TENSORFLOW_RESHAPE_2D is specified, the number of elements in the input tensor must be
  865. /// divisible by the input_channels. The tensor will be first flattened into a 1D tensor of
  866. /// [num_input_elements] dimensions, then reshaped into a 2D tensor of
  867. /// [num_input_elements / input_channels, input_channels] dimensions where num_input_elements is the
  868. /// total number of elements in the input tensor.
  869. /// @param filter_id - Value ID for the filter tensor. The filter tensor must a 2D tensor defined in the @a subgraph.
  870. /// If the XNN_FLAG_TRANSPOSE_WEIGHTS flag is not specified, the filter tensor must have
  871. /// [output_channels, input_channels] dimensions. If the XNN_FLAG_TRANSPOSE_WEIGHTS flag is
  872. /// specified, the filter tensor must have [input_channels, output_channels] dimensions.
  873. /// @param bias_id - Value ID for the bias tensor, or XNN_INVALID_VALUE_ID for a Fully Connected Node without a bias.
  874. /// If present, the bias tensor must be a 1D tensor defined in the @a subgraph with [output_channels]
  875. /// dimensions.
  876. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph.
  877. /// If XNN_FLAG_TENSORFLOW_RESHAPE_2D is not specified, the output tensor must have the same
  878. /// dimensionality as the input tensor, all its dimensions but the last one must match the
  879. /// corresponding dimensions of the input tensor, and the last dimensions of the output tensor must
  880. /// match the first dimension of the filter tensor. In particular, if input is a 2D tensor, output
  881. /// must be a 2D tensor of [batch_size, output_channels] dimensions.
  882. /// If XNN_FLAG_TENSORFLOW_RESHAPE_2D is specified, output must be a 2D tensor of
  883. /// [num_input_elements / input_channels, output_channels] dimensions where num_input_elements is the
  884. /// total number of elements in the input tensor.
  885. /// @param flags - binary features of the Fully Connected Node. The only currently supported values are
  886. /// XNN_FLAG_TENSORFLOW_RESHAPE_2D and XNN_FLAG_TRANSPOSE_WEIGHTS.
  887. enum xnn_status xnn_define_fully_connected_sparse(
  888. xnn_subgraph_t subgraph,
  889. float output_min,
  890. float output_max,
  891. uint32_t input_id,
  892. uint32_t filter_id,
  893. uint32_t bias_id,
  894. uint32_t output_id,
  895. uint32_t flags);
  896. /// Define a 2D Max Pooling Node and add it to a Subgraph.
  897. ///
  898. /// @param subgraph - a Subgraph object that will own the created Node.
  899. /// @param input_padding_top - implicit zero-padding above 2D input data. Must be 0 if XNN_FLAG_TENSORFLOW_SAME_PADDING
  900. /// flag is specified.
  901. /// @param input_padding_right - implicit zero-padding to the right of 2D input data. Must be 0 if
  902. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  903. /// @param input_padding_bottom - implicit zero-padding below 2D input data. Must be 0 if
  904. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  905. /// @param input_padding_left - implicit zero-padding to the left of 2D input data. Must be 0 if
  906. /// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified.
  907. /// @param pooling_height - pooling (kernel) height.
  908. /// @param pooling_width - pooling (kernel) width.
  909. /// @param stride_height - displacing of the pooling window in the vertical dimension of the input pixels corresponding
  910. /// to vertically adjacent output pixels.
  911. /// @param stride_width - displacing of the pooling window in the horizontal dimension of the input pixels corresponding
  912. /// to horizontally adjacent output pixels.
  913. /// @param dilation_height - dilation of pooling elements along the height dimension.
  914. /// @param dilation_width - dilation of pooling elements along the width dimension.
  915. /// @param output_min - lower bound for clipping output values.
  916. /// @param output_max - upper bound for clipping output values.
  917. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  918. /// with [N, IH, IW, channels] dimensions
  919. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  920. /// with [N, OH, OW, channels] dimensions.
  921. /// @param flags - binary features of the 2D Max Pooling Node. The only currently supported values is
  922. /// XNN_FLAG_TENSORFLOW_SAME_PADDING.
  923. enum xnn_status xnn_define_max_pooling_2d(
  924. xnn_subgraph_t subgraph,
  925. uint32_t input_padding_top,
  926. uint32_t input_padding_right,
  927. uint32_t input_padding_bottom,
  928. uint32_t input_padding_left,
  929. uint32_t pooling_height,
  930. uint32_t pooling_width,
  931. uint32_t stride_height,
  932. uint32_t stride_width,
  933. uint32_t dilation_height,
  934. uint32_t dilation_width,
  935. float output_min,
  936. float output_max,
  937. uint32_t input_id,
  938. uint32_t output_id,
  939. uint32_t flags);
  940. /// Define a 2D ArgMax Pooling Node and add it to a Subgraph.
  941. ///
  942. /// @param subgraph - a Subgraph object that will own the created Node.
  943. /// @param input_padding_top - implicit zero-padding above 2D input data.
  944. /// @param input_padding_right - implicit zero-padding to the right of 2D input data.
  945. /// @param input_padding_bottom - implicit zero-padding below 2D input data.
  946. /// @param input_padding_left - implicit zero-padding to the left of 2D input data.
  947. /// @param pooling_height - pooling (kernel) height. Vertical stride between pooling regions match this value.
  948. /// @param pooling_width - pooling (kernel) width. Horizontal stride between pooling regions match this value.
  949. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  950. /// with [N, IH, IW, channels] dimensions
  951. /// @param output_value_id - Value ID for the output tensor with the maximum values in the pools. The output tensor must
  952. /// be a 4D tensor defined in the @a subgraph with [N, OH, OW, channels] dimensions.
  953. /// @param output_index_id - Value ID for the output tensor with the indexes of the maximum values in the pools. The
  954. /// output tensor must be a 4D tensor defined in the @a subgraph with [N, OH, OW, channels]
  955. /// dimensions.
  956. /// @param flags - binary features of the 2D ArgMax Pooling Node. No supported flags are currently defined.
  957. enum xnn_status xnn_define_argmax_pooling_2d(
  958. xnn_subgraph_t subgraph,
  959. uint32_t input_padding_top,
  960. uint32_t input_padding_right,
  961. uint32_t input_padding_bottom,
  962. uint32_t input_padding_left,
  963. uint32_t pooling_height,
  964. uint32_t pooling_width,
  965. uint32_t input_id,
  966. uint32_t output_value_id,
  967. uint32_t output_index_id,
  968. uint32_t flags);
  969. /// Define a 2D UnPooling Node and add it to a Subgraph.
  970. ///
  971. /// @param subgraph - a Subgraph object that will own the created Node.
  972. /// @param padding_top - implicit padding above 2D output data.
  973. /// @param padding_right - implicit padding to the right of 2D output data.
  974. /// @param padding_bottom - implicit padding below 2D output data.
  975. /// @param padding_left - implicit padding to the left of 2D output data.
  976. /// @param pooling_height - height of the pooling window.
  977. /// @param pooling_width - width of the pooling window.
  978. /// @param input_value_id - Value ID for the input tensor with the max-pooling values to invert. The input value tensor
  979. /// must be a 4D tensor defined in the @a subgraph with [N, IH, IW, channels] dimensions.
  980. /// @param input_index_id - Value ID for the input tensor with the indices of the per-pool maximum values produced by
  981. /// a 2D UnPooling Node. The input tensor must be a 4D tensor defined in the @a subgraph with
  982. /// [N, IH, IW, channels] dimensions.
  983. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  984. /// with [N, OH, OW, channels] dimensions.
  985. /// @param flags - binary features of the 2D UnPooling Node. No supported flags are currently defined.
  986. enum xnn_status xnn_define_unpooling_2d(
  987. xnn_subgraph_t subgraph,
  988. uint32_t padding_top,
  989. uint32_t padding_right,
  990. uint32_t padding_bottom,
  991. uint32_t padding_left,
  992. uint32_t pooling_height,
  993. uint32_t pooling_width,
  994. uint32_t input_value_id,
  995. uint32_t input_index_id,
  996. uint32_t output_id,
  997. uint32_t flags);
  998. enum xnn_binary_operator {
  999. xnn_binary_invalid = -1,
  1000. xnn_binary_add,
  1001. xnn_binary_subtract,
  1002. xnn_binary_multiply,
  1003. xnn_binary_divide,
  1004. xnn_binary_maximum,
  1005. xnn_binary_minimum,
  1006. xnn_binary_copysign,
  1007. xnn_binary_squared_difference,
  1008. xnn_binary_prelu,
  1009. // The following operators are experimental and may be removed.
  1010. xnn_binary_modulus,
  1011. xnn_binary_atan2,
  1012. xnn_binary_pow,
  1013. xnn_binary_bitwise_and,
  1014. xnn_binary_bitwise_or,
  1015. xnn_binary_bitwise_xor,
  1016. xnn_binary_shift_left,
  1017. xnn_binary_shift_right_logical,
  1018. xnn_binary_shift_right_arithmetic,
  1019. };
  1020. struct xnn_binary_params {
  1021. /// lower bound for clipping output values.
  1022. double output_min;
  1023. /// upper bound for clipping output values.
  1024. double output_max;
  1025. };
  1026. /// Define a 2-Input binary operator Node and add it to a Subgraph.
  1027. ///
  1028. /// @param subgraph - a Subgraph object that will own the created Node.
  1029. /// @param type - Type of operator to apply to the two inputs.
  1030. /// @param params - Optional parameters for the operator.
  1031. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1032. /// the @a subgraph with each dimension either equal to the corresponding dimension of the second
  1033. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1034. /// that dimension.
  1035. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in
  1036. /// the @a subgraph with each dimension either equal to the corresponding dimension of the first
  1037. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1038. /// that dimension.
  1039. /// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined
  1040. /// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension
  1041. /// of the two inputs.
  1042. /// @param flags - binary features of the Node. No supported flags are currently defined.
  1043. enum xnn_status xnn_define_binary(
  1044. xnn_subgraph_t subgraph,
  1045. enum xnn_binary_operator type,
  1046. const struct xnn_binary_params* params,
  1047. uint32_t input1_id,
  1048. uint32_t input2_id,
  1049. uint32_t output_id,
  1050. uint32_t flags);
  1051. /// Define a 2-Input Add Node and add it to a Subgraph.
  1052. ///
  1053. /// The 2-Input Add Node computes elementwise addition of two tensor inputs with numpy broadcasting rules.
  1054. ///
  1055. /// @param subgraph - a Subgraph object that will own the created Node.
  1056. /// @param output_min - lower bound for clipping output values.
  1057. /// @param output_max - upper bound for clipping output values.
  1058. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1059. /// the @a subgraph with each dimension either equal to the corresponding dimension of the second
  1060. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1061. /// that dimension.
  1062. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in
  1063. /// the @a subgraph with each dimension either equal to the corresponding dimension of the first
  1064. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1065. /// that dimension.
  1066. /// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined
  1067. /// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension
  1068. /// of the two inputs.
  1069. /// @param flags - binary features of the Add Node. No supported flags are currently defined.
  1070. XNN_DEPRECATED enum xnn_status xnn_define_add2(
  1071. xnn_subgraph_t subgraph,
  1072. float output_min,
  1073. float output_max,
  1074. uint32_t input1_id,
  1075. uint32_t input2_id,
  1076. uint32_t output_id,
  1077. uint32_t flags);
  1078. /// Define a 2-Input Multiply Node and add it to a Subgraph.
  1079. ///
  1080. /// The 2-Input Multiply Node computes elementwise multiplication of two tensor inputs with numpy broadcasting rules.
  1081. ///
  1082. /// @param subgraph - a Subgraph object that will own the created Node.
  1083. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1084. /// the @a subgraph with each dimension either equal to the corresponding dimension of the second
  1085. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1086. /// that dimension.
  1087. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in
  1088. /// the @a subgraph with each dimension either equal to the corresponding dimension of the first
  1089. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1090. /// that dimension.
  1091. /// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined
  1092. /// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension
  1093. /// of the two inputs.
  1094. /// @param flags - binary features of the Multiply Node. No supported flags are currently defined.
  1095. XNN_DEPRECATED enum xnn_status xnn_define_multiply2(
  1096. xnn_subgraph_t subgraph,
  1097. float output_min,
  1098. float output_max,
  1099. uint32_t input1_id,
  1100. uint32_t input2_id,
  1101. uint32_t output_id,
  1102. uint32_t flags);
  1103. // Cap operations applied to logits (Q * K) of attention operator.
  1104. enum xnn_attention_logits_cap_type {
  1105. // No capping.
  1106. xnn_attention_logits_cap_type_none = 0,
  1107. // Cap the absolute values of logits by tanh: tanh(logits / cap) * cap
  1108. xnn_attention_logits_cap_type_tanh
  1109. };
  1110. // Params when the cap type is xnn_attention_logits_cap_type_tanh.
  1111. struct xnn_attention_logits_cap_tanh_params {
  1112. float cap;
  1113. };
  1114. /// Define a Scaled Dot-Product Attention Node and add it to a Subgraph.
  1115. ///
  1116. /// This operator is experimental.
  1117. ///
  1118. /// The Scaled Dot-Product Attention Node computes a multi-head or multi-query scaled dot attention on the query, key,
  1119. /// and value tensors.
  1120. ///
  1121. /// @param subgraph - a Subgraph object that will own the created Node.
  1122. /// @param cap_type - type of cap to be applied to the logits.
  1123. /// @param cap_params - parameters for the cap. Must be a pointer to xnn_attention_logits_cap_tanh_params if cap_type
  1124. /// is xnn_attention_logits_cap_type_tanh.
  1125. /// @param query_id - Value ID for the query tensor. The query tensor must be a 3+-dimensional tensor defined in the
  1126. /// @a subgraph with the dimensions as [*, H, T, C], where H/T/C are the heads/tokens/channels, and *
  1127. /// is the 0 or more dimensions treated as batch size.
  1128. /// @param key_id - Value ID for the key tensor. The key tensor must be a 2+--dimensional tensor defined in the
  1129. /// @a subgraph. It can have the same number of dimensions as the query, with the dimensions as
  1130. /// [*, H, U, C] (multi-head), or have 1 less dimension than the query, with the dimensions as
  1131. /// as [*, U, C] (multi-query, number of heads omitted implies single head), where H/U/C are the
  1132. /// heads/key_value_tokens/channels, and * is the 0 or more dimensions treated as batch size. These
  1133. /// batch size dimensions must be the same as query.
  1134. /// @param value_id - Value ID for the value tensor. The value tensor must be a 2+--dimensional tensor defined in the
  1135. /// @a subgraph. It can have the same number of dimensions as the query, with the dimensions as
  1136. /// [*, H, U, D] (multi-head), or have 1 less dimension than the query, with the dimensions as
  1137. /// as [*, U, D] (multi-query, number of heads omitted implies single head), where H/U/D are the
  1138. /// heads/key_value_tokens/value_channels, and * is the 0 or more dimensions treated as batch size.
  1139. /// These batch size dimensions must be the same as query and key.
  1140. /// @param scale_id - Value ID for the scale tensor. The scale tensor must be a 1D tensor defined in the @a subgraph
  1141. /// with [C] dimensions. The query tensor is multiplied with this scale tensor before the dot product
  1142. /// with the key tensor.
  1143. /// @param mask_id - Value ID for the mask tensor. The mask tensor must be a 2D tensor defined in the @a subgraph with
  1144. /// [T, U] dimensions. The mask tensor is added to the logits (query dot value).
  1145. /// @param output_id - Value ID for the output tensor. The output tensor must be a 3+-dimensional tensor defined in the
  1146. /// @a subgraph with the dimensions as [*, H, T, D], where H/T/D are the heads/tokens/value_channels,
  1147. /// and * is the 0 or more dimensions treated as batch size. These batch size dimensions must be the
  1148. /// same as query, key, and value.
  1149. /// @param flags - binary features of the Scaled Dot Product Attention Node. No supported flags are currently defined.
  1150. enum xnn_status xnn_define_scaled_dot_product_attention(
  1151. xnn_subgraph_t subgraph,
  1152. enum xnn_attention_logits_cap_type cap_type,
  1153. const void* cap_params,
  1154. uint32_t query_id,
  1155. uint32_t key_id,
  1156. uint32_t value_id,
  1157. uint32_t scale_id,
  1158. uint32_t mask_id,
  1159. uint32_t output_id,
  1160. uint32_t flags);
  1161. /// Define a Subtract Node and add it to a Subgraph.
  1162. ///
  1163. /// The Subtract Node computes elementwise subtraction of two tensor inputs with numpy broadcasting rules.
  1164. ///
  1165. /// @param subgraph - a Subgraph object that will own the created Node.
  1166. /// @param output_min - lower bound for clipping output values.
  1167. /// @param output_max - upper bound for clipping output values.
  1168. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1169. /// the @a subgraph with each dimension either equal to the corresponding dimension of the second
  1170. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1171. /// that dimension.
  1172. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in
  1173. /// the @a subgraph with each dimension either equal to the corresponding dimension of the first
  1174. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1175. /// that dimension.
  1176. /// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined
  1177. /// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension
  1178. /// of the two inputs.
  1179. /// @param flags - binary features of the Subtract Node. No supported flags are currently defined.
  1180. XNN_DEPRECATED enum xnn_status xnn_define_subtract(
  1181. xnn_subgraph_t subgraph,
  1182. float output_min,
  1183. float output_max,
  1184. uint32_t input1_id,
  1185. uint32_t input2_id,
  1186. uint32_t output_id,
  1187. uint32_t flags);
  1188. /// Define a Divide Node and add it to a Subgraph.
  1189. ///
  1190. /// The Divide Node computes elementwise division of two tensor inputs with numpy broadcasting rules.
  1191. ///
  1192. /// @param subgraph - a Subgraph object that will own the created Node.
  1193. /// @param output_min - lower bound for clipping output values.
  1194. /// @param output_max - upper bound for clipping output values.
  1195. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1196. /// the @a subgraph with each dimension either equal to the corresponding dimension of the second
  1197. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1198. /// that dimension.
  1199. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in
  1200. /// the @a subgraph with each dimension either equal to the corresponding dimension of the first
  1201. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1202. /// that dimension.
  1203. /// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined
  1204. /// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension
  1205. /// of the two inputs.
  1206. /// @param flags - binary features of the Divide Node. No supported flags are currently defined.
  1207. XNN_DEPRECATED enum xnn_status xnn_define_divide(
  1208. xnn_subgraph_t subgraph,
  1209. float output_min,
  1210. float output_max,
  1211. uint32_t input1_id,
  1212. uint32_t input2_id,
  1213. uint32_t output_id,
  1214. uint32_t flags);
  1215. /// Define a 2-Input Maximum Node and add it to a Subgraph.
  1216. ///
  1217. /// The 2-Input Maximum Node computes elementwise maximum of two tensor inputs with numpy broadcasting rules.
  1218. ///
  1219. /// @param subgraph - a Subgraph object that will own the created Node.
  1220. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1221. /// the @a subgraph with each dimension either equal to the corresponding dimension of the second
  1222. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1223. /// that dimension.
  1224. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in
  1225. /// the @a subgraph with each dimension either equal to the corresponding dimension of the first
  1226. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1227. /// that dimension.
  1228. /// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined
  1229. /// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension
  1230. /// of the two inputs.
  1231. /// @param flags - binary features of the Maximum Node. No supported flags are currently defined.
  1232. XNN_DEPRECATED enum xnn_status xnn_define_maximum2(
  1233. xnn_subgraph_t subgraph,
  1234. uint32_t input1_id,
  1235. uint32_t input2_id,
  1236. uint32_t output_id,
  1237. uint32_t flags);
  1238. /// Define a 2-Input Minimum Node and add it to a Subgraph.
  1239. ///
  1240. /// The 2-Input Minimum Node computes elementwise minimum of two tensor inputs with numpy broadcasting rules.
  1241. ///
  1242. /// @param subgraph - a Subgraph object that will own the created Node.
  1243. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1244. /// the @a subgraph with each dimension either equal to the corresponding dimension of the second
  1245. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1246. /// that dimension.
  1247. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in
  1248. /// the @a subgraph with each dimension either equal to the corresponding dimension of the first
  1249. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1250. /// that dimension.
  1251. /// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined
  1252. /// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension
  1253. /// of the two inputs.
  1254. /// @param flags - binary features of the Minimum Node. No supported flags are currently defined.
  1255. XNN_DEPRECATED enum xnn_status xnn_define_minimum2(
  1256. xnn_subgraph_t subgraph,
  1257. uint32_t input1_id,
  1258. uint32_t input2_id,
  1259. uint32_t output_id,
  1260. uint32_t flags);
  1261. /// Define a Squared Difference Node and add it to a Subgraph.
  1262. ///
  1263. /// The Squared Difference Node computes elementwise squared difference of two tensor inputs with numpy broadcasting
  1264. /// rules.
  1265. ///
  1266. /// @param subgraph - a Subgraph object that will own the created Node.
  1267. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1268. /// the @a subgraph with each dimension either equal to the corresponding dimension of the second
  1269. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1270. /// that dimension.
  1271. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in
  1272. /// the @a subgraph with each dimension either equal to the corresponding dimension of the first
  1273. /// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along
  1274. /// that dimension.
  1275. /// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined
  1276. /// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension
  1277. /// of the two inputs.
  1278. /// @param flags - binary features of the Squared Difference Node. No supported flags are currently defined.
  1279. XNN_DEPRECATED enum xnn_status xnn_define_squared_difference(
  1280. xnn_subgraph_t subgraph,
  1281. uint32_t input1_id,
  1282. uint32_t input2_id,
  1283. uint32_t output_id,
  1284. uint32_t flags);
  1285. /// Define a Constant Pad Node with static padding specification and add it to a Subgraph.
  1286. ///
  1287. /// @param subgraph - a Subgraph object that will own the created Node.
  1288. /// @param pre_paddings - number of padding elements to insert before input elements for every dimension. This array
  1289. /// must have as many elements as the number of dimensions in the input tensor.
  1290. /// @param post_paddings - number of padding elements to insert after input elements for every dimension. This array
  1291. /// must have as many elements as the number of dimensions in the input tensor.
  1292. /// @param padding_value - constant value used to initialize padding elements.
  1293. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1294. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1295. /// shape must match the shape of the input tensor with padding.
  1296. /// @param flags - binary features of the Constant Pad Node. No supported flags are currently defined.
  1297. enum xnn_status xnn_define_static_constant_pad(
  1298. xnn_subgraph_t subgraph,
  1299. const size_t* pre_paddings,
  1300. const size_t* post_paddings,
  1301. float padding_value,
  1302. uint32_t input_id,
  1303. uint32_t output_id,
  1304. uint32_t flags);
  1305. /// Define a Expand Dims Node with and add it to a Subgraph.
  1306. ///
  1307. /// @param subgraph - a Subgraph object that will own the created Node.
  1308. /// @param num_new_axes - number of new axes of size 1 to be inserted.
  1309. /// @param new_axes - The axis positions of the new axes in the expanded dimensions.
  1310. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1311. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1312. /// shape must match the shape of the input tensor with padding.
  1313. /// @param flags - binary features of the Constant Pad Node. No supported flags are currently defined.
  1314. enum xnn_status xnn_define_static_expand_dims(
  1315. xnn_subgraph_t subgraph,
  1316. size_t num_new_axes,
  1317. const size_t* new_axes,
  1318. uint32_t input_id,
  1319. uint32_t output_id,
  1320. uint32_t flags);
  1321. /// Define a Mean Node and add it to a Subgraph.
  1322. ///
  1323. /// @param subgraph - a Subgraph object that will own the created Node.
  1324. /// @param num_reduction_axes - number of axes along which mean is computed.
  1325. /// @param reduction_axes - axes along which mean is computed.
  1326. /// @param input_id - Value ID for the input tensor. The input tensor must be a dense tensor with at least
  1327. /// @a num_reduction_axes dimensions defined in the @a subgraph.
  1328. /// @param output_id - Value ID for the output tensor. The output tensor must be a dense tensor defined in the
  1329. /// @a subgraph with @a num_reduction_axes fewer dimensions than the input tensor (if
  1330. /// XNN_FLAG_KEEP_DIMS is not specified), or has same dimension rank but the dimension at
  1331. /// @a reduction_axes reduced to 1 (if XNN_FLAG_KEEP_DIMS is specified).
  1332. /// @param flags - binary features of the Mean Node. The only currently supported value is XNN_FLAG_KEEP_DIMS
  1333. XNN_DEPRECATED enum xnn_status xnn_define_static_mean(
  1334. xnn_subgraph_t subgraph,
  1335. size_t num_reduction_axes,
  1336. const size_t* reduction_axes,
  1337. uint32_t input_id,
  1338. uint32_t output_id,
  1339. uint32_t flags);
  1340. enum xnn_reduce_operator {
  1341. xnn_reduce_invalid = -1,
  1342. xnn_reduce_sum,
  1343. xnn_reduce_mean,
  1344. };
  1345. /// Define a Reduce Node and add it to a Subgraph.
  1346. ///
  1347. /// @param subgraph - a Subgraph object that will own the created Node.
  1348. /// @param num_reduction_axes - number of axes along which reduce is computed.
  1349. /// @param reduction_axes - axes along which reduce is computed.
  1350. /// @param input_id - Value ID for the input tensor. The input tensor must be a dense tensor with at least
  1351. /// @a num_reduction_axes dimensions defined in the @a subgraph.
  1352. /// @param output_id - Value ID for the output tensor. The output tensor must be a dense tensor defined in the
  1353. /// @a subgraph with @a num_reduction_axes fewer dimensions than the input tensor (if
  1354. /// XNN_FLAG_KEEP_DIMS is not specified), or has same dimension rank but the dimension at
  1355. /// @a reduction_axes reduced to 1 (if XNN_FLAG_KEEP_DIMS is specified).
  1356. /// @param flags - binary features of the Reduce Node. The only currently supported value is XNN_FLAG_KEEP_DIMS
  1357. enum xnn_status xnn_define_static_reduce(
  1358. xnn_subgraph_t subgraph,
  1359. enum xnn_reduce_operator reduce_operator_type,
  1360. size_t num_reduction_axes,
  1361. const size_t* reduction_axes,
  1362. uint32_t input_id,
  1363. uint32_t output_id,
  1364. uint32_t flags);
  1365. /// Define a Reduce Node and add it to a Subgraph.
  1366. ///
  1367. /// @param subgraph - a Subgraph object that will own the created Node.
  1368. /// @param num_reduction_axes - number of axes along which reduce is computed.
  1369. /// @param reduction_axes - axes along which reduce is computed. Negative values
  1370. /// are interpreted as offsets from @a
  1371. /// num_reduction_axes.
  1372. /// @param input_id - Value ID for the input tensor. The input tensor must be a
  1373. /// dense tensor with at least @a num_reduction_axes
  1374. /// dimensions defined in the @a subgraph.
  1375. /// @param output_id - Value ID for the output tensor. The output tensor must be
  1376. /// a dense tensor defined in the @a subgraph with @a
  1377. /// num_reduction_axes fewer dimensions than the input tensor
  1378. /// (if XNN_FLAG_KEEP_DIMS is not specified), or has same
  1379. /// dimension rank but the dimension at
  1380. /// @a reduction_axes reduced to 1 (if XNN_FLAG_KEEP_DIMS is
  1381. /// specified).
  1382. /// @param flags - binary features of the Reduce Node. The only currently
  1383. /// supported value is XNN_FLAG_KEEP_DIMS
  1384. enum xnn_status xnn_define_static_reduce_v2( //
  1385. xnn_subgraph_t subgraph, //
  1386. enum xnn_reduce_operator reduce_operator_type, //
  1387. size_t num_reduction_axes, //
  1388. const int64_t* reduction_axes, //
  1389. uint32_t input_id, //
  1390. uint32_t output_id, //
  1391. uint32_t flags);
  1392. /// Define a 2-Input Concatenate Node and add it to a Subgraph.
  1393. ///
  1394. /// The 2-Input Concatenate Node concatenates two tensors along a specified axis.
  1395. ///
  1396. /// @param subgraph - a Subgraph object that will own the created Node.
  1397. /// @param axis - the axis to concatenate the two input tensors along. If this is less than zero, the number of
  1398. /// dimensions is added to it.
  1399. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1400. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1401. /// second input.
  1402. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an N-dimensional tensor defined in
  1403. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1404. /// first input.
  1405. /// @param output_id - Value ID for the output tensor. The output tensor must be a N-dimensional tensor defined
  1406. /// in the @a subgraph with each dimension equal to the dimension of both inputs, except the axis
  1407. /// dimension, where it is the sum of the corresponding dimensions of both inputs.
  1408. /// @param flags - binary features of the Concatenate Node. No supported flags are currently defined.
  1409. enum xnn_status xnn_define_concatenate2(
  1410. xnn_subgraph_t subgraph,
  1411. int32_t axis,
  1412. uint32_t input1_id,
  1413. uint32_t input2_id,
  1414. uint32_t output_id,
  1415. uint32_t flags);
  1416. /// Define a 3-Input Concatenate Node and add it to a Subgraph.
  1417. ///
  1418. /// The 3-Input Concatenate Node concatenates three tensors along a specified axis.
  1419. ///
  1420. /// @param subgraph - a Subgraph object that will own the created Node.
  1421. /// @param axis - the axis to concatenate the two input tensors along. If this is less than zero, the number of
  1422. /// dimensions is added to it.
  1423. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1424. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1425. /// other inputs.
  1426. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an N-dimensional tensor defined in
  1427. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1428. /// other inputs.
  1429. /// @param input3_id - Value ID for the third input tensor. The input tensor must be an N-dimensional tensor defined in
  1430. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1431. /// other inputs.
  1432. /// @param output_id - Value ID for the output tensor. The output tensor must be a N-dimensional tensor defined
  1433. /// in the @a subgraph with each dimension equal to the dimension of all inputs, except the axis
  1434. /// dimension, where it is the sum of the corresponding dimensions of all inputs.
  1435. /// @param flags - binary features of the Concatenate Node. No supported flags are currently defined.
  1436. enum xnn_status xnn_define_concatenate3(
  1437. xnn_subgraph_t subgraph,
  1438. int32_t axis,
  1439. uint32_t input1_id,
  1440. uint32_t input2_id,
  1441. uint32_t input3_id,
  1442. uint32_t output_id,
  1443. uint32_t flags);
  1444. /// Define a 4-Input Concatenate Node and add it to a Subgraph.
  1445. ///
  1446. /// The 4-Input Concatenate Node concatenates four tensors along a specified axis.
  1447. ///
  1448. /// @param subgraph - a Subgraph object that will own the created Node.
  1449. /// @param axis - the axis to concatenate the two input tensors along. If this is less than zero, the number of
  1450. /// dimensions is added to it.
  1451. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1452. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1453. /// other inputs.
  1454. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an N-dimensional tensor defined in
  1455. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1456. /// other inputs.
  1457. /// @param input3_id - Value ID for the third input tensor. The input tensor must be an N-dimensional tensor defined in
  1458. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1459. /// other inputs.
  1460. /// @param input4_id - Value ID for the fourth input tensor. The input tensor must be an N-dimensional tensor defined in
  1461. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1462. /// other inputs.
  1463. /// @param output_id - Value ID for the output tensor. The output tensor must be a N-dimensional tensor defined
  1464. /// in the @a subgraph with each dimension equal to the dimension of all inputs, except the axis
  1465. /// dimension, where it is the sum of the corresponding dimensions of all inputs.
  1466. /// @param flags - binary features of the Concatenate Node. No supported flags are currently defined.
  1467. enum xnn_status xnn_define_concatenate4(
  1468. xnn_subgraph_t subgraph,
  1469. int32_t axis,
  1470. uint32_t input1_id,
  1471. uint32_t input2_id,
  1472. uint32_t input3_id,
  1473. uint32_t input4_id,
  1474. uint32_t output_id,
  1475. uint32_t flags);
  1476. /// Define a 5-Input Concatenate Node and add it to a Subgraph.
  1477. ///
  1478. /// The 5-Input Concatenate Node concatenates four tensors along a specified axis.
  1479. ///
  1480. /// @param subgraph - a Subgraph object that will own the created Node.
  1481. /// @param axis - the axis to concatenate the two input tensors along. If this is less than zero, the number of
  1482. /// dimensions is added to it.
  1483. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1484. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1485. /// other inputs.
  1486. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an N-dimensional tensor defined in
  1487. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1488. /// other inputs.
  1489. /// @param input3_id - Value ID for the third input tensor. The input tensor must be an N-dimensional tensor defined in
  1490. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1491. /// other inputs.
  1492. /// @param input4_id - Value ID for the fourth input tensor. The input tensor must be an N-dimensional tensor defined in
  1493. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1494. /// other inputs.
  1495. /// @param input5_id - Value ID for the fourth input tensor. The input tensor must be an N-dimensional tensor defined in
  1496. /// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the
  1497. /// other inputs.
  1498. /// @param output_id - Value ID for the output tensor. The output tensor must be a N-dimensional tensor defined
  1499. /// in the @a subgraph with each dimension equal to the dimension of all inputs, except the axis
  1500. /// dimension, where it is the sum of the corresponding dimensions of all inputs.
  1501. enum xnn_status xnn_define_concatenate5(
  1502. xnn_subgraph_t subgraph,
  1503. int32_t axis,
  1504. uint32_t input1_id,
  1505. uint32_t input2_id,
  1506. uint32_t input3_id,
  1507. uint32_t input4_id,
  1508. uint32_t input5_id,
  1509. uint32_t output_id,
  1510. uint32_t flags);
  1511. /// Define a Copy Sign Node and add it to a Subgraph.
  1512. ///
  1513. /// The Copy Sign Node copies the sign of the second input to the first input.
  1514. ///
  1515. /// @param subgraph - a Subgraph object that will own the created Node.
  1516. /// @param input1_id - Value ID for the first input tensor. The input tensor must be defined in the @a subgraph.
  1517. /// @param input2_id - Value ID for the second input tensor. The input tensor must be defined in the @a subgraph.
  1518. /// @param output_id - Value ID for the output tensor.
  1519. /// @param flags - binary features of the Copy Sign Node. No supported flags are currently defined.
  1520. XNN_DEPRECATED enum xnn_status xnn_define_copysign(
  1521. xnn_subgraph_t subgraph,
  1522. uint32_t input1_id,
  1523. uint32_t input2_id,
  1524. uint32_t output_id,
  1525. uint32_t flags);
  1526. /// Define a Copy Node and add it to a Subgraph.
  1527. ///
  1528. /// The Copy Node copies an input tensor to an output tensor.
  1529. ///
  1530. /// @param subgraph - a Subgraph object that will own the created Node.
  1531. /// @param input_id - Value ID for the first input tensor. The input tensor must be defined in the @a subgraph.
  1532. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1533. /// shape must match the shape of the input tensor.
  1534. /// @param flags - binary features of the Copy Node. No supported flags are currently defined.
  1535. enum xnn_status xnn_define_copy(
  1536. xnn_subgraph_t subgraph,
  1537. uint32_t input_id,
  1538. uint32_t output_id,
  1539. uint32_t flags);
  1540. /// Define a 2-Output Split Node and add it to a Subgraph.
  1541. ///
  1542. /// The 2-Output Split Node splits an input tensor into two output tensors along a specified axis evenly.
  1543. ///
  1544. /// @param subgraph - a Subgraph object that will own the created Node.
  1545. /// @param split_dim - the dimension to split the input tensor along. If this is less than zero, the number of
  1546. /// dimensions is added to it.
  1547. /// @param input_id - Value ID for the input tensor. The input tensor must be an N-dimensional tensor defined in the @a
  1548. /// subgraph.
  1549. /// @param output1_id - Value ID for the first output tensor. The output tensor must be an N-dimensional tensor defined
  1550. /// in the @a subgraph with each dimension, except the axis, equal to the corresponding dimension
  1551. /// of the second output. The split_dim dimension is half of the input's split_dim.
  1552. /// @param output2_id - Value ID for the second output tensor. The output tensor must be an N-dimensional tensor
  1553. /// defined in the @a subgraph with each dimension, except the axis, equal to the corresponding
  1554. /// dimension of the first output. The split_dim dimension is half of the input's split_dim.
  1555. /// @param flags - binary features of the Split Node. No supported flags are currently defined.
  1556. enum xnn_status xnn_define_even_split2(
  1557. xnn_subgraph_t subgraph,
  1558. int32_t split_dim,
  1559. uint32_t input_id,
  1560. uint32_t output1_id,
  1561. uint32_t output2_id,
  1562. uint32_t flags);
  1563. /// Define a 3-Output Split Node and add it to a Subgraph.
  1564. ///
  1565. /// The 3-Output Split Node splits an input tensor into three output tensors along a specified axis evenly.
  1566. ///
  1567. /// @param subgraph - a Subgraph object that will own the created Node.
  1568. /// @param split_dim - the dimension to split the input tensor along. If this is less than zero, the number of
  1569. /// dimensions is added to it.
  1570. /// @param input_id - Value ID for the input tensor. The input tensor must be an N-dimensional tensor defined in the @a
  1571. /// subgraph.
  1572. /// @param output1_id - Value ID for the first output tensor. The output tensor must be an N-dimensional tensor defined
  1573. /// in the @a subgraph with each dimension, except the axis, equal to the corresponding dimension
  1574. /// of the second and third output. The split_dim dimension is one third of the input's split_dim.
  1575. /// @param output2_id - Value ID for the second output tensor. The output tensor must be an N-dimensional tensor
  1576. /// defined in the @a subgraph with each dimension, except the axis, equal to the corresponding
  1577. /// dimension of the first and third output. The split_dim dimension is one third of the input's
  1578. /// split_dim.
  1579. /// @param output3_id - Value ID for the third output tensor. The output tensor must be an N-dimensional tensor
  1580. /// defined in the @a subgraph with each dimension, except the axis, equal to the corresponding
  1581. /// dimension of the second and third output. The split_dim dimension is one third of the input's
  1582. /// split_dim.
  1583. /// @param flags - binary features of the Split Node. No supported flags are currently defined.
  1584. enum xnn_status xnn_define_even_split3(
  1585. xnn_subgraph_t subgraph,
  1586. int32_t split_dim,
  1587. uint32_t input_id,
  1588. uint32_t output1_id,
  1589. uint32_t output2_id,
  1590. uint32_t output3_id,
  1591. uint32_t flags);
  1592. /// Define a 4-Output Split Node and add it to a Subgraph.
  1593. ///
  1594. /// The 4-Output Split Node splits an input tensor into four output tensors along a specified axis evenly.
  1595. ///
  1596. /// @param subgraph - a Subgraph object that will own the created Node.
  1597. /// @param split_dim - the dimension to split the input tensor along. If this is less than zero, the number of
  1598. /// dimensions is added to it.
  1599. /// @param input_id - Value ID for the input tensor. The input tensor must be an N-dimensional tensor defined in the @a
  1600. /// subgraph.
  1601. /// @param output1_id - Value ID for the first output tensor. The output tensor must be an N-dimensional tensor defined
  1602. /// in the @a subgraph with each dimension, except the axis, equal to the corresponding dimension
  1603. /// of the other output tensors. The split_dim dimension is one fourth of the input's split_dim.
  1604. /// @param output2_id - Value ID for the second output tensor. The output tensor must be an N-dimensional tensor
  1605. /// defined in the @a subgraph with each dimension, except the axis, equal to the corresponding
  1606. /// dimension of the other output tensors. The split_dim dimension is one fourth of the input's
  1607. /// split_dim.
  1608. /// @param output3_id - Value ID for the third output tensor. The output tensor must be an N-dimensional tensor
  1609. /// defined in the @a subgraph with each dimension, except the axis, equal to the corresponding
  1610. /// dimension of the other output tensors. The split_dim dimension is one fourth of the input's
  1611. /// split_dim.
  1612. /// @param output4_id - Value ID for the fourth output tensor. The output tensor must be an N-dimensional tensor
  1613. /// defined in the @a subgraph with each dimension, except the axis, equal to the corresponding
  1614. /// dimension of the other output tensors. The split_dim dimension is one fourth of the input's
  1615. /// split_dim.
  1616. /// @param flags - binary features of the Split Node. No supported flags are currently defined.
  1617. enum xnn_status xnn_define_even_split4(
  1618. xnn_subgraph_t subgraph,
  1619. int32_t split_dim,
  1620. uint32_t input_id,
  1621. uint32_t output1_id,
  1622. uint32_t output2_id,
  1623. uint32_t output3_id,
  1624. uint32_t output4_id,
  1625. uint32_t flags);
  1626. /// Define a Reshape Node with static shape specification and add it to a Subgraph.
  1627. ///
  1628. /// @param subgraph - a Subgraph object that will own the created Node.
  1629. /// @param num_dims - number of shape dimensions in the output tensor.
  1630. /// @param new_shape - shape dimensions of the output tensor.
  1631. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1632. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1633. /// shape must match the shape of the input tensor with padding.
  1634. /// @param flags - binary features of the Reshape Node. No supported flags are currently defined.
  1635. enum xnn_status xnn_define_static_reshape(
  1636. xnn_subgraph_t subgraph,
  1637. size_t num_dims,
  1638. const size_t* new_shape,
  1639. uint32_t input_id,
  1640. uint32_t output_id,
  1641. uint32_t flags);
  1642. /// Define a 2D Resize Bilinear Node with static output height & width specification and add it to a Subgraph.
  1643. ///
  1644. /// @param subgraph - a Subgraph object that will own the created Node.
  1645. /// @param new_height - height dimension of the output tensor.
  1646. /// @param new_width - width dimension of the output tensor.
  1647. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  1648. /// with [N, H, W, C] dimensions.
  1649. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  1650. /// with [N, new_height, new_width, C] dimensions.
  1651. /// @param flags - binary features of the 2D Resize Bilinear Node. The only currently supported values are
  1652. /// XNN_FLAG_TENSORFLOW_LEGACY_MODE and XNN_FLAG_ALIGN_CORNERS, which are mutually exclusive.
  1653. enum xnn_status xnn_define_static_resize_bilinear_2d(
  1654. xnn_subgraph_t subgraph,
  1655. size_t new_height,
  1656. size_t new_width,
  1657. uint32_t input_id,
  1658. uint32_t output_id,
  1659. uint32_t flags);
  1660. /// Define a PReLU (Parametric ReLU) Node and add it to a Subgraph.
  1661. ///
  1662. /// @param subgraph - a Subgraph object that will own the created Node.
  1663. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  1664. /// with [N, H, W, channels] dimensions.
  1665. /// @param slope_id - Value ID for the slope tensor. The slope tensor must be a 1D tensor defined in the @a subgraph with
  1666. /// either [1] or [channels] dimensions.
  1667. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  1668. /// with [N, H, W, channels] dimensions.
  1669. /// @param flags - binary features of the PReLU Node. No supported flags are currently defined.
  1670. XNN_DEPRECATED enum xnn_status xnn_define_prelu(
  1671. xnn_subgraph_t subgraph,
  1672. uint32_t input_id,
  1673. uint32_t slope_id,
  1674. uint32_t output_id,
  1675. uint32_t flags);
  1676. /// Define a RoPE (Rotary Positional Embeddings) Node and add it to a Subgraph.
  1677. ///
  1678. /// @param subgraph - a Subgraph object that will own the created Node.
  1679. /// @param max_tokens - deprecated.
  1680. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  1681. /// with [batch, tokens, heads, channels] dimensions.
  1682. /// @param weights_id - Value ID for the weights tensor. The weights tensor must be a 2D tensor defined in the
  1683. /// @a subgraph with [max_tokens, channels] dimensions.
  1684. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  1685. /// with [batch, tokens, heads, channels] dimensions.
  1686. /// @param flags - binary features of the RoPE Node. No supported flags are currently defined.
  1687. enum xnn_status xnn_define_rope(
  1688. xnn_subgraph_t subgraph,
  1689. size_t max_sequence_size,
  1690. uint32_t input_id,
  1691. uint32_t weights_id,
  1692. uint32_t output_id,
  1693. uint32_t flags);
  1694. /// Define a Abs Node and add it to a Subgraph.
  1695. ///
  1696. /// @param subgraph - a Subgraph object that will own the created Node.
  1697. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1698. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1699. /// shape must match the shape of the input tensor.
  1700. /// @param flags - binary features of the Abs Node. No supported flags are currently defined.
  1701. XNN_DEPRECATED enum xnn_status xnn_define_abs(
  1702. xnn_subgraph_t subgraph,
  1703. uint32_t input_id,
  1704. uint32_t output_id,
  1705. uint32_t flags);
  1706. /// Define a Bankers' Rounding Node and add it to a Subgraph.
  1707. ///
  1708. /// @param subgraph - a Subgraph object that will own the created Node.
  1709. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1710. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1711. /// shape must match the shape of the input tensor.
  1712. /// @param flags - binary features of the Bankers' Rounding Node. No supported flags are currently defined.
  1713. XNN_DEPRECATED enum xnn_status xnn_define_bankers_rounding(
  1714. xnn_subgraph_t subgraph,
  1715. uint32_t input_id,
  1716. uint32_t output_id,
  1717. uint32_t flags);
  1718. /// Define a Batch Matrix Multiply Node and add it to a Subgraph.
  1719. ///
  1720. /// @param subgraph - a Subgraph object that will own the created Node.
  1721. /// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in
  1722. /// the @a subgraph. It must be at least 3D. The first N-2 dimensions must match the second input
  1723. /// tensor. The last 2 dimensions are [M, K]. If XNN_FLAG_TRANSPOSE_B is not specified, the last
  1724. /// dimension must match the second last dimension of the second input tensor. If
  1725. /// XNN_FLAG_TRANSPOSE_B is specified, the last dimension must match the last dimension of the
  1726. /// second input tensor.
  1727. /// @param input2_id - Value ID for the second input tensor. The input tensor must be an N-dimensional tensor defined
  1728. /// in the @a subgraph. It must be at least 3D. The first N-2 dimensions must match the first input
  1729. /// tensor. If XNN_FLAG_TRANSPOSE_B is not specified, the last 2 dimensions are [K, N], and the
  1730. /// second last dimension must match the last dimension of the first input tensor. If
  1731. /// XNN_FLAG_TRANSPOSE_B is specified, the last 2 dimensions are [N, K], and the last dimension must
  1732. /// match the last dimension of the first input tensor.
  1733. /// @param output_id - Value ID for the output tensor. The output tensor must be an N-dimensional tensor defined in the
  1734. /// @a subgraph. It must be at least 3D. The first N-2 dimensions must match the first and second
  1735. /// input tensors . The last 2 dimensions must be [M, N].
  1736. /// @param flags - binary features of the Batch Matrix Multiply Node. The only currently supported value is
  1737. /// XNN_FLAG_TRANSPOSE_B.
  1738. enum xnn_status xnn_define_batch_matrix_multiply(
  1739. xnn_subgraph_t subgraph,
  1740. uint32_t input1_id,
  1741. uint32_t input2_id,
  1742. uint32_t output_id,
  1743. uint32_t flags);
  1744. /// Define a Ceiling Node and add it to a Subgraph.
  1745. ///
  1746. /// @param subgraph - a Subgraph object that will own the created Node.
  1747. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1748. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1749. /// shape must match the shape of the input tensor.
  1750. /// @param flags - binary features of the Ceiling Node. No supported flags are currently defined.
  1751. XNN_DEPRECATED enum xnn_status xnn_define_ceiling(
  1752. xnn_subgraph_t subgraph,
  1753. uint32_t input_id,
  1754. uint32_t output_id,
  1755. uint32_t flags);
  1756. /// Define a Clamp Node and add it to a Subgraph.
  1757. ///
  1758. /// @param subgraph - a Subgraph object that will own the created Node.
  1759. /// @param output_min - lower bound for clipping output values.
  1760. /// @param output_max - upper bound for clipping output values.
  1761. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1762. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1763. /// shape must match the shape of the input tensor.
  1764. /// @param flags - binary features of the Clamp Node. No supported flags are currently defined.
  1765. XNN_DEPRECATED enum xnn_status xnn_define_clamp(
  1766. xnn_subgraph_t subgraph,
  1767. float output_min,
  1768. float output_max,
  1769. uint32_t input_id,
  1770. uint32_t output_id,
  1771. uint32_t flags);
  1772. /// Define an ELU (Exponential Linear Unit) Node and add it to a Subgraph.
  1773. ///
  1774. /// @param subgraph - a Subgraph object that will own the created Node.
  1775. /// @param alpha - scale factor for negative output elements.
  1776. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1777. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1778. /// shape must match the shape of the input tensor.
  1779. /// @param flags - binary features of the ELU Node. No supported flags are currently defined.
  1780. XNN_DEPRECATED enum xnn_status xnn_define_elu(
  1781. xnn_subgraph_t subgraph,
  1782. float alpha,
  1783. uint32_t input_id,
  1784. uint32_t output_id,
  1785. uint32_t flags);
  1786. /// Define a Exp Node and add it to a Subgraph.
  1787. ///
  1788. /// @param subgraph - a Subgraph object that will own the created Node.
  1789. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1790. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1791. /// shape must match the shape of the input tensor.
  1792. /// @param flags - binary features of the Exp Node. No supported flags are currently defined.
  1793. XNN_DEPRECATED enum xnn_status xnn_define_exp(
  1794. xnn_subgraph_t subgraph,
  1795. uint32_t input_id,
  1796. uint32_t output_id,
  1797. uint32_t flags);
  1798. /// Define a Floor Node and add it to a Subgraph.
  1799. ///
  1800. /// @param subgraph - a Subgraph object that will own the created Node.
  1801. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1802. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1803. /// shape must match the shape of the input tensor.
  1804. /// @param flags - binary features of the Floor Node. No supported flags are currently defined.
  1805. XNN_DEPRECATED enum xnn_status xnn_define_floor(
  1806. xnn_subgraph_t subgraph,
  1807. uint32_t input_id,
  1808. uint32_t output_id,
  1809. uint32_t flags);
  1810. /// Define an GELU (Gaussian Error Linear Unit) Node and add it to a Subgraph.
  1811. ///
  1812. /// @param subgraph - a Subgraph object that will own the created Node.
  1813. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1814. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1815. /// shape must match the shape of the input tensor.
  1816. /// @param flags - binary features of the GELU Node. No supported flags are currently defined.
  1817. XNN_DEPRECATED enum xnn_status xnn_define_gelu(
  1818. xnn_subgraph_t subgraph,
  1819. uint32_t input_id,
  1820. uint32_t output_id,
  1821. uint32_t flags);
  1822. /// Define a HardSwish Node and add it to a Subgraph.
  1823. ///
  1824. /// @param subgraph - a Subgraph object that will own the created Node.
  1825. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1826. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1827. /// shape must match the shape of the input tensor.
  1828. /// @param flags - binary features of the HardSwish Node. No supported flags are currently defined.
  1829. XNN_DEPRECATED enum xnn_status xnn_define_hardswish(
  1830. xnn_subgraph_t subgraph,
  1831. uint32_t input_id,
  1832. uint32_t output_id,
  1833. uint32_t flags);
  1834. /// Define a Leaky ReLU Node and add it to a Subgraph.
  1835. ///
  1836. /// @param subgraph - a Subgraph object that will own the created Node.
  1837. /// @param negative_slope - scale factor for negative input elements.
  1838. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1839. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1840. /// shape must match the shape of the input tensor.
  1841. /// @param flags - binary features of the Leaky ReLU Node. No supported flags are currently defined.
  1842. XNN_DEPRECATED enum xnn_status xnn_define_leaky_relu(
  1843. xnn_subgraph_t subgraph,
  1844. float negative_slope,
  1845. uint32_t input_id,
  1846. uint32_t output_id,
  1847. uint32_t flags);
  1848. /// Define a Log Node and add it to a Subgraph.
  1849. ///
  1850. /// @param subgraph - a Subgraph object that will own the created Node.
  1851. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1852. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1853. /// shape must match the shape of the input tensor.
  1854. /// @param flags - binary features of the Log Node. No supported flags are currently defined.
  1855. XNN_DEPRECATED enum xnn_status xnn_define_log(
  1856. xnn_subgraph_t subgraph,
  1857. uint32_t input_id,
  1858. uint32_t output_id,
  1859. uint32_t flags);
  1860. /// Define a Negate Node and add it to a Subgraph.
  1861. ///
  1862. /// @param subgraph - a Subgraph object that will own the created Node.
  1863. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1864. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1865. /// shape must match the shape of the input tensor.
  1866. /// @param flags - binary features of the Negate Node. No supported flags are currently defined.
  1867. XNN_DEPRECATED enum xnn_status xnn_define_negate(
  1868. xnn_subgraph_t subgraph,
  1869. uint32_t input_id,
  1870. uint32_t output_id,
  1871. uint32_t flags);
  1872. /// Define a Sigmoid Node and add it to a Subgraph.
  1873. ///
  1874. /// @param subgraph - a Subgraph object that will own the created Node.
  1875. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1876. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1877. /// shape must match the shape of the input tensor.
  1878. /// @param flags - binary features of the Sigmoid Node. No supported flags are currently defined.
  1879. XNN_DEPRECATED enum xnn_status xnn_define_sigmoid(
  1880. xnn_subgraph_t subgraph,
  1881. uint32_t input_id,
  1882. uint32_t output_id,
  1883. uint32_t flags);
  1884. /// Define a SoftMax Node and add it to a Subgraph.
  1885. ///
  1886. /// @param subgraph - a Subgraph object that will own the created Node.
  1887. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph, and have at
  1888. /// least one dimension.
  1889. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1890. /// shape must match the shape of the input tensor.
  1891. /// @param flags - binary features of the SoftMax Node. No supported flags are currently defined.
  1892. enum xnn_status xnn_define_softmax(
  1893. xnn_subgraph_t subgraph,
  1894. uint32_t input_id,
  1895. uint32_t output_id,
  1896. uint32_t flags);
  1897. /// Define a Space To Depth 2D Node and add it to a Subgraph.
  1898. ///
  1899. /// The Space To Depth 2D Node rearranges blocks of spatial data into blocks (a reverse transform to Depth To Space 2D).
  1900. /// For a given input pixel, an output square of pixels with side @a block_size is formed from values in the
  1901. /// corresponding number of its channels. The output depth is therefore @a block_size x @a block_size times greater
  1902. /// than that of the input.
  1903. ///
  1904. /// @param subgraph - a Subgraph object that will own the created Node.
  1905. /// @param block_size - the size of the spatial block.
  1906. /// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph
  1907. /// with [N, IH * block_size, IW * block_size, OC] dimensions.
  1908. /// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph
  1909. /// with [N, IH, IW, OC * block_size * block_size] dimensions.
  1910. /// @param flags - binary features of the input_channels Node. No supported flags are currently defined.
  1911. enum xnn_status xnn_define_space_to_depth_2d(
  1912. xnn_subgraph_t subgraph,
  1913. uint32_t block_size,
  1914. uint32_t input_id,
  1915. uint32_t output_id,
  1916. uint32_t flags);
  1917. /// Define a Square Node and add it to a Subgraph.
  1918. ///
  1919. /// @param subgraph - a Subgraph object that will own the created Node.
  1920. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1921. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1922. /// shape must match the shape of the input tensor.
  1923. /// @param flags - binary features of the Square Node. No supported flags are currently defined.
  1924. XNN_DEPRECATED enum xnn_status xnn_define_square(
  1925. xnn_subgraph_t subgraph,
  1926. uint32_t input_id,
  1927. uint32_t output_id,
  1928. uint32_t flags);
  1929. /// Define a Square Root Node and add it to a Subgraph.
  1930. ///
  1931. /// @param subgraph - a Subgraph object that will own the created Node.
  1932. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1933. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1934. /// shape must match the shape of the input tensor.
  1935. /// @param flags - binary features of the Square Root Node. No supported flags are currently defined.
  1936. XNN_DEPRECATED enum xnn_status xnn_define_square_root(
  1937. xnn_subgraph_t subgraph,
  1938. uint32_t input_id,
  1939. uint32_t output_id,
  1940. uint32_t flags);
  1941. /// Define a Reciprocal Square Root Node and add it to a Subgraph.
  1942. ///
  1943. /// @param subgraph - a Subgraph object that will own the created Node.
  1944. /// @param input_id - Value ID for the input tensor. The input tensor must be
  1945. /// defined in the @a subgraph.
  1946. /// @param output_id - Value ID for the output tensor. The output tensor must be
  1947. /// defined in the @a subgraph, and its
  1948. /// shape must match the shape of the input tensor.
  1949. /// @param flags - binary features of the Square Root Node. No supported flags
  1950. /// are currently defined.
  1951. XNN_DEPRECATED enum xnn_status xnn_define_reciprocal_square_root(
  1952. xnn_subgraph_t subgraph,
  1953. uint32_t input_id,
  1954. uint32_t output_id,
  1955. uint32_t flags);
  1956. enum xnn_status xnn_define_static_slice(
  1957. xnn_subgraph_t subgraph,
  1958. size_t num_dims,
  1959. const size_t* offsets,
  1960. const size_t* sizes,
  1961. uint32_t input_id,
  1962. uint32_t output_id,
  1963. uint32_t flags);
  1964. /// Define a Static Slice Node add it to a Subgraph.
  1965. ///
  1966. /// @param subgraph - a Subgraph object that will own the created Node.
  1967. /// @param num_dims - number of shape dimensions in the input and output tensor.
  1968. /// @param offsets - offsets in each dimension of the input tensor. This array must have @a num_dims elements. Can be
  1969. /// negative meaning that the offset is relative to the end of the dimension.
  1970. /// @param sizes - size of each dimension in output tensor. This array must have @a num_dims elements.
  1971. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  1972. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  1973. /// dimensions must match @a sizes.
  1974. /// @param flags - binary features of the Static Slice Node. No supported flags are currently defined.
  1975. enum xnn_status xnn_define_static_slice_v2( //
  1976. xnn_subgraph_t subgraph, //
  1977. size_t num_dims, //
  1978. const int64_t* offsets, //
  1979. const size_t* sizes, //
  1980. uint32_t input_id, //
  1981. uint32_t output_id, //
  1982. uint32_t flags);
  1983. /// Define a Static Transpose Node and add it to a Subgraph.
  1984. ///
  1985. /// The Static Transpose Node applies a generalized transpose to the input tensor using the permuation in perm.
  1986. ///
  1987. /// @param subgraph - a Subgraph object that will own the created Node.
  1988. /// @param input_id - Value ID for the input tensor. The input tensor must be an N-dimensional tensor defined in
  1989. /// the @a subgraph.
  1990. /// @param output_id - Value ID for the output tensor. The output tensor must be an N-dimensional tensor defined
  1991. /// in the @a subgraph with each dimension equal to its corresponding permuted input dimension.
  1992. /// @param num_dims - the number of permutation dimensions. This must be equal to the number of input dimensions.
  1993. /// @param perm - The permutation of the axis of the input tensor. The perm array must must contain 0 to N-1 in the
  1994. /// permuted order.
  1995. /// @param flags - binary features of the Static Transpose Node. No supported flags are currently defined.
  1996. enum xnn_status xnn_define_static_transpose(
  1997. xnn_subgraph_t subgraph,
  1998. size_t num_dims,
  1999. const size_t* perm,
  2000. uint32_t input_id,
  2001. uint32_t output_id,
  2002. uint32_t flags);
  2003. /// Define a Tanh Node and add it to a Subgraph.
  2004. ///
  2005. /// @param subgraph - a Subgraph object that will own the created Node.
  2006. /// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph.
  2007. /// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its
  2008. /// shape must match the shape of the input tensor.
  2009. /// @param flags - binary features of the Tanh Node. No supported flags are currently defined.
  2010. XNN_DEPRECATED enum xnn_status xnn_define_tanh(
  2011. xnn_subgraph_t subgraph,
  2012. uint32_t input_id,
  2013. uint32_t output_id,
  2014. uint32_t flags);
  2015. /// Code cache is a cache for JIT generated code.
  2016. typedef struct xnn_code_cache* xnn_code_cache_t;
  2017. /// Weights cache can be finalized in these ways:
  2018. enum xnn_weights_cache_finalization_kind {
  2019. /// Weights cache is finalized, no insert operations into the weights cache is allowed, even if the "inserted"
  2020. /// weights already exist in thee cache. Weights cache memory will also be trimmed to page boundary and set to
  2021. /// read-only (to prevent writes).
  2022. xnn_weights_cache_finalization_kind_hard,
  2023. /// Weights cache will be finalized with some extra space at the end, this allows for "inserting" into the cache only
  2024. /// if the weights are already in the cache, and errors on inserting uncached weights. There is memory overhead.
  2025. xnn_weights_cache_finalization_kind_soft,
  2026. };
  2027. /// A combination of multiple factors to uniquely locate the weights cache.
  2028. struct xnn_weights_cache_look_up_key {
  2029. /// The unique seed for each ukernel. It is guaranteed that each ukernel provides
  2030. /// a consistent and identical seed.
  2031. uint32_t seed;
  2032. /// Pointer to the original kernel.
  2033. const void* kernel;
  2034. /// Pointer to the original bias, could be NULL.
  2035. const void* bias;
  2036. };
  2037. /// A group of function pointers to manage weights cache. All functions may be
  2038. /// called on multi threads.
  2039. struct xnn_weights_cache_provider {
  2040. /// User-specified pointer that will be passed as-is to all functions in this
  2041. /// structure.
  2042. void* context;
  2043. /// Looks up the tuple of {cache_key, kernel, bias} in the cache. If it is found,
  2044. /// returns the offset to the found entry for reuse. Otherwise, returns SIZE_MAX.
  2045. /// @param context - The user-specified pointer from xnn_weights_cache_provider structure.
  2046. /// @param cache_key - The key used to locate the weights cache entry.
  2047. size_t (*look_up)(void* context, const struct xnn_weights_cache_look_up_key* cache_key);
  2048. /// Ensures that cache has enough space for `n` bytes. Returns the address to
  2049. /// store weight cache. Returns NULL if fails to reserve space.
  2050. /// @param context - The user-specified pointer from xnn_weights_cache_provider structure.
  2051. /// @param n - size to be reserved.
  2052. void* (*reserve_space)(void* context, size_t n);
  2053. /// Looks up packed weights at `ptr` in the cache. If it is found, reuse it.
  2054. /// Otherwise, it is added to the cache. Returns the offset to the cache.
  2055. /// @param context - The user-specified pointer from xnn_weights_cache_provider structure.
  2056. /// @param cache_key - The key used to locate the weights cache entry.
  2057. /// @param ptr - pointer pointing to the packed weight.
  2058. /// @param size - size of the packed weight.
  2059. size_t (*look_up_or_insert)(void* context, const struct xnn_weights_cache_look_up_key* cache_key, void* ptr, size_t size);
  2060. /// Returns whether the cache is finalized.
  2061. /// @param context - The user-specified pointer from xnn_weights_cache_provider structure.
  2062. bool (*is_finalized)(void* context);
  2063. /// Returns the absolute pointer corresponding to `offset`, where the offset is returned from
  2064. /// `look_up` or `get_or_insert`. This function must be called after finalize.
  2065. /// @param context - The user-specified pointer from xnn_weights_cache_provider structure.
  2066. /// @param offset - offset to the start of internal buffer
  2067. void* (*offset_to_addr)(void* context, size_t offset);
  2068. /// Destroy a weights cache object, as well as memory used for the cache.
  2069. /// @param context - The user-specified pointer from xnn_weights_cache_provider structure.
  2070. enum xnn_status (*delete_cache)(void* context);
  2071. };
  2072. /// Weights cache is a cache for packed weights. It can be reused between runtimes.
  2073. typedef struct xnn_weights_cache_provider* xnn_weights_cache_t;
  2074. /// Create a weights cache object specifying the initial size of weights cache (in bytes).
  2075. ///
  2076. /// @param[in] size - initial capacity of the weights cache (in bytes), i.e. it can hold size bytes without growing.
  2077. /// @param weights_cache_out - pointer to the variable that will be initialized to a handle to the weights cache provider
  2078. /// upon successful return. Once created, the weights cache provider can be shared between
  2079. /// different Runtime objects.
  2080. enum xnn_status xnn_create_weights_cache_with_size(size_t size, xnn_weights_cache_t* weights_cache_out);
  2081. enum xnn_status xnn_create_weights_cache(xnn_weights_cache_t* weights_cache_out);
  2082. /// Finalizes the weights cache. The kind of finalization is specified by `finalization_kind`.
  2083. /// @param weights_cache - the weights cache object to finalize.
  2084. /// @param finalization_kind - the kind of finalization.
  2085. enum xnn_status xnn_finalize_weights_cache(
  2086. xnn_weights_cache_t weights_cache,
  2087. enum xnn_weights_cache_finalization_kind finalization_kind);
  2088. // Wrapper function of the function pointers in `xnn_weights_cache_t`.
  2089. bool xnn_weights_cache_is_finalized(xnn_weights_cache_t cache);
  2090. /// Destroy a weights cache object, as well as memory used for the cache.
  2091. /// @param weights_cache - the weights cache object to destroy.
  2092. enum xnn_status xnn_delete_weights_cache(xnn_weights_cache_t weights_cache);
  2093. typedef struct xnn_workspace* xnn_workspace_t;
  2094. /// Create a workspace object.
  2095. /// @param workspace_out - pointer to the variable that will be initialized to a handle to the workspace object upon
  2096. /// successful return. Once created, the workspace can be shared between different Runtime
  2097. /// objects.
  2098. enum xnn_status xnn_create_workspace(xnn_workspace_t* workspace_out);
  2099. /// Destroy a workspace object, as well as memory used by the workspace. Object destruction can be deferred until all
  2100. /// Runtime objects created with this workspace are destroyed.
  2101. /// @param workspace - the workspace object to destroy.
  2102. enum xnn_status xnn_release_workspace(xnn_workspace_t workspace);
  2103. /// Runtime is a combination of an execution plan for subgraph Nodes and a memory manager for subgraph Values.
  2104. typedef struct xnn_runtime* xnn_runtime_t;
  2105. enum xnn_profile_info {
  2106. /// Returns a size_t containing the number of operators.
  2107. xnn_profile_info_num_operators,
  2108. /// Returns a char[] containing the null character separated names of all operators.
  2109. xnn_profile_info_operator_name,
  2110. /// Returns a uint64_t[] with the runtimes of all operators in the same order as xnn_profile_info_operator_name.
  2111. xnn_profile_info_operator_timing,
  2112. };
  2113. /// Return profile information for all operators.
  2114. ///
  2115. /// @param runtime - a Runtime object created with @ref xnn_create_runtime, @ref xnn_create_runtime_v2 or
  2116. /// @ref xnn_create_runtime_v3.
  2117. /// @param param_name - type of profile information required.
  2118. /// @param param_value_size - the size in bytes of memory pointed to by param_value. If this is not sufficient then
  2119. /// param_value_size_ret will be set to the required size and xnn_status_out_of_memory will be
  2120. /// returned.
  2121. /// @param param_value - a pointer to memory location where appropriate values for a given param_value will be written.
  2122. /// @param param_value_size_ret - returns number of bytes required to write the result if param_value_size is not
  2123. /// sufficient.
  2124. enum xnn_status xnn_get_runtime_profiling_info(xnn_runtime_t runtime,
  2125. enum xnn_profile_info param_name,
  2126. size_t param_value_size,
  2127. void* param_value,
  2128. size_t* param_value_size_ret);
  2129. /// Create a Runtime object from a subgraph.
  2130. ///
  2131. /// @param subgraph - a Subgraph object with all Values and Nodes that would be handled by the runtime. No Values or
  2132. /// Nodes can be added to the runtime once it is constructed.
  2133. /// @param weights_cache - a cache for packed weights. The runtime will look up and reuse packed weights in this cache,
  2134. /// this will reduce memory allocated for packed weights.
  2135. /// @param workspace - a workspace to hold internal tensors. The runtime will allocate space used for internal tensors
  2136. /// and track them using workspace. Workspace can be shared and reused across different runtimes. If
  2137. /// workspace is NULL, there will be no sharing: each runtime has its own workspace.
  2138. /// @param threadpool - the thread pool to be used for parallelisation of computations in the runtime. If the thread
  2139. /// pool is NULL, the computation would run on the caller thread without parallelization.
  2140. /// @param flags - binary features of the runtime. The only currently supported values are
  2141. /// XNN_FLAG_HINT_SPARSE_INFERENCE, XNN_FLAG_HINT_FP16_INFERENCE, XNN_FLAG_FORCE_FP16_INFERENCE,
  2142. /// XNN_FLAG_YIELD_WORKERS, and XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER. If XNN_FLAG_YIELD_WORKERS is
  2143. /// specified, worker threads would be yielded to the system scheduler after processing the last operator
  2144. /// in the Runtime. If XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER is specified, convolution operators will
  2145. /// initialize indirection buffers on each inference run using temporary memory in the workspace, instead
  2146. /// of initializing persistent indirection buffers once.
  2147. /// @param runtime_out - pointer to the variable that will be initialized with a handle to the Runtime object upon
  2148. /// successful return. Once constructed, the Runtime object is independent of the Subgraph object
  2149. /// used to create it.
  2150. enum xnn_status xnn_create_runtime_v4(
  2151. xnn_subgraph_t subgraph,
  2152. xnn_weights_cache_t weights_cache,
  2153. xnn_workspace_t workspace,
  2154. pthreadpool_t threadpool,
  2155. uint32_t flags,
  2156. xnn_runtime_t* runtime_out);
  2157. enum xnn_status xnn_create_runtime_v3(
  2158. xnn_subgraph_t subgraph,
  2159. xnn_weights_cache_t weights_cache,
  2160. pthreadpool_t threadpool,
  2161. uint32_t flags,
  2162. xnn_runtime_t* runtime_out);
  2163. enum xnn_status xnn_create_runtime_v2(
  2164. xnn_subgraph_t subgraph,
  2165. pthreadpool_t threadpool,
  2166. uint32_t flags,
  2167. xnn_runtime_t* runtime_out);
  2168. enum xnn_status xnn_create_runtime(
  2169. xnn_subgraph_t subgraph,
  2170. xnn_runtime_t* runtime_out);
  2171. struct xnn_external_value {
  2172. uint32_t id;
  2173. void* data;
  2174. };
  2175. /// Reshape an external value.
  2176. ///
  2177. /// @param external_id - external ID for the Value. The ID must be within the range of reversed Value IDs specified on
  2178. /// the Subgraph creation. If the external ID is XNN_INVALID_VALUE_ID, an internal ID will be
  2179. /// created for the Value.
  2180. /// @param num_dims - number of dimensions in the shape.
  2181. /// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL.
  2182. /// XNNPACK does not keep any pointers to this array after the function returns.
  2183. enum xnn_status xnn_reshape_external_value(
  2184. xnn_runtime_t runtime,
  2185. uint32_t external_id,
  2186. size_t num_dims,
  2187. const size_t* dims);
  2188. /// Get the external value shape.
  2189. ///
  2190. /// @param external_id - external ID for the Value. The ID must be within the range of reversed Value IDs specified on
  2191. /// the Subgraph creation. The external ID can not be XNN_INVALID_VALUE_ID.
  2192. /// @param num_dims - A valid pointer into which the number of dimensions in the shape will be written. It can not be larger than XNN_MAX_TENSOR_DIMS.
  2193. /// @param dims - pointer to an array of @a num_dims shape dimensions. This pointer can't be NULL. It must be large enough to hold
  2194. /// at least @a num_dims elements. XNNPACK does not keep any pointers to this array after the function returns.
  2195. enum xnn_status xnn_get_external_value_shape(
  2196. xnn_runtime_t runtime,
  2197. uint32_t external_id,
  2198. size_t* num_dims,
  2199. size_t* dims);
  2200. /// Reshape the XNNPACK runtime.
  2201. ///
  2202. /// Propagates the shapes of input tensors through the graph to determine the shapes of intermediate and output tensors.
  2203. /// Memory is allocated if required. Output tensor shapes are returned by xnn_get_external_value_shape.
  2204. ///
  2205. /// @param runtime - a Runtime object created with @ref xnn_create_runtime or @ref xnn_create_runtime_v2.
  2206. enum xnn_status xnn_reshape_runtime(
  2207. xnn_runtime_t runtime);
  2208. /// Deprecated. Use xnn_reshape_runtime and xnn_setup_runtime_v2.
  2209. ///
  2210. /// Setup data pointers for external inputs and outputs in a Runtime object and
  2211. /// allocate memory.
  2212. ///
  2213. /// @param runtime - a Runtime object created with @ref xnn_create_runtime or @ref xnn_create_runtime_v2.
  2214. /// @param num_external_values - the number of external inputs and outputs specified in this call. This number must
  2215. /// match the number of external inputs and outputs in the runtime, i.e. all external
  2216. /// inputs and outputs in the runtime must be specified in one call.
  2217. /// @param external_values - array with location information for all external inputs and outputs in the runtime.
  2218. enum xnn_status xnn_setup_runtime(
  2219. xnn_runtime_t runtime,
  2220. size_t num_external_values,
  2221. const struct xnn_external_value* external_values);
  2222. /// Setup data pointers for external inputs and outputs in a Runtime object.
  2223. /// Should be called after xnn_reshape_runtime.
  2224. ///
  2225. /// @param runtime - a Runtime object created with @ref xnn_create_runtime or @ref xnn_create_runtime_v2.
  2226. /// @param num_external_values - the number of external inputs and outputs specified in this call. This number must
  2227. /// match the number of external inputs and outputs in the runtime, i.e. all external
  2228. /// inputs and outputs in the runtime must be specified in one call.
  2229. /// @param external_values - array with location information for all external inputs and outputs in the runtime.
  2230. enum xnn_status xnn_setup_runtime_v2(
  2231. xnn_runtime_t runtime,
  2232. size_t num_external_values,
  2233. const struct xnn_external_value* external_values);
  2234. /// Execute forward pass for all operators in the runtime.
  2235. ///
  2236. /// @param runtime - the Runtime object with the execution plan to invoke.
  2237. enum xnn_status xnn_invoke_runtime(
  2238. xnn_runtime_t runtime);
  2239. /// Destroy a Runtime object, as well as operators and memory associated with it.
  2240. ///
  2241. /// @param runtime - the Runtime object to destroy.
  2242. enum xnn_status xnn_delete_runtime(
  2243. xnn_runtime_t runtime);
  2244. typedef struct xnn_operator* xnn_operator_t;
  2245. enum xnn_status xnn_run_operator(
  2246. xnn_operator_t op,
  2247. pthreadpool_t threadpool);
  2248. enum xnn_status xnn_delete_operator(
  2249. xnn_operator_t op);
  2250. /// Operator API:
  2251. /// - create operator will create and populate a xnn_operator_t
  2252. /// - reshape operator will update fields in xnn_operator_t with shape/dimensions and parallelization information
  2253. /// - setup operator will update pointers to input and outputs
  2254. /// Each supported operator must have a create, reshape, and setup function. (Optionally a run function.)
  2255. /// Operators listed below are in alphabetical order by operator name; within each operator, we sort alphabetically by
  2256. /// data layout and type. We also group create, reshape, setup (and optionally run) functions of each operator together.
  2257. enum xnn_status xnn_create_binary_elementwise_nd(
  2258. enum xnn_binary_operator type,
  2259. enum xnn_datatype datatype,
  2260. const struct xnn_quantization_params* input1_quantization,
  2261. const struct xnn_quantization_params* input2_quantization,
  2262. const struct xnn_quantization_params* output_quantization,
  2263. uint32_t flags,
  2264. xnn_operator_t* binary_op_out);
  2265. enum xnn_status xnn_reshape_binary_elementwise_nd(
  2266. xnn_operator_t binary_op,
  2267. size_t num_input1_dims,
  2268. const size_t* input1_shape,
  2269. size_t num_input2_dims,
  2270. const size_t* input2_shape,
  2271. pthreadpool_t threadpool);
  2272. enum xnn_status xnn_setup_binary_elementwise_nd(
  2273. xnn_operator_t binary_op,
  2274. const void* input1,
  2275. const void* input2,
  2276. void* output);
  2277. enum xnn_status xnn_run_binary_elementwise_nd(
  2278. enum xnn_binary_operator type,
  2279. enum xnn_datatype datatype,
  2280. const struct xnn_quantization_params* input1_quantization,
  2281. const struct xnn_quantization_params* input2_quantization,
  2282. const struct xnn_quantization_params* output_quantization,
  2283. uint32_t flags,
  2284. size_t num_input1_dims,
  2285. const size_t* input1_shape,
  2286. size_t num_input2_dims,
  2287. const size_t* input2_shape,
  2288. const void* input1,
  2289. const void* input2,
  2290. void* output,
  2291. pthreadpool_t threadpool);
  2292. enum xnn_status xnn_create_unary_elementwise_nc(
  2293. enum xnn_unary_operator op_type,
  2294. enum xnn_datatype input_datatype,
  2295. enum xnn_datatype output_datatype,
  2296. const union xnn_unary_params* params,
  2297. const struct xnn_quantization_params* input_quantization,
  2298. const struct xnn_quantization_params* output_quantization,
  2299. uint32_t flags,
  2300. xnn_operator_t* op_out);
  2301. enum xnn_status xnn_reshape_unary_elementwise_nc(
  2302. xnn_operator_t op,
  2303. size_t batch_size,
  2304. size_t channels,
  2305. size_t input_stride,
  2306. size_t output_stride,
  2307. pthreadpool_t threadpool);
  2308. enum xnn_status xnn_setup_unary_elementwise_nc(
  2309. xnn_operator_t op,
  2310. const void* input,
  2311. void* output);
  2312. enum xnn_status xnn_run_unary_elementwise_nc(
  2313. // create parameters
  2314. enum xnn_unary_operator op_type,
  2315. enum xnn_datatype input_datatype,
  2316. enum xnn_datatype output_datatype,
  2317. const union xnn_unary_params* params,
  2318. const struct xnn_quantization_params* input_quantization,
  2319. const struct xnn_quantization_params* output_quantization,
  2320. uint32_t flags,
  2321. // reshape parameters
  2322. size_t batch_size,
  2323. size_t channels,
  2324. size_t input_stride,
  2325. size_t output_stride,
  2326. pthreadpool_t threadpool,
  2327. // setup parameters
  2328. const void* input,
  2329. void* output);
  2330. enum xnn_status xnn_create_argmax_pooling2d_nhwc_f32(
  2331. uint32_t input_padding_top,
  2332. uint32_t input_padding_right,
  2333. uint32_t input_padding_bottom,
  2334. uint32_t input_padding_left,
  2335. uint32_t pooling_height,
  2336. uint32_t pooling_width,
  2337. uint32_t flags,
  2338. xnn_operator_t* argmax_pooling_op_out);
  2339. enum xnn_status xnn_reshape_argmax_pooling2d_nhwc_f32(
  2340. xnn_operator_t argmax_pooling_op,
  2341. size_t batch_size,
  2342. size_t input_height,
  2343. size_t input_width,
  2344. size_t channels,
  2345. size_t input_pixel_stride,
  2346. size_t output_pixel_stride,
  2347. size_t* workspace_size,
  2348. size_t* workspace_alignment,
  2349. size_t* output_height_out,
  2350. size_t* output_width_out,
  2351. pthreadpool_t threadpool);
  2352. enum xnn_status xnn_setup_argmax_pooling2d_nhwc_f32(
  2353. xnn_operator_t argmax_pooling_op,
  2354. void* workspace,
  2355. const float* input,
  2356. float* output,
  2357. uint32_t* index);
  2358. enum xnn_status xnn_create_average_pooling2d_nhwc_f16(
  2359. uint32_t input_padding_top,
  2360. uint32_t input_padding_right,
  2361. uint32_t input_padding_bottom,
  2362. uint32_t input_padding_left,
  2363. uint32_t pooling_height,
  2364. uint32_t pooling_width,
  2365. uint32_t stride_height,
  2366. uint32_t stride_width,
  2367. float output_min,
  2368. float output_max,
  2369. uint32_t flags,
  2370. xnn_operator_t* average_pooling_op_out);
  2371. enum xnn_status xnn_reshape_average_pooling2d_nhwc_f16(
  2372. xnn_operator_t average_pooling_op,
  2373. size_t batch_size,
  2374. size_t input_height,
  2375. size_t input_width,
  2376. size_t channels,
  2377. size_t input_pixel_stride,
  2378. size_t output_pixel_stride,
  2379. size_t* workspace_size,
  2380. size_t* workspace_alignment,
  2381. size_t* output_height_out,
  2382. size_t* output_width_out,
  2383. pthreadpool_t threadpool);
  2384. enum xnn_status xnn_setup_average_pooling2d_nhwc_f16(
  2385. xnn_operator_t average_pooling_op,
  2386. void* workspace,
  2387. const void* input,
  2388. void* output);
  2389. enum xnn_status xnn_create_average_pooling2d_nhwc_f32(
  2390. uint32_t input_padding_top,
  2391. uint32_t input_padding_right,
  2392. uint32_t input_padding_bottom,
  2393. uint32_t input_padding_left,
  2394. uint32_t pooling_height,
  2395. uint32_t pooling_width,
  2396. uint32_t stride_height,
  2397. uint32_t stride_width,
  2398. float output_min,
  2399. float output_max,
  2400. uint32_t flags,
  2401. xnn_operator_t* average_pooling_op_out);
  2402. enum xnn_status xnn_reshape_average_pooling2d_nhwc_f32(
  2403. xnn_operator_t average_pooling_op,
  2404. size_t batch_size,
  2405. size_t input_height,
  2406. size_t input_width,
  2407. size_t channels,
  2408. size_t input_pixel_stride,
  2409. size_t output_pixel_stride,
  2410. size_t* workspace_size,
  2411. size_t* workspace_alignment,
  2412. size_t* output_height_out,
  2413. size_t* output_width_out,
  2414. pthreadpool_t threadpool);
  2415. enum xnn_status xnn_setup_average_pooling2d_nhwc_f32(
  2416. xnn_operator_t average_pooling_op,
  2417. void* workspace,
  2418. const float* input,
  2419. float* output);
  2420. enum xnn_status xnn_create_average_pooling2d_nhwc_qu8(
  2421. uint32_t input_padding_top,
  2422. uint32_t input_padding_right,
  2423. uint32_t input_padding_bottom,
  2424. uint32_t input_padding_left,
  2425. uint32_t pooling_height,
  2426. uint32_t pooling_width,
  2427. uint32_t stride_height,
  2428. uint32_t stride_width,
  2429. uint8_t input_zero_point,
  2430. float input_scale,
  2431. uint8_t output_zero_point,
  2432. float output_scale,
  2433. uint8_t output_min,
  2434. uint8_t output_max,
  2435. uint32_t flags,
  2436. xnn_operator_t* average_pooling_op_out);
  2437. enum xnn_status xnn_reshape_average_pooling2d_nhwc_qu8(
  2438. xnn_operator_t average_pooling_op,
  2439. size_t batch_size,
  2440. size_t input_height,
  2441. size_t input_width,
  2442. size_t channels,
  2443. size_t input_pixel_stride,
  2444. size_t output_pixel_stride,
  2445. size_t* workspace_size,
  2446. size_t* workspace_alignment,
  2447. size_t* output_height_out,
  2448. size_t* output_width_out,
  2449. pthreadpool_t threadpool);
  2450. enum xnn_status xnn_setup_average_pooling2d_nhwc_qu8(
  2451. xnn_operator_t average_pooling_op,
  2452. void* workspace,
  2453. const uint8_t* input,
  2454. uint8_t* output);
  2455. enum xnn_status xnn_create_batch_matrix_multiply_nc_f16(
  2456. uint32_t flags,
  2457. xnn_operator_t* batch_matrix_multiply_op);
  2458. enum xnn_status xnn_reshape_batch_matrix_multiply_nc_f16(
  2459. xnn_operator_t batch_matrix_multiply_op, size_t num_batch_dims,
  2460. const size_t* batch_dims_a, const size_t* batch_dims_b, size_t m, size_t k,
  2461. size_t n, size_t* workspace_size, size_t* workspace_alignment,
  2462. pthreadpool_t threadpool);
  2463. enum xnn_status xnn_setup_batch_matrix_multiply_nc_f16(
  2464. xnn_operator_t batch_matrix_multiply_op, void* workspace,
  2465. const void* input_a, const void* input_b, void* output);
  2466. enum xnn_status xnn_create_batch_matrix_multiply_nc_f32(
  2467. uint32_t flags, xnn_operator_t* batch_matrix_multiply_op);
  2468. enum xnn_status xnn_create_batch_matrix_multiply_nc_f32_const_weights(
  2469. size_t batch_size_b, size_t k, size_t n, const float* data_b,
  2470. uint32_t flags, xnn_operator_t* batch_matrix_multiply_op);
  2471. enum xnn_status xnn_reshape_batch_matrix_multiply_nc_f32(
  2472. xnn_operator_t batch_matrix_multiply_op, size_t num_batch_dims,
  2473. const size_t* batch_dims_a, const size_t* batch_dims_b, size_t m, size_t k,
  2474. size_t n, size_t* workspace_size, size_t* workspace_alignment,
  2475. pthreadpool_t threadpool);
  2476. enum xnn_status xnn_setup_batch_matrix_multiply_nc_f32(
  2477. xnn_operator_t batch_matrix_multiply_op, void* workspace,
  2478. const float* input_a, const float* input_b, float* output);
  2479. enum xnn_status xnn_create_batch_matrix_multiply_nc_qd8_f32_qc8w(
  2480. size_t batch_size_b, size_t k, size_t n, const int8_t* data_b,
  2481. const float* scale_b, uint32_t flags,
  2482. xnn_operator_t* batch_matrix_multiply_op);
  2483. enum xnn_status xnn_reshape_batch_matrix_multiply_nc_qd8_f32_qc8w(
  2484. xnn_operator_t batch_matrix_multiply_op, size_t num_batch_dims,
  2485. const size_t* batch_dims_a, const size_t* batch_dims_b, size_t m, size_t k,
  2486. size_t n, pthreadpool_t threadpool);
  2487. enum xnn_status xnn_setup_batch_matrix_multiply_nc_qd8_f32_qc8w(
  2488. xnn_operator_t batch_matrix_multiply_op, const int8_t* input_a,
  2489. const struct xnn_quantization_params* quantization_params,
  2490. float* output);
  2491. enum xnn_status xnn_create_channel_shuffle_nc_x8(
  2492. size_t groups,
  2493. size_t group_channels,
  2494. size_t input_stride,
  2495. size_t output_stride,
  2496. uint32_t flags,
  2497. xnn_operator_t* channel_shuffle_op_out);
  2498. enum xnn_status xnn_reshape_channel_shuffle_nc_x8(
  2499. xnn_operator_t channel_shuffle_op,
  2500. size_t batch_size,
  2501. pthreadpool_t threadpool);
  2502. enum xnn_status xnn_setup_channel_shuffle_nc_x8(
  2503. xnn_operator_t channel_shuffle_op,
  2504. const void* input,
  2505. void* output);
  2506. enum xnn_status xnn_create_channel_shuffle_nc_x32(
  2507. size_t groups,
  2508. size_t group_channels,
  2509. size_t input_stride,
  2510. size_t output_stride,
  2511. uint32_t flags,
  2512. xnn_operator_t* channel_shuffle_op_out);
  2513. enum xnn_status xnn_reshape_channel_shuffle_nc_x32(
  2514. xnn_operator_t channel_shuffle_op,
  2515. size_t batch_size,
  2516. pthreadpool_t threadpool);
  2517. enum xnn_status xnn_setup_channel_shuffle_nc_x32(
  2518. xnn_operator_t channel_shuffle_op,
  2519. const void* input,
  2520. void* output);
  2521. enum xnn_status xnn_create_constant_pad_nd_x8(
  2522. const void* padding_value,
  2523. uint32_t flags,
  2524. xnn_operator_t* constant_pad_op_out);
  2525. enum xnn_status xnn_reshape_constant_pad_nd_x8(
  2526. xnn_operator_t constant_pad_op,
  2527. size_t num_dims,
  2528. const size_t* input_shape,
  2529. const size_t* pre_padding,
  2530. const size_t* post_padding,
  2531. pthreadpool_t threadpool);
  2532. enum xnn_status xnn_setup_constant_pad_nd_x8(
  2533. xnn_operator_t constant_pad_op,
  2534. const void* input,
  2535. void* output);
  2536. enum xnn_status xnn_run_constant_pad_nd_x8(
  2537. uint32_t flags,
  2538. size_t num_dims,
  2539. const size_t* input_shape,
  2540. const size_t* pre_paddings,
  2541. const size_t* post_paddings,
  2542. const void* input,
  2543. void* output,
  2544. const void* padding_value,
  2545. pthreadpool_t threadpool);
  2546. enum xnn_status xnn_create_constant_pad_nd_x16(
  2547. const void* padding_value,
  2548. uint32_t flags,
  2549. xnn_operator_t* constant_pad_op_out);
  2550. enum xnn_status xnn_reshape_constant_pad_nd_x16(
  2551. xnn_operator_t constant_pad_op,
  2552. size_t num_dims,
  2553. const size_t* input_shape,
  2554. const size_t* pre_padding,
  2555. const size_t* post_padding,
  2556. pthreadpool_t threadpool);
  2557. enum xnn_status xnn_setup_constant_pad_nd_x16(
  2558. xnn_operator_t constant_pad_op,
  2559. const void* input,
  2560. void* output);
  2561. enum xnn_status xnn_run_constant_pad_nd_x16(
  2562. uint32_t flags,
  2563. size_t num_dims,
  2564. const size_t* input_shape,
  2565. const size_t* pre_paddings,
  2566. const size_t* post_paddings,
  2567. const void* input,
  2568. void* output,
  2569. const void* padding_value,
  2570. pthreadpool_t threadpool);
  2571. enum xnn_status xnn_create_constant_pad_nd_x32(
  2572. const void* padding_value,
  2573. uint32_t flags,
  2574. xnn_operator_t* constant_pad_op_out);
  2575. enum xnn_status xnn_reshape_constant_pad_nd_x32(
  2576. xnn_operator_t constant_pad_op,
  2577. size_t num_dims,
  2578. const size_t* input_shape,
  2579. const size_t* pre_padding,
  2580. const size_t* post_padding,
  2581. pthreadpool_t threadpool);
  2582. enum xnn_status xnn_setup_constant_pad_nd_x32(
  2583. xnn_operator_t constant_pad_op,
  2584. const void* input,
  2585. void* output);
  2586. enum xnn_status xnn_run_constant_pad_nd_x32(
  2587. uint32_t flags,
  2588. size_t num_dims,
  2589. const size_t* input_shape,
  2590. const size_t* pre_paddings,
  2591. const size_t* post_paddings,
  2592. const void* input,
  2593. void* output,
  2594. const void* padding_value,
  2595. pthreadpool_t threadpool);
  2596. enum xnn_status xnn_create_convert_nc_f16_qd8(
  2597. uint32_t flags,
  2598. xnn_operator_t* convert_op_out);
  2599. enum xnn_status xnn_reshape_convert_nc_f16_qd8(
  2600. xnn_operator_t convert_op,
  2601. size_t batch_size,
  2602. size_t channels,
  2603. size_t input_stride,
  2604. size_t output_stride,
  2605. pthreadpool_t threadpool);
  2606. // quantization_params must be padded with at least XNN_EXTRA_QUANTIZATION_PARAMS entries.
  2607. enum xnn_status xnn_setup_convert_nc_f16_qd8(
  2608. xnn_operator_t convert_op,
  2609. const void* input,
  2610. int8_t* output,
  2611. struct xnn_quantization_params* quantization_params);
  2612. enum xnn_status xnn_create_convert_nc_f32_qd8(
  2613. uint32_t flags,
  2614. xnn_operator_t* convert_op_out);
  2615. enum xnn_status xnn_reshape_convert_nc_f32_qd8(
  2616. xnn_operator_t convert_op,
  2617. size_t batch_size,
  2618. size_t channels,
  2619. size_t input_stride,
  2620. size_t output_stride,
  2621. pthreadpool_t threadpool);
  2622. // quantization_params must be padded with at least XNN_EXTRA_QUANTIZATION_PARAMS entries.
  2623. enum xnn_status xnn_setup_convert_nc_f32_qd8(
  2624. xnn_operator_t convert_op,
  2625. const float* input,
  2626. int8_t* output,
  2627. struct xnn_quantization_params* quantization_params);
  2628. XNN_DEPRECATED enum xnn_status xnn_run_convert_nc_f32_f16(
  2629. size_t channels,
  2630. size_t input_stride,
  2631. size_t output_stride,
  2632. size_t batch_size,
  2633. const float* input,
  2634. void* output,
  2635. uint32_t flags,
  2636. pthreadpool_t threadpool);
  2637. enum xnn_status xnn_create_convolution2d_nchw_f16(
  2638. uint32_t input_padding_top,
  2639. uint32_t input_padding_right,
  2640. uint32_t input_padding_bottom,
  2641. uint32_t input_padding_left,
  2642. uint32_t kernel_height,
  2643. uint32_t kernel_width,
  2644. uint32_t subsampling_height,
  2645. uint32_t subsampling_width,
  2646. uint32_t dilation_height,
  2647. uint32_t dilation_width,
  2648. uint32_t groups,
  2649. size_t group_input_channels,
  2650. size_t group_output_channels,
  2651. size_t input_channel_stride,
  2652. size_t output_channel_stride,
  2653. const void* kernel,
  2654. const void* bias,
  2655. float output_min,
  2656. float output_max,
  2657. uint32_t flags,
  2658. xnn_code_cache_t code_cache,
  2659. xnn_weights_cache_t weights_cache,
  2660. xnn_operator_t* convolution_op_out);
  2661. enum xnn_status xnn_reshape_convolution2d_nchw_f16(
  2662. xnn_operator_t convolution_op,
  2663. size_t batch_size,
  2664. size_t input_height,
  2665. size_t input_width,
  2666. size_t* output_height_out,
  2667. size_t* output_width_out,
  2668. pthreadpool_t threadpool);
  2669. enum xnn_status xnn_setup_convolution2d_nchw_f16(
  2670. xnn_operator_t convolution_op,
  2671. const void* input,
  2672. void* output);
  2673. enum xnn_status xnn_create_convolution2d_nchw_f32(
  2674. uint32_t input_padding_top,
  2675. uint32_t input_padding_right,
  2676. uint32_t input_padding_bottom,
  2677. uint32_t input_padding_left,
  2678. uint32_t kernel_height,
  2679. uint32_t kernel_width,
  2680. uint32_t subsampling_height,
  2681. uint32_t subsampling_width,
  2682. uint32_t dilation_height,
  2683. uint32_t dilation_width,
  2684. uint32_t groups,
  2685. size_t group_input_channels,
  2686. size_t group_output_channels,
  2687. size_t input_channel_stride,
  2688. size_t output_channel_stride,
  2689. const float* kernel,
  2690. const float* bias,
  2691. float output_min,
  2692. float output_max,
  2693. uint32_t flags,
  2694. xnn_code_cache_t code_cache,
  2695. xnn_weights_cache_t weights_cache,
  2696. xnn_operator_t* convolution_op_out);
  2697. enum xnn_status xnn_reshape_convolution2d_nchw_f32(
  2698. xnn_operator_t convolution_op,
  2699. size_t batch_size,
  2700. size_t input_height,
  2701. size_t input_width,
  2702. size_t* output_height_out,
  2703. size_t* output_width_out,
  2704. pthreadpool_t threadpool);
  2705. enum xnn_status xnn_setup_convolution2d_nchw_f32(
  2706. xnn_operator_t convolution_op,
  2707. const float* input,
  2708. float* output);
  2709. enum xnn_status xnn_create_convolution2d_nhwc_f16(
  2710. uint32_t input_padding_top,
  2711. uint32_t input_padding_right,
  2712. uint32_t input_padding_bottom,
  2713. uint32_t input_padding_left,
  2714. uint32_t kernel_height,
  2715. uint32_t kernel_width,
  2716. uint32_t subsampling_height,
  2717. uint32_t subsampling_width,
  2718. uint32_t dilation_height,
  2719. uint32_t dilation_width,
  2720. uint32_t groups,
  2721. size_t group_input_channels,
  2722. size_t group_output_channels,
  2723. size_t input_channel_stride,
  2724. size_t output_channel_stride,
  2725. const void* kernel,
  2726. const void* bias,
  2727. float output_min,
  2728. float output_max,
  2729. uint32_t flags,
  2730. xnn_code_cache_t code_cache,
  2731. xnn_weights_cache_t weights_cache,
  2732. xnn_operator_t* convolution_op_out);
  2733. enum xnn_status xnn_reshape_convolution2d_nhwc_f16(
  2734. xnn_operator_t convolution_op,
  2735. size_t batch_size,
  2736. size_t input_height,
  2737. size_t input_width,
  2738. size_t* workspace_size,
  2739. size_t* workspace_alignment,
  2740. size_t* output_height_out,
  2741. size_t* output_width_out,
  2742. pthreadpool_t threadpool);
  2743. enum xnn_status xnn_setup_convolution2d_nhwc_f16(
  2744. xnn_operator_t convolution_op,
  2745. void* workspace,
  2746. const void* input,
  2747. void* output);
  2748. enum xnn_status xnn_create_convolution2d_nhwc_f32(
  2749. uint32_t input_padding_top,
  2750. uint32_t input_padding_right,
  2751. uint32_t input_padding_bottom,
  2752. uint32_t input_padding_left,
  2753. uint32_t kernel_height,
  2754. uint32_t kernel_width,
  2755. uint32_t subsampling_height,
  2756. uint32_t subsampling_width,
  2757. uint32_t dilation_height,
  2758. uint32_t dilation_width,
  2759. uint32_t groups,
  2760. size_t group_input_channels,
  2761. size_t group_output_channels,
  2762. size_t input_channel_stride,
  2763. size_t output_channel_stride,
  2764. const float* kernel,
  2765. const float* bias,
  2766. float output_min,
  2767. float output_max,
  2768. uint32_t flags,
  2769. xnn_code_cache_t code_cache,
  2770. xnn_weights_cache_t weights_cache,
  2771. xnn_operator_t* convolution_op_out);
  2772. enum xnn_status xnn_create_convolution2d_nhwc_f32_f16(
  2773. uint32_t input_padding_top,
  2774. uint32_t input_padding_right,
  2775. uint32_t input_padding_bottom,
  2776. uint32_t input_padding_left,
  2777. uint32_t kernel_height,
  2778. uint32_t kernel_width,
  2779. uint32_t subsampling_height,
  2780. uint32_t subsampling_width,
  2781. uint32_t dilation_height,
  2782. uint32_t dilation_width,
  2783. uint32_t groups,
  2784. size_t group_input_channels,
  2785. size_t group_output_channels,
  2786. size_t input_channel_stride,
  2787. size_t output_channel_stride,
  2788. const void* kernel,
  2789. const void* bias,
  2790. float output_min,
  2791. float output_max,
  2792. uint32_t flags,
  2793. xnn_code_cache_t code_cache,
  2794. xnn_weights_cache_t weights_cache,
  2795. xnn_operator_t* convolution_op_out);
  2796. // Forward declare.
  2797. struct xnn_post_operation;
  2798. /// Deprecated
  2799. enum xnn_status xnn_create_fused_convolution2d_nhwc_f32(
  2800. uint32_t input_padding_top,
  2801. uint32_t input_padding_right,
  2802. uint32_t input_padding_bottom,
  2803. uint32_t input_padding_left,
  2804. uint32_t kernel_height,
  2805. uint32_t kernel_width,
  2806. uint32_t subsampling_height,
  2807. uint32_t subsampling_width,
  2808. uint32_t dilation_height,
  2809. uint32_t dilation_width,
  2810. uint32_t groups,
  2811. size_t group_input_channels,
  2812. size_t group_output_channels,
  2813. size_t input_channel_stride,
  2814. size_t output_channel_stride,
  2815. const float* kernel,
  2816. const float* bias,
  2817. size_t num_post_operations,
  2818. struct xnn_post_operation* post_operations,
  2819. uint32_t flags,
  2820. xnn_code_cache_t code_cache,
  2821. xnn_weights_cache_t weights_cache,
  2822. xnn_operator_t* convolution_op_out);
  2823. enum xnn_status xnn_reshape_convolution2d_nhwc_f32(
  2824. xnn_operator_t convolution_op,
  2825. size_t batch_size,
  2826. size_t input_height,
  2827. size_t input_width,
  2828. size_t* workspace_size,
  2829. size_t* workspace_alignment,
  2830. size_t* output_height_out,
  2831. size_t* output_width_out,
  2832. pthreadpool_t threadpool);
  2833. enum xnn_status xnn_setup_convolution2d_nhwc_f32(
  2834. xnn_operator_t convolution_op,
  2835. void* workspace,
  2836. const float* input,
  2837. float* output);
  2838. enum xnn_status xnn_create_convolution2d_nhwc_qd8_f16_qc8w(
  2839. uint32_t input_padding_top, uint32_t input_padding_right,
  2840. uint32_t input_padding_bottom, uint32_t input_padding_left,
  2841. uint32_t kernel_height, uint32_t kernel_width, uint32_t subsampling_height,
  2842. uint32_t subsampling_width, uint32_t dilation_height,
  2843. uint32_t dilation_width, uint32_t groups, size_t group_input_channels,
  2844. size_t group_output_channels, size_t input_channel_stride,
  2845. size_t output_channel_stride, const float* kernel_scale,
  2846. const int8_t* kernel, const float* bias, float output_min, float output_max,
  2847. uint32_t flags, xnn_code_cache_t code_cache,
  2848. xnn_weights_cache_t weights_cache, xnn_operator_t* convolution_op_out);
  2849. enum xnn_status xnn_create_convolution2d_nhwc_qd8_f32_qc8w(
  2850. uint32_t input_padding_top, uint32_t input_padding_right,
  2851. uint32_t input_padding_bottom, uint32_t input_padding_left,
  2852. uint32_t kernel_height, uint32_t kernel_width, uint32_t subsampling_height,
  2853. uint32_t subsampling_width, uint32_t dilation_height,
  2854. uint32_t dilation_width, uint32_t groups, size_t group_input_channels,
  2855. size_t group_output_channels, size_t input_channel_stride,
  2856. size_t output_channel_stride, const float* kernel_scale,
  2857. const int8_t* kernel, const float* bias, float output_min, float output_max,
  2858. uint32_t flags, xnn_code_cache_t code_cache,
  2859. xnn_weights_cache_t weights_cache, xnn_operator_t* convolution_op_out);
  2860. enum xnn_status xnn_create_convolution2d_nhwc_qs8(
  2861. uint32_t input_padding_top,
  2862. uint32_t input_padding_right,
  2863. uint32_t input_padding_bottom,
  2864. uint32_t input_padding_left,
  2865. uint32_t kernel_height,
  2866. uint32_t kernel_width,
  2867. uint32_t subsampling_height,
  2868. uint32_t subsampling_width,
  2869. uint32_t dilation_height,
  2870. uint32_t dilation_width,
  2871. uint32_t groups,
  2872. size_t group_input_channels,
  2873. size_t group_output_channels,
  2874. size_t input_channel_stride,
  2875. size_t output_channel_stride,
  2876. int8_t input_zero_point,
  2877. float input_scale,
  2878. float kernel_scale,
  2879. const int8_t* kernel,
  2880. const int32_t* bias,
  2881. int8_t output_zero_point,
  2882. float output_scale,
  2883. int8_t output_min,
  2884. int8_t output_max,
  2885. uint32_t flags,
  2886. xnn_code_cache_t code_cache,
  2887. xnn_weights_cache_t weights_cache,
  2888. xnn_operator_t* convolution_op_out);
  2889. enum xnn_status xnn_reshape_convolution2d_nhwc_qd8_f16_qc8w(
  2890. xnn_operator_t convolution_op, size_t batch_size, size_t input_height,
  2891. size_t input_width, size_t* workspace_size, size_t* workspace_alignment,
  2892. size_t* output_height_out, size_t* output_width_out,
  2893. pthreadpool_t threadpool);
  2894. enum xnn_status xnn_reshape_convolution2d_nhwc_qd8_f32_qc8w(
  2895. xnn_operator_t convolution_op, size_t batch_size, size_t input_height,
  2896. size_t input_width, size_t* workspace_size, size_t* workspace_alignment,
  2897. size_t* output_height_out, size_t* output_width_out,
  2898. pthreadpool_t threadpool);
  2899. enum xnn_status xnn_reshape_convolution2d_nhwc_qs8(
  2900. xnn_operator_t convolution_op,
  2901. size_t batch_size,
  2902. size_t input_height,
  2903. size_t input_width,
  2904. size_t* workspace_size,
  2905. size_t* workspace_alignment,
  2906. size_t* output_height_out,
  2907. size_t* output_width_out,
  2908. pthreadpool_t threadpool);
  2909. enum xnn_status xnn_setup_convolution2d_nhwc_qd8_f16_qc8w(
  2910. xnn_operator_t convolution_op, void* workspace, const int8_t* input,
  2911. void* output,
  2912. const struct xnn_quantization_params* quantization_params);
  2913. enum xnn_status xnn_setup_convolution2d_nhwc_qd8_f32_qc8w(
  2914. xnn_operator_t convolution_op, void* workspace, const int8_t* input,
  2915. float* output,
  2916. const struct xnn_quantization_params* quantization_params);
  2917. enum xnn_status xnn_setup_convolution2d_nhwc_qs8(
  2918. xnn_operator_t convolution_op,
  2919. void* workspace,
  2920. const int8_t* input,
  2921. int8_t* output);
  2922. enum xnn_status xnn_create_convolution2d_nhwc_qs8_qc8w(
  2923. uint32_t input_padding_top,
  2924. uint32_t input_padding_right,
  2925. uint32_t input_padding_bottom,
  2926. uint32_t input_padding_left,
  2927. uint32_t kernel_height,
  2928. uint32_t kernel_width,
  2929. uint32_t subsampling_height,
  2930. uint32_t subsampling_width,
  2931. uint32_t dilation_height,
  2932. uint32_t dilation_width,
  2933. uint32_t groups,
  2934. size_t group_input_channels,
  2935. size_t group_output_channels,
  2936. size_t input_channel_stride,
  2937. size_t output_channel_stride,
  2938. int8_t input_zero_point,
  2939. float input_scale,
  2940. const float* kernel_scale,
  2941. const int8_t* kernel,
  2942. const int32_t* bias,
  2943. int8_t output_zero_point,
  2944. float output_scale,
  2945. int8_t output_min,
  2946. int8_t output_max,
  2947. uint32_t flags,
  2948. xnn_code_cache_t code_cache,
  2949. xnn_weights_cache_t weights_cache,
  2950. xnn_operator_t* convolution_op_out);
  2951. enum xnn_status xnn_reshape_convolution2d_nhwc_qs8_qc8w(
  2952. xnn_operator_t convolution_op,
  2953. size_t batch_size,
  2954. size_t input_height,
  2955. size_t input_width,
  2956. size_t* workspace_size,
  2957. size_t* workspace_alignment,
  2958. size_t* output_height_out,
  2959. size_t* output_width_out,
  2960. pthreadpool_t threadpool);
  2961. enum xnn_status xnn_setup_convolution2d_nhwc_qs8_qc8w(
  2962. xnn_operator_t convolution_op,
  2963. void* workspace,
  2964. const int8_t* input,
  2965. int8_t* output);
  2966. enum xnn_status xnn_create_convolution2d_nhwc_qu8(
  2967. uint32_t input_padding_top,
  2968. uint32_t input_padding_right,
  2969. uint32_t input_padding_bottom,
  2970. uint32_t input_padding_left,
  2971. uint32_t kernel_height,
  2972. uint32_t kernel_width,
  2973. uint32_t subsampling_height,
  2974. uint32_t subsampling_width,
  2975. uint32_t dilation_height,
  2976. uint32_t dilation_width,
  2977. uint32_t groups,
  2978. size_t group_input_channels,
  2979. size_t group_output_channels,
  2980. size_t input_channel_stride,
  2981. size_t output_channel_stride,
  2982. uint8_t input_zero_point,
  2983. float input_scale,
  2984. uint8_t kernel_zero_point,
  2985. float kernel_scale,
  2986. const uint8_t* kernel,
  2987. const int32_t* bias,
  2988. uint8_t output_zero_point,
  2989. float output_scale,
  2990. uint8_t output_min,
  2991. uint8_t output_max,
  2992. uint32_t flags,
  2993. xnn_code_cache_t code_cache,
  2994. xnn_weights_cache_t weights_cache,
  2995. xnn_operator_t* convolution_op_out);
  2996. enum xnn_status xnn_reshape_convolution2d_nhwc_qu8(
  2997. xnn_operator_t convolution_op,
  2998. size_t batch_size,
  2999. size_t input_height,
  3000. size_t input_width,
  3001. size_t* workspace_size,
  3002. size_t* workspace_alignment,
  3003. size_t* output_height_out,
  3004. size_t* output_width_out,
  3005. pthreadpool_t threadpool);
  3006. enum xnn_status xnn_setup_convolution2d_nhwc_qu8(
  3007. xnn_operator_t convolution_op,
  3008. void* workspace,
  3009. const uint8_t* input,
  3010. uint8_t* output);
  3011. enum xnn_status xnn_create_copy_nc_x8(
  3012. uint32_t flags,
  3013. xnn_operator_t* copy_op_out);
  3014. enum xnn_status xnn_reshape_copy_nc_x8(
  3015. xnn_operator_t copy_op,
  3016. size_t batch_size,
  3017. size_t channels,
  3018. size_t input_stride,
  3019. size_t output_stride,
  3020. pthreadpool_t threadpool);
  3021. enum xnn_status xnn_setup_copy_nc_x8(
  3022. xnn_operator_t copy_op,
  3023. const void* input,
  3024. void* output);
  3025. enum xnn_status xnn_create_copy_nc_x16(
  3026. uint32_t flags,
  3027. xnn_operator_t* copy_op_out);
  3028. enum xnn_status xnn_reshape_copy_nc_x16(
  3029. xnn_operator_t copy_op,
  3030. size_t batch_size,
  3031. size_t channels,
  3032. size_t input_stride,
  3033. size_t output_stride,
  3034. pthreadpool_t threadpool);
  3035. enum xnn_status xnn_setup_copy_nc_x16(
  3036. xnn_operator_t copy_op,
  3037. const void* input,
  3038. void* output);
  3039. enum xnn_status xnn_create_copy_nc_x32(
  3040. uint32_t flags,
  3041. xnn_operator_t* copy_op_out);
  3042. enum xnn_status xnn_reshape_copy_nc_x32(
  3043. xnn_operator_t copy_op,
  3044. size_t batch_size,
  3045. size_t channels,
  3046. size_t input_stride,
  3047. size_t output_stride,
  3048. pthreadpool_t threadpool);
  3049. enum xnn_status xnn_setup_copy_nc_x32(
  3050. xnn_operator_t copy_op,
  3051. const void* input,
  3052. void* output);
  3053. enum xnn_status xnn_run_copy_nc_x32(
  3054. size_t channels,
  3055. size_t input_stride,
  3056. size_t output_stride,
  3057. size_t batch_size,
  3058. const uint32_t* input,
  3059. uint32_t* output,
  3060. uint32_t flags,
  3061. pthreadpool_t threadpool);
  3062. enum xnn_status xnn_create_deconvolution2d_nhwc_f16(
  3063. uint32_t output_padding_top,
  3064. uint32_t output_padding_right,
  3065. uint32_t output_padding_bottom,
  3066. uint32_t output_padding_left,
  3067. uint32_t kernel_height,
  3068. uint32_t kernel_width,
  3069. uint32_t stride_height,
  3070. uint32_t stride_width,
  3071. uint32_t dilation_height,
  3072. uint32_t dilation_width,
  3073. uint32_t groups,
  3074. size_t group_input_channels,
  3075. size_t group_output_channels,
  3076. size_t input_pixel_stride,
  3077. size_t output_pixel_stride,
  3078. const void* kernel,
  3079. const void* bias,
  3080. float output_min,
  3081. float output_max,
  3082. uint32_t flags,
  3083. xnn_code_cache_t code_cache,
  3084. xnn_weights_cache_t weights_cache,
  3085. xnn_operator_t* deconvolution_op_out);
  3086. enum xnn_status xnn_reshape_deconvolution2d_nhwc_f16(
  3087. xnn_operator_t deconvolution_op,
  3088. size_t batch_size,
  3089. size_t input_height,
  3090. size_t input_width,
  3091. uint32_t adjustment_height,
  3092. uint32_t adjustment_width,
  3093. size_t* output_height_out,
  3094. size_t* output_width_out,
  3095. pthreadpool_t threadpool);
  3096. enum xnn_status xnn_setup_deconvolution2d_nhwc_f16(
  3097. xnn_operator_t deconvolution_op,
  3098. const void* input,
  3099. void* output);
  3100. enum xnn_status xnn_create_deconvolution2d_nhwc_f32(
  3101. uint32_t output_padding_top,
  3102. uint32_t output_padding_right,
  3103. uint32_t output_padding_bottom,
  3104. uint32_t output_padding_left,
  3105. uint32_t kernel_height,
  3106. uint32_t kernel_width,
  3107. uint32_t stride_height,
  3108. uint32_t stride_width,
  3109. uint32_t dilation_height,
  3110. uint32_t dilation_width,
  3111. uint32_t groups,
  3112. size_t group_input_channels,
  3113. size_t group_output_channels,
  3114. size_t input_pixel_stride,
  3115. size_t output_pixel_stride,
  3116. const float* kernel,
  3117. const float* bias,
  3118. float output_min,
  3119. float output_max,
  3120. uint32_t flags,
  3121. xnn_code_cache_t code_cache,
  3122. xnn_weights_cache_t weights_cache,
  3123. xnn_operator_t* deconvolution_op_out);
  3124. enum xnn_status xnn_create_deconvolution2d_nhwc_f32_f16(
  3125. uint32_t output_padding_top,
  3126. uint32_t output_padding_right,
  3127. uint32_t output_padding_bottom,
  3128. uint32_t output_padding_left,
  3129. uint32_t kernel_height,
  3130. uint32_t kernel_width,
  3131. uint32_t stride_height,
  3132. uint32_t stride_width,
  3133. uint32_t dilation_height,
  3134. uint32_t dilation_width,
  3135. uint32_t groups,
  3136. size_t group_input_channels,
  3137. size_t group_output_channels,
  3138. size_t input_pixel_stride,
  3139. size_t output_pixel_stride,
  3140. const void* kernel,
  3141. const void* bias,
  3142. float output_min,
  3143. float output_max,
  3144. uint32_t flags,
  3145. xnn_code_cache_t code_cache,
  3146. xnn_weights_cache_t weights_cache,
  3147. xnn_operator_t* deconvolution_op_out);
  3148. enum xnn_status xnn_reshape_deconvolution2d_nhwc_f32(
  3149. xnn_operator_t deconvolution_op,
  3150. size_t batch_size,
  3151. size_t input_height,
  3152. size_t input_width,
  3153. uint32_t adjustment_height,
  3154. uint32_t adjustment_width,
  3155. size_t* output_height_out,
  3156. size_t* output_width_out,
  3157. pthreadpool_t threadpool);
  3158. enum xnn_status xnn_setup_deconvolution2d_nhwc_f32(
  3159. xnn_operator_t deconvolution_op,
  3160. const float* input,
  3161. float* output);
  3162. enum xnn_status xnn_create_deconvolution2d_nhwc_qd8_f32_qc8w(
  3163. uint32_t output_padding_top,
  3164. uint32_t output_padding_right,
  3165. uint32_t output_padding_bottom,
  3166. uint32_t output_padding_left,
  3167. uint32_t kernel_height,
  3168. uint32_t kernel_width,
  3169. uint32_t stride_height,
  3170. uint32_t stride_width,
  3171. uint32_t dilation_height,
  3172. uint32_t dilation_width,
  3173. uint32_t groups,
  3174. size_t group_input_channels,
  3175. size_t group_output_channels,
  3176. size_t input_pixel_stride,
  3177. size_t output_pixel_stride,
  3178. const float* kernel_scale,
  3179. const int8_t* kernel,
  3180. const float* bias,
  3181. float output_min,
  3182. float output_max,
  3183. uint32_t flags,
  3184. xnn_code_cache_t code_cache,
  3185. xnn_weights_cache_t weights_cache,
  3186. xnn_operator_t* deconvolution_op_out);
  3187. enum xnn_status xnn_reshape_deconvolution2d_nhwc_qd8_f32_qc8w(
  3188. xnn_operator_t deconvolution_op,
  3189. size_t batch_size,
  3190. size_t input_height,
  3191. size_t input_width,
  3192. uint32_t adjustment_height,
  3193. uint32_t adjustment_width,
  3194. size_t* output_height_out,
  3195. size_t* output_width_out,
  3196. pthreadpool_t threadpool);
  3197. enum xnn_status xnn_setup_deconvolution2d_nhwc_qd8_f32_qc8w(
  3198. xnn_operator_t deconvolution_op,
  3199. const int8_t* input,
  3200. float* output,
  3201. const struct xnn_quantization_params* quantization_params);
  3202. enum xnn_status xnn_create_deconvolution2d_nhwc_qs8(
  3203. uint32_t output_padding_top,
  3204. uint32_t output_padding_right,
  3205. uint32_t output_padding_bottom,
  3206. uint32_t output_padding_left,
  3207. uint32_t kernel_height,
  3208. uint32_t kernel_width,
  3209. uint32_t stride_height,
  3210. uint32_t stride_width,
  3211. uint32_t dilation_height,
  3212. uint32_t dilation_width,
  3213. uint32_t groups,
  3214. size_t group_input_channels,
  3215. size_t group_output_channels,
  3216. size_t input_pixel_stride,
  3217. size_t output_pixel_stride,
  3218. int8_t input_zero_point,
  3219. float input_scale,
  3220. float kernel_scale,
  3221. const int8_t* kernel,
  3222. const int32_t* bias,
  3223. int8_t output_zero_point,
  3224. float output_scale,
  3225. int8_t output_min,
  3226. int8_t output_max,
  3227. uint32_t flags,
  3228. xnn_code_cache_t code_cache,
  3229. xnn_weights_cache_t weights_cache,
  3230. xnn_operator_t* deconvolution_op_out);
  3231. enum xnn_status xnn_reshape_deconvolution2d_nhwc_qs8(
  3232. xnn_operator_t deconvolution_op,
  3233. size_t batch_size,
  3234. size_t input_height,
  3235. size_t input_width,
  3236. uint32_t adjustment_height,
  3237. uint32_t adjustment_width,
  3238. size_t* output_height_out,
  3239. size_t* output_width_out,
  3240. pthreadpool_t threadpool);
  3241. enum xnn_status xnn_setup_deconvolution2d_nhwc_qs8(
  3242. xnn_operator_t deconvolution_op,
  3243. const int8_t* input,
  3244. int8_t* output);
  3245. enum xnn_status xnn_create_deconvolution2d_nhwc_qs8_qc8w(
  3246. uint32_t output_padding_top,
  3247. uint32_t output_padding_right,
  3248. uint32_t output_padding_bottom,
  3249. uint32_t output_padding_left,
  3250. uint32_t kernel_height,
  3251. uint32_t kernel_width,
  3252. uint32_t stride_height,
  3253. uint32_t stride_width,
  3254. uint32_t dilation_height,
  3255. uint32_t dilation_width,
  3256. uint32_t groups,
  3257. size_t group_input_channels,
  3258. size_t group_output_channels,
  3259. size_t input_pixel_stride,
  3260. size_t output_pixel_stride,
  3261. int8_t input_zero_point,
  3262. float input_scale,
  3263. const float* kernel_scale,
  3264. const int8_t* kernel,
  3265. const int32_t* bias,
  3266. int8_t output_zero_point,
  3267. float output_scale,
  3268. int8_t output_min,
  3269. int8_t output_max,
  3270. uint32_t flags,
  3271. xnn_code_cache_t code_cache,
  3272. xnn_weights_cache_t weights_cache,
  3273. xnn_operator_t* deconvolution_op_out);
  3274. enum xnn_status xnn_reshape_deconvolution2d_nhwc_qs8_qc8w(
  3275. xnn_operator_t deconvolution_op,
  3276. size_t batch_size,
  3277. size_t input_height,
  3278. size_t input_width,
  3279. uint32_t adjustment_height,
  3280. uint32_t adjustment_width,
  3281. size_t* output_height_out,
  3282. size_t* output_width_out,
  3283. pthreadpool_t threadpool);
  3284. enum xnn_status xnn_setup_deconvolution2d_nhwc_qs8_qc8w(
  3285. xnn_operator_t deconvolution_op,
  3286. const int8_t* input,
  3287. int8_t* output);
  3288. enum xnn_status xnn_create_deconvolution2d_nhwc_qu8(
  3289. uint32_t output_padding_top,
  3290. uint32_t output_padding_right,
  3291. uint32_t output_padding_bottom,
  3292. uint32_t output_padding_left,
  3293. uint32_t kernel_height,
  3294. uint32_t kernel_width,
  3295. uint32_t stride_height,
  3296. uint32_t stride_width,
  3297. uint32_t dilation_height,
  3298. uint32_t dilation_width,
  3299. uint32_t groups,
  3300. size_t group_input_channels,
  3301. size_t group_output_channels,
  3302. size_t input_pixel_stride,
  3303. size_t output_pixel_stride,
  3304. uint8_t input_zero_point,
  3305. float input_scale,
  3306. uint8_t kernel_zero_point,
  3307. float kernel_scale,
  3308. const uint8_t* kernel,
  3309. const int32_t* bias,
  3310. uint8_t output_zero_point,
  3311. float output_scale,
  3312. uint8_t output_min,
  3313. uint8_t output_max,
  3314. uint32_t flags,
  3315. xnn_code_cache_t code_cache,
  3316. xnn_weights_cache_t weights_cache,
  3317. xnn_operator_t* deconvolution_op_out);
  3318. enum xnn_status xnn_reshape_deconvolution2d_nhwc_qu8(
  3319. xnn_operator_t deconvolution_op,
  3320. size_t batch_size,
  3321. size_t input_height,
  3322. size_t input_width,
  3323. uint32_t adjustment_height,
  3324. uint32_t adjustment_width,
  3325. size_t* output_height_out,
  3326. size_t* output_width_out,
  3327. pthreadpool_t threadpool);
  3328. enum xnn_status xnn_setup_deconvolution2d_nhwc_qu8(
  3329. xnn_operator_t deconvolution_op,
  3330. const uint8_t* input,
  3331. uint8_t* output);
  3332. enum xnn_status xnn_create_depth_to_space_nchw2nhwc_x16(
  3333. uint32_t block_size,
  3334. uint32_t flags,
  3335. xnn_operator_t* depth_to_space_op_out);
  3336. enum xnn_status xnn_reshape_depth_to_space_nchw2nhwc_x16(
  3337. xnn_operator_t depth_to_space_op,
  3338. size_t batch_size,
  3339. size_t input_height,
  3340. size_t input_width,
  3341. size_t input_channels,
  3342. size_t* output_height_out,
  3343. size_t* output_width_out,
  3344. size_t* output_channels_out,
  3345. pthreadpool_t threadpool);
  3346. enum xnn_status xnn_setup_depth_to_space_nchw2nhwc_x16(
  3347. xnn_operator_t depth_to_space_op,
  3348. const void* input,
  3349. void* output);
  3350. enum xnn_status xnn_create_depth_to_space_nchw2nhwc_x32(
  3351. uint32_t block_size,
  3352. uint32_t flags,
  3353. xnn_operator_t* depth_to_space_op_out);
  3354. enum xnn_status xnn_reshape_depth_to_space_nchw2nhwc_x32(
  3355. xnn_operator_t depth_to_space_op,
  3356. size_t batch_size,
  3357. size_t input_height,
  3358. size_t input_width,
  3359. size_t input_channels,
  3360. size_t* output_height_out,
  3361. size_t* output_width_out,
  3362. size_t* output_channels_out,
  3363. pthreadpool_t threadpool);
  3364. enum xnn_status xnn_setup_depth_to_space_nchw2nhwc_x32(
  3365. xnn_operator_t depth_to_space_op,
  3366. const void* input,
  3367. void* output);
  3368. enum xnn_status xnn_create_depth_to_space_nhwc_x8(
  3369. uint32_t block_size,
  3370. uint32_t flags,
  3371. xnn_operator_t* depth_to_space_op_out);
  3372. enum xnn_status xnn_reshape_depth_to_space_nhwc_x8(
  3373. xnn_operator_t depth_to_space_op,
  3374. size_t batch_size,
  3375. size_t input_height,
  3376. size_t input_width,
  3377. size_t input_channels,
  3378. size_t* output_height_out,
  3379. size_t* output_width_out,
  3380. size_t* output_channels_out,
  3381. pthreadpool_t threadpool);
  3382. enum xnn_status xnn_setup_depth_to_space_nhwc_x8(
  3383. xnn_operator_t depth_to_space_op,
  3384. const void* input,
  3385. void* output);
  3386. enum xnn_status xnn_create_depth_to_space_nhwc_x16(
  3387. uint32_t block_size,
  3388. uint32_t flags,
  3389. xnn_operator_t* depth_to_space_op_out);
  3390. enum xnn_status xnn_reshape_depth_to_space_nhwc_x16(
  3391. xnn_operator_t depth_to_space_op,
  3392. size_t batch_size,
  3393. size_t input_height,
  3394. size_t input_width,
  3395. size_t input_channels,
  3396. size_t* output_height_out,
  3397. size_t* output_width_out,
  3398. size_t* output_channels_out,
  3399. pthreadpool_t threadpool);
  3400. enum xnn_status xnn_setup_depth_to_space_nhwc_x16(
  3401. xnn_operator_t depth_to_space_op,
  3402. const void* input,
  3403. void* output);
  3404. enum xnn_status xnn_create_depth_to_space_nhwc_x32(
  3405. uint32_t block_size,
  3406. uint32_t flags,
  3407. xnn_operator_t* depth_to_space_op_out);
  3408. enum xnn_status xnn_reshape_depth_to_space_nhwc_x32(
  3409. xnn_operator_t depth_to_space_op,
  3410. size_t batch_size,
  3411. size_t input_height,
  3412. size_t input_width,
  3413. size_t input_channels,
  3414. size_t* output_height_out,
  3415. size_t* output_width_out,
  3416. size_t* output_channels_out,
  3417. pthreadpool_t threadpool);
  3418. enum xnn_status xnn_setup_depth_to_space_nhwc_x32(
  3419. xnn_operator_t depth_to_space_op,
  3420. const void* input,
  3421. void* output);
  3422. enum xnn_status xnn_create_dynamic_fully_connected_nc_f16(
  3423. float output_min,
  3424. float output_max,
  3425. uint32_t flags,
  3426. xnn_operator_t* dynamic_fully_connected_op_out);
  3427. enum xnn_status xnn_reshape_dynamic_fully_connected_nc_f16(
  3428. xnn_operator_t dynamic_fully_connected_op,
  3429. size_t batch_size,
  3430. size_t input_channels,
  3431. size_t output_channels,
  3432. size_t input_stride,
  3433. size_t output_stride,
  3434. size_t* workspace_size,
  3435. size_t* workspace_alignment,
  3436. pthreadpool_t threadpool);
  3437. enum xnn_status xnn_setup_dynamic_fully_connected_nc_f16(
  3438. xnn_operator_t dynamic_fully_connected_op,
  3439. void* workspace,
  3440. const void* input,
  3441. const void* kernel,
  3442. const void* bias,
  3443. void* output);
  3444. enum xnn_status xnn_create_dynamic_fully_connected_nc_f32(
  3445. float output_min,
  3446. float output_max,
  3447. uint32_t flags,
  3448. xnn_operator_t* dynamic_fully_connected_op_out);
  3449. enum xnn_status xnn_reshape_dynamic_fully_connected_nc_f32(
  3450. xnn_operator_t dynamic_fully_connected_op,
  3451. size_t batch_size,
  3452. size_t input_channels,
  3453. size_t output_channels,
  3454. size_t input_stride,
  3455. size_t output_stride,
  3456. size_t* workspace_size,
  3457. size_t* workspace_alignment,
  3458. pthreadpool_t threadpool);
  3459. enum xnn_status xnn_setup_dynamic_fully_connected_nc_f32(
  3460. xnn_operator_t dynamic_fully_connected_op,
  3461. void* workspace,
  3462. const float* input,
  3463. const float* kernel,
  3464. const float* bias,
  3465. float* output);
  3466. enum xnn_status xnn_create_fully_connected_nc_f16(
  3467. size_t input_channels,
  3468. size_t output_channels,
  3469. size_t input_stride,
  3470. size_t output_stride,
  3471. const void* kernel,
  3472. const void* bias,
  3473. float output_min,
  3474. float output_max,
  3475. uint32_t flags,
  3476. xnn_code_cache_t code_cache,
  3477. xnn_weights_cache_t weights_cache,
  3478. xnn_operator_t* fully_connected_op_out);
  3479. enum xnn_status xnn_reshape_fully_connected_nc_f16(
  3480. xnn_operator_t fully_connected_op,
  3481. size_t batch_size,
  3482. pthreadpool_t threadpool);
  3483. enum xnn_status xnn_setup_fully_connected_nc_f16(
  3484. xnn_operator_t fully_connected_op,
  3485. const void* input,
  3486. void* output);
  3487. enum xnn_status xnn_create_fully_connected_nc_f32_f16(
  3488. size_t input_channels,
  3489. size_t output_channels,
  3490. size_t input_stride,
  3491. size_t output_stride,
  3492. const void* kernel,
  3493. const void* bias,
  3494. float output_min,
  3495. float output_max,
  3496. uint32_t flags,
  3497. xnn_code_cache_t code_cache,
  3498. xnn_weights_cache_t weights_cache,
  3499. xnn_operator_t* fully_connected_op_out);
  3500. enum xnn_status xnn_create_fully_connected_nc_f32(
  3501. size_t input_channels,
  3502. size_t output_channels,
  3503. size_t input_stride,
  3504. size_t output_stride,
  3505. const float* kernel,
  3506. const float* bias,
  3507. float output_min,
  3508. float output_max,
  3509. uint32_t flags,
  3510. xnn_code_cache_t code_cache,
  3511. xnn_weights_cache_t weights_cache,
  3512. xnn_operator_t* fully_connected_op_out);
  3513. enum xnn_status xnn_reshape_fully_connected_nc_f32_f16(
  3514. xnn_operator_t fully_connected_op,
  3515. size_t batch_size,
  3516. pthreadpool_t threadpool);
  3517. enum xnn_status xnn_reshape_fully_connected_nc_f32(
  3518. xnn_operator_t fully_connected_op,
  3519. size_t batch_size,
  3520. pthreadpool_t threadpool);
  3521. enum xnn_status xnn_setup_fully_connected_nc_f32_f16(
  3522. xnn_operator_t fully_connected_op,
  3523. const float* input,
  3524. float* output);
  3525. enum xnn_status xnn_setup_fully_connected_nc_f32(
  3526. xnn_operator_t fully_connected_op,
  3527. const float* input,
  3528. float* output);
  3529. enum xnn_status xnn_create_fully_connected_nc_f32_qc4w(
  3530. size_t input_channels,
  3531. size_t output_channels,
  3532. size_t input_stride,
  3533. size_t output_stride,
  3534. uint8_t kernel_zero_point,
  3535. const float* kernel_scale,
  3536. const uint8_t* kernel,
  3537. const float* bias,
  3538. float output_min,
  3539. float output_max,
  3540. uint32_t flags,
  3541. xnn_code_cache_t code_cache,
  3542. xnn_weights_cache_t weights_cache,
  3543. xnn_operator_t* fully_connected_op_out);
  3544. enum xnn_status xnn_reshape_fully_connected_nc_f32_qc4w(
  3545. xnn_operator_t fully_connected_op,
  3546. size_t batch_size,
  3547. pthreadpool_t threadpool);
  3548. enum xnn_status xnn_setup_fully_connected_nc_f32_qc4w(
  3549. xnn_operator_t fully_connected_op,
  3550. const float* input,
  3551. float* output);
  3552. enum xnn_status xnn_create_fully_connected_nc_f32_qc8w(
  3553. size_t input_channels,
  3554. size_t output_channels,
  3555. size_t input_stride,
  3556. size_t output_stride,
  3557. const float* kernel_scale,
  3558. const int8_t* kernel,
  3559. const float* bias,
  3560. float output_min,
  3561. float output_max,
  3562. uint32_t flags,
  3563. xnn_code_cache_t code_cache,
  3564. xnn_weights_cache_t weights_cache,
  3565. xnn_operator_t* fully_connected_op_out);
  3566. enum xnn_status xnn_reshape_fully_connected_nc_f32_qc8w(
  3567. xnn_operator_t fully_connected_op,
  3568. size_t batch_size,
  3569. pthreadpool_t threadpool);
  3570. enum xnn_status xnn_setup_fully_connected_nc_f32_qc8w(
  3571. xnn_operator_t fully_connected_op,
  3572. const float* input,
  3573. float* output);
  3574. enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qc4w(
  3575. size_t input_channels,
  3576. size_t output_channels,
  3577. size_t input_stride,
  3578. size_t output_stride,
  3579. uint8_t kernel_zero_point,
  3580. const float* kernel_scale,
  3581. const void* kernel,
  3582. const float* bias,
  3583. float output_min,
  3584. float output_max,
  3585. uint32_t flags,
  3586. xnn_code_cache_t code_cache,
  3587. xnn_weights_cache_t weights_cache,
  3588. xnn_operator_t* fully_connected_op_out);
  3589. enum xnn_status xnn_setup_fully_connected_nc_qd8_f16_qc4w(
  3590. xnn_operator_t fully_connected_op,
  3591. const int8_t* input,
  3592. void* output,
  3593. const struct xnn_quantization_params* quantization_params);
  3594. enum xnn_status xnn_reshape_fully_connected_nc_qd8_f16_qc4w(
  3595. xnn_operator_t fully_connected_op,
  3596. size_t batch_size,
  3597. pthreadpool_t threadpool);
  3598. enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qb4w(
  3599. size_t input_channels,
  3600. size_t output_channels,
  3601. size_t input_stride,
  3602. size_t output_stride,
  3603. size_t block_size,
  3604. uint8_t kernel_zero_point,
  3605. const uint16_t* kernel_scale,
  3606. const void* kernel,
  3607. const float* bias,
  3608. float output_min,
  3609. float output_max,
  3610. uint32_t flags,
  3611. xnn_code_cache_t code_cache,
  3612. xnn_weights_cache_t weights_cache,
  3613. xnn_operator_t* fully_connected_op_out);
  3614. enum xnn_status xnn_reshape_fully_connected_nc_qd8_f16_qb4w(
  3615. xnn_operator_t fully_connected_op,
  3616. size_t batch_size,
  3617. pthreadpool_t threadpool);
  3618. enum xnn_status xnn_setup_fully_connected_nc_qd8_f16_qb4w(
  3619. xnn_operator_t fully_connected_op,
  3620. const int8_t* input,
  3621. void* output,
  3622. const struct xnn_quantization_params* quantization_params);
  3623. enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qc4w(
  3624. size_t input_channels,
  3625. size_t output_channels,
  3626. size_t input_stride,
  3627. size_t output_stride,
  3628. uint8_t kernel_zero_point,
  3629. const float* kernel_scale,
  3630. const void* kernel,
  3631. const float* bias,
  3632. float output_min,
  3633. float output_max,
  3634. uint32_t flags,
  3635. xnn_code_cache_t code_cache,
  3636. xnn_weights_cache_t weights_cache,
  3637. xnn_operator_t* fully_connected_op_out);
  3638. enum xnn_status xnn_setup_fully_connected_nc_qd8_f32_qc4w(
  3639. xnn_operator_t fully_connected_op,
  3640. const int8_t* input,
  3641. float* output,
  3642. const struct xnn_quantization_params* quantization_params);
  3643. enum xnn_status xnn_reshape_fully_connected_nc_qd8_f32_qc4w(
  3644. xnn_operator_t fully_connected_op,
  3645. size_t batch_size,
  3646. pthreadpool_t threadpool);
  3647. enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qb4w(
  3648. size_t input_channels,
  3649. size_t output_channels,
  3650. size_t input_stride,
  3651. size_t output_stride,
  3652. size_t block_size,
  3653. uint8_t kernel_zero_point,
  3654. const uint16_t* kernel_scale,
  3655. const void* kernel,
  3656. const float* bias,
  3657. float output_min,
  3658. float output_max,
  3659. uint32_t flags,
  3660. xnn_code_cache_t code_cache,
  3661. xnn_weights_cache_t weights_cache,
  3662. xnn_operator_t* fully_connected_op_out);
  3663. enum xnn_status xnn_reshape_fully_connected_nc_qd8_f32_qb4w(
  3664. xnn_operator_t fully_connected_op,
  3665. size_t batch_size,
  3666. pthreadpool_t threadpool);
  3667. enum xnn_status xnn_setup_fully_connected_nc_qd8_f32_qb4w(
  3668. xnn_operator_t fully_connected_op,
  3669. const int8_t* input,
  3670. float* output,
  3671. const struct xnn_quantization_params* quantization_params);
  3672. enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qc8w(
  3673. size_t input_channels,
  3674. size_t output_channels,
  3675. size_t input_stride,
  3676. size_t output_stride,
  3677. const float* kernel_scale,
  3678. const int8_t* kernel,
  3679. const float* bias,
  3680. float output_min,
  3681. float output_max,
  3682. uint32_t flags,
  3683. xnn_code_cache_t code_cache,
  3684. xnn_weights_cache_t weights_cache,
  3685. xnn_operator_t* fully_connected_op_out);
  3686. enum xnn_status xnn_setup_fully_connected_nc_qd8_f16_qc8w(
  3687. xnn_operator_t fully_connected_op,
  3688. const int8_t* input,
  3689. void* output,
  3690. const struct xnn_quantization_params* quantization_params);
  3691. enum xnn_status xnn_reshape_fully_connected_nc_qd8_f16_qc8w(
  3692. xnn_operator_t fully_connected_op,
  3693. size_t batch_size,
  3694. pthreadpool_t threadpool);
  3695. enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qc8w(
  3696. size_t input_channels,
  3697. size_t output_channels,
  3698. size_t input_stride,
  3699. size_t output_stride,
  3700. const float* kernel_scale,
  3701. const int8_t* kernel,
  3702. const float* bias,
  3703. float output_min,
  3704. float output_max,
  3705. uint32_t flags,
  3706. xnn_code_cache_t code_cache,
  3707. xnn_weights_cache_t weights_cache,
  3708. xnn_operator_t* fully_connected_op_out);
  3709. enum xnn_status xnn_setup_fully_connected_nc_qd8_f32_qc8w(
  3710. xnn_operator_t fully_connected_op,
  3711. const int8_t* input,
  3712. float* output,
  3713. const struct xnn_quantization_params* quantization_params);
  3714. enum xnn_status xnn_reshape_fully_connected_nc_qd8_f32_qc8w(
  3715. xnn_operator_t fully_connected_op,
  3716. size_t batch_size,
  3717. pthreadpool_t threadpool);
  3718. enum xnn_status xnn_create_fully_connected_nc_qs8(
  3719. size_t input_channels,
  3720. size_t output_channels,
  3721. size_t input_stride,
  3722. size_t output_stride,
  3723. int8_t input_zero_point,
  3724. float input_scale,
  3725. float kernel_scale,
  3726. const int8_t* kernel,
  3727. const int32_t* bias,
  3728. int8_t output_zero_point,
  3729. float output_scale,
  3730. int8_t output_min,
  3731. int8_t output_max,
  3732. uint32_t flags,
  3733. xnn_code_cache_t code_cache,
  3734. xnn_weights_cache_t weights_cache,
  3735. xnn_operator_t* fully_connected_op_out);
  3736. enum xnn_status xnn_reshape_fully_connected_nc_qs8(
  3737. xnn_operator_t fully_connected_op,
  3738. size_t batch_size,
  3739. pthreadpool_t threadpool);
  3740. enum xnn_status xnn_setup_fully_connected_nc_qs8(
  3741. xnn_operator_t fully_connected_op,
  3742. const int8_t* input,
  3743. int8_t* output);
  3744. enum xnn_status xnn_create_fully_connected_nc_qs8_qc8w(
  3745. size_t input_channels,
  3746. size_t output_channels,
  3747. size_t input_stride,
  3748. size_t output_stride,
  3749. int8_t input_zero_point,
  3750. float input_scale,
  3751. const float* kernel_scale,
  3752. const int8_t* kernel,
  3753. const int32_t* bias,
  3754. int8_t output_zero_point,
  3755. float output_scale,
  3756. int8_t output_min,
  3757. int8_t output_max,
  3758. uint32_t flags,
  3759. xnn_code_cache_t code_cache,
  3760. xnn_weights_cache_t weights_cache,
  3761. xnn_operator_t* fully_connected_op_out);
  3762. enum xnn_status xnn_reshape_fully_connected_nc_qs8_qc8w(
  3763. xnn_operator_t fully_connected_op,
  3764. size_t batch_size,
  3765. pthreadpool_t threadpool);
  3766. enum xnn_status xnn_setup_fully_connected_nc_qs8_qc8w(
  3767. xnn_operator_t fully_connected_op,
  3768. const int8_t* input,
  3769. int8_t* output);
  3770. enum xnn_status xnn_create_fully_connected_nc_qu8(
  3771. size_t input_channels,
  3772. size_t output_channels,
  3773. size_t input_stride,
  3774. size_t output_stride,
  3775. uint8_t input_zero_point,
  3776. float input_scale,
  3777. uint8_t kernel_zero_point,
  3778. float kernel_scale,
  3779. const uint8_t* kernel,
  3780. const int32_t* bias,
  3781. uint8_t output_zero_point,
  3782. float output_scale,
  3783. uint8_t output_min,
  3784. uint8_t output_max,
  3785. uint32_t flags,
  3786. xnn_code_cache_t code_cache,
  3787. xnn_weights_cache_t weights_cache,
  3788. xnn_operator_t* fully_connected_op_out);
  3789. enum xnn_status xnn_reshape_fully_connected_nc_qu8(
  3790. xnn_operator_t fully_connected_op,
  3791. size_t batch_size,
  3792. pthreadpool_t threadpool);
  3793. enum xnn_status xnn_setup_fully_connected_nc_qu8(
  3794. xnn_operator_t fully_connected_op,
  3795. const uint8_t* input,
  3796. uint8_t* output);
  3797. enum xnn_status xnn_create_max_pooling2d_nhwc_f16(
  3798. uint32_t input_padding_top,
  3799. uint32_t input_padding_right,
  3800. uint32_t input_padding_bottom,
  3801. uint32_t input_padding_left,
  3802. uint32_t pooling_height,
  3803. uint32_t pooling_width,
  3804. uint32_t stride_height,
  3805. uint32_t stride_width,
  3806. uint32_t dilation_height,
  3807. uint32_t dilation_width,
  3808. float output_min,
  3809. float output_max,
  3810. uint32_t flags,
  3811. xnn_operator_t* max_pooling_op_out);
  3812. enum xnn_status xnn_reshape_max_pooling2d_nhwc_f16(
  3813. xnn_operator_t max_pooling_op,
  3814. size_t batch_size,
  3815. size_t input_height,
  3816. size_t input_width,
  3817. size_t channels,
  3818. size_t input_pixel_stride,
  3819. size_t output_pixel_stride,
  3820. size_t* output_height_out,
  3821. size_t* output_width_out,
  3822. pthreadpool_t threadpool);
  3823. enum xnn_status xnn_setup_max_pooling2d_nhwc_f16(
  3824. xnn_operator_t max_pooling_op,
  3825. const void* input,
  3826. void* output);
  3827. enum xnn_status xnn_create_max_pooling2d_nhwc_f32(
  3828. uint32_t input_padding_top,
  3829. uint32_t input_padding_right,
  3830. uint32_t input_padding_bottom,
  3831. uint32_t input_padding_left,
  3832. uint32_t pooling_height,
  3833. uint32_t pooling_width,
  3834. uint32_t stride_height,
  3835. uint32_t stride_width,
  3836. uint32_t dilation_height,
  3837. uint32_t dilation_width,
  3838. float output_min,
  3839. float output_max,
  3840. uint32_t flags,
  3841. xnn_operator_t* max_pooling_op_out);
  3842. enum xnn_status xnn_reshape_max_pooling2d_nhwc_f32(
  3843. xnn_operator_t max_pooling_op,
  3844. size_t batch_size,
  3845. size_t input_height,
  3846. size_t input_width,
  3847. size_t channels,
  3848. size_t input_pixel_stride,
  3849. size_t output_pixel_stride,
  3850. size_t* output_height_out,
  3851. size_t* output_width_out,
  3852. pthreadpool_t threadpool);
  3853. enum xnn_status xnn_setup_max_pooling2d_nhwc_f32(
  3854. xnn_operator_t max_pooling_op,
  3855. const float* input,
  3856. float* output);
  3857. enum xnn_status xnn_create_max_pooling2d_nhwc_s8(
  3858. uint32_t input_padding_top,
  3859. uint32_t input_padding_right,
  3860. uint32_t input_padding_bottom,
  3861. uint32_t input_padding_left,
  3862. uint32_t pooling_height,
  3863. uint32_t pooling_width,
  3864. uint32_t stride_height,
  3865. uint32_t stride_width,
  3866. uint32_t dilation_height,
  3867. uint32_t dilation_width,
  3868. int8_t output_min,
  3869. int8_t output_max,
  3870. uint32_t flags,
  3871. xnn_operator_t* max_pooling_op_out);
  3872. enum xnn_status xnn_reshape_max_pooling2d_nhwc_s8(
  3873. xnn_operator_t max_pooling_op,
  3874. size_t batch_size,
  3875. size_t input_height,
  3876. size_t input_width,
  3877. size_t channels,
  3878. size_t input_pixel_stride,
  3879. size_t output_pixel_stride,
  3880. size_t* output_height_out,
  3881. size_t* output_width_out,
  3882. pthreadpool_t threadpool);
  3883. enum xnn_status xnn_setup_max_pooling2d_nhwc_s8(
  3884. xnn_operator_t max_pooling_op,
  3885. const int8_t* input,
  3886. int8_t* output);
  3887. enum xnn_status xnn_create_max_pooling2d_nhwc_u8(
  3888. uint32_t input_padding_top,
  3889. uint32_t input_padding_right,
  3890. uint32_t input_padding_bottom,
  3891. uint32_t input_padding_left,
  3892. uint32_t pooling_height,
  3893. uint32_t pooling_width,
  3894. uint32_t stride_height,
  3895. uint32_t stride_width,
  3896. uint32_t dilation_height,
  3897. uint32_t dilation_width,
  3898. uint8_t output_min,
  3899. uint8_t output_max,
  3900. uint32_t flags,
  3901. xnn_operator_t* max_pooling_op_out);
  3902. enum xnn_status xnn_reshape_max_pooling2d_nhwc_u8(
  3903. xnn_operator_t max_pooling_op,
  3904. size_t batch_size,
  3905. size_t input_height,
  3906. size_t input_width,
  3907. size_t channels,
  3908. size_t input_pixel_stride,
  3909. size_t output_pixel_stride,
  3910. size_t* output_height_out,
  3911. size_t* output_width_out,
  3912. pthreadpool_t threadpool);
  3913. enum xnn_status xnn_setup_max_pooling2d_nhwc_u8(
  3914. xnn_operator_t max_pooling_op,
  3915. const uint8_t* input,
  3916. uint8_t* output);
  3917. enum xnn_status xnn_create_reduce_nd(
  3918. enum xnn_reduce_operator reduce_operator_type,
  3919. enum xnn_datatype datatype,
  3920. const struct xnn_quantization_params* input_quantization,
  3921. const struct xnn_quantization_params* output_quantization,
  3922. uint32_t flags,
  3923. xnn_operator_t* reduce_op_out);
  3924. enum xnn_status xnn_reshape_reduce_nd( //
  3925. xnn_operator_t reduce_op, //
  3926. size_t num_reduction_axes, //
  3927. const int64_t* reduction_axes, //
  3928. size_t num_input_dims, //
  3929. const size_t* input_shape, //
  3930. size_t* workspace_size, //
  3931. size_t* workspace_alignment, //
  3932. pthreadpool_t threadpool);
  3933. enum xnn_status xnn_setup_reduce_nd(
  3934. xnn_operator_t reduce_op,
  3935. void* workspace,
  3936. const void* input,
  3937. void* output);
  3938. enum xnn_status xnn_create_resize_bilinear2d_nchw_f32(
  3939. size_t output_height,
  3940. size_t output_width,
  3941. uint32_t flags,
  3942. xnn_operator_t* resize_op_out);
  3943. enum xnn_status xnn_reshape_resize_bilinear2d_nchw_f32(
  3944. xnn_operator_t resize_op,
  3945. size_t batch_size,
  3946. size_t input_height,
  3947. size_t input_width,
  3948. size_t channels,
  3949. size_t input_pixel_stride,
  3950. size_t output_pixel_stride,
  3951. pthreadpool_t threadpool);
  3952. enum xnn_status xnn_setup_resize_bilinear2d_nchw_f32(
  3953. xnn_operator_t resize_op,
  3954. const float* input,
  3955. float* output);
  3956. enum xnn_status xnn_create_resize_bilinear2d_nchw_f16(
  3957. size_t output_height,
  3958. size_t output_width,
  3959. uint32_t flags,
  3960. xnn_operator_t* resize_op_out);
  3961. enum xnn_status xnn_reshape_resize_bilinear2d_nchw_f16(
  3962. xnn_operator_t resize_op,
  3963. size_t batch_size,
  3964. size_t input_height,
  3965. size_t input_width,
  3966. size_t channels,
  3967. size_t input_pixel_stride,
  3968. size_t output_pixel_stride,
  3969. pthreadpool_t threadpool);
  3970. enum xnn_status xnn_setup_resize_bilinear2d_nchw_f16(
  3971. xnn_operator_t resize_op,
  3972. const void* input,
  3973. void* output);
  3974. enum xnn_status xnn_create_resize_bilinear2d_nhwc_f16(
  3975. size_t output_height,
  3976. size_t output_width,
  3977. uint32_t flags,
  3978. xnn_operator_t* resize_op_out);
  3979. enum xnn_status xnn_reshape_resize_bilinear2d_nhwc_f16(
  3980. xnn_operator_t resize_op,
  3981. size_t batch_size,
  3982. size_t input_height,
  3983. size_t input_width,
  3984. size_t channels,
  3985. size_t input_pixel_stride,
  3986. size_t output_pixel_stride,
  3987. size_t* workspace_size,
  3988. size_t* workspace_alignment,
  3989. pthreadpool_t threadpool);
  3990. enum xnn_status xnn_setup_resize_bilinear2d_nhwc_f16(
  3991. xnn_operator_t resize_op,
  3992. void* workspace,
  3993. const void* input,
  3994. void* output);
  3995. enum xnn_status xnn_create_resize_bilinear2d_nhwc_f32(
  3996. size_t output_height,
  3997. size_t output_width,
  3998. uint32_t flags,
  3999. xnn_operator_t* resize_op_out);
  4000. enum xnn_status xnn_reshape_resize_bilinear2d_nhwc_f32(
  4001. xnn_operator_t resize_op,
  4002. size_t batch_size,
  4003. size_t input_height,
  4004. size_t input_width,
  4005. size_t channels,
  4006. size_t input_pixel_stride,
  4007. size_t output_pixel_stride,
  4008. size_t* workspace_size,
  4009. size_t* workspace_alignment,
  4010. pthreadpool_t threadpool);
  4011. enum xnn_status xnn_setup_resize_bilinear2d_nhwc_f32(
  4012. xnn_operator_t resize_op,
  4013. void* workspace,
  4014. const float* input,
  4015. float* output);
  4016. enum xnn_status xnn_create_resize_bilinear2d_nhwc_s8(
  4017. size_t output_height,
  4018. size_t output_width,
  4019. uint32_t flags,
  4020. xnn_operator_t* resize_op_out);
  4021. enum xnn_status xnn_reshape_resize_bilinear2d_nhwc_s8(
  4022. xnn_operator_t resize_op,
  4023. size_t batch_size,
  4024. size_t input_height,
  4025. size_t input_width,
  4026. size_t channels,
  4027. size_t input_pixel_stride,
  4028. size_t output_pixel_stride,
  4029. size_t* workspace_size,
  4030. size_t* workspace,
  4031. pthreadpool_t threadpool);
  4032. enum xnn_status xnn_setup_resize_bilinear2d_nhwc_s8(
  4033. xnn_operator_t resize_op,
  4034. void* workspace,
  4035. const int8_t* input,
  4036. int8_t* output);
  4037. enum xnn_status xnn_create_resize_bilinear2d_nhwc_u8(
  4038. size_t output_height,
  4039. size_t output_width,
  4040. uint32_t flags,
  4041. xnn_operator_t* resize_op_out);
  4042. enum xnn_status xnn_reshape_resize_bilinear2d_nhwc_u8(
  4043. xnn_operator_t resize_op,
  4044. size_t batch_size,
  4045. size_t input_height,
  4046. size_t input_width,
  4047. size_t channels,
  4048. size_t input_pixel_stride,
  4049. size_t output_pixel_stride,
  4050. size_t* workspace_size,
  4051. size_t* workspace_alignment,
  4052. pthreadpool_t threadpool);
  4053. enum xnn_status xnn_setup_resize_bilinear2d_nhwc_u8(
  4054. xnn_operator_t resize_op,
  4055. void* workspace,
  4056. const uint8_t* input,
  4057. uint8_t* output);
  4058. enum xnn_status xnn_create_rope_nthc_f16(
  4059. uint32_t flags,
  4060. xnn_operator_t* rope_op_out);
  4061. enum xnn_status xnn_reshape_rope_nthc_f16(
  4062. xnn_operator_t rope_op,
  4063. size_t batch_size,
  4064. size_t tokens,
  4065. size_t heads,
  4066. size_t channels,
  4067. pthreadpool_t threadpool);
  4068. enum xnn_status xnn_setup_rope_nthc_f16(
  4069. xnn_operator_t rope_op,
  4070. const void* input,
  4071. const void* weights,
  4072. void* output);
  4073. enum xnn_status xnn_create_rope_nthc_f32(
  4074. uint32_t flags,
  4075. xnn_operator_t* rope_op_out);
  4076. enum xnn_status xnn_reshape_rope_nthc_f32(
  4077. xnn_operator_t rope_op,
  4078. size_t batch_size,
  4079. size_t tokens,
  4080. size_t heads,
  4081. size_t channels,
  4082. pthreadpool_t threadpool);
  4083. enum xnn_status xnn_setup_rope_nthc_f32(
  4084. xnn_operator_t rope_op,
  4085. const float* input,
  4086. const float* weights,
  4087. float* output);
  4088. // N: batch size
  4089. // H: number of heads
  4090. // T: tokens (sequence length)
  4091. // C: channels (head dimension)
  4092. enum xnn_status xnn_create_scaled_dot_product_attention_nhtc_f16(
  4093. enum xnn_attention_logits_cap_type cap_type,
  4094. const void* cap_params,
  4095. uint32_t flags,
  4096. xnn_operator_t* attention_op_out);
  4097. enum xnn_status xnn_reshape_scaled_dot_product_attention_nhtc_f16(
  4098. xnn_operator_t attention_op,
  4099. size_t batch_size,
  4100. size_t query_heads,
  4101. // Number of tokens in query.
  4102. size_t query_tokens,
  4103. size_t key_value_heads,
  4104. // Number of tokens in key/value. For self-attention, this is same as tokens.
  4105. size_t key_value_tokens,
  4106. size_t query_key_channels,
  4107. size_t value_channels,
  4108. size_t* workspace_size,
  4109. size_t* workspace_alignment,
  4110. pthreadpool_t threadpool);
  4111. // Query is of dimension [batch_size, query_heads, query_tokens, channels].
  4112. // Key and value are of dimension [batch_size, key_value_heads, key_value_tokens, channels].
  4113. // Scale is of dimension [channels].
  4114. // Mask is of dimension [query_tokens, key_value_tokens].
  4115. enum xnn_status xnn_setup_scaled_dot_product_attention_nhtc_f16(
  4116. xnn_operator_t attention_op,
  4117. void* workspace,
  4118. const void* query,
  4119. const void* key,
  4120. const void* value,
  4121. const void* scale,
  4122. const void* mask,
  4123. void* output);
  4124. // N: batch size
  4125. // H: number of heads
  4126. // T: tokens (sequence length)
  4127. // C: channels (head dimension)
  4128. enum xnn_status xnn_create_scaled_dot_product_attention_nhtc_f32(
  4129. enum xnn_attention_logits_cap_type cap_type,
  4130. const void* cap_params,
  4131. uint32_t flags,
  4132. xnn_operator_t* attention_op_out);
  4133. enum xnn_status xnn_reshape_scaled_dot_product_attention_nhtc_f32(
  4134. xnn_operator_t attention_op,
  4135. size_t batch_size,
  4136. size_t query_heads,
  4137. // Number of tokens in query.
  4138. size_t query_tokens,
  4139. size_t key_value_heads,
  4140. // Number of tokens in key/value. For self-attention, this is same as tokens.
  4141. size_t key_value_tokens,
  4142. size_t query_key_channels,
  4143. size_t value_channels,
  4144. size_t* workspace_size,
  4145. size_t* workspace_alignment,
  4146. pthreadpool_t threadpool);
  4147. // Query is of dimension [batch_size, query_heads, query_tokens, query_key_channels].
  4148. // Key and value are of dimension [batch_size, key_value_heads, key_value_tokens, query_key_channels].
  4149. // Scale is of dimension [query_key_channels].
  4150. // Mask is of dimension [query_tokens, key_value_tokens].
  4151. // Output is of dimension [batch_size, query_heads, query_tokens, value_channels].
  4152. enum xnn_status xnn_setup_scaled_dot_product_attention_nhtc_f32(
  4153. xnn_operator_t attention_op,
  4154. void* workspace,
  4155. const float* query,
  4156. const float* key,
  4157. const float* value,
  4158. const float* scale,
  4159. const float* mask,
  4160. float* output);
  4161. enum xnn_status xnn_create_slice_nd_x16(
  4162. uint32_t flags,
  4163. xnn_operator_t* slice_op_out);
  4164. enum xnn_status xnn_reshape_slice_nd_x16(
  4165. xnn_operator_t slice_op,
  4166. size_t num_dims,
  4167. const size_t* input_shape,
  4168. const size_t* offsets,
  4169. const size_t* sizes,
  4170. pthreadpool_t threadpool);
  4171. enum xnn_status xnn_setup_slice_nd_x16(
  4172. xnn_operator_t slice_op,
  4173. const void* input,
  4174. void* output);
  4175. enum xnn_status xnn_create_slice_nd_x32(
  4176. uint32_t flags,
  4177. xnn_operator_t* slice_op_out);
  4178. enum xnn_status xnn_reshape_slice_nd_x32(
  4179. xnn_operator_t slice_op,
  4180. size_t num_dims,
  4181. const size_t* input_shape,
  4182. const size_t* offsets,
  4183. const size_t* sizes,
  4184. pthreadpool_t threadpool);
  4185. enum xnn_status xnn_setup_slice_nd_x32(
  4186. xnn_operator_t slice_op,
  4187. const void* input,
  4188. void* output);
  4189. enum xnn_status xnn_run_slice_nd_x32(
  4190. size_t num_dims,
  4191. const size_t* input_shape,
  4192. const size_t* offsets,
  4193. const size_t* sizes,
  4194. const void* input,
  4195. void* output,
  4196. uint32_t flags,
  4197. pthreadpool_t threadpool);
  4198. enum xnn_status xnn_create_softmax_nc_f16(
  4199. uint32_t flags,
  4200. xnn_operator_t* softmax_op_out);
  4201. enum xnn_status xnn_reshape_softmax_nc_f16(
  4202. xnn_operator_t softmax_op,
  4203. size_t channels,
  4204. size_t input_stride,
  4205. size_t output_stride,
  4206. size_t batch_size,
  4207. pthreadpool_t threadpool);
  4208. enum xnn_status xnn_setup_softmax_nc_f16(
  4209. xnn_operator_t softmax_op,
  4210. const void* input,
  4211. void* output);
  4212. enum xnn_status xnn_create_softmax_nc_f32(
  4213. uint32_t flags,
  4214. xnn_operator_t* softmax_op_out);
  4215. enum xnn_status xnn_reshape_softmax_nc_f32(
  4216. xnn_operator_t softmax_op,
  4217. size_t channels,
  4218. size_t input_stride,
  4219. size_t output_stride,
  4220. size_t batch_size,
  4221. pthreadpool_t threadpool);
  4222. enum xnn_status xnn_setup_softmax_nc_f32(
  4223. xnn_operator_t softmax_op,
  4224. const float* input,
  4225. float* output);
  4226. enum xnn_status xnn_create_softmax_nc_qu8(
  4227. float input_scale,
  4228. uint8_t output_zero_point,
  4229. float output_scale,
  4230. uint32_t flags,
  4231. xnn_operator_t* softmax_op_out);
  4232. enum xnn_status xnn_reshape_softmax_nc_qu8(
  4233. xnn_operator_t softmax_op,
  4234. size_t channels,
  4235. size_t input_stride,
  4236. size_t output_stride,
  4237. size_t batch_size,
  4238. pthreadpool_t threadpool);
  4239. enum xnn_status xnn_setup_softmax_nc_qu8(
  4240. xnn_operator_t softmax_op,
  4241. const uint8_t* input,
  4242. uint8_t* output);
  4243. enum xnn_status xnn_create_space_to_depth_nhwc_x16(
  4244. uint32_t block_size,
  4245. uint32_t flags,
  4246. xnn_operator_t* space_to_depth_op_out);
  4247. enum xnn_status xnn_reshape_space_to_depth_nhwc_x16(
  4248. xnn_operator_t space_to_depth_op,
  4249. size_t batch_size,
  4250. size_t input_height,
  4251. size_t input_width,
  4252. size_t input_channels,
  4253. size_t* output_height_out,
  4254. size_t* output_width_out,
  4255. size_t* output_channels_out,
  4256. pthreadpool_t threadpool);
  4257. enum xnn_status xnn_setup_space_to_depth_nhwc_x16(
  4258. xnn_operator_t space_to_depth_op,
  4259. const void* input,
  4260. void* output);
  4261. enum xnn_status xnn_create_space_to_depth_nhwc_x32(
  4262. uint32_t block_size,
  4263. uint32_t flags,
  4264. xnn_operator_t* space_to_depth_op_out);
  4265. enum xnn_status xnn_reshape_space_to_depth_nhwc_x32(
  4266. xnn_operator_t space_to_depth_op,
  4267. size_t batch_size,
  4268. size_t input_height,
  4269. size_t input_width,
  4270. size_t input_channels,
  4271. size_t* output_height_out,
  4272. size_t* output_width_out,
  4273. size_t* output_channels_out,
  4274. pthreadpool_t threadpool);
  4275. enum xnn_status xnn_setup_space_to_depth_nhwc_x32(
  4276. xnn_operator_t space_to_depth_op,
  4277. const void* input,
  4278. void* output);
  4279. enum xnn_status xnn_create_transpose_nd_x8(
  4280. uint32_t flags,
  4281. xnn_operator_t* transpose_op_out);
  4282. enum xnn_status xnn_reshape_transpose_nd_x8(
  4283. xnn_operator_t transpose_op,
  4284. size_t num_dims,
  4285. const size_t* input_shape,
  4286. const size_t* output_perm,
  4287. pthreadpool_t threadpool);
  4288. enum xnn_status xnn_setup_transpose_nd_x8(
  4289. xnn_operator_t transpose_op,
  4290. const void* input,
  4291. void* output);
  4292. enum xnn_status xnn_run_transpose_nd_x8(
  4293. const void* input,
  4294. void* output,
  4295. size_t num_dims,
  4296. const size_t* input_shape,
  4297. const size_t* output_perm,
  4298. uint32_t flags,
  4299. pthreadpool_t threadpool);
  4300. enum xnn_status xnn_create_transpose_nd_x16(
  4301. uint32_t flags,
  4302. xnn_operator_t* transpose_op_out);
  4303. enum xnn_status xnn_reshape_transpose_nd_x16(
  4304. xnn_operator_t transpose_op,
  4305. size_t num_dims,
  4306. const size_t* input_shape,
  4307. const size_t* output_perm,
  4308. pthreadpool_t threadpool);
  4309. enum xnn_status xnn_setup_transpose_nd_x16(
  4310. xnn_operator_t transpose_op,
  4311. const void* input,
  4312. void* output);
  4313. enum xnn_status xnn_run_transpose_nd_x16(
  4314. const void* input,
  4315. void* output,
  4316. size_t num_dims,
  4317. const size_t* input_shape,
  4318. const size_t* output_perm,
  4319. uint32_t flags,
  4320. pthreadpool_t threadpool);
  4321. enum xnn_status xnn_create_transpose_nd_x32(
  4322. uint32_t flags,
  4323. xnn_operator_t* transpose_op_out);
  4324. enum xnn_status xnn_reshape_transpose_nd_x32(
  4325. xnn_operator_t transpose_op,
  4326. size_t num_dims,
  4327. const size_t* input_shape,
  4328. const size_t* output_perm,
  4329. pthreadpool_t threadpool);
  4330. enum xnn_status xnn_setup_transpose_nd_x32(
  4331. xnn_operator_t transpose_op,
  4332. const void* input,
  4333. void* output);
  4334. enum xnn_status xnn_run_transpose_nd_x32(
  4335. const void* input,
  4336. void* output,
  4337. size_t num_dims,
  4338. const size_t* input_shape,
  4339. const size_t* output_perm,
  4340. uint32_t flags,
  4341. pthreadpool_t threadpool);
  4342. enum xnn_status xnn_create_transpose_nd_x64(
  4343. uint32_t flags,
  4344. xnn_operator_t* transpose_op_out);
  4345. enum xnn_status xnn_reshape_transpose_nd_x64(
  4346. xnn_operator_t transpose_op,
  4347. size_t num_dims,
  4348. const size_t* input_shape,
  4349. const size_t* output_perm,
  4350. pthreadpool_t threadpool);
  4351. enum xnn_status xnn_setup_transpose_nd_x64(
  4352. xnn_operator_t transpose_op,
  4353. const void* input,
  4354. void* output);
  4355. enum xnn_status xnn_run_transpose_nd_x64(
  4356. const void* input,
  4357. void* output,
  4358. size_t num_dims,
  4359. const size_t* input_shape,
  4360. const size_t* output_perm,
  4361. uint32_t flags,
  4362. pthreadpool_t threadpool);
  4363. enum xnn_status xnn_create_unpooling2d_nhwc_x32(
  4364. uint32_t input_padding_top,
  4365. uint32_t input_padding_right,
  4366. uint32_t input_padding_bottom,
  4367. uint32_t input_padding_left,
  4368. uint32_t pooling_height,
  4369. uint32_t pooling_width,
  4370. size_t channels,
  4371. size_t input_pixel_stride,
  4372. size_t output_pixel_stride,
  4373. uint32_t flags,
  4374. xnn_operator_t* unpooling_op_out);
  4375. enum xnn_status xnn_reshape_unpooling2d_nhwc_x32(
  4376. xnn_operator_t unpooling_op,
  4377. size_t batch_size,
  4378. size_t input_height,
  4379. size_t input_width,
  4380. size_t* output_height_out,
  4381. size_t* output_width_out,
  4382. pthreadpool_t threadpool);
  4383. enum xnn_status xnn_setup_unpooling2d_nhwc_x32(
  4384. xnn_operator_t unpooling_op,
  4385. const void* input,
  4386. const uint32_t* index,
  4387. void* output);
  4388. enum xnn_status xnn_create_slice_nd_x8(
  4389. uint32_t flags,
  4390. xnn_operator_t* slice_op_out);
  4391. enum xnn_status xnn_reshape_slice_nd_x8(
  4392. xnn_operator_t slice_op,
  4393. size_t num_dims,
  4394. const size_t* input_shape,
  4395. const size_t* offsets,
  4396. const size_t* sizes,
  4397. pthreadpool_t threadpool);
  4398. enum xnn_status xnn_setup_slice_nd_x8(
  4399. xnn_operator_t slice_op,
  4400. const void* input,
  4401. void* output);
  4402. enum xnn_status xnn_create_space_to_depth_nhwc_x8(
  4403. uint32_t block_size,
  4404. uint32_t flags,
  4405. xnn_operator_t* space_to_depth_op_out);
  4406. enum xnn_status xnn_reshape_space_to_depth_nhwc_x8(
  4407. xnn_operator_t space_to_depth_op,
  4408. size_t batch_size,
  4409. size_t input_height,
  4410. size_t input_width,
  4411. size_t input_channels,
  4412. size_t* output_height_out,
  4413. size_t* output_width_out,
  4414. size_t* output_channels_out,
  4415. pthreadpool_t threadpool);
  4416. enum xnn_status xnn_setup_space_to_depth_nhwc_x8(
  4417. xnn_operator_t space_to_depth_op,
  4418. const void* input,
  4419. void* output);
  4420. #ifdef __cplusplus
  4421. } // extern "C"
  4422. #endif
  4423. #else
  4424. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  4425. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)