_scimath_impl.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643
  1. """
  2. Wrapper functions to more user-friendly calling of certain math functions
  3. whose output data-type is different than the input data-type in certain
  4. domains of the input.
  5. For example, for functions like `log` with branch cuts, the versions in this
  6. module provide the mathematically valid answers in the complex plane::
  7. >>> import math
  8. >>> np.emath.log(-math.exp(1)) == (1+1j*math.pi)
  9. True
  10. Similarly, `sqrt`, other base logarithms, `power` and trig functions are
  11. correctly handled. See their respective docstrings for specific examples.
  12. """
  13. import numpy._core.numeric as nx
  14. import numpy._core.numerictypes as nt
  15. from numpy._core.numeric import asarray, any
  16. from numpy._core.overrides import array_function_dispatch, set_module
  17. from numpy.lib._type_check_impl import isreal
  18. __all__ = [
  19. 'sqrt', 'log', 'log2', 'logn', 'log10', 'power', 'arccos', 'arcsin',
  20. 'arctanh'
  21. ]
  22. _ln2 = nx.log(2.0)
  23. def _tocomplex(arr):
  24. """Convert its input `arr` to a complex array.
  25. The input is returned as a complex array of the smallest type that will fit
  26. the original data: types like single, byte, short, etc. become csingle,
  27. while others become cdouble.
  28. A copy of the input is always made.
  29. Parameters
  30. ----------
  31. arr : array
  32. Returns
  33. -------
  34. array
  35. An array with the same input data as the input but in complex form.
  36. Examples
  37. --------
  38. >>> import numpy as np
  39. First, consider an input of type short:
  40. >>> a = np.array([1,2,3],np.short)
  41. >>> ac = np.lib.scimath._tocomplex(a); ac
  42. array([1.+0.j, 2.+0.j, 3.+0.j], dtype=complex64)
  43. >>> ac.dtype
  44. dtype('complex64')
  45. If the input is of type double, the output is correspondingly of the
  46. complex double type as well:
  47. >>> b = np.array([1,2,3],np.double)
  48. >>> bc = np.lib.scimath._tocomplex(b); bc
  49. array([1.+0.j, 2.+0.j, 3.+0.j])
  50. >>> bc.dtype
  51. dtype('complex128')
  52. Note that even if the input was complex to begin with, a copy is still
  53. made, since the astype() method always copies:
  54. >>> c = np.array([1,2,3],np.csingle)
  55. >>> cc = np.lib.scimath._tocomplex(c); cc
  56. array([1.+0.j, 2.+0.j, 3.+0.j], dtype=complex64)
  57. >>> c *= 2; c
  58. array([2.+0.j, 4.+0.j, 6.+0.j], dtype=complex64)
  59. >>> cc
  60. array([1.+0.j, 2.+0.j, 3.+0.j], dtype=complex64)
  61. """
  62. if issubclass(arr.dtype.type, (nt.single, nt.byte, nt.short, nt.ubyte,
  63. nt.ushort, nt.csingle)):
  64. return arr.astype(nt.csingle)
  65. else:
  66. return arr.astype(nt.cdouble)
  67. def _fix_real_lt_zero(x):
  68. """Convert `x` to complex if it has real, negative components.
  69. Otherwise, output is just the array version of the input (via asarray).
  70. Parameters
  71. ----------
  72. x : array_like
  73. Returns
  74. -------
  75. array
  76. Examples
  77. --------
  78. >>> import numpy as np
  79. >>> np.lib.scimath._fix_real_lt_zero([1,2])
  80. array([1, 2])
  81. >>> np.lib.scimath._fix_real_lt_zero([-1,2])
  82. array([-1.+0.j, 2.+0.j])
  83. """
  84. x = asarray(x)
  85. if any(isreal(x) & (x < 0)):
  86. x = _tocomplex(x)
  87. return x
  88. def _fix_int_lt_zero(x):
  89. """Convert `x` to double if it has real, negative components.
  90. Otherwise, output is just the array version of the input (via asarray).
  91. Parameters
  92. ----------
  93. x : array_like
  94. Returns
  95. -------
  96. array
  97. Examples
  98. --------
  99. >>> import numpy as np
  100. >>> np.lib.scimath._fix_int_lt_zero([1,2])
  101. array([1, 2])
  102. >>> np.lib.scimath._fix_int_lt_zero([-1,2])
  103. array([-1., 2.])
  104. """
  105. x = asarray(x)
  106. if any(isreal(x) & (x < 0)):
  107. x = x * 1.0
  108. return x
  109. def _fix_real_abs_gt_1(x):
  110. """Convert `x` to complex if it has real components x_i with abs(x_i)>1.
  111. Otherwise, output is just the array version of the input (via asarray).
  112. Parameters
  113. ----------
  114. x : array_like
  115. Returns
  116. -------
  117. array
  118. Examples
  119. --------
  120. >>> import numpy as np
  121. >>> np.lib.scimath._fix_real_abs_gt_1([0,1])
  122. array([0, 1])
  123. >>> np.lib.scimath._fix_real_abs_gt_1([0,2])
  124. array([0.+0.j, 2.+0.j])
  125. """
  126. x = asarray(x)
  127. if any(isreal(x) & (abs(x) > 1)):
  128. x = _tocomplex(x)
  129. return x
  130. def _unary_dispatcher(x):
  131. return (x,)
  132. @set_module('numpy.lib.scimath')
  133. @array_function_dispatch(_unary_dispatcher)
  134. def sqrt(x):
  135. """
  136. Compute the square root of x.
  137. For negative input elements, a complex value is returned
  138. (unlike `numpy.sqrt` which returns NaN).
  139. Parameters
  140. ----------
  141. x : array_like
  142. The input value(s).
  143. Returns
  144. -------
  145. out : ndarray or scalar
  146. The square root of `x`. If `x` was a scalar, so is `out`,
  147. otherwise an array is returned.
  148. See Also
  149. --------
  150. numpy.sqrt
  151. Examples
  152. --------
  153. For real, non-negative inputs this works just like `numpy.sqrt`:
  154. >>> import numpy as np
  155. >>> np.emath.sqrt(1)
  156. 1.0
  157. >>> np.emath.sqrt([1, 4])
  158. array([1., 2.])
  159. But it automatically handles negative inputs:
  160. >>> np.emath.sqrt(-1)
  161. 1j
  162. >>> np.emath.sqrt([-1,4])
  163. array([0.+1.j, 2.+0.j])
  164. Different results are expected because:
  165. floating point 0.0 and -0.0 are distinct.
  166. For more control, explicitly use complex() as follows:
  167. >>> np.emath.sqrt(complex(-4.0, 0.0))
  168. 2j
  169. >>> np.emath.sqrt(complex(-4.0, -0.0))
  170. -2j
  171. """
  172. x = _fix_real_lt_zero(x)
  173. return nx.sqrt(x)
  174. @set_module('numpy.lib.scimath')
  175. @array_function_dispatch(_unary_dispatcher)
  176. def log(x):
  177. """
  178. Compute the natural logarithm of `x`.
  179. Return the "principal value" (for a description of this, see `numpy.log`)
  180. of :math:`log_e(x)`. For real `x > 0`, this is a real number (``log(0)``
  181. returns ``-inf`` and ``log(np.inf)`` returns ``inf``). Otherwise, the
  182. complex principle value is returned.
  183. Parameters
  184. ----------
  185. x : array_like
  186. The value(s) whose log is (are) required.
  187. Returns
  188. -------
  189. out : ndarray or scalar
  190. The log of the `x` value(s). If `x` was a scalar, so is `out`,
  191. otherwise an array is returned.
  192. See Also
  193. --------
  194. numpy.log
  195. Notes
  196. -----
  197. For a log() that returns ``NAN`` when real `x < 0`, use `numpy.log`
  198. (note, however, that otherwise `numpy.log` and this `log` are identical,
  199. i.e., both return ``-inf`` for `x = 0`, ``inf`` for `x = inf`, and,
  200. notably, the complex principle value if ``x.imag != 0``).
  201. Examples
  202. --------
  203. >>> import numpy as np
  204. >>> np.emath.log(np.exp(1))
  205. 1.0
  206. Negative arguments are handled "correctly" (recall that
  207. ``exp(log(x)) == x`` does *not* hold for real ``x < 0``):
  208. >>> np.emath.log(-np.exp(1)) == (1 + np.pi * 1j)
  209. True
  210. """
  211. x = _fix_real_lt_zero(x)
  212. return nx.log(x)
  213. @set_module('numpy.lib.scimath')
  214. @array_function_dispatch(_unary_dispatcher)
  215. def log10(x):
  216. """
  217. Compute the logarithm base 10 of `x`.
  218. Return the "principal value" (for a description of this, see
  219. `numpy.log10`) of :math:`log_{10}(x)`. For real `x > 0`, this
  220. is a real number (``log10(0)`` returns ``-inf`` and ``log10(np.inf)``
  221. returns ``inf``). Otherwise, the complex principle value is returned.
  222. Parameters
  223. ----------
  224. x : array_like or scalar
  225. The value(s) whose log base 10 is (are) required.
  226. Returns
  227. -------
  228. out : ndarray or scalar
  229. The log base 10 of the `x` value(s). If `x` was a scalar, so is `out`,
  230. otherwise an array object is returned.
  231. See Also
  232. --------
  233. numpy.log10
  234. Notes
  235. -----
  236. For a log10() that returns ``NAN`` when real `x < 0`, use `numpy.log10`
  237. (note, however, that otherwise `numpy.log10` and this `log10` are
  238. identical, i.e., both return ``-inf`` for `x = 0`, ``inf`` for `x = inf`,
  239. and, notably, the complex principle value if ``x.imag != 0``).
  240. Examples
  241. --------
  242. >>> import numpy as np
  243. (We set the printing precision so the example can be auto-tested)
  244. >>> np.set_printoptions(precision=4)
  245. >>> np.emath.log10(10**1)
  246. 1.0
  247. >>> np.emath.log10([-10**1, -10**2, 10**2])
  248. array([1.+1.3644j, 2.+1.3644j, 2.+0.j ])
  249. """
  250. x = _fix_real_lt_zero(x)
  251. return nx.log10(x)
  252. def _logn_dispatcher(n, x):
  253. return (n, x,)
  254. @set_module('numpy.lib.scimath')
  255. @array_function_dispatch(_logn_dispatcher)
  256. def logn(n, x):
  257. """
  258. Take log base n of x.
  259. If `x` contains negative inputs, the answer is computed and returned in the
  260. complex domain.
  261. Parameters
  262. ----------
  263. n : array_like
  264. The integer base(s) in which the log is taken.
  265. x : array_like
  266. The value(s) whose log base `n` is (are) required.
  267. Returns
  268. -------
  269. out : ndarray or scalar
  270. The log base `n` of the `x` value(s). If `x` was a scalar, so is
  271. `out`, otherwise an array is returned.
  272. Examples
  273. --------
  274. >>> import numpy as np
  275. >>> np.set_printoptions(precision=4)
  276. >>> np.emath.logn(2, [4, 8])
  277. array([2., 3.])
  278. >>> np.emath.logn(2, [-4, -8, 8])
  279. array([2.+4.5324j, 3.+4.5324j, 3.+0.j ])
  280. """
  281. x = _fix_real_lt_zero(x)
  282. n = _fix_real_lt_zero(n)
  283. return nx.log(x)/nx.log(n)
  284. @set_module('numpy.lib.scimath')
  285. @array_function_dispatch(_unary_dispatcher)
  286. def log2(x):
  287. """
  288. Compute the logarithm base 2 of `x`.
  289. Return the "principal value" (for a description of this, see
  290. `numpy.log2`) of :math:`log_2(x)`. For real `x > 0`, this is
  291. a real number (``log2(0)`` returns ``-inf`` and ``log2(np.inf)`` returns
  292. ``inf``). Otherwise, the complex principle value is returned.
  293. Parameters
  294. ----------
  295. x : array_like
  296. The value(s) whose log base 2 is (are) required.
  297. Returns
  298. -------
  299. out : ndarray or scalar
  300. The log base 2 of the `x` value(s). If `x` was a scalar, so is `out`,
  301. otherwise an array is returned.
  302. See Also
  303. --------
  304. numpy.log2
  305. Notes
  306. -----
  307. For a log2() that returns ``NAN`` when real `x < 0`, use `numpy.log2`
  308. (note, however, that otherwise `numpy.log2` and this `log2` are
  309. identical, i.e., both return ``-inf`` for `x = 0`, ``inf`` for `x = inf`,
  310. and, notably, the complex principle value if ``x.imag != 0``).
  311. Examples
  312. --------
  313. We set the printing precision so the example can be auto-tested:
  314. >>> np.set_printoptions(precision=4)
  315. >>> np.emath.log2(8)
  316. 3.0
  317. >>> np.emath.log2([-4, -8, 8])
  318. array([2.+4.5324j, 3.+4.5324j, 3.+0.j ])
  319. """
  320. x = _fix_real_lt_zero(x)
  321. return nx.log2(x)
  322. def _power_dispatcher(x, p):
  323. return (x, p)
  324. @set_module('numpy.lib.scimath')
  325. @array_function_dispatch(_power_dispatcher)
  326. def power(x, p):
  327. """
  328. Return x to the power p, (x**p).
  329. If `x` contains negative values, the output is converted to the
  330. complex domain.
  331. Parameters
  332. ----------
  333. x : array_like
  334. The input value(s).
  335. p : array_like of ints
  336. The power(s) to which `x` is raised. If `x` contains multiple values,
  337. `p` has to either be a scalar, or contain the same number of values
  338. as `x`. In the latter case, the result is
  339. ``x[0]**p[0], x[1]**p[1], ...``.
  340. Returns
  341. -------
  342. out : ndarray or scalar
  343. The result of ``x**p``. If `x` and `p` are scalars, so is `out`,
  344. otherwise an array is returned.
  345. See Also
  346. --------
  347. numpy.power
  348. Examples
  349. --------
  350. >>> import numpy as np
  351. >>> np.set_printoptions(precision=4)
  352. >>> np.emath.power(2, 2)
  353. 4
  354. >>> np.emath.power([2, 4], 2)
  355. array([ 4, 16])
  356. >>> np.emath.power([2, 4], -2)
  357. array([0.25 , 0.0625])
  358. >>> np.emath.power([-2, 4], 2)
  359. array([ 4.-0.j, 16.+0.j])
  360. >>> np.emath.power([2, 4], [2, 4])
  361. array([ 4, 256])
  362. """
  363. x = _fix_real_lt_zero(x)
  364. p = _fix_int_lt_zero(p)
  365. return nx.power(x, p)
  366. @set_module('numpy.lib.scimath')
  367. @array_function_dispatch(_unary_dispatcher)
  368. def arccos(x):
  369. """
  370. Compute the inverse cosine of x.
  371. Return the "principal value" (for a description of this, see
  372. `numpy.arccos`) of the inverse cosine of `x`. For real `x` such that
  373. `abs(x) <= 1`, this is a real number in the closed interval
  374. :math:`[0, \\pi]`. Otherwise, the complex principle value is returned.
  375. Parameters
  376. ----------
  377. x : array_like or scalar
  378. The value(s) whose arccos is (are) required.
  379. Returns
  380. -------
  381. out : ndarray or scalar
  382. The inverse cosine(s) of the `x` value(s). If `x` was a scalar, so
  383. is `out`, otherwise an array object is returned.
  384. See Also
  385. --------
  386. numpy.arccos
  387. Notes
  388. -----
  389. For an arccos() that returns ``NAN`` when real `x` is not in the
  390. interval ``[-1,1]``, use `numpy.arccos`.
  391. Examples
  392. --------
  393. >>> import numpy as np
  394. >>> np.set_printoptions(precision=4)
  395. >>> np.emath.arccos(1) # a scalar is returned
  396. 0.0
  397. >>> np.emath.arccos([1,2])
  398. array([0.-0.j , 0.-1.317j])
  399. """
  400. x = _fix_real_abs_gt_1(x)
  401. return nx.arccos(x)
  402. @set_module('numpy.lib.scimath')
  403. @array_function_dispatch(_unary_dispatcher)
  404. def arcsin(x):
  405. """
  406. Compute the inverse sine of x.
  407. Return the "principal value" (for a description of this, see
  408. `numpy.arcsin`) of the inverse sine of `x`. For real `x` such that
  409. `abs(x) <= 1`, this is a real number in the closed interval
  410. :math:`[-\\pi/2, \\pi/2]`. Otherwise, the complex principle value is
  411. returned.
  412. Parameters
  413. ----------
  414. x : array_like or scalar
  415. The value(s) whose arcsin is (are) required.
  416. Returns
  417. -------
  418. out : ndarray or scalar
  419. The inverse sine(s) of the `x` value(s). If `x` was a scalar, so
  420. is `out`, otherwise an array object is returned.
  421. See Also
  422. --------
  423. numpy.arcsin
  424. Notes
  425. -----
  426. For an arcsin() that returns ``NAN`` when real `x` is not in the
  427. interval ``[-1,1]``, use `numpy.arcsin`.
  428. Examples
  429. --------
  430. >>> import numpy as np
  431. >>> np.set_printoptions(precision=4)
  432. >>> np.emath.arcsin(0)
  433. 0.0
  434. >>> np.emath.arcsin([0,1])
  435. array([0. , 1.5708])
  436. """
  437. x = _fix_real_abs_gt_1(x)
  438. return nx.arcsin(x)
  439. @set_module('numpy.lib.scimath')
  440. @array_function_dispatch(_unary_dispatcher)
  441. def arctanh(x):
  442. """
  443. Compute the inverse hyperbolic tangent of `x`.
  444. Return the "principal value" (for a description of this, see
  445. `numpy.arctanh`) of ``arctanh(x)``. For real `x` such that
  446. ``abs(x) < 1``, this is a real number. If `abs(x) > 1`, or if `x` is
  447. complex, the result is complex. Finally, `x = 1` returns``inf`` and
  448. ``x=-1`` returns ``-inf``.
  449. Parameters
  450. ----------
  451. x : array_like
  452. The value(s) whose arctanh is (are) required.
  453. Returns
  454. -------
  455. out : ndarray or scalar
  456. The inverse hyperbolic tangent(s) of the `x` value(s). If `x` was
  457. a scalar so is `out`, otherwise an array is returned.
  458. See Also
  459. --------
  460. numpy.arctanh
  461. Notes
  462. -----
  463. For an arctanh() that returns ``NAN`` when real `x` is not in the
  464. interval ``(-1,1)``, use `numpy.arctanh` (this latter, however, does
  465. return +/-inf for ``x = +/-1``).
  466. Examples
  467. --------
  468. >>> import numpy as np
  469. >>> np.set_printoptions(precision=4)
  470. >>> np.emath.arctanh(0.5)
  471. 0.5493061443340549
  472. >>> from numpy.testing import suppress_warnings
  473. >>> with suppress_warnings() as sup:
  474. ... sup.filter(RuntimeWarning)
  475. ... np.emath.arctanh(np.eye(2))
  476. array([[inf, 0.],
  477. [ 0., inf]])
  478. >>> np.emath.arctanh([1j])
  479. array([0.+0.7854j])
  480. """
  481. x = _fix_real_abs_gt_1(x)
  482. return nx.arctanh(x)