test_autowrap.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467
  1. # Tests that require installed backends go into
  2. # sympy/test_external/test_autowrap
  3. import os
  4. import tempfile
  5. import shutil
  6. from io import StringIO
  7. from pathlib import Path
  8. from sympy.core import symbols, Eq
  9. from sympy.utilities.autowrap import (autowrap, binary_function,
  10. CythonCodeWrapper, UfuncifyCodeWrapper, CodeWrapper)
  11. from sympy.utilities.codegen import (
  12. CCodeGen, C99CodeGen, CodeGenArgumentListError, make_routine
  13. )
  14. from sympy.testing.pytest import raises
  15. from sympy.testing.tmpfiles import TmpFileManager
  16. def get_string(dump_fn, routines, prefix="file", **kwargs):
  17. """Wrapper for dump_fn. dump_fn writes its results to a stream object and
  18. this wrapper returns the contents of that stream as a string. This
  19. auxiliary function is used by many tests below.
  20. The header and the empty lines are not generator to facilitate the
  21. testing of the output.
  22. """
  23. output = StringIO()
  24. dump_fn(routines, output, prefix, **kwargs)
  25. source = output.getvalue()
  26. output.close()
  27. return source
  28. def test_cython_wrapper_scalar_function():
  29. x, y, z = symbols('x,y,z')
  30. expr = (x + y)*z
  31. routine = make_routine("test", expr)
  32. code_gen = CythonCodeWrapper(CCodeGen())
  33. source = get_string(code_gen.dump_pyx, [routine])
  34. expected = (
  35. "cdef extern from 'file.h':\n"
  36. " double test(double x, double y, double z)\n"
  37. "\n"
  38. "def test_c(double x, double y, double z):\n"
  39. "\n"
  40. " return test(x, y, z)")
  41. assert source == expected
  42. def test_cython_wrapper_outarg():
  43. from sympy.core.relational import Equality
  44. x, y, z = symbols('x,y,z')
  45. code_gen = CythonCodeWrapper(C99CodeGen())
  46. routine = make_routine("test", Equality(z, x + y))
  47. source = get_string(code_gen.dump_pyx, [routine])
  48. expected = (
  49. "cdef extern from 'file.h':\n"
  50. " void test(double x, double y, double *z)\n"
  51. "\n"
  52. "def test_c(double x, double y):\n"
  53. "\n"
  54. " cdef double z = 0\n"
  55. " test(x, y, &z)\n"
  56. " return z")
  57. assert source == expected
  58. def test_cython_wrapper_inoutarg():
  59. from sympy.core.relational import Equality
  60. x, y, z = symbols('x,y,z')
  61. code_gen = CythonCodeWrapper(C99CodeGen())
  62. routine = make_routine("test", Equality(z, x + y + z))
  63. source = get_string(code_gen.dump_pyx, [routine])
  64. expected = (
  65. "cdef extern from 'file.h':\n"
  66. " void test(double x, double y, double *z)\n"
  67. "\n"
  68. "def test_c(double x, double y, double z):\n"
  69. "\n"
  70. " test(x, y, &z)\n"
  71. " return z")
  72. assert source == expected
  73. def test_cython_wrapper_compile_flags():
  74. from sympy.core.relational import Equality
  75. x, y, z = symbols('x,y,z')
  76. routine = make_routine("test", Equality(z, x + y))
  77. code_gen = CythonCodeWrapper(CCodeGen())
  78. expected = """\
  79. from setuptools import setup
  80. from setuptools import Extension
  81. from Cython.Build import cythonize
  82. cy_opts = {'compiler_directives': {'language_level': '3'}}
  83. ext_mods = [Extension(
  84. 'wrapper_module_%(num)s', ['wrapper_module_%(num)s.pyx', 'wrapped_code_%(num)s.c'],
  85. include_dirs=[],
  86. library_dirs=[],
  87. libraries=[],
  88. extra_compile_args=['-std=c99'],
  89. extra_link_args=[]
  90. )]
  91. setup(ext_modules=cythonize(ext_mods, **cy_opts))
  92. """ % {'num': CodeWrapper._module_counter}
  93. temp_dir = tempfile.mkdtemp()
  94. TmpFileManager.tmp_folder(temp_dir)
  95. setup_file_path = os.path.join(temp_dir, 'setup.py')
  96. code_gen._prepare_files(routine, build_dir=temp_dir)
  97. setup_text = Path(setup_file_path).read_text()
  98. assert setup_text == expected
  99. code_gen = CythonCodeWrapper(CCodeGen(),
  100. include_dirs=['/usr/local/include', '/opt/booger/include'],
  101. library_dirs=['/user/local/lib'],
  102. libraries=['thelib', 'nilib'],
  103. extra_compile_args=['-slow-math'],
  104. extra_link_args=['-lswamp', '-ltrident'],
  105. cythonize_options={'compiler_directives': {'boundscheck': False}}
  106. )
  107. expected = """\
  108. from setuptools import setup
  109. from setuptools import Extension
  110. from Cython.Build import cythonize
  111. cy_opts = {'compiler_directives': {'boundscheck': False}}
  112. ext_mods = [Extension(
  113. 'wrapper_module_%(num)s', ['wrapper_module_%(num)s.pyx', 'wrapped_code_%(num)s.c'],
  114. include_dirs=['/usr/local/include', '/opt/booger/include'],
  115. library_dirs=['/user/local/lib'],
  116. libraries=['thelib', 'nilib'],
  117. extra_compile_args=['-slow-math', '-std=c99'],
  118. extra_link_args=['-lswamp', '-ltrident']
  119. )]
  120. setup(ext_modules=cythonize(ext_mods, **cy_opts))
  121. """ % {'num': CodeWrapper._module_counter}
  122. code_gen._prepare_files(routine, build_dir=temp_dir)
  123. setup_text = Path(setup_file_path).read_text()
  124. assert setup_text == expected
  125. expected = """\
  126. from setuptools import setup
  127. from setuptools import Extension
  128. from Cython.Build import cythonize
  129. cy_opts = {'compiler_directives': {'boundscheck': False}}
  130. import numpy as np
  131. ext_mods = [Extension(
  132. 'wrapper_module_%(num)s', ['wrapper_module_%(num)s.pyx', 'wrapped_code_%(num)s.c'],
  133. include_dirs=['/usr/local/include', '/opt/booger/include', np.get_include()],
  134. library_dirs=['/user/local/lib'],
  135. libraries=['thelib', 'nilib'],
  136. extra_compile_args=['-slow-math', '-std=c99'],
  137. extra_link_args=['-lswamp', '-ltrident']
  138. )]
  139. setup(ext_modules=cythonize(ext_mods, **cy_opts))
  140. """ % {'num': CodeWrapper._module_counter}
  141. code_gen._need_numpy = True
  142. code_gen._prepare_files(routine, build_dir=temp_dir)
  143. setup_text = Path(setup_file_path).read_text()
  144. assert setup_text == expected
  145. TmpFileManager.cleanup()
  146. def test_cython_wrapper_unique_dummyvars():
  147. from sympy.core.relational import Equality
  148. from sympy.core.symbol import Dummy
  149. x, y, z = Dummy('x'), Dummy('y'), Dummy('z')
  150. x_id, y_id, z_id = [str(d.dummy_index) for d in [x, y, z]]
  151. expr = Equality(z, x + y)
  152. routine = make_routine("test", expr)
  153. code_gen = CythonCodeWrapper(CCodeGen())
  154. source = get_string(code_gen.dump_pyx, [routine])
  155. expected_template = (
  156. "cdef extern from 'file.h':\n"
  157. " void test(double x_{x_id}, double y_{y_id}, double *z_{z_id})\n"
  158. "\n"
  159. "def test_c(double x_{x_id}, double y_{y_id}):\n"
  160. "\n"
  161. " cdef double z_{z_id} = 0\n"
  162. " test(x_{x_id}, y_{y_id}, &z_{z_id})\n"
  163. " return z_{z_id}")
  164. expected = expected_template.format(x_id=x_id, y_id=y_id, z_id=z_id)
  165. assert source == expected
  166. def test_autowrap_dummy():
  167. x, y, z = symbols('x y z')
  168. # Uses DummyWrapper to test that codegen works as expected
  169. f = autowrap(x + y, backend='dummy')
  170. assert f() == str(x + y)
  171. assert f.args == "x, y"
  172. assert f.returns == "nameless"
  173. f = autowrap(Eq(z, x + y), backend='dummy')
  174. assert f() == str(x + y)
  175. assert f.args == "x, y"
  176. assert f.returns == "z"
  177. f = autowrap(Eq(z, x + y + z), backend='dummy')
  178. assert f() == str(x + y + z)
  179. assert f.args == "x, y, z"
  180. assert f.returns == "z"
  181. def test_autowrap_args():
  182. x, y, z = symbols('x y z')
  183. raises(CodeGenArgumentListError, lambda: autowrap(Eq(z, x + y),
  184. backend='dummy', args=[x]))
  185. f = autowrap(Eq(z, x + y), backend='dummy', args=[y, x])
  186. assert f() == str(x + y)
  187. assert f.args == "y, x"
  188. assert f.returns == "z"
  189. raises(CodeGenArgumentListError, lambda: autowrap(Eq(z, x + y + z),
  190. backend='dummy', args=[x, y]))
  191. f = autowrap(Eq(z, x + y + z), backend='dummy', args=[y, x, z])
  192. assert f() == str(x + y + z)
  193. assert f.args == "y, x, z"
  194. assert f.returns == "z"
  195. f = autowrap(Eq(z, x + y + z), backend='dummy', args=(y, x, z))
  196. assert f() == str(x + y + z)
  197. assert f.args == "y, x, z"
  198. assert f.returns == "z"
  199. def test_autowrap_store_files():
  200. x, y = symbols('x y')
  201. tmp = tempfile.mkdtemp()
  202. TmpFileManager.tmp_folder(tmp)
  203. f = autowrap(x + y, backend='dummy', tempdir=tmp)
  204. assert f() == str(x + y)
  205. assert os.access(tmp, os.F_OK)
  206. TmpFileManager.cleanup()
  207. def test_autowrap_store_files_issue_gh12939():
  208. x, y = symbols('x y')
  209. tmp = './tmp'
  210. saved_cwd = os.getcwd()
  211. temp_cwd = tempfile.mkdtemp()
  212. try:
  213. os.chdir(temp_cwd)
  214. f = autowrap(x + y, backend='dummy', tempdir=tmp)
  215. assert f() == str(x + y)
  216. assert os.access(tmp, os.F_OK)
  217. finally:
  218. os.chdir(saved_cwd)
  219. shutil.rmtree(temp_cwd)
  220. def test_binary_function():
  221. x, y = symbols('x y')
  222. f = binary_function('f', x + y, backend='dummy')
  223. assert f._imp_() == str(x + y)
  224. def test_ufuncify_source():
  225. x, y, z = symbols('x,y,z')
  226. code_wrapper = UfuncifyCodeWrapper(C99CodeGen("ufuncify"))
  227. routine = make_routine("test", x + y + z)
  228. source = get_string(code_wrapper.dump_c, [routine])
  229. expected = """\
  230. #include "Python.h"
  231. #include "math.h"
  232. #include "numpy/ndarraytypes.h"
  233. #include "numpy/ufuncobject.h"
  234. #include "numpy/halffloat.h"
  235. #include "file.h"
  236. static PyMethodDef wrapper_module_%(num)sMethods[] = {
  237. {NULL, NULL, 0, NULL}
  238. };
  239. #ifdef NPY_1_19_API_VERSION
  240. static void test_ufunc(char **args, const npy_intp *dimensions, const npy_intp* steps, void* data)
  241. #else
  242. static void test_ufunc(char **args, npy_intp *dimensions, npy_intp* steps, void* data)
  243. #endif
  244. {
  245. npy_intp i;
  246. npy_intp n = dimensions[0];
  247. char *in0 = args[0];
  248. char *in1 = args[1];
  249. char *in2 = args[2];
  250. char *out0 = args[3];
  251. npy_intp in0_step = steps[0];
  252. npy_intp in1_step = steps[1];
  253. npy_intp in2_step = steps[2];
  254. npy_intp out0_step = steps[3];
  255. for (i = 0; i < n; i++) {
  256. *((double *)out0) = test(*(double *)in0, *(double *)in1, *(double *)in2);
  257. in0 += in0_step;
  258. in1 += in1_step;
  259. in2 += in2_step;
  260. out0 += out0_step;
  261. }
  262. }
  263. PyUFuncGenericFunction test_funcs[1] = {&test_ufunc};
  264. static char test_types[4] = {NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE};
  265. static void *test_data[1] = {NULL};
  266. #if PY_VERSION_HEX >= 0x03000000
  267. static struct PyModuleDef moduledef = {
  268. PyModuleDef_HEAD_INIT,
  269. "wrapper_module_%(num)s",
  270. NULL,
  271. -1,
  272. wrapper_module_%(num)sMethods,
  273. NULL,
  274. NULL,
  275. NULL,
  276. NULL
  277. };
  278. PyMODINIT_FUNC PyInit_wrapper_module_%(num)s(void)
  279. {
  280. PyObject *m, *d;
  281. PyObject *ufunc0;
  282. m = PyModule_Create(&moduledef);
  283. if (!m) {
  284. return NULL;
  285. }
  286. import_array();
  287. import_umath();
  288. d = PyModule_GetDict(m);
  289. ufunc0 = PyUFunc_FromFuncAndData(test_funcs, test_data, test_types, 1, 3, 1,
  290. PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0);
  291. PyDict_SetItemString(d, "test", ufunc0);
  292. Py_DECREF(ufunc0);
  293. return m;
  294. }
  295. #else
  296. PyMODINIT_FUNC initwrapper_module_%(num)s(void)
  297. {
  298. PyObject *m, *d;
  299. PyObject *ufunc0;
  300. m = Py_InitModule("wrapper_module_%(num)s", wrapper_module_%(num)sMethods);
  301. if (m == NULL) {
  302. return;
  303. }
  304. import_array();
  305. import_umath();
  306. d = PyModule_GetDict(m);
  307. ufunc0 = PyUFunc_FromFuncAndData(test_funcs, test_data, test_types, 1, 3, 1,
  308. PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0);
  309. PyDict_SetItemString(d, "test", ufunc0);
  310. Py_DECREF(ufunc0);
  311. }
  312. #endif""" % {'num': CodeWrapper._module_counter}
  313. assert source == expected
  314. def test_ufuncify_source_multioutput():
  315. x, y, z = symbols('x,y,z')
  316. var_symbols = (x, y, z)
  317. expr = x + y**3 + 10*z**2
  318. code_wrapper = UfuncifyCodeWrapper(C99CodeGen("ufuncify"))
  319. routines = [make_routine("func{}".format(i), expr.diff(var_symbols[i]), var_symbols) for i in range(len(var_symbols))]
  320. source = get_string(code_wrapper.dump_c, routines, funcname='multitest')
  321. expected = """\
  322. #include "Python.h"
  323. #include "math.h"
  324. #include "numpy/ndarraytypes.h"
  325. #include "numpy/ufuncobject.h"
  326. #include "numpy/halffloat.h"
  327. #include "file.h"
  328. static PyMethodDef wrapper_module_%(num)sMethods[] = {
  329. {NULL, NULL, 0, NULL}
  330. };
  331. #ifdef NPY_1_19_API_VERSION
  332. static void multitest_ufunc(char **args, const npy_intp *dimensions, const npy_intp* steps, void* data)
  333. #else
  334. static void multitest_ufunc(char **args, npy_intp *dimensions, npy_intp* steps, void* data)
  335. #endif
  336. {
  337. npy_intp i;
  338. npy_intp n = dimensions[0];
  339. char *in0 = args[0];
  340. char *in1 = args[1];
  341. char *in2 = args[2];
  342. char *out0 = args[3];
  343. char *out1 = args[4];
  344. char *out2 = args[5];
  345. npy_intp in0_step = steps[0];
  346. npy_intp in1_step = steps[1];
  347. npy_intp in2_step = steps[2];
  348. npy_intp out0_step = steps[3];
  349. npy_intp out1_step = steps[4];
  350. npy_intp out2_step = steps[5];
  351. for (i = 0; i < n; i++) {
  352. *((double *)out0) = func0(*(double *)in0, *(double *)in1, *(double *)in2);
  353. *((double *)out1) = func1(*(double *)in0, *(double *)in1, *(double *)in2);
  354. *((double *)out2) = func2(*(double *)in0, *(double *)in1, *(double *)in2);
  355. in0 += in0_step;
  356. in1 += in1_step;
  357. in2 += in2_step;
  358. out0 += out0_step;
  359. out1 += out1_step;
  360. out2 += out2_step;
  361. }
  362. }
  363. PyUFuncGenericFunction multitest_funcs[1] = {&multitest_ufunc};
  364. static char multitest_types[6] = {NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE};
  365. static void *multitest_data[1] = {NULL};
  366. #if PY_VERSION_HEX >= 0x03000000
  367. static struct PyModuleDef moduledef = {
  368. PyModuleDef_HEAD_INIT,
  369. "wrapper_module_%(num)s",
  370. NULL,
  371. -1,
  372. wrapper_module_%(num)sMethods,
  373. NULL,
  374. NULL,
  375. NULL,
  376. NULL
  377. };
  378. PyMODINIT_FUNC PyInit_wrapper_module_%(num)s(void)
  379. {
  380. PyObject *m, *d;
  381. PyObject *ufunc0;
  382. m = PyModule_Create(&moduledef);
  383. if (!m) {
  384. return NULL;
  385. }
  386. import_array();
  387. import_umath();
  388. d = PyModule_GetDict(m);
  389. ufunc0 = PyUFunc_FromFuncAndData(multitest_funcs, multitest_data, multitest_types, 1, 3, 3,
  390. PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0);
  391. PyDict_SetItemString(d, "multitest", ufunc0);
  392. Py_DECREF(ufunc0);
  393. return m;
  394. }
  395. #else
  396. PyMODINIT_FUNC initwrapper_module_%(num)s(void)
  397. {
  398. PyObject *m, *d;
  399. PyObject *ufunc0;
  400. m = Py_InitModule("wrapper_module_%(num)s", wrapper_module_%(num)sMethods);
  401. if (m == NULL) {
  402. return;
  403. }
  404. import_array();
  405. import_umath();
  406. d = PyModule_GetDict(m);
  407. ufunc0 = PyUFunc_FromFuncAndData(multitest_funcs, multitest_data, multitest_types, 1, 3, 3,
  408. PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0);
  409. PyDict_SetItemString(d, "multitest", ufunc0);
  410. Py_DECREF(ufunc0);
  411. }
  412. #endif""" % {'num': CodeWrapper._module_counter}
  413. assert source == expected