Math.h 139 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/AccumulateType.h>
  4. #include <ATen/NumericUtils.h>
  5. #include <ATen/jiterator_macros.h>
  6. #include <c10/macros/Macros.h>
  7. #include <c10/util/BFloat16.h>
  8. #include <c10/util/Half.h>
  9. #include <c10/util/MathConstants.h>
  10. #include <cfloat>
  11. #include <cmath>
  12. #include <cstdint>
  13. #include <cstdlib>
  14. #include <limits>
  15. #include <type_traits>
  16. C10_CLANG_DIAGNOSTIC_PUSH()
  17. #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
  18. C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
  19. #endif
  20. /* The next function is taken from https://github.com/antelopeusersgroup/antelope_contrib/blob/master/lib/location/libgenloc/erfinv.c.
  21. Below is the copyright.
  22. Output was modified to be inf or -inf when input is 1 or -1. */
  23. /*
  24. Copyright (c) 2014 Indiana University
  25. All rights reserved.
  26. Written by Prof. Gary L. Pavlis, Dept. of Geol. Sci.,
  27. Indiana University, Bloomington, IN
  28. This software is licensed under the New BSD license:
  29. Redistribution and use in source and binary forms,
  30. with or without modification, are permitted provided
  31. that the following conditions are met:
  32. Redistributions of source code must retain the above
  33. copyright notice, this list of conditions and the
  34. following disclaimer.
  35. Redistributions in binary form must reproduce the
  36. above copyright notice, this list of conditions and
  37. the following disclaimer in the documentation and/or
  38. other materials provided with the distribution.
  39. Neither the name of Indiana University nor
  40. the names of its contributors may be used to endorse
  41. or promote products derived from this software without
  42. specific prior written permission.
  43. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
  44. CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
  45. WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  46. WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
  47. PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
  48. THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY
  49. DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  50. CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
  51. PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF
  52. USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
  53. HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
  54. IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
  55. NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
  56. USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  57. POSSIBILITY OF SUCH DAMAGE.
  58. */
  59. namespace {
  60. /*
  61. * This function is derived from the implementation of the i0e function in the
  62. * Cephes Math Library. See note [3-Clause BSD License for the Cephes Math
  63. * Library].
  64. *
  65. * Computes an approximation of the exponentially scaled zeroth order modified
  66. * Bessel function of the first kind. The approximation is actually two
  67. * (sub)approximations, both using a Chebyshev polynomial expansion. One
  68. * approximates the function over [0, 8], and the other over (8, infinity). This
  69. * function takes the absolute value of all inputs to convert them into the
  70. * domain of the approximation.
  71. */
  72. jiterator_also_stringify_as(jiterator_code(
  73. template <typename T>
  74. JITERATOR_HOST_DEVICE T chbevl(T x, const T array[], const int len) {
  75. T b0, b1, b2 = 0;
  76. b0 = array[0];
  77. b1 = 0;
  78. for (int i = 1; i < len; ++i) {
  79. b2 = b1;
  80. b1 = b0;
  81. b0 = x * b1 - b2 + array[i];
  82. }
  83. return T{0.5} * (b0 - b2);
  84. }
  85. template <typename T>
  86. JITERATOR_HOST_DEVICE T calc_i0e(T _x) {
  87. T x = std::fabs(_x);
  88. if (x <= T{8.0}) {
  89. static const T coefficients[] = {
  90. -4.41534164647933937950E-18, 3.33079451882223809783E-17,
  91. -2.43127984654795469359E-16, 1.71539128555513303061E-15,
  92. -1.16853328779934516808E-14, 7.67618549860493561688E-14,
  93. -4.85644678311192946090E-13, 2.95505266312963983461E-12,
  94. -1.72682629144155570723E-11, 9.67580903537323691224E-11,
  95. -5.18979560163526290666E-10, 2.65982372468238665035E-9,
  96. -1.30002500998624804212E-8, 6.04699502254191894932E-8,
  97. -2.67079385394061173391E-7, 1.11738753912010371815E-6,
  98. -4.41673835845875056359E-6, 1.64484480707288970893E-5,
  99. -5.75419501008210370398E-5, 1.88502885095841655729E-4,
  100. -5.76375574538582365885E-4, 1.63947561694133579842E-3,
  101. -4.32430999505057594430E-3, 1.05464603945949983183E-2,
  102. -2.37374148058994688156E-2, 4.93052842396707084878E-2,
  103. -9.49010970480476444210E-2, 1.71620901522208775349E-1,
  104. -3.04682672343198398683E-1, 6.76795274409476084995E-1};
  105. T y = (x / T{2.0}) - T{2.0};
  106. return chbevl(y, coefficients, int{30});
  107. }
  108. // x > 8
  109. static const T coefficients[] = {
  110. -7.23318048787475395456E-18, -4.83050448594418207126E-18,
  111. 4.46562142029675999901E-17, 3.46122286769746109310E-17,
  112. -2.82762398051658348494E-16, -3.42548561967721913462E-16,
  113. 1.77256013305652638360E-15, 3.81168066935262242075E-15,
  114. -9.55484669882830764870E-15, -4.15056934728722208663E-14,
  115. 1.54008621752140982691E-14, 3.85277838274214270114E-13,
  116. 7.18012445138366623367E-13, -1.79417853150680611778E-12,
  117. -1.32158118404477131188E-11, -3.14991652796324136454E-11,
  118. 1.18891471078464383424E-11, 4.94060238822496958910E-10,
  119. 3.39623202570838634515E-9, 2.26666899049817806459E-8,
  120. 2.04891858946906374183E-7, 2.89137052083475648297E-6,
  121. 6.88975834691682398426E-5, 3.36911647825569408990E-3,
  122. 8.04490411014108831608E-1};
  123. return chbevl(T{32.0} / x - T{2.0}, coefficients, int{25}) / std::sqrt(x);
  124. }),
  125. i0e_string) // i0e_string
  126. }
  127. #define CENTRAL_RANGE 0.7
  128. template <typename T>
  129. inline typename std::enable_if_t<std::is_floating_point_v<T>, T>
  130. calc_erfinv(T y) {
  131. /* Function to calculate inverse error function. Rational approximation
  132. is used to generate an initial approximation, which is then improved to
  133. full accuracy by two steps of Newton's method. Code is a direct
  134. translation of the erfinv m file in matlab version 2.0.
  135. Author: Gary L. Pavlis, Indiana University
  136. Date: February 1996
  137. */
  138. T x, z, num, dem; /*working variables */
  139. /* coefficients in rational expansion */
  140. T a[4] = { T(0.886226899), T(-1.645349621), T(0.914624893), T(-0.140543331) };
  141. T b[4] = { T(-2.118377725), T(1.442710462), T(-0.329097515), T(0.012229801) };
  142. T c[4] = { T(-1.970840454), T(-1.624906493), T(3.429567803), T(1.641345311) };
  143. T d[2] = { T(3.543889200), T(1.637067800) };
  144. T y_abs = std::abs(y);
  145. if(y_abs > 1.0) return std::numeric_limits<T>::quiet_NaN();
  146. #ifdef _WIN32
  147. // error C2039: '_copysign': is not a member of 'std'
  148. if(y_abs == 1.0) return copysign(std::numeric_limits<T>::infinity(), y);
  149. #else
  150. if(y_abs == 1.0) return std::copysign(std::numeric_limits<T>::infinity(), y);
  151. #endif
  152. if(y_abs <= static_cast<T>(CENTRAL_RANGE)) {
  153. z = y * y;
  154. num = (((a[3]*z + a[2])*z + a[1])*z + a[0]);
  155. dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0]) * z + static_cast<T>(1.0));
  156. x = y * num / dem;
  157. }
  158. else{
  159. z = std::sqrt(-std::log((static_cast<T>(1.0)-y_abs)/static_cast<T>(2.0)));
  160. num = ((c[3]*z + c[2])*z + c[1]) * z + c[0];
  161. dem = (d[1]*z + d[0])*z + static_cast<T>(1.0);
  162. #ifdef _WIN32
  163. // error C2039: '_copysign': is not a member of 'std'
  164. x = copysign(num, y) / dem;
  165. #else
  166. x = std::copysign(num, y) / dem;
  167. #endif
  168. }
  169. /* Two steps of Newton-Raphson correction */
  170. x = x - (std::erf(x) - y) / ((static_cast<T>(2.0)*c10::frac_1_sqrt_pi<T>)*std::exp(-x*x));
  171. x = x - (std::erf(x) - y) / ((static_cast<T>(2.0)*c10::frac_1_sqrt_pi<T>)*std::exp(-x*x));
  172. return x;
  173. }
  174. #undef CENTRAL_RANGE
  175. /*
  176. * Note [3-Clause BSD License for the Cephes Math Library]
  177. * Code derived from implementations in the Cephes Math Library should mention its derivation and reference
  178. * this note (ex. 'This function is derived from the implementation of X in the Cephes Math Library. See note
  179. * [3-Clause BSD License for the Cephes Math Library]. The license is:
  180. * Copyright (c) 2018, Steven Moshier
  181. * All rights reserved.
  182. *
  183. * Redistribution and use in source and binary forms, with or without
  184. * modification, are permitted provided that the following conditions are met:
  185. * * Redistributions of source code must retain the above copyright
  186. * notice, this list of conditions and the following disclaimer.
  187. * * Redistributions in binary form must reproduce the above copyright
  188. * notice, this list of conditions and the following disclaimer in the
  189. * documentation and/or other materials provided with the distribution.
  190. * * Neither the name of the nor the
  191. * names of its contributors may be used to endorse or promote products
  192. * derived from this software without specific prior written permission.
  193. *
  194. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  195. * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  196. * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  197. * DISCLAIMED. IN NO EVENT SHALL Steven Moshier BE LIABLE FOR ANY
  198. * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  199. * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  200. * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  201. * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  202. * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  203. * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  204. */
  205. /*
  206. * This function is derived from the implementation of the zeta function in the Cephes Math Library.
  207. * See note [3-Clause BSD License for the Cephes Math Library].
  208. */
  209. template <typename scalar_t, bool is_cuda=false>
  210. C10_HOST_DEVICE inline scalar_t zeta(scalar_t x, scalar_t q) __ubsan_ignore_float_divide_by_zero__ {
  211. using acc_t = at::acc_type<scalar_t, is_cuda>;
  212. const acc_t MACHEP = acc_t{1.11022302462515654042E-16};
  213. constexpr acc_t zero = acc_t{0.0};
  214. constexpr acc_t half = acc_t{0.5};
  215. constexpr acc_t one = acc_t{1.0};
  216. static const acc_t A[] = {
  217. 12.0,
  218. -720.0,
  219. 30240.0,
  220. -1209600.0,
  221. 47900160.0,
  222. -1.8924375803183791606e9, /*1.307674368e12/691*/
  223. 7.47242496e10,
  224. -2.950130727918164224e12, /*1.067062284288e16/3617*/
  225. 1.1646782814350067249e14, /*5.109094217170944e18/43867*/
  226. -4.5979787224074726105e15, /*8.028576626982912e20/174611*/
  227. 1.8152105401943546773e17, /*1.5511210043330985984e23/854513*/
  228. -7.1661652561756670113e18 /*1.6938241367317436694528e27/236364091*/
  229. };
  230. acc_t a, b, k, s, t, w;
  231. if (x == one) {
  232. return std::numeric_limits<scalar_t>::infinity();
  233. }
  234. if (x < one) {
  235. return std::numeric_limits<scalar_t>::quiet_NaN();
  236. }
  237. if (q <= zero) {
  238. if (q == std::floor(q)) {
  239. return std::numeric_limits<scalar_t>::infinity();
  240. }
  241. if (x != std::floor(x)) {
  242. return std::numeric_limits<scalar_t>::quiet_NaN();
  243. }
  244. }
  245. s = std::pow(q, -x);
  246. a = q;
  247. int i = 0;
  248. b = zero;
  249. while ((i < 9) || (a <= acc_t{9.0})) {
  250. i += 1;
  251. a += one;
  252. b = ::pow(a, -x);
  253. s += b;
  254. if ((-MACHEP * s < b) && (b < MACHEP * s)) {
  255. return static_cast<scalar_t>(s);
  256. }
  257. };
  258. w = a;
  259. s += b * w / (x - one);
  260. s -= half * b;
  261. a = one;
  262. k = zero;
  263. for (i = 0; i < 12; i++) {
  264. a *= x + k;
  265. b /= w;
  266. t = a * b / A[i];
  267. s = s + t;
  268. t = ::fabs(t / s);
  269. if (t < MACHEP) {
  270. return static_cast<scalar_t>(s);
  271. }
  272. k += one;
  273. a *= x + k;
  274. b /= w;
  275. k += one;
  276. }
  277. return static_cast<scalar_t>(s);
  278. }
  279. /*
  280. * This function is derived from the implementation of the digamma function in the Cephes Math Library.
  281. * See note [3-Clause BSD License for the Cephes Math Library].
  282. *
  283. * Evaluates polynomial of degree N:
  284. *
  285. * 2 N
  286. * y = C + C x + C x +...+ C x
  287. * 0 1 2 N
  288. *
  289. * Coefficients are stored in reverse order:
  290. *
  291. * coef[0] = C , ..., coef[N] = C .
  292. * N 0
  293. */
  294. template <typename T>
  295. C10_HOST_DEVICE inline T polevl(const T x, const T A[], size_t len) {
  296. T result = 0;
  297. for (size_t i = 0; i <= len; i++) {
  298. result = result * x + A[i];
  299. }
  300. return result;
  301. }
  302. inline double trigamma(double x) __ubsan_ignore_float_divide_by_zero__ {
  303. double sign = +1;
  304. double result = 0;
  305. if (x < 0.5) {
  306. sign = -1;
  307. const double sin_pi_x = sin(c10::pi<double> * x);
  308. result -= (c10::pi<double> * c10::pi<double>) / (sin_pi_x * sin_pi_x);
  309. x = 1 - x;
  310. }
  311. for (int i = 0; i < 6; ++i) {
  312. result += 1 / (x * x);
  313. x += 1;
  314. }
  315. const double ixx = 1 / (x*x);
  316. result += (1 + 1 / (2*x) + ixx * (1./6 - ixx * (1./30 - ixx * (1./42)))) / x;
  317. return sign * result;
  318. }
  319. inline float trigamma(float x) __ubsan_ignore_float_divide_by_zero__ {
  320. float sign = +1;
  321. float result = 0;
  322. if (x < 0.5f) {
  323. sign = -1;
  324. const float sin_pi_x = sinf(c10::pi<float> * x);
  325. result -= (c10::pi<float> * c10::pi<float>) / (sin_pi_x * sin_pi_x);
  326. x = 1 - x;
  327. }
  328. for (int i = 0; i < 6; ++i) {
  329. result += 1 / (x * x);
  330. x += 1;
  331. }
  332. const float ixx = 1 / (x*x);
  333. result += (1 + 1 / (2*x) + ixx * (1.f/6 - ixx * (1.f/30 - ixx * (1.f/42)))) / x;
  334. return sign * result;
  335. }
  336. /*
  337. * This function is derived from the implementation of the digamma function in the Cephes Math Library.
  338. * See note [3-Clause BSD License for the Cephes Math Library].
  339. */
  340. inline double calc_digamma(double x) {
  341. // [C++ Standard Reference: Gamma Function] https://en.cppreference.com/w/cpp/numeric/math/tgamma
  342. static double PSI_10 = 2.25175258906672110764;
  343. if (x == 0) {
  344. // As per C++ standard for gamma related functions and SciPy,
  345. // If the argument is ±0, ±∞ is returned
  346. return std::copysign(INFINITY, -x);
  347. }
  348. bool x_is_integer = x == trunc(x);
  349. if (x < 0) {
  350. if (x_is_integer) {
  351. // As per C++ standard for gamma related functions and SciPy,
  352. // If the argument is a negative integer, NaN is returned
  353. return std::numeric_limits<double>::quiet_NaN();
  354. }
  355. // Extracts the fractional part of x as r, since tan(pi * r) is more numerically
  356. // accurate than tan(pi * x). While these operations are mathematically equivalent
  357. // since both x and r are in radians and tan() has a periodicity of pi, in practice
  358. // the computation of pi * x is a source of error (when |x| > 1).
  359. double q, r;
  360. r = std::modf(x, &q);
  361. return calc_digamma(1 - x) - c10::pi<double> / tan(c10::pi<double> * r);
  362. }
  363. // Push x to be >= 10
  364. double result = 0;
  365. while (x < 10) {
  366. result -= 1 / x;
  367. x += 1;
  368. }
  369. if (x == 10) {
  370. return result + PSI_10;
  371. }
  372. // Compute asymptotic digamma
  373. static const double A[] = {
  374. 8.33333333333333333333E-2,
  375. -2.10927960927960927961E-2,
  376. 7.57575757575757575758E-3,
  377. -4.16666666666666666667E-3,
  378. 3.96825396825396825397E-3,
  379. -8.33333333333333333333E-3,
  380. 8.33333333333333333333E-2,
  381. };
  382. double y = 0;
  383. if (x < 1.0e17) {
  384. double z = 1.0 / (x * x);
  385. y = z * polevl(z, A, 6);
  386. }
  387. return result + log(x) - (0.5 / x) - y;
  388. }
  389. /*
  390. * This function is derived from the implementation of the digamma function in the Cephes Math Library.
  391. * See note [3-Clause BSD License for the Cephes Math Library].
  392. */
  393. inline float calc_digamma(float x) {
  394. // See [C++ Standard Reference: Gamma Function]
  395. static float PSI_10 = 2.25175258906672110764f;
  396. if (x == 0) {
  397. // As per C++ standard for gamma related functions and SciPy,
  398. // If the argument is ±0, ±∞ is returned
  399. return std::copysign(INFINITY, -x);
  400. }
  401. bool x_is_integer = x == truncf(x);
  402. if (x < 0) {
  403. if (x_is_integer) {
  404. // As per C++ standard for gamma related functions and SciPy,
  405. // If the argument is a negative integer, NaN is returned
  406. return std::numeric_limits<float>::quiet_NaN();
  407. }
  408. // Extracts the fractional part of x as r, since tan(pi * r) is more numerically
  409. // accurate than tan(pi * x). While these operations are mathematically equivalent
  410. // since both x and r are in radians and tan() has a periodicity of pi, in practice
  411. // the computation of pi * x is a source of error (when |x| > 1).
  412. double q, r;
  413. r = std::modf(x, &q);
  414. float pi_over_tan_pi_x = (float)(c10::pi<double> / tan(c10::pi<double> * r));
  415. return calc_digamma(1 - x) - pi_over_tan_pi_x;
  416. }
  417. // Push x to be >= 10
  418. float result = 0;
  419. while (x < 10) {
  420. result -= 1 / x;
  421. x += 1;
  422. }
  423. if (x == 10) {
  424. return result + PSI_10;
  425. }
  426. // Compute asymptotic digamma
  427. static const float A[] = {
  428. 8.33333333333333333333E-2f,
  429. -2.10927960927960927961E-2f,
  430. 7.57575757575757575758E-3f,
  431. -4.16666666666666666667E-3f,
  432. 3.96825396825396825397E-3f,
  433. -8.33333333333333333333E-3f,
  434. 8.33333333333333333333E-2f,
  435. };
  436. float y = 0;
  437. if (x < 1.0e17f) {
  438. float z = 1 / (x * x);
  439. y = z * polevl(z, A, 6);
  440. }
  441. return result + logf(x) - (0.5f / x) - y;
  442. }
  443. inline c10::BFloat16 calc_digamma(c10::BFloat16 a) {
  444. return calc_digamma(static_cast<float>(a));
  445. }
  446. inline c10::Half calc_digamma(c10::Half a) {
  447. return calc_digamma(static_cast<float>(a));
  448. }
  449. template <typename scalar_t, bool is_cuda=false>
  450. inline C10_HOST_DEVICE scalar_t calc_polygamma(scalar_t x, int n) {
  451. // already blocked if n <= 1
  452. const auto one = scalar_t{1};
  453. return ((n % 2) ? one : -one) *
  454. std::exp(std::lgamma(static_cast<scalar_t>(n) + one)) *
  455. zeta<scalar_t, is_cuda>(static_cast<scalar_t>(n + 1), x);
  456. }
  457. // regularized lower incomplete gamma
  458. // the regularized lower, upper incomplete gamma, as well as their
  459. // helper functions follow SciPy's implementation
  460. /* References
  461. * [igam1] "The Digital Library of Mathematical Functions", dlmf.nist.gov
  462. * [igam2] Maddock et al., "Incomplete Gamma Functions",
  463. * https://www.boost.org/doc/libs/1_61_0/libs/math/doc/html/math_toolkit/sf_gamma/igamma.html
  464. */
  465. /*
  466. * This implementation of the regularized incomplete gamma functions and
  467. * their helper functions are derived from the implementation of SciPy's
  468. * gammainc, Cephes's igam and igamc, and Boost's Lanczos approximations.
  469. * See NOTICE for the licenses.
  470. */
  471. template <typename scalar_t>
  472. scalar_t ratevl(scalar_t x, const scalar_t num[], int64_t M,
  473. const scalar_t denom[], int64_t N) {
  474. // evaluating rational function, i.e., the ratio of two polynomials
  475. // the coefficients for numerator are given by `num` while coeffs for
  476. // denumerator are given by `denom`
  477. int64_t i, dir;
  478. scalar_t y, num_ans, denom_ans;
  479. scalar_t absx = std::fabs(x);
  480. const scalar_t *p;
  481. if (absx > 1) {
  482. /* Evaluate as a polynomial in 1/x. */
  483. dir = -1;
  484. p = num + M;
  485. y = 1 / x;
  486. }
  487. else {
  488. dir = 1;
  489. p = num;
  490. y = x;
  491. }
  492. /* Evaluate the numerator */
  493. num_ans = *p;
  494. p += dir;
  495. for (i = 1; i <= M; i++) {
  496. num_ans = num_ans * y + *p;
  497. p += dir;
  498. }
  499. /* Evaluate the denominator */
  500. if (absx > 1) {
  501. p = denom + N;
  502. }
  503. else {
  504. p = denom;
  505. }
  506. denom_ans = *p;
  507. p += dir;
  508. for (i = 1; i <= N; i++) {
  509. denom_ans = denom_ans * y + *p;
  510. p += dir;
  511. }
  512. if (absx > 1) {
  513. i = N - M;
  514. return std::pow(x, i) * num_ans / denom_ans;
  515. }
  516. else {
  517. return num_ans / denom_ans;
  518. }
  519. }
  520. // SciPy's lanczos implementation is taken from Boost
  521. /* (C) Copyright John Maddock 2006.
  522. * Use, modification and distribution are subject to the
  523. * Boost Software License, Version 1.0. See
  524. * https://www.boost.org/LICENSE_1_0.txt or see NOTICE.
  525. */
  526. template <typename scalar_t>
  527. static scalar_t lanczos_sum_expg_scaled(scalar_t x) {
  528. // lanczos approximation
  529. static constexpr scalar_t lanczos_sum_expg_scaled_num[13] = {
  530. 0.006061842346248906525783753964555936883222,
  531. 0.5098416655656676188125178644804694509993,
  532. 19.51992788247617482847860966235652136208,
  533. 449.9445569063168119446858607650988409623,
  534. 6955.999602515376140356310115515198987526,
  535. 75999.29304014542649875303443598909137092,
  536. 601859.6171681098786670226533699352302507,
  537. 3481712.15498064590882071018964774556468,
  538. 14605578.08768506808414169982791359218571,
  539. 43338889.32467613834773723740590533316085,
  540. 86363131.28813859145546927288977868422342,
  541. 103794043.1163445451906271053616070238554,
  542. 56906521.91347156388090791033559122686859
  543. };
  544. static constexpr scalar_t lanczos_sum_expg_scaled_denom[13] = {
  545. 1.,
  546. 66.,
  547. 1925.,
  548. 32670.,
  549. 357423.,
  550. 2637558.,
  551. 13339535.,
  552. 45995730.,
  553. 105258076.,
  554. 150917976.,
  555. 120543840.,
  556. 39916800.,
  557. 0.
  558. };
  559. return ratevl(x, lanczos_sum_expg_scaled_num,
  560. sizeof(lanczos_sum_expg_scaled_num) / sizeof(lanczos_sum_expg_scaled_num[0]) - 1,
  561. lanczos_sum_expg_scaled_denom,
  562. sizeof(lanczos_sum_expg_scaled_denom) / sizeof(lanczos_sum_expg_scaled_denom[0]) - 1);
  563. }
  564. template <typename scalar_t>
  565. static scalar_t _igam_helper_fac(scalar_t a, scalar_t x) {
  566. // compute x^a * exp(-a) / gamma(a)
  567. // corrected from (15) and (16) in [igam2] by replacing exp(x - a) with
  568. // exp(a - x).
  569. scalar_t ax, fac, res, num, numfac;
  570. static scalar_t MAXLOG = std::is_same_v<scalar_t,double> ?
  571. 7.09782712893383996843E2 : 88.72283905206835;
  572. static scalar_t EXP1 = 2.718281828459045;
  573. static scalar_t lanczos_g = 6.024680040776729583740234375;
  574. if (std::fabs(a - x) > 0.4 * std::fabs(a)) {
  575. ax = a * std::log(x) - x - std::lgamma(a);
  576. if (ax < -MAXLOG) {
  577. return 0.0;
  578. }
  579. return std::exp(ax);
  580. }
  581. fac = a + lanczos_g - 0.5;
  582. res = std::sqrt(fac / EXP1) / lanczos_sum_expg_scaled(a);
  583. if ((a < 200) && (x < 200)) {
  584. res *= std::exp(a - x) * std::pow(x / fac, a);
  585. }
  586. else {
  587. num = x - a - lanczos_g + 0.5;
  588. numfac = num / fac;
  589. res *= std::exp(a * (std::log1p(numfac) - numfac) + x * (0.5 - lanczos_g) / fac);
  590. }
  591. return res;
  592. }
  593. template <typename scalar_t>
  594. static scalar_t _igam_helper_series(scalar_t a, scalar_t x) {
  595. // Compute igam using DLMF 8.11.4. [igam1]
  596. static scalar_t MACHEP = std::is_same_v<scalar_t, double> ?
  597. 1.11022302462515654042E-16 : 5.9604644775390625E-8;
  598. static int MAXITER = 2000;
  599. int i;
  600. scalar_t ans, ax, c, r;
  601. ax = _igam_helper_fac(a, x);
  602. if (ax == 0.0) {
  603. return 0.0;
  604. }
  605. /* power series */
  606. r = a;
  607. c = 1.0;
  608. ans = 1.0;
  609. for (i = 0; i < MAXITER; i++) {
  610. r += 1.0;
  611. c *= x / r;
  612. ans += c;
  613. if (c <= MACHEP * ans) {
  614. break;
  615. }
  616. }
  617. return (ans * ax / a);
  618. }
  619. template <typename scalar_t>
  620. static scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
  621. // Compute igamc using DLMF 8.7.3 [igam1]. This is related to the series in
  622. // _igam_helper_series but extra care is taken to avoid cancellation.
  623. int n;
  624. scalar_t fac = 1;
  625. scalar_t sum = 0;
  626. scalar_t term, logx;
  627. static scalar_t MAXITER = 2000;
  628. static scalar_t MACHEP = std::is_same_v<scalar_t, double> ?
  629. 1.11022302462515654042E-16 : 5.9604644775390625E-8;
  630. for (n = 1; n < MAXITER; n++) {
  631. fac *= -x / n;
  632. term = fac / (a + n);
  633. sum += term;
  634. if (std::fabs(term) <= MACHEP * std::fabs(sum)) {
  635. break;
  636. }
  637. }
  638. logx = std::log(x);
  639. term = -std::expm1(a * logx - std::lgamma(1+a));
  640. return term - std::exp(a * logx - std::lgamma(a)) * sum;
  641. }
  642. template <typename scalar_t>
  643. static scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) {
  644. // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
  645. static constexpr scalar_t d[25][25] =
  646. {{-3.3333333333333333e-1, 8.3333333333333333e-2, -1.4814814814814815e-2,
  647. 1.1574074074074074e-3, 3.527336860670194e-4, -1.7875514403292181e-4,
  648. 3.9192631785224378e-5, -2.1854485106799922e-6, -1.85406221071516e-6,
  649. 8.296711340953086e-7, -1.7665952736826079e-7, 6.7078535434014986e-9,
  650. 1.0261809784240308e-8, -4.3820360184533532e-9, 9.1476995822367902e-10,
  651. -2.551419399494625e-11, -5.8307721325504251e-11, 2.4361948020667416e-11,
  652. -5.0276692801141756e-12, 1.1004392031956135e-13, 3.3717632624009854e-13,
  653. -1.3923887224181621e-13, 2.8534893807047443e-14, -5.1391118342425726e-16,
  654. -1.9752288294349443e-15},
  655. {-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3,
  656. -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7,
  657. -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6,
  658. 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8,
  659. 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9,
  660. 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14,
  661. 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13,
  662. -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14,
  663. -4.13125571381061e-15},
  664. {4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4,
  665. 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5,
  666. -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6,
  667. -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10,
  668. -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9,
  669. 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11,
  670. 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12,
  671. 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17,
  672. 8.8592218725911273e-15},
  673. {6.4943415637860082e-4, 2.2947209362139918e-4, -4.6918949439525571e-4,
  674. 2.6772063206283885e-4, -7.5618016718839764e-5, -2.3965051138672967e-7,
  675. 1.1082654115347302e-5, -5.6749528269915966e-6, 1.4230900732435884e-6,
  676. -2.7861080291528142e-11, -1.6958404091930277e-7, 8.0994649053880824e-8,
  677. -1.9111168485973654e-8, 2.3928620439808118e-12, 2.0620131815488798e-9,
  678. -9.4604966618551322e-10, 2.1541049775774908e-10, -1.388823336813903e-14,
  679. -2.1894761681963939e-11, 9.7909989511716851e-12, -2.1782191880180962e-12,
  680. 6.2088195734079014e-17, 2.126978363279737e-13, -9.3446887915174333e-14,
  681. 2.0453671226782849e-14},
  682. {-8.618882909167117e-4, 7.8403922172006663e-4, -2.9907248030319018e-4,
  683. -1.4638452578843418e-6, 6.6414982154651222e-5, -3.9683650471794347e-5,
  684. 1.1375726970678419e-5, 2.5074972262375328e-10, -1.6954149536558306e-6,
  685. 8.9075075322053097e-7, -2.2929348340008049e-7, 2.956794137544049e-11,
  686. 2.8865829742708784e-8, -1.4189739437803219e-8, 3.4463580499464897e-9,
  687. -2.3024517174528067e-13, -3.9409233028046405e-10, 1.8602338968504502e-10,
  688. -4.356323005056618e-11, 1.2786001016296231e-15, 4.6792750266579195e-12,
  689. -2.1492464706134829e-12, 4.9088156148096522e-13, -6.3385914848915603e-18,
  690. -5.0453320690800944e-14},
  691. {-3.3679855336635815e-4, -6.9728137583658578e-5, 2.7727532449593921e-4,
  692. -1.9932570516188848e-4, 6.7977804779372078e-5, 1.419062920643967e-7,
  693. -1.3594048189768693e-5, 8.0184702563342015e-6, -2.2914811765080952e-6,
  694. -3.252473551298454e-10, 3.4652846491085265e-7, -1.8447187191171343e-7,
  695. 4.8240967037894181e-8, -1.7989466721743515e-14, -6.3061945000135234e-9,
  696. 3.1624176287745679e-9, -7.8409242536974293e-10, 5.1926791652540407e-15,
  697. 9.3589442423067836e-11, -4.5134262161632782e-11, 1.0799129993116827e-11,
  698. -3.661886712685252e-17, -1.210902069055155e-12, 5.6807435849905643e-13,
  699. -1.3249659916340829e-13},
  700. {5.3130793646399222e-4, -5.9216643735369388e-4, 2.7087820967180448e-4,
  701. 7.9023532326603279e-7, -8.1539693675619688e-5, 5.6116827531062497e-5,
  702. -1.8329116582843376e-5, -3.0796134506033048e-9, 3.4651553688036091e-6,
  703. -2.0291327396058604e-6, 5.7887928631490037e-7, 2.338630673826657e-13,
  704. -8.8286007463304835e-8, 4.7435958880408128e-8, -1.2545415020710382e-8,
  705. 8.6496488580102925e-14, 1.6846058979264063e-9, -8.5754928235775947e-10,
  706. 2.1598224929232125e-10, -7.6132305204761539e-16, -2.6639822008536144e-11,
  707. 1.3065700536611057e-11, -3.1799163902367977e-12, 4.7109761213674315e-18,
  708. 3.6902800842763467e-13},
  709. {3.4436760689237767e-4, 5.1717909082605922e-5, -3.3493161081142236e-4,
  710. 2.812695154763237e-4, -1.0976582244684731e-4, -1.2741009095484485e-7,
  711. 2.7744451511563644e-5, -1.8263488805711333e-5, 5.7876949497350524e-6,
  712. 4.9387589339362704e-10, -1.0595367014026043e-6, 6.1667143761104075e-7,
  713. -1.7562973359060462e-7, -1.2974473287015439e-12, 2.695423606288966e-8,
  714. -1.4578352908731271e-8, 3.887645959386175e-9, -3.8810022510194121e-17,
  715. -5.3279941738772867e-10, 2.7437977643314845e-10, -6.9957960920705679e-11,
  716. 2.5899863874868481e-17, 8.8566890996696381e-12, -4.403168815871311e-12,
  717. 1.0865561947091654e-12},
  718. {-6.5262391859530942e-4, 8.3949872067208728e-4, -4.3829709854172101e-4,
  719. -6.969091458420552e-7, 1.6644846642067548e-4, -1.2783517679769219e-4,
  720. 4.6299532636913043e-5, 4.5579098679227077e-9, -1.0595271125805195e-5,
  721. 6.7833429048651666e-6, -2.1075476666258804e-6, -1.7213731432817145e-11,
  722. 3.7735877416110979e-7, -2.1867506700122867e-7, 6.2202288040189269e-8,
  723. 6.5977038267330006e-16, -9.5903864974256858e-9, 5.2132144922808078e-9,
  724. -1.3991589583935709e-9, 5.382058999060575e-16, 1.9484714275467745e-10,
  725. -1.0127287556389682e-10, 2.6077347197254926e-11, -5.0904186999932993e-18,
  726. -3.3721464474854592e-12},
  727. {-5.9676129019274625e-4, -7.2048954160200106e-5, 6.7823088376673284e-4,
  728. -6.4014752602627585e-4, 2.7750107634328704e-4, 1.8197008380465151e-7,
  729. -8.4795071170685032e-5, 6.105192082501531e-5, -2.1073920183404862e-5,
  730. -8.8585890141255994e-10, 4.5284535953805377e-6, -2.8427815022504408e-6,
  731. 8.7082341778646412e-7, 3.6886101871706965e-12, -1.5344695190702061e-7,
  732. 8.862466778790695e-8, -2.5184812301826817e-8, -1.0225912098215092e-14,
  733. 3.8969470758154777e-9, -2.1267304792235635e-9, 5.7370135528051385e-10,
  734. -1.887749850169741e-19, -8.0931538694657866e-11, 4.2382723283449199e-11,
  735. -1.1002224534207726e-11},
  736. {1.3324454494800656e-3, -1.9144384985654775e-3, 1.1089369134596637e-3,
  737. 9.932404122642299e-7, -5.0874501293093199e-4, 4.2735056665392884e-4,
  738. -1.6858853767910799e-4, -8.1301893922784998e-9, 4.5284402370562147e-5,
  739. -3.127053674781734e-5, 1.044986828530338e-5, 4.8435226265680926e-11,
  740. -2.1482565873456258e-6, 1.329369701097492e-6, -4.0295693092101029e-7,
  741. -1.7567877666323291e-13, 7.0145043163668257e-8, -4.040787734999483e-8,
  742. 1.1474026743371963e-8, 3.9642746853563325e-18, -1.7804938269892714e-9,
  743. 9.7480262548731646e-10, -2.6405338676507616e-10, 5.794875163403742e-18,
  744. 3.7647749553543836e-11},
  745. {1.579727660730835e-3, 1.6251626278391582e-4, -2.0633421035543276e-3,
  746. 2.1389686185689098e-3, -1.0108559391263003e-3, -3.9912705529919201e-7,
  747. 3.6235025084764691e-4, -2.8143901463712154e-4, 1.0449513336495887e-4,
  748. 2.1211418491830297e-9, -2.5779417251947842e-5, 1.7281818956040463e-5,
  749. -5.6413773872904282e-6, -1.1024320105776174e-11, 1.1223224418895175e-6,
  750. -6.8693396379526735e-7, 2.0653236975414887e-7, 4.6714772409838506e-14,
  751. -3.5609886164949055e-8, 2.0470855345905963e-8, -5.8091738633283358e-9,
  752. -1.332821287582869e-16, 9.0354604391335133e-10, -4.9598782517330834e-10,
  753. 1.3481607129399749e-10},
  754. {-4.0725121195140166e-3, 6.4033628338080698e-3, -4.0410161081676618e-3,
  755. -2.183732802866233e-6, 2.1740441801254639e-3, -1.9700440518418892e-3,
  756. 8.3595469747962458e-4, 1.9445447567109655e-8, -2.5779387120421696e-4,
  757. 1.9009987368139304e-4, -6.7696499937438965e-5, -1.4440629666426572e-10,
  758. 1.5712512518742269e-5, -1.0304008744776893e-5, 3.304517767401387e-6,
  759. 7.9829760242325709e-13, -6.4097794149313004e-7, 3.8894624761300056e-7,
  760. -1.1618347644948869e-7, -2.816808630596451e-15, 1.9878012911297093e-8,
  761. -1.1407719956357511e-8, 3.2355857064185555e-9, 4.1759468293455945e-20,
  762. -5.0423112718105824e-10},
  763. {-5.9475779383993003e-3, -5.4016476789260452e-4, 8.7910413550767898e-3,
  764. -9.8576315587856125e-3, 5.0134695031021538e-3, 1.2807521786221875e-6,
  765. -2.0626019342754683e-3, 1.7109128573523058e-3, -6.7695312714133799e-4,
  766. -6.9011545676562133e-9, 1.8855128143995902e-4, -1.3395215663491969e-4,
  767. 4.6263183033528039e-5, 4.0034230613321351e-11, -1.0255652921494033e-5,
  768. 6.612086372797651e-6, -2.0913022027253008e-6, -2.0951775649603837e-13,
  769. 3.9756029041993247e-7, -2.3956211978815887e-7, 7.1182883382145864e-8,
  770. 8.925574873053455e-16, -1.2101547235064676e-8, 6.9350618248334386e-9,
  771. -1.9661464453856102e-9},
  772. {1.7402027787522711e-2, -2.9527880945699121e-2, 2.0045875571402799e-2,
  773. 7.0289515966903407e-6, -1.2375421071343148e-2, 1.1976293444235254e-2,
  774. -5.4156038466518525e-3, -6.3290893396418616e-8, 1.8855118129005065e-3,
  775. -1.473473274825001e-3, 5.5515810097708387e-4, 5.2406834412550662e-10,
  776. -1.4357913535784836e-4, 9.9181293224943297e-5, -3.3460834749478311e-5,
  777. -3.5755837291098993e-12, 7.1560851960630076e-6, -4.5516802628155526e-6,
  778. 1.4236576649271475e-6, 1.8803149082089664e-14, -2.6623403898929211e-7,
  779. 1.5950642189595716e-7, -4.7187514673841102e-8, -6.5107872958755177e-17,
  780. 7.9795091026746235e-9},
  781. {3.0249124160905891e-2, 2.4817436002649977e-3, -4.9939134373457022e-2,
  782. 5.9915643009307869e-2, -3.2483207601623391e-2, -5.7212968652103441e-6,
  783. 1.5085251778569354e-2, -1.3261324005088445e-2, 5.5515262632426148e-3,
  784. 3.0263182257030016e-8, -1.7229548406756723e-3, 1.2893570099929637e-3,
  785. -4.6845138348319876e-4, -1.830259937893045e-10, 1.1449739014822654e-4,
  786. -7.7378565221244477e-5, 2.5625836246985201e-5, 1.0766165333192814e-12,
  787. -5.3246809282422621e-6, 3.349634863064464e-6, -1.0381253128684018e-6,
  788. -5.608909920621128e-15, 1.9150821930676591e-7, -1.1418365800203486e-7,
  789. 3.3654425209171788e-8},
  790. {-9.9051020880159045e-2, 1.7954011706123486e-1, -1.2989606383463778e-1,
  791. -3.1478872752284357e-5, 9.0510635276848131e-2, -9.2828824411184397e-2,
  792. 4.4412112839877808e-2, 2.7779236316835888e-7, -1.7229543805449697e-2,
  793. 1.4182925050891573e-2, -5.6214161633747336e-3, -2.39598509186381e-9,
  794. 1.6029634366079908e-3, -1.1606784674435773e-3, 4.1001337768153873e-4,
  795. 1.8365800754090661e-11, -9.5844256563655903e-5, 6.3643062337764708e-5,
  796. -2.076250624489065e-5, -1.1806020912804483e-13, 4.2131808239120649e-6,
  797. -2.6262241337012467e-6, 8.0770620494930662e-7, 6.0125912123632725e-16,
  798. -1.4729737374018841e-7},
  799. {-1.9994542198219728e-1, -1.5056113040026424e-2, 3.6470239469348489e-1,
  800. -4.6435192311733545e-1, 2.6640934719197893e-1, 3.4038266027147191e-5,
  801. -1.3784338709329624e-1, 1.276467178337056e-1, -5.6213828755200985e-2,
  802. -1.753150885483011e-7, 1.9235592956768113e-2, -1.5088821281095315e-2,
  803. 5.7401854451350123e-3, 1.0622382710310225e-9, -1.5335082692563998e-3,
  804. 1.0819320643228214e-3, -3.7372510193945659e-4, -6.6170909729031985e-12,
  805. 8.4263617380909628e-5, -5.5150706827483479e-5, 1.7769536448348069e-5,
  806. 3.8827923210205533e-14, -3.53513697488768e-6, 2.1865832130045269e-6,
  807. -6.6812849447625594e-7},
  808. {7.2438608504029431e-1, -1.3918010932653375, 1.0654143352413968,
  809. 1.876173868950258e-4, -8.2705501176152696e-1, 8.9352433347828414e-1,
  810. -4.4971003995291339e-1, -1.6107401567546652e-6, 1.9235590165271091e-1,
  811. -1.6597702160042609e-1, 6.8882222681814333e-2, 1.3910091724608687e-8,
  812. -2.146911561508663e-2, 1.6228980898865892e-2, -5.9796016172584256e-3,
  813. -1.1287469112826745e-10, 1.5167451119784857e-3, -1.0478634293553899e-3,
  814. 3.5539072889126421e-4, 8.1704322111801517e-13, -7.7773013442452395e-5,
  815. 5.0291413897007722e-5, -1.6035083867000518e-5, 1.2469354315487605e-14,
  816. 3.1369106244517615e-6},
  817. {1.6668949727276811, 1.165462765994632e-1, -3.3288393225018906,
  818. 4.4692325482864037, -2.6977693045875807, -2.600667859891061e-4,
  819. 1.5389017615694539, -1.4937962361134612, 6.8881964633233148e-1,
  820. 1.3077482004552385e-6, -2.5762963325596288e-1, 2.1097676102125449e-1,
  821. -8.3714408359219882e-2, -7.7920428881354753e-9, 2.4267923064833599e-2,
  822. -1.7813678334552311e-2, 6.3970330388900056e-3, 4.9430807090480523e-11,
  823. -1.5554602758465635e-3, 1.0561196919903214e-3, -3.5277184460472902e-4,
  824. 9.3002334645022459e-14, 7.5285855026557172e-5, -4.8186515569156351e-5,
  825. 1.5227271505597605e-5},
  826. {-6.6188298861372935, 1.3397985455142589e+1, -1.0789350606845146e+1,
  827. -1.4352254537875018e-3, 9.2333694596189809, -1.0456552819547769e+1,
  828. 5.5105526029033471, 1.2024439690716742e-5, -2.5762961164755816,
  829. 2.3207442745387179, -1.0045728797216284, -1.0207833290021914e-7,
  830. 3.3975092171169466e-1, -2.6720517450757468e-1, 1.0235252851562706e-1,
  831. 8.4329730484871625e-10, -2.7998284958442595e-2, 2.0066274144976813e-2,
  832. -7.0554368915086242e-3, 1.9402238183698188e-12, 1.6562888105449611e-3,
  833. -1.1082898580743683e-3, 3.654545161310169e-4, -5.1290032026971794e-11,
  834. -7.6340103696869031e-5},
  835. {-1.7112706061976095e+1, -1.1208044642899116, 3.7131966511885444e+1,
  836. -5.2298271025348962e+1, 3.3058589696624618e+1, 2.4791298976200222e-3,
  837. -2.061089403411526e+1, 2.088672775145582e+1, -1.0045703956517752e+1,
  838. -1.2238783449063012e-5, 4.0770134274221141, -3.473667358470195,
  839. 1.4329352617312006, 7.1359914411879712e-8, -4.4797257159115612e-1,
  840. 3.4112666080644461e-1, -1.2699786326594923e-1, -2.8953677269081528e-10,
  841. 3.3125776278259863e-2, -2.3274087021036101e-2, 8.0399993503648882e-3,
  842. -1.177805216235265e-9, -1.8321624891071668e-3, 1.2108282933588665e-3,
  843. -3.9479941246822517e-4},
  844. {7.389033153567425e+1, -1.5680141270402273e+2, 1.322177542759164e+2,
  845. 1.3692876877324546e-2, -1.2366496885920151e+2, 1.4620689391062729e+2,
  846. -8.0365587724865346e+1, -1.1259851148881298e-4, 4.0770132196179938e+1,
  847. -3.8210340013273034e+1, 1.719522294277362e+1, 9.3519707955168356e-7,
  848. -6.2716159907747034, 5.1168999071852637, -2.0319658112299095,
  849. -4.9507215582761543e-9, 5.9626397294332597e-1, -4.4220765337238094e-1,
  850. 1.6079998700166273e-1, -2.4733786203223402e-8, -4.0307574759979762e-2,
  851. 2.7849050747097869e-2, -9.4751858992054221e-3, 6.419922235909132e-6,
  852. 2.1250180774699461e-3},
  853. {2.1216837098382522e+2, 1.3107863022633868e+1, -4.9698285932871748e+2,
  854. 7.3121595266969204e+2, -4.8213821720890847e+2, -2.8817248692894889e-2,
  855. 3.2616720302947102e+2, -3.4389340280087117e+2, 1.7195193870816232e+2,
  856. 1.4038077378096158e-4, -7.52594195897599e+1, 6.651969984520934e+1,
  857. -2.8447519748152462e+1, -7.613702615875391e-7, 9.5402237105304373,
  858. -7.5175301113311376, 2.8943997568871961, -4.6612194999538201e-7,
  859. -8.0615149598794088e-1, 5.8483006570631029e-1, -2.0845408972964956e-1,
  860. 1.4765818959305817e-4, 5.1000433863753019e-2, -3.3066252141883665e-2,
  861. 1.5109265210467774e-2},
  862. {-9.8959643098322368e+2, 2.1925555360905233e+3, -1.9283586782723356e+3,
  863. -1.5925738122215253e-1, 1.9569985945919857e+3, -2.4072514765081556e+3,
  864. 1.3756149959336496e+3, 1.2920735237496668e-3, -7.525941715948055e+2,
  865. 7.3171668742208716e+2, -3.4137023466220065e+2, -9.9857390260608043e-6,
  866. 1.3356313181291573e+2, -1.1276295161252794e+2, 4.6310396098204458e+1,
  867. -7.9237387133614756e-6, -1.4510726927018646e+1, 1.1111771248100563e+1,
  868. -4.1690817945270892, 3.1008219800117808e-3, 1.1220095449981468,
  869. -7.6052379926149916e-1, 3.6262236505085254e-1, 2.216867741940747e-1,
  870. 4.8683443692930507e-1}};
  871. int k, n, sgn;
  872. int maxpow = 0;
  873. static scalar_t MACHEP = std::is_same_v<scalar_t, double> ?
  874. 1.11022302462515654042E-16 : 5.9604644775390625E-8;
  875. scalar_t lambda = x / a;
  876. scalar_t sigma = (x - a) / a;
  877. scalar_t eta, res, ck, ckterm, term, absterm;
  878. scalar_t absoldterm = INFINITY;
  879. scalar_t etapow[25] = {1};
  880. scalar_t sum = 0;
  881. scalar_t afac = 1;
  882. if (igam) {
  883. sgn = -1;
  884. }
  885. else {
  886. sgn = 1;
  887. }
  888. if (lambda > 1) {
  889. eta = std::sqrt(-2 * (std::log1p(sigma) - sigma));
  890. }
  891. else if (lambda < 1) {
  892. eta = -std::sqrt(-2 * (std::log1p(sigma) - sigma));
  893. }
  894. else {
  895. eta = 0;
  896. }
  897. res = 0.5 * std::erfc(sgn * eta * std::sqrt(a / 2));
  898. for (k = 0; k < 25; k++) {
  899. ck = d[k][0];
  900. for (n = 1; n < 25; n++) {
  901. if (n > maxpow) {
  902. etapow[n] = eta * etapow[n-1];
  903. maxpow += 1;
  904. }
  905. ckterm = d[k][n]*etapow[n];
  906. ck += ckterm;
  907. if (std::fabs(ckterm) < MACHEP * std::fabs(ck)) {
  908. break;
  909. }
  910. }
  911. term = ck * afac;
  912. absterm = std::fabs(term);
  913. if (absterm > absoldterm) {
  914. break;
  915. }
  916. sum += term;
  917. if (absterm < MACHEP * std::fabs(sum)) {
  918. break;
  919. }
  920. absoldterm = absterm;
  921. afac /= a;
  922. }
  923. res += sgn * std::exp(-0.5 * a * eta * eta) * sum / std::sqrt(2 * c10::pi<float> * a);
  924. return res;
  925. }
  926. template <typename scalar_t>
  927. static scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar_t x) {
  928. // Compute igamc using DLMF 8.9.2. [igam1]
  929. int i;
  930. scalar_t ans, ax, c, yc, r, t, y, z;
  931. scalar_t pk, pkm1, pkm2, qk, qkm1, qkm2;
  932. int MAXITER = 2000;
  933. static scalar_t MACHEP = std::is_same_v<scalar_t, double> ?
  934. 1.11022302462515654042E-16 : 5.9604644775390625E-8;
  935. static scalar_t BIG = std::is_same_v<scalar_t,double> ?
  936. 4.503599627370496e15 : 16777216.;
  937. static scalar_t BIGINV = std::is_same_v<scalar_t,double> ?
  938. 2.22044604925031308085e-16 : 5.9604644775390625E-8;
  939. ax = _igam_helper_fac(a, x);
  940. if (ax == 0.0) {
  941. return 0.0;
  942. }
  943. /* continued fraction */
  944. y = 1.0 - a;
  945. z = x + y + 1.0;
  946. c = 0.0;
  947. pkm2 = 1.0;
  948. qkm2 = x;
  949. pkm1 = x + 1.0;
  950. qkm1 = z * x;
  951. ans = pkm1 / qkm1;
  952. for (i = 0; i < MAXITER; i++) {
  953. c += 1.0;
  954. y += 1.0;
  955. z += 2.0;
  956. yc = y * c;
  957. pk = pkm1 * z - pkm2 * yc;
  958. qk = qkm1 * z - qkm2 * yc;
  959. if (qk != 0) {
  960. r = pk / qk;
  961. t = std::fabs((ans - r) / r);
  962. ans = r;
  963. }
  964. else {
  965. t = 1.0;
  966. }
  967. pkm2 = pkm1;
  968. pkm1 = pk;
  969. qkm2 = qkm1;
  970. qkm1 = qk;
  971. if (std::fabs(pk) > BIG) {
  972. pkm2 *= BIGINV;
  973. pkm1 *= BIGINV;
  974. qkm2 *= BIGINV;
  975. qkm1 *= BIGINV;
  976. }
  977. if (t <= MACHEP) {
  978. break;
  979. }
  980. }
  981. return ans * ax;
  982. }
  983. template <typename scalar_t>
  984. inline scalar_t calc_igammac(scalar_t a, scalar_t x) {
  985. /* the calculation of the regularized upper incomplete gamma function
  986. * is done differently based on the values of a and x:
  987. * - if x and/or a is at the boundary of defined region, then assign the
  988. * result at the boundary
  989. * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for
  990. * Large Parameter (see DLMF 8.12.4 [igam1])
  991. * - if x > 1.1 and x < a, using the subtraction from the regularized lower
  992. * incomplete gamma
  993. * - otherwise, calculate the series from [igam2] eq (5)
  994. */
  995. scalar_t absxma_a;
  996. static scalar_t SMALL = 20.0;
  997. static scalar_t LARGE = 200.0;
  998. static scalar_t SMALLRATIO = 0.3;
  999. static scalar_t LARGERATIO = 4.5;
  1000. // note that in SciPy, a and x are non-negative, with exclusive 0s (i.e.,
  1001. // at most 1 of them can be 0), where igammac(0, x) = 0.0 iff x > 0.
  1002. if ((x < 0) || (a < 0)) {
  1003. // out of defined-region of the function
  1004. return std::numeric_limits<scalar_t>::quiet_NaN();
  1005. }
  1006. else if (a == 0) {
  1007. if (x > 0) {
  1008. return 0.0;
  1009. }
  1010. else {
  1011. return std::numeric_limits<scalar_t>::quiet_NaN();
  1012. }
  1013. }
  1014. else if (x == 0) {
  1015. return 1.0;
  1016. }
  1017. else if (std::isinf(a)) {
  1018. if (std::isinf(x)) {
  1019. return std::numeric_limits<scalar_t>::quiet_NaN();
  1020. }
  1021. return 1.0;
  1022. }
  1023. else if (std::isinf(x)) {
  1024. return 0.0;
  1025. }
  1026. absxma_a = std::fabs(x - a) / a;
  1027. if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) {
  1028. return _igam_helper_asymptotic_series(a, x, 0);
  1029. }
  1030. else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) {
  1031. return _igam_helper_asymptotic_series(a, x, 0);
  1032. }
  1033. if (x > 1.1) {
  1034. if (x < a) {
  1035. return 1.0 - _igam_helper_series(a, x);
  1036. }
  1037. else {
  1038. return _igamc_helper_continued_fraction(a, x);
  1039. }
  1040. }
  1041. else if (x <= 0.5) {
  1042. if (-0.4 / std::log(x) < a) {
  1043. return 1.0 - _igam_helper_series(a, x);
  1044. }
  1045. else {
  1046. return _igamc_helper_series(a, x);
  1047. }
  1048. }
  1049. else {
  1050. if (x * 1.1 < a) {
  1051. return 1.0 - _igam_helper_series(a, x);
  1052. }
  1053. else {
  1054. return _igamc_helper_series(a, x);
  1055. }
  1056. }
  1057. }
  1058. template <typename scalar_t>
  1059. scalar_t calc_igamma(scalar_t a, scalar_t x) {
  1060. /* the calculation of the regularized lower incomplete gamma function
  1061. * is done differently based on the values of a and x:
  1062. * - if x and/or a is at the boundary of defined region, then assign the
  1063. * result at the boundary
  1064. * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for
  1065. * Large Parameter (see DLMF 8.12.3 [igam1])
  1066. * - if x > 1 and x > a, using the subtraction from the regularized upper
  1067. * incomplete gamma
  1068. * - otherwise, calculate the series from [igam2] eq (4)
  1069. */
  1070. scalar_t absxma_a;
  1071. static scalar_t SMALL = 20.0;
  1072. static scalar_t LARGE = 200.0;
  1073. static scalar_t SMALLRATIO = 0.3;
  1074. static scalar_t LARGERATIO = 4.5;
  1075. // boundary values following SciPy
  1076. // note that in SciPy, a and x are non-negative, with exclusive 0s (i.e.,
  1077. // at most 1 of them can be 0), where igamma(0, x) = 1.0 iff x > 0.
  1078. if ((x < 0) || (a < 0)) {
  1079. // out of defined-region of the function
  1080. return std::numeric_limits<scalar_t>::quiet_NaN();
  1081. }
  1082. else if (a == 0) {
  1083. if (x > 0) {
  1084. return 1.0;
  1085. }
  1086. else {
  1087. return std::numeric_limits<scalar_t>::quiet_NaN();
  1088. }
  1089. }
  1090. else if (x == 0) {
  1091. return 0.0; // zero integration limit
  1092. }
  1093. else if (std::isinf(a)) {
  1094. if (std::isinf(x)) {
  1095. return std::numeric_limits<scalar_t>::quiet_NaN();
  1096. }
  1097. return 0.0;
  1098. }
  1099. else if (std::isinf(x)) {
  1100. return 1.0;
  1101. }
  1102. /* Asymptotic regime where a ~ x. See [igam2] */
  1103. absxma_a = std::fabs(x - a) / a;
  1104. if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) {
  1105. return _igam_helper_asymptotic_series(a, x, 1);
  1106. }
  1107. else if ((a > LARGE) && (absxma_a < LARGERATIO / std::sqrt(a))) {
  1108. return _igam_helper_asymptotic_series(a, x, 1);
  1109. }
  1110. if ((x > 1.0) && (x > a)) {
  1111. return 1.0 - calc_igammac(a, x);
  1112. }
  1113. return _igam_helper_series(a, x);
  1114. }
  1115. template <>
  1116. [[maybe_unused]] inline c10::BFloat16 calc_igamma<c10::BFloat16>(
  1117. c10::BFloat16 a,
  1118. c10::BFloat16 x) {
  1119. return calc_igamma<float>(float(a), float(x));
  1120. }
  1121. template <>
  1122. [[maybe_unused]] inline c10::Half calc_igamma<c10::Half>(
  1123. c10::Half a,
  1124. c10::Half x) {
  1125. return calc_igamma<float>(float(a), float(x));
  1126. }
  1127. template <>
  1128. [[maybe_unused]] inline c10::BFloat16 calc_igammac<c10::BFloat16>(
  1129. c10::BFloat16 a,
  1130. c10::BFloat16 x) {
  1131. return calc_igammac<float>(float(a), float(x));
  1132. }
  1133. template <>
  1134. [[maybe_unused]] inline c10::Half calc_igammac<c10::Half>(
  1135. c10::Half a,
  1136. c10::Half x) {
  1137. return calc_igammac<float>(float(a), float(x));
  1138. }
  1139. inline c10::BFloat16 calc_erfinv(c10::BFloat16 a) { return calc_erfinv(float(a)); }
  1140. template <typename T>
  1141. inline T abs_impl(T v) {
  1142. return std::abs(v);
  1143. }
  1144. template <>
  1145. [[maybe_unused]] inline uint8_t abs_impl(uint8_t v) {
  1146. return v;
  1147. }
  1148. template <typename T>
  1149. inline typename std::enable_if_t<std::is_integral_v<T>, T>
  1150. calc_gcd(T a, T b) {
  1151. a = abs_impl(a);
  1152. b = abs_impl(b);
  1153. while (a != 0) {
  1154. T c = a;
  1155. a = b % a;
  1156. b = c;
  1157. }
  1158. return b;
  1159. }
  1160. template <typename T>
  1161. C10_HOST_DEVICE T exp2_impl(T x) {
  1162. return std::exp2(x);
  1163. }
  1164. template <typename T>
  1165. C10_HOST_DEVICE c10::complex<T> exp2_impl(c10::complex<T> x) {
  1166. // There is no std::exp2 overload for complex, so instead
  1167. // use the identity 2^x = e^(ln(2) * x)
  1168. constexpr auto ln2 = c10::ln_2<T>;
  1169. return std::exp(ln2 * x);
  1170. }
  1171. /*
  1172. * This function is derived from the implementation of the chbevl function in the Cephes Math Library.
  1173. * See note [3-Clause BSD License for the Cephes Math Library].
  1174. *
  1175. * Evaluates the series
  1176. *
  1177. * len-1
  1178. * - '
  1179. * y = > array[i] T (x/2)
  1180. * - i
  1181. * i=0
  1182. *
  1183. * of Chebyshev polynomials Ti at argument x/2.
  1184. *
  1185. * Coefficients are stored in reverse order, i.e. the zero order term is last in the array. Note len is the number of
  1186. * coefficients, not the order.
  1187. *
  1188. * If coefficients are for the interval a to b, x must have been transformed to x -> 2(2x - b - a)/(b-a) before
  1189. * entering the routine. This maps x from (a, b) to (-1, 1), over which the Chebyshev polynomials are defined.
  1190. *
  1191. * If the coefficients are for the inverted interval, in which (a, b) is mapped to (1/b, 1/a), the transformation
  1192. * required is x -> 2(2ab/x - b - a)/(b-a). If b is infinity, this becomes x -> 4a/x - 1.
  1193. */
  1194. template <typename T>
  1195. inline typename std::enable_if_t<std::is_floating_point_v<T>, T>
  1196. chbevl(const T x, const T array[], size_t len) {
  1197. T b0, b1, b2 = static_cast<T>(0.0);
  1198. b0 = array[0];
  1199. b1 = static_cast<T>(0.0);
  1200. for (size_t i = 1; i < len; ++i) {
  1201. b2 = b1;
  1202. b1 = b0;
  1203. b0 = x * b1 - b2 + array[i];
  1204. }
  1205. return (static_cast<T>(0.5) * (b0 - b2));
  1206. }
  1207. /*
  1208. * This function is derived from the implementation of the i0 function in the Cephes Math Library.
  1209. * See note [3-Clause BSD License for the Cephes Math Library].
  1210. *
  1211. * Computes an approximation of the zeroth order modified Bessel function of the first kind.
  1212. * The approximation is actually two (sub)approximations, both using a Chebyshev polynomial expansion.
  1213. * One approximates the function over [0, 8], and the other over (8, infinity). This function takes the absolute value
  1214. * of all inputs to convert them into the domain of the approximation.
  1215. */
  1216. template <typename T>
  1217. inline std::tuple<const T*, size_t> chebyshev_coefficients_i0e_A() {
  1218. /* Chebyshev coefficients for exp(-x) I0(x)
  1219. * in the interval [0,8].
  1220. *
  1221. * lim(x->0){ exp(-x) I0(x) } = 1.
  1222. */
  1223. static const T coeff[] = {
  1224. -4.41534164647933937950E-18, 3.33079451882223809783E-17,
  1225. -2.43127984654795469359E-16, 1.71539128555513303061E-15,
  1226. -1.16853328779934516808E-14, 7.67618549860493561688E-14,
  1227. -4.85644678311192946090E-13, 2.95505266312963983461E-12,
  1228. -1.72682629144155570723E-11, 9.67580903537323691224E-11,
  1229. -5.18979560163526290666E-10, 2.65982372468238665035E-9,
  1230. -1.30002500998624804212E-8, 6.04699502254191894932E-8,
  1231. -2.67079385394061173391E-7, 1.11738753912010371815E-6,
  1232. -4.41673835845875056359E-6, 1.64484480707288970893E-5,
  1233. -5.75419501008210370398E-5, 1.88502885095841655729E-4,
  1234. -5.76375574538582365885E-4, 1.63947561694133579842E-3,
  1235. -4.32430999505057594430E-3, 1.05464603945949983183E-2,
  1236. -2.37374148058994688156E-2, 4.93052842396707084878E-2,
  1237. -9.49010970480476444210E-2, 1.71620901522208775349E-1,
  1238. -3.04682672343198398683E-1, 6.76795274409476084995E-1};
  1239. return std::make_tuple(coeff, 30);
  1240. }
  1241. template <typename T>
  1242. inline std::tuple<const T*, size_t> chebyshev_coefficients_i0e_B() {
  1243. /* Chebyshev coefficients for exp(-x) sqrt(x) I0(x)
  1244. * in the inverted interval [8,infinity].
  1245. *
  1246. * lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi).
  1247. */
  1248. static const T coeff[] = {
  1249. -7.23318048787475395456E-18, -4.83050448594418207126E-18,
  1250. 4.46562142029675999901E-17, 3.46122286769746109310E-17,
  1251. -2.82762398051658348494E-16, -3.42548561967721913462E-16,
  1252. 1.77256013305652638360E-15, 3.81168066935262242075E-15,
  1253. -9.55484669882830764870E-15, -4.15056934728722208663E-14,
  1254. 1.54008621752140982691E-14, 3.85277838274214270114E-13,
  1255. 7.18012445138366623367E-13, -1.79417853150680611778E-12,
  1256. -1.32158118404477131188E-11, -3.14991652796324136454E-11,
  1257. 1.18891471078464383424E-11, 4.94060238822496958910E-10,
  1258. 3.39623202570838634515E-9, 2.26666899049817806459E-8,
  1259. 2.04891858946906374183E-7, 2.89137052083475648297E-6,
  1260. 6.88975834691682398426E-5, 3.36911647825569408990E-3,
  1261. 8.04490411014108831608E-1};
  1262. return std::make_tuple(coeff, 25);
  1263. }
  1264. template <typename T>
  1265. inline typename std::enable_if_t<std::is_same_v<double, T>, std::tuple<const T*, size_t>>
  1266. chebyshev_coefficients_i1e_A() {
  1267. /* Chebyshev coefficients for exp(-x) I1(x)
  1268. * in the interval [0,8].
  1269. *
  1270. * lim(x->0){ exp(-x) I1(x) / x } = 1/2.
  1271. */
  1272. static const T coeff[] = {
  1273. 2.77791411276104639959E-18, -2.11142121435816608115E-17,
  1274. 1.55363195773620046921E-16, -1.10559694773538630805E-15,
  1275. 7.60068429473540693410E-15, -5.04218550472791168711E-14,
  1276. 3.22379336594557470981E-13, -1.98397439776494371520E-12,
  1277. 1.17361862988909016308E-11, -6.66348972350202774223E-11,
  1278. 3.62559028155211703701E-10, -1.88724975172282928790E-9,
  1279. 9.38153738649577178388E-9, -4.44505912879632808065E-8,
  1280. 2.00329475355213526229E-7, -8.56872026469545474066E-7,
  1281. 3.47025130813767847674E-6, -1.32731636560394358279E-5,
  1282. 4.78156510755005422638E-5, -1.61760815825896745588E-4,
  1283. 5.12285956168575772895E-4, -1.51357245063125314899E-3,
  1284. 4.15642294431288815669E-3, -1.05640848946261981558E-2,
  1285. 2.47264490306265168283E-2, -5.29459812080949914269E-2,
  1286. 1.02643658689847095384E-1, -1.76416518357834055153E-1,
  1287. 2.52587186443633654823E-1};
  1288. return std::make_tuple(coeff, 29);
  1289. }
  1290. template <typename T>
  1291. inline typename std::enable_if_t<std::is_same_v<float, T>, std::tuple<const T*, size_t>>
  1292. chebyshev_coefficients_i1e_A() {
  1293. /* Chebyshev coefficients for exp(-x) I1(x)
  1294. * in the interval [0,8].
  1295. *
  1296. * lim(x->0){ exp(-x) I1(x) / x } = 1/2.
  1297. */
  1298. static const T coeff[] = {
  1299. 9.38153738649577178388E-9f,
  1300. -4.44505912879632808065E-8f,
  1301. 2.00329475355213526229E-7f,
  1302. -8.56872026469545474066E-7f,
  1303. 3.47025130813767847674E-6f,
  1304. -1.32731636560394358279E-5f,
  1305. 4.78156510755005422638E-5f,
  1306. -1.61760815825896745588E-4f,
  1307. 5.12285956168575772895E-4f,
  1308. -1.51357245063125314899E-3f,
  1309. 4.15642294431288815669E-3f,
  1310. -1.05640848946261981558E-2f,
  1311. 2.47264490306265168283E-2f,
  1312. -5.29459812080949914269E-2f,
  1313. 1.02643658689847095384E-1f,
  1314. -1.76416518357834055153E-1f,
  1315. 2.52587186443633654823E-1f};
  1316. return std::make_tuple(coeff, 17);
  1317. }
  1318. template <typename T>
  1319. inline typename std::enable_if_t<std::is_same_v<double, T>, std::tuple<const T*, size_t>>
  1320. chebyshev_coefficients_i1e_B() {
  1321. /* Chebyshev coefficients for exp(-x) sqrt(x) I1(x)
  1322. * in the inverted interval [8,infinity].
  1323. *
  1324. * lim(x->inf){ exp(-x) sqrt(x) I1(x) } = 1/sqrt(2pi).
  1325. */
  1326. static const T coeff[] = {
  1327. 7.51729631084210481353E-18, 4.41434832307170791151E-18,
  1328. -4.65030536848935832153E-17, -3.20952592199342395980E-17,
  1329. 2.96262899764595013876E-16, 3.30820231092092828324E-16,
  1330. -1.88035477551078244854E-15, -3.81440307243700780478E-15,
  1331. 1.04202769841288027642E-14, 4.27244001671195135429E-14,
  1332. -2.10154184277266431302E-14, -4.08355111109219731823E-13,
  1333. -7.19855177624590851209E-13, 2.03562854414708950722E-12,
  1334. 1.41258074366137813316E-11, 3.25260358301548823856E-11,
  1335. -1.89749581235054123450E-11, -5.58974346219658380687E-10,
  1336. -3.83538038596423702205E-9, -2.63146884688951950684E-8,
  1337. -2.51223623787020892529E-7, -3.88256480887769039346E-6,
  1338. -1.10588938762623716291E-4, -9.76109749136146840777E-3,
  1339. 7.78576235018280120474E-1};
  1340. return std::make_tuple(coeff, 25);
  1341. }
  1342. template <typename T>
  1343. inline typename std::enable_if_t<std::is_same_v<float, T>, std::tuple<const T*, size_t>>
  1344. chebyshev_coefficients_i1e_B() {
  1345. /* Chebyshev coefficients for exp(-x) sqrt(x) I1(x)
  1346. * in the inverted interval [8,infinity].
  1347. *
  1348. * lim(x->inf){ exp(-x) sqrt(x) I1(x) } = 1/sqrt(2pi).
  1349. */
  1350. static const T coeff[] = {
  1351. -3.83538038596423702205E-9f,
  1352. -2.63146884688951950684E-8f,
  1353. -2.51223623787020892529E-7f,
  1354. -3.88256480887769039346E-6f,
  1355. -1.10588938762623716291E-4f,
  1356. -9.76109749136146840777E-3f,
  1357. 7.78576235018280120474E-1f};
  1358. return std::make_tuple(coeff, 7);
  1359. }
  1360. template <typename T>
  1361. inline typename std::enable_if_t<std::is_floating_point_v<T>, T>
  1362. calc_i0(T _x) {
  1363. T x = std::abs(_x);
  1364. if (x <= T{8.0}) {
  1365. auto [A, len] = chebyshev_coefficients_i0e_A<T>();
  1366. T y = (x / T{2.0}) - T{2.0};
  1367. return static_cast<T>(std::exp(x) * chbevl(y, A, len));
  1368. }
  1369. auto [B, len] = chebyshev_coefficients_i0e_B<T>();
  1370. return std::exp(x) * chbevl(T{32.0} / x - T{2.0}, B, len) / std::sqrt(x);
  1371. }
  1372. // Upcast bfloat16/half input to float for numerical accuracy purposes
  1373. inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cast<float>(a)); }
  1374. inline c10::Half calc_i0(c10::Half a) { return calc_i0(static_cast<float>(a)); }
  1375. /*
  1376. * This function is derived from the implementation of the i1 function in the Cephes Math Library.
  1377. * See note [3-Clause BSD License for the Cephes Math Library].
  1378. *
  1379. * Computes an approximation of the first order modified Bessel function of the first kind.
  1380. * The approximation is actually two (sub)approximations, both using a Chebyshev polynomial expansion.
  1381. * One approximates the function over [0, 8], and the other over (8, infinity). This function takes the absolute value
  1382. * of all inputs to convert them into the domain of the approximation.
  1383. */
  1384. template <typename T>
  1385. inline typename std::enable_if_t<std::is_floating_point_v<T>, T>
  1386. calc_i1(T _x) {
  1387. T x = std::abs(_x);
  1388. if (x <= T{8.0}) {
  1389. auto [A, len] = chebyshev_coefficients_i1e_A<T>();
  1390. T y = (x / T{2.0}) - T{2.0};
  1391. const T out = std::exp(x) * x * chbevl(y, A, len);
  1392. return (_x < T{0.0}) ? -out : out;
  1393. }
  1394. auto [B, len] = chebyshev_coefficients_i1e_B<T>();
  1395. const T out = (std::exp(x) * chbevl(T{32.0} / x - T{2.0}, B, len)) / std::sqrt(x);
  1396. return (_x < T{0.0}) ? -out : out;
  1397. }
  1398. // Upcast bfloat16/half input to float for numerical accuracy purposes
  1399. inline c10::BFloat16 calc_i1(c10::BFloat16 a) { return calc_i1(static_cast<float>(a)); }
  1400. inline c10::Half calc_i1(c10::Half a) { return calc_i1(static_cast<float>(a)); }
  1401. /*
  1402. * This function is derived from the implementation of the i1e function in the Cephes Math Library.
  1403. * See note [3-Clause BSD License for the Cephes Math Library].
  1404. *
  1405. * Computes an approximation of the exponentially scaled first order modified Bessel function of the first kind.
  1406. * The approximation is actually two (sub)approximations, both using a Chebyshev polynomial expansion.
  1407. * One approximates the function over [0, 8], and the other over (8, infinity). This function takes the absolute value
  1408. * of all inputs to convert them into the domain of the approximation.
  1409. */
  1410. template <typename T>
  1411. inline typename std::enable_if_t<std::is_floating_point_v<T>, T>
  1412. calc_i1e(T _x) {
  1413. T x = std::abs(_x);
  1414. if (x <= T{8.0}) {
  1415. auto [A, len] = chebyshev_coefficients_i1e_A<T>();
  1416. T y = (x / T{2.0}) - T{2.0};
  1417. const T out = chbevl(y, A, len) * x;
  1418. return (_x < T{0.0}) ? -out : out;
  1419. }
  1420. auto [B, len] = chebyshev_coefficients_i1e_B<T>();
  1421. const auto out = chbevl(T{32.0} / x - T{2.0}, B, len) / std::sqrt(x);
  1422. return (_x < T{0.0}) ? -out : out;
  1423. }
  1424. // Upcast bfloat16/half input to float for numerical accuracy purposes
  1425. inline c10::BFloat16 calc_i1e(c10::BFloat16 a) { return calc_i1e(static_cast<float>(a)); }
  1426. inline c10::Half calc_i1e(c10::Half a) { return calc_i1e(static_cast<float>(a)); }
  1427. /*
  1428. * This function is derived from the implementation of the i1e function in the Cephes Math Library.
  1429. * See note [3-Clause BSD License for the Cephes Math Library].
  1430. *
  1431. * Computes the argument, x, for which the area under the Gaussian probability density function
  1432. * (integrated from minus infinity to x) is equal to y.
  1433. */
  1434. template <typename T>
  1435. inline C10_HOST_DEVICE T calc_ndtri(T y0) {
  1436. /* sqrt(2pi) */
  1437. constexpr T s2pi = 2.50662827463100050242E0;
  1438. constexpr T one = 1;
  1439. constexpr T zero = 0;
  1440. /* approximation for 0 <= |y - 0.5| <= 3/8 */
  1441. static const T P0[5] = {
  1442. -5.99633501014107895267E1,
  1443. 9.80010754185999661536E1,
  1444. -5.66762857469070293439E1,
  1445. 1.39312609387279679503E1,
  1446. -1.23916583867381258016E0,
  1447. };
  1448. static const T Q0[9] = {
  1449. 1.00000000000000000000E0,
  1450. 1.95448858338141759834E0,
  1451. 4.67627912898881538453E0,
  1452. 8.63602421390890590575E1,
  1453. -2.25462687854119370527E2,
  1454. 2.00260212380060660359E2,
  1455. -8.20372256168333339912E1,
  1456. 1.59056225126211695515E1,
  1457. -1.18331621121330003142E0,
  1458. };
  1459. /* Approximation for interval z = sqrt(-2 log y ) between 2 and 8
  1460. * i.e., y between exp(-2) = .135 and exp(-32) = 1.27e-14.
  1461. */
  1462. static const T P1[9] = {
  1463. 4.05544892305962419923E0,
  1464. 3.15251094599893866154E1,
  1465. 5.71628192246421288162E1,
  1466. 4.40805073893200834700E1,
  1467. 1.46849561928858024014E1,
  1468. 2.18663306850790267539E0,
  1469. -1.40256079171354495875E-1,
  1470. -3.50424626827848203418E-2,
  1471. -8.57456785154685413611E-4,
  1472. };
  1473. static const T Q1[9] = {
  1474. 1.00000000000000000000E0,
  1475. 1.57799883256466749731E1,
  1476. 4.53907635128879210584E1,
  1477. 4.13172038254672030440E1,
  1478. 1.50425385692907503408E1,
  1479. 2.50464946208309415979E0,
  1480. -1.42182922854787788574E-1,
  1481. -3.80806407691578277194E-2,
  1482. -9.33259480895457427372E-4,
  1483. };
  1484. /* Approximation for interval z = sqrt(-2 log y ) between 8 and 64
  1485. * i.e., y between exp(-32) = 1.27e-14 and exp(-2048) = 3.67e-890.
  1486. */
  1487. static const T P2[9] = {
  1488. 3.23774891776946035970E0,
  1489. 6.91522889068984211695E0,
  1490. 3.93881025292474443415E0,
  1491. 1.33303460815807542389E0,
  1492. 2.01485389549179081538E-1,
  1493. 1.23716634817820021358E-2,
  1494. 3.01581553508235416007E-4,
  1495. 2.65806974686737550832E-6,
  1496. 6.23974539184983293730E-9,
  1497. };
  1498. static const T Q2[9] = {
  1499. 1.00000000000000000000E0,
  1500. 6.02427039364742014255E0,
  1501. 3.67983563856160859403E0,
  1502. 1.37702099489081330271E0,
  1503. 2.16236993594496635890E-1,
  1504. 1.34204006088543189037E-2,
  1505. 3.28014464682127739104E-4,
  1506. 2.89247864745380683936E-6,
  1507. 6.79019408009981274425E-9,
  1508. };
  1509. if (y0 == zero) {
  1510. return -std::numeric_limits<T>::infinity();
  1511. }
  1512. if (y0 == one) {
  1513. return std::numeric_limits<T>::infinity();
  1514. }
  1515. if (y0 < zero || y0 > one) {
  1516. return std::numeric_limits<T>::quiet_NaN();
  1517. }
  1518. bool code = true;
  1519. T y = y0;
  1520. if (y > one - T{0.13533528323661269189}) { /* 0.135... = exp(-2) */
  1521. y = one - y;
  1522. code = false;
  1523. }
  1524. if (y > T{0.13533528323661269189}) {
  1525. y = y - T{0.5};
  1526. const T y2 = y * y;
  1527. T x = y + y * (y2 * polevl(y2, P0, 4) / polevl(y2, Q0, 8));
  1528. return (x * s2pi);
  1529. }
  1530. T x = ::sqrt(T{-2.0} * ::log(y));
  1531. const T x0 = x - ::log(x) / x;
  1532. const T z = one / x;
  1533. T x1;
  1534. if (x < T{8.0}) /* y > exp(-32) = 1.2664165549e-14 */
  1535. {
  1536. x1 = z * polevl(z, P1, 8) / polevl(z, Q1, 8);
  1537. } else {
  1538. x1 = z * polevl(z, P2, 8) / polevl(z, Q2, 8);
  1539. }
  1540. x = x0 - x1;
  1541. if (code) {
  1542. x = -x;
  1543. }
  1544. return x;
  1545. }
  1546. /* The next function is taken from http://ab-initio.mit.edu/faddeeva */
  1547. /* Copyright (c) 2012 Massachusetts Institute of Technology
  1548. *
  1549. * Permission is hereby granted, free of charge, to any person obtaining
  1550. * a copy of this software and associated documentation files (the
  1551. * "Software"), to deal in the Software without restriction, including
  1552. * without limitation the rights to use, copy, modify, merge, publish,
  1553. * distribute, sublicense, and/or sell copies of the Software, and to
  1554. * permit persons to whom the Software is furnished to do so, subject to
  1555. * the following conditions:
  1556. *
  1557. * The above copyright notice and this permission notice shall be
  1558. * included in all copies or substantial portions of the Software.
  1559. *
  1560. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
  1561. * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
  1562. * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
  1563. * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
  1564. * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
  1565. * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
  1566. * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
  1567. */
  1568. /* erfcx(x) = exp(x^2) erfc(x) function, for real x, written by
  1569. Steven G. Johnson, October 2012.
  1570. This function combines a few different ideas.
  1571. First, for x > 50, it uses a continued-fraction expansion (same as
  1572. for the Faddeeva function, but with algebraic simplifications for z=i*x).
  1573. Second, for 0 <= x <= 50, it uses Chebyshev polynomial approximations,
  1574. but with two twists:
  1575. a) It maps x to y = 4 / (4+x) in [0,1]. This simple transformation,
  1576. inspired by a similar transformation in the octave-forge/specfun
  1577. erfcx by Soren Hauberg, results in much faster Chebyshev convergence
  1578. than other simple transformations I have examined.
  1579. b) Instead of using a single Chebyshev polynomial for the entire
  1580. [0,1] y interval, we break the interval up into 100 equal
  1581. subintervals, with a switch/lookup table, and use much lower
  1582. degree Chebyshev polynomials in each subinterval. This greatly
  1583. improves performance in my tests.
  1584. For x < 0, we use the relationship erfcx(-x) = 2 exp(x^2) - erfc(x),
  1585. with the usual checks for overflow etcetera.
  1586. Performance-wise, it seems to be substantially faster than either
  1587. the SLATEC DERFC function [or an erfcx function derived there from]
  1588. or Cody's CALERF function (from netlib.org/specfun), while
  1589. retaining near machine precision in accuracy. */
  1590. /* Given y100=100*y, where y = 4/(4+x) for x >= 0, compute erfc(x).
  1591. Uses a look-up table of 100 different Chebyshev polynomials
  1592. for y intervals [0,0.01], [0.01,0.02], ...., [0.99,1], generated
  1593. with the help of Maple and a little shell script. This allows
  1594. the Chebyshev polynomials to be of significantly lower degree (about 1/4)
  1595. compared to fitting the whole [0,1] interval with a single polynomial. */
  1596. template <typename T>
  1597. C10_HOST_DEVICE inline typename std::enable_if_t<std::is_floating_point_v<T>, T>
  1598. erfcx_y100(T y100)
  1599. {
  1600. switch (static_cast<int>(y100)) {
  1601. case 0: {
  1602. T t = 2*y100 - 1;
  1603. return 0.70878032454106438663e-3 + (0.71234091047026302958e-3 + (0.35779077297597742384e-5 + (0.17403143962587937815e-7 + (0.81710660047307788845e-10 + (0.36885022360434957634e-12 + 0.15917038551111111111e-14 * t) * t) * t) * t) * t) * t;
  1604. }
  1605. case 1: {
  1606. T t = 2*y100 - 3;
  1607. return 0.21479143208285144230e-2 + (0.72686402367379996033e-3 + (0.36843175430938995552e-5 + (0.18071841272149201685e-7 + (0.85496449296040325555e-10 + (0.38852037518534291510e-12 + 0.16868473576888888889e-14 * t) * t) * t) * t) * t) * t;
  1608. }
  1609. case 2: {
  1610. T t = 2*y100 - 5;
  1611. return 0.36165255935630175090e-2 + (0.74182092323555510862e-3 + (0.37948319957528242260e-5 + (0.18771627021793087350e-7 + (0.89484715122415089123e-10 + (0.40935858517772440862e-12 + 0.17872061464888888889e-14 * t) * t) * t) * t) * t) * t;
  1612. }
  1613. case 3: {
  1614. T t = 2*y100 - 7;
  1615. return 0.51154983860031979264e-2 + (0.75722840734791660540e-3 + (0.39096425726735703941e-5 + (0.19504168704300468210e-7 + (0.93687503063178993915e-10 + (0.43143925959079664747e-12 + 0.18939926435555555556e-14 * t) * t) * t) * t) * t) * t;
  1616. }
  1617. case 4: {
  1618. T t = 2*y100 - 9;
  1619. return 0.66457513172673049824e-2 + (0.77310406054447454920e-3 + (0.40289510589399439385e-5 + (0.20271233238288381092e-7 + (0.98117631321709100264e-10 + (0.45484207406017752971e-12 + 0.20076352213333333333e-14 * t) * t) * t) * t) * t) * t;
  1620. }
  1621. case 5: {
  1622. T t = 2*y100 - 11;
  1623. return 0.82082389970241207883e-2 + (0.78946629611881710721e-3 + (0.41529701552622656574e-5 + (0.21074693344544655714e-7 + (0.10278874108587317989e-9 + (0.47965201390613339638e-12 + 0.21285907413333333333e-14 * t) * t) * t) * t) * t) * t;
  1624. }
  1625. case 6: {
  1626. T t = 2*y100 - 13;
  1627. return 0.98039537275352193165e-2 + (0.80633440108342840956e-3 + (0.42819241329736982942e-5 + (0.21916534346907168612e-7 + (0.10771535136565470914e-9 + (0.50595972623692822410e-12 + 0.22573462684444444444e-14 * t) * t) * t) * t) * t) * t;
  1628. }
  1629. case 7: {
  1630. T t = 2*y100 - 15;
  1631. return 0.11433927298290302370e-1 + (0.82372858383196561209e-3 + (0.44160495311765438816e-5 + (0.22798861426211986056e-7 + (0.11291291745879239736e-9 + (0.53386189365816880454e-12 + 0.23944209546666666667e-14 * t) * t) * t) * t) * t) * t;
  1632. }
  1633. case 8: {
  1634. T t = 2*y100 - 17;
  1635. return 0.13099232878814653979e-1 + (0.84167002467906968214e-3 + (0.45555958988457506002e-5 + (0.23723907357214175198e-7 + (0.11839789326602695603e-9 + (0.56346163067550237877e-12 + 0.25403679644444444444e-14 * t) * t) * t) * t) * t) * t;
  1636. }
  1637. case 9: {
  1638. T t = 2*y100 - 19;
  1639. return 0.14800987015587535621e-1 + (0.86018092946345943214e-3 + (0.47008265848816866105e-5 + (0.24694040760197315333e-7 + (0.12418779768752299093e-9 + (0.59486890370320261949e-12 + 0.26957764568888888889e-14 * t) * t) * t) * t) * t) * t;
  1640. }
  1641. case 10: {
  1642. T t = 2*y100 - 21;
  1643. return 0.16540351739394069380e-1 + (0.87928458641241463952e-3 + (0.48520195793001753903e-5 + (0.25711774900881709176e-7 + (0.13030128534230822419e-9 + (0.62820097586874779402e-12 + 0.28612737351111111111e-14 * t) * t) * t) * t) * t) * t;
  1644. }
  1645. case 11: {
  1646. T t = 2*y100 - 23;
  1647. return 0.18318536789842392647e-1 + (0.89900542647891721692e-3 + (0.50094684089553365810e-5 + (0.26779777074218070482e-7 + (0.13675822186304615566e-9 + (0.66358287745352705725e-12 + 0.30375273884444444444e-14 * t) * t) * t) * t) * t) * t;
  1648. }
  1649. case 12: {
  1650. T t = 2*y100 - 25;
  1651. return 0.20136801964214276775e-1 + (0.91936908737673676012e-3 + (0.51734830914104276820e-5 + (0.27900878609710432673e-7 + (0.14357976402809042257e-9 + (0.70114790311043728387e-12 + 0.32252476000000000000e-14 * t) * t) * t) * t) * t) * t;
  1652. }
  1653. case 13: {
  1654. T t = 2*y100 - 27;
  1655. return 0.21996459598282740954e-1 + (0.94040248155366777784e-3 + (0.53443911508041164739e-5 + (0.29078085538049374673e-7 + (0.15078844500329731137e-9 + (0.74103813647499204269e-12 + 0.34251892320000000000e-14 * t) * t) * t) * t) * t) * t;
  1656. }
  1657. case 14: {
  1658. T t = 2*y100 - 29;
  1659. return 0.23898877187226319502e-1 + (0.96213386835900177540e-3 + (0.55225386998049012752e-5 + (0.30314589961047687059e-7 + (0.15840826497296335264e-9 + (0.78340500472414454395e-12 + 0.36381553564444444445e-14 * t) * t) * t) * t) * t) * t;
  1660. }
  1661. case 15: {
  1662. T t = 2*y100 - 31;
  1663. return 0.25845480155298518485e-1 + (0.98459293067820123389e-3 + (0.57082915920051843672e-5 + (0.31613782169164830118e-7 + (0.16646478745529630813e-9 + (0.82840985928785407942e-12 + 0.38649975768888888890e-14 * t) * t) * t) * t) * t) * t;
  1664. }
  1665. case 16: {
  1666. T t = 2*y100 - 33;
  1667. return 0.27837754783474696598e-1 + (0.10078108563256892757e-2 + (0.59020366493792212221e-5 + (0.32979263553246520417e-7 + (0.17498524159268458073e-9 + (0.87622459124842525110e-12 + 0.41066206488888888890e-14 * t) * t) * t) * t) * t) * t;
  1668. }
  1669. case 17: {
  1670. T t = 2*y100 - 35;
  1671. return 0.29877251304899307550e-1 + (0.10318204245057349310e-2 + (0.61041829697162055093e-5 + (0.34414860359542720579e-7 + (0.18399863072934089607e-9 + (0.92703227366365046533e-12 + 0.43639844053333333334e-14 * t) * t) * t) * t) * t) * t;
  1672. }
  1673. case 18: {
  1674. T t = 2*y100 - 37;
  1675. return 0.31965587178596443475e-1 + (0.10566560976716574401e-2 + (0.63151633192414586770e-5 + (0.35924638339521924242e-7 + (0.19353584758781174038e-9 + (0.98102783859889264382e-12 + 0.46381060817777777779e-14 * t) * t) * t) * t) * t) * t;
  1676. }
  1677. case 19: {
  1678. T t = 2*y100 - 39;
  1679. return 0.34104450552588334840e-1 + (0.10823541191350532574e-2 + (0.65354356159553934436e-5 + (0.37512918348533521149e-7 + (0.20362979635817883229e-9 + (0.10384187833037282363e-11 + 0.49300625262222222221e-14 * t) * t) * t) * t) * t) * t;
  1680. }
  1681. case 20: {
  1682. T t = 2*y100 - 41;
  1683. return 0.36295603928292425716e-1 + (0.11089526167995268200e-2 + (0.67654845095518363577e-5 + (0.39184292949913591646e-7 + (0.21431552202133775150e-9 + (0.10994259106646731797e-11 + 0.52409949102222222221e-14 * t) * t) * t) * t) * t) * t;
  1684. }
  1685. case 21: {
  1686. T t = 2*y100 - 43;
  1687. return 0.38540888038840509795e-1 + (0.11364917134175420009e-2 + (0.70058230641246312003e-5 + (0.40943644083718586939e-7 + (0.22563034723692881631e-9 + (0.11642841011361992885e-11 + 0.55721092871111111110e-14 * t) * t) * t) * t) * t) * t;
  1688. }
  1689. case 22: {
  1690. T t = 2*y100 - 45;
  1691. return 0.40842225954785960651e-1 + (0.11650136437945673891e-2 + (0.72569945502343006619e-5 + (0.42796161861855042273e-7 + (0.23761401711005024162e-9 + (0.12332431172381557035e-11 + 0.59246802364444444445e-14 * t) * t) * t) * t) * t) * t;
  1692. }
  1693. case 23: {
  1694. T t = 2*y100 - 47;
  1695. return 0.43201627431540222422e-1 + (0.11945628793917272199e-2 + (0.75195743532849206263e-5 + (0.44747364553960993492e-7 + (0.25030885216472953674e-9 + (0.13065684400300476484e-11 + 0.63000532853333333334e-14 * t) * t) * t) * t) * t) * t;
  1696. }
  1697. case 24: {
  1698. T t = 2*y100 - 49;
  1699. return 0.45621193513810471438e-1 + (0.12251862608067529503e-2 + (0.77941720055551920319e-5 + (0.46803119830954460212e-7 + (0.26375990983978426273e-9 + (0.13845421370977119765e-11 + 0.66996477404444444445e-14 * t) * t) * t) * t) * t) * t;
  1700. }
  1701. case 25: {
  1702. T t = 2*y100 - 51;
  1703. return 0.48103121413299865517e-1 + (0.12569331386432195113e-2 + (0.80814333496367673980e-5 + (0.48969667335682018324e-7 + (0.27801515481905748484e-9 + (0.14674637611609884208e-11 + 0.71249589351111111110e-14 * t) * t) * t) * t) * t) * t;
  1704. }
  1705. case 26: {
  1706. T t = 2*y100 - 53;
  1707. return 0.50649709676983338501e-1 + (0.12898555233099055810e-2 + (0.83820428414568799654e-5 + (0.51253642652551838659e-7 + (0.29312563849675507232e-9 + (0.15556512782814827846e-11 + 0.75775607822222222221e-14 * t) * t) * t) * t) * t) * t;
  1708. }
  1709. case 27: {
  1710. T t = 2*y100 - 55;
  1711. return 0.53263363664388864181e-1 + (0.13240082443256975769e-2 + (0.86967260015007658418e-5 + (0.53662102750396795566e-7 + (0.30914568786634796807e-9 + (0.16494420240828493176e-11 + 0.80591079644444444445e-14 * t) * t) * t) * t) * t) * t;
  1712. }
  1713. case 28: {
  1714. T t = 2*y100 - 57;
  1715. return 0.55946601353500013794e-1 + (0.13594491197408190706e-2 + (0.90262520233016380987e-5 + (0.56202552975056695376e-7 + (0.32613310410503135996e-9 + (0.17491936862246367398e-11 + 0.85713381688888888890e-14 * t) * t) * t) * t) * t) * t;
  1716. }
  1717. case 29: {
  1718. T t = 2*y100 - 59;
  1719. return 0.58702059496154081813e-1 + (0.13962391363223647892e-2 + (0.93714365487312784270e-5 + (0.58882975670265286526e-7 + (0.34414937110591753387e-9 + (0.18552853109751857859e-11 + 0.91160736711111111110e-14 * t) * t) * t) * t) * t) * t;
  1720. }
  1721. case 30: {
  1722. T t = 2*y100 - 61;
  1723. return 0.61532500145144778048e-1 + (0.14344426411912015247e-2 + (0.97331446201016809696e-5 + (0.61711860507347175097e-7 + (0.36325987418295300221e-9 + (0.19681183310134518232e-11 + 0.96952238400000000000e-14 * t) * t) * t) * t) * t) * t;
  1724. }
  1725. case 31: {
  1726. T t = 2*y100 - 63;
  1727. return 0.64440817576653297993e-1 + (0.14741275456383131151e-2 + (0.10112293819576437838e-4 + (0.64698236605933246196e-7 + (0.38353412915303665586e-9 + (0.20881176114385120186e-11 + 0.10310784480000000000e-13 * t) * t) * t) * t) * t) * t;
  1728. }
  1729. case 32: {
  1730. T t = 2*y100 - 65;
  1731. return 0.67430045633130393282e-1 + (0.15153655418916540370e-2 + (0.10509857606888328667e-4 + (0.67851706529363332855e-7 + (0.40504602194811140006e-9 + (0.22157325110542534469e-11 + 0.10964842115555555556e-13 * t) * t) * t) * t) * t) * t;
  1732. }
  1733. case 33: {
  1734. T t = 2*y100 - 67;
  1735. return 0.70503365513338850709e-1 + (0.15582323336495709827e-2 + (0.10926868866865231089e-4 + (0.71182482239613507542e-7 + (0.42787405890153386710e-9 + (0.23514379522274416437e-11 + 0.11659571751111111111e-13 * t) * t) * t) * t) * t) * t;
  1736. }
  1737. case 34: {
  1738. T t = 2*y100 - 69;
  1739. return 0.73664114037944596353e-1 + (0.16028078812438820413e-2 + (0.11364423678778207991e-4 + (0.74701423097423182009e-7 + (0.45210162777476488324e-9 + (0.24957355004088569134e-11 + 0.12397238257777777778e-13 * t) * t) * t) * t) * t) * t;
  1740. }
  1741. case 35: {
  1742. T t = 2*y100 - 71;
  1743. return 0.76915792420819562379e-1 + (0.16491766623447889354e-2 + (0.11823685320041302169e-4 + (0.78420075993781544386e-7 + (0.47781726956916478925e-9 + (0.26491544403815724749e-11 + 0.13180196462222222222e-13 * t) * t) * t) * t) * t) * t;
  1744. }
  1745. case 36: {
  1746. T t = 2*y100 - 73;
  1747. return 0.80262075578094612819e-1 + (0.16974279491709504117e-2 + (0.12305888517309891674e-4 + (0.82350717698979042290e-7 + (0.50511496109857113929e-9 + (0.28122528497626897696e-11 + 0.14010889635555555556e-13 * t) * t) * t) * t) * t) * t;
  1748. }
  1749. case 37: {
  1750. T t = 2*y100 - 75;
  1751. return 0.83706822008980357446e-1 + (0.17476561032212656962e-2 + (0.12812343958540763368e-4 + (0.86506399515036435592e-7 + (0.53409440823869467453e-9 + (0.29856186620887555043e-11 + 0.14891851591111111111e-13 * t) * t) * t) * t) * t) * t;
  1752. }
  1753. case 38: {
  1754. T t = 2*y100 - 77;
  1755. return 0.87254084284461718231e-1 + (0.17999608886001962327e-2 + (0.13344443080089492218e-4 + (0.90900994316429008631e-7 + (0.56486134972616465316e-9 + (0.31698707080033956934e-11 + 0.15825697795555555556e-13 * t) * t) * t) * t) * t) * t;
  1756. }
  1757. case 39: {
  1758. T t = 2*y100 - 79;
  1759. return 0.90908120182172748487e-1 + (0.18544478050657699758e-2 + (0.13903663143426120077e-4 + (0.95549246062549906177e-7 + (0.59752787125242054315e-9 + (0.33656597366099099413e-11 + 0.16815130613333333333e-13 * t) * t) * t) * t) * t) * t;
  1760. }
  1761. case 40: {
  1762. T t = 2*y100 - 81;
  1763. return 0.94673404508075481121e-1 + (0.19112284419887303347e-2 + (0.14491572616545004930e-4 + (0.10046682186333613697e-6 + (0.63221272959791000515e-9 + (0.35736693975589130818e-11 + 0.17862931591111111111e-13 * t) * t) * t) * t) * t) * t;
  1764. }
  1765. case 41: {
  1766. T t = 2*y100 - 83;
  1767. return 0.98554641648004456555e-1 + (0.19704208544725622126e-2 + (0.15109836875625443935e-4 + (0.10567036667675984067e-6 + (0.66904168640019354565e-9 + (0.37946171850824333014e-11 + 0.18971959040000000000e-13 * t) * t) * t) * t) * t) * t;
  1768. }
  1769. case 42: {
  1770. T t = 2*y100 - 85;
  1771. return 0.10255677889470089531e0 + (0.20321499629472857418e-2 + (0.15760224242962179564e-4 + (0.11117756071353507391e-6 + (0.70814785110097658502e-9 + (0.40292553276632563925e-11 + 0.20145143075555555556e-13 * t) * t) * t) * t) * t) * t;
  1772. }
  1773. case 43: {
  1774. T t = 2*y100 - 87;
  1775. return 0.10668502059865093318e0 + (0.20965479776148731610e-2 + (0.16444612377624983565e-4 + (0.11700717962026152749e-6 + (0.74967203250938418991e-9 + (0.42783716186085922176e-11 + 0.21385479360000000000e-13 * t) * t) * t) * t) * t) * t;
  1776. }
  1777. case 44: {
  1778. T t = 2*y100 - 89;
  1779. return 0.11094484319386444474e0 + (0.21637548491908170841e-2 + (0.17164995035719657111e-4 + (0.12317915750735938089e-6 + (0.79376309831499633734e-9 + (0.45427901763106353914e-11 + 0.22696025653333333333e-13 * t) * t) * t) * t) * t) * t;
  1780. }
  1781. case 45: {
  1782. T t = 2*y100 - 91;
  1783. return 0.11534201115268804714e0 + (0.22339187474546420375e-2 + (0.17923489217504226813e-4 + (0.12971465288245997681e-6 + (0.84057834180389073587e-9 + (0.48233721206418027227e-11 + 0.24079890062222222222e-13 * t) * t) * t) * t) * t) * t;
  1784. }
  1785. case 46: {
  1786. T t = 2*y100 - 93;
  1787. return 0.11988259392684094740e0 + (0.23071965691918689601e-2 + (0.18722342718958935446e-4 + (0.13663611754337957520e-6 + (0.89028385488493287005e-9 + (0.51210161569225846701e-11 + 0.25540227111111111111e-13 * t) * t) * t) * t) * t) * t;
  1788. }
  1789. case 47: {
  1790. T t = 2*y100 - 95;
  1791. return 0.12457298393509812907e0 + (0.23837544771809575380e-2 + (0.19563942105711612475e-4 + (0.14396736847739470782e-6 + (0.94305490646459247016e-9 + (0.54366590583134218096e-11 + 0.27080225920000000000e-13 * t) * t) * t) * t) * t) * t;
  1792. }
  1793. case 48: {
  1794. T t = 2*y100 - 97;
  1795. return 0.12941991566142438816e0 + (0.24637684719508859484e-2 + (0.20450821127475879816e-4 + (0.15173366280523906622e-6 + (0.99907632506389027739e-9 + (0.57712760311351625221e-11 + 0.28703099555555555556e-13 * t) * t) * t) * t) * t) * t;
  1796. }
  1797. case 49: {
  1798. T t = 2*y100 - 99;
  1799. return 0.13443048593088696613e0 + (0.25474249981080823877e-2 + (0.21385669591362915223e-4 + (0.15996177579900443030e-6 + (0.10585428844575134013e-8 + (0.61258809536787882989e-11 + 0.30412080142222222222e-13 * t) * t) * t) * t) * t) * t;
  1800. }
  1801. case 50: {
  1802. T t = 2*y100 - 101;
  1803. return 0.13961217543434561353e0 + (0.26349215871051761416e-2 + (0.22371342712572567744e-4 + (0.16868008199296822247e-6 + (0.11216596910444996246e-8 + (0.65015264753090890662e-11 + 0.32210394506666666666e-13 * t) * t) * t) * t) * t) * t;
  1804. }
  1805. case 51: {
  1806. T t = 2*y100 - 103;
  1807. return 0.14497287157673800690e0 + (0.27264675383982439814e-2 + (0.23410870961050950197e-4 + (0.17791863939526376477e-6 + (0.11886425714330958106e-8 + (0.68993039665054288034e-11 + 0.34101266222222222221e-13 * t) * t) * t) * t) * t) * t;
  1808. }
  1809. case 52: {
  1810. T t = 2*y100 - 105;
  1811. return 0.15052089272774618151e0 + (0.28222846410136238008e-2 + (0.24507470422713397006e-4 + (0.18770927679626136909e-6 + (0.12597184587583370712e-8 + (0.73203433049229821618e-11 + 0.36087889048888888890e-13 * t) * t) * t) * t) * t) * t;
  1812. }
  1813. case 53: {
  1814. T t = 2*y100 - 107;
  1815. return 0.15626501395774612325e0 + (0.29226079376196624949e-2 + (0.25664553693768450545e-4 + (0.19808568415654461964e-6 + (0.13351257759815557897e-8 + (0.77658124891046760667e-11 + 0.38173420035555555555e-13 * t) * t) * t) * t) * t) * t;
  1816. }
  1817. case 54: {
  1818. T t = 2*y100 - 109;
  1819. return 0.16221449434620737567e0 + (0.30276865332726475672e-2 + (0.26885741326534564336e-4 + (0.20908350604346384143e-6 + (0.14151148144240728728e-8 + (0.82369170665974313027e-11 + 0.40360957457777777779e-13 * t) * t) * t) * t) * t) * t;
  1820. }
  1821. case 55: {
  1822. T t = 2*y100 - 111;
  1823. return 0.16837910595412130659e0 + (0.31377844510793082301e-2 + (0.28174873844911175026e-4 + (0.22074043807045782387e-6 + (0.14999481055996090039e-8 + (0.87348993661930809254e-11 + 0.42653528977777777779e-13 * t) * t) * t) * t) * t) * t;
  1824. }
  1825. case 56: {
  1826. T t = 2*y100 - 113;
  1827. return 0.17476916455659369953e0 + (0.32531815370903068316e-2 + (0.29536024347344364074e-4 + (0.23309632627767074202e-6 + (0.15899007843582444846e-8 + (0.92610375235427359475e-11 + 0.45054073102222222221e-13 * t) * t) * t) * t) * t) * t;
  1828. }
  1829. case 57: {
  1830. T t = 2*y100 - 115;
  1831. return 0.18139556223643701364e0 + (0.33741744168096996041e-2 + (0.30973511714709500836e-4 + (0.24619326937592290996e-6 + (0.16852609412267750744e-8 + (0.98166442942854895573e-11 + 0.47565418097777777779e-13 * t) * t) * t) * t) * t) * t;
  1832. }
  1833. case 58: {
  1834. T t = 2*y100 - 117;
  1835. return 0.18826980194443664549e0 + (0.35010775057740317997e-2 + (0.32491914440014267480e-4 + (0.26007572375886319028e-6 + (0.17863299617388376116e-8 + (0.10403065638343878679e-10 + 0.50190265831111111110e-13 * t) * t) * t) * t) * t) * t;
  1836. }
  1837. case 59: {
  1838. T t = 2*y100 - 119;
  1839. return 0.19540403413693967350e0 + (0.36342240767211326315e-2 + (0.34096085096200907289e-4 + (0.27479061117017637474e-6 + (0.18934228504790032826e-8 + (0.11021679075323598664e-10 + 0.52931171733333333334e-13 * t) * t) * t) * t) * t) * t;
  1840. }
  1841. case 60: {
  1842. T t = 2*y100 - 121;
  1843. return 0.20281109560651886959e0 + (0.37739673859323597060e-2 + (0.35791165457592409054e-4 + (0.29038742889416172404e-6 + (0.20068685374849001770e-8 + (0.11673891799578381999e-10 + 0.55790523093333333334e-13 * t) * t) * t) * t) * t) * t;
  1844. }
  1845. case 61: {
  1846. T t = 2*y100 - 123;
  1847. return 0.21050455062669334978e0 + (0.39206818613925652425e-2 + (0.37582602289680101704e-4 + (0.30691836231886877385e-6 + (0.21270101645763677824e-8 + (0.12361138551062899455e-10 + 0.58770520160000000000e-13 * t) * t) * t) * t) * t) * t;
  1848. }
  1849. case 62: {
  1850. T t = 2*y100 - 125;
  1851. return 0.21849873453703332479e0 + (0.40747643554689586041e-2 + (0.39476163820986711501e-4 + (0.32443839970139918836e-6 + (0.22542053491518680200e-8 + (0.13084879235290858490e-10 + 0.61873153262222222221e-13 * t) * t) * t) * t) * t) * t;
  1852. }
  1853. case 63: {
  1854. T t = 2*y100 - 127;
  1855. return 0.22680879990043229327e0 + (0.42366354648628516935e-2 + (0.41477956909656896779e-4 + (0.34300544894502810002e-6 + (0.23888264229264067658e-8 + (0.13846596292818514601e-10 + 0.65100183751111111110e-13 * t) * t) * t) * t) * t) * t;
  1856. }
  1857. case 64: {
  1858. T t = 2*y100 - 129;
  1859. return 0.23545076536988703937e0 + (0.44067409206365170888e-2 + (0.43594444916224700881e-4 + (0.36268045617760415178e-6 + (0.25312606430853202748e-8 + (0.14647791812837903061e-10 + 0.68453122631111111110e-13 * t) * t) * t) * t) * t) * t;
  1860. }
  1861. case 65: {
  1862. T t = 2*y100 - 131;
  1863. return 0.24444156740777432838e0 + (0.45855530511605787178e-2 + (0.45832466292683085475e-4 + (0.38352752590033030472e-6 + (0.26819103733055603460e-8 + (0.15489984390884756993e-10 + 0.71933206364444444445e-13 * t) * t) * t) * t) * t) * t;
  1864. }
  1865. case 66: {
  1866. T t = 2*y100 - 133;
  1867. return 0.25379911500634264643e0 + (0.47735723208650032167e-2 + (0.48199253896534185372e-4 + (0.40561404245564732314e-6 + (0.28411932320871165585e-8 + (0.16374705736458320149e-10 + 0.75541379822222222221e-13 * t) * t) * t) * t) * t) * t;
  1868. }
  1869. case 67: {
  1870. T t = 2*y100 - 135;
  1871. return 0.26354234756393613032e0 + (0.49713289477083781266e-2 + (0.50702455036930367504e-4 + (0.42901079254268185722e-6 + (0.30095422058900481753e-8 + (0.17303497025347342498e-10 + 0.79278273368888888890e-13 * t) * t) * t) * t) * t) * t;
  1872. }
  1873. case 68: {
  1874. T t = 2*y100 - 137;
  1875. return 0.27369129607732343398e0 + (0.51793846023052643767e-2 + (0.53350152258326602629e-4 + (0.45379208848865015485e-6 + (0.31874057245814381257e-8 + (0.18277905010245111046e-10 + 0.83144182364444444445e-13 * t) * t) * t) * t) * t) * t;
  1876. }
  1877. case 69: {
  1878. T t = 2*y100 - 139;
  1879. return 0.28426714781640316172e0 + (0.53983341916695141966e-2 + (0.56150884865255810638e-4 + (0.48003589196494734238e-6 + (0.33752476967570796349e-8 + (0.19299477888083469086e-10 + 0.87139049137777777779e-13 * t) * t) * t) * t) * t) * t;
  1880. }
  1881. case 70: {
  1882. T t = 2*y100 - 141;
  1883. return 0.29529231465348519920e0 + (0.56288077305420795663e-2 + (0.59113671189913307427e-4 + (0.50782393781744840482e-6 + (0.35735475025851713168e-8 + (0.20369760937017070382e-10 + 0.91262442613333333334e-13 * t) * t) * t) * t) * t) * t;
  1884. }
  1885. case 71: {
  1886. T t = 2*y100 - 143;
  1887. return 0.30679050522528838613e0 + (0.58714723032745403331e-2 + (0.62248031602197686791e-4 + (0.53724185766200945789e-6 + (0.37827999418960232678e-8 + (0.21490291930444538307e-10 + 0.95513539182222222221e-13 * t) * t) * t) * t) * t) * t;
  1888. }
  1889. case 72: {
  1890. T t = 2*y100 - 145;
  1891. return 0.31878680111173319425e0 + (0.61270341192339103514e-2 + (0.65564012259707640976e-4 + (0.56837930287837738996e-6 + (0.40035151353392378882e-8 + (0.22662596341239294792e-10 + 0.99891109760000000000e-13 * t) * t) * t) * t) * t) * t;
  1892. }
  1893. case 73: {
  1894. T t = 2*y100 - 147;
  1895. return 0.33130773722152622027e0 + (0.63962406646798080903e-2 + (0.69072209592942396666e-4 + (0.60133006661885941812e-6 + (0.42362183765883466691e-8 + (0.23888182347073698382e-10 + 0.10439349811555555556e-12 * t) * t) * t) * t) * t) * t;
  1896. }
  1897. case 74: {
  1898. T t = 2*y100 - 149;
  1899. return 0.34438138658041336523e0 + (0.66798829540414007258e-2 + (0.72783795518603561144e-4 + (0.63619220443228800680e-6 + (0.44814499336514453364e-8 + (0.25168535651285475274e-10 + 0.10901861383111111111e-12 * t) * t) * t) * t) * t) * t;
  1900. }
  1901. case 75: {
  1902. T t = 2*y100 - 151;
  1903. return 0.35803744972380175583e0 + (0.69787978834882685031e-2 + (0.76710543371454822497e-4 + (0.67306815308917386747e-6 + (0.47397647975845228205e-8 + (0.26505114141143050509e-10 + 0.11376390933333333333e-12 * t) * t) * t) * t) * t) * t;
  1904. }
  1905. case 76: {
  1906. T t = 2*y100 - 153;
  1907. return 0.37230734890119724188e0 + (0.72938706896461381003e-2 + (0.80864854542670714092e-4 + (0.71206484718062688779e-6 + (0.50117323769745883805e-8 + (0.27899342394100074165e-10 + 0.11862637614222222222e-12 * t) * t) * t) * t) * t) * t;
  1908. }
  1909. case 77: {
  1910. T t = 2*y100 - 155;
  1911. return 0.38722432730555448223e0 + (0.76260375162549802745e-2 + (0.85259785810004603848e-4 + (0.75329383305171327677e-6 + (0.52979361368388119355e-8 + (0.29352606054164086709e-10 + 0.12360253370666666667e-12 * t) * t) * t) * t) * t) * t;
  1912. }
  1913. case 78: {
  1914. T t = 2*y100 - 157;
  1915. return 0.40282355354616940667e0 + (0.79762880915029728079e-2 + (0.89909077342438246452e-4 + (0.79687137961956194579e-6 + (0.55989731807360403195e-8 + (0.30866246101464869050e-10 + 0.12868841946666666667e-12 * t) * t) * t) * t) * t) * t;
  1916. }
  1917. case 79: {
  1918. T t = 2*y100 - 159;
  1919. return 0.41914223158913787649e0 + (0.83456685186950463538e-2 + (0.94827181359250161335e-4 + (0.84291858561783141014e-6 + (0.59154537751083485684e-8 + (0.32441553034347469291e-10 + 0.13387957943111111111e-12 * t) * t) * t) * t) * t) * t;
  1920. }
  1921. case 80: {
  1922. T t = 2*y100 - 161;
  1923. return 0.43621971639463786896e0 + (0.87352841828289495773e-2 + (0.10002929142066799966e-3 + (0.89156148280219880024e-6 + (0.62480008150788597147e-8 + (0.34079760983458878910e-10 + 0.13917107176888888889e-12 * t) * t) * t) * t) * t) * t;
  1924. }
  1925. case 81: {
  1926. T t = 2*y100 - 163;
  1927. return 0.45409763548534330981e0 + (0.91463027755548240654e-2 + (0.10553137232446167258e-3 + (0.94293113464638623798e-6 + (0.65972492312219959885e-8 + (0.35782041795476563662e-10 + 0.14455745872000000000e-12 * t) * t) * t) * t) * t) * t;
  1928. }
  1929. case 82: {
  1930. T t = 2*y100 - 165;
  1931. return 0.47282001668512331468e0 + (0.95799574408860463394e-2 + (0.11135019058000067469e-3 + (0.99716373005509038080e-6 + (0.69638453369956970347e-8 + (0.37549499088161345850e-10 + 0.15003280712888888889e-12 * t) * t) * t) * t) * t) * t;
  1932. }
  1933. case 83: {
  1934. T t = 2*y100 - 167;
  1935. return 0.49243342227179841649e0 + (0.10037550043909497071e-1 + (0.11750334542845234952e-3 + (0.10544006716188967172e-5 + (0.73484461168242224872e-8 + (0.39383162326435752965e-10 + 0.15559069118222222222e-12 * t) * t) * t) * t) * t) * t;
  1936. }
  1937. case 84: {
  1938. T t = 2*y100 - 169;
  1939. return 0.51298708979209258326e0 + (0.10520454564612427224e-1 + (0.12400930037494996655e-3 + (0.11147886579371265246e-5 + (0.77517184550568711454e-8 + (0.41283980931872622611e-10 + 0.16122419680000000000e-12 * t) * t) * t) * t) * t) * t;
  1940. }
  1941. case 85: {
  1942. T t = 2*y100 - 171;
  1943. return 0.53453307979101369843e0 + (0.11030120618800726938e-1 + (0.13088741519572269581e-3 + (0.11784797595374515432e-5 + (0.81743383063044825400e-8 + (0.43252818449517081051e-10 + 0.16692592640000000000e-12 * t) * t) * t) * t) * t) * t;
  1944. }
  1945. case 86: {
  1946. T t = 2*y100 - 173;
  1947. return 0.55712643071169299478e0 + (0.11568077107929735233e-1 + (0.13815797838036651289e-3 + (0.12456314879260904558e-5 + (0.86169898078969313597e-8 + (0.45290446811539652525e-10 + 0.17268801084444444444e-12 * t) * t) * t) * t) * t) * t;
  1948. }
  1949. case 87: {
  1950. T t = 2*y100 - 175;
  1951. return 0.58082532122519320968e0 + (0.12135935999503877077e-1 + (0.14584223996665838559e-3 + (0.13164068573095710742e-5 + (0.90803643355106020163e-8 + (0.47397540713124619155e-10 + 0.17850211608888888889e-12 * t) * t) * t) * t) * t) * t;
  1952. }
  1953. case 88: {
  1954. T t = 2*y100 - 177;
  1955. return 0.60569124025293375554e0 + (0.12735396239525550361e-1 + (0.15396244472258863344e-3 + (0.13909744385382818253e-5 + (0.95651595032306228245e-8 + (0.49574672127669041550e-10 + 0.18435945564444444444e-12 * t) * t) * t) * t) * t) * t;
  1956. }
  1957. case 89: {
  1958. T t = 2*y100 - 179;
  1959. return 0.63178916494715716894e0 + (0.13368247798287030927e-1 + (0.16254186562762076141e-3 + (0.14695084048334056083e-5 + (0.10072078109604152350e-7 + (0.51822304995680707483e-10 + 0.19025081422222222222e-12 * t) * t) * t) * t) * t) * t;
  1960. }
  1961. case 90: {
  1962. T t = 2*y100 - 181;
  1963. return 0.65918774689725319200e0 + (0.14036375850601992063e-1 + (0.17160483760259706354e-3 + (0.15521885688723188371e-5 + (0.10601827031535280590e-7 + (0.54140790105837520499e-10 + 0.19616655146666666667e-12 * t) * t) * t) * t) * t) * t;
  1964. }
  1965. case 91: {
  1966. T t = 2*y100 - 183;
  1967. return 0.68795950683174433822e0 + (0.14741765091365869084e-1 + (0.18117679143520433835e-3 + (0.16392004108230585213e-5 + (0.11155116068018043001e-7 + (0.56530360194925690374e-10 + 0.20209663662222222222e-12 * t) * t) * t) * t) * t) * t;
  1968. }
  1969. case 92: {
  1970. T t = 2*y100 - 185;
  1971. return 0.71818103808729967036e0 + (0.15486504187117112279e-1 + (0.19128428784550923217e-3 + (0.17307350969359975848e-5 + (0.11732656736113607751e-7 + (0.58991125287563833603e-10 + 0.20803065333333333333e-12 * t) * t) * t) * t) * t) * t;
  1972. }
  1973. case 93: {
  1974. T t = 2*y100 - 187;
  1975. return 0.74993321911726254661e0 + (0.16272790364044783382e-1 + (0.20195505163377912645e-3 + (0.18269894883203346953e-5 + (0.12335161021630225535e-7 + (0.61523068312169087227e-10 + 0.21395783431111111111e-12 * t) * t) * t) * t) * t) * t;
  1976. }
  1977. case 94: {
  1978. T t = 2*y100 - 189;
  1979. return 0.78330143531283492729e0 + (0.17102934132652429240e-1 + (0.21321800585063327041e-3 + (0.19281661395543913713e-5 + (0.12963340087354341574e-7 + (0.64126040998066348872e-10 + 0.21986708942222222222e-12 * t) * t) * t) * t) * t) * t;
  1980. }
  1981. case 95: {
  1982. T t = 2*y100 - 191;
  1983. return 0.81837581041023811832e0 + (0.17979364149044223802e-1 + (0.22510330592753129006e-3 + (0.20344732868018175389e-5 + (0.13617902941839949718e-7 + (0.66799760083972474642e-10 + 0.22574701262222222222e-12 * t) * t) * t) * t) * t) * t;
  1984. }
  1985. case 96: {
  1986. T t = 2*y100 - 193;
  1987. return 0.85525144775685126237e0 + (0.18904632212547561026e-1 + (0.23764237370371255638e-3 + (0.21461248251306387979e-5 + (0.14299555071870523786e-7 + (0.69543803864694171934e-10 + 0.23158593688888888889e-12 * t) * t) * t) * t) * t) * t;
  1988. }
  1989. case 97: {
  1990. T t = 2*y100 - 195;
  1991. return 0.89402868170849933734e0 + (0.19881418399127202569e-1 + (0.25086793128395995798e-3 + (0.22633402747585233180e-5 + (0.15008997042116532283e-7 + (0.72357609075043941261e-10 + 0.23737194737777777778e-12 * t) * t) * t) * t) * t) * t;
  1992. }
  1993. case 98: {
  1994. T t = 2*y100 - 197;
  1995. return 0.93481333942870796363e0 + (0.20912536329780368893e-1 + (0.26481403465998477969e-3 + (0.23863447359754921676e-5 + (0.15746923065472184451e-7 + (0.75240468141720143653e-10 + 0.24309291271111111111e-12 * t) * t) * t) * t) * t) * t;
  1996. }
  1997. case 99: {
  1998. T t = 2*y100 - 199;
  1999. return 0.97771701335885035464e0 + (0.22000938572830479551e-1 + (0.27951610702682383001e-3 + (0.25153688325245314530e-5 + (0.16514019547822821453e-7 + (0.78191526829368231251e-10 + 0.24873652355555555556e-12 * t) * t) * t) * t) * t) * t;
  2000. }
  2001. }
  2002. // we only get here if y = 1, i.e. |x| < 4*eps, in which case
  2003. // erfcx is within 1e-15 of 1..
  2004. return 1.0;
  2005. }
  2006. template <typename T>
  2007. C10_HOST_DEVICE inline typename std::enable_if_t<std::is_floating_point_v<T>, T>
  2008. calc_erfcx(T x)
  2009. {
  2010. if (at::_isnan(x)) {
  2011. return x;
  2012. }
  2013. if (x >= 0) {
  2014. if (x > 50) { // continued-fraction expansion is faster
  2015. const T ispi = 0.56418958354775628694807945156; // 1 / sqrt(pi)
  2016. if (x > 5e7) { // 1-term expansion, important to avoid overflow
  2017. return ispi / x;
  2018. }
  2019. /* 5-term expansion (rely on compiler for CSE), simplified from:
  2020. ispi / (x+0.5/(x+1/(x+1.5/(x+2/x)))) */
  2021. return ispi*((x*x) * (x*x+4.5) + 2) / (x * ((x*x) * (x*x+5) + 3.75));
  2022. }
  2023. return erfcx_y100(400/(4+x));
  2024. }
  2025. else {
  2026. if (x < -26.7) {
  2027. return std::numeric_limits<T>::infinity();
  2028. }
  2029. else if (x < -6.1) {
  2030. return 2*exp(x*x);
  2031. }
  2032. else {
  2033. return 2*exp(x*x) - erfcx_y100(400/(4-x));
  2034. }
  2035. }
  2036. }
  2037. /*
  2038. * Logarithm of Gaussian cumulative distribution function.
  2039. * This implementation of log_ndtr and its helper functions
  2040. * follow SciPy's implementation
  2041. * See NOTICE for the licenses.
  2042. */
  2043. template <typename T>
  2044. inline C10_HOST_DEVICE T calc_log_ndtr(T x) {
  2045. T t = x * c10::frac_sqrt_2<T>;
  2046. if (x < T{-1.0}) {
  2047. return std::log(calc_erfcx(-t) / 2) - t * t;
  2048. } else {
  2049. return std::log1p(-std::erfc(t) / 2);
  2050. }
  2051. }
  2052. template<typename T>
  2053. inline C10_HOST_DEVICE T airy_ai_forward(T x) {
  2054. static const T AN[] = {
  2055. +3.46538101525629032477e-01,
  2056. +1.20075952739645805542e+01,
  2057. +7.62796053615234516538e+01,
  2058. +1.68089224934630576269e+02,
  2059. +1.59756391350164413639e+02,
  2060. +7.05360906840444183113e+01,
  2061. +1.40264691163389668864e+01,
  2062. +9.99999999999999995305e-01,
  2063. };
  2064. static const T AD[] = {
  2065. +5.67594532638770212846e-01,
  2066. +1.47562562584847203173e+01,
  2067. +8.45138970141474626562e+01,
  2068. +1.77318088145400459522e+02,
  2069. +1.64234692871529701831e+02,
  2070. +7.14778400825575695274e+01,
  2071. +1.40959135607834029598e+01,
  2072. +1.00000000000000000470e+00,
  2073. };
  2074. static const T AFN[] = {
  2075. -1.31696323418331795333e-01,
  2076. -6.26456544431912369773e-01,
  2077. -6.93158036036933542233e-01,
  2078. -2.79779981545119124951e-01,
  2079. -4.91900132609500318020e-02,
  2080. -4.06265923594885404393e-03,
  2081. -1.59276496239262096340e-04,
  2082. -2.77649108155232920844e-06,
  2083. -1.67787698489114633780e-08,
  2084. };
  2085. static const T AFD[] = {
  2086. +1.33560420706553243746e+01,
  2087. +3.26825032795224613948e+01,
  2088. +2.67367040941499554804e+01,
  2089. +9.18707402907259625840e+00,
  2090. +1.47529146771666414581e+00,
  2091. +1.15687173795188044134e-01,
  2092. +4.40291641615211203805e-03,
  2093. +7.54720348287414296618e-05,
  2094. +4.51850092970580378464e-07,
  2095. };
  2096. static const T AGN[] = {
  2097. +1.97339932091685679179e-02,
  2098. +3.91103029615688277255e-01,
  2099. +1.06579897599595591108e+00,
  2100. +9.39169229816650230044e-01,
  2101. +3.51465656105547619242e-01,
  2102. +6.33888919628925490927e-02,
  2103. +5.85804113048388458567e-03,
  2104. +2.82851600836737019778e-04,
  2105. +6.98793669997260967291e-06,
  2106. +8.11789239554389293311e-08,
  2107. +3.41551784765923618484e-10,
  2108. };
  2109. static const T AGD[] = {
  2110. +9.30892908077441974853e+00,
  2111. +1.98352928718312140417e+01,
  2112. +1.55646628932864612953e+01,
  2113. +5.47686069422975497931e+00,
  2114. +9.54293611618961883998e-01,
  2115. +8.64580826352392193095e-02,
  2116. +4.12656523824222607191e-03,
  2117. +1.01259085116509135510e-04,
  2118. +1.17166733214413521882e-06,
  2119. +4.91834570062930015649e-09,
  2120. };
  2121. int domain_flag = 0;
  2122. T ai;
  2123. if (std::isinf(x)) {
  2124. return std::numeric_limits<T>::quiet_NaN();
  2125. }
  2126. if (x > T(103.892)) {
  2127. return T(0.0);
  2128. }
  2129. T f;
  2130. T g;
  2131. T k;
  2132. if (x < T(-2.09)) {
  2133. T z = T(1.0) / (T(-2.0) * x * std::sqrt(-x) / T(3.0));
  2134. T afn = 0.0;
  2135. for (uint8_t index = 0; index <= 8; index++) {
  2136. afn = afn * (z * z) + AFN[index];
  2137. }
  2138. T afd = 0.0;
  2139. for (uint8_t index = 0; index <= 8; index++) {
  2140. afd = afd * (z * z) + AFD[index];
  2141. }
  2142. T agn = 0.0;
  2143. for (uint8_t index = 0; index <= 10 + 0; index++) {
  2144. agn = agn * (z * z) + AGN[index];
  2145. }
  2146. T agd = 0.0;
  2147. for (uint8_t index = 0; index <= 10 - 1; index++) {
  2148. agd = agd * (z * z) + AGD[index];
  2149. }
  2150. T t = T(-2.0) * x * std::sqrt(-x) / T(3.0) + T(0.25) * c10::pi<T>;
  2151. return T(5.64189583547756286948e-01) / std::sqrt(std::sqrt(-x)) * (std::sin(t) * (T(1.0) + z * z * afn / afd) - std::cos(t) * (z * agn / agd));
  2152. }
  2153. if (x >= T(2.09)) {
  2154. domain_flag = 5;
  2155. T zeta = T(2.0) * x * std::sqrt(x) / T(3.0);
  2156. T an = 0.0;
  2157. for (uint8_t index = 0; index <= 7; index++) {
  2158. an = an * (T(1.0) / zeta) + AN[index];
  2159. }
  2160. T ad = 0.0;
  2161. for (uint8_t index = 0; index <= 7; index++) {
  2162. ad = ad * (T(1.0) / zeta) + AD[index];
  2163. }
  2164. ai = T(5.64189583547756286948e-01) * (an / ad) / (T(2.0) * std::sqrt(std::sqrt(x)) * std::exp(zeta));
  2165. if (x > T(8.3203353)) {
  2166. return ai;
  2167. }
  2168. }
  2169. f = 1.0;
  2170. g = x;
  2171. k = 1.0;
  2172. T m = 1.0;
  2173. T n = x;
  2174. T t = 1.0;
  2175. T z = x * x * x;
  2176. while (t > std::numeric_limits<T>::epsilon()) {
  2177. m *= z;
  2178. k += T(1.0);
  2179. m /= k;
  2180. n *= z;
  2181. k += T(1.0);
  2182. n /= k;
  2183. m /= k;
  2184. f += m;
  2185. k += T(1.0);
  2186. n /= k;
  2187. g += n;
  2188. t = std::abs(m / f);
  2189. }
  2190. if ((domain_flag & 1) == 0) {
  2191. return T(0.355028053887817239260) * f - T(0.258819403792806798405) * g;
  2192. }
  2193. return ai;
  2194. } // T airy_ai(T x)
  2195. template<typename T>
  2196. inline C10_HOST_DEVICE T bessel_j0_forward(T x) {
  2197. static const T PP[] = {
  2198. +7.96936729297347051624e-04,
  2199. +8.28352392107440799803e-02,
  2200. +1.23953371646414299388e+00,
  2201. +5.44725003058768775090e+00,
  2202. +8.74716500199817011941e+00,
  2203. +5.30324038235394892183e+00,
  2204. +9.99999999999999997821e-01,
  2205. };
  2206. static const T PQ[] = {
  2207. +9.24408810558863637013e-04,
  2208. +8.56288474354474431428e-02,
  2209. +1.25352743901058953537e+00,
  2210. +5.47097740330417105182e+00,
  2211. +8.76190883237069594232e+00,
  2212. +5.30605288235394617618e+00,
  2213. +1.00000000000000000218e+00,
  2214. };
  2215. static const T QP[] = {
  2216. -1.13663838898469149931e-02,
  2217. -1.28252718670509318512e+00,
  2218. -1.95539544257735972385e+01,
  2219. -9.32060152123768231369e+01,
  2220. -1.77681167980488050595e+02,
  2221. -1.47077505154951170175e+02,
  2222. -5.14105326766599330220e+01,
  2223. -6.05014350600728481186e+00,
  2224. };
  2225. static const T QQ[] = {
  2226. +6.43178256118178023184e+01,
  2227. +8.56430025976980587198e+02,
  2228. +3.88240183605401609683e+03,
  2229. +7.24046774195652478189e+03,
  2230. +5.93072701187316984827e+03,
  2231. +2.06209331660327847417e+03,
  2232. +2.42005740240291393179e+02,
  2233. };
  2234. static const T RP[] = {
  2235. -4.79443220978201773821e+09,
  2236. +1.95617491946556577543e+12,
  2237. -2.49248344360967716204e+14,
  2238. +9.70862251047306323952e+15,
  2239. };
  2240. static const T RQ[] = {
  2241. +4.99563147152651017219e+02,
  2242. +1.73785401676374683123e+05,
  2243. +4.84409658339962045305e+07,
  2244. +1.11855537045356834862e+10,
  2245. +2.11277520115489217587e+12,
  2246. +3.10518229857422583814e+14,
  2247. +3.18121955943204943306e+16,
  2248. +1.71086294081043136091e+18,
  2249. };
  2250. if (x < T(0)) {
  2251. x = -x;
  2252. }
  2253. if (x <= T(5.0)) {
  2254. if (x < T(0.00001)) {
  2255. return T(1.0) - x * x / T(4.0);
  2256. }
  2257. T rp = 0.0;
  2258. for (uint8_t index = 0; index <= 3; index++) {
  2259. rp = rp * (x * x) + RP[index];
  2260. }
  2261. T rq = 0.0;
  2262. for (uint8_t index = 0; index <= 7; index++) {
  2263. rq = rq * (x * x) + RQ[index];
  2264. }
  2265. return (x * x - T(5.78318596294678452118e+00)) * (x * x - T(3.04712623436620863991e+01)) * rp / rq;
  2266. }
  2267. T pp = 0.0;
  2268. for (uint8_t index = 0; index <= 6; index++) {
  2269. pp = pp * (T(25.0) / (x * x)) + PP[index];
  2270. }
  2271. T pq = 0.0;
  2272. for (uint8_t index = 0; index <= 6; index++) {
  2273. pq = pq * (T(25.0) / (x * x)) + PQ[index];
  2274. }
  2275. T qp = 0.0;
  2276. for (uint8_t index = 0; index <= 7; index++) {
  2277. qp = qp * (T(25.0) / (x * x)) + QP[index];
  2278. }
  2279. T qq = 0.0;
  2280. for (uint8_t index = 0; index <= 6; index++) {
  2281. qq = qq * (T(25.0) / (x * x)) + QQ[index];
  2282. }
  2283. return (pp / pq * std::cos(x - T(0.785398163397448309615660845819875721)) - T(5.0) / x * (qp / qq) * std::sin(x - T(0.785398163397448309615660845819875721))) * T(0.797884560802865355879892119868763737) / std::sqrt(x);
  2284. } // bessel_j0_forward(T x)
  2285. template<typename T>
  2286. inline C10_HOST_DEVICE T bessel_j1_forward(T x) {
  2287. static const T PP[] = {
  2288. +7.62125616208173112003e-04,
  2289. +7.31397056940917570436e-02,
  2290. +1.12719608129684925192e+00,
  2291. +5.11207951146807644818e+00,
  2292. +8.42404590141772420927e+00,
  2293. +5.21451598682361504063e+00,
  2294. +1.00000000000000000254e+00,
  2295. };
  2296. static const T PQ[] = {
  2297. +5.71323128072548699714e-04,
  2298. +6.88455908754495404082e-02,
  2299. +1.10514232634061696926e+00,
  2300. +5.07386386128601488557e+00,
  2301. +8.39985554327604159757e+00,
  2302. +5.20982848682361821619e+00,
  2303. +9.99999999999999997461e-01,
  2304. };
  2305. static const T QP[] = {
  2306. +5.10862594750176621635e-02,
  2307. +4.98213872951233449420e+00,
  2308. +7.58238284132545283818e+01,
  2309. +3.66779609360150777800e+02,
  2310. +7.10856304998926107277e+02,
  2311. +5.97489612400613639965e+02,
  2312. +2.11688757100572135698e+02,
  2313. +2.52070205858023719784e+01,
  2314. };
  2315. static const T QQ[] = {
  2316. +7.42373277035675149943e+01,
  2317. +1.05644886038262816351e+03,
  2318. +4.98641058337653607651e+03,
  2319. +9.56231892404756170795e+03,
  2320. +7.99704160447350683650e+03,
  2321. +2.82619278517639096600e+03,
  2322. +3.36093607810698293419e+02,
  2323. };
  2324. static const T RP[] = {
  2325. -8.99971225705559398224e+08,
  2326. +4.52228297998194034323e+11,
  2327. -7.27494245221818276015e+13,
  2328. +3.68295732863852883286e+15,
  2329. };
  2330. static const T RQ[] = {
  2331. +6.20836478118054335476e+02,
  2332. +2.56987256757748830383e+05,
  2333. +8.35146791431949253037e+07,
  2334. +2.21511595479792499675e+10,
  2335. +4.74914122079991414898e+12,
  2336. +7.84369607876235854894e+14,
  2337. +8.95222336184627338078e+16,
  2338. +5.32278620332680085395e+18,
  2339. };
  2340. if (x < T(0.0)) {
  2341. return -bessel_j1_forward(-x);
  2342. }
  2343. if (x <= T(5.0)) {
  2344. T rp = 0.0;
  2345. for (uint8_t index = 0; index <= 3; index++) {
  2346. rp = rp * (x * x) + RP[index];
  2347. }
  2348. T rq = 0.0;
  2349. for (uint8_t index = 0; index <= 7; index++) {
  2350. rq = rq * (x * x) + RQ[index];
  2351. }
  2352. return rp / rq * x * (x * x - T(1.46819706421238932572e+01)) * (x * x - T(4.92184563216946036703e+01));
  2353. }
  2354. T pp = 0.0;
  2355. for (uint8_t index = 0; index <= 6; index++) {
  2356. pp = pp * (T(5.0) / x * (T(5.0) / x)) + PP[index];
  2357. }
  2358. T pq = 0.0;
  2359. for (uint8_t index = 0; index <= 6; index++) {
  2360. pq = pq * (T(5.0) / x * (T(5.0) / x)) + PQ[index];
  2361. }
  2362. T qp = 0.0;
  2363. for (uint8_t index = 0; index <= 7; index++) {
  2364. qp = qp * (T(5.0) / x * (T(5.0) / x)) + QP[index];
  2365. }
  2366. T qq = 0.0;
  2367. for (uint8_t index = 0; index <= 6; index++) {
  2368. qq = qq * (T(5.0) / x * (T(5.0) / x)) + QQ[index];
  2369. }
  2370. return (pp / pq * std::cos(x - T(2.356194490192344928846982537459627163)) - T(5.0) / x * (qp / qq) * std::sin(x - T(2.356194490192344928846982537459627163))) * T(0.797884560802865355879892119868763737) / std::sqrt(x);
  2371. } // bessel_j1_forward(T x)
  2372. template<typename T>
  2373. inline C10_HOST_DEVICE T bessel_y0_forward(T x) {
  2374. static const T PP[] = {
  2375. +7.96936729297347051624e-04,
  2376. +8.28352392107440799803e-02,
  2377. +1.23953371646414299388e+00,
  2378. +5.44725003058768775090e+00,
  2379. +8.74716500199817011941e+00,
  2380. +5.30324038235394892183e+00,
  2381. +9.99999999999999997821e-01,
  2382. };
  2383. static const T PQ[] = {
  2384. +9.24408810558863637013e-04,
  2385. +8.56288474354474431428e-02,
  2386. +1.25352743901058953537e+00,
  2387. +5.47097740330417105182e+00,
  2388. +8.76190883237069594232e+00,
  2389. +5.30605288235394617618e+00,
  2390. +1.00000000000000000218e+00,
  2391. };
  2392. static const T QP[] = {
  2393. -1.13663838898469149931e-02,
  2394. -1.28252718670509318512e+00,
  2395. -1.95539544257735972385e+01,
  2396. -9.32060152123768231369e+01,
  2397. -1.77681167980488050595e+02,
  2398. -1.47077505154951170175e+02,
  2399. -5.14105326766599330220e+01,
  2400. -6.05014350600728481186e+00,
  2401. };
  2402. static const T QQ[] = {
  2403. +6.43178256118178023184e+01,
  2404. +8.56430025976980587198e+02,
  2405. +3.88240183605401609683e+03,
  2406. +7.24046774195652478189e+03,
  2407. +5.93072701187316984827e+03,
  2408. +2.06209331660327847417e+03,
  2409. +2.42005740240291393179e+02,
  2410. };
  2411. static const T YP[] = {
  2412. +1.55924367855235737965e+04,
  2413. -1.46639295903971606143e+07,
  2414. +5.43526477051876500413e+09,
  2415. -9.82136065717911466409e+11,
  2416. +8.75906394395366999549e+13,
  2417. -3.46628303384729719441e+15,
  2418. +4.42733268572569800351e+16,
  2419. -1.84950800436986690637e+16,
  2420. };
  2421. static const T YQ[] = {
  2422. +1.04128353664259848412e+03,
  2423. +6.26107330137134956842e+05,
  2424. +2.68919633393814121987e+08,
  2425. +8.64002487103935000337e+10,
  2426. +2.02979612750105546709e+13,
  2427. +3.17157752842975028269e+15,
  2428. +2.50596256172653059228e+17,
  2429. };
  2430. if (x <= T(5.0)) {
  2431. if (x == T(0.0)) {
  2432. return -std::numeric_limits<T>::infinity();
  2433. }
  2434. if (x < T(0.0)) {
  2435. return std::numeric_limits<T>::quiet_NaN();
  2436. }
  2437. T yp = 0.0;
  2438. for (uint8_t index = 0; index <= 7; index++) {
  2439. yp = yp * (x * x) + YP[index];
  2440. }
  2441. T yq = 0.0;
  2442. for (uint8_t index = 0; index <= 6; index++) {
  2443. yq = yq * (x * x) + YQ[index];
  2444. }
  2445. return yp / yq + (T(0.636619772367581343075535053490057448) * std::log(x) * bessel_j0_forward(x));
  2446. }
  2447. T pp = 0.0;
  2448. for (uint8_t index = 0; index <= 6; index++) {
  2449. pp = pp * (T(25.0) / (x * x)) + PP[index];
  2450. }
  2451. T pq = 0.0;
  2452. for (uint8_t index = 0; index <= 6; index++) {
  2453. pq = pq * (T(25.0) / (x * x)) + PQ[index];
  2454. }
  2455. T qp = 0.0;
  2456. for (uint8_t index = 0; index <= 7; index++) {
  2457. qp = qp * (T(25.0) / (x * x)) + QP[index];
  2458. }
  2459. T qq = 0.0;
  2460. for (uint8_t index = 0; index <= 6; index++) {
  2461. qq = qq * (T(25.0) / (x * x)) + QQ[index];
  2462. }
  2463. return (pp / pq * std::sin(x - T(0.785398163397448309615660845819875721)) + T(5.0) / x * (qp / qq) * std::cos(x - T(0.785398163397448309615660845819875721))) * T(0.797884560802865355879892119868763737) / std::sqrt(x);
  2464. } // bessel_y0_forward(T x)
  2465. template<typename T>
  2466. inline C10_HOST_DEVICE T bessel_y1_forward(T x) {
  2467. static const T PP[] = {
  2468. +7.62125616208173112003e-04,
  2469. +7.31397056940917570436e-02,
  2470. +1.12719608129684925192e+00,
  2471. +5.11207951146807644818e+00,
  2472. +8.42404590141772420927e+00,
  2473. +5.21451598682361504063e+00,
  2474. +1.00000000000000000254e+00,
  2475. };
  2476. static const T PQ[] = {
  2477. +5.71323128072548699714e-04,
  2478. +6.88455908754495404082e-02,
  2479. +1.10514232634061696926e+00,
  2480. +5.07386386128601488557e+00,
  2481. +8.39985554327604159757e+00,
  2482. +5.20982848682361821619e+00,
  2483. +9.99999999999999997461e-01,
  2484. };
  2485. static const T QP[] = {
  2486. +5.10862594750176621635e-02,
  2487. +4.98213872951233449420e+00,
  2488. +7.58238284132545283818e+01,
  2489. +3.66779609360150777800e+02,
  2490. +7.10856304998926107277e+02,
  2491. +5.97489612400613639965e+02,
  2492. +2.11688757100572135698e+02,
  2493. +2.52070205858023719784e+01,
  2494. };
  2495. static const T QQ[] = {
  2496. +7.42373277035675149943e+01,
  2497. +1.05644886038262816351e+03,
  2498. +4.98641058337653607651e+03,
  2499. +9.56231892404756170795e+03,
  2500. +7.99704160447350683650e+03,
  2501. +2.82619278517639096600e+03,
  2502. +3.36093607810698293419e+02,
  2503. };
  2504. static const T YP[] = {
  2505. +1.26320474790178026440e+09,
  2506. -6.47355876379160291031e+11,
  2507. +1.14509511541823727583e+14,
  2508. -8.12770255501325109621e+15,
  2509. +2.02439475713594898196e+17,
  2510. -7.78877196265950026825e+17,
  2511. };
  2512. static const T YQ[] = {
  2513. +5.94301592346128195359e+02,
  2514. +2.35564092943068577943e+05,
  2515. +7.34811944459721705660e+07,
  2516. +1.87601316108706159478e+10,
  2517. +3.88231277496238566008e+12,
  2518. +6.20557727146953693363e+14,
  2519. +6.87141087355300489866e+16,
  2520. +3.97270608116560655612e+18,
  2521. };
  2522. if (x <= T(5.0)) {
  2523. if (x == T(0.0)) {
  2524. return -std::numeric_limits<T>::infinity();
  2525. }
  2526. if (x <= T(0.0)) {
  2527. return std::numeric_limits<T>::quiet_NaN();
  2528. }
  2529. T yp = 0.0;
  2530. for (uint8_t index = 0; index <= 5; index++) {
  2531. yp = yp * (x * x) + YP[index];
  2532. }
  2533. T yq = 0.0;
  2534. for (uint8_t index = 0; index <= 7; index++) {
  2535. yq = yq * (x * x) + YQ[index];
  2536. }
  2537. return x * (yp / yq) + (T(0.636619772367581343075535053490057448) * (bessel_j1_forward(x) * std::log(x) - T(1.0) / x));
  2538. }
  2539. T pp = 0.0;
  2540. for (uint8_t index = 0; index <= 6; index++) {
  2541. pp = pp * (T(5.0) / x * (T(5.0) / x)) + PP[index];
  2542. }
  2543. T pq = 0.0;
  2544. for (uint8_t index = 0; index <= 6; index++) {
  2545. pq = pq * (T(5.0) / x * (T(5.0) / x)) + PQ[index];
  2546. }
  2547. T qp = 0.0;
  2548. for (uint8_t index = 0; index <= 7; index++) {
  2549. qp = qp * (T(5.0) / x * (T(5.0) / x)) + QP[index];
  2550. }
  2551. T qq = 0.0;
  2552. for (uint8_t index = 0; index <= 6; index++) {
  2553. qq = qq * (T(5.0) / x * (T(5.0) / x)) + QQ[index];
  2554. }
  2555. return (pp / pq * std::sin(x - T(2.356194490192344928846982537459627163)) + T(5.0) / x * (qp / qq) * std::cos(x - T(2.356194490192344928846982537459627163))) * T(0.797884560802865355879892119868763737) / std::sqrt(x);
  2556. } // bessel_y1_forward(T x)
  2557. template<typename T>
  2558. inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, int64_t n) {
  2559. if (n < 0) {
  2560. return T(0.0);
  2561. }
  2562. if (std::abs(x) == T(1.0)) {
  2563. if (x > T(0.0) || n % 2 == 0) {
  2564. return T(1.0);
  2565. }
  2566. return T(-1.0);
  2567. }
  2568. if ((n > 6) && (std::abs(x) < T(1.0))) {
  2569. return std::cos(n * std::acos(x));
  2570. }
  2571. if (n == 0) {
  2572. return T(1.0);
  2573. }
  2574. if (n == 1) {
  2575. return x;
  2576. }
  2577. T p = T(1.0);
  2578. T q = x;
  2579. T r;
  2580. for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) {
  2581. r = (x + x) * q - p;
  2582. p = q;
  2583. q = r;
  2584. }
  2585. return r;
  2586. } // chebyshev_polynomial_t_forward(T x, int64_t n)
  2587. template<typename T, bool is_cuda=false>
  2588. inline C10_HOST_DEVICE T chebyshev_polynomial_t_forward(T x, T n) {
  2589. return chebyshev_polynomial_t_forward(x, static_cast<int64_t>(n));
  2590. } // chebyshev_polynomial_t_forward(T x, T n)
  2591. template<typename T>
  2592. inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, int64_t n) {
  2593. if (n < 0) {
  2594. return T(0.0);
  2595. }
  2596. if (std::abs(x) == T(1.0)) {
  2597. if (x > T(0.0) || n % 2 == 0) {
  2598. return n + 1;
  2599. }
  2600. return -(n + 1);
  2601. }
  2602. if ((n > 8) && (std::abs(x) < T(1.0))) {
  2603. if (std::sin(std::acos(x)) != T(0.0)) {
  2604. return std::sin((n + 1) * std::acos(x)) / std::sin(std::acos(x));
  2605. }
  2606. return (n + 1) * std::cos((n + 1) * std::acos(x)) / x;
  2607. }
  2608. if (n == 0) {
  2609. return T(1.0);
  2610. }
  2611. if (n == 1) {
  2612. return x + x;
  2613. }
  2614. T p = T(1.0);
  2615. T q = x + x;
  2616. T r;
  2617. for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) {
  2618. r = (x + x) * q - p;
  2619. p = q;
  2620. q = r;
  2621. }
  2622. return r;
  2623. } // chebyshev_polynomial_u_forward(T x, int64_t n)
  2624. template<typename T, bool is_cuda=false>
  2625. inline C10_HOST_DEVICE T chebyshev_polynomial_u_forward(T x, T n) {
  2626. return chebyshev_polynomial_u_forward(x, static_cast<int64_t>(n));
  2627. } // chebyshev_polynomial_u_forward(T x, T n)
  2628. template<typename T>
  2629. inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, int64_t n) {
  2630. if (n < 0) {
  2631. return T(0.0);
  2632. }
  2633. if (std::abs(x) == T(1.0)) {
  2634. if (x > T(0.0)) {
  2635. return T(1.0);
  2636. }
  2637. if (n % 2 == 0) {
  2638. return n + n + 1;
  2639. }
  2640. return -(n + n + 1);
  2641. }
  2642. if ((n > 8) && (std::abs(x) < T(1.0))) {
  2643. if (std::sin(std::acos(x) / T(2.0)) != T(1.0)) {
  2644. return std::cos((n + T(0.5)) * std::acos(x)) / std::cos(std::acos(x) / T(2.0));
  2645. }
  2646. if (n % 2 == 0) {
  2647. return n + n + 1;
  2648. }
  2649. return -(n + n + 1);
  2650. }
  2651. if (n == 0) {
  2652. return T(1.0);
  2653. }
  2654. if (n == 1) {
  2655. return x + x - T(1.0);
  2656. }
  2657. T p = T(1.0);
  2658. T q = x + x - T(1.0);
  2659. T r;
  2660. for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) {
  2661. r = (x + x) * q - p;
  2662. p = q;
  2663. q = r;
  2664. }
  2665. return r;
  2666. } // chebyshev_polynomial_v_forward(T x, int64_t n)
  2667. template<typename T, bool is_cuda=false>
  2668. inline C10_HOST_DEVICE T chebyshev_polynomial_v_forward(T x, T n) {
  2669. return chebyshev_polynomial_v_forward(x, static_cast<int64_t>(n));
  2670. } // chebyshev_polynomial_v_forward(T x, T n)
  2671. template<typename T>
  2672. inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, int64_t n) {
  2673. if (n < 0) {
  2674. return T(0.0);
  2675. }
  2676. if (std::abs(x) == T(1.0)) {
  2677. if (x > T(0.0)) {
  2678. return n + n + 1;
  2679. }
  2680. if (n % 2 == 0) {
  2681. return T(1.0);
  2682. }
  2683. return T(-1.0);
  2684. }
  2685. if ((n > 8) && (std::abs(x) < T(1.0))) {
  2686. if (std::cos(std::acos(x) / T(2.0)) != T(1.0)) {
  2687. return std::sin((n + T(0.5)) * std::acos(x)) / std::sin(std::acos(x) / T(2.0));
  2688. }
  2689. if (x > T(0.0)) {
  2690. return n + n + 1;
  2691. }
  2692. if (n % 2 == 0) {
  2693. return T(1.0);
  2694. }
  2695. return T(-1.0);
  2696. }
  2697. if (n == 0) {
  2698. return T(1.0);
  2699. }
  2700. if (n == 1) {
  2701. return x + x + T(1.0);
  2702. }
  2703. T p = T(1.0);
  2704. T q = x + x + T(1.0);
  2705. T r;
  2706. for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) {
  2707. r = (x + x) * q - p;
  2708. p = q;
  2709. q = r;
  2710. }
  2711. return r;
  2712. } // chebyshev_polynomial_w_forward(T x, int64_t n)
  2713. template<typename T, bool is_cuda=false>
  2714. inline C10_HOST_DEVICE T chebyshev_polynomial_w_forward(T x, T n) {
  2715. return chebyshev_polynomial_w_forward(x, static_cast<int64_t>(n));
  2716. } // chebyshev_polynomial_w_forward(T x, T n)
  2717. template<typename T>
  2718. constexpr auto getHermitianLimit() {
  2719. if constexpr (std::is_same_v<T, float>) {
  2720. return 128;
  2721. } else if constexpr (std::is_same_v<T, double>) {
  2722. return 512;
  2723. } else {
  2724. return 1024;
  2725. }
  2726. }
  2727. template<typename T>
  2728. inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, int64_t n) {
  2729. if (n < 0) {
  2730. return T(0.0);
  2731. }
  2732. if (n == 0) {
  2733. return T(1.0);
  2734. }
  2735. if (n == 1) {
  2736. return x + x;
  2737. }
  2738. if (n > getHermitianLimit<T>()) {
  2739. return std::numeric_limits<T>::quiet_NaN();
  2740. }
  2741. T p = T(1.0);
  2742. T q = x + x;
  2743. T r = T(0.0);
  2744. for (int64_t k = 2; k < n + n; k += 2) {
  2745. r = (x + x) * q - k * p;
  2746. p = q;
  2747. q = r;
  2748. }
  2749. return r;
  2750. } // hermite_polynomial_h_forward(T x, int64_t n)
  2751. template<typename T, bool is_cuda=false, std::enable_if_t<!std::is_floating_point_v<T>, int> = 0>
  2752. inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) {
  2753. return hermite_polynomial_h_forward(x, static_cast<int64_t>(n));
  2754. } // hermite_polynomial_h_forward(T x, T n)
  2755. template<typename T, bool is_cuda=false, std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
  2756. __ubsan_ignore_float_cast_overflow__ inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) {
  2757. return hermite_polynomial_h_forward(x, (!std::isinf(n) && !std::isnan(n)) ? static_cast<int64_t>(n) : static_cast<int64_t>(-1));
  2758. } // hermite_polynomial_h_forward(T x, T n)
  2759. template<typename T>
  2760. inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, int64_t n) {
  2761. if (n < 0) {
  2762. return T(0.0);
  2763. }
  2764. if (n == 0) {
  2765. return T(1.0);
  2766. }
  2767. if (n == 1) {
  2768. return x;
  2769. }
  2770. if (n > getHermitianLimit<T>()) {
  2771. return std::numeric_limits<T>::quiet_NaN();
  2772. }
  2773. T p = T(1.0);
  2774. T q = x;
  2775. T r;
  2776. for (int64_t k = 1; k < n; k++) {
  2777. r = x * q - k * p;
  2778. p = q;
  2779. q = r;
  2780. }
  2781. return r;
  2782. } // hermite_polynomial_he_forward(T x, int64_t n)
  2783. template<typename T, bool is_cuda=false>
  2784. inline C10_HOST_DEVICE T hermite_polynomial_he_forward(T x, T n) {
  2785. return hermite_polynomial_he_forward(x, static_cast<int64_t>(n));
  2786. } // hermite_polynomial_he_forward(T x, T n)
  2787. template<typename T>
  2788. inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, int64_t n) {
  2789. if (n < 0) {
  2790. return T(0.0);
  2791. }
  2792. if (std::abs(x) == T(0.0)) {
  2793. return T(1.0);
  2794. }
  2795. if (n == 0) {
  2796. return T(1.0);
  2797. }
  2798. if (n == 1) {
  2799. return T(1.0) - x;
  2800. }
  2801. T p = T(1.0);
  2802. T q = T(1.0) - x;
  2803. T r;
  2804. for (int64_t k = 1; (k < n) && !std::isnan(q); k++) {
  2805. r = (((k + k) + (T(1.0) - x)) * q - k * p) / (k + 1);
  2806. p = q;
  2807. q = r;
  2808. }
  2809. return r;
  2810. } // laguerre_polynomial_l_forward(T x, int64_t n)
  2811. template<typename T, bool is_cuda=false>
  2812. inline C10_HOST_DEVICE T laguerre_polynomial_l_forward(T x, T n) {
  2813. return laguerre_polynomial_l_forward(x, static_cast<int64_t>(n));
  2814. } // laguerre_polynomial_l_forward(T x, T n)
  2815. template<typename T>
  2816. inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, int64_t n) {
  2817. if (n < 0) {
  2818. return T(0.0);
  2819. }
  2820. if (std::abs(x) == T(1.0)) {
  2821. if (x > T(0.0) || n % 2 == 0) {
  2822. return T(1.0);
  2823. }
  2824. return T(-1.0);
  2825. }
  2826. if (n == 0) {
  2827. return T(1.0);
  2828. }
  2829. if (n == 1) {
  2830. return x;
  2831. }
  2832. T p = T(1.0);
  2833. T q = x;
  2834. T r;
  2835. for (int64_t k = 1; (k < n) && !std::isnan(q); k++) {
  2836. r = ((k + k + 1) * x * q - k * p) / (k + 1);
  2837. p = q;
  2838. q = r;
  2839. }
  2840. return r;
  2841. } // legendre_polynomial_p_forward(T x, int64_t n)
  2842. template<typename T, bool is_cuda=false>
  2843. inline C10_HOST_DEVICE T legendre_polynomial_p_forward(T x, T n) {
  2844. return legendre_polynomial_p_forward(x, static_cast<int64_t>(n));
  2845. } // legendre_polynomial_p_forward(T x, T n)
  2846. template<typename T>
  2847. inline C10_HOST_DEVICE T modified_bessel_i0_forward(T x) {
  2848. static const T A[] = {
  2849. -4.41534164647933937950e-18,
  2850. +3.33079451882223809783e-17,
  2851. -2.43127984654795469359e-16,
  2852. +1.71539128555513303061e-15,
  2853. -1.16853328779934516808e-14,
  2854. +7.67618549860493561688e-14,
  2855. -4.85644678311192946090e-13,
  2856. +2.95505266312963983461e-12,
  2857. -1.72682629144155570723e-11,
  2858. +9.67580903537323691224e-11,
  2859. -5.18979560163526290666e-10,
  2860. +2.65982372468238665035e-09,
  2861. -1.30002500998624804212e-08,
  2862. +6.04699502254191894932e-08,
  2863. -2.67079385394061173391e-07,
  2864. +1.11738753912010371815e-06,
  2865. -4.41673835845875056359e-06,
  2866. +1.64484480707288970893e-05,
  2867. -5.75419501008210370398e-05,
  2868. +1.88502885095841655729e-04,
  2869. -5.76375574538582365885e-04,
  2870. +1.63947561694133579842e-03,
  2871. -4.32430999505057594430e-03,
  2872. +1.05464603945949983183e-02,
  2873. -2.37374148058994688156e-02,
  2874. +4.93052842396707084878e-02,
  2875. -9.49010970480476444210e-02,
  2876. +1.71620901522208775349e-01,
  2877. -3.04682672343198398683e-01,
  2878. +6.76795274409476084995e-01,
  2879. };
  2880. static const T B[] = {
  2881. -7.23318048787475395456e-18,
  2882. -4.83050448594418207126e-18,
  2883. +4.46562142029675999901e-17,
  2884. +3.46122286769746109310e-17,
  2885. -2.82762398051658348494e-16,
  2886. -3.42548561967721913462e-16,
  2887. +1.77256013305652638360e-15,
  2888. +3.81168066935262242075e-15,
  2889. -9.55484669882830764870e-15,
  2890. -4.15056934728722208663e-14,
  2891. +1.54008621752140982691e-14,
  2892. +3.85277838274214270114e-13,
  2893. +7.18012445138366623367e-13,
  2894. -1.79417853150680611778e-12,
  2895. -1.32158118404477131188e-11,
  2896. -3.14991652796324136454e-11,
  2897. +1.18891471078464383424e-11,
  2898. +4.94060238822496958910e-10,
  2899. +3.39623202570838634515e-09,
  2900. +2.26666899049817806459e-08,
  2901. +2.04891858946906374183e-07,
  2902. +2.89137052083475648297e-06,
  2903. +6.88975834691682398426e-05,
  2904. +3.36911647825569408990e-03,
  2905. +8.04490411014108831608e-01,
  2906. };
  2907. T p;
  2908. T q = 0.0;
  2909. if (std::abs(x) <= T(8.0)) {
  2910. T a = A[0];
  2911. for (uint8_t index = 1; index < 30; index++) {
  2912. p = q;
  2913. q = a;
  2914. a = ((std::abs(x) / T(2.0)) - T(2.0)) * q - p + A[index];
  2915. }
  2916. return std::exp(std::abs(x)) * (T(0.5) * (a - p));
  2917. }
  2918. T b = B[0];
  2919. for (uint8_t index = 1; index < 25; index++) {
  2920. p = q;
  2921. q = b;
  2922. b = (T(32.0) / std::abs(x) - T(2.0)) * q - p + B[index];
  2923. }
  2924. return std::exp(std::abs(x)) * (T(0.5) * (b - p)) / std::sqrt(std::abs(x));
  2925. } // modified_bessel_i0_forward(T x)
  2926. template<typename T>
  2927. inline C10_HOST_DEVICE T modified_bessel_i1_forward(T x) {
  2928. static const T A[] = {
  2929. +2.77791411276104639959e-18,
  2930. -2.11142121435816608115e-17,
  2931. +1.55363195773620046921e-16,
  2932. -1.10559694773538630805e-15,
  2933. +7.60068429473540693410e-15,
  2934. -5.04218550472791168711e-14,
  2935. +3.22379336594557470981e-13,
  2936. -1.98397439776494371520e-12,
  2937. +1.17361862988909016308e-11,
  2938. -6.66348972350202774223e-11,
  2939. +3.62559028155211703701e-10,
  2940. -1.88724975172282928790e-09,
  2941. +9.38153738649577178388e-09,
  2942. -4.44505912879632808065e-08,
  2943. +2.00329475355213526229e-07,
  2944. -8.56872026469545474066e-07,
  2945. +3.47025130813767847674e-06,
  2946. -1.32731636560394358279e-05,
  2947. +4.78156510755005422638e-05,
  2948. -1.61760815825896745588e-04,
  2949. +5.12285956168575772895e-04,
  2950. -1.51357245063125314899e-03,
  2951. +4.15642294431288815669e-03,
  2952. -1.05640848946261981558e-02,
  2953. +2.47264490306265168283e-02,
  2954. -5.29459812080949914269e-02,
  2955. +1.02643658689847095384e-01,
  2956. -1.76416518357834055153e-01,
  2957. +2.52587186443633654823e-01,
  2958. };
  2959. static const T B[] = {
  2960. +7.51729631084210481353e-18,
  2961. +4.41434832307170791151e-18,
  2962. -4.65030536848935832153e-17,
  2963. -3.20952592199342395980e-17,
  2964. +2.96262899764595013876e-16,
  2965. +3.30820231092092828324e-16,
  2966. -1.88035477551078244854e-15,
  2967. -3.81440307243700780478e-15,
  2968. +1.04202769841288027642e-14,
  2969. +4.27244001671195135429e-14,
  2970. -2.10154184277266431302e-14,
  2971. -4.08355111109219731823e-13,
  2972. -7.19855177624590851209e-13,
  2973. +2.03562854414708950722e-12,
  2974. +1.41258074366137813316e-11,
  2975. +3.25260358301548823856e-11,
  2976. -1.89749581235054123450e-11,
  2977. -5.58974346219658380687e-10,
  2978. -3.83538038596423702205e-09,
  2979. -2.63146884688951950684e-08,
  2980. -2.51223623787020892529e-07,
  2981. -3.88256480887769039346e-06,
  2982. -1.10588938762623716291e-04,
  2983. -9.76109749136146840777e-03,
  2984. +7.78576235018280120474e-01,
  2985. };
  2986. T p;
  2987. T q = 0.0;
  2988. if (std::abs(x) <= T(8.0)) {
  2989. T a = A[0];
  2990. for (uint8_t index = 1; index < 29; index++) {
  2991. p = q;
  2992. q = a;
  2993. a = ((std::abs(x) / T(2.0)) - T(2.0)) * q - p + A[index];
  2994. }
  2995. if (x < T(0.0)) {
  2996. return -(T(0.5) * (a - p) * std::abs(x) * std::exp(std::abs(x)));
  2997. }
  2998. return T(0.5) * (a - p) * std::abs(x) * std::exp(std::abs(x));
  2999. }
  3000. T b = B[0];
  3001. for (uint8_t index = 1; index < 25; index++) {
  3002. p = q;
  3003. q = b;
  3004. b = (T(32.0) / std::abs(x) - T(2.0)) * q - p + B[index];
  3005. }
  3006. if (x < T(0.0)) {
  3007. return -(std::exp(std::abs(x)) * (T(0.5) * (b - p)) / std::sqrt(std::abs(x)));
  3008. }
  3009. return std::exp(std::abs(x)) * (T(0.5) * (b - p)) / std::sqrt(std::abs(x));
  3010. } // modified_bessel_i1_forward(T x)
  3011. template<typename T>
  3012. inline C10_HOST_DEVICE T modified_bessel_k0_forward(T x) {
  3013. static const T A[] = {
  3014. +1.37446543561352307156e-16,
  3015. +4.25981614279661018399e-14,
  3016. +1.03496952576338420167e-11,
  3017. +1.90451637722020886025e-09,
  3018. +2.53479107902614945675e-07,
  3019. +2.28621210311945178607e-05,
  3020. +1.26461541144692592338e-03,
  3021. +3.59799365153615016266e-02,
  3022. +3.44289899924628486886e-01,
  3023. -5.35327393233902768720e-01,
  3024. };
  3025. static const T B[] = {
  3026. +5.30043377268626276149e-18,
  3027. -1.64758043015242134646e-17,
  3028. +5.21039150503902756861e-17,
  3029. -1.67823109680541210385e-16,
  3030. +5.51205597852431940784e-16,
  3031. -1.84859337734377901440e-15,
  3032. +6.34007647740507060557e-15,
  3033. -2.22751332699166985548e-14,
  3034. +8.03289077536357521100e-14,
  3035. -2.98009692317273043925e-13,
  3036. +1.14034058820847496303e-12,
  3037. -4.51459788337394416547e-12,
  3038. +1.85594911495471785253e-11,
  3039. -7.95748924447710747776e-11,
  3040. +3.57739728140030116597e-10,
  3041. -1.69753450938905987466e-09,
  3042. +8.57403401741422608519e-09,
  3043. -4.66048989768794782956e-08,
  3044. +2.76681363944501510342e-07,
  3045. -1.83175552271911948767e-06,
  3046. +1.39498137188764993662e-05,
  3047. -1.28495495816278026384e-04,
  3048. +1.56988388573005337491e-03,
  3049. -3.14481013119645005427e-02,
  3050. +2.44030308206595545468e+00,
  3051. };
  3052. if (x == T(0.0)) {
  3053. return std::numeric_limits<T>::infinity();
  3054. }
  3055. if (x < T(0.0)) {
  3056. return std::numeric_limits<T>::quiet_NaN();
  3057. }
  3058. T p;
  3059. T q = 0.0;
  3060. if (x <= T(2.0)) {
  3061. T a = A[0];
  3062. for (uint8_t index = 1; index < 10; index++) {
  3063. p = q;
  3064. q = a;
  3065. a = (x * x - T(2.0)) * q - p + A[index];
  3066. }
  3067. return T(0.5) * (a - p) - std::log(0.5 * x) * modified_bessel_i0_forward(x);
  3068. }
  3069. T b = B[0];
  3070. for (uint8_t index = 1; index < 25; index++) {
  3071. p = q;
  3072. q = b;
  3073. b = (T(8.0) / x - T(2.0)) * q - p + B[index];
  3074. }
  3075. return std::exp(-x) * (T(0.5) * (b - p)) / std::sqrt(x);
  3076. } // modified_bessel_k0_forward(T x)
  3077. template<typename T>
  3078. inline C10_HOST_DEVICE T modified_bessel_k1_forward(T x) {
  3079. static const T A[] = {
  3080. -7.02386347938628759343e-18,
  3081. -2.42744985051936593393e-15,
  3082. -6.66690169419932900609e-13,
  3083. -1.41148839263352776110e-10,
  3084. -2.21338763073472585583e-08,
  3085. -2.43340614156596823496e-06,
  3086. -1.73028895751305206302e-04,
  3087. -6.97572385963986435018e-03,
  3088. -1.22611180822657148235e-01,
  3089. -3.53155960776544875667e-01,
  3090. +1.52530022733894777053e+00,
  3091. };
  3092. static const T B[] = {
  3093. -5.75674448366501715755e-18,
  3094. +1.79405087314755922667e-17,
  3095. -5.68946255844285935196e-17,
  3096. +1.83809354436663880070e-16,
  3097. -6.05704724837331885336e-16,
  3098. +2.03870316562433424052e-15,
  3099. -7.01983709041831346144e-15,
  3100. +2.47715442448130437068e-14,
  3101. -8.97670518232499435011e-14,
  3102. +3.34841966607842919884e-13,
  3103. -1.28917396095102890680e-12,
  3104. +5.13963967348173025100e-12,
  3105. -2.12996783842756842877e-11,
  3106. +9.21831518760500529508e-11,
  3107. -4.19035475934189648750e-10,
  3108. +2.01504975519703286596e-09,
  3109. -1.03457624656780970260e-08,
  3110. +5.74108412545004946722e-08,
  3111. -3.50196060308781257119e-07,
  3112. +2.40648494783721712015e-06,
  3113. -1.93619797416608296024e-05,
  3114. +1.95215518471351631108e-04,
  3115. -2.85781685962277938680e-03,
  3116. +1.03923736576817238437e-01,
  3117. +2.72062619048444266945e+00,
  3118. };
  3119. if (x == T(0.0)) {
  3120. return std::numeric_limits<T>::infinity();
  3121. }
  3122. if (x < T(0.0)) {
  3123. return std::numeric_limits<T>::quiet_NaN();
  3124. }
  3125. T p;
  3126. T q = 0.0;
  3127. if (x <= T(2.0)) {
  3128. T a = A[0];
  3129. for (uint8_t index = 1; index < 11; index++) {
  3130. p = q;
  3131. q = a;
  3132. a = (x * x - T(2.0)) * q - p + A[index];
  3133. }
  3134. return std::log(T(0.5) * x) * modified_bessel_i1_forward(x) + T(0.5) * (a - p) / x;
  3135. }
  3136. T b = B[0];
  3137. for (uint8_t index = 1; index < 25; index++) {
  3138. p = q;
  3139. q = b;
  3140. b = (T(8.0) / x - T(2.0)) * q - p + B[index];
  3141. }
  3142. return std::exp(-x) * (T(0.5) * (b - p)) / std::sqrt(x);
  3143. } // modified_bessel_k1_forward(T x)
  3144. template<typename T>
  3145. inline C10_HOST_DEVICE T scaled_modified_bessel_k0_forward(T x) {
  3146. static const T A[] = {
  3147. +1.37446543561352307156e-16,
  3148. +4.25981614279661018399e-14,
  3149. +1.03496952576338420167e-11,
  3150. +1.90451637722020886025e-09,
  3151. +2.53479107902614945675e-07,
  3152. +2.28621210311945178607e-05,
  3153. +1.26461541144692592338e-03,
  3154. +3.59799365153615016266e-02,
  3155. +3.44289899924628486886e-01,
  3156. -5.35327393233902768720e-01,
  3157. };
  3158. static const T B[] = {
  3159. +5.30043377268626276149e-18,
  3160. -1.64758043015242134646e-17,
  3161. +5.21039150503902756861e-17,
  3162. -1.67823109680541210385e-16,
  3163. +5.51205597852431940784e-16,
  3164. -1.84859337734377901440e-15,
  3165. +6.34007647740507060557e-15,
  3166. -2.22751332699166985548e-14,
  3167. +8.03289077536357521100e-14,
  3168. -2.98009692317273043925e-13,
  3169. +1.14034058820847496303e-12,
  3170. -4.51459788337394416547e-12,
  3171. +1.85594911495471785253e-11,
  3172. -7.95748924447710747776e-11,
  3173. +3.57739728140030116597e-10,
  3174. -1.69753450938905987466e-09,
  3175. +8.57403401741422608519e-09,
  3176. -4.66048989768794782956e-08,
  3177. +2.76681363944501510342e-07,
  3178. -1.83175552271911948767e-06,
  3179. +1.39498137188764993662e-05,
  3180. -1.28495495816278026384e-04,
  3181. +1.56988388573005337491e-03,
  3182. -3.14481013119645005427e-02,
  3183. +2.44030308206595545468e+00,
  3184. };
  3185. if (x == T(0.0)) {
  3186. return std::numeric_limits<T>::infinity();
  3187. }
  3188. if (x < T(0.0)) {
  3189. return std::numeric_limits<T>::quiet_NaN();
  3190. }
  3191. T p;
  3192. T q = 0.0;
  3193. if (x <= T(2.0)) {
  3194. T a = A[0];
  3195. for (uint64_t index = 1; index < 10; index++) {
  3196. p = q;
  3197. q = a;
  3198. a = (x * x - T(2.0)) * q - p + A[index];
  3199. }
  3200. return (T(0.5) * (a - p) - std::log(T(0.5) * x) * modified_bessel_i0_forward(x)) * std::exp(x);
  3201. }
  3202. T b = B[0];
  3203. for (uint64_t index = 1; index < 25; index++) {
  3204. p = q;
  3205. q = b;
  3206. b = (T(8.0) / x - T(2.0)) * q - p + B[index];
  3207. }
  3208. return T(0.5) * (b - p) / std::sqrt(x);
  3209. } // T scaled_modified_bessel_k0_forward(T x)
  3210. template<typename T>
  3211. inline C10_HOST_DEVICE T scaled_modified_bessel_k1_forward(T x) {
  3212. static const T A[] = {
  3213. -7.02386347938628759343e-18,
  3214. -2.42744985051936593393e-15,
  3215. -6.66690169419932900609e-13,
  3216. -1.41148839263352776110e-10,
  3217. -2.21338763073472585583e-08,
  3218. -2.43340614156596823496e-06,
  3219. -1.73028895751305206302e-04,
  3220. -6.97572385963986435018e-03,
  3221. -1.22611180822657148235e-01,
  3222. -3.53155960776544875667e-01,
  3223. +1.52530022733894777053e+00,
  3224. };
  3225. static const T B[] = {
  3226. -5.75674448366501715755e-18,
  3227. +1.79405087314755922667e-17,
  3228. -5.68946255844285935196e-17,
  3229. +1.83809354436663880070e-16,
  3230. -6.05704724837331885336e-16,
  3231. +2.03870316562433424052e-15,
  3232. -7.01983709041831346144e-15,
  3233. +2.47715442448130437068e-14,
  3234. -8.97670518232499435011e-14,
  3235. +3.34841966607842919884e-13,
  3236. -1.28917396095102890680e-12,
  3237. +5.13963967348173025100e-12,
  3238. -2.12996783842756842877e-11,
  3239. +9.21831518760500529508e-11,
  3240. -4.19035475934189648750e-10,
  3241. +2.01504975519703286596e-09,
  3242. -1.03457624656780970260e-08,
  3243. +5.74108412545004946722e-08,
  3244. -3.50196060308781257119e-07,
  3245. +2.40648494783721712015e-06,
  3246. -1.93619797416608296024e-05,
  3247. +1.95215518471351631108e-04,
  3248. -2.85781685962277938680e-03,
  3249. +1.03923736576817238437e-01,
  3250. +2.72062619048444266945e+00,
  3251. };
  3252. if (x == T(0.0)) {
  3253. return std::numeric_limits<T>::infinity();
  3254. }
  3255. if (x < T(0.0)) {
  3256. return std::numeric_limits<T>::quiet_NaN();
  3257. }
  3258. T p;
  3259. T q = 0.0;
  3260. if (x <= T(2.0)) {
  3261. T a = A[0];
  3262. for (uint64_t index = 1; index < 11; index++) {
  3263. p = q;
  3264. q = a;
  3265. a = (x * x - T(2.0)) * q - p + A[index];
  3266. }
  3267. return (std::log(T(0.5) * x) * modified_bessel_i1_forward(x) + T(0.5) * (a - p) / x) * std::exp(x);
  3268. }
  3269. T b = B[0];
  3270. for (uint64_t index = 1; index < 25; index++) {
  3271. p = q;
  3272. q = b;
  3273. b = (T(8.0) / x - T(2.0)) * q - p + B[index];
  3274. }
  3275. return (T(0.5) * (b - p) / std::sqrt(x));
  3276. } // T scaled_modified_bessel_k1_forward(T x)
  3277. template<typename T>
  3278. inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, int64_t n) {
  3279. if (n < 0) {
  3280. return T(0.0);
  3281. }
  3282. if (x == T(1.0)) {
  3283. return T(1.0);
  3284. }
  3285. if (x == T(0.0)) {
  3286. if (n % 2 == 0) {
  3287. return T(1.0);
  3288. }
  3289. return T(-1.0);
  3290. }
  3291. if ((n > 6) && (std::abs(x + x - T(1.0)) < T(1.0))) {
  3292. return std::cos(n * std::acos(x + x - T(1.0)));
  3293. }
  3294. if (n == 0) {
  3295. return T(1.0);
  3296. }
  3297. if (n == 1) {
  3298. return x + x - T(1.0);
  3299. }
  3300. T p = T(1.0);
  3301. T q = x + x - T(1.0);
  3302. T r;
  3303. for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) {
  3304. r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p;
  3305. p = q;
  3306. q = r;
  3307. }
  3308. return r;
  3309. } // shifted_chebyshev_polynomial_t_forward(T x, int64_t n)
  3310. template<typename T, bool is_cuda=false>
  3311. inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_t_forward(T x, T n) {
  3312. return shifted_chebyshev_polynomial_t_forward(x, static_cast<int64_t>(n));
  3313. } // shifted_chebyshev_polynomial_t_forward(T x, T n)
  3314. template<typename T>
  3315. inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, int64_t n) {
  3316. if (n < 0) {
  3317. return T(0.0);
  3318. }
  3319. if (x == T(1.0)) {
  3320. return n + 1;
  3321. }
  3322. if (x == T(0.0)) {
  3323. if (n % 2 == 0) {
  3324. return n + 1;
  3325. }
  3326. return -(n + 1);
  3327. }
  3328. if ((n > 6) && (std::abs(x + x - T(1.0)) < T(1.0))) {
  3329. if (std::sin(std::acos(x + x - T(1.0))) != T(0.0)) {
  3330. return std::sin((n + 1) * std::acos(x + x - T(1.0))) / std::sin(std::acos(x + x - T(1.0)));
  3331. }
  3332. return (n + 1) * std::cos((n + 1) * std::acos(x + x - T(1.0))) / (x + x - T(1.0));
  3333. }
  3334. if (n == 0) {
  3335. return T(1.0);
  3336. }
  3337. if (n == 1) {
  3338. return x + x - T(1.0) + (x + x - T(1.0));
  3339. }
  3340. T p = T(1.0);
  3341. T q = x + x - T(1.0) + (x + x - T(1.0));
  3342. T r;
  3343. for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) {
  3344. r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p;
  3345. p = q;
  3346. q = r;
  3347. }
  3348. return r;
  3349. } // shifted_chebyshev_polynomial_u_forward(T x, int64_t n)
  3350. template<typename T, bool is_cuda=false>
  3351. inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_u_forward(T x, T n) {
  3352. return shifted_chebyshev_polynomial_u_forward(x, static_cast<int64_t>(n));
  3353. } // shifted_chebyshev_polynomial_u_forward(T x, T n)
  3354. template<typename T>
  3355. inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, int64_t n) {
  3356. if (n < 0) {
  3357. return T(0.0);
  3358. }
  3359. if (x == T(1.0)) {
  3360. return T(1.0);
  3361. }
  3362. if (x == T(0.0)) {
  3363. if (n % 2 == 0) {
  3364. return (n + n + 1);
  3365. }
  3366. return -(n + n + 1);
  3367. }
  3368. if ((n > 6) && (std::abs(x + x - T(1.0)) < T(1.0))) {
  3369. if (std::sin(std::acos(x + x - T(1.0)) / T(2.0)) != T(1.0)) {
  3370. return std::cos((n + T(0.5)) * std::acos(x + x - T(1.0))) / std::cos(std::acos(x + x - T(1.0)) / T(2.0));
  3371. }
  3372. if (n % 2 == 0) {
  3373. return n + n + 1;
  3374. }
  3375. return -(n + n + 1);
  3376. }
  3377. if (n == 0) {
  3378. return T(1.0);
  3379. }
  3380. if (n == 1) {
  3381. return x + x - T(1.0) + (x + x - T(1.0)) - T(1.0);
  3382. }
  3383. T p = T(1.0);
  3384. T q = x + x - T(1.0) + (x + x - T(1.0)) - T(1.0);
  3385. T r;
  3386. for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) {
  3387. r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p;
  3388. p = q;
  3389. q = r;
  3390. }
  3391. return r;
  3392. } // shifted_chebyshev_polynomial_v_forward(T x, int64_t n)
  3393. template<typename T, bool is_cuda=false>
  3394. inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_v_forward(T x, T n) {
  3395. return shifted_chebyshev_polynomial_v_forward(x, static_cast<int64_t>(n));
  3396. } // shifted_chebyshev_polynomial_v_forward(T x, T n)
  3397. template<typename T>
  3398. inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, int64_t n) {
  3399. if (n < 0) {
  3400. return T(0.0);
  3401. }
  3402. if (x == T(1.0)) {
  3403. return n + n + 1;
  3404. }
  3405. if (x == T(0.0)) {
  3406. if (n % 2 == 0) {
  3407. return T(1.0);
  3408. }
  3409. return T(-1.0);
  3410. }
  3411. if ((n > 4) && (std::abs(x + x - T(1.0)) < T(1.0))) {
  3412. if (std::cos(std::acos(x + x - T(1.0)) / T(2.0)) != T(1.0)) {
  3413. return std::sin((n + T(0.5)) * std::acos(x + x - T(1.0))) / std::sin(std::acos(x + x - T(1.0)) / T(2.0));
  3414. }
  3415. if (n % 2 == 0) {
  3416. return T(1.0);
  3417. }
  3418. return T(-1.0);
  3419. }
  3420. if (n == 0) {
  3421. return T(1.0);
  3422. }
  3423. if (n == 1) {
  3424. return x + x - T(1.0) + (x + x - T(1.0)) + T(1.0);
  3425. }
  3426. T p = T(1.0);
  3427. T q = x + x - T(1.0) + (x + x - T(1.0)) + T(1.0);
  3428. T r;
  3429. for (int64_t k = 2; (k <= n) && !std::isnan(q); k++) {
  3430. r = (x + x - T(1.0) + (x + x - T(1.0))) * q - p;
  3431. p = q;
  3432. q = r;
  3433. }
  3434. return r;
  3435. } // shifted_chebyshev_polynomial_w_forward(T x, int64_t n)
  3436. template<typename T, bool is_cuda=false>
  3437. inline C10_HOST_DEVICE T shifted_chebyshev_polynomial_w_forward(T x, T n) {
  3438. return shifted_chebyshev_polynomial_w_forward(x, static_cast<int64_t>(n));
  3439. } // shifted_chebyshev_polynomial_w_forward(T x, T n)
  3440. template<typename T>
  3441. inline C10_HOST_DEVICE T spherical_bessel_j0_forward(T x) {
  3442. if (std::isinf(x)) {
  3443. return T(0.0);
  3444. }
  3445. if (std::abs(x) < T(0.5)) {
  3446. return T(1.0) + x * x * (T(-1.0) / T(6.0) + x * x * (T(1.0) / T(120.0) + x * x * (T(-1.0) / T(5040.0) + x * x * (T(1.0) / T(362880.0) + x * x * (T(-1.0) / T(39916800.0) + x * x * (T(1.0) / T(6227020800.0)))))));
  3447. }
  3448. return std::sin(x) / x;
  3449. } // T spherical_bessel_j0_forward(T x)
  3450. C10_CLANG_DIAGNOSTIC_POP()
  3451. #else
  3452. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  3453. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)