_cubature.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731
  1. import math
  2. import heapq
  3. import itertools
  4. from dataclasses import dataclass, field
  5. from types import ModuleType
  6. from typing import Any, TypeAlias
  7. from scipy._lib._array_api import (
  8. array_namespace,
  9. xp_size,
  10. xp_copy,
  11. xp_promote,
  12. xp_capabilities
  13. )
  14. from scipy._lib._util import MapWrapper
  15. from scipy.integrate._rules import (
  16. ProductNestedFixed,
  17. GaussKronrodQuadrature,
  18. GenzMalikCubature,
  19. )
  20. from scipy.integrate._rules._base import _split_subregion
  21. __all__ = ['cubature']
  22. Array: TypeAlias = Any # To be changed to an array-api-typing Protocol later
  23. @dataclass
  24. class CubatureRegion:
  25. estimate: Array
  26. error: Array
  27. a: Array
  28. b: Array
  29. _xp: ModuleType = field(repr=False)
  30. def __lt__(self, other):
  31. # Consider regions with higher error estimates as being "less than" regions with
  32. # lower order estimates, so that regions with high error estimates are placed at
  33. # the top of the heap.
  34. this_err = self._xp.max(self._xp.abs(self.error))
  35. other_err = self._xp.max(self._xp.abs(other.error))
  36. return this_err > other_err
  37. @dataclass
  38. class CubatureResult:
  39. estimate: Array
  40. error: Array
  41. status: str
  42. regions: list[CubatureRegion]
  43. subdivisions: int
  44. atol: float
  45. rtol: float
  46. @xp_capabilities(allow_dask_compute=True, jax_jit=False)
  47. def cubature(f, a, b, *, rule="gk21", rtol=1e-8, atol=0, max_subdivisions=10000,
  48. args=(), workers=1, points=None):
  49. r"""
  50. Adaptive cubature of multidimensional array-valued function.
  51. Given an arbitrary integration rule, this function returns an estimate of the
  52. integral to the requested tolerance over the region defined by the arrays `a` and
  53. `b` specifying the corners of a hypercube.
  54. Convergence is not guaranteed for all integrals.
  55. Parameters
  56. ----------
  57. f : callable
  58. Function to integrate. `f` must have the signature::
  59. f(x : ndarray, *args) -> ndarray
  60. `f` should accept arrays ``x`` of shape::
  61. (npoints, ndim)
  62. and output arrays of shape::
  63. (npoints, output_dim_1, ..., output_dim_n)
  64. In this case, `cubature` will return arrays of shape::
  65. (output_dim_1, ..., output_dim_n)
  66. a, b : array_like
  67. Lower and upper limits of integration as 1D arrays specifying the left and right
  68. endpoints of the intervals being integrated over. Limits can be infinite.
  69. rule : str, optional
  70. Rule used to estimate the integral. If passing a string, the options are
  71. "gauss-kronrod" (21 node), or "genz-malik" (degree 7). If a rule like
  72. "gauss-kronrod" is specified for an ``n``-dim integrand, the corresponding
  73. Cartesian product rule is used. "gk21", "gk15" are also supported for
  74. compatibility with `quad_vec`. See Notes.
  75. rtol, atol : float, optional
  76. Relative and absolute tolerances. Iterations are performed until the error is
  77. estimated to be less than ``atol + rtol * abs(est)``. Here `rtol` controls
  78. relative accuracy (number of correct digits), while `atol` controls absolute
  79. accuracy (number of correct decimal places). To achieve the desired `rtol`, set
  80. `atol` to be smaller than the smallest value that can be expected from
  81. ``rtol * abs(y)`` so that `rtol` dominates the allowable error. If `atol` is
  82. larger than ``rtol * abs(y)`` the number of correct digits is not guaranteed.
  83. Conversely, to achieve the desired `atol`, set `rtol` such that
  84. ``rtol * abs(y)`` is always smaller than `atol`. Default values are 1e-8 for
  85. `rtol` and 0 for `atol`.
  86. max_subdivisions : int, optional
  87. Upper bound on the number of subdivisions to perform. Default is 10,000.
  88. args : tuple, optional
  89. Additional positional args passed to `f`, if any.
  90. workers : int or map-like callable, optional
  91. If `workers` is an integer, part of the computation is done in parallel
  92. subdivided to this many tasks (using :class:`python:multiprocessing.pool.Pool`).
  93. Supply `-1` to use all cores available to the Process. Alternatively, supply a
  94. map-like callable, such as :meth:`python:multiprocessing.pool.Pool.map` for
  95. evaluating the population in parallel. This evaluation is carried out as
  96. ``workers(func, iterable)``.
  97. points : list of array_like, optional
  98. List of points to avoid evaluating `f` at, under the condition that the rule
  99. being used does not evaluate `f` on the boundary of a region (which is the
  100. case for all Genz-Malik and Gauss-Kronrod rules). This can be useful if `f` has
  101. a singularity at the specified point. This should be a list of array-likes where
  102. each element has length ``ndim``. Default is empty. See Examples.
  103. Returns
  104. -------
  105. res : object
  106. Object containing the results of the estimation. It has the following
  107. attributes:
  108. estimate : ndarray
  109. Estimate of the value of the integral over the overall region specified.
  110. error : ndarray
  111. Estimate of the error of the approximation over the overall region
  112. specified.
  113. status : str
  114. Whether the estimation was successful. Can be either: "converged",
  115. "not_converged".
  116. subdivisions : int
  117. Number of subdivisions performed.
  118. atol, rtol : float
  119. Requested tolerances for the approximation.
  120. regions: list of object
  121. List of objects containing the estimates of the integral over smaller
  122. regions of the domain.
  123. Each object in ``regions`` has the following attributes:
  124. a, b : ndarray
  125. Points describing the corners of the region. If the original integral
  126. contained infinite limits or was over a region described by `region`,
  127. then `a` and `b` are in the transformed coordinates.
  128. estimate : ndarray
  129. Estimate of the value of the integral over this region.
  130. error : ndarray
  131. Estimate of the error of the approximation over this region.
  132. Notes
  133. -----
  134. The algorithm uses a similar algorithm to `quad_vec`, which itself is based on the
  135. implementation of QUADPACK's DQAG* algorithms, implementing global error control and
  136. adaptive subdivision.
  137. The source of the nodes and weights used for Gauss-Kronrod quadrature can be found
  138. in [1]_, and the algorithm for calculating the nodes and weights in Genz-Malik
  139. cubature can be found in [2]_.
  140. The rules currently supported via the `rule` argument are:
  141. - ``"gauss-kronrod"``, 21-node Gauss-Kronrod
  142. - ``"genz-malik"``, n-node Genz-Malik
  143. If using Gauss-Kronrod for an ``n``-dim integrand where ``n > 2``, then the
  144. corresponding Cartesian product rule will be found by taking the Cartesian product
  145. of the nodes in the 1D case. This means that the number of nodes scales
  146. exponentially as ``21^n`` in the Gauss-Kronrod case, which may be problematic in a
  147. moderate number of dimensions.
  148. Genz-Malik is typically less accurate than Gauss-Kronrod but has much fewer nodes,
  149. so in this situation using "genz-malik" might be preferable.
  150. Infinite limits are handled with an appropriate variable transformation. Assuming
  151. ``a = [a_1, ..., a_n]`` and ``b = [b_1, ..., b_n]``:
  152. If :math:`a_i = -\infty` and :math:`b_i = \infty`, the i-th integration variable
  153. will use the transformation :math:`x = \frac{1-|t|}{t}` and :math:`t \in (-1, 1)`.
  154. If :math:`a_i \ne \pm\infty` and :math:`b_i = \infty`, the i-th integration variable
  155. will use the transformation :math:`x = a_i + \frac{1-t}{t}` and
  156. :math:`t \in (0, 1)`.
  157. If :math:`a_i = -\infty` and :math:`b_i \ne \pm\infty`, the i-th integration
  158. variable will use the transformation :math:`x = b_i - \frac{1-t}{t}` and
  159. :math:`t \in (0, 1)`.
  160. References
  161. ----------
  162. .. [1] R. Piessens, E. de Doncker, Quadpack: A Subroutine Package for Automatic
  163. Integration, files: dqk21.f, dqk15.f (1983).
  164. .. [2] A.C. Genz, A.A. Malik, Remarks on algorithm 006: An adaptive algorithm for
  165. numerical integration over an N-dimensional rectangular region, Journal of
  166. Computational and Applied Mathematics, Volume 6, Issue 4, 1980, Pages 295-302,
  167. ISSN 0377-0427
  168. :doi:`10.1016/0771-050X(80)90039-X`
  169. Examples
  170. --------
  171. **1D integral with vector output**:
  172. .. math::
  173. \int^1_0 \mathbf f(x) \text dx
  174. Where ``f(x) = x^n`` and ``n = np.arange(10)`` is a vector. Since no rule is
  175. specified, the default "gk21" is used, which corresponds to Gauss-Kronrod
  176. integration with 21 nodes.
  177. >>> import numpy as np
  178. >>> from scipy.integrate import cubature
  179. >>> def f(x, n):
  180. ... # Make sure x and n are broadcastable
  181. ... return x[:, np.newaxis]**n[np.newaxis, :]
  182. >>> res = cubature(
  183. ... f,
  184. ... a=[0],
  185. ... b=[1],
  186. ... args=(np.arange(10),),
  187. ... )
  188. >>> res.estimate
  189. array([1. , 0.5 , 0.33333333, 0.25 , 0.2 ,
  190. 0.16666667, 0.14285714, 0.125 , 0.11111111, 0.1 ])
  191. **7D integral with arbitrary-shaped array output**::
  192. f(x) = cos(2*pi*r + alphas @ x)
  193. for some ``r`` and ``alphas``, and the integral is performed over the unit
  194. hybercube, :math:`[0, 1]^7`. Since the integral is in a moderate number of
  195. dimensions, "genz-malik" is used rather than the default "gauss-kronrod" to
  196. avoid constructing a product rule with :math:`21^7 \approx 2 \times 10^9` nodes.
  197. >>> import numpy as np
  198. >>> from scipy.integrate import cubature
  199. >>> def f(x, r, alphas):
  200. ... # f(x) = cos(2*pi*r + alphas @ x)
  201. ... # Need to allow r and alphas to be arbitrary shape
  202. ... npoints, ndim = x.shape[0], x.shape[-1]
  203. ... alphas = alphas[np.newaxis, ...]
  204. ... x = x.reshape(npoints, *([1]*(len(alphas.shape) - 1)), ndim)
  205. ... return np.cos(2*np.pi*r + np.sum(alphas * x, axis=-1))
  206. >>> rng = np.random.default_rng()
  207. >>> r, alphas = rng.random((2, 3)), rng.random((2, 3, 7))
  208. >>> res = cubature(
  209. ... f=f,
  210. ... a=np.array([0, 0, 0, 0, 0, 0, 0]),
  211. ... b=np.array([1, 1, 1, 1, 1, 1, 1]),
  212. ... rtol=1e-5,
  213. ... rule="genz-malik",
  214. ... args=(r, alphas),
  215. ... )
  216. >>> res.estimate
  217. array([[-0.79812452, 0.35246913, -0.52273628],
  218. [ 0.88392779, 0.59139899, 0.41895111]])
  219. **Parallel computation with** `workers`:
  220. >>> from concurrent.futures import ThreadPoolExecutor
  221. >>> with ThreadPoolExecutor() as executor:
  222. ... res = cubature(
  223. ... f=f,
  224. ... a=np.array([0, 0, 0, 0, 0, 0, 0]),
  225. ... b=np.array([1, 1, 1, 1, 1, 1, 1]),
  226. ... rtol=1e-5,
  227. ... rule="genz-malik",
  228. ... args=(r, alphas),
  229. ... workers=executor.map,
  230. ... )
  231. >>> res.estimate
  232. array([[-0.79812452, 0.35246913, -0.52273628],
  233. [ 0.88392779, 0.59139899, 0.41895111]])
  234. **2D integral with infinite limits**:
  235. .. math::
  236. \int^{ \infty }_{ -\infty }
  237. \int^{ \infty }_{ -\infty }
  238. e^{-x^2-y^2}
  239. \text dy
  240. \text dx
  241. >>> def gaussian(x):
  242. ... return np.exp(-np.sum(x**2, axis=-1))
  243. >>> res = cubature(gaussian, [-np.inf, -np.inf], [np.inf, np.inf])
  244. >>> res.estimate
  245. 3.1415926
  246. **1D integral with singularities avoided using** `points`:
  247. .. math::
  248. \int^{ 1 }_{ -1 }
  249. \frac{\sin(x)}{x}
  250. \text dx
  251. It is necessary to use the `points` parameter to avoid evaluating `f` at the origin.
  252. >>> def sinc(x):
  253. ... return np.sin(x)/x
  254. >>> res = cubature(sinc, [-1], [1], points=[[0]])
  255. >>> res.estimate
  256. 1.8921661
  257. """
  258. # It is also possible to use a custom rule, but this is not yet part of the public
  259. # API. An example of this can be found in the class scipy.integrate._rules.Rule.
  260. xp = array_namespace(a, b)
  261. max_subdivisions = float("inf") if max_subdivisions is None else max_subdivisions
  262. points = [] if points is None else points
  263. # Convert a and b to arrays and convert each point in points to an array, promoting
  264. # each to a common floating dtype.
  265. a, b, *points = xp_promote(a, b, *points, broadcast=True, force_floating=True,
  266. xp=xp)
  267. result_dtype = a.dtype
  268. if xp_size(a) == 0 or xp_size(b) == 0:
  269. raise ValueError("`a` and `b` must be nonempty")
  270. if a.ndim != 1 or b.ndim != 1:
  271. raise ValueError("`a` and `b` must be 1D arrays")
  272. # If the rule is a string, convert to a corresponding product rule
  273. if isinstance(rule, str):
  274. ndim = xp_size(a)
  275. if rule == "genz-malik":
  276. rule = GenzMalikCubature(ndim, xp=xp)
  277. else:
  278. quadratues = {
  279. "gauss-kronrod": GaussKronrodQuadrature(21, xp=xp),
  280. # Also allow names quad_vec uses:
  281. "gk21": GaussKronrodQuadrature(21, xp=xp),
  282. "gk15": GaussKronrodQuadrature(15, xp=xp),
  283. }
  284. base_rule = quadratues.get(rule)
  285. if base_rule is None:
  286. raise ValueError(f"unknown rule {rule}")
  287. rule = ProductNestedFixed([base_rule] * ndim)
  288. # If any of limits are the wrong way around (a > b), flip them and keep track of
  289. # the sign.
  290. sign = (-1) ** xp.sum(xp.astype(a > b, xp.int8), dtype=result_dtype)
  291. a_flipped = xp.min(xp.stack([a, b]), axis=0)
  292. b_flipped = xp.max(xp.stack([a, b]), axis=0)
  293. a, b = a_flipped, b_flipped
  294. # If any of the limits are infinite, apply a transformation
  295. if xp.any(xp.isinf(a)) or xp.any(xp.isinf(b)):
  296. f = _InfiniteLimitsTransform(f, a, b, xp=xp)
  297. a, b = f.transformed_limits
  298. # Map points from the original coordinates to the new transformed coordinates.
  299. #
  300. # `points` is a list of arrays of shape (ndim,), but transformations are applied
  301. # to arrays of shape (npoints, ndim).
  302. #
  303. # It is not possible to combine all the points into one array and then apply
  304. # f.inv to all of them at once since `points` needs to remain iterable.
  305. # Instead, each point is reshaped to an array of shape (1, ndim), `f.inv` is
  306. # applied, and then each is reshaped back to (ndim,).
  307. points = [xp.reshape(point, (1, -1)) for point in points]
  308. points = [f.inv(point) for point in points]
  309. points = [xp.reshape(point, (-1,)) for point in points]
  310. # Include any problematic points introduced by the transformation
  311. points.extend(f.points)
  312. # If any problematic points are specified, divide the initial region so that these
  313. # points lie on the edge of a subregion.
  314. #
  315. # This means ``f`` won't be evaluated there if the rule being used has no evaluation
  316. # points on the boundary.
  317. if len(points) == 0:
  318. initial_regions = [(a, b)]
  319. else:
  320. initial_regions = _split_region_at_points(a, b, points, xp)
  321. regions = []
  322. est = 0.0
  323. err = 0.0
  324. for a_k, b_k in initial_regions:
  325. est_k = rule.estimate(f, a_k, b_k, args)
  326. err_k = rule.estimate_error(f, a_k, b_k, args)
  327. regions.append(CubatureRegion(est_k, err_k, a_k, b_k, xp))
  328. est += est_k
  329. err += err_k
  330. subdivisions = 0
  331. success = True
  332. with MapWrapper(workers) as mapwrapper:
  333. while xp.any(err > atol + rtol * xp.abs(est)):
  334. # region_k is the region with highest estimated error
  335. region_k = heapq.heappop(regions)
  336. est_k = region_k.estimate
  337. err_k = region_k.error
  338. a_k, b_k = region_k.a, region_k.b
  339. # Subtract the estimate of the integral and its error over this region from
  340. # the current global estimates, since these will be refined in the loop over
  341. # all subregions.
  342. est -= est_k
  343. err -= err_k
  344. # Find all 2^ndim subregions formed by splitting region_k along each axis,
  345. # e.g. for 1D integrals this splits an estimate over an interval into an
  346. # estimate over two subintervals, for 3D integrals this splits an estimate
  347. # over a cube into 8 subcubes.
  348. #
  349. # For each of the new subregions, calculate an estimate for the integral and
  350. # the error there, and push these regions onto the heap for potential
  351. # further subdividing.
  352. executor_args = zip(
  353. itertools.repeat(f),
  354. itertools.repeat(rule),
  355. itertools.repeat(args),
  356. _split_subregion(a_k, b_k, xp),
  357. )
  358. for subdivision_result in mapwrapper(_process_subregion, executor_args):
  359. a_k_sub, b_k_sub, est_sub, err_sub = subdivision_result
  360. est += est_sub
  361. err += err_sub
  362. new_region = CubatureRegion(est_sub, err_sub, a_k_sub, b_k_sub, xp)
  363. heapq.heappush(regions, new_region)
  364. subdivisions += 1
  365. if subdivisions >= max_subdivisions:
  366. success = False
  367. break
  368. status = "converged" if success else "not_converged"
  369. # Apply sign change to handle any limits which were initially flipped.
  370. est = sign * est
  371. return CubatureResult(
  372. estimate=est,
  373. error=err,
  374. status=status,
  375. subdivisions=subdivisions,
  376. regions=regions,
  377. atol=atol,
  378. rtol=rtol,
  379. )
  380. def _process_subregion(data):
  381. f, rule, args, coord = data
  382. a_k_sub, b_k_sub = coord
  383. est_sub = rule.estimate(f, a_k_sub, b_k_sub, args)
  384. err_sub = rule.estimate_error(f, a_k_sub, b_k_sub, args)
  385. return a_k_sub, b_k_sub, est_sub, err_sub
  386. def _is_strictly_in_region(a, b, point, xp):
  387. if xp.all(point == a) or xp.all(point == b):
  388. return False
  389. return xp.all(a <= point) and xp.all(point <= b)
  390. def _split_region_at_points(a, b, points, xp):
  391. """
  392. Given the integration limits `a` and `b` describing a rectangular region and a list
  393. of `points`, find the list of ``[(a_1, b_1), ..., (a_l, b_l)]`` which breaks up the
  394. initial region into smaller subregion such that no `points` lie strictly inside
  395. any of the subregions.
  396. """
  397. regions = [(a, b)]
  398. for point in points:
  399. if xp.any(xp.isinf(point)):
  400. # If a point is specified at infinity, ignore.
  401. #
  402. # This case occurs when points are given by the user to avoid, but after
  403. # applying a transformation, they are removed.
  404. continue
  405. new_subregions = []
  406. for a_k, b_k in regions:
  407. if _is_strictly_in_region(a_k, b_k, point, xp):
  408. subregions = _split_subregion(a_k, b_k, xp, point)
  409. for left, right in subregions:
  410. # Skip any zero-width regions.
  411. if xp.any(left == right):
  412. continue
  413. else:
  414. new_subregions.append((left, right))
  415. new_subregions.extend(subregions)
  416. else:
  417. new_subregions.append((a_k, b_k))
  418. regions = new_subregions
  419. return regions
  420. class _VariableTransform:
  421. """
  422. A transformation that can be applied to an integral.
  423. """
  424. @property
  425. def transformed_limits(self):
  426. """
  427. New limits of integration after applying the transformation.
  428. """
  429. raise NotImplementedError
  430. @property
  431. def points(self):
  432. """
  433. Any problematic points introduced by the transformation.
  434. These should be specified as points where ``_VariableTransform(f)(self, point)``
  435. would be problematic.
  436. For example, if the transformation ``x = 1/((1-t)(1+t))`` is applied to a
  437. univariate integral, then points should return ``[ [1], [-1] ]``.
  438. """
  439. return []
  440. def inv(self, x):
  441. """
  442. Map points ``x`` to ``t`` such that if ``f`` is the original function and ``g``
  443. is the function after the transformation is applied, then::
  444. f(x) = g(self.inv(x))
  445. """
  446. raise NotImplementedError
  447. def __call__(self, t, *args, **kwargs):
  448. """
  449. Apply the transformation to ``f`` and multiply by the Jacobian determinant.
  450. This should be the new integrand after the transformation has been applied so
  451. that the following is satisfied::
  452. f_transformed = _VariableTransform(f)
  453. cubature(f, a, b) == cubature(
  454. f_transformed,
  455. *f_transformed.transformed_limits(a, b),
  456. )
  457. """
  458. raise NotImplementedError
  459. class _InfiniteLimitsTransform(_VariableTransform):
  460. r"""
  461. Transformation for handling infinite limits.
  462. Assuming ``a = [a_1, ..., a_n]`` and ``b = [b_1, ..., b_n]``:
  463. If :math:`a_i = -\infty` and :math:`b_i = \infty`, the i-th integration variable
  464. will use the transformation :math:`x = \frac{1-|t|}{t}` and :math:`t \in (-1, 1)`.
  465. If :math:`a_i \ne \pm\infty` and :math:`b_i = \infty`, the i-th integration variable
  466. will use the transformation :math:`x = a_i + \frac{1-t}{t}` and
  467. :math:`t \in (0, 1)`.
  468. If :math:`a_i = -\infty` and :math:`b_i \ne \pm\infty`, the i-th integration
  469. variable will use the transformation :math:`x = b_i - \frac{1-t}{t}` and
  470. :math:`t \in (0, 1)`.
  471. """
  472. def __init__(self, f, a, b, xp):
  473. self._xp = xp
  474. self._f = f
  475. self._orig_a = a
  476. self._orig_b = b
  477. # (-oo, oo) will be mapped to (-1, 1).
  478. self._double_inf_pos = (a == -math.inf) & (b == math.inf)
  479. # (start, oo) will be mapped to (0, 1).
  480. start_inf_mask = (a != -math.inf) & (b == math.inf)
  481. # (-oo, end) will be mapped to (0, 1).
  482. inf_end_mask = (a == -math.inf) & (b != math.inf)
  483. # This is handled by making the transformation t = -x and reducing it to
  484. # the other semi-infinite case.
  485. self._semi_inf_pos = start_inf_mask | inf_end_mask
  486. # Since we flip the limits, we don't need to separately multiply the
  487. # integrand by -1.
  488. self._orig_a[inf_end_mask] = -b[inf_end_mask]
  489. self._orig_b[inf_end_mask] = -a[inf_end_mask]
  490. self._num_inf = self._xp.sum(
  491. self._xp.astype(self._double_inf_pos | self._semi_inf_pos, self._xp.int64),
  492. ).__int__()
  493. @property
  494. def transformed_limits(self):
  495. a = xp_copy(self._orig_a)
  496. b = xp_copy(self._orig_b)
  497. a[self._double_inf_pos] = -1
  498. b[self._double_inf_pos] = 1
  499. a[self._semi_inf_pos] = 0
  500. b[self._semi_inf_pos] = 1
  501. return a, b
  502. @property
  503. def points(self):
  504. # If there are infinite limits, then the origin becomes a problematic point
  505. # due to a division by zero there.
  506. # If the function using this class only wraps f when a and b contain infinite
  507. # limits, this condition will always be met (as is the case with cubature).
  508. #
  509. # If a and b do not contain infinite limits but f is still wrapped with this
  510. # class, then without this condition the initial region of integration will
  511. # be split around the origin unnecessarily.
  512. if self._num_inf != 0:
  513. return [self._xp.zeros(self._orig_a.shape)]
  514. else:
  515. return []
  516. def inv(self, x):
  517. t = xp_copy(x)
  518. npoints = x.shape[0]
  519. double_inf_mask = self._xp.tile(
  520. self._double_inf_pos[self._xp.newaxis, :],
  521. (npoints, 1),
  522. )
  523. semi_inf_mask = self._xp.tile(
  524. self._semi_inf_pos[self._xp.newaxis, :],
  525. (npoints, 1),
  526. )
  527. # If any components of x are 0, then this component will be mapped to infinity
  528. # under the transformation used for doubly-infinite limits.
  529. #
  530. # Handle the zero values and non-zero values separately to avoid division by
  531. # zero.
  532. zero_mask = x[double_inf_mask] == 0
  533. non_zero_mask = double_inf_mask & ~zero_mask
  534. t[zero_mask] = math.inf
  535. t[non_zero_mask] = 1/(x[non_zero_mask] + self._xp.sign(x[non_zero_mask]))
  536. start = self._xp.tile(self._orig_a[self._semi_inf_pos], (npoints,))
  537. t[semi_inf_mask] = 1/(x[semi_inf_mask] - start + 1)
  538. return t
  539. def __call__(self, t, *args, **kwargs):
  540. x = xp_copy(t)
  541. npoints = t.shape[0]
  542. double_inf_mask = self._xp.tile(
  543. self._double_inf_pos[self._xp.newaxis, :],
  544. (npoints, 1),
  545. )
  546. semi_inf_mask = self._xp.tile(
  547. self._semi_inf_pos[self._xp.newaxis, :],
  548. (npoints, 1),
  549. )
  550. # For (-oo, oo) -> (-1, 1), use the transformation x = (1-|t|)/t.
  551. x[double_inf_mask] = (
  552. (1 - self._xp.abs(t[double_inf_mask])) / t[double_inf_mask]
  553. )
  554. start = self._xp.tile(self._orig_a[self._semi_inf_pos], (npoints,))
  555. # For (start, oo) -> (0, 1), use the transformation x = start + (1-t)/t.
  556. x[semi_inf_mask] = start + (1 - t[semi_inf_mask]) / t[semi_inf_mask]
  557. jacobian_det = 1/self._xp.prod(
  558. self._xp.reshape(
  559. t[semi_inf_mask | double_inf_mask]**2,
  560. (-1, self._num_inf),
  561. ),
  562. axis=-1,
  563. )
  564. f_x = self._f(x, *args, **kwargs)
  565. jacobian_det = self._xp.reshape(jacobian_det, (-1, *([1]*(len(f_x.shape) - 1))))
  566. return f_x * jacobian_det