_continued_fraction.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. import numpy as np
  2. from scipy._lib._array_api import (
  3. array_namespace, xp_ravel, xp_copy, xp_promote
  4. )
  5. import scipy._lib._elementwise_iterative_method as eim
  6. from scipy._lib._util import _RichResult
  7. from scipy import special
  8. # Todo:
  9. # Avoid special-casing key 'n' in _lib._elementwise_iterative_method::_check_termination
  10. # Rearrange termination condition to allow absolute and relative tolerances?
  11. # Interpret/return |f_n - f_{n-1}| as an error estimate?
  12. # Return gracefully for size=0 arrays
  13. def _logaddexp(x, y, xp=None):
  14. # logaddexp that supports complex numbers
  15. xp = array_namespace(x, y) if xp is None else xp
  16. x, y = xp.broadcast_arrays(x, y)
  17. xy = xp.stack((x, y), axis=0)
  18. return special.logsumexp(xy, axis=0)
  19. def _continued_fraction_iv(a, b, args, tolerances, maxiter, log):
  20. # Input validation for `_continued_fraction`
  21. if not callable(a) or not callable(b):
  22. raise ValueError('`a` and `b` must be callable.')
  23. if not np.iterable(args):
  24. args = (args,)
  25. # Call each callable once to determine namespace and dtypes
  26. a0, b0 = a(0, *args), b(0, *args)
  27. xp = array_namespace(a0, b0, *args)
  28. a0, b0, *args = xp_promote(a0, b0, *args, force_floating=True, broadcast=True,
  29. xp=xp)
  30. shape, dtype = a0.shape, a0.dtype
  31. a0, b0, *args = (xp_ravel(arg) for arg in (a0, b0) + tuple(args))
  32. tolerances = {} if tolerances is None else tolerances
  33. eps = tolerances.get('eps', None)
  34. tiny = tolerances.get('tiny', None)
  35. # tolerances are floats, not arrays, so it's OK to use NumPy
  36. message = ('`eps` and `tiny` must be (or represent the logarithm of) '
  37. 'finite, positive, real scalars.')
  38. tols = np.asarray([eps if eps is not None else 1,
  39. tiny if tiny is not None else 1])
  40. not_real = (not np.issubdtype(tols.dtype, np.number)
  41. or np.issubdtype(tols.dtype, np.complexfloating))
  42. not_positive = np.any(tols <= 0) if not log else False
  43. not_finite = not np.all(np.isfinite(tols))
  44. not_scalar = tols.shape != (2,)
  45. if not_real or not_positive or not_finite or not_scalar:
  46. raise ValueError(message)
  47. maxiter_int = int(maxiter)
  48. if maxiter != maxiter_int or maxiter < 0:
  49. raise ValueError('`maxiter` must be a non-negative integer.')
  50. if not isinstance(log, bool):
  51. raise ValueError('`log` must be boolean.')
  52. return a, b, args, eps, tiny, maxiter, log, a0, b0, shape, dtype, xp
  53. def _continued_fraction(a, b, *, args=(), tolerances=None, maxiter=100, log=False):
  54. r"""Evaluate a generalized continued fraction numerically.
  55. `_continued_fraction` iteratively evaluates convergents of a continued fraction
  56. given coefficients returned by callables `a` and `b`. Iteration terminates when
  57. `maxiter` terms have been evaluated or a termination criterion controlled by
  58. `tolerances` is satisfied, and the final convergent is returned as the ``f``
  59. attribute of the result object.
  60. This function works elementwise when `args` contains (broadcastable) arrays.
  61. Parameters
  62. ----------
  63. a, b: callable
  64. Functions that provide the *numerator* and *denominator* coefficients of
  65. the continued fraction, respectively.
  66. The signature of each must be::
  67. a(n: int, *argsj) -> ndarray
  68. where ``n`` is the coefficient number and ``argsj`` is a tuple, which may
  69. contain an arbitrary number of arrays of any shape. `a` and `b` must be
  70. elementwise functions: each scalar element ``a(n, *argsj)[i]`` must equal
  71. ``a(n, *[argj[i] for argj in argsj])`` for valid indices ``i``.
  72. `a` and `b` must not mutate the arrays in ``argsj``.
  73. The result shape is the broadcasted shape of ``a(0, *args)`` and
  74. ``b(0, *args)``. The dtype used throughout computation is the result dtype
  75. of these terms if it is a float, and the default float of the array library
  76. otherwise. The numerical value of ``a(0, *args)`` is ignored, and
  77. the value of the leading term ``b(0, *args)`` is the so-called "integer"
  78. part of the continued fraction (although it need not be integral).
  79. args : tuple of array_like, optional
  80. Additional positional *array* arguments to be passed to `a` and `b`. Arrays
  81. must be broadcastable with one another. If the coefficient callables
  82. require additional arguments that are not broadcastable with one
  83. another, wrap them with callables `a` and `b` such that `a` and `b` accept
  84. only ``n`` and broadcastable array arguments.
  85. tolerances : dictionary of floats, optional
  86. Tolerances and numerical thresholds used by the algorithm. Currently,
  87. valid keys of the dictionary are:
  88. - ``eps`` - the convergence threshold of Lentz' algorithm
  89. - ``tiny`` - not strictly a "tolerance", but a very small positive number
  90. used to avoid division by zero
  91. The default `eps` is the precision of the appropriate dtype, and the default
  92. `tiny` is the precision squared. [1]_ advises that ``eps`` is "as small as
  93. you like", but for most purposes, it should not be set smaller than the default
  94. because it may prevent convergence of the algorithm. [1]_ also advises that
  95. ``tiny`` should be less than typical values of ``eps * b(n)``, so the default
  96. is a good choice unless the :math:`b_n` are very small. See [1]_ for details.
  97. maxiter : int, default: 100
  98. The maximum number of iterations of the algorithm to perform.
  99. log : bool, default: False
  100. If True, `a` and `b` return the (natural) logarithm of the terms, `tolerances`
  101. contains the logarithm of the tolerances, and the result object reports the
  102. logarithm of the convergent.
  103. Returns
  104. -------
  105. res : _RichResult
  106. An object similar to an instance of `scipy.optimize.OptimizeResult` with the
  107. following attributes. The descriptions are written as though the values will
  108. be scalars; however, if `f` returns an array, the outputs will be
  109. arrays of the same shape.
  110. success : bool array
  111. ``True`` where the algorithm terminated successfully (status ``0``);
  112. ``False`` otherwise.
  113. status : int array
  114. An integer representing the exit status of the algorithm.
  115. - ``0`` : The algorithm converged to the specified tolerances.
  116. - ``-2`` : The maximum number of iterations was reached.
  117. - ``-3`` : A non-finite value was encountered.
  118. f : float array
  119. The convergent which satisfied a termination criterion.
  120. nit : int array
  121. The number of iterations of the algorithm that were performed.
  122. nfev : int array
  123. The number of terms that were evaluated.
  124. Notes
  125. -----
  126. A generalized continued fraction is an expression of the form
  127. .. math::
  128. b_0 + \frac{a_1}{b_1 + \frac{a_2}{b_2 + \frac{a_3}{b_3 + \cdots}}}
  129. Successive "convergents" approximate the infinitely recursive continued fraction
  130. with a finite number of terms :math:`a_n` and :math:`b_n`, which are provided
  131. by callables `a` and `b`, respectively. This implementation follows the modified
  132. Lentz algorithm ([1]_, [2]_) to evaluate successive convergents until a
  133. termination condition is satisfied.
  134. References
  135. ----------
  136. .. [1] Press, William H., and Saul A. Teukolsky. "Evaluating continued fractions
  137. and computing exponential integrals." Computers in Physics 2.5 (1988): 88-89.
  138. .. [2] Lentz's algorithm. Wikipedia.
  139. https://en.wikipedia.org/wiki/Lentz%27s_algorithm
  140. .. [3] Continued fraction. Wikipedia.
  141. https://en.wikipedia.org/wiki/Continued_fraction
  142. .. [4] Generalized continued fraction. Wikipedia.
  143. https://en.wikipedia.org/wiki/Generalized_continued_fraction
  144. Examples
  145. --------
  146. The "simple continued fraction" of :math:`\pi` is given in [3]_ as
  147. .. math::
  148. 3 + \frac{1}{7 + \frac{1}{15 + \frac{1}{1 + \cdots}}}
  149. where the :math:`b_n` terms follow no obvious pattern:
  150. >>> b = [3, 7, 15, 1, 292, 1, 1, 1, 2, 1, 3, 1]
  151. and the :math:`a_n` terms are all :math:`1`.
  152. In this case, all the terms have been precomputed, so we call `_continued_fraction`
  153. with simple callables which simply return the precomputed coefficients:
  154. >>> import numpy as np
  155. >>> from scipy.special._continued_fraction import _continued_fraction
  156. >>> res = _continued_fraction(a=lambda n: 1, b=lambda n: b[n], maxiter=len(b) - 1)
  157. >>> (res.f - np.pi) / np.pi
  158. np.float64(7.067899292141148e-15)
  159. A generalized continued fraction for :math:`\pi` is given by:
  160. .. math::
  161. 3 + \frac{1^2}{6 + \frac{3^2}{6 + \frac{5^2}{6 + \cdots}}}
  162. We define the coefficient callables as:
  163. >>> def a(n):
  164. ... return (2*n - 1)**2
  165. >>>
  166. >>> def b(n):
  167. ... if n == 0:
  168. ... return 3
  169. ... else:
  170. ... return 6
  171. Then the continued fraction can be evaluated as:
  172. >>> res = _continued_fraction(a, b)
  173. >>> res
  174. success: False
  175. status: -2
  176. f: 3.1415924109719846
  177. nit: 100
  178. nfev: 101
  179. Note that the requested tolerance was not reached within the (default)
  180. maximum number of iterations because it converges very slowly.
  181. An expression that converges more rapidly is expressed as the difference
  182. between two continued fractions. We will compute both of them in one
  183. vectorized call to `_continued_fraction`.
  184. >>> u, v = 5, 239
  185. >>>
  186. >>> def a(n, a1, _):
  187. ... # The shape of the output must be the shape of the arguments
  188. ... shape = a1.shape
  189. ... if n == 0:
  190. ... return np.zeros(shape)
  191. ... elif n == 1:
  192. ... return a1
  193. ... else:
  194. ... return np.full(shape, (n-1)**2)
  195. >>>
  196. >>> def b(n, _, uv):
  197. ... shape = uv.shape
  198. ... if n == 0:
  199. ... return np.zeros(shape)
  200. ... return np.full(shape, (2*n - 1)*uv)
  201. >>>
  202. >>> res = _continued_fraction(a, b, args=([16, 4], [u, v]))
  203. >>> res
  204. success: [ True True]
  205. status: [0 0]
  206. f: [ 3.158e+00 1.674e-02]
  207. nit: [10 4]
  208. nfev: [11 5]
  209. Note that the second term converged in fewer than half the number of iterations
  210. as the first. The approximation of :math:`\pi` is the difference between the two:
  211. >>> pi = res.f[0] - res.f[1]
  212. >>> (pi - np.pi) / np.pi
  213. np.float64(2.8271597168564594e-16)
  214. If it is more efficient to compute the :math:`a_n` and :math:`b_n` terms together,
  215. consider instantiating a class with a method that computes both terms and stores
  216. the results in an attribute. Separate methods of the class retrieve the
  217. coefficients, and these methods are passed to `_continued_fraction` as arguments
  218. `a` and `b`. Similarly,if the coefficients can be computed recursively in terms of
  219. previous coefficients, use a class to maintain state between callable evaluations.
  220. """
  221. res = _continued_fraction_iv(a, b, args, tolerances, maxiter, log)
  222. a, b, args, eps, tiny, maxiter, log, a0, b0, shape, dtype, xp = res
  223. callback = None # don't want to test it, but easy to add later
  224. # The EIM framework was designed for the case in where there would
  225. # be only one callable, and all arguments of the callable would be
  226. # arrays. We're going a bit beyond that here, since we have two callables,
  227. # and the first argument is an integer (the number of the term). Rather
  228. # than complicate the framework, we wrap the user-provided callables to
  229. # make this problem fit within the existing framework.
  230. def a(n, *args, a=a):
  231. n = int(xp.real(xp_ravel(n))[0])
  232. return a(n, *args)
  233. def b(n, *args, b=b):
  234. n = int(xp.real(xp_ravel(n))[0])
  235. return b(n, *args)
  236. def func(n, *args):
  237. return xp.stack((a(n, *args), b(n, *args)), axis=-1)
  238. status = xp.full_like(a0, eim._EINPROGRESS, dtype=xp.int32) # in progress
  239. nit, nfev = 0, 1 # one function evaluation (per function) performed above
  240. maxiter = 100 if maxiter is None else maxiter
  241. # Quotations describing the algorithm are from [1]_
  242. # "... as small as you like, say eps"
  243. if eps is None:
  244. eps = xp.finfo(dtype).eps if not log else np.log(xp.finfo(dtype).eps)
  245. # "The parameter tiny should be less than typical values of eps |b_n|"
  246. if tiny is None:
  247. tiny = xp.finfo(dtype).eps**2 if not log else 2*np.log(xp.finfo(dtype).eps)
  248. # "Set f0 and C0 to the value b0 or to tiny if b0=0. Set D0 = 0.
  249. zero = -xp.inf if log else 0
  250. fn = xp.where(b0 == zero, tiny, b0)
  251. Cnm1 = xp_copy(fn)
  252. Dnm1 = xp.full_like(fn, zero)
  253. CnDn = xp.full_like(fn, xp.inf)
  254. work = _RichResult(n=0, fn=fn, Cnm1=Cnm1, Dnm1=Dnm1, CnDn=CnDn,
  255. eps=eps, tiny=tiny,
  256. nit=nit, nfev=nfev, status=status)
  257. res_work_pairs = [('status', 'status'), ('f', 'fn'),
  258. ('nit', 'nit'), ('nfev', 'nfev')]
  259. def pre_func_eval(work):
  260. work.n = xp.reshape(xp.asarray(work.n + 1), (-1,))
  261. return work.n
  262. def post_func_eval(n, ab, work):
  263. an, bn = ab[..., 0], ab[..., 1]
  264. zero = 0 if not log else -xp.inf
  265. # "Set D_n = 1/(b_n + a_n D_{n-1}) or 1/tiny, if the denominator vanishes"
  266. denominator = (bn + an*work.Dnm1 if not log
  267. else _logaddexp(bn, an + work.Dnm1, xp=xp))
  268. denominator[denominator == zero] = tiny
  269. Dn = (1/denominator if not log
  270. else -denominator)
  271. # "Set C_n = b_n + a_n / C_{n-1} (or =tiny, if the expression vanishes)"
  272. Cn = (bn + an / work.Cnm1 if not log
  273. else _logaddexp(bn, an - work.Cnm1, xp=xp))
  274. Cn[Cn == zero] = tiny
  275. # "and set f_n = f_{n-1} C_n D_n"
  276. work.CnDn = (Cn * Dn if not log
  277. else Cn + Dn)
  278. work.fn = (work.fn * work.CnDn if not log
  279. else work.fn + work.CnDn)
  280. work.Cnm1, work.Dnm1 = Cn, Dn
  281. def check_termination(work):
  282. # Check for all terminal conditions and record statuses.
  283. stop = xp.zeros_like(work.CnDn, dtype=xp.bool)
  284. # "You quit when |D_n C_n - 1| is as small as you like, say eps"
  285. pij = xp.full_like(work.CnDn, xp.pi*1j) if log else None
  286. residual = (xp.abs(work.CnDn - 1) if not log
  287. else xp.real(_logaddexp(work.CnDn, pij, xp=xp)))
  288. i = residual < work.eps
  289. work.status[i] = eim._ECONVERGED
  290. stop[i] = True
  291. # If function value is NaN, report failure.
  292. i = (~xp.isfinite(work.fn) if not log
  293. else ~(xp.isfinite(work.fn) | (work.fn == -xp.inf)))
  294. work.status[i] = eim._EVALUEERR
  295. stop[i] = True
  296. return stop
  297. def post_termination_check(work):
  298. pass
  299. def customize_result(res, shape):
  300. # Only needed pre-NEP 50
  301. res['f'] = xp.asarray(res['f'], dtype=dtype)
  302. res['f'] = res['f'][()] if res['f'].ndim == 0 else res['f']
  303. return shape
  304. return eim._loop(work, callback, shape, maxiter, func, args, dtype,
  305. pre_func_eval, post_func_eval, check_termination,
  306. post_termination_check, customize_result, res_work_pairs,
  307. xp=xp)