test__quad_vec.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. import pytest
  2. import numpy as np
  3. from numpy.testing import assert_allclose
  4. from scipy.integrate import quad_vec
  5. from scipy._lib._array_api import make_xp_test_case
  6. from multiprocessing.dummy import Pool
  7. quadrature_params = pytest.mark.parametrize(
  8. 'quadrature', [None, "gk15", "gk21", "trapezoid"])
  9. def _lorenzian(x):
  10. return 1 / (1 + x**2)
  11. def _func_with_args(x, a):
  12. return x * (x + a) * np.arange(3)
  13. @make_xp_test_case(quad_vec)
  14. class TestQuadVec:
  15. @quadrature_params
  16. def test_quad_vec_simple(self, quadrature):
  17. n = np.arange(10)
  18. def f(x):
  19. return x ** n
  20. for epsabs in [0.1, 1e-3, 1e-6]:
  21. if quadrature == 'trapezoid' and epsabs < 1e-4:
  22. # slow: skip
  23. continue
  24. kwargs = dict(epsabs=epsabs, quadrature=quadrature)
  25. exact = 2**(n+1)/(n + 1)
  26. res, err = quad_vec(f, 0, 2, norm='max', **kwargs)
  27. assert_allclose(res, exact, rtol=0, atol=epsabs)
  28. res, err = quad_vec(f, 0, 2, norm='2', **kwargs)
  29. assert np.linalg.norm(res - exact) < epsabs
  30. res, err = quad_vec(f, 0, 2, norm='max', points=(0.5, 1.0), **kwargs)
  31. assert_allclose(res, exact, rtol=0, atol=epsabs)
  32. res, err, *rest = quad_vec(f, 0, 2, norm='max',
  33. epsrel=1e-8,
  34. full_output=True,
  35. limit=10000,
  36. **kwargs)
  37. assert_allclose(res, exact, rtol=0, atol=epsabs)
  38. @quadrature_params
  39. def test_quad_vec_simple_inf(self, quadrature):
  40. def f(x):
  41. return 1 / (1 + np.float64(x) ** 2)
  42. for epsabs in [0.1, 1e-3, 1e-6]:
  43. if quadrature == 'trapezoid' and epsabs < 1e-4:
  44. # slow: skip
  45. continue
  46. kwargs = dict(norm='max', epsabs=epsabs, quadrature=quadrature)
  47. res, err = quad_vec(f, 0, np.inf, **kwargs)
  48. assert_allclose(res, np.pi/2, rtol=0, atol=max(epsabs, err))
  49. res, err = quad_vec(f, 0, -np.inf, **kwargs)
  50. assert_allclose(res, -np.pi/2, rtol=0, atol=max(epsabs, err))
  51. res, err = quad_vec(f, -np.inf, 0, **kwargs)
  52. assert_allclose(res, np.pi/2, rtol=0, atol=max(epsabs, err))
  53. res, err = quad_vec(f, np.inf, 0, **kwargs)
  54. assert_allclose(res, -np.pi/2, rtol=0, atol=max(epsabs, err))
  55. res, err = quad_vec(f, -np.inf, np.inf, **kwargs)
  56. assert_allclose(res, np.pi, rtol=0, atol=max(epsabs, err))
  57. res, err = quad_vec(f, np.inf, -np.inf, **kwargs)
  58. assert_allclose(res, -np.pi, rtol=0, atol=max(epsabs, err))
  59. res, err = quad_vec(f, np.inf, np.inf, **kwargs)
  60. assert_allclose(res, 0, rtol=0, atol=max(epsabs, err))
  61. res, err = quad_vec(f, -np.inf, -np.inf, **kwargs)
  62. assert_allclose(res, 0, rtol=0, atol=max(epsabs, err))
  63. res, err = quad_vec(f, 0, np.inf, points=(1.0, 2.0), **kwargs)
  64. assert_allclose(res, np.pi/2, rtol=0, atol=max(epsabs, err))
  65. def f(x):
  66. return np.sin(x + 2) / (1 + x ** 2)
  67. exact = np.pi / np.e * np.sin(2)
  68. epsabs = 1e-5
  69. res, err, info = quad_vec(f, -np.inf, np.inf, limit=1000, norm='max',
  70. epsabs=epsabs, quadrature=quadrature,
  71. full_output=True)
  72. assert info.status == 1
  73. assert_allclose(res, exact, rtol=0, atol=max(epsabs, 1.5 * err))
  74. def test_quad_vec_args(self):
  75. def f(x, a):
  76. return x * (x + a) * np.arange(3)
  77. a = 2
  78. exact = np.array([0, 4/3, 8/3])
  79. res, err = quad_vec(f, 0, 1, args=(a,))
  80. assert_allclose(res, exact, rtol=0, atol=1e-4)
  81. @pytest.mark.fail_slow(10)
  82. def test_quad_vec_pool(self):
  83. f = _lorenzian
  84. res, err = quad_vec(f, -np.inf, np.inf, norm='max', epsabs=1e-4, workers=4)
  85. assert_allclose(res, np.pi, rtol=0, atol=1e-4)
  86. with Pool(10) as pool:
  87. def f(x):
  88. return 1 / (1 + x ** 2)
  89. res, _ = quad_vec(f, -np.inf, np.inf, norm='max', epsabs=1e-4,
  90. workers=pool.map)
  91. assert_allclose(res, np.pi, rtol=0, atol=1e-4)
  92. @pytest.mark.fail_slow(10)
  93. @pytest.mark.parametrize('extra_args', [2, (2,)])
  94. @pytest.mark.parametrize(
  95. 'workers',
  96. [1, pytest.param(10, marks=pytest.mark.parallel_threads_limit(4))]
  97. )
  98. def test_quad_vec_pool_args(self, extra_args, workers):
  99. f = _func_with_args
  100. exact = np.array([0, 4/3, 8/3])
  101. res, err = quad_vec(f, 0, 1, args=extra_args, workers=workers)
  102. assert_allclose(res, exact, rtol=0, atol=1e-4)
  103. with Pool(workers) as pool:
  104. res, err = quad_vec(f, 0, 1, args=extra_args, workers=pool.map)
  105. assert_allclose(res, exact, rtol=0, atol=1e-4)
  106. @quadrature_params
  107. def test_num_eval(self, quadrature):
  108. def f(x):
  109. count[0] += 1
  110. return x**5
  111. count = [0]
  112. res = quad_vec(f, 0, 1, norm='max', full_output=True, quadrature=quadrature)
  113. assert res[2].neval == count[0]
  114. def test_info(self):
  115. def f(x):
  116. return np.ones((3, 2, 1))
  117. res, err, info = quad_vec(f, 0, 1, norm='max', full_output=True)
  118. assert info.success is True
  119. assert info.status == 0
  120. assert info.message == 'Target precision reached.'
  121. assert info.neval > 0
  122. assert info.intervals.shape[1] == 2
  123. assert info.integrals.shape == (info.intervals.shape[0], 3, 2, 1)
  124. assert info.errors.shape == (info.intervals.shape[0],)
  125. def test_nan_inf(self):
  126. def f_nan(x):
  127. return np.nan
  128. def f_inf(x):
  129. return np.inf if x < 0.1 else 1/x
  130. res, err, info = quad_vec(f_nan, 0, 1, full_output=True)
  131. assert info.status == 3
  132. res, err, info = quad_vec(f_inf, 0, 1, full_output=True)
  133. assert info.status == 3
  134. @pytest.mark.parametrize('a,b', [(0, 1), (0, np.inf), (np.inf, 0),
  135. (-np.inf, np.inf), (np.inf, -np.inf)])
  136. def test_points(self, a, b):
  137. # Check that initial interval splitting is done according to
  138. # `points`, by checking that consecutive sets of 15 point (for
  139. # gk15) function evaluations lie between `points`
  140. points = (0, 0.25, 0.5, 0.75, 1.0)
  141. points += tuple(-x for x in points)
  142. quadrature_points = 15
  143. interval_sets = []
  144. count = 0
  145. def f(x):
  146. nonlocal count
  147. if count % quadrature_points == 0:
  148. interval_sets.append(set())
  149. count += 1
  150. interval_sets[-1].add(float(x))
  151. return 0.0
  152. quad_vec(f, a, b, points=points, quadrature='gk15', limit=0)
  153. # Check that all point sets lie in a single `points` interval
  154. for p in interval_sets:
  155. j = np.searchsorted(sorted(points), tuple(p))
  156. assert np.all(j == j[0])