test_bary_rational.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. # Copyright (c) 2017, The Chancellor, Masters and Scholars of the University
  2. # of Oxford, and the Chebfun Developers. All rights reserved.
  3. #
  4. # Redistribution and use in source and binary forms, with or without
  5. # modification, are permitted provided that the following conditions are met:
  6. # * Redistributions of source code must retain the above copyright
  7. # notice, this list of conditions and the following disclaimer.
  8. # * Redistributions in binary form must reproduce the above copyright
  9. # notice, this list of conditions and the following disclaimer in the
  10. # documentation and/or other materials provided with the distribution.
  11. # * Neither the name of the University of Oxford nor the names of its
  12. # contributors may be used to endorse or promote products derived from
  13. # this software without specific prior written permission.
  14. #
  15. # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
  16. # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
  17. # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
  18. # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
  19. # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
  20. # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
  21. # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
  22. # ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  23. # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  24. # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  25. from math import factorial
  26. import numpy as np
  27. from numpy.testing import assert_allclose, assert_equal, assert_array_less
  28. import pytest
  29. import scipy
  30. from scipy.interpolate import AAA, FloaterHormannInterpolator, BarycentricInterpolator
  31. TOL = 1e4 * np.finfo(np.float64).eps
  32. UNIT_INTERVAL = np.linspace(-1, 1, num=1000)
  33. PTS = np.logspace(-15, 0, base=10, num=500)
  34. PTS = np.concatenate([-PTS[::-1], [0], PTS])
  35. @pytest.mark.parametrize("method", [AAA, FloaterHormannInterpolator])
  36. @pytest.mark.parametrize("dtype", [np.float32, np.float64, np.complex64, np.complex128])
  37. def test_dtype_preservation(method, dtype):
  38. rtol = np.finfo(dtype).eps ** 0.75 * 100
  39. if method is FloaterHormannInterpolator:
  40. rtol *= 100
  41. rng = np.random.default_rng(59846294526092468)
  42. z = np.linspace(-1, 1, dtype=dtype)
  43. r = method(z, np.sin(z))
  44. z2 = rng.uniform(-1, 1, size=100).astype(dtype)
  45. assert_allclose(r(z2), np.sin(z2), rtol=rtol)
  46. assert r(z2).dtype == dtype
  47. if method is AAA:
  48. assert r.support_points.dtype == dtype
  49. assert r.support_values.dtype == dtype
  50. assert r.errors.dtype == z.real.dtype
  51. assert r.weights.dtype == dtype
  52. assert r.poles().dtype == np.result_type(dtype, 1j)
  53. assert r.residues().dtype == np.result_type(dtype, 1j)
  54. assert r.roots().dtype == np.result_type(dtype, 1j)
  55. @pytest.mark.parametrize("method", [AAA, FloaterHormannInterpolator])
  56. @pytest.mark.parametrize("dtype", [np.int16, np.int32, np.int64])
  57. def test_integer_promotion(method, dtype):
  58. z = np.arange(10, dtype=dtype)
  59. r = method(z, z)
  60. assert r.weights.dtype == np.result_type(dtype, 1.0)
  61. if method is AAA:
  62. assert r.support_points.dtype == np.result_type(dtype, 1.0)
  63. assert r.support_values.dtype == np.result_type(dtype, 1.0)
  64. assert r.errors.dtype == np.result_type(dtype, 1.0)
  65. assert r.poles().dtype == np.result_type(dtype, 1j)
  66. assert r.residues().dtype == np.result_type(dtype, 1j)
  67. assert r.roots().dtype == np.result_type(dtype, 1j)
  68. assert r(z).dtype == np.result_type(dtype, 1.0)
  69. class TestAAA:
  70. def test_input_validation(self):
  71. with pytest.raises(ValueError, match="`x` be of size 2 but got size 1."):
  72. AAA([0], [1, 1])
  73. with pytest.raises(ValueError, match="1-D"):
  74. AAA([[0], [0]], [[1], [1]])
  75. with pytest.raises(ValueError, match="finite"):
  76. AAA([np.inf], [1])
  77. with pytest.raises(TypeError):
  78. AAA([1], [1], max_terms=1.0)
  79. with pytest.raises(ValueError, match="greater"):
  80. AAA([1], [1], max_terms=-1)
  81. def test_convergence_error(self):
  82. with pytest.warns(RuntimeWarning, match="AAA failed"):
  83. AAA(UNIT_INTERVAL, np.exp(UNIT_INTERVAL), max_terms=1)
  84. # The following tests are based on:
  85. # https://github.com/chebfun/chebfun/blob/master/tests/chebfun/test_aaa.m
  86. def test_exp(self):
  87. f = np.exp(UNIT_INTERVAL)
  88. r = AAA(UNIT_INTERVAL, f)
  89. assert_allclose(r(UNIT_INTERVAL), f, atol=TOL)
  90. assert_equal(r(np.nan), np.nan)
  91. assert np.isfinite(r(np.inf))
  92. m1 = r.support_points.size
  93. r = AAA(UNIT_INTERVAL, f, rtol=1e-3)
  94. assert r.support_points.size < m1
  95. def test_tan(self):
  96. f = np.tan(np.pi * UNIT_INTERVAL)
  97. r = AAA(UNIT_INTERVAL, f)
  98. assert_allclose(r(UNIT_INTERVAL), f, atol=10 * TOL, rtol=1.4e-7)
  99. assert_allclose(np.min(np.abs(r.roots())), 0, atol=3e-10)
  100. assert_allclose(np.min(np.abs(r.poles() - 0.5)), 0, atol=TOL)
  101. # Test for spurious poles (poles with tiny residue are likely spurious)
  102. assert np.min(np.abs(r.residues())) > 1e-13
  103. def test_short_cases(self):
  104. # Computed using Chebfun:
  105. # >> format long
  106. # >> [r, pol, res, zer, zj, fj, wj, errvec] = aaa([1 2], [0 1])
  107. z = np.array([0, 1])
  108. f = np.array([1, 2])
  109. r = AAA(z, f, rtol=1e-13)
  110. assert_allclose(r(z), f, atol=TOL)
  111. assert_allclose(r.poles(), 0.5)
  112. assert_allclose(r.residues(), 0.25)
  113. assert_allclose(r.roots(), 1/3)
  114. assert_equal(r.support_points, z)
  115. assert_equal(r.support_values, f)
  116. assert_allclose(r.weights, [0.707106781186547, 0.707106781186547])
  117. assert_equal(r.errors, [1, 0])
  118. # >> format long
  119. # >> [r, pol, res, zer, zj, fj, wj, errvec] = aaa([1 0 0], [0 1 2])
  120. z = np.array([0, 1, 2])
  121. f = np.array([1, 0, 0])
  122. r = AAA(z, f, rtol=1e-13)
  123. assert_allclose(r(z), f, atol=TOL)
  124. assert_allclose(np.sort(r.poles()),
  125. np.sort([1.577350269189626, 0.422649730810374]))
  126. assert_allclose(np.sort(r.residues()),
  127. np.sort([-0.070441621801729, -0.262891711531604]))
  128. assert_allclose(np.sort(r.roots()), np.sort([2, 1]))
  129. assert_equal(r.support_points, z)
  130. assert_equal(r.support_values, f)
  131. assert_allclose(r.weights, [0.577350269189626, 0.577350269189626,
  132. 0.577350269189626])
  133. assert_equal(r.errors, [1, 1, 0])
  134. def test_scale_invariance(self):
  135. z = np.linspace(0.3, 1.5)
  136. f = np.exp(z) / (1 + 1j)
  137. r1 = AAA(z, f)
  138. r2 = AAA(z, (2**311 * f).astype(np.complex128))
  139. r3 = AAA(z, (2**-311 * f).astype(np.complex128))
  140. assert_equal(r1(0.2j), 2**-311 * r2(0.2j))
  141. assert_equal(r1(1.4), 2**311 * r3(1.4))
  142. def test_log_func(self):
  143. rng = np.random.default_rng(1749382759832758297)
  144. z = rng.standard_normal(10000) + 3j * rng.standard_normal(10000)
  145. def f(z):
  146. return np.log(5 - z) / (1 + z**2)
  147. r = AAA(z, f(z))
  148. assert_allclose(r(0), f(0), atol=TOL)
  149. def test_infinite_data(self):
  150. z = np.linspace(-1, 1)
  151. r = AAA(z, scipy.special.gamma(z))
  152. assert_allclose(r(0.63), scipy.special.gamma(0.63), atol=1e-15)
  153. def test_nan(self):
  154. x = np.linspace(0, 20)
  155. with np.errstate(invalid="ignore"):
  156. f = np.sin(x) / x
  157. r = AAA(x, f)
  158. assert_allclose(r(2), np.sin(2) / 2, atol=1e-15)
  159. def test_residues(self):
  160. x = np.linspace(-1.337, 2, num=537)
  161. r = AAA(x, np.exp(x) / x)
  162. ii = np.flatnonzero(np.abs(r.poles()) < 1e-8)
  163. assert_allclose(r.residues()[ii], 1, atol=1e-15)
  164. r = AAA(x, (1 + 1j) * scipy.special.gamma(x))
  165. ii = np.flatnonzero(abs(r.poles() - (-1)) < 1e-8)
  166. assert_allclose(r.residues()[ii], -1 - 1j, atol=1e-15)
  167. # The following tests are based on:
  168. # https://github.com/complexvariables/RationalFunctionApproximation.jl/blob/main/test/interval.jl
  169. @pytest.mark.parametrize("func,atol,rtol",
  170. [(lambda x: np.abs(x + 0.5 + 0.01j), 5e-13, 1e-7),
  171. (lambda x: np.sin(1/(1.05 - x)), 2e-13, 1e-7),
  172. (lambda x: np.exp(-1/(x**2)), 3.5e-11, 0),
  173. (lambda x: np.exp(-100*x**2), 2e-12, 0),
  174. (lambda x: np.exp(-10/(1.2 - x)), 1e-14, 0),
  175. (lambda x: 1/(1+np.exp(100*(x + 0.5))), 2e-13, 1e-7),
  176. (lambda x: np.abs(x - 0.95), 1e-6, 1e-7)])
  177. def test_basic_functions(self, func, atol, rtol):
  178. with np.errstate(divide="ignore"):
  179. f = func(PTS)
  180. assert_allclose(AAA(UNIT_INTERVAL, func(UNIT_INTERVAL))(PTS),
  181. f, atol=atol, rtol=rtol)
  182. def test_poles_zeros_residues(self):
  183. def f(z):
  184. return (z+1) * (z+2) / ((z+3) * (z+4))
  185. r = AAA(UNIT_INTERVAL, f(UNIT_INTERVAL))
  186. assert_allclose(np.sum(r.poles() + r.roots()), -10, atol=1e-12)
  187. def f(z):
  188. return 2/(3 + z) + 5/(z - 2j)
  189. r = AAA(UNIT_INTERVAL, f(UNIT_INTERVAL))
  190. assert_allclose(r.residues().prod(), 10, atol=1e-8)
  191. r = AAA(UNIT_INTERVAL, np.sin(10*np.pi*UNIT_INTERVAL))
  192. assert_allclose(np.sort(np.abs(r.roots()))[18], 0.9, atol=1e-12)
  193. def f(z):
  194. return (z - (3 + 3j))/(z + 2)
  195. r = AAA(UNIT_INTERVAL, f(UNIT_INTERVAL))
  196. assert_allclose(r.poles()[0]*r.roots()[0], -6-6j, atol=1e-12)
  197. @pytest.mark.parametrize("func",
  198. [lambda z: np.zeros_like(z), lambda z: z, lambda z: 1j*z,
  199. lambda z: z**2 + z, lambda z: z**3 + z,
  200. lambda z: 1/(1.1 + z), lambda z: 1/(1 + 1j*z),
  201. lambda z: 1/(3 + z + z**2), lambda z: 1/(1.01 + z**3)])
  202. def test_polynomials_and_reciprocals(self, func):
  203. assert_allclose(AAA(UNIT_INTERVAL, func(UNIT_INTERVAL))(PTS),
  204. func(PTS), atol=2e-13)
  205. # The following tests are taken from:
  206. # https://github.com/macd/BaryRational.jl/blob/main/test/test_aaa.jl
  207. def test_spiral(self):
  208. z = np.exp(np.linspace(-0.5, 0.5 + 15j*np.pi, num=1000))
  209. r = AAA(z, np.tan(np.pi*z/2))
  210. assert_allclose(np.sort(np.abs(r.poles()))[:4], [1, 1, 3, 3], rtol=9e-7)
  211. def test_spiral_cleanup(self):
  212. z = np.exp(np.linspace(-0.5, 0.5 + 15j*np.pi, num=1000))
  213. # here we set `rtol=0` to force froissart doublets, without cleanup there
  214. # are many spurious poles
  215. with pytest.warns(RuntimeWarning):
  216. r = AAA(z, np.tan(np.pi*z/2), rtol=0, max_terms=60, clean_up=False)
  217. n_spurious = np.sum(np.abs(r.residues()) < 1e-14)
  218. with pytest.warns(RuntimeWarning):
  219. assert r.clean_up() >= 1
  220. # check there are less potentially spurious poles than before
  221. assert np.sum(np.abs(r.residues()) < 1e-14) < n_spurious
  222. # check accuracy
  223. assert_allclose(r(z), np.tan(np.pi*z/2), atol=6e-12, rtol=3e-12)
  224. def test_diag_scaling(self):
  225. # fails without diag scaling
  226. z = np.logspace(-15, 0, 300)
  227. f = np.sqrt(z)
  228. r = AAA(z, f)
  229. zz = np.logspace(-15, 0, 500)
  230. assert_allclose(r(zz), np.sqrt(zz), rtol=9e-6)
  231. class BatchFloaterHormann:
  232. # FloaterHormann class with reference batch behaviour
  233. def __init__(self, x, y, axis):
  234. y = np.moveaxis(y, axis, -1)
  235. self._batch_shape = y.shape[:-1]
  236. self._interps = [FloaterHormannInterpolator(x, yi,)
  237. for yi in y.reshape(-1, y.shape[-1])]
  238. self._axis = axis
  239. def __call__(self, x):
  240. y = [interp(x) for interp in self._interps]
  241. y = np.reshape(y, self._batch_shape + x.shape)
  242. return np.moveaxis(y, -1, self._axis) if x.shape else y
  243. class TestFloaterHormann:
  244. def runge(self, z):
  245. return 1/(1 + z**2)
  246. def scale(self, n, d):
  247. return (-1)**(np.arange(n) + d) * factorial(d)
  248. def test_iv(self):
  249. with pytest.raises(ValueError, match="`x`"):
  250. FloaterHormannInterpolator([[0]], [0], d=0)
  251. with pytest.raises(ValueError, match="`y`"):
  252. FloaterHormannInterpolator([0], 0, d=0)
  253. with pytest.raises(ValueError, match="`x` be of size 2 but got size 1."):
  254. FloaterHormannInterpolator([0], [[1, 1], [1, 1]], d=0)
  255. with pytest.raises(ValueError, match="finite"):
  256. FloaterHormannInterpolator([np.inf], [1], d=0)
  257. with pytest.raises(ValueError, match="`d`"):
  258. FloaterHormannInterpolator([0], [0], d=-1)
  259. with pytest.raises(ValueError, match="`d`"):
  260. FloaterHormannInterpolator([0], [0], d=10)
  261. with pytest.raises(TypeError):
  262. FloaterHormannInterpolator([0], [0], d=0.0)
  263. # reference values from Floater and Hormann 2007 page 8.
  264. @pytest.mark.parametrize("d,expected", [
  265. (0, [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
  266. (1, [1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1]),
  267. (2, [1, 3, 4, 4, 4, 4, 4, 4, 4, 3, 1]),
  268. (3, [1, 4, 7, 8, 8, 8, 8, 8, 7, 4, 1]),
  269. (4, [1, 5, 11, 15, 16, 16, 16, 15, 11, 5, 1])
  270. ])
  271. def test_uniform_grid(self, d, expected):
  272. # Check against explicit results on an uniform grid
  273. x = np.arange(11)
  274. r = FloaterHormannInterpolator(x, 0.0*x, d=d)
  275. assert_allclose(r.weights.ravel()*self.scale(x.size, d), expected,
  276. rtol=1e-15, atol=1e-15)
  277. @pytest.mark.parametrize("d", range(10))
  278. def test_runge(self, d):
  279. x = np.linspace(0, 1, 51)
  280. rng = np.random.default_rng(802754237598370893)
  281. xx = rng.uniform(0, 1, size=1000)
  282. y = self.runge(x)
  283. h = x[1] - x[0]
  284. r = FloaterHormannInterpolator(x, y, d=d)
  285. tol = 10*h**(d+1)
  286. assert_allclose(r(xx), self.runge(xx), atol=1e-10, rtol=tol)
  287. # check interpolation property
  288. assert_equal(r(x), self.runge(x))
  289. def test_complex(self):
  290. x = np.linspace(-1, 1)
  291. z = x + x*1j
  292. r = FloaterHormannInterpolator(z, np.sin(z), d=12)
  293. xx = np.linspace(-1, 1, num=1000)
  294. zz = xx + xx*1j
  295. assert_allclose(r(zz), np.sin(zz), rtol=1e-12)
  296. def test_polyinterp(self):
  297. # check that when d=n-1 FH gives a polynomial interpolant
  298. x = np.linspace(0, 1, 11)
  299. xx = np.linspace(0, 1, 1001)
  300. y = np.sin(x)
  301. r = FloaterHormannInterpolator(x, y, d=x.size-1)
  302. p = BarycentricInterpolator(x, y)
  303. assert_allclose(r(xx), p(xx), rtol=1e-12, atol=1e-12)
  304. @pytest.mark.parametrize("y_shape", [(2,), (2, 3, 1), (1, 5, 6, 4)])
  305. @pytest.mark.parametrize("xx_shape", [(100), (10, 10)])
  306. def test_trailing_dim(self, y_shape, xx_shape):
  307. x = np.linspace(0, 1)
  308. y = np.broadcast_to(
  309. np.expand_dims(np.sin(x), tuple(range(1, len(y_shape) + 1))),
  310. x.shape + y_shape
  311. )
  312. r = FloaterHormannInterpolator(x, y)
  313. rng = np.random.default_rng(897138947238097528091759187597)
  314. xx = rng.random(xx_shape)
  315. yy = np.broadcast_to(
  316. np.expand_dims(np.sin(xx), tuple(range(xx.ndim, len(y_shape) + xx.ndim))),
  317. xx.shape + y_shape
  318. )
  319. rr = r(xx)
  320. assert rr.shape == xx.shape + y_shape
  321. assert_allclose(rr, yy, rtol=1e-6)
  322. def test_zeros(self):
  323. x = np.linspace(0, 10, num=100)
  324. r = FloaterHormannInterpolator(x, np.sin(np.pi*x))
  325. err = np.abs(np.subtract.outer(r.roots(), np.arange(11))).min(axis=0)
  326. assert_array_less(err, 1e-5)
  327. def test_no_poles(self):
  328. x = np.linspace(-1, 1)
  329. r = FloaterHormannInterpolator(x, 1/x**2)
  330. p = r.poles()
  331. mask = (p.real >= -1) & (p.real <= 1) & (np.abs(p.imag) < 1.e-12)
  332. assert np.sum(mask) == 0
  333. @pytest.mark.parametrize('eval_shape', [(), (1,), (3,)])
  334. @pytest.mark.parametrize('axis', [-1, 0, 1])
  335. def test_batch(self, eval_shape, axis):
  336. rng = np.random.default_rng(4329872134985134)
  337. n = 10
  338. shape = (2, 3, 4, n)
  339. domain = (0, 10)
  340. x = np.linspace(*domain, n)
  341. y = np.moveaxis(rng.random(shape), -1, axis)
  342. res = FloaterHormannInterpolator(x, y, axis=axis)
  343. ref = BatchFloaterHormann(x, y, axis=axis)
  344. x = rng.uniform(*domain, size=eval_shape)
  345. assert_allclose(res(x), ref(x))
  346. pytest.raises(NotImplementedError, res.roots)
  347. pytest.raises(NotImplementedError, res.residues)