test_continued_fraction.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import math
  2. import pytest
  3. import numpy as np
  4. from scipy._lib._array_api import array_namespace
  5. from scipy._lib._array_api_no_0d import xp_assert_close, xp_assert_less, xp_assert_equal
  6. from scipy.stats._continued_fraction import _continued_fraction
  7. @pytest.mark.skip_xp_backends('array_api_strict', reason='No fancy indexing assignment')
  8. @pytest.mark.skip_xp_backends('jax.numpy', reason="Don't support mutation")
  9. # dask doesn't like lines like this
  10. # n = int(xp.real(xp_ravel(n))[0])
  11. # (at some point in here the shape becomes nan)
  12. @pytest.mark.skip_xp_backends('dask.array', reason="dask has issues with the shapes")
  13. class TestContinuedFraction:
  14. rng = np.random.default_rng(5895448232066142650)
  15. p = rng.uniform(1, 10, size=10)
  16. def a1(self, n, x=1.5):
  17. if n == 0:
  18. y = 0*x
  19. elif n == 1:
  20. y = x
  21. else:
  22. y = -x**2
  23. if np.isscalar(y) and np.__version__ < "2.0":
  24. y = np.full_like(x, y) # preserve dtype pre NEP 50
  25. return y
  26. def b1(self, n, x=1.5):
  27. if n == 0:
  28. y = 0*x
  29. else:
  30. one = x/x # gets array of correct type, dtype, and shape
  31. y = one * (2*n - 1)
  32. if np.isscalar(y) and np.__version__ < "2.0":
  33. y = np.full_like(x, y) # preserve dtype pre NEP 50
  34. return y
  35. def log_a1(self, n, x):
  36. xp = array_namespace(x)
  37. if n == 0:
  38. y = xp.full_like(x, -xp.asarray(math.inf, dtype=x.dtype))
  39. elif n == 1:
  40. y = xp.log(x)
  41. else:
  42. y = 2 * xp.log(x) + math.pi * 1j
  43. return y
  44. def log_b1(self, n, x):
  45. xp = array_namespace(x)
  46. if n == 0:
  47. y = xp.full_like(x, -xp.asarray(math.inf, dtype=x.dtype))
  48. else:
  49. one = x - x # gets array of correct type, dtype, and shape
  50. y = one + math.log(2 * n - 1)
  51. return y
  52. def test_input_validation(self, xp):
  53. a1 = self.a1
  54. b1 = self.b1
  55. message = '`a` and `b` must be callable.'
  56. with pytest.raises(ValueError, match=message):
  57. _continued_fraction(1, b1)
  58. with pytest.raises(ValueError, match=message):
  59. _continued_fraction(a1, 1)
  60. message = r'`eps` and `tiny` must be \(or represent the logarithm of\)...'
  61. with pytest.raises(ValueError, match=message):
  62. _continued_fraction(a1, b1, tolerances={'eps': -10})
  63. with pytest.raises(ValueError, match=message):
  64. _continued_fraction(a1, b1, tolerances={'eps': np.nan})
  65. with pytest.raises(ValueError, match=message):
  66. _continued_fraction(a1, b1, tolerances={'eps': 1+1j}, log=True)
  67. with pytest.raises(ValueError, match=message):
  68. _continued_fraction(a1, b1, tolerances={'tiny': 0})
  69. with pytest.raises(ValueError, match=message):
  70. _continued_fraction(a1, b1, tolerances={'tiny': np.inf})
  71. with pytest.raises(ValueError, match=message):
  72. _continued_fraction(a1, b1, tolerances={'tiny': np.inf}, log=True)
  73. # this should not raise
  74. kwargs = dict(args=xp.asarray(1.5+0j), log=True, maxiter=0)
  75. _continued_fraction(a1, b1, tolerances={'eps': -10}, **kwargs)
  76. _continued_fraction(a1, b1, tolerances={'tiny': -10}, **kwargs)
  77. message = '`maxiter` must be a non-negative integer.'
  78. with pytest.raises(ValueError, match=message):
  79. _continued_fraction(a1, b1, maxiter=-1)
  80. message = '`log` must be boolean.'
  81. with pytest.raises(ValueError, match=message):
  82. _continued_fraction(a1, b1, log=2)
  83. @pytest.mark.parametrize('dtype', ['float32', 'float64', 'complex64', 'complex128'])
  84. @pytest.mark.parametrize('shape', [(), (1,), (3,), (3, 2)])
  85. def test_basic(self, shape, dtype, xp):
  86. np_dtype = getattr(np, dtype)
  87. xp_dtype = getattr(xp, dtype)
  88. rng = np.random.default_rng(2435908729190400)
  89. x = rng.random(shape).astype(np_dtype)
  90. x = x + rng.random(shape).astype(np_dtype)*1j if dtype.startswith('c') else x
  91. x = xp.asarray(x, dtype=xp_dtype)
  92. res = _continued_fraction(self.a1, self.b1, args=(x,))
  93. ref = xp.tan(x)
  94. xp_assert_close(res.f, ref)
  95. @pytest.mark.skip_xp_backends('torch', reason='pytorch/pytorch#136063')
  96. @pytest.mark.parametrize('dtype', ['float32', 'float64'])
  97. @pytest.mark.parametrize('shape', [(), (1,), (3,), (3, 2)])
  98. def test_log(self, shape, dtype, xp):
  99. if (np.__version__ < "2") and (dtype == 'float32'):
  100. pytest.skip("Scalar dtypes only respected after NEP 50.")
  101. np_dtype = getattr(np, dtype)
  102. rng = np.random.default_rng(2435908729190400)
  103. x = rng.random(shape).astype(np_dtype)
  104. x = xp.asarray(x)
  105. res = _continued_fraction(self.log_a1, self.log_b1, args=(x + 0j,), log=True)
  106. ref = xp.tan(x)
  107. xp_assert_close(xp.exp(xp.real(res.f)), ref)
  108. def test_maxiter(self, xp):
  109. rng = np.random.default_rng(2435908729190400)
  110. x = xp.asarray(rng.random(), dtype=xp.float64)
  111. ref = xp.tan(x)
  112. res1 = _continued_fraction(self.a1, self.b1, args=(x,), maxiter=3)
  113. assert res1.nit == 3
  114. res2 = _continued_fraction(self.a1, self.b1, args=(x,), maxiter=6)
  115. assert res2.nit == 6
  116. xp_assert_less(xp.abs(res2.f - ref), xp.abs(res1.f - ref))
  117. def test_eps(self, xp):
  118. x = xp.asarray(1.5, dtype=xp.float64) # x = 1.5 is the default defined above
  119. ref = xp.tan(x)
  120. res1 = _continued_fraction(self.a1, self.b1, args=(x,),
  121. tolerances={'eps': 1e-6})
  122. res2 = _continued_fraction(self.a1, self.b1, args=(x,))
  123. xp_assert_less(res1.nit, res2.nit)
  124. xp_assert_less(xp.abs(res2.f - ref), xp.abs(res1.f - ref))
  125. def test_feval(self, xp):
  126. def a(n, x):
  127. a.nfev += 1
  128. return n * x
  129. def b(n, x):
  130. b.nfev += 1
  131. return n * x
  132. a.nfev, b.nfev = 0, 0
  133. res = _continued_fraction(a, b, args=(xp.asarray(1.),))
  134. assert res.nfev == a.nfev == b.nfev == res.nit + 1
  135. def test_status(self, xp):
  136. x = xp.asarray([1, 10, np.nan], dtype=xp.float64)
  137. res = _continued_fraction(self.a1, self.b1, args=(x,), maxiter=15)
  138. xp_assert_equal(res.success, xp.asarray([True, False, False]))
  139. xp_assert_equal(res.status, xp.asarray([0, -2, -3], dtype=xp.int32))
  140. def test_special_cases(self, xp):
  141. one = xp.asarray(1)
  142. res = _continued_fraction(lambda x: one, lambda x: one, maxiter=0)
  143. xp_assert_close(res.f, xp.asarray(1.))
  144. assert res.nit == res.nfev - 1 == 0