_mptestutils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. import os
  2. import sys
  3. import time
  4. from itertools import zip_longest
  5. import numpy as np
  6. from numpy.testing import assert_
  7. import pytest
  8. from scipy.special._testutils import assert_func_equal
  9. try:
  10. import mpmath
  11. except ImportError:
  12. pass
  13. # ------------------------------------------------------------------------------
  14. # Machinery for systematic tests with mpmath
  15. # ------------------------------------------------------------------------------
  16. class Arg:
  17. """Generate a set of numbers on the real axis, concentrating on
  18. 'interesting' regions and covering all orders of magnitude.
  19. """
  20. def __init__(self, a=-np.inf, b=np.inf, inclusive_a=True, inclusive_b=True):
  21. if a > b:
  22. raise ValueError("a should be less than or equal to b")
  23. if a == -np.inf:
  24. a = -0.5*np.finfo(float).max
  25. if b == np.inf:
  26. b = 0.5*np.finfo(float).max
  27. self.a, self.b = a, b
  28. self.inclusive_a, self.inclusive_b = inclusive_a, inclusive_b
  29. def _positive_values(self, a, b, n):
  30. if a < 0:
  31. raise ValueError("a should be positive")
  32. # Try to put half of the points into a linspace between a and
  33. # 10 the other half in a logspace.
  34. if n % 2 == 0:
  35. nlogpts = n//2
  36. nlinpts = nlogpts
  37. else:
  38. nlogpts = n//2
  39. nlinpts = nlogpts + 1
  40. if a >= 10:
  41. # Outside of linspace range; just return a logspace.
  42. pts = np.logspace(np.log10(a), np.log10(b), n)
  43. elif a > 0 and b < 10:
  44. # Outside of logspace range; just return a linspace
  45. pts = np.linspace(a, b, n)
  46. elif a > 0:
  47. # Linspace between a and 10 and a logspace between 10 and
  48. # b.
  49. linpts = np.linspace(a, 10, nlinpts, endpoint=False)
  50. logpts = np.logspace(1, np.log10(b), nlogpts)
  51. pts = np.hstack((linpts, logpts))
  52. elif a == 0 and b <= 10:
  53. # Linspace between 0 and b and a logspace between 0 and
  54. # the smallest positive point of the linspace
  55. linpts = np.linspace(0, b, nlinpts)
  56. if linpts.size > 1:
  57. right = np.log10(linpts[1])
  58. else:
  59. right = -30
  60. logpts = np.logspace(-30, right, nlogpts, endpoint=False)
  61. pts = np.hstack((logpts, linpts))
  62. else:
  63. # Linspace between 0 and 10, logspace between 0 and the
  64. # smallest positive point of the linspace, and a logspace
  65. # between 10 and b.
  66. if nlogpts % 2 == 0:
  67. nlogpts1 = nlogpts//2
  68. nlogpts2 = nlogpts1
  69. else:
  70. nlogpts1 = nlogpts//2
  71. nlogpts2 = nlogpts1 + 1
  72. linpts = np.linspace(0, 10, nlinpts, endpoint=False)
  73. if linpts.size > 1:
  74. right = np.log10(linpts[1])
  75. else:
  76. right = -30
  77. logpts1 = np.logspace(-30, right, nlogpts1, endpoint=False)
  78. logpts2 = np.logspace(1, np.log10(b), nlogpts2)
  79. pts = np.hstack((logpts1, linpts, logpts2))
  80. return np.sort(pts)
  81. def values(self, n):
  82. """Return an array containing n numbers."""
  83. a, b = self.a, self.b
  84. if a == b:
  85. return np.zeros(n)
  86. if not self.inclusive_a:
  87. n += 1
  88. if not self.inclusive_b:
  89. n += 1
  90. if n % 2 == 0:
  91. n1 = n//2
  92. n2 = n1
  93. else:
  94. n1 = n//2
  95. n2 = n1 + 1
  96. if a >= 0:
  97. pospts = self._positive_values(a, b, n)
  98. negpts = []
  99. elif b <= 0:
  100. pospts = []
  101. negpts = -self._positive_values(-b, -a, n)
  102. else:
  103. pospts = self._positive_values(0, b, n1)
  104. negpts = -self._positive_values(0, -a, n2 + 1)
  105. # Don't want to get zero twice
  106. negpts = negpts[1:]
  107. pts = np.hstack((negpts[::-1], pospts))
  108. if not self.inclusive_a:
  109. pts = pts[1:]
  110. if not self.inclusive_b:
  111. pts = pts[:-1]
  112. return pts
  113. class FixedArg:
  114. def __init__(self, values):
  115. self._values = np.asarray(values)
  116. def values(self, n):
  117. return self._values
  118. class ComplexArg:
  119. def __init__(self, a=complex(-np.inf, -np.inf), b=complex(np.inf, np.inf)):
  120. self.real = Arg(a.real, b.real)
  121. self.imag = Arg(a.imag, b.imag)
  122. def values(self, n):
  123. m = int(np.floor(np.sqrt(n)))
  124. x = self.real.values(m)
  125. y = self.imag.values(m + 1)
  126. return (x[:,None] + 1j*y[None,:]).ravel()
  127. class IntArg:
  128. def __init__(self, a=-1000, b=1000):
  129. self.a = a
  130. self.b = b
  131. def values(self, n):
  132. v1 = Arg(self.a, self.b).values(max(1 + n//2, n-5)).astype(int)
  133. v2 = np.arange(-5, 5)
  134. v = np.unique(np.r_[v1, v2])
  135. v = v[(v >= self.a) & (v < self.b)]
  136. return v
  137. def get_args(argspec, n):
  138. if isinstance(argspec, np.ndarray):
  139. args = argspec.copy()
  140. else:
  141. nargs = len(argspec)
  142. ms = np.asarray(
  143. [1.5 if isinstance(spec, ComplexArg) else 1.0 for spec in argspec]
  144. )
  145. ms = (n**(ms/sum(ms))).astype(int) + 1
  146. args = [spec.values(m) for spec, m in zip(argspec, ms)]
  147. args = np.array(np.broadcast_arrays(*np.ix_(*args))).reshape(nargs, -1).T
  148. return args
  149. class MpmathData:
  150. def __init__(self, scipy_func, mpmath_func, arg_spec, name=None,
  151. dps=None, prec=None, n=None, rtol=1e-7, atol=1e-300,
  152. ignore_inf_sign=False, distinguish_nan_and_inf=True,
  153. nan_ok=True, param_filter=None):
  154. # mpmath tests are really slow (see gh-6989). Use a small number of
  155. # points by default, increase back to 5000 (old default) if XSLOW is
  156. # set
  157. if n is None:
  158. try:
  159. is_xslow = int(os.environ.get('SCIPY_XSLOW', '0'))
  160. except ValueError:
  161. is_xslow = False
  162. n = 5000 if is_xslow else 500
  163. self.scipy_func = scipy_func
  164. self.mpmath_func = mpmath_func
  165. self.arg_spec = arg_spec
  166. self.dps = dps
  167. self.prec = prec
  168. self.n = n
  169. self.rtol = rtol
  170. self.atol = atol
  171. self.ignore_inf_sign = ignore_inf_sign
  172. self.nan_ok = nan_ok
  173. if isinstance(self.arg_spec, np.ndarray):
  174. self.is_complex = np.issubdtype(self.arg_spec.dtype, np.complexfloating)
  175. else:
  176. self.is_complex = any(
  177. [isinstance(arg, ComplexArg) for arg in self.arg_spec]
  178. )
  179. self.ignore_inf_sign = ignore_inf_sign
  180. self.distinguish_nan_and_inf = distinguish_nan_and_inf
  181. if not name or name == '<lambda>':
  182. name = getattr(scipy_func, '__name__', None)
  183. if not name or name == '<lambda>':
  184. name = getattr(mpmath_func, '__name__', None)
  185. self.name = name
  186. self.param_filter = param_filter
  187. def check(self):
  188. np.random.seed(1234)
  189. # Generate values for the arguments
  190. argarr = get_args(self.arg_spec, self.n)
  191. # Check
  192. old_dps, old_prec = mpmath.mp.dps, mpmath.mp.prec
  193. try:
  194. if self.dps is not None:
  195. dps_list = [self.dps]
  196. else:
  197. dps_list = [20]
  198. if self.prec is not None:
  199. mpmath.mp.prec = self.prec
  200. # Proper casting of mpmath input and output types. Using
  201. # native mpmath types as inputs gives improved precision
  202. # in some cases.
  203. if np.issubdtype(argarr.dtype, np.complexfloating):
  204. pytype = mpc2complex
  205. def mptype(x):
  206. return mpmath.mpc(complex(x))
  207. else:
  208. def mptype(x):
  209. return mpmath.mpf(float(x))
  210. def pytype(x):
  211. if abs(x.imag) > 1e-16*(1 + abs(x.real)):
  212. return np.nan
  213. else:
  214. return mpf2float(x.real)
  215. # Try out different dps until one (or none) works
  216. for j, dps in enumerate(dps_list):
  217. mpmath.mp.dps = dps
  218. try:
  219. assert_func_equal(
  220. self.scipy_func,
  221. lambda *a: pytype(self.mpmath_func(*map(mptype, a))),
  222. argarr,
  223. vectorized=False,
  224. rtol=self.rtol,
  225. atol=self.atol,
  226. ignore_inf_sign=self.ignore_inf_sign,
  227. distinguish_nan_and_inf=self.distinguish_nan_and_inf,
  228. nan_ok=self.nan_ok,
  229. param_filter=self.param_filter
  230. )
  231. break
  232. except AssertionError:
  233. if j >= len(dps_list)-1:
  234. # reraise the Exception
  235. tp, value, tb = sys.exc_info()
  236. if value.__traceback__ is not tb:
  237. raise value.with_traceback(tb)
  238. raise value
  239. finally:
  240. mpmath.mp.dps, mpmath.mp.prec = old_dps, old_prec
  241. def __repr__(self):
  242. if self.is_complex:
  243. return f"<MpmathData: {self.name} (complex)>"
  244. else:
  245. return f"<MpmathData: {self.name}>"
  246. def assert_mpmath_equal(*a, **kw):
  247. d = MpmathData(*a, **kw)
  248. d.check()
  249. def nonfunctional_tooslow(func):
  250. return pytest.mark.skip(
  251. reason=" Test not yet functional (too slow), needs more work."
  252. )(func)
  253. # ------------------------------------------------------------------------------
  254. # Tools for dealing with mpmath quirks
  255. # ------------------------------------------------------------------------------
  256. def mpf2float(x):
  257. """
  258. Convert an mpf to the nearest floating point number. Just using
  259. float directly doesn't work because of results like this:
  260. with mp.workdps(50):
  261. float(mpf("0.99999999999999999")) = 0.9999999999999999
  262. """
  263. return float(mpmath.nstr(x, 17, min_fixed=0, max_fixed=0))
  264. def mpc2complex(x):
  265. return complex(mpf2float(x.real), mpf2float(x.imag))
  266. def trace_args(func):
  267. def tofloat(x):
  268. if isinstance(x, mpmath.mpc):
  269. return complex(x)
  270. else:
  271. return float(x)
  272. def wrap(*a, **kw):
  273. sys.stderr.write(f"{tuple(map(tofloat, a))!r}: ")
  274. sys.stderr.flush()
  275. try:
  276. r = func(*a, **kw)
  277. sys.stderr.write(f"-> {r!r}")
  278. finally:
  279. sys.stderr.write("\n")
  280. sys.stderr.flush()
  281. return r
  282. return wrap
  283. try:
  284. import signal
  285. POSIX = ('setitimer' in dir(signal))
  286. except ImportError:
  287. POSIX = False
  288. class TimeoutError(Exception):
  289. pass
  290. def time_limited(timeout=0.5, return_val=np.nan, use_sigalrm=True):
  291. """
  292. Decorator for setting a timeout for pure-Python functions.
  293. If the function does not return within `timeout` seconds, the
  294. value `return_val` is returned instead.
  295. On POSIX this uses SIGALRM by default. On non-POSIX, settrace is
  296. used. Do not use this with threads: the SIGALRM implementation
  297. does probably not work well. The settrace implementation only
  298. traces the current thread.
  299. The settrace implementation slows down execution speed. Slowdown
  300. by a factor around 10 is probably typical.
  301. """
  302. if POSIX and use_sigalrm:
  303. def sigalrm_handler(signum, frame):
  304. raise TimeoutError()
  305. def deco(func):
  306. def wrap(*a, **kw):
  307. old_handler = signal.signal(signal.SIGALRM, sigalrm_handler)
  308. signal.setitimer(signal.ITIMER_REAL, timeout)
  309. try:
  310. return func(*a, **kw)
  311. except TimeoutError:
  312. return return_val
  313. finally:
  314. signal.setitimer(signal.ITIMER_REAL, 0)
  315. signal.signal(signal.SIGALRM, old_handler)
  316. return wrap
  317. else:
  318. def deco(func):
  319. def wrap(*a, **kw):
  320. start_time = time.time()
  321. def trace(frame, event, arg):
  322. if time.time() - start_time > timeout:
  323. raise TimeoutError()
  324. return trace
  325. sys.settrace(trace)
  326. try:
  327. return func(*a, **kw)
  328. except TimeoutError:
  329. sys.settrace(None)
  330. return return_val
  331. finally:
  332. sys.settrace(None)
  333. return wrap
  334. return deco
  335. def exception_to_nan(func):
  336. """Decorate function to return nan if it raises an exception"""
  337. def wrap(*a, **kw):
  338. try:
  339. return func(*a, **kw)
  340. except Exception:
  341. return np.nan
  342. return wrap
  343. def inf_to_nan(func):
  344. """Decorate function to return nan if it returns inf"""
  345. def wrap(*a, **kw):
  346. v = func(*a, **kw)
  347. if not np.isfinite(v):
  348. return np.nan
  349. return v
  350. return wrap
  351. def mp_assert_allclose(res, std, atol=0, rtol=1e-17):
  352. """
  353. Compare lists of mpmath.mpf's or mpmath.mpc's directly so that it
  354. can be done to higher precision than double.
  355. """
  356. failures = []
  357. for k, (resval, stdval) in enumerate(zip_longest(res, std)):
  358. if resval is None or stdval is None:
  359. raise ValueError('Lengths of inputs res and std are not equal.')
  360. if mpmath.fabs(resval - stdval) > atol + rtol*mpmath.fabs(stdval):
  361. failures.append((k, resval, stdval))
  362. nfail = len(failures)
  363. if nfail > 0:
  364. ndigits = int(abs(np.log10(rtol)))
  365. msg = [""]
  366. msg.append(f"Bad results ({nfail} out of {k + 1}) for the following points:")
  367. for k, resval, stdval in failures:
  368. resrep = mpmath.nstr(resval, ndigits, min_fixed=0, max_fixed=0)
  369. stdrep = mpmath.nstr(stdval, ndigits, min_fixed=0, max_fixed=0)
  370. if stdval == 0:
  371. rdiff = "inf"
  372. else:
  373. rdiff = mpmath.fabs((resval - stdval)/stdval)
  374. rdiff = mpmath.nstr(rdiff, 3)
  375. msg.append(f"{k}: {resrep} != {stdrep} (rdiff {rdiff})")
  376. assert_(False, "\n".join(msg))