_testutils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. """
  2. Generic test utilities.
  3. """
  4. import inspect
  5. import os
  6. import re
  7. import shutil
  8. import subprocess
  9. import sys
  10. import sysconfig
  11. import threading
  12. from importlib.util import module_from_spec, spec_from_file_location
  13. import numpy as np
  14. import scipy
  15. try:
  16. # Need type: ignore[import-untyped] for mypy >= 1.6
  17. import cython # type: ignore[import-untyped]
  18. from Cython.Compiler.Version import ( # type: ignore[import-untyped]
  19. version as cython_version,
  20. )
  21. except ImportError:
  22. cython = None
  23. else:
  24. from scipy._lib import _pep440
  25. required_version = '3.0.8'
  26. if _pep440.parse(cython_version) < _pep440.Version(required_version):
  27. # too old or wrong cython, skip Cython API tests
  28. cython = None
  29. __all__ = ['PytestTester', 'check_free_memory', '_TestPythranFunc', 'IS_MUSL']
  30. IS_MUSL = False
  31. # alternate way is
  32. # from packaging.tags import sys_tags
  33. # _tags = list(sys_tags())
  34. # if 'musllinux' in _tags[0].platform:
  35. _v = sysconfig.get_config_var('HOST_GNU_TYPE') or ''
  36. if 'musl' in _v:
  37. IS_MUSL = True
  38. IS_EDITABLE = 'editable' in scipy.__path__[0]
  39. class FPUModeChangeWarning(RuntimeWarning):
  40. """Warning about FPU mode change"""
  41. pass
  42. class PytestTester:
  43. """
  44. Run tests for this namespace
  45. ``scipy.test()`` runs tests for all of SciPy, with the default settings.
  46. When used from a submodule (e.g., ``scipy.cluster.test()``, only the tests
  47. for that namespace are run.
  48. Parameters
  49. ----------
  50. label : {'fast', 'full'}, optional
  51. Whether to run only the fast tests, or also those marked as slow.
  52. Default is 'fast'.
  53. verbose : int, optional
  54. Test output verbosity. Default is 1.
  55. extra_argv : list, optional
  56. Arguments to pass through to Pytest.
  57. doctests : bool, optional
  58. Whether to run doctests or not. Default is False.
  59. coverage : bool, optional
  60. Whether to run tests with code coverage measurements enabled.
  61. Default is False.
  62. tests : list of str, optional
  63. List of module names to run tests for. By default, uses the module
  64. from which the ``test`` function is called.
  65. parallel : int, optional
  66. Run tests in parallel with pytest-xdist, if number given is larger than
  67. 1. Default is 1.
  68. """
  69. def __init__(self, module_name):
  70. self.module_name = module_name
  71. def __call__(self, label="fast", verbose=1, extra_argv=None, doctests=False,
  72. coverage=False, tests=None, parallel=None):
  73. import pytest
  74. module = sys.modules[self.module_name]
  75. module_path = os.path.abspath(module.__path__[0])
  76. pytest_args = ['--showlocals', '--tb=short']
  77. if extra_argv is None:
  78. extra_argv = []
  79. pytest_args += extra_argv
  80. if any(arg == "-m" or arg == "--markers" for arg in extra_argv):
  81. # Likely conflict with default --mode=fast
  82. raise ValueError("Must specify -m before --")
  83. if verbose and int(verbose) > 1:
  84. pytest_args += ["-" + "v"*(int(verbose)-1)]
  85. if coverage:
  86. pytest_args += ["--cov=" + module_path]
  87. if label == "fast":
  88. pytest_args += ["-m", "not slow"]
  89. elif label != "full":
  90. pytest_args += ["-m", label]
  91. if tests is None:
  92. tests = [self.module_name]
  93. if parallel is not None and parallel > 1:
  94. if _pytest_has_xdist():
  95. pytest_args += ['-n', str(parallel)]
  96. else:
  97. import warnings
  98. warnings.warn('Could not run tests in parallel because '
  99. 'pytest-xdist plugin is not available.',
  100. stacklevel=2)
  101. pytest_args += ['--pyargs'] + list(tests)
  102. try:
  103. code = pytest.main(pytest_args)
  104. except SystemExit as exc:
  105. code = exc.code
  106. return (code == 0)
  107. class _TestPythranFunc:
  108. '''
  109. These are situations that can be tested in our pythran tests:
  110. - A function with multiple array arguments and then
  111. other positional and keyword arguments.
  112. - A function with array-like keywords (e.g. `def somefunc(x0, x1=None)`.
  113. Note: list/tuple input is not yet tested!
  114. `self.arguments`: A dictionary which key is the index of the argument,
  115. value is tuple(array value, all supported dtypes)
  116. `self.partialfunc`: A function used to freeze some non-array argument
  117. that of no interests in the original function
  118. '''
  119. ALL_INTEGER = [np.int8, np.int16, np.int32, np.int64, np.intc, np.intp]
  120. ALL_FLOAT = [np.float32, np.float64]
  121. ALL_COMPLEX = [np.complex64, np.complex128]
  122. def setup_method(self):
  123. self.arguments = {}
  124. self.partialfunc = None
  125. self.expected = None
  126. def get_optional_args(self, func):
  127. # get optional arguments with its default value,
  128. # used for testing keywords
  129. signature = inspect.signature(func)
  130. optional_args = {}
  131. for k, v in signature.parameters.items():
  132. if v.default is not inspect.Parameter.empty:
  133. optional_args[k] = v.default
  134. return optional_args
  135. def get_max_dtype_list_length(self):
  136. # get the max supported dtypes list length in all arguments
  137. max_len = 0
  138. for arg_idx in self.arguments:
  139. cur_len = len(self.arguments[arg_idx][1])
  140. if cur_len > max_len:
  141. max_len = cur_len
  142. return max_len
  143. def get_dtype(self, dtype_list, dtype_idx):
  144. # get the dtype from dtype_list via index
  145. # if the index is out of range, then return the last dtype
  146. if dtype_idx > len(dtype_list)-1:
  147. return dtype_list[-1]
  148. else:
  149. return dtype_list[dtype_idx]
  150. def test_all_dtypes(self):
  151. for type_idx in range(self.get_max_dtype_list_length()):
  152. args_array = []
  153. for arg_idx in self.arguments:
  154. new_dtype = self.get_dtype(self.arguments[arg_idx][1],
  155. type_idx)
  156. args_array.append(self.arguments[arg_idx][0].astype(new_dtype))
  157. self.pythranfunc(*args_array)
  158. def test_views(self):
  159. args_array = []
  160. for arg_idx in self.arguments:
  161. args_array.append(self.arguments[arg_idx][0][::-1][::-1])
  162. self.pythranfunc(*args_array)
  163. def test_strided(self):
  164. args_array = []
  165. for arg_idx in self.arguments:
  166. args_array.append(np.repeat(self.arguments[arg_idx][0],
  167. 2, axis=0)[::2])
  168. self.pythranfunc(*args_array)
  169. def _pytest_has_xdist():
  170. """
  171. Check if the pytest-xdist plugin is installed, providing parallel tests
  172. """
  173. # Check xdist exists without importing, otherwise pytests emits warnings
  174. from importlib.util import find_spec
  175. return find_spec('xdist') is not None
  176. def check_free_memory(free_mb):
  177. """
  178. Check *free_mb* of memory is available, otherwise do pytest.skip
  179. """
  180. import pytest
  181. try:
  182. mem_free = _parse_size(os.environ['SCIPY_AVAILABLE_MEM'])
  183. msg = '{} MB memory required, but environment SCIPY_AVAILABLE_MEM={}'.format(
  184. free_mb, os.environ['SCIPY_AVAILABLE_MEM'])
  185. except KeyError:
  186. mem_free = _get_mem_available()
  187. if mem_free is None:
  188. pytest.skip("Could not determine available memory; set SCIPY_AVAILABLE_MEM "
  189. "variable to free memory in MB to run the test.")
  190. msg = f'{free_mb} MB memory required, but {mem_free/1e6} MB available'
  191. if mem_free < free_mb * 1e6:
  192. pytest.skip(msg)
  193. def _parse_size(size_str):
  194. suffixes = {'': 1e6,
  195. 'b': 1.0,
  196. 'k': 1e3, 'M': 1e6, 'G': 1e9, 'T': 1e12,
  197. 'kb': 1e3, 'Mb': 1e6, 'Gb': 1e9, 'Tb': 1e12,
  198. 'kib': 1024.0, 'Mib': 1024.0**2, 'Gib': 1024.0**3, 'Tib': 1024.0**4}
  199. m = re.match(r'^\s*(\d+)\s*({})\s*$'.format('|'.join(suffixes.keys())),
  200. size_str,
  201. re.I)
  202. if not m or m.group(2) not in suffixes:
  203. raise ValueError("Invalid size string")
  204. return float(m.group(1)) * suffixes[m.group(2)]
  205. def _get_mem_available():
  206. """
  207. Get information about memory available, not counting swap.
  208. """
  209. try:
  210. import psutil
  211. return psutil.virtual_memory().available
  212. except (ImportError, AttributeError):
  213. pass
  214. if sys.platform.startswith('linux'):
  215. info = {}
  216. with open('/proc/meminfo') as f:
  217. for line in f:
  218. p = line.split()
  219. info[p[0].strip(':').lower()] = float(p[1]) * 1e3
  220. if 'memavailable' in info:
  221. # Linux >= 3.14
  222. return info['memavailable']
  223. else:
  224. return info['memfree'] + info['cached']
  225. return None
  226. def _test_cython_extension(tmp_path, srcdir):
  227. """
  228. Helper function to test building and importing Cython modules that
  229. make use of the Cython APIs for BLAS, LAPACK, optimize, and special.
  230. """
  231. import pytest
  232. try:
  233. subprocess.check_call(["meson", "--version"])
  234. except FileNotFoundError:
  235. pytest.skip("No usable 'meson' found")
  236. # Make safe for being called by multiple threads within one test
  237. tmp_path = tmp_path / str(threading.get_ident())
  238. # build the examples in a temporary directory
  239. mod_name = os.path.split(srcdir)[1]
  240. shutil.copytree(srcdir, tmp_path / mod_name)
  241. build_dir = tmp_path / mod_name / 'tests' / '_cython_examples'
  242. target_dir = build_dir / 'build'
  243. os.makedirs(target_dir, exist_ok=True)
  244. # Ensure we use the correct Python interpreter even when `meson` is
  245. # installed in a different Python environment (see numpy#24956)
  246. native_file = str(build_dir / 'interpreter-native-file.ini')
  247. with open(native_file, 'w') as f:
  248. f.write("[binaries]\n")
  249. f.write(f"python = '{sys.executable}'")
  250. if sys.platform == "win32":
  251. subprocess.check_call(["meson", "setup",
  252. "--buildtype=release",
  253. "--native-file", native_file,
  254. "--vsenv", str(build_dir)],
  255. cwd=target_dir,
  256. )
  257. else:
  258. subprocess.check_call(["meson", "setup",
  259. "--native-file", native_file, str(build_dir)],
  260. cwd=target_dir
  261. )
  262. subprocess.check_call(["meson", "compile", "-vv"], cwd=target_dir)
  263. # import without adding the directory to sys.path
  264. suffix = sysconfig.get_config_var('EXT_SUFFIX')
  265. def load(modname):
  266. so = (target_dir / modname).with_suffix(suffix)
  267. spec = spec_from_file_location(modname, so)
  268. mod = module_from_spec(spec)
  269. spec.loader.exec_module(mod)
  270. return mod
  271. # test that the module can be imported
  272. return load("extending"), load("extending_cpp")
  273. def _run_concurrent_barrier(n_workers, fn, *args, **kwargs):
  274. """
  275. Run a given function concurrently across a given number of threads.
  276. This is equivalent to using a ThreadPoolExecutor, but using the threading
  277. primitives instead. This function ensures that the closure passed by
  278. parameter gets called concurrently by setting up a barrier before it gets
  279. called before any of the threads.
  280. Arguments
  281. ---------
  282. n_workers: int
  283. Number of concurrent threads to spawn.
  284. fn: callable
  285. Function closure to execute concurrently. Its first argument will
  286. be the thread id.
  287. *args: tuple
  288. Variable number of positional arguments to pass to the function.
  289. **kwargs: dict
  290. Keyword arguments to pass to the function.
  291. """
  292. barrier = threading.Barrier(n_workers)
  293. def closure(i, *args, **kwargs):
  294. barrier.wait()
  295. fn(i, *args, **kwargs)
  296. workers = []
  297. for i in range(0, n_workers):
  298. workers.append(threading.Thread(
  299. target=closure,
  300. args=(i,) + args, kwargs=kwargs))
  301. for worker in workers:
  302. worker.start()
  303. for worker in workers:
  304. worker.join()