test_codegen.py 55 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632
  1. from io import StringIO
  2. from sympy.core import symbols, Eq, pi, Catalan, Lambda, Dummy
  3. from sympy.core.relational import Equality
  4. from sympy.core.symbol import Symbol
  5. from sympy.functions.special.error_functions import erf
  6. from sympy.integrals.integrals import Integral
  7. from sympy.matrices import Matrix, MatrixSymbol
  8. from sympy.utilities.codegen import (
  9. codegen, make_routine, CCodeGen, C89CodeGen, C99CodeGen, InputArgument,
  10. CodeGenError, FCodeGen, CodeGenArgumentListError, OutputArgument,
  11. InOutArgument)
  12. from sympy.testing.pytest import raises
  13. from sympy.utilities.lambdify import implemented_function
  14. #FIXME: Fails due to circular import in with core
  15. # from sympy import codegen
  16. def get_string(dump_fn, routines, prefix="file", header=False, empty=False):
  17. """Wrapper for dump_fn. dump_fn writes its results to a stream object and
  18. this wrapper returns the contents of that stream as a string. This
  19. auxiliary function is used by many tests below.
  20. The header and the empty lines are not generated to facilitate the
  21. testing of the output.
  22. """
  23. output = StringIO()
  24. dump_fn(routines, output, prefix, header, empty)
  25. source = output.getvalue()
  26. output.close()
  27. return source
  28. def test_Routine_argument_order():
  29. a, x, y, z = symbols('a x y z')
  30. expr = (x + y)*z
  31. raises(CodeGenArgumentListError, lambda: make_routine("test", expr,
  32. argument_sequence=[z, x]))
  33. raises(CodeGenArgumentListError, lambda: make_routine("test", Eq(a,
  34. expr), argument_sequence=[z, x, y]))
  35. r = make_routine('test', Eq(a, expr), argument_sequence=[z, x, a, y])
  36. assert [ arg.name for arg in r.arguments ] == [z, x, a, y]
  37. assert [ type(arg) for arg in r.arguments ] == [
  38. InputArgument, InputArgument, OutputArgument, InputArgument ]
  39. r = make_routine('test', Eq(z, expr), argument_sequence=[z, x, y])
  40. assert [ type(arg) for arg in r.arguments ] == [
  41. InOutArgument, InputArgument, InputArgument ]
  42. from sympy.tensor import IndexedBase, Idx
  43. A, B = map(IndexedBase, ['A', 'B'])
  44. m = symbols('m', integer=True)
  45. i = Idx('i', m)
  46. r = make_routine('test', Eq(A[i], B[i]), argument_sequence=[B, A, m])
  47. assert [ arg.name for arg in r.arguments ] == [B.label, A.label, m]
  48. expr = Integral(x*y*z, (x, 1, 2), (y, 1, 3))
  49. r = make_routine('test', Eq(a, expr), argument_sequence=[z, x, a, y])
  50. assert [ arg.name for arg in r.arguments ] == [z, x, a, y]
  51. def test_empty_c_code():
  52. code_gen = C89CodeGen()
  53. source = get_string(code_gen.dump_c, [])
  54. assert source == "#include \"file.h\"\n#include <math.h>\n"
  55. def test_empty_c_code_with_comment():
  56. code_gen = C89CodeGen()
  57. source = get_string(code_gen.dump_c, [], header=True)
  58. assert source[:82] == (
  59. "/******************************************************************************\n *"
  60. )
  61. # " Code generated with SymPy 0.7.2-git "
  62. assert source[158:] == ( "*\n"
  63. " * *\n"
  64. " * See http://www.sympy.org/ for more information. *\n"
  65. " * *\n"
  66. " * This file is part of 'project' *\n"
  67. " ******************************************************************************/\n"
  68. "#include \"file.h\"\n"
  69. "#include <math.h>\n"
  70. )
  71. def test_empty_c_header():
  72. code_gen = C99CodeGen()
  73. source = get_string(code_gen.dump_h, [])
  74. assert source == "#ifndef PROJECT__FILE__H\n#define PROJECT__FILE__H\n#endif\n"
  75. def test_simple_c_code():
  76. x, y, z = symbols('x,y,z')
  77. expr = (x + y)*z
  78. routine = make_routine("test", expr)
  79. code_gen = C89CodeGen()
  80. source = get_string(code_gen.dump_c, [routine])
  81. expected = (
  82. "#include \"file.h\"\n"
  83. "#include <math.h>\n"
  84. "double test(double x, double y, double z) {\n"
  85. " double test_result;\n"
  86. " test_result = z*(x + y);\n"
  87. " return test_result;\n"
  88. "}\n"
  89. )
  90. assert source == expected
  91. def test_c_code_reserved_words():
  92. x, y, z = symbols('if, typedef, while')
  93. expr = (x + y) * z
  94. routine = make_routine("test", expr)
  95. code_gen = C99CodeGen()
  96. source = get_string(code_gen.dump_c, [routine])
  97. expected = (
  98. "#include \"file.h\"\n"
  99. "#include <math.h>\n"
  100. "double test(double if_, double typedef_, double while_) {\n"
  101. " double test_result;\n"
  102. " test_result = while_*(if_ + typedef_);\n"
  103. " return test_result;\n"
  104. "}\n"
  105. )
  106. assert source == expected
  107. def test_numbersymbol_c_code():
  108. routine = make_routine("test", pi**Catalan)
  109. code_gen = C89CodeGen()
  110. source = get_string(code_gen.dump_c, [routine])
  111. expected = (
  112. "#include \"file.h\"\n"
  113. "#include <math.h>\n"
  114. "double test() {\n"
  115. " double test_result;\n"
  116. " double const Catalan = %s;\n"
  117. " test_result = pow(M_PI, Catalan);\n"
  118. " return test_result;\n"
  119. "}\n"
  120. ) % Catalan.evalf(17)
  121. assert source == expected
  122. def test_c_code_argument_order():
  123. x, y, z = symbols('x,y,z')
  124. expr = x + y
  125. routine = make_routine("test", expr, argument_sequence=[z, x, y])
  126. code_gen = C89CodeGen()
  127. source = get_string(code_gen.dump_c, [routine])
  128. expected = (
  129. "#include \"file.h\"\n"
  130. "#include <math.h>\n"
  131. "double test(double z, double x, double y) {\n"
  132. " double test_result;\n"
  133. " test_result = x + y;\n"
  134. " return test_result;\n"
  135. "}\n"
  136. )
  137. assert source == expected
  138. def test_simple_c_header():
  139. x, y, z = symbols('x,y,z')
  140. expr = (x + y)*z
  141. routine = make_routine("test", expr)
  142. code_gen = C89CodeGen()
  143. source = get_string(code_gen.dump_h, [routine])
  144. expected = (
  145. "#ifndef PROJECT__FILE__H\n"
  146. "#define PROJECT__FILE__H\n"
  147. "double test(double x, double y, double z);\n"
  148. "#endif\n"
  149. )
  150. assert source == expected
  151. def test_simple_c_codegen():
  152. x, y, z = symbols('x,y,z')
  153. expr = (x + y)*z
  154. expected = [
  155. ("file.c",
  156. "#include \"file.h\"\n"
  157. "#include <math.h>\n"
  158. "double test(double x, double y, double z) {\n"
  159. " double test_result;\n"
  160. " test_result = z*(x + y);\n"
  161. " return test_result;\n"
  162. "}\n"),
  163. ("file.h",
  164. "#ifndef PROJECT__FILE__H\n"
  165. "#define PROJECT__FILE__H\n"
  166. "double test(double x, double y, double z);\n"
  167. "#endif\n")
  168. ]
  169. result = codegen(("test", expr), "C", "file", header=False, empty=False)
  170. assert result == expected
  171. def test_multiple_results_c():
  172. x, y, z = symbols('x,y,z')
  173. expr1 = (x + y)*z
  174. expr2 = (x - y)*z
  175. routine = make_routine(
  176. "test",
  177. [expr1, expr2]
  178. )
  179. code_gen = C99CodeGen()
  180. raises(CodeGenError, lambda: get_string(code_gen.dump_h, [routine]))
  181. def test_no_results_c():
  182. raises(ValueError, lambda: make_routine("test", []))
  183. def test_ansi_math1_codegen():
  184. # not included: log10
  185. from sympy.functions.elementary.complexes import Abs
  186. from sympy.functions.elementary.exponential import log
  187. from sympy.functions.elementary.hyperbolic import (cosh, sinh, tanh)
  188. from sympy.functions.elementary.integers import (ceiling, floor)
  189. from sympy.functions.elementary.miscellaneous import sqrt
  190. from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sin, tan)
  191. x = symbols('x')
  192. name_expr = [
  193. ("test_fabs", Abs(x)),
  194. ("test_acos", acos(x)),
  195. ("test_asin", asin(x)),
  196. ("test_atan", atan(x)),
  197. ("test_ceil", ceiling(x)),
  198. ("test_cos", cos(x)),
  199. ("test_cosh", cosh(x)),
  200. ("test_floor", floor(x)),
  201. ("test_log", log(x)),
  202. ("test_ln", log(x)),
  203. ("test_sin", sin(x)),
  204. ("test_sinh", sinh(x)),
  205. ("test_sqrt", sqrt(x)),
  206. ("test_tan", tan(x)),
  207. ("test_tanh", tanh(x)),
  208. ]
  209. result = codegen(name_expr, "C89", "file", header=False, empty=False)
  210. assert result[0][0] == "file.c"
  211. assert result[0][1] == (
  212. '#include "file.h"\n#include <math.h>\n'
  213. 'double test_fabs(double x) {\n double test_fabs_result;\n test_fabs_result = fabs(x);\n return test_fabs_result;\n}\n'
  214. 'double test_acos(double x) {\n double test_acos_result;\n test_acos_result = acos(x);\n return test_acos_result;\n}\n'
  215. 'double test_asin(double x) {\n double test_asin_result;\n test_asin_result = asin(x);\n return test_asin_result;\n}\n'
  216. 'double test_atan(double x) {\n double test_atan_result;\n test_atan_result = atan(x);\n return test_atan_result;\n}\n'
  217. 'double test_ceil(double x) {\n double test_ceil_result;\n test_ceil_result = ceil(x);\n return test_ceil_result;\n}\n'
  218. 'double test_cos(double x) {\n double test_cos_result;\n test_cos_result = cos(x);\n return test_cos_result;\n}\n'
  219. 'double test_cosh(double x) {\n double test_cosh_result;\n test_cosh_result = cosh(x);\n return test_cosh_result;\n}\n'
  220. 'double test_floor(double x) {\n double test_floor_result;\n test_floor_result = floor(x);\n return test_floor_result;\n}\n'
  221. 'double test_log(double x) {\n double test_log_result;\n test_log_result = log(x);\n return test_log_result;\n}\n'
  222. 'double test_ln(double x) {\n double test_ln_result;\n test_ln_result = log(x);\n return test_ln_result;\n}\n'
  223. 'double test_sin(double x) {\n double test_sin_result;\n test_sin_result = sin(x);\n return test_sin_result;\n}\n'
  224. 'double test_sinh(double x) {\n double test_sinh_result;\n test_sinh_result = sinh(x);\n return test_sinh_result;\n}\n'
  225. 'double test_sqrt(double x) {\n double test_sqrt_result;\n test_sqrt_result = sqrt(x);\n return test_sqrt_result;\n}\n'
  226. 'double test_tan(double x) {\n double test_tan_result;\n test_tan_result = tan(x);\n return test_tan_result;\n}\n'
  227. 'double test_tanh(double x) {\n double test_tanh_result;\n test_tanh_result = tanh(x);\n return test_tanh_result;\n}\n'
  228. )
  229. assert result[1][0] == "file.h"
  230. assert result[1][1] == (
  231. '#ifndef PROJECT__FILE__H\n#define PROJECT__FILE__H\n'
  232. 'double test_fabs(double x);\ndouble test_acos(double x);\n'
  233. 'double test_asin(double x);\ndouble test_atan(double x);\n'
  234. 'double test_ceil(double x);\ndouble test_cos(double x);\n'
  235. 'double test_cosh(double x);\ndouble test_floor(double x);\n'
  236. 'double test_log(double x);\ndouble test_ln(double x);\n'
  237. 'double test_sin(double x);\ndouble test_sinh(double x);\n'
  238. 'double test_sqrt(double x);\ndouble test_tan(double x);\n'
  239. 'double test_tanh(double x);\n#endif\n'
  240. )
  241. def test_ansi_math2_codegen():
  242. # not included: frexp, ldexp, modf, fmod
  243. from sympy.functions.elementary.trigonometric import atan2
  244. x, y = symbols('x,y')
  245. name_expr = [
  246. ("test_atan2", atan2(x, y)),
  247. ("test_pow", x**y),
  248. ]
  249. result = codegen(name_expr, "C89", "file", header=False, empty=False)
  250. assert result[0][0] == "file.c"
  251. assert result[0][1] == (
  252. '#include "file.h"\n#include <math.h>\n'
  253. 'double test_atan2(double x, double y) {\n double test_atan2_result;\n test_atan2_result = atan2(x, y);\n return test_atan2_result;\n}\n'
  254. 'double test_pow(double x, double y) {\n double test_pow_result;\n test_pow_result = pow(x, y);\n return test_pow_result;\n}\n'
  255. )
  256. assert result[1][0] == "file.h"
  257. assert result[1][1] == (
  258. '#ifndef PROJECT__FILE__H\n#define PROJECT__FILE__H\n'
  259. 'double test_atan2(double x, double y);\n'
  260. 'double test_pow(double x, double y);\n'
  261. '#endif\n'
  262. )
  263. def test_complicated_codegen():
  264. from sympy.functions.elementary.trigonometric import (cos, sin, tan)
  265. x, y, z = symbols('x,y,z')
  266. name_expr = [
  267. ("test1", ((sin(x) + cos(y) + tan(z))**7).expand()),
  268. ("test2", cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))),
  269. ]
  270. result = codegen(name_expr, "C89", "file", header=False, empty=False)
  271. assert result[0][0] == "file.c"
  272. assert result[0][1] == (
  273. '#include "file.h"\n#include <math.h>\n'
  274. 'double test1(double x, double y, double z) {\n'
  275. ' double test1_result;\n'
  276. ' test1_result = '
  277. 'pow(sin(x), 7) + '
  278. '7*pow(sin(x), 6)*cos(y) + '
  279. '7*pow(sin(x), 6)*tan(z) + '
  280. '21*pow(sin(x), 5)*pow(cos(y), 2) + '
  281. '42*pow(sin(x), 5)*cos(y)*tan(z) + '
  282. '21*pow(sin(x), 5)*pow(tan(z), 2) + '
  283. '35*pow(sin(x), 4)*pow(cos(y), 3) + '
  284. '105*pow(sin(x), 4)*pow(cos(y), 2)*tan(z) + '
  285. '105*pow(sin(x), 4)*cos(y)*pow(tan(z), 2) + '
  286. '35*pow(sin(x), 4)*pow(tan(z), 3) + '
  287. '35*pow(sin(x), 3)*pow(cos(y), 4) + '
  288. '140*pow(sin(x), 3)*pow(cos(y), 3)*tan(z) + '
  289. '210*pow(sin(x), 3)*pow(cos(y), 2)*pow(tan(z), 2) + '
  290. '140*pow(sin(x), 3)*cos(y)*pow(tan(z), 3) + '
  291. '35*pow(sin(x), 3)*pow(tan(z), 4) + '
  292. '21*pow(sin(x), 2)*pow(cos(y), 5) + '
  293. '105*pow(sin(x), 2)*pow(cos(y), 4)*tan(z) + '
  294. '210*pow(sin(x), 2)*pow(cos(y), 3)*pow(tan(z), 2) + '
  295. '210*pow(sin(x), 2)*pow(cos(y), 2)*pow(tan(z), 3) + '
  296. '105*pow(sin(x), 2)*cos(y)*pow(tan(z), 4) + '
  297. '21*pow(sin(x), 2)*pow(tan(z), 5) + '
  298. '7*sin(x)*pow(cos(y), 6) + '
  299. '42*sin(x)*pow(cos(y), 5)*tan(z) + '
  300. '105*sin(x)*pow(cos(y), 4)*pow(tan(z), 2) + '
  301. '140*sin(x)*pow(cos(y), 3)*pow(tan(z), 3) + '
  302. '105*sin(x)*pow(cos(y), 2)*pow(tan(z), 4) + '
  303. '42*sin(x)*cos(y)*pow(tan(z), 5) + '
  304. '7*sin(x)*pow(tan(z), 6) + '
  305. 'pow(cos(y), 7) + '
  306. '7*pow(cos(y), 6)*tan(z) + '
  307. '21*pow(cos(y), 5)*pow(tan(z), 2) + '
  308. '35*pow(cos(y), 4)*pow(tan(z), 3) + '
  309. '35*pow(cos(y), 3)*pow(tan(z), 4) + '
  310. '21*pow(cos(y), 2)*pow(tan(z), 5) + '
  311. '7*cos(y)*pow(tan(z), 6) + '
  312. 'pow(tan(z), 7);\n'
  313. ' return test1_result;\n'
  314. '}\n'
  315. 'double test2(double x, double y, double z) {\n'
  316. ' double test2_result;\n'
  317. ' test2_result = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))));\n'
  318. ' return test2_result;\n'
  319. '}\n'
  320. )
  321. assert result[1][0] == "file.h"
  322. assert result[1][1] == (
  323. '#ifndef PROJECT__FILE__H\n'
  324. '#define PROJECT__FILE__H\n'
  325. 'double test1(double x, double y, double z);\n'
  326. 'double test2(double x, double y, double z);\n'
  327. '#endif\n'
  328. )
  329. def test_loops_c():
  330. from sympy.tensor import IndexedBase, Idx
  331. from sympy.core.symbol import symbols
  332. n, m = symbols('n m', integer=True)
  333. A = IndexedBase('A')
  334. x = IndexedBase('x')
  335. y = IndexedBase('y')
  336. i = Idx('i', m)
  337. j = Idx('j', n)
  338. (f1, code), (f2, interface) = codegen(
  339. ('matrix_vector', Eq(y[i], A[i, j]*x[j])), "C99", "file", header=False, empty=False)
  340. assert f1 == 'file.c'
  341. expected = (
  342. '#include "file.h"\n'
  343. '#include <math.h>\n'
  344. 'void matrix_vector(double *A, int m, int n, double *x, double *y) {\n'
  345. ' for (int i=0; i<m; i++){\n'
  346. ' y[i] = 0;\n'
  347. ' }\n'
  348. ' for (int i=0; i<m; i++){\n'
  349. ' for (int j=0; j<n; j++){\n'
  350. ' y[i] = %(rhs)s + y[i];\n'
  351. ' }\n'
  352. ' }\n'
  353. '}\n'
  354. )
  355. assert (code == expected % {'rhs': 'A[%s]*x[j]' % (i*n + j)} or
  356. code == expected % {'rhs': 'A[%s]*x[j]' % (j + i*n)} or
  357. code == expected % {'rhs': 'x[j]*A[%s]' % (i*n + j)} or
  358. code == expected % {'rhs': 'x[j]*A[%s]' % (j + i*n)})
  359. assert f2 == 'file.h'
  360. assert interface == (
  361. '#ifndef PROJECT__FILE__H\n'
  362. '#define PROJECT__FILE__H\n'
  363. 'void matrix_vector(double *A, int m, int n, double *x, double *y);\n'
  364. '#endif\n'
  365. )
  366. def test_dummy_loops_c():
  367. from sympy.tensor import IndexedBase, Idx
  368. i, m = symbols('i m', integer=True, cls=Dummy)
  369. x = IndexedBase('x')
  370. y = IndexedBase('y')
  371. i = Idx(i, m)
  372. expected = (
  373. '#include "file.h"\n'
  374. '#include <math.h>\n'
  375. 'void test_dummies(int m_%(mno)i, double *x, double *y) {\n'
  376. ' for (int i_%(ino)i=0; i_%(ino)i<m_%(mno)i; i_%(ino)i++){\n'
  377. ' y[i_%(ino)i] = x[i_%(ino)i];\n'
  378. ' }\n'
  379. '}\n'
  380. ) % {'ino': i.label.dummy_index, 'mno': m.dummy_index}
  381. r = make_routine('test_dummies', Eq(y[i], x[i]))
  382. c89 = C89CodeGen()
  383. c99 = C99CodeGen()
  384. code = get_string(c99.dump_c, [r])
  385. assert code == expected
  386. with raises(NotImplementedError):
  387. get_string(c89.dump_c, [r])
  388. def test_partial_loops_c():
  389. # check that loop boundaries are determined by Idx, and array strides
  390. # determined by shape of IndexedBase object.
  391. from sympy.tensor import IndexedBase, Idx
  392. from sympy.core.symbol import symbols
  393. n, m, o, p = symbols('n m o p', integer=True)
  394. A = IndexedBase('A', shape=(m, p))
  395. x = IndexedBase('x')
  396. y = IndexedBase('y')
  397. i = Idx('i', (o, m - 5)) # Note: bounds are inclusive
  398. j = Idx('j', n) # dimension n corresponds to bounds (0, n - 1)
  399. (f1, code), (f2, interface) = codegen(
  400. ('matrix_vector', Eq(y[i], A[i, j]*x[j])), "C99", "file", header=False, empty=False)
  401. assert f1 == 'file.c'
  402. expected = (
  403. '#include "file.h"\n'
  404. '#include <math.h>\n'
  405. 'void matrix_vector(double *A, int m, int n, int o, int p, double *x, double *y) {\n'
  406. ' for (int i=o; i<%(upperi)s; i++){\n'
  407. ' y[i] = 0;\n'
  408. ' }\n'
  409. ' for (int i=o; i<%(upperi)s; i++){\n'
  410. ' for (int j=0; j<n; j++){\n'
  411. ' y[i] = %(rhs)s + y[i];\n'
  412. ' }\n'
  413. ' }\n'
  414. '}\n'
  415. ) % {'upperi': m - 4, 'rhs': '%(rhs)s'}
  416. assert (code == expected % {'rhs': 'A[%s]*x[j]' % (i*p + j)} or
  417. code == expected % {'rhs': 'A[%s]*x[j]' % (j + i*p)} or
  418. code == expected % {'rhs': 'x[j]*A[%s]' % (i*p + j)} or
  419. code == expected % {'rhs': 'x[j]*A[%s]' % (j + i*p)})
  420. assert f2 == 'file.h'
  421. assert interface == (
  422. '#ifndef PROJECT__FILE__H\n'
  423. '#define PROJECT__FILE__H\n'
  424. 'void matrix_vector(double *A, int m, int n, int o, int p, double *x, double *y);\n'
  425. '#endif\n'
  426. )
  427. def test_output_arg_c():
  428. from sympy.core.relational import Equality
  429. from sympy.functions.elementary.trigonometric import (cos, sin)
  430. x, y, z = symbols("x,y,z")
  431. r = make_routine("foo", [Equality(y, sin(x)), cos(x)])
  432. c = C89CodeGen()
  433. result = c.write([r], "test", header=False, empty=False)
  434. assert result[0][0] == "test.c"
  435. expected = (
  436. '#include "test.h"\n'
  437. '#include <math.h>\n'
  438. 'double foo(double x, double *y) {\n'
  439. ' (*y) = sin(x);\n'
  440. ' double foo_result;\n'
  441. ' foo_result = cos(x);\n'
  442. ' return foo_result;\n'
  443. '}\n'
  444. )
  445. assert result[0][1] == expected
  446. def test_output_arg_c_reserved_words():
  447. from sympy.core.relational import Equality
  448. from sympy.functions.elementary.trigonometric import (cos, sin)
  449. x, y, z = symbols("if, while, z")
  450. r = make_routine("foo", [Equality(y, sin(x)), cos(x)])
  451. c = C89CodeGen()
  452. result = c.write([r], "test", header=False, empty=False)
  453. assert result[0][0] == "test.c"
  454. expected = (
  455. '#include "test.h"\n'
  456. '#include <math.h>\n'
  457. 'double foo(double if_, double *while_) {\n'
  458. ' (*while_) = sin(if_);\n'
  459. ' double foo_result;\n'
  460. ' foo_result = cos(if_);\n'
  461. ' return foo_result;\n'
  462. '}\n'
  463. )
  464. assert result[0][1] == expected
  465. def test_multidim_c_argument_cse():
  466. A_sym = MatrixSymbol('A', 3, 3)
  467. b_sym = MatrixSymbol('b', 3, 1)
  468. A = Matrix(A_sym)
  469. b = Matrix(b_sym)
  470. c = A*b
  471. cgen = CCodeGen(project="test", cse=True)
  472. r = cgen.routine("c", c)
  473. r.arguments[-1].result_var = "out"
  474. r.arguments[-1]._name = "out"
  475. code = get_string(cgen.dump_c, [r], prefix="test")
  476. expected = (
  477. '#include "test.h"\n'
  478. "#include <math.h>\n"
  479. "void c(double *A, double *b, double *out) {\n"
  480. " out[0] = A[0]*b[0] + A[1]*b[1] + A[2]*b[2];\n"
  481. " out[1] = A[3]*b[0] + A[4]*b[1] + A[5]*b[2];\n"
  482. " out[2] = A[6]*b[0] + A[7]*b[1] + A[8]*b[2];\n"
  483. "}\n"
  484. )
  485. assert code == expected
  486. def test_ccode_results_named_ordered():
  487. x, y, z = symbols('x,y,z')
  488. B, C = symbols('B,C')
  489. A = MatrixSymbol('A', 1, 3)
  490. expr1 = Equality(A, Matrix([[1, 2, x]]))
  491. expr2 = Equality(C, (x + y)*z)
  492. expr3 = Equality(B, 2*x)
  493. name_expr = ("test", [expr1, expr2, expr3])
  494. expected = (
  495. '#include "test.h"\n'
  496. '#include <math.h>\n'
  497. 'void test(double x, double *C, double z, double y, double *A, double *B) {\n'
  498. ' (*C) = z*(x + y);\n'
  499. ' A[0] = 1;\n'
  500. ' A[1] = 2;\n'
  501. ' A[2] = x;\n'
  502. ' (*B) = 2*x;\n'
  503. '}\n'
  504. )
  505. result = codegen(name_expr, "c", "test", header=False, empty=False,
  506. argument_sequence=(x, C, z, y, A, B))
  507. source = result[0][1]
  508. assert source == expected
  509. def test_ccode_matrixsymbol_slice():
  510. A = MatrixSymbol('A', 5, 3)
  511. B = MatrixSymbol('B', 1, 3)
  512. C = MatrixSymbol('C', 1, 3)
  513. D = MatrixSymbol('D', 5, 1)
  514. name_expr = ("test", [Equality(B, A[0, :]),
  515. Equality(C, A[1, :]),
  516. Equality(D, A[:, 2])])
  517. result = codegen(name_expr, "c99", "test", header=False, empty=False)
  518. source = result[0][1]
  519. expected = (
  520. '#include "test.h"\n'
  521. '#include <math.h>\n'
  522. 'void test(double *A, double *B, double *C, double *D) {\n'
  523. ' B[0] = A[0];\n'
  524. ' B[1] = A[1];\n'
  525. ' B[2] = A[2];\n'
  526. ' C[0] = A[3];\n'
  527. ' C[1] = A[4];\n'
  528. ' C[2] = A[5];\n'
  529. ' D[0] = A[2];\n'
  530. ' D[1] = A[5];\n'
  531. ' D[2] = A[8];\n'
  532. ' D[3] = A[11];\n'
  533. ' D[4] = A[14];\n'
  534. '}\n'
  535. )
  536. assert source == expected
  537. def test_ccode_cse():
  538. a, b, c, d = symbols('a b c d')
  539. e = MatrixSymbol('e', 3, 1)
  540. name_expr = ("test", [Equality(e, Matrix([[a*b], [a*b + c*d], [a*b*c*d]]))])
  541. generator = CCodeGen(cse=True)
  542. result = codegen(name_expr, code_gen=generator, header=False, empty=False)
  543. source = result[0][1]
  544. expected = (
  545. '#include "test.h"\n'
  546. '#include <math.h>\n'
  547. 'void test(double a, double b, double c, double d, double *e) {\n'
  548. ' const double x0 = a*b;\n'
  549. ' const double x1 = c*d;\n'
  550. ' e[0] = x0;\n'
  551. ' e[1] = x0 + x1;\n'
  552. ' e[2] = x0*x1;\n'
  553. '}\n'
  554. )
  555. assert source == expected
  556. def test_ccode_unused_array_arg():
  557. x = MatrixSymbol('x', 2, 1)
  558. # x does not appear in output
  559. name_expr = ("test", 1.0)
  560. generator = CCodeGen()
  561. result = codegen(name_expr, code_gen=generator, header=False, empty=False, argument_sequence=(x,))
  562. source = result[0][1]
  563. # note: x should appear as (double *)
  564. expected = (
  565. '#include "test.h"\n'
  566. '#include <math.h>\n'
  567. 'double test(double *x) {\n'
  568. ' double test_result;\n'
  569. ' test_result = 1.0;\n'
  570. ' return test_result;\n'
  571. '}\n'
  572. )
  573. assert source == expected
  574. def test_ccode_unused_array_arg_func():
  575. # issue 16689
  576. X = MatrixSymbol('X',3,1)
  577. Y = MatrixSymbol('Y',3,1)
  578. z = symbols('z',integer = True)
  579. name_expr = ('testBug', X[0] + X[1])
  580. result = codegen(name_expr, language='C', header=False, empty=False, argument_sequence=(X, Y, z))
  581. source = result[0][1]
  582. expected = (
  583. '#include "testBug.h"\n'
  584. '#include <math.h>\n'
  585. 'double testBug(double *X, double *Y, int z) {\n'
  586. ' double testBug_result;\n'
  587. ' testBug_result = X[0] + X[1];\n'
  588. ' return testBug_result;\n'
  589. '}\n'
  590. )
  591. assert source == expected
  592. def test_empty_f_code():
  593. code_gen = FCodeGen()
  594. source = get_string(code_gen.dump_f95, [])
  595. assert source == ""
  596. def test_empty_f_code_with_header():
  597. code_gen = FCodeGen()
  598. source = get_string(code_gen.dump_f95, [], header=True)
  599. assert source[:82] == (
  600. "!******************************************************************************\n!*"
  601. )
  602. # " Code generated with SymPy 0.7.2-git "
  603. assert source[158:] == ( "*\n"
  604. "!* *\n"
  605. "!* See http://www.sympy.org/ for more information. *\n"
  606. "!* *\n"
  607. "!* This file is part of 'project' *\n"
  608. "!******************************************************************************\n"
  609. )
  610. def test_empty_f_header():
  611. code_gen = FCodeGen()
  612. source = get_string(code_gen.dump_h, [])
  613. assert source == ""
  614. def test_simple_f_code():
  615. x, y, z = symbols('x,y,z')
  616. expr = (x + y)*z
  617. routine = make_routine("test", expr)
  618. code_gen = FCodeGen()
  619. source = get_string(code_gen.dump_f95, [routine])
  620. expected = (
  621. "REAL*8 function test(x, y, z)\n"
  622. "implicit none\n"
  623. "REAL*8, intent(in) :: x\n"
  624. "REAL*8, intent(in) :: y\n"
  625. "REAL*8, intent(in) :: z\n"
  626. "test = z*(x + y)\n"
  627. "end function\n"
  628. )
  629. assert source == expected
  630. def test_numbersymbol_f_code():
  631. routine = make_routine("test", pi**Catalan)
  632. code_gen = FCodeGen()
  633. source = get_string(code_gen.dump_f95, [routine])
  634. expected = (
  635. "REAL*8 function test()\n"
  636. "implicit none\n"
  637. "REAL*8, parameter :: Catalan = %sd0\n"
  638. "REAL*8, parameter :: pi = %sd0\n"
  639. "test = pi**Catalan\n"
  640. "end function\n"
  641. ) % (Catalan.evalf(17), pi.evalf(17))
  642. assert source == expected
  643. def test_erf_f_code():
  644. x = symbols('x')
  645. routine = make_routine("test", erf(x) - erf(-2 * x))
  646. code_gen = FCodeGen()
  647. source = get_string(code_gen.dump_f95, [routine])
  648. expected = (
  649. "REAL*8 function test(x)\n"
  650. "implicit none\n"
  651. "REAL*8, intent(in) :: x\n"
  652. "test = erf(x) + erf(2.0d0*x)\n"
  653. "end function\n"
  654. )
  655. assert source == expected, source
  656. def test_f_code_argument_order():
  657. x, y, z = symbols('x,y,z')
  658. expr = x + y
  659. routine = make_routine("test", expr, argument_sequence=[z, x, y])
  660. code_gen = FCodeGen()
  661. source = get_string(code_gen.dump_f95, [routine])
  662. expected = (
  663. "REAL*8 function test(z, x, y)\n"
  664. "implicit none\n"
  665. "REAL*8, intent(in) :: z\n"
  666. "REAL*8, intent(in) :: x\n"
  667. "REAL*8, intent(in) :: y\n"
  668. "test = x + y\n"
  669. "end function\n"
  670. )
  671. assert source == expected
  672. def test_simple_f_header():
  673. x, y, z = symbols('x,y,z')
  674. expr = (x + y)*z
  675. routine = make_routine("test", expr)
  676. code_gen = FCodeGen()
  677. source = get_string(code_gen.dump_h, [routine])
  678. expected = (
  679. "interface\n"
  680. "REAL*8 function test(x, y, z)\n"
  681. "implicit none\n"
  682. "REAL*8, intent(in) :: x\n"
  683. "REAL*8, intent(in) :: y\n"
  684. "REAL*8, intent(in) :: z\n"
  685. "end function\n"
  686. "end interface\n"
  687. )
  688. assert source == expected
  689. def test_simple_f_codegen():
  690. x, y, z = symbols('x,y,z')
  691. expr = (x + y)*z
  692. result = codegen(
  693. ("test", expr), "F95", "file", header=False, empty=False)
  694. expected = [
  695. ("file.f90",
  696. "REAL*8 function test(x, y, z)\n"
  697. "implicit none\n"
  698. "REAL*8, intent(in) :: x\n"
  699. "REAL*8, intent(in) :: y\n"
  700. "REAL*8, intent(in) :: z\n"
  701. "test = z*(x + y)\n"
  702. "end function\n"),
  703. ("file.h",
  704. "interface\n"
  705. "REAL*8 function test(x, y, z)\n"
  706. "implicit none\n"
  707. "REAL*8, intent(in) :: x\n"
  708. "REAL*8, intent(in) :: y\n"
  709. "REAL*8, intent(in) :: z\n"
  710. "end function\n"
  711. "end interface\n")
  712. ]
  713. assert result == expected
  714. def test_multiple_results_f():
  715. x, y, z = symbols('x,y,z')
  716. expr1 = (x + y)*z
  717. expr2 = (x - y)*z
  718. routine = make_routine(
  719. "test",
  720. [expr1, expr2]
  721. )
  722. code_gen = FCodeGen()
  723. raises(CodeGenError, lambda: get_string(code_gen.dump_h, [routine]))
  724. def test_no_results_f():
  725. raises(ValueError, lambda: make_routine("test", []))
  726. def test_intrinsic_math_codegen():
  727. # not included: log10
  728. from sympy.functions.elementary.complexes import Abs
  729. from sympy.functions.elementary.exponential import log
  730. from sympy.functions.elementary.hyperbolic import (cosh, sinh, tanh)
  731. from sympy.functions.elementary.miscellaneous import sqrt
  732. from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sin, tan)
  733. x = symbols('x')
  734. name_expr = [
  735. ("test_abs", Abs(x)),
  736. ("test_acos", acos(x)),
  737. ("test_asin", asin(x)),
  738. ("test_atan", atan(x)),
  739. ("test_cos", cos(x)),
  740. ("test_cosh", cosh(x)),
  741. ("test_log", log(x)),
  742. ("test_ln", log(x)),
  743. ("test_sin", sin(x)),
  744. ("test_sinh", sinh(x)),
  745. ("test_sqrt", sqrt(x)),
  746. ("test_tan", tan(x)),
  747. ("test_tanh", tanh(x)),
  748. ]
  749. result = codegen(name_expr, "F95", "file", header=False, empty=False)
  750. assert result[0][0] == "file.f90"
  751. expected = (
  752. 'REAL*8 function test_abs(x)\n'
  753. 'implicit none\n'
  754. 'REAL*8, intent(in) :: x\n'
  755. 'test_abs = abs(x)\n'
  756. 'end function\n'
  757. 'REAL*8 function test_acos(x)\n'
  758. 'implicit none\n'
  759. 'REAL*8, intent(in) :: x\n'
  760. 'test_acos = acos(x)\n'
  761. 'end function\n'
  762. 'REAL*8 function test_asin(x)\n'
  763. 'implicit none\n'
  764. 'REAL*8, intent(in) :: x\n'
  765. 'test_asin = asin(x)\n'
  766. 'end function\n'
  767. 'REAL*8 function test_atan(x)\n'
  768. 'implicit none\n'
  769. 'REAL*8, intent(in) :: x\n'
  770. 'test_atan = atan(x)\n'
  771. 'end function\n'
  772. 'REAL*8 function test_cos(x)\n'
  773. 'implicit none\n'
  774. 'REAL*8, intent(in) :: x\n'
  775. 'test_cos = cos(x)\n'
  776. 'end function\n'
  777. 'REAL*8 function test_cosh(x)\n'
  778. 'implicit none\n'
  779. 'REAL*8, intent(in) :: x\n'
  780. 'test_cosh = cosh(x)\n'
  781. 'end function\n'
  782. 'REAL*8 function test_log(x)\n'
  783. 'implicit none\n'
  784. 'REAL*8, intent(in) :: x\n'
  785. 'test_log = log(x)\n'
  786. 'end function\n'
  787. 'REAL*8 function test_ln(x)\n'
  788. 'implicit none\n'
  789. 'REAL*8, intent(in) :: x\n'
  790. 'test_ln = log(x)\n'
  791. 'end function\n'
  792. 'REAL*8 function test_sin(x)\n'
  793. 'implicit none\n'
  794. 'REAL*8, intent(in) :: x\n'
  795. 'test_sin = sin(x)\n'
  796. 'end function\n'
  797. 'REAL*8 function test_sinh(x)\n'
  798. 'implicit none\n'
  799. 'REAL*8, intent(in) :: x\n'
  800. 'test_sinh = sinh(x)\n'
  801. 'end function\n'
  802. 'REAL*8 function test_sqrt(x)\n'
  803. 'implicit none\n'
  804. 'REAL*8, intent(in) :: x\n'
  805. 'test_sqrt = sqrt(x)\n'
  806. 'end function\n'
  807. 'REAL*8 function test_tan(x)\n'
  808. 'implicit none\n'
  809. 'REAL*8, intent(in) :: x\n'
  810. 'test_tan = tan(x)\n'
  811. 'end function\n'
  812. 'REAL*8 function test_tanh(x)\n'
  813. 'implicit none\n'
  814. 'REAL*8, intent(in) :: x\n'
  815. 'test_tanh = tanh(x)\n'
  816. 'end function\n'
  817. )
  818. assert result[0][1] == expected
  819. assert result[1][0] == "file.h"
  820. expected = (
  821. 'interface\n'
  822. 'REAL*8 function test_abs(x)\n'
  823. 'implicit none\n'
  824. 'REAL*8, intent(in) :: x\n'
  825. 'end function\n'
  826. 'end interface\n'
  827. 'interface\n'
  828. 'REAL*8 function test_acos(x)\n'
  829. 'implicit none\n'
  830. 'REAL*8, intent(in) :: x\n'
  831. 'end function\n'
  832. 'end interface\n'
  833. 'interface\n'
  834. 'REAL*8 function test_asin(x)\n'
  835. 'implicit none\n'
  836. 'REAL*8, intent(in) :: x\n'
  837. 'end function\n'
  838. 'end interface\n'
  839. 'interface\n'
  840. 'REAL*8 function test_atan(x)\n'
  841. 'implicit none\n'
  842. 'REAL*8, intent(in) :: x\n'
  843. 'end function\n'
  844. 'end interface\n'
  845. 'interface\n'
  846. 'REAL*8 function test_cos(x)\n'
  847. 'implicit none\n'
  848. 'REAL*8, intent(in) :: x\n'
  849. 'end function\n'
  850. 'end interface\n'
  851. 'interface\n'
  852. 'REAL*8 function test_cosh(x)\n'
  853. 'implicit none\n'
  854. 'REAL*8, intent(in) :: x\n'
  855. 'end function\n'
  856. 'end interface\n'
  857. 'interface\n'
  858. 'REAL*8 function test_log(x)\n'
  859. 'implicit none\n'
  860. 'REAL*8, intent(in) :: x\n'
  861. 'end function\n'
  862. 'end interface\n'
  863. 'interface\n'
  864. 'REAL*8 function test_ln(x)\n'
  865. 'implicit none\n'
  866. 'REAL*8, intent(in) :: x\n'
  867. 'end function\n'
  868. 'end interface\n'
  869. 'interface\n'
  870. 'REAL*8 function test_sin(x)\n'
  871. 'implicit none\n'
  872. 'REAL*8, intent(in) :: x\n'
  873. 'end function\n'
  874. 'end interface\n'
  875. 'interface\n'
  876. 'REAL*8 function test_sinh(x)\n'
  877. 'implicit none\n'
  878. 'REAL*8, intent(in) :: x\n'
  879. 'end function\n'
  880. 'end interface\n'
  881. 'interface\n'
  882. 'REAL*8 function test_sqrt(x)\n'
  883. 'implicit none\n'
  884. 'REAL*8, intent(in) :: x\n'
  885. 'end function\n'
  886. 'end interface\n'
  887. 'interface\n'
  888. 'REAL*8 function test_tan(x)\n'
  889. 'implicit none\n'
  890. 'REAL*8, intent(in) :: x\n'
  891. 'end function\n'
  892. 'end interface\n'
  893. 'interface\n'
  894. 'REAL*8 function test_tanh(x)\n'
  895. 'implicit none\n'
  896. 'REAL*8, intent(in) :: x\n'
  897. 'end function\n'
  898. 'end interface\n'
  899. )
  900. assert result[1][1] == expected
  901. def test_intrinsic_math2_codegen():
  902. # not included: frexp, ldexp, modf, fmod
  903. from sympy.functions.elementary.trigonometric import atan2
  904. x, y = symbols('x,y')
  905. name_expr = [
  906. ("test_atan2", atan2(x, y)),
  907. ("test_pow", x**y),
  908. ]
  909. result = codegen(name_expr, "F95", "file", header=False, empty=False)
  910. assert result[0][0] == "file.f90"
  911. expected = (
  912. 'REAL*8 function test_atan2(x, y)\n'
  913. 'implicit none\n'
  914. 'REAL*8, intent(in) :: x\n'
  915. 'REAL*8, intent(in) :: y\n'
  916. 'test_atan2 = atan2(x, y)\n'
  917. 'end function\n'
  918. 'REAL*8 function test_pow(x, y)\n'
  919. 'implicit none\n'
  920. 'REAL*8, intent(in) :: x\n'
  921. 'REAL*8, intent(in) :: y\n'
  922. 'test_pow = x**y\n'
  923. 'end function\n'
  924. )
  925. assert result[0][1] == expected
  926. assert result[1][0] == "file.h"
  927. expected = (
  928. 'interface\n'
  929. 'REAL*8 function test_atan2(x, y)\n'
  930. 'implicit none\n'
  931. 'REAL*8, intent(in) :: x\n'
  932. 'REAL*8, intent(in) :: y\n'
  933. 'end function\n'
  934. 'end interface\n'
  935. 'interface\n'
  936. 'REAL*8 function test_pow(x, y)\n'
  937. 'implicit none\n'
  938. 'REAL*8, intent(in) :: x\n'
  939. 'REAL*8, intent(in) :: y\n'
  940. 'end function\n'
  941. 'end interface\n'
  942. )
  943. assert result[1][1] == expected
  944. def test_complicated_codegen_f95():
  945. from sympy.functions.elementary.trigonometric import (cos, sin, tan)
  946. x, y, z = symbols('x,y,z')
  947. name_expr = [
  948. ("test1", ((sin(x) + cos(y) + tan(z))**7).expand()),
  949. ("test2", cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))),
  950. ]
  951. result = codegen(name_expr, "F95", "file", header=False, empty=False)
  952. assert result[0][0] == "file.f90"
  953. expected = (
  954. 'REAL*8 function test1(x, y, z)\n'
  955. 'implicit none\n'
  956. 'REAL*8, intent(in) :: x\n'
  957. 'REAL*8, intent(in) :: y\n'
  958. 'REAL*8, intent(in) :: z\n'
  959. 'test1 = sin(x)**7 + 7*sin(x)**6*cos(y) + 7*sin(x)**6*tan(z) + 21*sin(x) &\n'
  960. ' **5*cos(y)**2 + 42*sin(x)**5*cos(y)*tan(z) + 21*sin(x)**5*tan(z) &\n'
  961. ' **2 + 35*sin(x)**4*cos(y)**3 + 105*sin(x)**4*cos(y)**2*tan(z) + &\n'
  962. ' 105*sin(x)**4*cos(y)*tan(z)**2 + 35*sin(x)**4*tan(z)**3 + 35*sin( &\n'
  963. ' x)**3*cos(y)**4 + 140*sin(x)**3*cos(y)**3*tan(z) + 210*sin(x)**3* &\n'
  964. ' cos(y)**2*tan(z)**2 + 140*sin(x)**3*cos(y)*tan(z)**3 + 35*sin(x) &\n'
  965. ' **3*tan(z)**4 + 21*sin(x)**2*cos(y)**5 + 105*sin(x)**2*cos(y)**4* &\n'
  966. ' tan(z) + 210*sin(x)**2*cos(y)**3*tan(z)**2 + 210*sin(x)**2*cos(y) &\n'
  967. ' **2*tan(z)**3 + 105*sin(x)**2*cos(y)*tan(z)**4 + 21*sin(x)**2*tan &\n'
  968. ' (z)**5 + 7*sin(x)*cos(y)**6 + 42*sin(x)*cos(y)**5*tan(z) + 105* &\n'
  969. ' sin(x)*cos(y)**4*tan(z)**2 + 140*sin(x)*cos(y)**3*tan(z)**3 + 105 &\n'
  970. ' *sin(x)*cos(y)**2*tan(z)**4 + 42*sin(x)*cos(y)*tan(z)**5 + 7*sin( &\n'
  971. ' x)*tan(z)**6 + cos(y)**7 + 7*cos(y)**6*tan(z) + 21*cos(y)**5*tan( &\n'
  972. ' z)**2 + 35*cos(y)**4*tan(z)**3 + 35*cos(y)**3*tan(z)**4 + 21*cos( &\n'
  973. ' y)**2*tan(z)**5 + 7*cos(y)*tan(z)**6 + tan(z)**7\n'
  974. 'end function\n'
  975. 'REAL*8 function test2(x, y, z)\n'
  976. 'implicit none\n'
  977. 'REAL*8, intent(in) :: x\n'
  978. 'REAL*8, intent(in) :: y\n'
  979. 'REAL*8, intent(in) :: z\n'
  980. 'test2 = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))\n'
  981. 'end function\n'
  982. )
  983. assert result[0][1] == expected
  984. assert result[1][0] == "file.h"
  985. expected = (
  986. 'interface\n'
  987. 'REAL*8 function test1(x, y, z)\n'
  988. 'implicit none\n'
  989. 'REAL*8, intent(in) :: x\n'
  990. 'REAL*8, intent(in) :: y\n'
  991. 'REAL*8, intent(in) :: z\n'
  992. 'end function\n'
  993. 'end interface\n'
  994. 'interface\n'
  995. 'REAL*8 function test2(x, y, z)\n'
  996. 'implicit none\n'
  997. 'REAL*8, intent(in) :: x\n'
  998. 'REAL*8, intent(in) :: y\n'
  999. 'REAL*8, intent(in) :: z\n'
  1000. 'end function\n'
  1001. 'end interface\n'
  1002. )
  1003. assert result[1][1] == expected
  1004. def test_loops():
  1005. from sympy.tensor import IndexedBase, Idx
  1006. from sympy.core.symbol import symbols
  1007. n, m = symbols('n,m', integer=True)
  1008. A, x, y = map(IndexedBase, 'Axy')
  1009. i = Idx('i', m)
  1010. j = Idx('j', n)
  1011. (f1, code), (f2, interface) = codegen(
  1012. ('matrix_vector', Eq(y[i], A[i, j]*x[j])), "F95", "file", header=False, empty=False)
  1013. assert f1 == 'file.f90'
  1014. expected = (
  1015. 'subroutine matrix_vector(A, m, n, x, y)\n'
  1016. 'implicit none\n'
  1017. 'INTEGER*4, intent(in) :: m\n'
  1018. 'INTEGER*4, intent(in) :: n\n'
  1019. 'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n'
  1020. 'REAL*8, intent(in), dimension(1:n) :: x\n'
  1021. 'REAL*8, intent(out), dimension(1:m) :: y\n'
  1022. 'INTEGER*4 :: i\n'
  1023. 'INTEGER*4 :: j\n'
  1024. 'do i = 1, m\n'
  1025. ' y(i) = 0\n'
  1026. 'end do\n'
  1027. 'do i = 1, m\n'
  1028. ' do j = 1, n\n'
  1029. ' y(i) = %(rhs)s + y(i)\n'
  1030. ' end do\n'
  1031. 'end do\n'
  1032. 'end subroutine\n'
  1033. )
  1034. assert code == expected % {'rhs': 'A(i, j)*x(j)'} or\
  1035. code == expected % {'rhs': 'x(j)*A(i, j)'}
  1036. assert f2 == 'file.h'
  1037. assert interface == (
  1038. 'interface\n'
  1039. 'subroutine matrix_vector(A, m, n, x, y)\n'
  1040. 'implicit none\n'
  1041. 'INTEGER*4, intent(in) :: m\n'
  1042. 'INTEGER*4, intent(in) :: n\n'
  1043. 'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n'
  1044. 'REAL*8, intent(in), dimension(1:n) :: x\n'
  1045. 'REAL*8, intent(out), dimension(1:m) :: y\n'
  1046. 'end subroutine\n'
  1047. 'end interface\n'
  1048. )
  1049. def test_dummy_loops_f95():
  1050. from sympy.tensor import IndexedBase, Idx
  1051. i, m = symbols('i m', integer=True, cls=Dummy)
  1052. x = IndexedBase('x')
  1053. y = IndexedBase('y')
  1054. i = Idx(i, m)
  1055. expected = (
  1056. 'subroutine test_dummies(m_%(mcount)i, x, y)\n'
  1057. 'implicit none\n'
  1058. 'INTEGER*4, intent(in) :: m_%(mcount)i\n'
  1059. 'REAL*8, intent(in), dimension(1:m_%(mcount)i) :: x\n'
  1060. 'REAL*8, intent(out), dimension(1:m_%(mcount)i) :: y\n'
  1061. 'INTEGER*4 :: i_%(icount)i\n'
  1062. 'do i_%(icount)i = 1, m_%(mcount)i\n'
  1063. ' y(i_%(icount)i) = x(i_%(icount)i)\n'
  1064. 'end do\n'
  1065. 'end subroutine\n'
  1066. ) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index}
  1067. r = make_routine('test_dummies', Eq(y[i], x[i]))
  1068. c = FCodeGen()
  1069. code = get_string(c.dump_f95, [r])
  1070. assert code == expected
  1071. def test_loops_InOut():
  1072. from sympy.tensor import IndexedBase, Idx
  1073. from sympy.core.symbol import symbols
  1074. i, j, n, m = symbols('i,j,n,m', integer=True)
  1075. A, x, y = symbols('A,x,y')
  1076. A = IndexedBase(A)[Idx(i, m), Idx(j, n)]
  1077. x = IndexedBase(x)[Idx(j, n)]
  1078. y = IndexedBase(y)[Idx(i, m)]
  1079. (f1, code), (f2, interface) = codegen(
  1080. ('matrix_vector', Eq(y, y + A*x)), "F95", "file", header=False, empty=False)
  1081. assert f1 == 'file.f90'
  1082. expected = (
  1083. 'subroutine matrix_vector(A, m, n, x, y)\n'
  1084. 'implicit none\n'
  1085. 'INTEGER*4, intent(in) :: m\n'
  1086. 'INTEGER*4, intent(in) :: n\n'
  1087. 'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n'
  1088. 'REAL*8, intent(in), dimension(1:n) :: x\n'
  1089. 'REAL*8, intent(inout), dimension(1:m) :: y\n'
  1090. 'INTEGER*4 :: i\n'
  1091. 'INTEGER*4 :: j\n'
  1092. 'do i = 1, m\n'
  1093. ' do j = 1, n\n'
  1094. ' y(i) = %(rhs)s + y(i)\n'
  1095. ' end do\n'
  1096. 'end do\n'
  1097. 'end subroutine\n'
  1098. )
  1099. assert (code == expected % {'rhs': 'A(i, j)*x(j)'} or
  1100. code == expected % {'rhs': 'x(j)*A(i, j)'})
  1101. assert f2 == 'file.h'
  1102. assert interface == (
  1103. 'interface\n'
  1104. 'subroutine matrix_vector(A, m, n, x, y)\n'
  1105. 'implicit none\n'
  1106. 'INTEGER*4, intent(in) :: m\n'
  1107. 'INTEGER*4, intent(in) :: n\n'
  1108. 'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n'
  1109. 'REAL*8, intent(in), dimension(1:n) :: x\n'
  1110. 'REAL*8, intent(inout), dimension(1:m) :: y\n'
  1111. 'end subroutine\n'
  1112. 'end interface\n'
  1113. )
  1114. def test_partial_loops_f():
  1115. # check that loop boundaries are determined by Idx, and array strides
  1116. # determined by shape of IndexedBase object.
  1117. from sympy.tensor import IndexedBase, Idx
  1118. from sympy.core.symbol import symbols
  1119. n, m, o, p = symbols('n m o p', integer=True)
  1120. A = IndexedBase('A', shape=(m, p))
  1121. x = IndexedBase('x')
  1122. y = IndexedBase('y')
  1123. i = Idx('i', (o, m - 5)) # Note: bounds are inclusive
  1124. j = Idx('j', n) # dimension n corresponds to bounds (0, n - 1)
  1125. (f1, code), (f2, interface) = codegen(
  1126. ('matrix_vector', Eq(y[i], A[i, j]*x[j])), "F95", "file", header=False, empty=False)
  1127. expected = (
  1128. 'subroutine matrix_vector(A, m, n, o, p, x, y)\n'
  1129. 'implicit none\n'
  1130. 'INTEGER*4, intent(in) :: m\n'
  1131. 'INTEGER*4, intent(in) :: n\n'
  1132. 'INTEGER*4, intent(in) :: o\n'
  1133. 'INTEGER*4, intent(in) :: p\n'
  1134. 'REAL*8, intent(in), dimension(1:m, 1:p) :: A\n'
  1135. 'REAL*8, intent(in), dimension(1:n) :: x\n'
  1136. 'REAL*8, intent(out), dimension(1:%(iup-ilow)s) :: y\n'
  1137. 'INTEGER*4 :: i\n'
  1138. 'INTEGER*4 :: j\n'
  1139. 'do i = %(ilow)s, %(iup)s\n'
  1140. ' y(i) = 0\n'
  1141. 'end do\n'
  1142. 'do i = %(ilow)s, %(iup)s\n'
  1143. ' do j = 1, n\n'
  1144. ' y(i) = %(rhs)s + y(i)\n'
  1145. ' end do\n'
  1146. 'end do\n'
  1147. 'end subroutine\n'
  1148. ) % {
  1149. 'rhs': '%(rhs)s',
  1150. 'iup': str(m - 4),
  1151. 'ilow': str(1 + o),
  1152. 'iup-ilow': str(m - 4 - o)
  1153. }
  1154. assert code == expected % {'rhs': 'A(i, j)*x(j)'} or\
  1155. code == expected % {'rhs': 'x(j)*A(i, j)'}
  1156. def test_output_arg_f():
  1157. from sympy.core.relational import Equality
  1158. from sympy.functions.elementary.trigonometric import (cos, sin)
  1159. x, y, z = symbols("x,y,z")
  1160. r = make_routine("foo", [Equality(y, sin(x)), cos(x)])
  1161. c = FCodeGen()
  1162. result = c.write([r], "test", header=False, empty=False)
  1163. assert result[0][0] == "test.f90"
  1164. assert result[0][1] == (
  1165. 'REAL*8 function foo(x, y)\n'
  1166. 'implicit none\n'
  1167. 'REAL*8, intent(in) :: x\n'
  1168. 'REAL*8, intent(out) :: y\n'
  1169. 'y = sin(x)\n'
  1170. 'foo = cos(x)\n'
  1171. 'end function\n'
  1172. )
  1173. def test_inline_function():
  1174. from sympy.tensor import IndexedBase, Idx
  1175. from sympy.core.symbol import symbols
  1176. n, m = symbols('n m', integer=True)
  1177. A, x, y = map(IndexedBase, 'Axy')
  1178. i = Idx('i', m)
  1179. p = FCodeGen()
  1180. func = implemented_function('func', Lambda(n, n*(n + 1)))
  1181. routine = make_routine('test_inline', Eq(y[i], func(x[i])))
  1182. code = get_string(p.dump_f95, [routine])
  1183. expected = (
  1184. 'subroutine test_inline(m, x, y)\n'
  1185. 'implicit none\n'
  1186. 'INTEGER*4, intent(in) :: m\n'
  1187. 'REAL*8, intent(in), dimension(1:m) :: x\n'
  1188. 'REAL*8, intent(out), dimension(1:m) :: y\n'
  1189. 'INTEGER*4 :: i\n'
  1190. 'do i = 1, m\n'
  1191. ' y(i) = %s*%s\n'
  1192. 'end do\n'
  1193. 'end subroutine\n'
  1194. )
  1195. args = ('x(i)', '(x(i) + 1)')
  1196. assert code == expected % args or\
  1197. code == expected % args[::-1]
  1198. def test_f_code_call_signature_wrap():
  1199. # Issue #7934
  1200. x = symbols('x:20')
  1201. expr = 0
  1202. for sym in x:
  1203. expr += sym
  1204. routine = make_routine("test", expr)
  1205. code_gen = FCodeGen()
  1206. source = get_string(code_gen.dump_f95, [routine])
  1207. expected = """\
  1208. REAL*8 function test(x0, x1, x10, x11, x12, x13, x14, x15, x16, x17, x18, &
  1209. x19, x2, x3, x4, x5, x6, x7, x8, x9)
  1210. implicit none
  1211. REAL*8, intent(in) :: x0
  1212. REAL*8, intent(in) :: x1
  1213. REAL*8, intent(in) :: x10
  1214. REAL*8, intent(in) :: x11
  1215. REAL*8, intent(in) :: x12
  1216. REAL*8, intent(in) :: x13
  1217. REAL*8, intent(in) :: x14
  1218. REAL*8, intent(in) :: x15
  1219. REAL*8, intent(in) :: x16
  1220. REAL*8, intent(in) :: x17
  1221. REAL*8, intent(in) :: x18
  1222. REAL*8, intent(in) :: x19
  1223. REAL*8, intent(in) :: x2
  1224. REAL*8, intent(in) :: x3
  1225. REAL*8, intent(in) :: x4
  1226. REAL*8, intent(in) :: x5
  1227. REAL*8, intent(in) :: x6
  1228. REAL*8, intent(in) :: x7
  1229. REAL*8, intent(in) :: x8
  1230. REAL*8, intent(in) :: x9
  1231. test = x0 + x1 + x10 + x11 + x12 + x13 + x14 + x15 + x16 + x17 + x18 + &
  1232. x19 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9
  1233. end function
  1234. """
  1235. assert source == expected
  1236. def test_check_case():
  1237. x, X = symbols('x,X')
  1238. raises(CodeGenError, lambda: codegen(('test', x*X), 'f95', 'prefix'))
  1239. def test_check_case_false_positive():
  1240. # The upper case/lower case exception should not be triggered by SymPy
  1241. # objects that differ only because of assumptions. (It may be useful to
  1242. # have a check for that as well, but here we only want to test against
  1243. # false positives with respect to case checking.)
  1244. x1 = symbols('x')
  1245. x2 = symbols('x', my_assumption=True)
  1246. try:
  1247. codegen(('test', x1*x2), 'f95', 'prefix')
  1248. except CodeGenError as e:
  1249. if e.args[0].startswith("Fortran ignores case."):
  1250. raise AssertionError("This exception should not be raised!")
  1251. def test_c_fortran_omit_routine_name():
  1252. x, y = symbols("x,y")
  1253. name_expr = [("foo", 2*x)]
  1254. result = codegen(name_expr, "F95", header=False, empty=False)
  1255. expresult = codegen(name_expr, "F95", "foo", header=False, empty=False)
  1256. assert result[0][1] == expresult[0][1]
  1257. name_expr = ("foo", x*y)
  1258. result = codegen(name_expr, "F95", header=False, empty=False)
  1259. expresult = codegen(name_expr, "F95", "foo", header=False, empty=False)
  1260. assert result[0][1] == expresult[0][1]
  1261. name_expr = ("foo", Matrix([[x, y], [x+y, x-y]]))
  1262. result = codegen(name_expr, "C89", header=False, empty=False)
  1263. expresult = codegen(name_expr, "C89", "foo", header=False, empty=False)
  1264. assert result[0][1] == expresult[0][1]
  1265. def test_fcode_matrix_output():
  1266. x, y, z = symbols('x,y,z')
  1267. e1 = x + y
  1268. e2 = Matrix([[x, y], [z, 16]])
  1269. name_expr = ("test", (e1, e2))
  1270. result = codegen(name_expr, "f95", "test", header=False, empty=False)
  1271. source = result[0][1]
  1272. expected = (
  1273. "REAL*8 function test(x, y, z, out_%(hash)s)\n"
  1274. "implicit none\n"
  1275. "REAL*8, intent(in) :: x\n"
  1276. "REAL*8, intent(in) :: y\n"
  1277. "REAL*8, intent(in) :: z\n"
  1278. "REAL*8, intent(out), dimension(1:2, 1:2) :: out_%(hash)s\n"
  1279. "out_%(hash)s(1, 1) = x\n"
  1280. "out_%(hash)s(2, 1) = z\n"
  1281. "out_%(hash)s(1, 2) = y\n"
  1282. "out_%(hash)s(2, 2) = 16\n"
  1283. "test = x + y\n"
  1284. "end function\n"
  1285. )
  1286. # look for the magic number
  1287. a = source.splitlines()[5]
  1288. b = a.split('_')
  1289. out = b[1]
  1290. expected = expected % {'hash': out}
  1291. assert source == expected
  1292. def test_fcode_results_named_ordered():
  1293. x, y, z = symbols('x,y,z')
  1294. B, C = symbols('B,C')
  1295. A = MatrixSymbol('A', 1, 3)
  1296. expr1 = Equality(A, Matrix([[1, 2, x]]))
  1297. expr2 = Equality(C, (x + y)*z)
  1298. expr3 = Equality(B, 2*x)
  1299. name_expr = ("test", [expr1, expr2, expr3])
  1300. result = codegen(name_expr, "f95", "test", header=False, empty=False,
  1301. argument_sequence=(x, z, y, C, A, B))
  1302. source = result[0][1]
  1303. expected = (
  1304. "subroutine test(x, z, y, C, A, B)\n"
  1305. "implicit none\n"
  1306. "REAL*8, intent(in) :: x\n"
  1307. "REAL*8, intent(in) :: z\n"
  1308. "REAL*8, intent(in) :: y\n"
  1309. "REAL*8, intent(out) :: C\n"
  1310. "REAL*8, intent(out) :: B\n"
  1311. "REAL*8, intent(out), dimension(1:1, 1:3) :: A\n"
  1312. "C = z*(x + y)\n"
  1313. "A(1, 1) = 1\n"
  1314. "A(1, 2) = 2\n"
  1315. "A(1, 3) = x\n"
  1316. "B = 2*x\n"
  1317. "end subroutine\n"
  1318. )
  1319. assert source == expected
  1320. def test_fcode_matrixsymbol_slice():
  1321. A = MatrixSymbol('A', 2, 3)
  1322. B = MatrixSymbol('B', 1, 3)
  1323. C = MatrixSymbol('C', 1, 3)
  1324. D = MatrixSymbol('D', 2, 1)
  1325. name_expr = ("test", [Equality(B, A[0, :]),
  1326. Equality(C, A[1, :]),
  1327. Equality(D, A[:, 2])])
  1328. result = codegen(name_expr, "f95", "test", header=False, empty=False)
  1329. source = result[0][1]
  1330. expected = (
  1331. "subroutine test(A, B, C, D)\n"
  1332. "implicit none\n"
  1333. "REAL*8, intent(in), dimension(1:2, 1:3) :: A\n"
  1334. "REAL*8, intent(out), dimension(1:1, 1:3) :: B\n"
  1335. "REAL*8, intent(out), dimension(1:1, 1:3) :: C\n"
  1336. "REAL*8, intent(out), dimension(1:2, 1:1) :: D\n"
  1337. "B(1, 1) = A(1, 1)\n"
  1338. "B(1, 2) = A(1, 2)\n"
  1339. "B(1, 3) = A(1, 3)\n"
  1340. "C(1, 1) = A(2, 1)\n"
  1341. "C(1, 2) = A(2, 2)\n"
  1342. "C(1, 3) = A(2, 3)\n"
  1343. "D(1, 1) = A(1, 3)\n"
  1344. "D(2, 1) = A(2, 3)\n"
  1345. "end subroutine\n"
  1346. )
  1347. assert source == expected
  1348. def test_fcode_matrixsymbol_slice_autoname():
  1349. # see issue #8093
  1350. A = MatrixSymbol('A', 2, 3)
  1351. name_expr = ("test", A[:, 1])
  1352. result = codegen(name_expr, "f95", "test", header=False, empty=False)
  1353. source = result[0][1]
  1354. expected = (
  1355. "subroutine test(A, out_%(hash)s)\n"
  1356. "implicit none\n"
  1357. "REAL*8, intent(in), dimension(1:2, 1:3) :: A\n"
  1358. "REAL*8, intent(out), dimension(1:2, 1:1) :: out_%(hash)s\n"
  1359. "out_%(hash)s(1, 1) = A(1, 2)\n"
  1360. "out_%(hash)s(2, 1) = A(2, 2)\n"
  1361. "end subroutine\n"
  1362. )
  1363. # look for the magic number
  1364. a = source.splitlines()[3]
  1365. b = a.split('_')
  1366. out = b[1]
  1367. expected = expected % {'hash': out}
  1368. assert source == expected
  1369. def test_global_vars():
  1370. x, y, z, t = symbols("x y z t")
  1371. result = codegen(('f', x*y), "F95", header=False, empty=False,
  1372. global_vars=(y,))
  1373. source = result[0][1]
  1374. expected = (
  1375. "REAL*8 function f(x)\n"
  1376. "implicit none\n"
  1377. "REAL*8, intent(in) :: x\n"
  1378. "f = x*y\n"
  1379. "end function\n"
  1380. )
  1381. assert source == expected
  1382. expected = (
  1383. '#include "f.h"\n'
  1384. '#include <math.h>\n'
  1385. 'double f(double x, double y) {\n'
  1386. ' double f_result;\n'
  1387. ' f_result = x*y + z;\n'
  1388. ' return f_result;\n'
  1389. '}\n'
  1390. )
  1391. result = codegen(('f', x*y+z), "C", header=False, empty=False,
  1392. global_vars=(z, t))
  1393. source = result[0][1]
  1394. assert source == expected
  1395. def test_custom_codegen():
  1396. from sympy.printing.c import C99CodePrinter
  1397. from sympy.functions.elementary.exponential import exp
  1398. printer = C99CodePrinter(settings={'user_functions': {'exp': 'fastexp'}})
  1399. x, y = symbols('x y')
  1400. expr = exp(x + y)
  1401. # replace math.h with a different header
  1402. gen = C99CodeGen(printer=printer,
  1403. preprocessor_statements=['#include "fastexp.h"'])
  1404. expected = (
  1405. '#include "expr.h"\n'
  1406. '#include "fastexp.h"\n'
  1407. 'double expr(double x, double y) {\n'
  1408. ' double expr_result;\n'
  1409. ' expr_result = fastexp(x + y);\n'
  1410. ' return expr_result;\n'
  1411. '}\n'
  1412. )
  1413. result = codegen(('expr', expr), header=False, empty=False, code_gen=gen)
  1414. source = result[0][1]
  1415. assert source == expected
  1416. # use both math.h and an external header
  1417. gen = C99CodeGen(printer=printer)
  1418. gen.preprocessor_statements.append('#include "fastexp.h"')
  1419. expected = (
  1420. '#include "expr.h"\n'
  1421. '#include <math.h>\n'
  1422. '#include "fastexp.h"\n'
  1423. 'double expr(double x, double y) {\n'
  1424. ' double expr_result;\n'
  1425. ' expr_result = fastexp(x + y);\n'
  1426. ' return expr_result;\n'
  1427. '}\n'
  1428. )
  1429. result = codegen(('expr', expr), header=False, empty=False, code_gen=gen)
  1430. source = result[0][1]
  1431. assert source == expected
  1432. def test_c_with_printer():
  1433. # issue 13586
  1434. from sympy.printing.c import C99CodePrinter
  1435. class CustomPrinter(C99CodePrinter):
  1436. def _print_Pow(self, expr):
  1437. return "fastpow({}, {})".format(self._print(expr.base),
  1438. self._print(expr.exp))
  1439. x = symbols('x')
  1440. expr = x**3
  1441. expected =[
  1442. ("file.c",
  1443. "#include \"file.h\"\n"
  1444. "#include <math.h>\n"
  1445. "double test(double x) {\n"
  1446. " double test_result;\n"
  1447. " test_result = fastpow(x, 3);\n"
  1448. " return test_result;\n"
  1449. "}\n"),
  1450. ("file.h",
  1451. "#ifndef PROJECT__FILE__H\n"
  1452. "#define PROJECT__FILE__H\n"
  1453. "double test(double x);\n"
  1454. "#endif\n")
  1455. ]
  1456. result = codegen(("test", expr), "C","file", header=False, empty=False, printer = CustomPrinter())
  1457. assert result == expected
  1458. def test_fcode_complex():
  1459. import sympy.utilities.codegen
  1460. sympy.utilities.codegen.COMPLEX_ALLOWED = True
  1461. x = Symbol('x', real=True)
  1462. y = Symbol('y',real=True)
  1463. result = codegen(('test',x+y), 'f95', 'test', header=False, empty=False)
  1464. source = (result[0][1])
  1465. expected = (
  1466. "REAL*8 function test(x, y)\n"
  1467. "implicit none\n"
  1468. "REAL*8, intent(in) :: x\n"
  1469. "REAL*8, intent(in) :: y\n"
  1470. "test = x + y\n"
  1471. "end function\n")
  1472. assert source == expected
  1473. x = Symbol('x')
  1474. y = Symbol('y',real=True)
  1475. result = codegen(('test',x+y), 'f95', 'test', header=False, empty=False)
  1476. source = (result[0][1])
  1477. expected = (
  1478. "COMPLEX*16 function test(x, y)\n"
  1479. "implicit none\n"
  1480. "COMPLEX*16, intent(in) :: x\n"
  1481. "REAL*8, intent(in) :: y\n"
  1482. "test = x + y\n"
  1483. "end function\n"
  1484. )
  1485. assert source==expected
  1486. sympy.utilities.codegen.COMPLEX_ALLOWED = False