__init__.py 31 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459
  1. import torch
  2. from torch._C import _add_docstr, _special # type: ignore[attr-defined]
  3. from torch._torch_docs import common_args, multi_dim_common
  4. __all__ = [
  5. "airy_ai",
  6. "bessel_j0",
  7. "bessel_j1",
  8. "bessel_y0",
  9. "bessel_y1",
  10. "chebyshev_polynomial_t",
  11. "chebyshev_polynomial_u",
  12. "chebyshev_polynomial_v",
  13. "chebyshev_polynomial_w",
  14. "digamma",
  15. "entr",
  16. "erf",
  17. "erfc",
  18. "erfcx",
  19. "erfinv",
  20. "exp2",
  21. "expit",
  22. "expm1",
  23. "gammainc",
  24. "gammaincc",
  25. "gammaln",
  26. "hermite_polynomial_h",
  27. "hermite_polynomial_he",
  28. "i0",
  29. "i0e",
  30. "i1",
  31. "i1e",
  32. "laguerre_polynomial_l",
  33. "legendre_polynomial_p",
  34. "log1p",
  35. "log_ndtr",
  36. "log_softmax",
  37. "logit",
  38. "logsumexp",
  39. "modified_bessel_i0",
  40. "modified_bessel_i1",
  41. "modified_bessel_k0",
  42. "modified_bessel_k1",
  43. "multigammaln",
  44. "ndtr",
  45. "ndtri",
  46. "polygamma",
  47. "psi",
  48. "round",
  49. "shifted_chebyshev_polynomial_t",
  50. "shifted_chebyshev_polynomial_u",
  51. "shifted_chebyshev_polynomial_v",
  52. "shifted_chebyshev_polynomial_w",
  53. "scaled_modified_bessel_k0",
  54. "scaled_modified_bessel_k1",
  55. "sinc",
  56. "softmax",
  57. "spherical_bessel_j0",
  58. "xlog1py",
  59. "xlogy",
  60. "zeta",
  61. ]
  62. Tensor = torch.Tensor
  63. entr = _add_docstr(
  64. _special.special_entr,
  65. r"""
  66. entr(input, *, out=None) -> Tensor
  67. Computes the entropy on :attr:`input` (as defined below), elementwise.
  68. .. math::
  69. \begin{align}
  70. \text{entr(x)} = \begin{cases}
  71. -x * \ln(x) & x > 0 \\
  72. 0 & x = 0.0 \\
  73. -\infty & x < 0
  74. \end{cases}
  75. \end{align}
  76. """
  77. + """
  78. Args:
  79. input (Tensor): the input tensor.
  80. Keyword args:
  81. out (Tensor, optional): the output tensor.
  82. Example::
  83. >>> a = torch.arange(-0.5, 1, 0.5)
  84. >>> a
  85. tensor([-0.5000, 0.0000, 0.5000])
  86. >>> torch.special.entr(a)
  87. tensor([ -inf, 0.0000, 0.3466])
  88. """,
  89. )
  90. psi = _add_docstr(
  91. _special.special_psi,
  92. r"""
  93. psi(input, *, out=None) -> Tensor
  94. Alias for :func:`torch.special.digamma`.
  95. """,
  96. )
  97. digamma = _add_docstr(
  98. _special.special_digamma,
  99. r"""
  100. digamma(input, *, out=None) -> Tensor
  101. Computes the logarithmic derivative of the gamma function on `input`.
  102. .. math::
  103. \digamma(x) = \frac{d}{dx} \ln\left(\Gamma\left(x\right)\right) = \frac{\Gamma'(x)}{\Gamma(x)}
  104. """
  105. + r"""
  106. Args:
  107. input (Tensor): the tensor to compute the digamma function on
  108. Keyword args:
  109. {out}
  110. .. note:: This function is similar to SciPy's `scipy.special.digamma`.
  111. .. note:: From PyTorch 1.8 onwards, the digamma function returns `-Inf` for `0`.
  112. Previously it returned `NaN` for `0`.
  113. Example::
  114. >>> a = torch.tensor([1, 0.5])
  115. >>> torch.special.digamma(a)
  116. tensor([-0.5772, -1.9635])
  117. """.format(**common_args),
  118. )
  119. gammaln = _add_docstr(
  120. _special.special_gammaln,
  121. r"""
  122. gammaln(input, *, out=None) -> Tensor
  123. Computes the natural logarithm of the absolute value of the gamma function on :attr:`input`.
  124. .. math::
  125. \text{out}_{i} = \ln \Gamma(|\text{input}_{i}|)
  126. """
  127. + """
  128. Args:
  129. {input}
  130. Keyword args:
  131. {out}
  132. Example::
  133. >>> a = torch.arange(0.5, 2, 0.5)
  134. >>> torch.special.gammaln(a)
  135. tensor([ 0.5724, 0.0000, -0.1208])
  136. """.format(**common_args),
  137. )
  138. polygamma = _add_docstr(
  139. _special.special_polygamma,
  140. r"""
  141. polygamma(n, input, *, out=None) -> Tensor
  142. Computes the :math:`n^{th}` derivative of the digamma function on :attr:`input`.
  143. :math:`n \geq 0` is called the order of the polygamma function.
  144. .. math::
  145. \psi^{(n)}(x) = \frac{d^{(n)}}{dx^{(n)}} \psi(x)
  146. .. note::
  147. This function is implemented only for nonnegative integers :math:`n \geq 0`.
  148. """
  149. + """
  150. Args:
  151. n (int): the order of the polygamma function
  152. {input}
  153. Keyword args:
  154. {out}
  155. Example::
  156. >>> a = torch.tensor([1, 0.5])
  157. >>> torch.special.polygamma(1, a)
  158. tensor([1.64493, 4.9348])
  159. >>> torch.special.polygamma(2, a)
  160. tensor([ -2.4041, -16.8288])
  161. >>> torch.special.polygamma(3, a)
  162. tensor([ 6.4939, 97.4091])
  163. >>> torch.special.polygamma(4, a)
  164. tensor([ -24.8863, -771.4742])
  165. """.format(**common_args),
  166. )
  167. erf = _add_docstr(
  168. _special.special_erf,
  169. r"""
  170. erf(input, *, out=None) -> Tensor
  171. Computes the error function of :attr:`input`. The error function is defined as follows:
  172. .. math::
  173. \mathrm{erf}(x) = \frac{2}{\sqrt{\pi}} \int_{0}^{x} e^{-t^2} dt
  174. """
  175. + r"""
  176. Args:
  177. {input}
  178. Keyword args:
  179. {out}
  180. Example::
  181. >>> torch.special.erf(torch.tensor([0, -1., 10.]))
  182. tensor([ 0.0000, -0.8427, 1.0000])
  183. """.format(**common_args),
  184. )
  185. erfc = _add_docstr(
  186. _special.special_erfc,
  187. r"""
  188. erfc(input, *, out=None) -> Tensor
  189. Computes the complementary error function of :attr:`input`.
  190. The complementary error function is defined as follows:
  191. .. math::
  192. \mathrm{erfc}(x) = 1 - \frac{2}{\sqrt{\pi}} \int_{0}^{x} e^{-t^2} dt
  193. """
  194. + r"""
  195. Args:
  196. {input}
  197. Keyword args:
  198. {out}
  199. Example::
  200. >>> torch.special.erfc(torch.tensor([0, -1., 10.]))
  201. tensor([ 1.0000, 1.8427, 0.0000])
  202. """.format(**common_args),
  203. )
  204. erfcx = _add_docstr(
  205. _special.special_erfcx,
  206. r"""
  207. erfcx(input, *, out=None) -> Tensor
  208. Computes the scaled complementary error function for each element of :attr:`input`.
  209. The scaled complementary error function is defined as follows:
  210. .. math::
  211. \mathrm{erfcx}(x) = e^{x^2} \mathrm{erfc}(x)
  212. """
  213. + r"""
  214. """
  215. + r"""
  216. Args:
  217. {input}
  218. Keyword args:
  219. {out}
  220. Example::
  221. >>> torch.special.erfcx(torch.tensor([0, -1., 10.]))
  222. tensor([ 1.0000, 5.0090, 0.0561])
  223. """.format(**common_args),
  224. )
  225. erfinv = _add_docstr(
  226. _special.special_erfinv,
  227. r"""
  228. erfinv(input, *, out=None) -> Tensor
  229. Computes the inverse error function of :attr:`input`.
  230. The inverse error function is defined in the range :math:`(-1, 1)` as:
  231. .. math::
  232. \mathrm{erfinv}(\mathrm{erf}(x)) = x
  233. """
  234. + r"""
  235. Args:
  236. {input}
  237. Keyword args:
  238. {out}
  239. Example::
  240. >>> torch.special.erfinv(torch.tensor([0, 0.5, -1.]))
  241. tensor([ 0.0000, 0.4769, -inf])
  242. """.format(**common_args),
  243. )
  244. logit = _add_docstr(
  245. _special.special_logit,
  246. r"""
  247. logit(input, eps=None, *, out=None) -> Tensor
  248. Returns a new tensor with the logit of the elements of :attr:`input`.
  249. :attr:`input` is clamped to [eps, 1 - eps] when eps is not None.
  250. When eps is None and :attr:`input` < 0 or :attr:`input` > 1, the function will yields NaN.
  251. .. math::
  252. \begin{align}
  253. y_{i} &= \ln(\frac{z_{i}}{1 - z_{i}}) \\
  254. z_{i} &= \begin{cases}
  255. x_{i} & \text{if eps is None} \\
  256. \text{eps} & \text{if } x_{i} < \text{eps} \\
  257. x_{i} & \text{if } \text{eps} \leq x_{i} \leq 1 - \text{eps} \\
  258. 1 - \text{eps} & \text{if } x_{i} > 1 - \text{eps}
  259. \end{cases}
  260. \end{align}
  261. """
  262. + r"""
  263. Args:
  264. {input}
  265. eps (float, optional): the epsilon for input clamp bound. Default: ``None``
  266. Keyword args:
  267. {out}
  268. Example::
  269. >>> a = torch.rand(5)
  270. >>> a
  271. tensor([0.2796, 0.9331, 0.6486, 0.1523, 0.6516])
  272. >>> torch.special.logit(a, eps=1e-6)
  273. tensor([-0.9466, 2.6352, 0.6131, -1.7169, 0.6261])
  274. """.format(**common_args),
  275. )
  276. logsumexp = _add_docstr(
  277. _special.special_logsumexp,
  278. r"""
  279. logsumexp(input, dim, keepdim=False, *, out=None)
  280. Alias for :func:`torch.logsumexp`.
  281. """.format(**multi_dim_common),
  282. )
  283. expit = _add_docstr(
  284. _special.special_expit,
  285. r"""
  286. expit(input, *, out=None) -> Tensor
  287. Computes the expit (also known as the logistic sigmoid function) of the elements of :attr:`input`.
  288. .. math::
  289. \text{out}_{i} = \frac{1}{1 + e^{-\text{input}_{i}}}
  290. """
  291. + r"""
  292. Args:
  293. {input}
  294. Keyword args:
  295. {out}
  296. Example::
  297. >>> t = torch.randn(4)
  298. >>> t
  299. tensor([ 0.9213, 1.0887, -0.8858, -1.7683])
  300. >>> torch.special.expit(t)
  301. tensor([ 0.7153, 0.7481, 0.2920, 0.1458])
  302. """.format(**common_args),
  303. )
  304. exp2 = _add_docstr(
  305. _special.special_exp2,
  306. r"""
  307. exp2(input, *, out=None) -> Tensor
  308. Computes the base two exponential function of :attr:`input`.
  309. .. math::
  310. y_{i} = 2^{x_{i}}
  311. """
  312. + r"""
  313. Args:
  314. {input}
  315. Keyword args:
  316. {out}
  317. Example::
  318. >>> torch.special.exp2(torch.tensor([0, math.log2(2.), 3, 4]))
  319. tensor([ 1., 2., 8., 16.])
  320. """.format(**common_args),
  321. )
  322. expm1 = _add_docstr(
  323. _special.special_expm1,
  324. r"""
  325. expm1(input, *, out=None) -> Tensor
  326. Computes the exponential of the elements minus 1
  327. of :attr:`input`.
  328. .. math::
  329. y_{i} = e^{x_{i}} - 1
  330. .. note:: This function provides greater precision than exp(x) - 1 for small values of x.
  331. """
  332. + r"""
  333. Args:
  334. {input}
  335. Keyword args:
  336. {out}
  337. Example::
  338. >>> torch.special.expm1(torch.tensor([0, math.log(2.)]))
  339. tensor([ 0., 1.])
  340. """.format(**common_args),
  341. )
  342. xlog1py = _add_docstr(
  343. _special.special_xlog1py,
  344. r"""
  345. xlog1py(input, other, *, out=None) -> Tensor
  346. Computes ``input * log1p(other)`` with the following cases.
  347. .. math::
  348. \text{out}_{i} = \begin{cases}
  349. \text{NaN} & \text{if } \text{other}_{i} = \text{NaN} \\
  350. 0 & \text{if } \text{input}_{i} = 0.0 \text{ and } \text{other}_{i} != \text{NaN} \\
  351. \text{input}_{i} * \text{log1p}(\text{other}_{i})& \text{otherwise}
  352. \end{cases}
  353. Similar to SciPy's `scipy.special.xlog1py`.
  354. """
  355. + r"""
  356. Args:
  357. input (Number or Tensor) : Multiplier
  358. other (Number or Tensor) : Argument
  359. .. note:: At least one of :attr:`input` or :attr:`other` must be a tensor.
  360. Keyword args:
  361. {out}
  362. Example::
  363. >>> x = torch.zeros(5,)
  364. >>> y = torch.tensor([-1, 0, 1, float('inf'), float('nan')])
  365. >>> torch.special.xlog1py(x, y)
  366. tensor([0., 0., 0., 0., nan])
  367. >>> x = torch.tensor([1, 2, 3])
  368. >>> y = torch.tensor([3, 2, 1])
  369. >>> torch.special.xlog1py(x, y)
  370. tensor([1.3863, 2.1972, 2.0794])
  371. >>> torch.special.xlog1py(x, 4)
  372. tensor([1.6094, 3.2189, 4.8283])
  373. >>> torch.special.xlog1py(2, y)
  374. tensor([2.7726, 2.1972, 1.3863])
  375. """.format(**common_args),
  376. )
  377. xlogy = _add_docstr(
  378. _special.special_xlogy,
  379. r"""
  380. xlogy(input, other, *, out=None) -> Tensor
  381. Computes ``input * log(other)`` with the following cases.
  382. .. math::
  383. \text{out}_{i} = \begin{cases}
  384. \text{NaN} & \text{if } \text{other}_{i} = \text{NaN} \\
  385. 0 & \text{if } \text{input}_{i} = 0.0 \\
  386. \text{input}_{i} * \log{(\text{other}_{i})} & \text{otherwise}
  387. \end{cases}
  388. Similar to SciPy's `scipy.special.xlogy`.
  389. """
  390. + r"""
  391. Args:
  392. input (Number or Tensor) : Multiplier
  393. other (Number or Tensor) : Argument
  394. .. note:: At least one of :attr:`input` or :attr:`other` must be a tensor.
  395. Keyword args:
  396. {out}
  397. Example::
  398. >>> x = torch.zeros(5,)
  399. >>> y = torch.tensor([-1, 0, 1, float('inf'), float('nan')])
  400. >>> torch.special.xlogy(x, y)
  401. tensor([0., 0., 0., 0., nan])
  402. >>> x = torch.tensor([1, 2, 3])
  403. >>> y = torch.tensor([3, 2, 1])
  404. >>> torch.special.xlogy(x, y)
  405. tensor([1.0986, 1.3863, 0.0000])
  406. >>> torch.special.xlogy(x, 4)
  407. tensor([1.3863, 2.7726, 4.1589])
  408. >>> torch.special.xlogy(2, y)
  409. tensor([2.1972, 1.3863, 0.0000])
  410. """.format(**common_args),
  411. )
  412. i0 = _add_docstr(
  413. _special.special_i0,
  414. r"""
  415. i0(input, *, out=None) -> Tensor
  416. Computes the zeroth order modified Bessel function of the first kind for each element of :attr:`input`.
  417. .. math::
  418. \text{out}_{i} = I_0(\text{input}_{i}) = \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!)^2}
  419. """
  420. + r"""
  421. Args:
  422. input (Tensor): the input tensor
  423. Keyword args:
  424. {out}
  425. Example::
  426. >>> torch.i0(torch.arange(5, dtype=torch.float32))
  427. tensor([ 1.0000, 1.2661, 2.2796, 4.8808, 11.3019])
  428. """.format(**common_args),
  429. )
  430. i0e = _add_docstr(
  431. _special.special_i0e,
  432. r"""
  433. i0e(input, *, out=None) -> Tensor
  434. Computes the exponentially scaled zeroth order modified Bessel function of the first kind (as defined below)
  435. for each element of :attr:`input`.
  436. .. math::
  437. \text{out}_{i} = \exp(-|x|) * i0(x) = \exp(-|x|) * \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!)^2}
  438. """
  439. + r"""
  440. Args:
  441. {input}
  442. Keyword args:
  443. {out}
  444. Example::
  445. >>> torch.special.i0e(torch.arange(5, dtype=torch.float32))
  446. tensor([1.0000, 0.4658, 0.3085, 0.2430, 0.2070])
  447. """.format(**common_args),
  448. )
  449. i1 = _add_docstr(
  450. _special.special_i1,
  451. r"""
  452. i1(input, *, out=None) -> Tensor
  453. Computes the first order modified Bessel function of the first kind (as defined below)
  454. for each element of :attr:`input`.
  455. .. math::
  456. \text{out}_{i} = \frac{(\text{input}_{i})}{2} * \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!) * (k+1)!}
  457. """
  458. + r"""
  459. Args:
  460. {input}
  461. Keyword args:
  462. {out}
  463. Example::
  464. >>> torch.special.i1(torch.arange(5, dtype=torch.float32))
  465. tensor([0.0000, 0.5652, 1.5906, 3.9534, 9.7595])
  466. """.format(**common_args),
  467. )
  468. i1e = _add_docstr(
  469. _special.special_i1e,
  470. r"""
  471. i1e(input, *, out=None) -> Tensor
  472. Computes the exponentially scaled first order modified Bessel function of the first kind (as defined below)
  473. for each element of :attr:`input`.
  474. .. math::
  475. \text{out}_{i} = \exp(-|x|) * i1(x) =
  476. \exp(-|x|) * \frac{(\text{input}_{i})}{2} * \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!) * (k+1)!}
  477. """
  478. + r"""
  479. Args:
  480. {input}
  481. Keyword args:
  482. {out}
  483. Example::
  484. >>> torch.special.i1e(torch.arange(5, dtype=torch.float32))
  485. tensor([0.0000, 0.2079, 0.2153, 0.1968, 0.1788])
  486. """.format(**common_args),
  487. )
  488. ndtr = _add_docstr(
  489. _special.special_ndtr,
  490. r"""
  491. ndtr(input, *, out=None) -> Tensor
  492. Computes the area under the standard Gaussian probability density function,
  493. integrated from minus infinity to :attr:`input`, elementwise.
  494. .. math::
  495. \text{ndtr}(x) = \frac{1}{\sqrt{2 \pi}}\int_{-\infty}^{x} e^{-\frac{1}{2}t^2} dt
  496. """
  497. + r"""
  498. Args:
  499. {input}
  500. Keyword args:
  501. {out}
  502. Example::
  503. >>> torch.special.ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3]))
  504. tensor([0.0013, 0.0228, 0.1587, 0.5000, 0.8413, 0.9772, 0.9987])
  505. """.format(**common_args),
  506. )
  507. ndtri = _add_docstr(
  508. _special.special_ndtri,
  509. r"""
  510. ndtri(input, *, out=None) -> Tensor
  511. Computes the argument, x, for which the area under the Gaussian probability density function
  512. (integrated from minus infinity to x) is equal to :attr:`input`, elementwise.
  513. .. math::
  514. \text{ndtri}(p) = \sqrt{2}\text{erf}^{-1}(2p - 1)
  515. .. note::
  516. Also known as quantile function for Normal Distribution.
  517. """
  518. + r"""
  519. Args:
  520. {input}
  521. Keyword args:
  522. {out}
  523. Example::
  524. >>> torch.special.ndtri(torch.tensor([0, 0.25, 0.5, 0.75, 1]))
  525. tensor([ -inf, -0.6745, 0.0000, 0.6745, inf])
  526. """.format(**common_args),
  527. )
  528. log_ndtr = _add_docstr(
  529. _special.special_log_ndtr,
  530. r"""
  531. log_ndtr(input, *, out=None) -> Tensor
  532. Computes the log of the area under the standard Gaussian probability density function,
  533. integrated from minus infinity to :attr:`input`, elementwise.
  534. .. math::
  535. \text{log\_ndtr}(x) = \log\left(\frac{1}{\sqrt{2 \pi}}\int_{-\infty}^{x} e^{-\frac{1}{2}t^2} dt \right)
  536. """
  537. + r"""
  538. Args:
  539. {input}
  540. Keyword args:
  541. {out}
  542. Example::
  543. >>> torch.special.log_ndtr(torch.tensor([-3., -2, -1, 0, 1, 2, 3]))
  544. tensor([-6.6077 -3.7832 -1.841 -0.6931 -0.1728 -0.023 -0.0014])
  545. """.format(**common_args),
  546. )
  547. log1p = _add_docstr(
  548. _special.special_log1p,
  549. r"""
  550. log1p(input, *, out=None) -> Tensor
  551. Alias for :func:`torch.log1p`.
  552. """,
  553. )
  554. sinc = _add_docstr(
  555. _special.special_sinc,
  556. r"""
  557. sinc(input, *, out=None) -> Tensor
  558. Computes the normalized sinc of :attr:`input.`
  559. .. math::
  560. \text{out}_{i} =
  561. \begin{cases}
  562. 1, & \text{if}\ \text{input}_{i}=0 \\
  563. \sin(\pi \text{input}_{i}) / (\pi \text{input}_{i}), & \text{otherwise}
  564. \end{cases}
  565. """
  566. + r"""
  567. Args:
  568. {input}
  569. Keyword args:
  570. {out}
  571. Example::
  572. >>> t = torch.randn(4)
  573. >>> t
  574. tensor([ 0.2252, -0.2948, 1.0267, -1.1566])
  575. >>> torch.special.sinc(t)
  576. tensor([ 0.9186, 0.8631, -0.0259, -0.1300])
  577. """.format(**common_args),
  578. )
  579. round = _add_docstr(
  580. _special.special_round,
  581. r"""
  582. round(input, *, out=None) -> Tensor
  583. Alias for :func:`torch.round`.
  584. """,
  585. )
  586. softmax = _add_docstr(
  587. _special.special_softmax,
  588. r"""
  589. softmax(input, dim, *, dtype=None) -> Tensor
  590. Computes the softmax function.
  591. Softmax is defined as:
  592. :math:`\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}`
  593. It is applied to all slices along dim, and will re-scale them so that the elements
  594. lie in the range `[0, 1]` and sum to 1.
  595. Args:
  596. input (Tensor): input
  597. dim (int): A dimension along which softmax will be computed.
  598. dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
  599. If specified, the input tensor is cast to :attr:`dtype` before the operation
  600. is performed. This is useful for preventing data type overflows. Default: None.
  601. Examples::
  602. >>> t = torch.ones(2, 2)
  603. >>> torch.special.softmax(t, 0)
  604. tensor([[0.5000, 0.5000],
  605. [0.5000, 0.5000]])
  606. """,
  607. )
  608. log_softmax = _add_docstr(
  609. _special.special_log_softmax,
  610. r"""
  611. log_softmax(input, dim, *, dtype=None) -> Tensor
  612. Computes softmax followed by a logarithm.
  613. While mathematically equivalent to log(softmax(x)), doing these two
  614. operations separately is slower and numerically unstable. This function
  615. is computed as:
  616. .. math::
  617. \text{log\_softmax}(x_{i}) = \log\left(\frac{\exp(x_i) }{ \sum_j \exp(x_j)} \right)
  618. """
  619. + r"""
  620. Args:
  621. input (Tensor): input
  622. dim (int): A dimension along which log_softmax will be computed.
  623. dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
  624. If specified, the input tensor is cast to :attr:`dtype` before the operation
  625. is performed. This is useful for preventing data type overflows. Default: None.
  626. Example::
  627. >>> t = torch.ones(2, 2)
  628. >>> torch.special.log_softmax(t, 0)
  629. tensor([[-0.6931, -0.6931],
  630. [-0.6931, -0.6931]])
  631. """,
  632. )
  633. zeta = _add_docstr(
  634. _special.special_zeta,
  635. r"""
  636. zeta(input, other, *, out=None) -> Tensor
  637. Computes the Hurwitz zeta function, elementwise.
  638. .. math::
  639. \zeta(x, q) = \sum_{k=0}^{\infty} \frac{1}{(k + q)^x}
  640. """
  641. + r"""
  642. Args:
  643. input (Tensor): the input tensor corresponding to `x`.
  644. other (Tensor): the input tensor corresponding to `q`.
  645. .. note::
  646. The Riemann zeta function corresponds to the case when `q = 1`
  647. Keyword args:
  648. {out}
  649. Example::
  650. >>> x = torch.tensor([2., 4.])
  651. >>> torch.special.zeta(x, 1)
  652. tensor([1.6449, 1.0823])
  653. >>> torch.special.zeta(x, torch.tensor([1., 2.]))
  654. tensor([1.6449, 0.0823])
  655. >>> torch.special.zeta(2, torch.tensor([1., 2.]))
  656. tensor([1.6449, 0.6449])
  657. """.format(**common_args),
  658. )
  659. multigammaln = _add_docstr(
  660. _special.special_multigammaln,
  661. r"""
  662. multigammaln(input, p, *, out=None) -> Tensor
  663. Computes the `multivariate log-gamma function
  664. <https://en.wikipedia.org/wiki/Multivariate_gamma_function>`_ with dimension
  665. :math:`p` element-wise, given by
  666. .. math::
  667. \log(\Gamma_{p}(a)) = C + \displaystyle \sum_{i=1}^{p} \log\left(\Gamma\left(a - \frac{i - 1}{2}\right)\right)
  668. where :math:`C = \log(\pi) \cdot \frac{p (p - 1)}{4}` and :math:`\Gamma(-)` is the Gamma function.
  669. All elements must be greater than :math:`\frac{p - 1}{2}`, otherwise the behavior is undefined.
  670. """
  671. + """
  672. Args:
  673. input (Tensor): the tensor to compute the multivariate log-gamma function
  674. p (int): the number of dimensions
  675. Keyword args:
  676. {out}
  677. Example::
  678. >>> a = torch.empty(2, 3).uniform_(1, 2)
  679. >>> a
  680. tensor([[1.6835, 1.8474, 1.1929],
  681. [1.0475, 1.7162, 1.4180]])
  682. >>> torch.special.multigammaln(a, 2)
  683. tensor([[0.3928, 0.4007, 0.7586],
  684. [1.0311, 0.3901, 0.5049]])
  685. """.format(**common_args),
  686. )
  687. gammainc = _add_docstr(
  688. _special.special_gammainc,
  689. r"""
  690. gammainc(input, other, *, out=None) -> Tensor
  691. Computes the regularized lower incomplete gamma function:
  692. .. math::
  693. \text{out}_{i} = \frac{1}{\Gamma(\text{input}_i)} \int_0^{\text{other}_i} t^{\text{input}_i-1} e^{-t} dt
  694. where both :math:`\text{input}_i` and :math:`\text{other}_i` are weakly positive
  695. and at least one is strictly positive.
  696. If both are zero or either is negative then :math:`\text{out}_i=\text{nan}`.
  697. :math:`\Gamma(\cdot)` in the equation above is the gamma function,
  698. .. math::
  699. \Gamma(\text{input}_i) = \int_0^\infty t^{(\text{input}_i-1)} e^{-t} dt.
  700. See :func:`torch.special.gammaincc` and :func:`torch.special.gammaln` for related functions.
  701. Supports :ref:`broadcasting to a common shape <broadcasting-semantics>`
  702. and float inputs.
  703. .. note::
  704. The backward pass with respect to :attr:`input` is not yet supported.
  705. Please open an issue on PyTorch's Github to request it.
  706. """
  707. + r"""
  708. Args:
  709. input (Tensor): the first non-negative input tensor
  710. other (Tensor): the second non-negative input tensor
  711. Keyword args:
  712. {out}
  713. Example::
  714. >>> a1 = torch.tensor([4.0])
  715. >>> a2 = torch.tensor([3.0, 4.0, 5.0])
  716. >>> a = torch.special.gammaincc(a1, a2)
  717. tensor([0.3528, 0.5665, 0.7350])
  718. tensor([0.3528, 0.5665, 0.7350])
  719. >>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2)
  720. tensor([1., 1., 1.])
  721. """.format(**common_args),
  722. )
  723. gammaincc = _add_docstr(
  724. _special.special_gammaincc,
  725. r"""
  726. gammaincc(input, other, *, out=None) -> Tensor
  727. Computes the regularized upper incomplete gamma function:
  728. .. math::
  729. \text{out}_{i} = \frac{1}{\Gamma(\text{input}_i)} \int_{\text{other}_i}^{\infty} t^{\text{input}_i-1} e^{-t} dt
  730. where both :math:`\text{input}_i` and :math:`\text{other}_i` are weakly positive
  731. and at least one is strictly positive.
  732. If both are zero or either is negative then :math:`\text{out}_i=\text{nan}`.
  733. :math:`\Gamma(\cdot)` in the equation above is the gamma function,
  734. .. math::
  735. \Gamma(\text{input}_i) = \int_0^\infty t^{(\text{input}_i-1)} e^{-t} dt.
  736. See :func:`torch.special.gammainc` and :func:`torch.special.gammaln` for related functions.
  737. Supports :ref:`broadcasting to a common shape <broadcasting-semantics>`
  738. and float inputs.
  739. .. note::
  740. The backward pass with respect to :attr:`input` is not yet supported.
  741. Please open an issue on PyTorch's Github to request it.
  742. """
  743. + r"""
  744. Args:
  745. input (Tensor): the first non-negative input tensor
  746. other (Tensor): the second non-negative input tensor
  747. Keyword args:
  748. {out}
  749. Example::
  750. >>> a1 = torch.tensor([4.0])
  751. >>> a2 = torch.tensor([3.0, 4.0, 5.0])
  752. >>> a = torch.special.gammaincc(a1, a2)
  753. tensor([0.6472, 0.4335, 0.2650])
  754. >>> b = torch.special.gammainc(a1, a2) + torch.special.gammaincc(a1, a2)
  755. tensor([1., 1., 1.])
  756. """.format(**common_args),
  757. )
  758. airy_ai = _add_docstr(
  759. _special.special_airy_ai,
  760. r"""
  761. airy_ai(input, *, out=None) -> Tensor
  762. Airy function :math:`\text{Ai}\left(\text{input}\right)`.
  763. """
  764. + r"""
  765. Args:
  766. {input}
  767. Keyword args:
  768. {out}
  769. """.format(**common_args),
  770. )
  771. bessel_j0 = _add_docstr(
  772. _special.special_bessel_j0,
  773. r"""
  774. bessel_j0(input, *, out=None) -> Tensor
  775. Bessel function of the first kind of order :math:`0`.
  776. """
  777. + r"""
  778. Args:
  779. {input}
  780. Keyword args:
  781. {out}
  782. """.format(**common_args),
  783. )
  784. bessel_j1 = _add_docstr(
  785. _special.special_bessel_j1,
  786. r"""
  787. bessel_j1(input, *, out=None) -> Tensor
  788. Bessel function of the first kind of order :math:`1`.
  789. """
  790. + r"""
  791. Args:
  792. {input}
  793. Keyword args:
  794. {out}
  795. """.format(**common_args),
  796. )
  797. bessel_y0 = _add_docstr(
  798. _special.special_bessel_y0,
  799. r"""
  800. bessel_y0(input, *, out=None) -> Tensor
  801. Bessel function of the second kind of order :math:`0`.
  802. """
  803. + r"""
  804. Args:
  805. {input}
  806. Keyword args:
  807. {out}
  808. """.format(**common_args),
  809. )
  810. bessel_y1 = _add_docstr(
  811. _special.special_bessel_y1,
  812. r"""
  813. bessel_y1(input, *, out=None) -> Tensor
  814. Bessel function of the second kind of order :math:`1`.
  815. """
  816. + r"""
  817. Args:
  818. {input}
  819. Keyword args:
  820. {out}
  821. """.format(**common_args),
  822. )
  823. chebyshev_polynomial_t = _add_docstr(
  824. _special.special_chebyshev_polynomial_t,
  825. r"""
  826. chebyshev_polynomial_t(input, n, *, out=None) -> Tensor
  827. Chebyshev polynomial of the first kind :math:`T_{n}(\text{input})`.
  828. If :math:`n = 0`, :math:`1` is returned. If :math:`n = 1`, :math:`\text{input}`
  829. is returned. If :math:`n < 6` or :math:`|\text{input}| > 1` the recursion:
  830. .. math::
  831. T_{n + 1}(\text{input}) = 2 \times \text{input} \times T_{n}(\text{input}) - T_{n - 1}(\text{input})
  832. is evaluated. Otherwise, the explicit trigonometric formula:
  833. .. math::
  834. T_{n}(\text{input}) = \text{cos}(n \times \text{arccos}(x))
  835. is evaluated.
  836. """
  837. + r"""
  838. Args:
  839. {input}
  840. n (Tensor): Degree of the polynomial.
  841. Keyword args:
  842. {out}
  843. """.format(**common_args),
  844. )
  845. chebyshev_polynomial_u = _add_docstr(
  846. _special.special_chebyshev_polynomial_u,
  847. r"""
  848. chebyshev_polynomial_u(input, n, *, out=None) -> Tensor
  849. Chebyshev polynomial of the second kind :math:`U_{n}(\text{input})`.
  850. If :math:`n = 0`, :math:`1` is returned. If :math:`n = 1`,
  851. :math:`2 \times \text{input}` is returned. If :math:`n < 6` or
  852. :math:`|\text{input}| > 1`, the recursion:
  853. .. math::
  854. U_{n + 1}(\text{input}) = 2 \times \text{input} \times U_{n}(\text{input}) - U_{n - 1}(\text{input})
  855. is evaluated. Otherwise, the explicit trigonometric formula:
  856. .. math::
  857. \frac{\text{sin}((n + 1) \times \text{arccos}(\text{input}))}{\text{sin}(\text{arccos}(\text{input}))}
  858. is evaluated.
  859. """
  860. + r"""
  861. Args:
  862. {input}
  863. n (Tensor): Degree of the polynomial.
  864. Keyword args:
  865. {out}
  866. """.format(**common_args),
  867. )
  868. chebyshev_polynomial_v = _add_docstr(
  869. _special.special_chebyshev_polynomial_v,
  870. r"""
  871. chebyshev_polynomial_v(input, n, *, out=None) -> Tensor
  872. Chebyshev polynomial of the third kind :math:`V_{n}^{\ast}(\text{input})`.
  873. """
  874. + r"""
  875. Args:
  876. {input}
  877. n (Tensor): Degree of the polynomial.
  878. Keyword args:
  879. {out}
  880. """.format(**common_args),
  881. )
  882. chebyshev_polynomial_w = _add_docstr(
  883. _special.special_chebyshev_polynomial_w,
  884. r"""
  885. chebyshev_polynomial_w(input, n, *, out=None) -> Tensor
  886. Chebyshev polynomial of the fourth kind :math:`W_{n}^{\ast}(\text{input})`.
  887. """
  888. + r"""
  889. Args:
  890. {input}
  891. n (Tensor): Degree of the polynomial.
  892. Keyword args:
  893. {out}
  894. """.format(**common_args),
  895. )
  896. hermite_polynomial_h = _add_docstr(
  897. _special.special_hermite_polynomial_h,
  898. r"""
  899. hermite_polynomial_h(input, n, *, out=None) -> Tensor
  900. Physicist's Hermite polynomial :math:`H_{n}(\text{input})`.
  901. If :math:`n = 0`, :math:`1` is returned. If :math:`n = 1`, :math:`\text{input}`
  902. is returned. Otherwise, the recursion:
  903. .. math::
  904. H_{n + 1}(\text{input}) = 2 \times \text{input} \times H_{n}(\text{input}) - H_{n - 1}(\text{input})
  905. is evaluated.
  906. """
  907. + r"""
  908. Args:
  909. {input}
  910. n (Tensor): Degree of the polynomial.
  911. Keyword args:
  912. {out}
  913. """.format(**common_args),
  914. )
  915. hermite_polynomial_he = _add_docstr(
  916. _special.special_hermite_polynomial_he,
  917. r"""
  918. hermite_polynomial_he(input, n, *, out=None) -> Tensor
  919. Probabilist's Hermite polynomial :math:`He_{n}(\text{input})`.
  920. If :math:`n = 0`, :math:`1` is returned. If :math:`n = 1`, :math:`\text{input}`
  921. is returned. Otherwise, the recursion:
  922. .. math::
  923. He_{n + 1}(\text{input}) = 2 \times \text{input} \times He_{n}(\text{input}) - He_{n - 1}(\text{input})
  924. is evaluated.
  925. """
  926. + r"""
  927. Args:
  928. {input}
  929. n (Tensor): Degree of the polynomial.
  930. Keyword args:
  931. {out}
  932. """.format(**common_args),
  933. )
  934. laguerre_polynomial_l = _add_docstr(
  935. _special.special_laguerre_polynomial_l,
  936. r"""
  937. laguerre_polynomial_l(input, n, *, out=None) -> Tensor
  938. Laguerre polynomial :math:`L_{n}(\text{input})`.
  939. If :math:`n = 0`, :math:`1` is returned. If :math:`n = 1`, :math:`\text{input}`
  940. is returned. Otherwise, the recursion:
  941. .. math::
  942. L_{n + 1}(\text{input}) = 2 \times \text{input} \times L_{n}(\text{input}) - L_{n - 1}(\text{input})
  943. is evaluated.
  944. """
  945. + r"""
  946. Args:
  947. {input}
  948. n (Tensor): Degree of the polynomial.
  949. Keyword args:
  950. {out}
  951. """.format(**common_args),
  952. )
  953. legendre_polynomial_p = _add_docstr(
  954. _special.special_legendre_polynomial_p,
  955. r"""
  956. legendre_polynomial_p(input, n, *, out=None) -> Tensor
  957. Legendre polynomial :math:`P_{n}(\text{input})`.
  958. If :math:`n = 0`, :math:`1` is returned. If :math:`n = 1`, :math:`\text{input}`
  959. is returned. Otherwise, the recursion:
  960. .. math::
  961. P_{n + 1}(\text{input}) = 2 \times \text{input} \times P_{n}(\text{input}) - P_{n - 1}(\text{input})
  962. is evaluated.
  963. """
  964. + r"""
  965. Args:
  966. {input}
  967. n (Tensor): Degree of the polynomial.
  968. Keyword args:
  969. {out}
  970. """.format(**common_args),
  971. )
  972. modified_bessel_i0 = _add_docstr(
  973. _special.special_modified_bessel_i0,
  974. r"""
  975. modified_bessel_i0(input, *, out=None) -> Tensor
  976. Modified Bessel function of the first kind of order :math:`0`.
  977. """
  978. + r"""
  979. Args:
  980. {input}
  981. Keyword args:
  982. {out}
  983. """.format(**common_args),
  984. )
  985. modified_bessel_i1 = _add_docstr(
  986. _special.special_modified_bessel_i1,
  987. r"""
  988. modified_bessel_i1(input, *, out=None) -> Tensor
  989. Modified Bessel function of the first kind of order :math:`1`.
  990. """
  991. + r"""
  992. Args:
  993. {input}
  994. Keyword args:
  995. {out}
  996. """.format(**common_args),
  997. )
  998. modified_bessel_k0 = _add_docstr(
  999. _special.special_modified_bessel_k0,
  1000. r"""
  1001. modified_bessel_k0(input, *, out=None) -> Tensor
  1002. Modified Bessel function of the second kind of order :math:`0`.
  1003. """
  1004. + r"""
  1005. Args:
  1006. {input}
  1007. Keyword args:
  1008. {out}
  1009. """.format(**common_args),
  1010. )
  1011. modified_bessel_k1 = _add_docstr(
  1012. _special.special_modified_bessel_k1,
  1013. r"""
  1014. modified_bessel_k1(input, *, out=None) -> Tensor
  1015. Modified Bessel function of the second kind of order :math:`1`.
  1016. """
  1017. + r"""
  1018. Args:
  1019. {input}
  1020. Keyword args:
  1021. {out}
  1022. """.format(**common_args),
  1023. )
  1024. scaled_modified_bessel_k0 = _add_docstr(
  1025. _special.special_scaled_modified_bessel_k0,
  1026. r"""
  1027. scaled_modified_bessel_k0(input, *, out=None) -> Tensor
  1028. Scaled modified Bessel function of the second kind of order :math:`0`.
  1029. """
  1030. + r"""
  1031. Args:
  1032. {input}
  1033. Keyword args:
  1034. {out}
  1035. """.format(**common_args),
  1036. )
  1037. scaled_modified_bessel_k1 = _add_docstr(
  1038. _special.special_scaled_modified_bessel_k1,
  1039. r"""
  1040. scaled_modified_bessel_k1(input, *, out=None) -> Tensor
  1041. Scaled modified Bessel function of the second kind of order :math:`1`.
  1042. """
  1043. + r"""
  1044. Args:
  1045. {input}
  1046. Keyword args:
  1047. {out}
  1048. """.format(**common_args),
  1049. )
  1050. shifted_chebyshev_polynomial_t = _add_docstr(
  1051. _special.special_shifted_chebyshev_polynomial_t,
  1052. r"""
  1053. shifted_chebyshev_polynomial_t(input, n, *, out=None) -> Tensor
  1054. Chebyshev polynomial of the first kind :math:`T_{n}^{\ast}(\text{input})`.
  1055. """
  1056. + r"""
  1057. Args:
  1058. {input}
  1059. n (Tensor): Degree of the polynomial.
  1060. Keyword args:
  1061. {out}
  1062. """.format(**common_args),
  1063. )
  1064. shifted_chebyshev_polynomial_u = _add_docstr(
  1065. _special.special_shifted_chebyshev_polynomial_u,
  1066. r"""
  1067. shifted_chebyshev_polynomial_u(input, n, *, out=None) -> Tensor
  1068. Chebyshev polynomial of the second kind :math:`U_{n}^{\ast}(\text{input})`.
  1069. """
  1070. + r"""
  1071. Args:
  1072. {input}
  1073. n (Tensor): Degree of the polynomial.
  1074. Keyword args:
  1075. {out}
  1076. """.format(**common_args),
  1077. )
  1078. shifted_chebyshev_polynomial_v = _add_docstr(
  1079. _special.special_shifted_chebyshev_polynomial_v,
  1080. r"""
  1081. shifted_chebyshev_polynomial_v(input, n, *, out=None) -> Tensor
  1082. Chebyshev polynomial of the third kind :math:`V_{n}^{\ast}(\text{input})`.
  1083. """
  1084. + r"""
  1085. Args:
  1086. {input}
  1087. n (Tensor): Degree of the polynomial.
  1088. Keyword args:
  1089. {out}
  1090. """.format(**common_args),
  1091. )
  1092. shifted_chebyshev_polynomial_w = _add_docstr(
  1093. _special.special_shifted_chebyshev_polynomial_w,
  1094. r"""
  1095. shifted_chebyshev_polynomial_w(input, n, *, out=None) -> Tensor
  1096. Chebyshev polynomial of the fourth kind :math:`W_{n}^{\ast}(\text{input})`.
  1097. """
  1098. + r"""
  1099. Args:
  1100. {input}
  1101. n (Tensor): Degree of the polynomial.
  1102. Keyword args:
  1103. {out}
  1104. """.format(**common_args),
  1105. )
  1106. spherical_bessel_j0 = _add_docstr(
  1107. _special.special_spherical_bessel_j0,
  1108. r"""
  1109. spherical_bessel_j0(input, *, out=None) -> Tensor
  1110. Spherical Bessel function of the first kind of order :math:`0`.
  1111. """
  1112. + r"""
  1113. Args:
  1114. {input}
  1115. Keyword args:
  1116. {out}
  1117. """.format(**common_args),
  1118. )