_testutils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. import os
  2. import functools
  3. import operator
  4. from scipy._lib import _pep440
  5. import numpy as np
  6. from numpy.testing import assert_
  7. import pytest
  8. import scipy.special as sc
  9. __all__ = ['with_special_errors', 'assert_func_equal', 'FuncData']
  10. #------------------------------------------------------------------------------
  11. # Check if a module is present to be used in tests
  12. #------------------------------------------------------------------------------
  13. class MissingModule:
  14. def __init__(self, name):
  15. self.name = name
  16. def check_version(module, min_ver):
  17. if type(module) is MissingModule:
  18. return pytest.mark.skip(reason=f"{module.name} is not installed")
  19. return pytest.mark.skipif(
  20. _pep440.parse(module.__version__) < _pep440.Version(min_ver),
  21. reason=f"{module.__name__} version >= {min_ver} required"
  22. )
  23. #------------------------------------------------------------------------------
  24. # Enable convergence and loss of precision warnings -- turn off one by one
  25. #------------------------------------------------------------------------------
  26. def with_special_errors(func):
  27. """
  28. Enable special function errors (such as underflow, overflow,
  29. loss of precision, etc.)
  30. """
  31. @functools.wraps(func)
  32. def wrapper(*a, **kw):
  33. with sc.errstate(all='raise'):
  34. res = func(*a, **kw)
  35. return res
  36. return wrapper
  37. #------------------------------------------------------------------------------
  38. # Comparing function values at many data points at once, with helpful
  39. # error reports
  40. #------------------------------------------------------------------------------
  41. def assert_func_equal(func, results, points, rtol=None, atol=None,
  42. param_filter=None, knownfailure=None,
  43. vectorized=True, dtype=None, nan_ok=False,
  44. ignore_inf_sign=False, distinguish_nan_and_inf=True):
  45. if hasattr(points, 'next'):
  46. # it's a generator
  47. points = list(points)
  48. points = np.asarray(points)
  49. if points.ndim == 1:
  50. points = points[:,None]
  51. nparams = points.shape[1]
  52. if hasattr(results, '__name__'):
  53. # function
  54. data = points
  55. result_columns = None
  56. result_func = results
  57. else:
  58. # dataset
  59. data = np.c_[points, results]
  60. result_columns = list(range(nparams, data.shape[1]))
  61. result_func = None
  62. fdata = FuncData(func, data, list(range(nparams)),
  63. result_columns=result_columns, result_func=result_func,
  64. rtol=rtol, atol=atol, param_filter=param_filter,
  65. knownfailure=knownfailure, nan_ok=nan_ok, vectorized=vectorized,
  66. ignore_inf_sign=ignore_inf_sign,
  67. distinguish_nan_and_inf=distinguish_nan_and_inf)
  68. fdata.check()
  69. class FuncData:
  70. """
  71. Data set for checking a special function.
  72. Parameters
  73. ----------
  74. func : function
  75. Function to test
  76. data : numpy array
  77. columnar data to use for testing
  78. param_columns : int or tuple of ints
  79. Columns indices in which the parameters to `func` lie.
  80. Can be imaginary integers to indicate that the parameter
  81. should be cast to complex.
  82. result_columns : int or tuple of ints, optional
  83. Column indices for expected results from `func`.
  84. result_func : callable, optional
  85. Function to call to obtain results.
  86. rtol : float, optional
  87. Required relative tolerance. Default is 5*eps.
  88. atol : float, optional
  89. Required absolute tolerance. Default is 5*tiny.
  90. param_filter : function, or tuple of functions/Nones, optional
  91. Filter functions to exclude some parameter ranges.
  92. If omitted, no filtering is done.
  93. knownfailure : str, optional
  94. Known failure error message to raise when the test is run.
  95. If omitted, no exception is raised.
  96. nan_ok : bool, optional
  97. If nan is always an accepted result.
  98. vectorized : bool, optional
  99. Whether all functions passed in are vectorized.
  100. ignore_inf_sign : bool, optional
  101. Whether to ignore signs of infinities.
  102. (Doesn't matter for complex-valued functions.)
  103. distinguish_nan_and_inf : bool, optional
  104. If True, treat numbers which contain nans or infs as
  105. equal. Sets ignore_inf_sign to be True.
  106. """
  107. def __init__(self, func, data, param_columns, result_columns=None,
  108. result_func=None, rtol=None, atol=None, param_filter=None,
  109. knownfailure=None, dataname=None, nan_ok=False, vectorized=True,
  110. ignore_inf_sign=False, distinguish_nan_and_inf=True):
  111. self.func = func
  112. self.data = data
  113. self.dataname = dataname
  114. if not hasattr(param_columns, '__len__'):
  115. param_columns = (param_columns,)
  116. self.param_columns = tuple(param_columns)
  117. if result_columns is not None:
  118. if not hasattr(result_columns, '__len__'):
  119. result_columns = (result_columns,)
  120. self.result_columns = tuple(result_columns)
  121. if result_func is not None:
  122. message = "Only result_func or result_columns should be provided"
  123. raise ValueError(message)
  124. elif result_func is not None:
  125. self.result_columns = None
  126. else:
  127. raise ValueError("Either result_func or result_columns should be provided")
  128. self.result_func = result_func
  129. self.rtol = rtol
  130. self.atol = atol
  131. if not hasattr(param_filter, '__len__'):
  132. param_filter = (param_filter,)
  133. self.param_filter = param_filter
  134. self.knownfailure = knownfailure
  135. self.nan_ok = nan_ok
  136. self.vectorized = vectorized
  137. self.ignore_inf_sign = ignore_inf_sign
  138. self.distinguish_nan_and_inf = distinguish_nan_and_inf
  139. if not self.distinguish_nan_and_inf:
  140. self.ignore_inf_sign = True
  141. def get_tolerances(self, dtype):
  142. if not np.issubdtype(dtype, np.inexact):
  143. dtype = np.dtype(float)
  144. info = np.finfo(dtype)
  145. rtol, atol = self.rtol, self.atol
  146. if rtol is None:
  147. rtol = 5*info.eps
  148. if atol is None:
  149. atol = 5*info.tiny
  150. return rtol, atol
  151. def check(self, data=None, dtype=None, dtypes=None):
  152. """Check the special function against the data."""
  153. __tracebackhide__ = operator.methodcaller(
  154. 'errisinstance', AssertionError
  155. )
  156. if self.knownfailure:
  157. pytest.xfail(reason=self.knownfailure)
  158. if data is None:
  159. data = self.data
  160. if dtype is None:
  161. dtype = data.dtype
  162. else:
  163. data = data.astype(dtype)
  164. rtol, atol = self.get_tolerances(dtype)
  165. # Apply given filter functions
  166. if self.param_filter:
  167. param_mask = np.ones((data.shape[0],), np.bool_)
  168. for j, filter in zip(self.param_columns, self.param_filter):
  169. if filter:
  170. param_mask &= list(filter(data[:,j]))
  171. data = data[param_mask]
  172. # Pick parameters from the correct columns
  173. params = []
  174. for idx, j in enumerate(self.param_columns):
  175. if np.iscomplexobj(j):
  176. j = int(j.imag)
  177. params.append(data[:,j].astype(complex))
  178. elif dtypes and idx < len(dtypes):
  179. params.append(data[:, j].astype(dtypes[idx]))
  180. else:
  181. params.append(data[:,j])
  182. # Helper for evaluating results
  183. def eval_func_at_params(func, skip_mask=None):
  184. if self.vectorized:
  185. got = func(*params)
  186. else:
  187. got = []
  188. for j in range(len(params[0])):
  189. if skip_mask is not None and skip_mask[j]:
  190. got.append(np.nan)
  191. continue
  192. got.append(func(*tuple([params[i][j] for i in range(len(params))])))
  193. got = np.asarray(got)
  194. if not isinstance(got, tuple):
  195. got = (got,)
  196. return got
  197. # Evaluate function to be tested
  198. got = eval_func_at_params(self.func)
  199. # Grab the correct results
  200. if self.result_columns is not None:
  201. # Correct results passed in with the data
  202. wanted = tuple([data[:,icol] for icol in self.result_columns])
  203. else:
  204. # Function producing correct results passed in
  205. skip_mask = None
  206. if self.nan_ok and len(got) == 1:
  207. # Don't spend time evaluating what doesn't need to be evaluated
  208. skip_mask = np.isnan(got[0])
  209. wanted = eval_func_at_params(self.result_func, skip_mask=skip_mask)
  210. # Check the validity of each output returned
  211. assert_(len(got) == len(wanted))
  212. for output_num, (x, y) in enumerate(zip(got, wanted)):
  213. if np.issubdtype(x.dtype, np.complexfloating) or self.ignore_inf_sign:
  214. pinf_x = np.isinf(x)
  215. pinf_y = np.isinf(y)
  216. minf_x = np.isinf(x)
  217. minf_y = np.isinf(y)
  218. else:
  219. pinf_x = np.isposinf(x)
  220. pinf_y = np.isposinf(y)
  221. minf_x = np.isneginf(x)
  222. minf_y = np.isneginf(y)
  223. nan_x = np.isnan(x)
  224. nan_y = np.isnan(y)
  225. with np.errstate(all='ignore'):
  226. abs_y = np.absolute(y)
  227. abs_y[~np.isfinite(abs_y)] = 0
  228. diff = np.absolute(x - y)
  229. diff[~np.isfinite(diff)] = 0
  230. rdiff = diff / np.absolute(y)
  231. rdiff[~np.isfinite(rdiff)] = 0
  232. tol_mask = (diff <= atol + rtol*abs_y)
  233. pinf_mask = (pinf_x == pinf_y)
  234. minf_mask = (minf_x == minf_y)
  235. nan_mask = (nan_x == nan_y)
  236. bad_j = ~(tol_mask & pinf_mask & minf_mask & nan_mask)
  237. point_count = bad_j.size
  238. if self.nan_ok:
  239. bad_j &= ~nan_x
  240. bad_j &= ~nan_y
  241. point_count -= (nan_x | nan_y).sum()
  242. if not self.distinguish_nan_and_inf and not self.nan_ok:
  243. # If nan's are okay we've already covered all these cases
  244. inf_x = np.isinf(x)
  245. inf_y = np.isinf(y)
  246. both_nonfinite = (inf_x & nan_y) | (nan_x & inf_y)
  247. bad_j &= ~both_nonfinite
  248. point_count -= both_nonfinite.sum()
  249. if np.any(bad_j):
  250. # Some bad results: inform what, where, and how bad
  251. msg = [""]
  252. msg.append(f"Max |adiff|: {diff[bad_j].max():g}")
  253. msg.append(f"Max |rdiff|: {rdiff[bad_j].max():g}")
  254. msg.append(f"Bad results ({np.sum(bad_j)} out of "
  255. f"{point_count}) for the following points "
  256. f"(in output {output_num}):")
  257. for j in np.nonzero(bad_j)[0]:
  258. j = int(j)
  259. def fmt(x):
  260. return f'{np.array2string(x[j], precision=18):30s}'
  261. a = " ".join(map(fmt, params))
  262. b = " ".join(map(fmt, got))
  263. c = " ".join(map(fmt, wanted))
  264. d = fmt(rdiff)
  265. msg.append(f"{a} => {b} != {c} (rdiff {d})")
  266. assert_(False, "\n".join(msg))
  267. def __repr__(self):
  268. """Pretty-printing"""
  269. if np.any(list(map(np.iscomplexobj, self.param_columns))):
  270. is_complex = " (complex)"
  271. else:
  272. is_complex = ""
  273. if self.dataname:
  274. return (f"<Data for {self.func.__name__}{is_complex}: "
  275. f"{os.path.basename(self.dataname)}>")
  276. else:
  277. return f"<Data for {self.func.__name__}{is_complex}>"