igamma.h 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/metal/utils.h>
  4. #include <metal_math>
  5. #include <metal_stdlib>
  6. using namespace c10::metal;
  7. using namespace metal;
  8. namespace c10 {
  9. namespace metal {
  10. template <typename T>
  11. inline float log_gamma(const T);
  12. inline float expm1f(float a);
  13. template <typename T>
  14. float erfc(T x);
  15. } // namespace metal
  16. } // namespace c10
  17. namespace {
  18. template <typename T>
  19. inline float lgamma(const T a) {
  20. return log_gamma(a);
  21. }
  22. inline float expm1(float a) {
  23. return expm1f(a);
  24. }
  25. // NOTE: The following code was ported directly from the CUDA implementation in
  26. // `aten/src/ATen/native/cuda/IGammaKernel.cu`
  27. /*
  28. * This implementation of the regularized incomplete gamma functions and
  29. * their helper functions are derived from the implementation of SciPy's
  30. * gammainc, Cephes's igam and igamc, and Boost's Lanczos approximations.
  31. * See NOTICE for the licenses.
  32. */
  33. // regularized lower & upper incomplete gamma
  34. template <typename scalar_t>
  35. scalar_t ratevl(
  36. scalar_t x,
  37. const scalar_t num[],
  38. int64_t M,
  39. const scalar_t denom[],
  40. int64_t N) {
  41. // evaluating rational function, i.e., the ratio of two polynomials
  42. // the coefficients for numerator are given by `num` while coeffs for
  43. // denumerator are given by `denom`
  44. using accscalar_t = opmath_t<scalar_t>;
  45. int64_t i, dir;
  46. accscalar_t y, num_ans, denom_ans;
  47. accscalar_t absx = ::fabs(x);
  48. thread const accscalar_t* p;
  49. if (absx > 1) {
  50. /* Evaluate as a polynomial in 1/x. */
  51. dir = -1;
  52. p = num + M;
  53. y = 1 / x;
  54. } else {
  55. dir = 1;
  56. p = num;
  57. y = x;
  58. }
  59. /* Evaluate the numerator */
  60. num_ans = *p;
  61. p += dir;
  62. for (i = 1; i <= M; i++) {
  63. num_ans = num_ans * y + *p;
  64. p += dir;
  65. }
  66. /* Evaluate the denominator */
  67. if (absx > 1) {
  68. p = denom + N;
  69. } else {
  70. p = denom;
  71. }
  72. denom_ans = *p;
  73. p += dir;
  74. for (i = 1; i <= N; i++) {
  75. denom_ans = denom_ans * y + *p;
  76. p += dir;
  77. }
  78. if (absx > 1) {
  79. i = N - M;
  80. return ::pow(x, static_cast<accscalar_t>(i)) * num_ans / denom_ans;
  81. } else {
  82. return num_ans / denom_ans;
  83. }
  84. }
  85. template <typename scalar_t>
  86. scalar_t lanczos_sum_expg_scaled(scalar_t x) {
  87. // lanczos approximation
  88. using accscalar_t = opmath_t<scalar_t>;
  89. const accscalar_t lanczos_sum_expg_scaled_num[13] = {
  90. 0.006061842346248906525783753964555936883222,
  91. 0.5098416655656676188125178644804694509993,
  92. 19.51992788247617482847860966235652136208,
  93. 449.9445569063168119446858607650988409623,
  94. 6955.999602515376140356310115515198987526,
  95. 75999.29304014542649875303443598909137092,
  96. 601859.6171681098786670226533699352302507,
  97. 3481712.15498064590882071018964774556468,
  98. 14605578.08768506808414169982791359218571,
  99. 43338889.32467613834773723740590533316085,
  100. 86363131.28813859145546927288977868422342,
  101. 103794043.1163445451906271053616070238554,
  102. 56906521.91347156388090791033559122686859};
  103. const accscalar_t lanczos_sum_expg_scaled_denom[13] = {
  104. 1.,
  105. 66.,
  106. 1925.,
  107. 32670.,
  108. 357423.,
  109. 2637558.,
  110. 13339535.,
  111. 45995730.,
  112. 105258076.,
  113. 150917976.,
  114. 120543840.,
  115. 39916800.,
  116. 0};
  117. return ratevl(
  118. static_cast<accscalar_t>(x),
  119. lanczos_sum_expg_scaled_num,
  120. sizeof(lanczos_sum_expg_scaled_num) /
  121. sizeof(lanczos_sum_expg_scaled_num[0]) -
  122. 1,
  123. lanczos_sum_expg_scaled_denom,
  124. sizeof(lanczos_sum_expg_scaled_denom) /
  125. sizeof(lanczos_sum_expg_scaled_denom[0]) -
  126. 1);
  127. }
  128. template <typename scalar_t>
  129. scalar_t _igam_helper_fac(scalar_t a, scalar_t x) {
  130. // compute x^a * exp(-a) / gamma(a)
  131. // corrected from (15) and (16) in [igam2] by replacing exp(x - a) with
  132. // exp(a - x).
  133. using accscalar_t = opmath_t<scalar_t>;
  134. accscalar_t ax, fac, res, num, numfac;
  135. const accscalar_t MAXLOG = 88.72283905206835;
  136. const accscalar_t EXP1 = 2.718281828459045;
  137. const accscalar_t lanczos_g = 6.024680040776729583740234375;
  138. if (::fabs(a - x) > 0.4 * ::fabs(a)) {
  139. ax = a * ::log(x) - x - ::lgamma(a);
  140. if (ax < -MAXLOG) {
  141. return 0.0;
  142. }
  143. return ::exp(ax);
  144. }
  145. fac = a + lanczos_g - 0.5;
  146. res = ::sqrt(fac / EXP1) / lanczos_sum_expg_scaled(a);
  147. if ((a < 200) && (x < 200)) {
  148. res *= ::exp(a - x) * ::pow(x / fac, a);
  149. } else {
  150. num = x - a - lanczos_g + 0.5;
  151. numfac = num / fac;
  152. res *= ::exp(a * (::log1p(numfac) - numfac) + x * (0.5 - lanczos_g) / fac);
  153. }
  154. return res;
  155. }
  156. template <typename scalar_t>
  157. scalar_t _igam_helper_series(scalar_t a, scalar_t x) {
  158. // Compute igam using DLMF 8.11.4. [igam1]
  159. using accscalar_t = opmath_t<scalar_t>;
  160. const accscalar_t MACHEP = 5.9604644775390625E-8;
  161. const int MAXITER = 2000;
  162. int i;
  163. accscalar_t ans, ax, c, r;
  164. ax = _igam_helper_fac(a, x);
  165. if (ax == 0.0) {
  166. return 0.0;
  167. }
  168. /* power series */
  169. r = a;
  170. c = 1.0;
  171. ans = 1.0;
  172. for (i = 0; i < MAXITER; i++) {
  173. r += 1.0;
  174. c *= x / r;
  175. ans += c;
  176. if (c <= MACHEP * ans) {
  177. break;
  178. }
  179. }
  180. return (ans * ax / a);
  181. }
  182. template <typename scalar_t>
  183. scalar_t _igamc_helper_series(scalar_t a, scalar_t x) {
  184. // Compute igamc using DLMF 8.7.3 [igam1]. This is related to the series in
  185. // _igam_helper_series but extra care is taken to avoid cancellation.
  186. using accscalar_t = opmath_t<scalar_t>;
  187. int n;
  188. accscalar_t fac = 1;
  189. accscalar_t sum = 0;
  190. accscalar_t term, logx;
  191. const int MAXITER = 2000;
  192. const accscalar_t MACHEP = 5.9604644775390625E-8;
  193. for (n = 1; n < MAXITER; n++) {
  194. fac *= -x / n;
  195. term = fac / (a + n);
  196. sum += term;
  197. if (::fabs(term) <= MACHEP * ::fabs(sum)) {
  198. break;
  199. }
  200. }
  201. logx = ::log(x);
  202. term = -::expm1(a * logx - ::lgamma(1 + a));
  203. return term - ::exp(a * logx - ::lgamma(a)) * sum;
  204. }
  205. template <typename scalar_t>
  206. scalar_t _igam_helper_asymptotic_series(scalar_t a, scalar_t x, bool igam) {
  207. // Compute igam/igamc using DLMF 8.12.3/8.12.4 [igam1]
  208. using accscalar_t = opmath_t<scalar_t>;
  209. const accscalar_t d[25][25] = {
  210. {-3.3333333333333333e-1, 8.3333333333333333e-2,
  211. -1.4814814814814815e-2, 1.1574074074074074e-3,
  212. 3.527336860670194e-4, -1.7875514403292181e-4,
  213. 3.9192631785224378e-5, -2.1854485106799922e-6,
  214. -1.85406221071516e-6, 8.296711340953086e-7,
  215. -1.7665952736826079e-7, 6.7078535434014986e-9,
  216. 1.0261809784240308e-8, -4.3820360184533532e-9,
  217. 9.1476995822367902e-10, -2.551419399494625e-11,
  218. -5.8307721325504251e-11, 2.4361948020667416e-11,
  219. -5.0276692801141756e-12, 1.1004392031956135e-13,
  220. 3.3717632624009854e-13, -1.3923887224181621e-13,
  221. 2.8534893807047443e-14, -5.1391118342425726e-16,
  222. -1.9752288294349443e-15},
  223. {-1.8518518518518519e-3, -3.4722222222222222e-3, 2.6455026455026455e-3,
  224. -9.9022633744855967e-4, 2.0576131687242798e-4, -4.0187757201646091e-7,
  225. -1.8098550334489978e-5, 7.6491609160811101e-6, -1.6120900894563446e-6,
  226. 4.6471278028074343e-9, 1.378633446915721e-7, -5.752545603517705e-8,
  227. 1.1951628599778147e-8, -1.7543241719747648e-11, -1.0091543710600413e-9,
  228. 4.1627929918425826e-10, -8.5639070264929806e-11, 6.0672151016047586e-14,
  229. 7.1624989648114854e-12, -2.9331866437714371e-12, 5.9966963656836887e-13,
  230. -2.1671786527323314e-16, -4.9783399723692616e-14, 2.0291628823713425e-14,
  231. -4.13125571381061e-15},
  232. {4.1335978835978836e-3, -2.6813271604938272e-3, 7.7160493827160494e-4,
  233. 2.0093878600823045e-6, -1.0736653226365161e-4, 5.2923448829120125e-5,
  234. -1.2760635188618728e-5, 3.4235787340961381e-8, 1.3721957309062933e-6,
  235. -6.298992138380055e-7, 1.4280614206064242e-7, -2.0477098421990866e-10,
  236. -1.4092529910867521e-8, 6.228974084922022e-9, -1.3670488396617113e-9,
  237. 9.4283561590146782e-13, 1.2872252400089318e-10, -5.5645956134363321e-11,
  238. 1.1975935546366981e-11, -4.1689782251838635e-15, -1.0940640427884594e-12,
  239. 4.6622399463901357e-13, -9.905105763906906e-14, 1.8931876768373515e-17,
  240. 8.8592218725911273e-15},
  241. {6.4943415637860082e-4, 2.2947209362139918e-4, -4.6918949439525571e-4,
  242. 2.6772063206283885e-4, -7.5618016718839764e-5, -2.3965051138672967e-7,
  243. 1.1082654115347302e-5, -5.6749528269915966e-6, 1.4230900732435884e-6,
  244. -2.7861080291528142e-11, -1.6958404091930277e-7, 8.0994649053880824e-8,
  245. -1.9111168485973654e-8, 2.3928620439808118e-12, 2.0620131815488798e-9,
  246. -9.4604966618551322e-10, 2.1541049775774908e-10, -1.388823336813903e-14,
  247. -2.1894761681963939e-11, 9.7909989511716851e-12, -2.1782191880180962e-12,
  248. 6.2088195734079014e-17, 2.126978363279737e-13, -9.3446887915174333e-14,
  249. 2.0453671226782849e-14},
  250. {-8.618882909167117e-4, 7.8403922172006663e-4,
  251. -2.9907248030319018e-4, -1.4638452578843418e-6,
  252. 6.6414982154651222e-5, -3.9683650471794347e-5,
  253. 1.1375726970678419e-5, 2.5074972262375328e-10,
  254. -1.6954149536558306e-6, 8.9075075322053097e-7,
  255. -2.2929348340008049e-7, 2.956794137544049e-11,
  256. 2.8865829742708784e-8, -1.4189739437803219e-8,
  257. 3.4463580499464897e-9, -2.3024517174528067e-13,
  258. -3.9409233028046405e-10, 1.8602338968504502e-10,
  259. -4.356323005056618e-11, 1.2786001016296231e-15,
  260. 4.6792750266579195e-12, -2.1492464706134829e-12,
  261. 4.9088156148096522e-13, -6.3385914848915603e-18,
  262. -5.0453320690800944e-14},
  263. {-3.3679855336635815e-4, -6.9728137583658578e-5, 2.7727532449593921e-4,
  264. -1.9932570516188848e-4, 6.7977804779372078e-5, 1.419062920643967e-7,
  265. -1.3594048189768693e-5, 8.0184702563342015e-6, -2.2914811765080952e-6,
  266. -3.252473551298454e-10, 3.4652846491085265e-7, -1.8447187191171343e-7,
  267. 4.8240967037894181e-8, -1.7989466721743515e-14, -6.3061945000135234e-9,
  268. 3.1624176287745679e-9, -7.8409242536974293e-10, 5.1926791652540407e-15,
  269. 9.3589442423067836e-11, -4.5134262161632782e-11, 1.0799129993116827e-11,
  270. -3.661886712685252e-17, -1.210902069055155e-12, 5.6807435849905643e-13,
  271. -1.3249659916340829e-13},
  272. {5.3130793646399222e-4, -5.9216643735369388e-4, 2.7087820967180448e-4,
  273. 7.9023532326603279e-7, -8.1539693675619688e-5, 5.6116827531062497e-5,
  274. -1.8329116582843376e-5, -3.0796134506033048e-9, 3.4651553688036091e-6,
  275. -2.0291327396058604e-6, 5.7887928631490037e-7, 2.338630673826657e-13,
  276. -8.8286007463304835e-8, 4.7435958880408128e-8, -1.2545415020710382e-8,
  277. 8.6496488580102925e-14, 1.6846058979264063e-9, -8.5754928235775947e-10,
  278. 2.1598224929232125e-10, -7.6132305204761539e-16, -2.6639822008536144e-11,
  279. 1.3065700536611057e-11, -3.1799163902367977e-12, 4.7109761213674315e-18,
  280. 3.6902800842763467e-13},
  281. {3.4436760689237767e-4, 5.1717909082605922e-5,
  282. -3.3493161081142236e-4, 2.812695154763237e-4,
  283. -1.0976582244684731e-4, -1.2741009095484485e-7,
  284. 2.7744451511563644e-5, -1.8263488805711333e-5,
  285. 5.7876949497350524e-6, 4.9387589339362704e-10,
  286. -1.0595367014026043e-6, 6.1667143761104075e-7,
  287. -1.7562973359060462e-7, -1.2974473287015439e-12,
  288. 2.695423606288966e-8, -1.4578352908731271e-8,
  289. 3.887645959386175e-9, -3.8810022510194121e-17,
  290. -5.3279941738772867e-10, 2.7437977643314845e-10,
  291. -6.9957960920705679e-11, 2.5899863874868481e-17,
  292. 8.8566890996696381e-12, -4.403168815871311e-12,
  293. 1.0865561947091654e-12},
  294. {-6.5262391859530942e-4, 8.3949872067208728e-4, -4.3829709854172101e-4,
  295. -6.969091458420552e-7, 1.6644846642067548e-4, -1.2783517679769219e-4,
  296. 4.6299532636913043e-5, 4.5579098679227077e-9, -1.0595271125805195e-5,
  297. 6.7833429048651666e-6, -2.1075476666258804e-6, -1.7213731432817145e-11,
  298. 3.7735877416110979e-7, -2.1867506700122867e-7, 6.2202288040189269e-8,
  299. 6.5977038267330006e-16, -9.5903864974256858e-9, 5.2132144922808078e-9,
  300. -1.3991589583935709e-9, 5.382058999060575e-16, 1.9484714275467745e-10,
  301. -1.0127287556389682e-10, 2.6077347197254926e-11, -5.0904186999932993e-18,
  302. -3.3721464474854592e-12},
  303. {-5.9676129019274625e-4, -7.2048954160200106e-5,
  304. 6.7823088376673284e-4, -6.4014752602627585e-4,
  305. 2.7750107634328704e-4, 1.8197008380465151e-7,
  306. -8.4795071170685032e-5, 6.105192082501531e-5,
  307. -2.1073920183404862e-5, -8.8585890141255994e-10,
  308. 4.5284535953805377e-6, -2.8427815022504408e-6,
  309. 8.7082341778646412e-7, 3.6886101871706965e-12,
  310. -1.5344695190702061e-7, 8.862466778790695e-8,
  311. -2.5184812301826817e-8, -1.0225912098215092e-14,
  312. 3.8969470758154777e-9, -2.1267304792235635e-9,
  313. 5.7370135528051385e-10, -1.887749850169741e-19,
  314. -8.0931538694657866e-11, 4.2382723283449199e-11,
  315. -1.1002224534207726e-11},
  316. {1.3324454494800656e-3, -1.9144384985654775e-3, 1.1089369134596637e-3,
  317. 9.932404122642299e-7, -5.0874501293093199e-4, 4.2735056665392884e-4,
  318. -1.6858853767910799e-4, -8.1301893922784998e-9, 4.5284402370562147e-5,
  319. -3.127053674781734e-5, 1.044986828530338e-5, 4.8435226265680926e-11,
  320. -2.1482565873456258e-6, 1.329369701097492e-6, -4.0295693092101029e-7,
  321. -1.7567877666323291e-13, 7.0145043163668257e-8, -4.040787734999483e-8,
  322. 1.1474026743371963e-8, 3.9642746853563325e-18, -1.7804938269892714e-9,
  323. 9.7480262548731646e-10, -2.6405338676507616e-10, 5.794875163403742e-18,
  324. 3.7647749553543836e-11},
  325. {1.579727660730835e-3, 1.6251626278391582e-4, -2.0633421035543276e-3,
  326. 2.1389686185689098e-3, -1.0108559391263003e-3, -3.9912705529919201e-7,
  327. 3.6235025084764691e-4, -2.8143901463712154e-4, 1.0449513336495887e-4,
  328. 2.1211418491830297e-9, -2.5779417251947842e-5, 1.7281818956040463e-5,
  329. -5.6413773872904282e-6, -1.1024320105776174e-11, 1.1223224418895175e-6,
  330. -6.8693396379526735e-7, 2.0653236975414887e-7, 4.6714772409838506e-14,
  331. -3.5609886164949055e-8, 2.0470855345905963e-8, -5.8091738633283358e-9,
  332. -1.332821287582869e-16, 9.0354604391335133e-10, -4.9598782517330834e-10,
  333. 1.3481607129399749e-10},
  334. {-4.0725121195140166e-3, 6.4033628338080698e-3, -4.0410161081676618e-3,
  335. -2.183732802866233e-6, 2.1740441801254639e-3, -1.9700440518418892e-3,
  336. 8.3595469747962458e-4, 1.9445447567109655e-8, -2.5779387120421696e-4,
  337. 1.9009987368139304e-4, -6.7696499937438965e-5, -1.4440629666426572e-10,
  338. 1.5712512518742269e-5, -1.0304008744776893e-5, 3.304517767401387e-6,
  339. 7.9829760242325709e-13, -6.4097794149313004e-7, 3.8894624761300056e-7,
  340. -1.1618347644948869e-7, -2.816808630596451e-15, 1.9878012911297093e-8,
  341. -1.1407719956357511e-8, 3.2355857064185555e-9, 4.1759468293455945e-20,
  342. -5.0423112718105824e-10},
  343. {-5.9475779383993003e-3, -5.4016476789260452e-4, 8.7910413550767898e-3,
  344. -9.8576315587856125e-3, 5.0134695031021538e-3, 1.2807521786221875e-6,
  345. -2.0626019342754683e-3, 1.7109128573523058e-3, -6.7695312714133799e-4,
  346. -6.9011545676562133e-9, 1.8855128143995902e-4, -1.3395215663491969e-4,
  347. 4.6263183033528039e-5, 4.0034230613321351e-11, -1.0255652921494033e-5,
  348. 6.612086372797651e-6, -2.0913022027253008e-6, -2.0951775649603837e-13,
  349. 3.9756029041993247e-7, -2.3956211978815887e-7, 7.1182883382145864e-8,
  350. 8.925574873053455e-16, -1.2101547235064676e-8, 6.9350618248334386e-9,
  351. -1.9661464453856102e-9},
  352. {1.7402027787522711e-2, -2.9527880945699121e-2, 2.0045875571402799e-2,
  353. 7.0289515966903407e-6, -1.2375421071343148e-2, 1.1976293444235254e-2,
  354. -5.4156038466518525e-3, -6.3290893396418616e-8, 1.8855118129005065e-3,
  355. -1.473473274825001e-3, 5.5515810097708387e-4, 5.2406834412550662e-10,
  356. -1.4357913535784836e-4, 9.9181293224943297e-5, -3.3460834749478311e-5,
  357. -3.5755837291098993e-12, 7.1560851960630076e-6, -4.5516802628155526e-6,
  358. 1.4236576649271475e-6, 1.8803149082089664e-14, -2.6623403898929211e-7,
  359. 1.5950642189595716e-7, -4.7187514673841102e-8, -6.5107872958755177e-17,
  360. 7.9795091026746235e-9},
  361. {3.0249124160905891e-2, 2.4817436002649977e-3, -4.9939134373457022e-2,
  362. 5.9915643009307869e-2, -3.2483207601623391e-2, -5.7212968652103441e-6,
  363. 1.5085251778569354e-2, -1.3261324005088445e-2, 5.5515262632426148e-3,
  364. 3.0263182257030016e-8, -1.7229548406756723e-3, 1.2893570099929637e-3,
  365. -4.6845138348319876e-4, -1.830259937893045e-10, 1.1449739014822654e-4,
  366. -7.7378565221244477e-5, 2.5625836246985201e-5, 1.0766165333192814e-12,
  367. -5.3246809282422621e-6, 3.349634863064464e-6, -1.0381253128684018e-6,
  368. -5.608909920621128e-15, 1.9150821930676591e-7, -1.1418365800203486e-7,
  369. 3.3654425209171788e-8},
  370. {-9.9051020880159045e-2, 1.7954011706123486e-1, -1.2989606383463778e-1,
  371. -3.1478872752284357e-5, 9.0510635276848131e-2, -9.2828824411184397e-2,
  372. 4.4412112839877808e-2, 2.7779236316835888e-7, -1.7229543805449697e-2,
  373. 1.4182925050891573e-2, -5.6214161633747336e-3, -2.39598509186381e-9,
  374. 1.6029634366079908e-3, -1.1606784674435773e-3, 4.1001337768153873e-4,
  375. 1.8365800754090661e-11, -9.5844256563655903e-5, 6.3643062337764708e-5,
  376. -2.076250624489065e-5, -1.1806020912804483e-13, 4.2131808239120649e-6,
  377. -2.6262241337012467e-6, 8.0770620494930662e-7, 6.0125912123632725e-16,
  378. -1.4729737374018841e-7},
  379. {-1.9994542198219728e-1, -1.5056113040026424e-2, 3.6470239469348489e-1,
  380. -4.6435192311733545e-1, 2.6640934719197893e-1, 3.4038266027147191e-5,
  381. -1.3784338709329624e-1, 1.276467178337056e-1, -5.6213828755200985e-2,
  382. -1.753150885483011e-7, 1.9235592956768113e-2, -1.5088821281095315e-2,
  383. 5.7401854451350123e-3, 1.0622382710310225e-9, -1.5335082692563998e-3,
  384. 1.0819320643228214e-3, -3.7372510193945659e-4, -6.6170909729031985e-12,
  385. 8.4263617380909628e-5, -5.5150706827483479e-5, 1.7769536448348069e-5,
  386. 3.8827923210205533e-14, -3.53513697488768e-6, 2.1865832130045269e-6,
  387. -6.6812849447625594e-7},
  388. {7.2438608504029431e-1, -1.3918010932653375, 1.0654143352413968,
  389. 1.876173868950258e-4, -8.2705501176152696e-1, 8.9352433347828414e-1,
  390. -4.4971003995291339e-1, -1.6107401567546652e-6, 1.9235590165271091e-1,
  391. -1.6597702160042609e-1, 6.8882222681814333e-2, 1.3910091724608687e-8,
  392. -2.146911561508663e-2, 1.6228980898865892e-2, -5.9796016172584256e-3,
  393. -1.1287469112826745e-10, 1.5167451119784857e-3, -1.0478634293553899e-3,
  394. 3.5539072889126421e-4, 8.1704322111801517e-13, -7.7773013442452395e-5,
  395. 5.0291413897007722e-5, -1.6035083867000518e-5, 1.2469354315487605e-14,
  396. 3.1369106244517615e-6},
  397. {1.6668949727276811, 1.165462765994632e-1, -3.3288393225018906,
  398. 4.4692325482864037, -2.6977693045875807, -2.600667859891061e-4,
  399. 1.5389017615694539, -1.4937962361134612, 6.8881964633233148e-1,
  400. 1.3077482004552385e-6, -2.5762963325596288e-1, 2.1097676102125449e-1,
  401. -8.3714408359219882e-2, -7.7920428881354753e-9, 2.4267923064833599e-2,
  402. -1.7813678334552311e-2, 6.3970330388900056e-3, 4.9430807090480523e-11,
  403. -1.5554602758465635e-3, 1.0561196919903214e-3, -3.5277184460472902e-4,
  404. 9.3002334645022459e-14, 7.5285855026557172e-5, -4.8186515569156351e-5,
  405. 1.5227271505597605e-5},
  406. {-6.6188298861372935, 1.3397985455142589e+1, -1.0789350606845146e+1,
  407. -1.4352254537875018e-3, 9.2333694596189809, -1.0456552819547769e+1,
  408. 5.5105526029033471, 1.2024439690716742e-5, -2.5762961164755816,
  409. 2.3207442745387179, -1.0045728797216284, -1.0207833290021914e-7,
  410. 3.3975092171169466e-1, -2.6720517450757468e-1, 1.0235252851562706e-1,
  411. 8.4329730484871625e-10, -2.7998284958442595e-2, 2.0066274144976813e-2,
  412. -7.0554368915086242e-3, 1.9402238183698188e-12, 1.6562888105449611e-3,
  413. -1.1082898580743683e-3, 3.654545161310169e-4, -5.1290032026971794e-11,
  414. -7.6340103696869031e-5},
  415. {-1.7112706061976095e+1, -1.1208044642899116, 3.7131966511885444e+1,
  416. -5.2298271025348962e+1, 3.3058589696624618e+1, 2.4791298976200222e-3,
  417. -2.061089403411526e+1, 2.088672775145582e+1, -1.0045703956517752e+1,
  418. -1.2238783449063012e-5, 4.0770134274221141, -3.473667358470195,
  419. 1.4329352617312006, 7.1359914411879712e-8, -4.4797257159115612e-1,
  420. 3.4112666080644461e-1, -1.2699786326594923e-1, -2.8953677269081528e-10,
  421. 3.3125776278259863e-2, -2.3274087021036101e-2, 8.0399993503648882e-3,
  422. -1.177805216235265e-9, -1.8321624891071668e-3, 1.2108282933588665e-3,
  423. -3.9479941246822517e-4},
  424. {7.389033153567425e+1, -1.5680141270402273e+2, 1.322177542759164e+2,
  425. 1.3692876877324546e-2, -1.2366496885920151e+2, 1.4620689391062729e+2,
  426. -8.0365587724865346e+1, -1.1259851148881298e-4, 4.0770132196179938e+1,
  427. -3.8210340013273034e+1, 1.719522294277362e+1, 9.3519707955168356e-7,
  428. -6.2716159907747034, 5.1168999071852637, -2.0319658112299095,
  429. -4.9507215582761543e-9, 5.9626397294332597e-1, -4.4220765337238094e-1,
  430. 1.6079998700166273e-1, -2.4733786203223402e-8, -4.0307574759979762e-2,
  431. 2.7849050747097869e-2, -9.4751858992054221e-3, 6.419922235909132e-6,
  432. 2.1250180774699461e-3},
  433. {2.1216837098382522e+2, 1.3107863022633868e+1, -4.9698285932871748e+2,
  434. 7.3121595266969204e+2, -4.8213821720890847e+2, -2.8817248692894889e-2,
  435. 3.2616720302947102e+2, -3.4389340280087117e+2, 1.7195193870816232e+2,
  436. 1.4038077378096158e-4, -7.52594195897599e+1, 6.651969984520934e+1,
  437. -2.8447519748152462e+1, -7.613702615875391e-7, 9.5402237105304373,
  438. -7.5175301113311376, 2.8943997568871961, -4.6612194999538201e-7,
  439. -8.0615149598794088e-1, 5.8483006570631029e-1, -2.0845408972964956e-1,
  440. 1.4765818959305817e-4, 5.1000433863753019e-2, -3.3066252141883665e-2,
  441. 1.5109265210467774e-2},
  442. {-9.8959643098322368e+2, 2.1925555360905233e+3, -1.9283586782723356e+3,
  443. -1.5925738122215253e-1, 1.9569985945919857e+3, -2.4072514765081556e+3,
  444. 1.3756149959336496e+3, 1.2920735237496668e-3, -7.525941715948055e+2,
  445. 7.3171668742208716e+2, -3.4137023466220065e+2, -9.9857390260608043e-6,
  446. 1.3356313181291573e+2, -1.1276295161252794e+2, 4.6310396098204458e+1,
  447. -7.9237387133614756e-6, -1.4510726927018646e+1, 1.1111771248100563e+1,
  448. -4.1690817945270892, 3.1008219800117808e-3, 1.1220095449981468,
  449. -7.6052379926149916e-1, 3.6262236505085254e-1, 2.216867741940747e-1,
  450. 4.8683443692930507e-1}};
  451. int k, n, sgn;
  452. int maxpow = 0;
  453. const accscalar_t MACHEP = 5.9604644775390625E-8;
  454. accscalar_t lambda = x / a;
  455. accscalar_t sigma = (x - a) / a;
  456. accscalar_t eta, res, ck, ckterm, term, absterm;
  457. accscalar_t absoldterm = INFINITY;
  458. accscalar_t etapow[25] = {1};
  459. accscalar_t sum = 0;
  460. accscalar_t afac = 1;
  461. if (igam) {
  462. sgn = -1;
  463. } else {
  464. sgn = 1;
  465. }
  466. if (lambda > 1) {
  467. eta = ::sqrt(-2 * (::log1p(sigma) - sigma));
  468. } else if (lambda < 1) {
  469. eta = -::sqrt(-2 * (::log1p(sigma) - sigma));
  470. } else {
  471. eta = 0;
  472. }
  473. res = 0.5 * ::erfc(sgn * eta * ::sqrt(a / 2));
  474. for (k = 0; k < 25; k++) {
  475. ck = d[k][0];
  476. for (n = 1; n < 25; n++) {
  477. if (n > maxpow) {
  478. etapow[n] = eta * etapow[n - 1];
  479. maxpow += 1;
  480. }
  481. ckterm = d[k][n] * etapow[n];
  482. ck += ckterm;
  483. if (::fabs(ckterm) < MACHEP * ::fabs(ck)) {
  484. break;
  485. }
  486. }
  487. term = ck * afac;
  488. absterm = ::fabs(term);
  489. if (absterm > absoldterm) {
  490. break;
  491. }
  492. sum += term;
  493. if (absterm < MACHEP * ::fabs(sum)) {
  494. break;
  495. }
  496. absoldterm = absterm;
  497. afac /= a;
  498. }
  499. res += sgn * ::exp(-0.5 * a * eta * eta) * sum / ::sqrt(2 * 3.1415926535 * a);
  500. return res;
  501. }
  502. template <typename scalar_t>
  503. scalar_t _igamc_helper_continued_fraction(scalar_t a, scalar_t x) {
  504. // Compute igamc using DLMF 8.9.2. [igam1]
  505. using accscalar_t = opmath_t<scalar_t>;
  506. int i;
  507. accscalar_t ans, ax, c, yc, r, t, y, z;
  508. accscalar_t pk, pkm1, pkm2, qk, qkm1, qkm2;
  509. const int MAXITER = 2000;
  510. const accscalar_t MACHEP = 5.9604644775390625E-8;
  511. const accscalar_t BIG = 16777216.;
  512. const accscalar_t BIGINV = 5.9604644775390625E-8;
  513. ax = _igam_helper_fac(a, x);
  514. if (ax == 0.0) {
  515. return 0.0;
  516. }
  517. /* continued fraction */
  518. y = 1.0 - a;
  519. z = x + y + 1.0;
  520. c = 0.0;
  521. pkm2 = 1.0;
  522. qkm2 = x;
  523. pkm1 = x + 1.0;
  524. qkm1 = z * x;
  525. ans = pkm1 / qkm1;
  526. for (i = 0; i < MAXITER; i++) {
  527. c += 1.0;
  528. y += 1.0;
  529. z += 2.0;
  530. yc = y * c;
  531. pk = pkm1 * z - pkm2 * yc;
  532. qk = qkm1 * z - qkm2 * yc;
  533. if (qk != 0) {
  534. r = pk / qk;
  535. t = ::fabs((ans - r) / r);
  536. ans = r;
  537. } else {
  538. t = 1.0;
  539. }
  540. pkm2 = pkm1;
  541. pkm1 = pk;
  542. qkm2 = qkm1;
  543. qkm1 = qk;
  544. if (::fabs(pk) > BIG) {
  545. pkm2 *= BIGINV;
  546. pkm1 *= BIGINV;
  547. qkm2 *= BIGINV;
  548. qkm1 *= BIGINV;
  549. }
  550. if (t <= MACHEP) {
  551. break;
  552. }
  553. }
  554. return ans * ax;
  555. }
  556. template <typename scalar_t>
  557. scalar_t calc_igammac(scalar_t a, scalar_t x) {
  558. /* the calculation of the regularized upper incomplete gamma function
  559. * is done differently based on the values of a and x:
  560. * - if x and/or a is at the boundary of defined region, then assign the
  561. * result at the boundary
  562. * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for
  563. * Large Parameter (see DLMF 8.12.4 [igam1])
  564. * - if x > 1.1 and x < a, using the subtraction from the regularized lower
  565. * incomplete gamma
  566. * - otherwise, calculate the series from [igam2] eq (5)
  567. */
  568. using accscalar_t = opmath_t<scalar_t>;
  569. accscalar_t absxma_a;
  570. const accscalar_t SMALL = 20.0;
  571. const accscalar_t LARGE = 200.0;
  572. const accscalar_t SMALLRATIO = 0.3;
  573. const accscalar_t LARGERATIO = 4.5;
  574. if ((x < 0) || (a < 0)) {
  575. // out of defined-region of the function
  576. return NAN;
  577. } else if (a == 0) {
  578. if (x > 0) {
  579. return 0.0;
  580. } else {
  581. return NAN;
  582. }
  583. } else if (x == 0) {
  584. return 1.0;
  585. } else if (isinf(a)) {
  586. if (isinf(x)) {
  587. return NAN;
  588. }
  589. return 1.0;
  590. } else if (isinf(x)) {
  591. return 0.0;
  592. }
  593. absxma_a = ::fabs(x - a) / a;
  594. if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) {
  595. return _igam_helper_asymptotic_series(a, x, 0);
  596. } else if ((a > LARGE) && (absxma_a < LARGERATIO / ::sqrt(a))) {
  597. return _igam_helper_asymptotic_series(a, x, 0);
  598. }
  599. if (x > 1.1) {
  600. if (x < a) {
  601. return 1.0 - _igam_helper_series(a, x);
  602. } else {
  603. return _igamc_helper_continued_fraction(a, x);
  604. }
  605. } else if (x <= 0.5) {
  606. if (-0.4 / ::log(x) < a) {
  607. return 1.0 - _igam_helper_series(a, x);
  608. } else {
  609. return _igamc_helper_series(a, x);
  610. }
  611. } else {
  612. if (x * 1.1 < a) {
  613. return 1.0 - _igam_helper_series(a, x);
  614. } else {
  615. return _igamc_helper_series(a, x);
  616. }
  617. }
  618. }
  619. template <typename scalar_t>
  620. scalar_t calc_igamma(scalar_t a, scalar_t x) {
  621. /* the calculation of the regularized lower incomplete gamma function
  622. * is done differently based on the values of a and x:
  623. * - if x and/or a is at the boundary of defined region, then assign the
  624. * result at the boundary
  625. * - if a is large and a ~ x, then using Uniform Asymptotic Expansions for
  626. * Large Parameter (see DLMF 8.12.3 [igam1])
  627. * - if x > 1 and x > a, using the subtraction from the regularized upper
  628. * incomplete gamma
  629. * - otherwise, calculate the series from [igam2] eq (4)
  630. */
  631. using accscalar_t = opmath_t<scalar_t>;
  632. accscalar_t absxma_a;
  633. const accscalar_t SMALL = 20.0;
  634. const accscalar_t LARGE = 200.0;
  635. const accscalar_t SMALLRATIO = 0.3;
  636. const accscalar_t LARGERATIO = 4.5;
  637. // boundary values following SciPy
  638. if ((x < 0) || (a < 0)) {
  639. // out of defined-region of the function
  640. return NAN;
  641. } else if (a == 0) {
  642. if (x > 0) {
  643. return 1.0;
  644. } else {
  645. return NAN;
  646. }
  647. } else if (x == 0) {
  648. return 0.0; // zero integration limit
  649. } else if (isinf(a)) {
  650. if (isinf(x)) {
  651. return NAN;
  652. }
  653. return 0.0;
  654. } else if (isinf(x)) {
  655. return 1.0;
  656. }
  657. /* Asymptotic regime where a ~ x. */
  658. absxma_a = ::fabs(x - a) / a;
  659. if ((a > SMALL) && (a < LARGE) && (absxma_a < SMALLRATIO)) {
  660. return _igam_helper_asymptotic_series(a, x, 1);
  661. } else if ((a > LARGE) && (absxma_a < LARGERATIO / ::sqrt(a))) {
  662. return _igam_helper_asymptotic_series(a, x, 1);
  663. }
  664. if ((x > 1.0) && (x > a)) {
  665. return 1.0 - calc_igammac(a, x);
  666. }
  667. return _igam_helper_series(a, x);
  668. }
  669. } // namespace
  670. // end of regularized lower & upper incomplete gamma
  671. namespace c10 {
  672. namespace metal {
  673. template <typename T>
  674. inline T igamma(T a, T b) {
  675. return calc_igamma(a, b);
  676. }
  677. template <typename T>
  678. inline T igammac(T a, T b) {
  679. return calc_igammac(a, b);
  680. }
  681. } // namespace metal
  682. } // namespace c10
  683. #else
  684. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  685. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)