test_extending.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import os
  2. import platform
  3. import sysconfig
  4. import numpy as np
  5. import pytest
  6. from scipy._lib._testutils import IS_EDITABLE, _test_cython_extension, cython
  7. from scipy.linalg.blas import cdotu # type: ignore[attr-defined]
  8. from scipy.linalg.lapack import dgtsv # type: ignore[attr-defined]
  9. @pytest.mark.parallel_threads_limit(4) # 0.35 GiB per thread RAM usage
  10. @pytest.mark.fail_slow(120)
  11. # essential per https://github.com/scipy/scipy/pull/20487#discussion_r1567057247
  12. @pytest.mark.skipif(IS_EDITABLE,
  13. reason='Editable install cannot find .pxd headers.')
  14. @pytest.mark.skipif((platform.system() == 'Windows' and
  15. sysconfig.get_config_var('Py_GIL_DISABLED')),
  16. reason='gh-22039')
  17. @pytest.mark.skipif(platform.machine() in ["wasm32", "wasm64"],
  18. reason="Can't start subprocess")
  19. @pytest.mark.skipif(cython is None, reason="requires cython")
  20. def test_cython(tmp_path):
  21. srcdir = os.path.dirname(os.path.dirname(__file__))
  22. extensions, extensions_cpp = _test_cython_extension(tmp_path, srcdir)
  23. # actually test the cython c-extensions
  24. a = np.ones(8) * 3
  25. b = np.ones(9)
  26. c = np.ones(8) * 4
  27. x = np.ones(9)
  28. _, _, _, x, _ = dgtsv(a, b, c, x)
  29. a = np.ones(8) * 3
  30. b = np.ones(9)
  31. c = np.ones(8) * 4
  32. x_c = np.ones(9)
  33. extensions.tridiag(a, b, c, x_c)
  34. a = np.ones(8) * 3
  35. b = np.ones(9)
  36. c = np.ones(8) * 4
  37. x_cpp = np.ones(9)
  38. extensions_cpp.tridiag(a, b, c, x_cpp)
  39. np.testing.assert_array_equal(x, x_cpp)
  40. cx = np.array([1-1j, 2+2j, 3-3j], dtype=np.complex64)
  41. cy = np.array([4+4j, 5-5j, 6+6j], dtype=np.complex64)
  42. np.testing.assert_array_equal(cdotu(cx, cy), extensions.complex_dot(cx, cy))
  43. np.testing.assert_array_equal(cdotu(cx, cy), extensions_cpp.complex_dot(cx, cy))