test_lambdify.py 72 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263
  1. from itertools import product
  2. import math
  3. import inspect
  4. import linecache
  5. import gc
  6. import mpmath
  7. import cmath
  8. from sympy.testing.pytest import raises, warns_deprecated_sympy
  9. from sympy.concrete.summations import Sum
  10. from sympy.core.function import (Function, Lambda, diff)
  11. from sympy.core.numbers import (E, Float, I, Rational, all_close, oo, pi)
  12. from sympy.core.relational import Eq
  13. from sympy.core.singleton import S
  14. from sympy.core.symbol import (Dummy, symbols)
  15. from sympy.functions.combinatorial.factorials import (RisingFactorial, factorial)
  16. from sympy.functions.combinatorial.numbers import bernoulli, harmonic
  17. from sympy.functions.elementary.complexes import Abs, sign
  18. from sympy.functions.elementary.exponential import exp, log
  19. from sympy.functions.elementary.hyperbolic import asinh,acosh,atanh
  20. from sympy.functions.elementary.integers import floor
  21. from sympy.functions.elementary.miscellaneous import (Max, Min, sqrt)
  22. from sympy.functions.elementary.piecewise import Piecewise
  23. from sympy.functions.elementary.trigonometric import (asin, acos, atan, cos, cot, sin,
  24. sinc, tan)
  25. from sympy.functions import sinh,cosh,tanh
  26. from sympy.functions.special.bessel import (besseli, besselj, besselk, bessely, jn, yn)
  27. from sympy.functions.special.beta_functions import (beta, betainc, betainc_regularized)
  28. from sympy.functions.special.delta_functions import (Heaviside)
  29. from sympy.functions.special.error_functions import (Ei, erf, erfc, fresnelc, fresnels, Si, Ci)
  30. from sympy.functions.special.gamma_functions import (digamma, gamma, loggamma, polygamma)
  31. from sympy.functions.special.zeta_functions import zeta
  32. from sympy.integrals.integrals import Integral
  33. from sympy.logic.boolalg import (And, false, ITE, Not, Or, true)
  34. from sympy.matrices.expressions.dotproduct import DotProduct
  35. from sympy.simplify.cse_main import cse
  36. from sympy.tensor.array import derive_by_array, Array
  37. from sympy.tensor.array.expressions import ArraySymbol
  38. from sympy.tensor.indexed import IndexedBase, Idx
  39. from sympy.utilities.lambdify import lambdify
  40. from sympy.utilities.iterables import numbered_symbols
  41. from sympy.vector import CoordSys3D
  42. from sympy.core.expr import UnevaluatedExpr
  43. from sympy.codegen.cfunctions import expm1, log1p, exp2, log2, log10, hypot, isnan, isinf
  44. from sympy.codegen.numpy_nodes import logaddexp, logaddexp2, amin, amax, minimum, maximum
  45. from sympy.codegen.scipy_nodes import cosm1, powm1
  46. from sympy.functions.elementary.complexes import re, im, arg
  47. from sympy.functions.special.polynomials import \
  48. chebyshevt, chebyshevu, legendre, hermite, laguerre, gegenbauer, \
  49. assoc_legendre, assoc_laguerre, jacobi
  50. from sympy.matrices import Matrix, MatrixSymbol, SparseMatrix
  51. from sympy.printing.codeprinter import PrintMethodNotImplementedError
  52. from sympy.printing.lambdarepr import LambdaPrinter
  53. from sympy.printing.numpy import NumPyPrinter
  54. from sympy.utilities.lambdify import implemented_function, lambdastr
  55. from sympy.testing.pytest import skip
  56. from sympy.utilities.decorator import conserve_mpmath_dps
  57. from sympy.utilities.exceptions import ignore_warnings
  58. from sympy.external import import_module
  59. from sympy.functions.special.gamma_functions import uppergamma, lowergamma
  60. import sympy
  61. MutableDenseMatrix = Matrix
  62. numpy = import_module('numpy')
  63. scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']})
  64. numexpr = import_module('numexpr')
  65. tensorflow = import_module('tensorflow')
  66. cupy = import_module('cupy')
  67. jax = import_module('jax')
  68. numba = import_module('numba')
  69. if tensorflow:
  70. # Hide Tensorflow warnings
  71. import os
  72. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
  73. w, x, y, z = symbols('w,x,y,z')
  74. #================== Test different arguments =======================
  75. def test_no_args():
  76. f = lambdify([], 1)
  77. raises(TypeError, lambda: f(-1))
  78. assert f() == 1
  79. def test_single_arg():
  80. f = lambdify(x, 2*x)
  81. assert f(1) == 2
  82. def test_list_args():
  83. f = lambdify([x, y], x + y)
  84. assert f(1, 2) == 3
  85. def test_nested_args():
  86. f1 = lambdify([[w, x]], [w, x])
  87. assert f1([91, 2]) == [91, 2]
  88. raises(TypeError, lambda: f1(1, 2))
  89. f2 = lambdify([(w, x), (y, z)], [w, x, y, z])
  90. assert f2((18, 12), (73, 4)) == [18, 12, 73, 4]
  91. raises(TypeError, lambda: f2(3, 4))
  92. f3 = lambdify([w, [[[x]], y], z], [w, x, y, z])
  93. assert f3(10, [[[52]], 31], 44) == [10, 52, 31, 44]
  94. def test_str_args():
  95. f = lambdify('x,y,z', 'z,y,x')
  96. assert f(3, 2, 1) == (1, 2, 3)
  97. assert f(1.0, 2.0, 3.0) == (3.0, 2.0, 1.0)
  98. # make sure correct number of args required
  99. raises(TypeError, lambda: f(0))
  100. def test_own_namespace_1():
  101. myfunc = lambda x: 1
  102. f = lambdify(x, sin(x), {"sin": myfunc})
  103. assert f(0.1) == 1
  104. assert f(100) == 1
  105. def test_own_namespace_2():
  106. def myfunc(x):
  107. return 1
  108. f = lambdify(x, sin(x), {'sin': myfunc})
  109. assert f(0.1) == 1
  110. assert f(100) == 1
  111. def test_own_module():
  112. f = lambdify(x, sin(x), math)
  113. assert f(0) == 0.0
  114. p, q, r = symbols("p q r", real=True)
  115. ae = abs(exp(p+UnevaluatedExpr(q+r)))
  116. f = lambdify([p, q, r], [ae, ae], modules=math)
  117. results = f(1.0, 1e18, -1e18)
  118. refvals = [math.exp(1.0)]*2
  119. for res, ref in zip(results, refvals):
  120. assert abs((res-ref)/ref) < 1e-15
  121. def test_bad_args():
  122. # no vargs given
  123. raises(TypeError, lambda: lambdify(1))
  124. # same with vector exprs
  125. raises(TypeError, lambda: lambdify([1, 2]))
  126. def test_atoms():
  127. # Non-Symbol atoms should not be pulled out from the expression namespace
  128. f = lambdify(x, pi + x, {"pi": 3.14})
  129. assert f(0) == 3.14
  130. f = lambdify(x, I + x, {"I": 1j})
  131. assert f(1) == 1 + 1j
  132. #================== Test different modules =========================
  133. # high precision output of sin(0.2*pi) is used to detect if precision is lost unwanted
  134. @conserve_mpmath_dps
  135. def test_sympy_lambda():
  136. mpmath.mp.dps = 50
  137. sin02 = mpmath.mpf("0.19866933079506121545941262711838975037020672954020")
  138. f = lambdify(x, sin(x), "sympy")
  139. assert f(x) == sin(x)
  140. prec = 1e-15
  141. assert -prec < f(Rational(1, 5)).evalf() - Float(str(sin02)) < prec
  142. # arctan is in numpy module and should not be available
  143. # The arctan below gives NameError. What is this supposed to test?
  144. # raises(NameError, lambda: lambdify(x, arctan(x), "sympy"))
  145. @conserve_mpmath_dps
  146. def test_math_lambda():
  147. mpmath.mp.dps = 50
  148. sin02 = mpmath.mpf("0.19866933079506121545941262711838975037020672954020")
  149. f = lambdify(x, sin(x), "math")
  150. prec = 1e-15
  151. assert -prec < f(0.2) - sin02 < prec
  152. raises(TypeError, lambda: f(x))
  153. # if this succeeds, it can't be a Python math function
  154. @conserve_mpmath_dps
  155. def test_mpmath_lambda():
  156. mpmath.mp.dps = 50
  157. sin02 = mpmath.mpf("0.19866933079506121545941262711838975037020672954020")
  158. f = lambdify(x, sin(x), "mpmath")
  159. prec = 1e-49 # mpmath precision is around 50 decimal places
  160. assert -prec < f(mpmath.mpf("0.2")) - sin02 < prec
  161. raises(TypeError, lambda: f(x))
  162. # if this succeeds, it can't be a mpmath function
  163. ref2 = (mpmath.mpf("1e-30")
  164. - mpmath.mpf("1e-45")/2
  165. + 5*mpmath.mpf("1e-60")/6
  166. - 3*mpmath.mpf("1e-75")/4
  167. + 33*mpmath.mpf("1e-90")/40
  168. )
  169. f2a = lambdify((x, y), x**y - 1, "mpmath")
  170. f2b = lambdify((x, y), powm1(x, y), "mpmath")
  171. f2c = lambdify((x,), expm1(x*log1p(x)), "mpmath")
  172. ans2a = f2a(mpmath.mpf("1")+mpmath.mpf("1e-15"), mpmath.mpf("1e-15"))
  173. ans2b = f2b(mpmath.mpf("1")+mpmath.mpf("1e-15"), mpmath.mpf("1e-15"))
  174. ans2c = f2c(mpmath.mpf("1e-15"))
  175. assert abs(ans2a - ref2) < 1e-51
  176. assert abs(ans2b - ref2) < 1e-67
  177. assert abs(ans2c - ref2) < 1e-80
  178. @conserve_mpmath_dps
  179. def test_number_precision():
  180. mpmath.mp.dps = 50
  181. sin02 = mpmath.mpf("0.19866933079506121545941262711838975037020672954020")
  182. f = lambdify(x, sin02, "mpmath")
  183. prec = 1e-49 # mpmath precision is around 50 decimal places
  184. assert -prec < f(0) - sin02 < prec
  185. @conserve_mpmath_dps
  186. def test_mpmath_precision():
  187. mpmath.mp.dps = 100
  188. assert str(lambdify((), pi.evalf(100), 'mpmath')()) == str(pi.evalf(100))
  189. #================== Test Translations ==============================
  190. # We can only check if all translated functions are valid. It has to be checked
  191. # by hand if they are complete.
  192. def test_math_transl():
  193. from sympy.utilities.lambdify import MATH_TRANSLATIONS
  194. for sym, mat in MATH_TRANSLATIONS.items():
  195. assert sym in sympy.__dict__
  196. assert mat in math.__dict__
  197. def test_mpmath_transl():
  198. from sympy.utilities.lambdify import MPMATH_TRANSLATIONS
  199. for sym, mat in MPMATH_TRANSLATIONS.items():
  200. assert sym in sympy.__dict__ or sym == 'Matrix'
  201. assert mat in mpmath.__dict__
  202. def test_numpy_transl():
  203. if not numpy:
  204. skip("numpy not installed.")
  205. from sympy.utilities.lambdify import NUMPY_TRANSLATIONS
  206. for sym, nump in NUMPY_TRANSLATIONS.items():
  207. assert sym in sympy.__dict__
  208. assert nump in numpy.__dict__
  209. def test_scipy_transl():
  210. if not scipy:
  211. skip("scipy not installed.")
  212. from sympy.utilities.lambdify import SCIPY_TRANSLATIONS
  213. for sym, scip in SCIPY_TRANSLATIONS.items():
  214. assert sym in sympy.__dict__
  215. assert scip in scipy.__dict__ or scip in scipy.special.__dict__
  216. def test_numpy_translation_abs():
  217. if not numpy:
  218. skip("numpy not installed.")
  219. f = lambdify(x, Abs(x), "numpy")
  220. assert f(-1) == 1
  221. assert f(1) == 1
  222. def test_numexpr_printer():
  223. if not numexpr:
  224. skip("numexpr not installed.")
  225. # if translation/printing is done incorrectly then evaluating
  226. # a lambdified numexpr expression will throw an exception
  227. from sympy.printing.lambdarepr import NumExprPrinter
  228. blacklist = ('where', 'complex', 'contains')
  229. arg_tuple = (x, y, z) # some functions take more than one argument
  230. for sym in NumExprPrinter._numexpr_functions.keys():
  231. if sym in blacklist:
  232. continue
  233. ssym = S(sym)
  234. if hasattr(ssym, '_nargs'):
  235. nargs = ssym._nargs[0]
  236. else:
  237. nargs = 1
  238. args = arg_tuple[:nargs]
  239. f = lambdify(args, ssym(*args), modules='numexpr')
  240. assert f(*(1, )*nargs) is not None
  241. def test_cmath_sqrt():
  242. f = lambdify(x, sqrt(x), "cmath")
  243. assert f(0) == 0
  244. assert f(1) == 1
  245. assert f(4) == 2
  246. assert abs(f(2) - 1.414) < 0.001
  247. assert f(-1) == 1j
  248. assert f(-4) == 2j
  249. def test_cmath_log():
  250. f = lambdify(x, log(x), "cmath")
  251. assert abs(f(1) - 0) < 1e-15
  252. assert abs(f(cmath.e) - 1) < 1e-15
  253. assert abs(f(-1) - cmath.log(-1)) < 1e-15
  254. def test_cmath_sinh():
  255. f = lambdify(x, sinh(x), "cmath")
  256. assert abs(f(0) - cmath.sinh(0)) < 1e-15
  257. assert abs(f(pi) - cmath.sinh(pi)) < 1e-15
  258. assert abs(f(-pi) - cmath.sinh(-pi)) < 1e-15
  259. assert abs(f(1j) - cmath.sinh(1j)) < 1e-15
  260. def test_cmath_cosh():
  261. f = lambdify(x, cosh(x), "cmath")
  262. assert abs(f(0) - cmath.cosh(0)) < 1e-15
  263. assert abs(f(pi) - cmath.cosh(pi)) < 1e-15
  264. assert abs(f(-pi) - cmath.cosh(-pi)) < 1e-15
  265. assert abs(f(1j) - cmath.cosh(1j)) < 1e-15
  266. def test_cmath_tanh():
  267. f = lambdify(x, tanh(x), "cmath")
  268. assert abs(f(0) - cmath.tanh(0)) < 1e-15
  269. assert abs(f(pi) - cmath.tanh(pi)) < 1e-15
  270. assert abs(f(-pi) - cmath.tanh(-pi)) < 1e-15
  271. assert abs(f(1j) - cmath.tanh(1j)) < 1e-15
  272. def test_cmath_sin():
  273. f = lambdify(x, sin(x), "cmath")
  274. assert abs(f(0) - cmath.sin(0)) < 1e-15
  275. assert abs(f(pi) - cmath.sin(pi)) < 1e-15
  276. assert abs(f(-pi) - cmath.sin(-pi)) < 1e-15
  277. assert abs(f(1j) - cmath.sin(1j)) < 1e-15
  278. def test_cmath_cos():
  279. f = lambdify(x, cos(x), "cmath")
  280. assert abs(f(0) - cmath.cos(0)) < 1e-15
  281. assert abs(f(pi) - cmath.cos(pi)) < 1e-15
  282. assert abs(f(-pi) - cmath.cos(-pi)) < 1e-15
  283. assert abs(f(1j) - cmath.cos(1j)) < 1e-15
  284. def test_cmath_tan():
  285. f = lambdify(x, tan(x), "cmath")
  286. assert abs(f(0) - cmath.tan(0)) < 1e-15
  287. assert abs(f(1j) - cmath.tan(1j)) < 1e-15
  288. def test_cmath_asin():
  289. f = lambdify(x, asin(x), "cmath")
  290. assert abs(f(0) - cmath.asin(0)) < 1e-15
  291. assert abs(f(1) - cmath.asin(1)) < 1e-15
  292. assert abs(f(-1) - cmath.asin(-1)) < 1e-15
  293. assert abs(f(2) - cmath.asin(2)) < 1e-15
  294. assert abs(f(1j) - cmath.asin(1j)) < 1e-15
  295. def test_cmath_acos():
  296. f = lambdify(x, acos(x), "cmath")
  297. assert abs(f(1) - cmath.acos(1)) < 1e-15
  298. assert abs(f(-1) - cmath.acos(-1)) < 1e-15
  299. assert abs(f(2) - cmath.acos(2)) < 1e-15
  300. assert abs(f(1j) - cmath.acos(1j)) < 1e-15
  301. def test_cmath_atan():
  302. f = lambdify(x, atan(x), "cmath")
  303. assert abs(f(0) - cmath.atan(0)) < 1e-15
  304. assert abs(f(1) - cmath.atan(1)) < 1e-15
  305. assert abs(f(-1) - cmath.atan(-1)) < 1e-15
  306. assert abs(f(2) - cmath.atan(2)) < 1e-15
  307. assert abs(f(2j) - cmath.atan(2j)) < 1e-15
  308. def test_cmath_asinh():
  309. f = lambdify(x, asinh(x), "cmath")
  310. assert abs(f(0) - cmath.asinh(0)) < 1e-15
  311. assert abs(f(1) - cmath.asinh(1)) < 1e-15
  312. assert abs(f(-1) - cmath.asinh(-1)) < 1e-15
  313. assert abs(f(2) - cmath.asinh(2)) < 1e-15
  314. assert abs(f(2j) - cmath.asinh(2j)) < 1e-15
  315. def test_cmath_acosh():
  316. f = lambdify(x, acosh(x), "cmath")
  317. assert abs(f(1) - cmath.acosh(1)) < 1e-15
  318. assert abs(f(2) - cmath.acosh(2)) < 1e-15
  319. assert abs(f(-1) - cmath.acosh(-1)) < 1e-15
  320. assert abs(f(2j) - cmath.acosh(2j)) < 1e-15
  321. def test_cmath_atanh():
  322. f = lambdify(x, atanh(x), "cmath")
  323. assert abs(f(0) - cmath.atanh(0)) < 1e-15
  324. assert abs(f(0.5) - cmath.atanh(0.5)) < 1e-15
  325. assert abs(f(-0.5) - cmath.atanh(-0.5)) < 1e-15
  326. assert abs(f(2) - cmath.atanh(2)) < 1e-15
  327. assert abs(f(-2) - cmath.atanh(-2)) < 1e-15
  328. assert abs(f(2j) - cmath.atanh(2j)) < 1e-15
  329. def test_cmath_complex_identities():
  330. # Define symbol
  331. z = symbols('z')
  332. # Trigonometric identity using re(z) and im(z)
  333. expr = cos(z) - cos(re(z)) * cosh(im(z)) + I * sin(re(z)) * sinh(im(z))
  334. func = lambdify([z], expr, modules=["cmath", "math"])
  335. hpi = math.pi / 2
  336. assert abs(func(hpi + 1j * hpi)) < 4e-16
  337. # Euler's Formula: e^(i*z) = cos(z) + i*sin(z)
  338. func = lambdify([z], exp(I * z) - (cos(z) + I * sin(z)), modules=["cmath", "math"])
  339. assert abs(func(hpi)) < 4e-16
  340. # Exponential Identity: e^z = e^(Re(z)) * (cos(Im(z)) + i*sin(Im(z)))
  341. func_exp = lambdify([z], exp(z) - exp(re(z)) * (cos(im(z)) + I * sin(im(z))),
  342. modules=["cmath", "math"])
  343. assert abs(func_exp(hpi + 1j * hpi)) < 4e-16
  344. # Complex Cosine Identity: cos(z) = cos(Re(z)) * cosh(Im(z)) - i*sin(Re(z)) * sinh(Im(z))
  345. func_cos = lambdify([z], cos(z) - (cos(re(z)) * cosh(im(z)) - I * sin(re(z)) * sinh(im(z))),
  346. modules=["cmath", "math"])
  347. assert abs(func_cos(hpi + 1j * hpi)) < 4e-16
  348. # Complex Sine Identity: sin(z) = sin(Re(z)) * cosh(Im(z)) + i*cos(Re(z)) * sinh(Im(z))
  349. func_sin = lambdify([z], sin(z) - (sin(re(z)) * cosh(im(z)) + I * cos(re(z)) * sinh(im(z))),
  350. modules=["cmath", "math"])
  351. assert abs(func_sin(hpi + 1j * hpi)) < 4e-16
  352. # Complex Hyperbolic Cosine Identity: cosh(z) = cosh(Re(z)) * cos(Im(z)) + i*sinh(Re(z)) * sin(Im(z))
  353. func_cosh_1 = lambdify([z], cosh(z) - (cosh(re(z)) * cos(im(z)) + I * sinh(re(z)) * sin(im(z))),
  354. modules=["cmath", "math"])
  355. assert abs(func_cosh_1(hpi + 1j * hpi)) < 4e-16
  356. # Complex Hyperbolic Sine Identity: sinh(z) = sinh(Re(z)) * cos(Im(z)) + i*cosh(Re(z)) * sin(Im(z))
  357. func_sinh = lambdify([z], sinh(z) - (sinh(re(z)) * cos(im(z)) + I * cosh(re(z)) * sin(im(z))),
  358. modules=["cmath", "math"])
  359. assert abs(func_sinh(hpi + 1j * hpi)) < 4e-16
  360. # cosh(z) = (e^z + e^(-z)) / 2
  361. func_cosh_2 = lambdify([z], cosh(z) - (exp(z) + exp(-z)) / 2, modules=["cmath", "math"])
  362. assert abs(func_cosh_2(hpi)) < 4e-16
  363. # Additional expressions testing log and exp with real and imaginary parts
  364. expr1 = log(re(z)) + log(im(z)) - log(re(z) * im(z))
  365. expr2 = exp(re(z)) * exp(im(z) * I) - exp(z)
  366. expr3 = log(exp(re(z))) - re(z)
  367. expr4 = exp(log(re(z))) - re(z)
  368. expr5 = log(exp(re(z) + im(z))) - (re(z) + im(z))
  369. expr6 = exp(log(re(z) + im(z))) - (re(z) + im(z))
  370. func1 = lambdify([z], expr1, modules=["cmath", "math"])
  371. func2 = lambdify([z], expr2, modules=["cmath", "math"])
  372. func3 = lambdify([z], expr3, modules=["cmath", "math"])
  373. func4 = lambdify([z], expr4, modules=["cmath", "math"])
  374. func5 = lambdify([z], expr5, modules=["cmath", "math"])
  375. func6 = lambdify([z], expr6, modules=["cmath", "math"])
  376. test_value = 3 + 4j
  377. assert abs(func1(test_value)) < 4e-16
  378. assert abs(func2(test_value)) < 4e-16
  379. assert abs(func3(test_value)) < 4e-16
  380. assert abs(func4(test_value)) < 4e-16
  381. assert abs(func5(test_value)) < 4e-16
  382. assert abs(func6(test_value)) < 4e-16
  383. def test_issue_9334():
  384. if not numexpr:
  385. skip("numexpr not installed.")
  386. if not numpy:
  387. skip("numpy not installed.")
  388. expr = S('b*a - sqrt(a**2)')
  389. a, b = sorted(expr.free_symbols, key=lambda s: s.name)
  390. func_numexpr = lambdify((a,b), expr, modules=[numexpr], dummify=False)
  391. foo, bar = numpy.random.random((2, 4))
  392. func_numexpr(foo, bar)
  393. def test_issue_12984():
  394. if not numexpr:
  395. skip("numexpr not installed.")
  396. func_numexpr = lambdify((x,y,z), Piecewise((y, x >= 0), (z, x > -1)), numexpr)
  397. with ignore_warnings(RuntimeWarning):
  398. assert func_numexpr(1, 24, 42) == 24
  399. assert str(func_numexpr(-1, 24, 42)) == 'nan'
  400. def test_empty_modules():
  401. x, y = symbols('x y')
  402. expr = -(x % y)
  403. no_modules = lambdify([x, y], expr)
  404. empty_modules = lambdify([x, y], expr, modules=[])
  405. assert no_modules(3, 7) == empty_modules(3, 7)
  406. assert no_modules(3, 7) == -3
  407. def test_exponentiation():
  408. f = lambdify(x, x**2)
  409. assert f(-1) == 1
  410. assert f(0) == 0
  411. assert f(1) == 1
  412. assert f(-2) == 4
  413. assert f(2) == 4
  414. assert f(2.5) == 6.25
  415. def test_sqrt():
  416. f = lambdify(x, sqrt(x))
  417. assert f(0) == 0.0
  418. assert f(1) == 1.0
  419. assert f(4) == 2.0
  420. assert abs(f(2) - 1.414) < 0.001
  421. assert f(6.25) == 2.5
  422. def test_trig():
  423. f = lambdify([x], [cos(x), sin(x)], 'math')
  424. d = f(pi)
  425. prec = 1e-11
  426. assert -prec < d[0] + 1 < prec
  427. assert -prec < d[1] < prec
  428. d = f(3.14159)
  429. prec = 1e-5
  430. assert -prec < d[0] + 1 < prec
  431. assert -prec < d[1] < prec
  432. def test_integral():
  433. if numpy and not scipy:
  434. skip("scipy not installed.")
  435. f = Lambda(x, exp(-x**2))
  436. l = lambdify(y, Integral(f(x), (x, y, oo)))
  437. d = l(-oo)
  438. assert 1.77245385 < d < 1.772453851
  439. def test_double_integral():
  440. if numpy and not scipy:
  441. skip("scipy not installed.")
  442. # example from http://mpmath.org/doc/current/calculus/integration.html
  443. i = Integral(1/(1 - x**2*y**2), (x, 0, 1), (y, 0, z))
  444. l = lambdify([z], i)
  445. d = l(1)
  446. assert 1.23370055 < d < 1.233700551
  447. def test_spherical_bessel():
  448. if numpy and not scipy:
  449. skip("scipy not installed.")
  450. test_point = 4.2 #randomly selected
  451. x = symbols("x")
  452. jtest = jn(2, x)
  453. assert abs(lambdify(x,jtest)(test_point) -
  454. jtest.subs(x,test_point).evalf()) < 1e-8
  455. ytest = yn(2, x)
  456. assert abs(lambdify(x,ytest)(test_point) -
  457. ytest.subs(x,test_point).evalf()) < 1e-8
  458. #================== Test vectors ===================================
  459. def test_vector_simple():
  460. f = lambdify((x, y, z), (z, y, x))
  461. assert f(3, 2, 1) == (1, 2, 3)
  462. assert f(1.0, 2.0, 3.0) == (3.0, 2.0, 1.0)
  463. # make sure correct number of args required
  464. raises(TypeError, lambda: f(0))
  465. def test_vector_discontinuous():
  466. f = lambdify(x, (-1/x, 1/x))
  467. raises(ZeroDivisionError, lambda: f(0))
  468. assert f(1) == (-1.0, 1.0)
  469. assert f(2) == (-0.5, 0.5)
  470. assert f(-2) == (0.5, -0.5)
  471. def test_trig_symbolic():
  472. f = lambdify([x], [cos(x), sin(x)], 'math')
  473. d = f(pi)
  474. assert abs(d[0] + 1) < 0.0001
  475. assert abs(d[1] - 0) < 0.0001
  476. def test_trig_float():
  477. f = lambdify([x], [cos(x), sin(x)])
  478. d = f(3.14159)
  479. assert abs(d[0] + 1) < 0.0001
  480. assert abs(d[1] - 0) < 0.0001
  481. def test_docs():
  482. f = lambdify(x, x**2)
  483. assert f(2) == 4
  484. f = lambdify([x, y, z], [z, y, x])
  485. assert f(1, 2, 3) == [3, 2, 1]
  486. f = lambdify(x, sqrt(x))
  487. assert f(4) == 2.0
  488. f = lambdify((x, y), sin(x*y)**2)
  489. assert f(0, 5) == 0
  490. def test_math():
  491. f = lambdify((x, y), sin(x), modules="math")
  492. assert f(0, 5) == 0
  493. def test_sin():
  494. f = lambdify(x, sin(x)**2)
  495. assert isinstance(f(2), float)
  496. f = lambdify(x, sin(x)**2, modules="math")
  497. assert isinstance(f(2), float)
  498. def test_matrix():
  499. A = Matrix([[x, x*y], [sin(z) + 4, x**z]])
  500. sol = Matrix([[1, 2], [sin(3) + 4, 1]])
  501. f = lambdify((x, y, z), A, modules="sympy")
  502. assert f(1, 2, 3) == sol
  503. f = lambdify((x, y, z), (A, [A]), modules="sympy")
  504. assert f(1, 2, 3) == (sol, [sol])
  505. J = Matrix((x, x + y)).jacobian((x, y))
  506. v = Matrix((x, y))
  507. sol = Matrix([[1, 0], [1, 1]])
  508. assert lambdify(v, J, modules='sympy')(1, 2) == sol
  509. assert lambdify(v.T, J, modules='sympy')(1, 2) == sol
  510. def test_numpy_matrix():
  511. if not numpy:
  512. skip("numpy not installed.")
  513. A = Matrix([[x, x*y], [sin(z) + 4, x**z]])
  514. sol_arr = numpy.array([[1, 2], [numpy.sin(3) + 4, 1]])
  515. #Lambdify array first, to ensure return to array as default
  516. f = lambdify((x, y, z), A, ['numpy'])
  517. numpy.testing.assert_allclose(f(1, 2, 3), sol_arr)
  518. #Check that the types are arrays and matrices
  519. assert isinstance(f(1, 2, 3), numpy.ndarray)
  520. # gh-15071
  521. class dot(Function):
  522. pass
  523. x_dot_mtx = dot(x, Matrix([[2], [1], [0]]))
  524. f_dot1 = lambdify(x, x_dot_mtx)
  525. inp = numpy.zeros((17, 3))
  526. assert numpy.all(f_dot1(inp) == 0)
  527. strict_kw = {"allow_unknown_functions": False, "inline": True, "fully_qualified_modules": False}
  528. p2 = NumPyPrinter(dict(user_functions={'dot': 'dot'}, **strict_kw))
  529. f_dot2 = lambdify(x, x_dot_mtx, printer=p2)
  530. assert numpy.all(f_dot2(inp) == 0)
  531. p3 = NumPyPrinter(strict_kw)
  532. # The line below should probably fail upon construction (before calling with "(inp)"):
  533. raises(Exception, lambda: lambdify(x, x_dot_mtx, printer=p3)(inp))
  534. def test_numpy_transpose():
  535. if not numpy:
  536. skip("numpy not installed.")
  537. A = Matrix([[1, x], [0, 1]])
  538. f = lambdify((x), A.T, modules="numpy")
  539. numpy.testing.assert_array_equal(f(2), numpy.array([[1, 0], [2, 1]]))
  540. def test_numpy_dotproduct():
  541. if not numpy:
  542. skip("numpy not installed")
  543. A = Matrix([x, y, z])
  544. f1 = lambdify([x, y, z], DotProduct(A, A), modules='numpy')
  545. f2 = lambdify([x, y, z], DotProduct(A, A.T), modules='numpy')
  546. f3 = lambdify([x, y, z], DotProduct(A.T, A), modules='numpy')
  547. f4 = lambdify([x, y, z], DotProduct(A, A.T), modules='numpy')
  548. assert f1(1, 2, 3) == \
  549. f2(1, 2, 3) == \
  550. f3(1, 2, 3) == \
  551. f4(1, 2, 3) == \
  552. numpy.array([14])
  553. def test_numpy_inverse():
  554. if not numpy:
  555. skip("numpy not installed.")
  556. A = Matrix([[1, x], [0, 1]])
  557. f = lambdify((x), A**-1, modules="numpy")
  558. numpy.testing.assert_array_equal(f(2), numpy.array([[1, -2], [0, 1]]))
  559. def test_numpy_old_matrix():
  560. if not numpy:
  561. skip("numpy not installed.")
  562. A = Matrix([[x, x*y], [sin(z) + 4, x**z]])
  563. sol_arr = numpy.array([[1, 2], [numpy.sin(3) + 4, 1]])
  564. f = lambdify((x, y, z), A, [{'ImmutableDenseMatrix': numpy.matrix}, 'numpy'])
  565. with ignore_warnings(PendingDeprecationWarning):
  566. numpy.testing.assert_allclose(f(1, 2, 3), sol_arr)
  567. assert isinstance(f(1, 2, 3), numpy.matrix)
  568. def test_scipy_sparse_matrix():
  569. if not scipy:
  570. skip("scipy not installed.")
  571. A = SparseMatrix([[x, 0], [0, y]])
  572. f = lambdify((x, y), A, modules="scipy")
  573. B = f(1, 2)
  574. assert isinstance(B, scipy.sparse.coo_matrix)
  575. def test_python_div_zero_issue_11306():
  576. if not numpy:
  577. skip("numpy not installed.")
  578. p = Piecewise((1 / x, y < -1), (x, y < 1), (1 / x, True))
  579. f = lambdify([x, y], p, modules='numpy')
  580. with numpy.errstate(divide='ignore'):
  581. assert float(f(numpy.array(0), numpy.array(0.5))) == 0
  582. assert float(f(numpy.array(0), numpy.array(1))) == float('inf')
  583. def test_issue9474():
  584. mods = [None, 'math']
  585. if numpy:
  586. mods.append('numpy')
  587. if mpmath:
  588. mods.append('mpmath')
  589. for mod in mods:
  590. f = lambdify(x, S.One/x, modules=mod)
  591. assert f(2) == 0.5
  592. f = lambdify(x, floor(S.One/x), modules=mod)
  593. assert f(2) == 0
  594. for absfunc, modules in product([Abs, abs], mods):
  595. f = lambdify(x, absfunc(x), modules=modules)
  596. assert f(-1) == 1
  597. assert f(1) == 1
  598. assert f(3+4j) == 5
  599. def test_issue_9871():
  600. if not numexpr:
  601. skip("numexpr not installed.")
  602. if not numpy:
  603. skip("numpy not installed.")
  604. r = sqrt(x**2 + y**2)
  605. expr = diff(1/r, x)
  606. xn = yn = numpy.linspace(1, 10, 16)
  607. # expr(xn, xn) = -xn/(sqrt(2)*xn)^3
  608. fv_exact = -numpy.sqrt(2.)**-3 * xn**-2
  609. fv_numpy = lambdify((x, y), expr, modules='numpy')(xn, yn)
  610. fv_numexpr = lambdify((x, y), expr, modules='numexpr')(xn, yn)
  611. numpy.testing.assert_allclose(fv_numpy, fv_exact, rtol=1e-10)
  612. numpy.testing.assert_allclose(fv_numexpr, fv_exact, rtol=1e-10)
  613. def test_numpy_piecewise():
  614. if not numpy:
  615. skip("numpy not installed.")
  616. pieces = Piecewise((x, x < 3), (x**2, x > 5), (0, True))
  617. f = lambdify(x, pieces, modules="numpy")
  618. numpy.testing.assert_array_equal(f(numpy.arange(10)),
  619. numpy.array([0, 1, 2, 0, 0, 0, 36, 49, 64, 81]))
  620. # If we evaluate somewhere all conditions are False, we should get back NaN
  621. nodef_func = lambdify(x, Piecewise((x, x > 0), (-x, x < 0)))
  622. numpy.testing.assert_array_equal(nodef_func(numpy.array([-1, 0, 1])),
  623. numpy.array([1, numpy.nan, 1]))
  624. def test_numpy_logical_ops():
  625. if not numpy:
  626. skip("numpy not installed.")
  627. and_func = lambdify((x, y), And(x, y), modules="numpy")
  628. and_func_3 = lambdify((x, y, z), And(x, y, z), modules="numpy")
  629. or_func = lambdify((x, y), Or(x, y), modules="numpy")
  630. or_func_3 = lambdify((x, y, z), Or(x, y, z), modules="numpy")
  631. not_func = lambdify((x), Not(x), modules="numpy")
  632. arr1 = numpy.array([True, True])
  633. arr2 = numpy.array([False, True])
  634. arr3 = numpy.array([True, False])
  635. numpy.testing.assert_array_equal(and_func(arr1, arr2), numpy.array([False, True]))
  636. numpy.testing.assert_array_equal(and_func_3(arr1, arr2, arr3), numpy.array([False, False]))
  637. numpy.testing.assert_array_equal(or_func(arr1, arr2), numpy.array([True, True]))
  638. numpy.testing.assert_array_equal(or_func_3(arr1, arr2, arr3), numpy.array([True, True]))
  639. numpy.testing.assert_array_equal(not_func(arr2), numpy.array([True, False]))
  640. def test_numpy_matmul():
  641. if not numpy:
  642. skip("numpy not installed.")
  643. xmat = Matrix([[x, y], [z, 1+z]])
  644. ymat = Matrix([[x**2], [Abs(x)]])
  645. mat_func = lambdify((x, y, z), xmat*ymat, modules="numpy")
  646. numpy.testing.assert_array_equal(mat_func(0.5, 3, 4), numpy.array([[1.625], [3.5]]))
  647. numpy.testing.assert_array_equal(mat_func(-0.5, 3, 4), numpy.array([[1.375], [3.5]]))
  648. # Multiple matrices chained together in multiplication
  649. f = lambdify((x, y, z), xmat*xmat*xmat, modules="numpy")
  650. numpy.testing.assert_array_equal(f(0.5, 3, 4), numpy.array([[72.125, 119.25],
  651. [159, 251]]))
  652. def test_numpy_numexpr():
  653. if not numpy:
  654. skip("numpy not installed.")
  655. if not numexpr:
  656. skip("numexpr not installed.")
  657. a, b, c = numpy.random.randn(3, 128, 128)
  658. # ensure that numpy and numexpr return same value for complicated expression
  659. expr = sin(x) + cos(y) + tan(z)**2 + Abs(z-y)*acos(sin(y*z)) + \
  660. Abs(y-z)*acosh(2+exp(y-x))- sqrt(x**2+I*y**2)
  661. npfunc = lambdify((x, y, z), expr, modules='numpy')
  662. nefunc = lambdify((x, y, z), expr, modules='numexpr')
  663. assert numpy.allclose(npfunc(a, b, c), nefunc(a, b, c))
  664. def test_numexpr_userfunctions():
  665. if not numpy:
  666. skip("numpy not installed.")
  667. if not numexpr:
  668. skip("numexpr not installed.")
  669. a, b = numpy.random.randn(2, 10)
  670. uf = type('uf', (Function, ),
  671. {'eval' : classmethod(lambda x, y : y**2+1)})
  672. func = lambdify(x, 1-uf(x), modules='numexpr')
  673. assert numpy.allclose(func(a), -(a**2))
  674. uf = implemented_function(Function('uf'), lambda x, y : 2*x*y+1)
  675. func = lambdify((x, y), uf(x, y), modules='numexpr')
  676. assert numpy.allclose(func(a, b), 2*a*b+1)
  677. def test_tensorflow_basic_math():
  678. if not tensorflow:
  679. skip("tensorflow not installed.")
  680. expr = Max(sin(x), Abs(1/(x+2)))
  681. func = lambdify(x, expr, modules="tensorflow")
  682. with tensorflow.compat.v1.Session() as s:
  683. a = tensorflow.constant(0, dtype=tensorflow.float32)
  684. assert func(a).eval(session=s) == 0.5
  685. def test_tensorflow_placeholders():
  686. if not tensorflow:
  687. skip("tensorflow not installed.")
  688. expr = Max(sin(x), Abs(1/(x+2)))
  689. func = lambdify(x, expr, modules="tensorflow")
  690. with tensorflow.compat.v1.Session() as s:
  691. a = tensorflow.compat.v1.placeholder(dtype=tensorflow.float32)
  692. assert func(a).eval(session=s, feed_dict={a: 0}) == 0.5
  693. def test_tensorflow_variables():
  694. if not tensorflow:
  695. skip("tensorflow not installed.")
  696. expr = Max(sin(x), Abs(1/(x+2)))
  697. func = lambdify(x, expr, modules="tensorflow")
  698. with tensorflow.compat.v1.Session() as s:
  699. a = tensorflow.Variable(0, dtype=tensorflow.float32)
  700. s.run(a.initializer)
  701. assert func(a).eval(session=s, feed_dict={a: 0}) == 0.5
  702. def test_tensorflow_logical_operations():
  703. if not tensorflow:
  704. skip("tensorflow not installed.")
  705. expr = Not(And(Or(x, y), y))
  706. func = lambdify([x, y], expr, modules="tensorflow")
  707. with tensorflow.compat.v1.Session() as s:
  708. assert func(False, True).eval(session=s) == False
  709. def test_tensorflow_piecewise():
  710. if not tensorflow:
  711. skip("tensorflow not installed.")
  712. expr = Piecewise((0, Eq(x,0)), (-1, x < 0), (1, x > 0))
  713. func = lambdify(x, expr, modules="tensorflow")
  714. with tensorflow.compat.v1.Session() as s:
  715. assert func(-1).eval(session=s) == -1
  716. assert func(0).eval(session=s) == 0
  717. assert func(1).eval(session=s) == 1
  718. def test_tensorflow_multi_max():
  719. if not tensorflow:
  720. skip("tensorflow not installed.")
  721. expr = Max(x, -x, x**2)
  722. func = lambdify(x, expr, modules="tensorflow")
  723. with tensorflow.compat.v1.Session() as s:
  724. assert func(-2).eval(session=s) == 4
  725. def test_tensorflow_multi_min():
  726. if not tensorflow:
  727. skip("tensorflow not installed.")
  728. expr = Min(x, -x, x**2)
  729. func = lambdify(x, expr, modules="tensorflow")
  730. with tensorflow.compat.v1.Session() as s:
  731. assert func(-2).eval(session=s) == -2
  732. def test_tensorflow_relational():
  733. if not tensorflow:
  734. skip("tensorflow not installed.")
  735. expr = x >= 0
  736. func = lambdify(x, expr, modules="tensorflow")
  737. with tensorflow.compat.v1.Session() as s:
  738. assert func(1).eval(session=s) == True
  739. def test_tensorflow_complexes():
  740. if not tensorflow:
  741. skip("tensorflow not installed")
  742. func1 = lambdify(x, re(x), modules="tensorflow")
  743. func2 = lambdify(x, im(x), modules="tensorflow")
  744. func3 = lambdify(x, Abs(x), modules="tensorflow")
  745. func4 = lambdify(x, arg(x), modules="tensorflow")
  746. with tensorflow.compat.v1.Session() as s:
  747. # For versions before
  748. # https://github.com/tensorflow/tensorflow/issues/30029
  749. # resolved, using Python numeric types may not work
  750. a = tensorflow.constant(1+2j)
  751. assert func1(a).eval(session=s) == 1
  752. assert func2(a).eval(session=s) == 2
  753. tensorflow_result = func3(a).eval(session=s)
  754. sympy_result = Abs(1 + 2j).evalf()
  755. assert abs(tensorflow_result-sympy_result) < 10**-6
  756. tensorflow_result = func4(a).eval(session=s)
  757. sympy_result = arg(1 + 2j).evalf()
  758. assert abs(tensorflow_result-sympy_result) < 10**-6
  759. def test_tensorflow_array_arg():
  760. # Test for issue 14655 (tensorflow part)
  761. if not tensorflow:
  762. skip("tensorflow not installed.")
  763. f = lambdify([[x, y]], x*x + y, 'tensorflow')
  764. with tensorflow.compat.v1.Session() as s:
  765. fcall = f(tensorflow.constant([2.0, 1.0]))
  766. assert fcall.eval(session=s) == 5.0
  767. #================== Test symbolic ==================================
  768. def test_sym_single_arg():
  769. f = lambdify(x, x * y)
  770. assert f(z) == z * y
  771. def test_sym_list_args():
  772. f = lambdify([x, y], x + y + z)
  773. assert f(1, 2) == 3 + z
  774. def test_sym_integral():
  775. f = Lambda(x, exp(-x**2))
  776. l = lambdify(x, Integral(f(x), (x, -oo, oo)), modules="sympy")
  777. assert l(y) == Integral(exp(-y**2), (y, -oo, oo))
  778. assert l(y).doit() == sqrt(pi)
  779. def test_namespace_order():
  780. # lambdify had a bug, such that module dictionaries or cached module
  781. # dictionaries would pull earlier namespaces into themselves.
  782. # Because the module dictionaries form the namespace of the
  783. # generated lambda, this meant that the behavior of a previously
  784. # generated lambda function could change as a result of later calls
  785. # to lambdify.
  786. n1 = {'f': lambda x: 'first f'}
  787. n2 = {'f': lambda x: 'second f',
  788. 'g': lambda x: 'function g'}
  789. f = sympy.Function('f')
  790. g = sympy.Function('g')
  791. if1 = lambdify(x, f(x), modules=(n1, "sympy"))
  792. assert if1(1) == 'first f'
  793. if2 = lambdify(x, g(x), modules=(n2, "sympy"))
  794. # previously gave 'second f'
  795. assert if1(1) == 'first f'
  796. assert if2(1) == 'function g'
  797. def test_imps():
  798. # Here we check if the default returned functions are anonymous - in
  799. # the sense that we can have more than one function with the same name
  800. f = implemented_function('f', lambda x: 2*x)
  801. g = implemented_function('f', lambda x: math.sqrt(x))
  802. l1 = lambdify(x, f(x))
  803. l2 = lambdify(x, g(x))
  804. assert str(f(x)) == str(g(x))
  805. assert l1(3) == 6
  806. assert l2(3) == math.sqrt(3)
  807. # check that we can pass in a Function as input
  808. func = sympy.Function('myfunc')
  809. assert not hasattr(func, '_imp_')
  810. my_f = implemented_function(func, lambda x: 2*x)
  811. assert hasattr(my_f, '_imp_')
  812. # Error for functions with same name and different implementation
  813. f2 = implemented_function("f", lambda x: x + 101)
  814. raises(ValueError, lambda: lambdify(x, f(f2(x))))
  815. def test_imps_errors():
  816. # Test errors that implemented functions can return, and still be able to
  817. # form expressions.
  818. # See: https://github.com/sympy/sympy/issues/10810
  819. #
  820. # XXX: Removed AttributeError here. This test was added due to issue 10810
  821. # but that issue was about ValueError. It doesn't seem reasonable to
  822. # "support" catching AttributeError in the same context...
  823. for val, error_class in product((0, 0., 2, 2.0), (TypeError, ValueError)):
  824. def myfunc(a):
  825. if a == 0:
  826. raise error_class
  827. return 1
  828. f = implemented_function('f', myfunc)
  829. expr = f(val)
  830. assert expr == f(val)
  831. def test_imps_wrong_args():
  832. raises(ValueError, lambda: implemented_function(sin, lambda x: x))
  833. def test_lambdify_imps():
  834. # Test lambdify with implemented functions
  835. # first test basic (sympy) lambdify
  836. f = sympy.cos
  837. assert lambdify(x, f(x))(0) == 1
  838. assert lambdify(x, 1 + f(x))(0) == 2
  839. assert lambdify((x, y), y + f(x))(0, 1) == 2
  840. # make an implemented function and test
  841. f = implemented_function("f", lambda x: x + 100)
  842. assert lambdify(x, f(x))(0) == 100
  843. assert lambdify(x, 1 + f(x))(0) == 101
  844. assert lambdify((x, y), y + f(x))(0, 1) == 101
  845. # Can also handle tuples, lists, dicts as expressions
  846. lam = lambdify(x, (f(x), x))
  847. assert lam(3) == (103, 3)
  848. lam = lambdify(x, [f(x), x])
  849. assert lam(3) == [103, 3]
  850. lam = lambdify(x, [f(x), (f(x), x)])
  851. assert lam(3) == [103, (103, 3)]
  852. lam = lambdify(x, {f(x): x})
  853. assert lam(3) == {103: 3}
  854. lam = lambdify(x, {f(x): x})
  855. assert lam(3) == {103: 3}
  856. lam = lambdify(x, {x: f(x)})
  857. assert lam(3) == {3: 103}
  858. # Check that imp preferred to other namespaces by default
  859. d = {'f': lambda x: x + 99}
  860. lam = lambdify(x, f(x), d)
  861. assert lam(3) == 103
  862. # Unless flag passed
  863. lam = lambdify(x, f(x), d, use_imps=False)
  864. assert lam(3) == 102
  865. def test_dummification():
  866. t = symbols('t')
  867. F = Function('F')
  868. G = Function('G')
  869. #"\alpha" is not a valid Python variable name
  870. #lambdify should sub in a dummy for it, and return
  871. #without a syntax error
  872. alpha = symbols(r'\alpha')
  873. some_expr = 2 * F(t)**2 / G(t)
  874. lam = lambdify((F(t), G(t)), some_expr)
  875. assert lam(3, 9) == 2
  876. lam = lambdify(sin(t), 2 * sin(t)**2)
  877. assert lam(F(t)) == 2 * F(t)**2
  878. #Test that \alpha was properly dummified
  879. lam = lambdify((alpha, t), 2*alpha + t)
  880. assert lam(2, 1) == 5
  881. raises(SyntaxError, lambda: lambdify(F(t) * G(t), F(t) * G(t) + 5))
  882. raises(SyntaxError, lambda: lambdify(2 * F(t), 2 * F(t) + 5))
  883. raises(SyntaxError, lambda: lambdify(2 * F(t), 4 * F(t) + 5))
  884. def test_lambdify__arguments_with_invalid_python_identifiers():
  885. # see sympy/sympy#26690
  886. N = CoordSys3D('N')
  887. xn, yn, zn = N.base_scalars()
  888. expr = xn + yn
  889. f = lambdify([xn, yn], expr)
  890. res = f(0.2, 0.3)
  891. ref = 0.2 + 0.3
  892. assert abs(res-ref) < 1e-15
  893. def test_curly_matrix_symbol():
  894. # Issue #15009
  895. curlyv = sympy.MatrixSymbol("{v}", 2, 1)
  896. lam = lambdify(curlyv, curlyv)
  897. assert lam(1)==1
  898. lam = lambdify(curlyv, curlyv, dummify=True)
  899. assert lam(1)==1
  900. def test_python_keywords():
  901. # Test for issue 7452. The automatic dummification should ensure use of
  902. # Python reserved keywords as symbol names will create valid lambda
  903. # functions. This is an additional regression test.
  904. python_if = symbols('if')
  905. expr = python_if / 2
  906. f = lambdify(python_if, expr)
  907. assert f(4.0) == 2.0
  908. def test_lambdify_docstring():
  909. func = lambdify((w, x, y, z), w + x + y + z)
  910. ref = (
  911. "Created with lambdify. Signature:\n\n"
  912. "func(w, x, y, z)\n\n"
  913. "Expression:\n\n"
  914. "w + x + y + z"
  915. ).splitlines()
  916. assert func.__doc__.splitlines()[:len(ref)] == ref
  917. syms = symbols('a1:26')
  918. func = lambdify(syms, sum(syms))
  919. ref = (
  920. "Created with lambdify. Signature:\n\n"
  921. "func(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15,\n"
  922. " a16, a17, a18, a19, a20, a21, a22, a23, a24, a25)\n\n"
  923. "Expression:\n\n"
  924. "a1 + a10 + a11 + a12 + a13 + a14 + a15 + a16 + a17 + a18 + a19 + a2 + a20 +..."
  925. ).splitlines()
  926. assert func.__doc__.splitlines()[:len(ref)] == ref
  927. def test_lambdify_linecache():
  928. func = lambdify(x, x + 1)
  929. source = 'def _lambdifygenerated(x):\n return x + 1\n'
  930. assert inspect.getsource(func) == source
  931. filename = inspect.getsourcefile(func)
  932. assert filename.startswith('<lambdifygenerated-')
  933. assert filename in linecache.cache
  934. assert linecache.cache[filename] == (len(source), None, source.splitlines(True), filename)
  935. del func
  936. gc.collect()
  937. assert filename not in linecache.cache
  938. #================== Test special printers ==========================
  939. def test_special_printers():
  940. from sympy.printing.lambdarepr import IntervalPrinter
  941. def intervalrepr(expr):
  942. return IntervalPrinter().doprint(expr)
  943. expr = sqrt(sqrt(2) + sqrt(3)) + S.Half
  944. func0 = lambdify((), expr, modules="mpmath", printer=intervalrepr)
  945. func1 = lambdify((), expr, modules="mpmath", printer=IntervalPrinter)
  946. func2 = lambdify((), expr, modules="mpmath", printer=IntervalPrinter())
  947. mpi = type(mpmath.mpi(1, 2))
  948. assert isinstance(func0(), mpi)
  949. assert isinstance(func1(), mpi)
  950. assert isinstance(func2(), mpi)
  951. # To check Is lambdify loggamma works for mpmath or not
  952. exp1 = lambdify(x, loggamma(x), 'mpmath')(5)
  953. exp2 = lambdify(x, loggamma(x), 'mpmath')(1.8)
  954. exp3 = lambdify(x, loggamma(x), 'mpmath')(15)
  955. exp_ls = [exp1, exp2, exp3]
  956. sol1 = mpmath.loggamma(5)
  957. sol2 = mpmath.loggamma(1.8)
  958. sol3 = mpmath.loggamma(15)
  959. sol_ls = [sol1, sol2, sol3]
  960. assert exp_ls == sol_ls
  961. def test_true_false():
  962. # We want exact is comparison here, not just ==
  963. assert lambdify([], true)() is True
  964. assert lambdify([], false)() is False
  965. def test_issue_2790():
  966. assert lambdify((x, (y, z)), x + y)(1, (2, 4)) == 3
  967. assert lambdify((x, (y, (w, z))), w + x + y + z)(1, (2, (3, 4))) == 10
  968. assert lambdify(x, x + 1, dummify=False)(1) == 2
  969. def test_issue_12092():
  970. f = implemented_function('f', lambda x: x**2)
  971. assert f(f(2)).evalf() == Float(16)
  972. def test_issue_14911():
  973. class Variable(sympy.Symbol):
  974. def _sympystr(self, printer):
  975. return printer.doprint(self.name)
  976. _lambdacode = _sympystr
  977. _numpycode = _sympystr
  978. x = Variable('x')
  979. y = 2 * x
  980. code = LambdaPrinter().doprint(y)
  981. assert code.replace(' ', '') == '2*x'
  982. def test_ITE():
  983. assert lambdify((x, y, z), ITE(x, y, z))(True, 5, 3) == 5
  984. assert lambdify((x, y, z), ITE(x, y, z))(False, 5, 3) == 3
  985. def test_Min_Max():
  986. # see gh-10375
  987. assert lambdify((x, y, z), Min(x, y, z))(1, 2, 3) == 1
  988. assert lambdify((x, y, z), Max(x, y, z))(1, 2, 3) == 3
  989. def test_amin_amax_minimum_maximum():
  990. if not numpy:
  991. skip("numpy not installed")
  992. a234 = numpy.array([2, 3, 4])
  993. a152 = numpy.array([1, 5, 2])
  994. a254 = numpy.array([2, 5, 4])
  995. a132 = numpy.array([1, 3, 2])
  996. # 2 args
  997. assert numpy.all(lambdify((x, y), maximum(x, y))(a234, a152) == a254)
  998. assert numpy.all(lambdify((x, y), minimum(x, y))(a234, a152) == a132)
  999. # 3 args
  1000. assert numpy.all(lambdify((x, y, z), maximum(x, y, z))(a234, a152, a234) == a254)
  1001. assert numpy.all(lambdify((x, y, z), minimum(x, y, z))(a234, a152, a234) == a132)
  1002. # 1 arg
  1003. assert numpy.all(lambdify((x,), maximum(x))(a234) == a234)
  1004. assert numpy.all(lambdify((x,), minimum(x))(a234) == a234)
  1005. # 4 args, mixed length
  1006. assert numpy.all(lambdify((x, y, z, w), maximum(x, y, z, w))(a234, a152, a234, 3) == [3, 5, 4])
  1007. assert numpy.all(lambdify((x, y, z, w), minimum(x, y, z, w))(a234, a152, a234, 2) == [1, 2, 2])
  1008. # amin & amax
  1009. assert lambdify((x, y), [amin(x), amax(y)])(a234, a152) == [2, 5]
  1010. A = numpy.array([
  1011. [0, 4, 8],
  1012. [1, 5, 9],
  1013. [2, 6, 10],
  1014. ])
  1015. min_, max_ = lambdify((x,), [amin(x, axis=0), amax(x, axis=1)])(A)
  1016. assert numpy.all(min_ == numpy.amin(A, axis=0))
  1017. assert numpy.all(max_ == numpy.amax(A, axis=1))
  1018. # see gh-25659
  1019. assert numpy.all(lambdify((x, y), Max(x, y))([1, 2, 3], [3, 2, 1]) == [3, 2, 3])
  1020. assert numpy.all(lambdify((x), Min(2, x))([1, 2, 3]) == [1, 2, 2])
  1021. def test_Indexed():
  1022. # Issue #10934
  1023. if not numpy:
  1024. skip("numpy not installed")
  1025. a = IndexedBase('a')
  1026. i, j = symbols('i j')
  1027. b = numpy.array([[1, 2], [3, 4]])
  1028. assert lambdify(a, Sum(a[x, y], (x, 0, 1), (y, 0, 1)))(b) == 10
  1029. def test_Sum():
  1030. e = Sum(z, (y, 0, x), (x, 0, 10))
  1031. ref = 66*z
  1032. assert e.doit() == ref
  1033. assert lambdify([z], e)(7) == ref.subs(z, 7)
  1034. def test_Idx():
  1035. # Issue 26888
  1036. a = IndexedBase('a')
  1037. i = Idx('i')
  1038. b = [1,2,3]
  1039. assert lambdify([a, i], a[i])(b, 2) == 3
  1040. def test_issue_12173():
  1041. #test for issue 12173
  1042. expr1 = lambdify((x, y), uppergamma(x, y),"mpmath")(1, 2)
  1043. expr2 = lambdify((x, y), lowergamma(x, y),"mpmath")(1, 2)
  1044. assert expr1 == uppergamma(1, 2).evalf()
  1045. assert expr2 == lowergamma(1, 2).evalf()
  1046. def test_issue_13642():
  1047. if not numpy:
  1048. skip("numpy not installed")
  1049. f = lambdify(x, sinc(x))
  1050. assert Abs(f(1) - sinc(1)).n() < 1e-15
  1051. def test_sinc_mpmath():
  1052. f = lambdify(x, sinc(x), "mpmath")
  1053. assert Abs(f(1) - sinc(1)).n() < 1e-15
  1054. def test_lambdify_dummy_arg():
  1055. d1 = Dummy()
  1056. f1 = lambdify(d1, d1 + 1, dummify=False)
  1057. assert f1(2) == 3
  1058. f1b = lambdify(d1, d1 + 1)
  1059. assert f1b(2) == 3
  1060. d2 = Dummy('x')
  1061. f2 = lambdify(d2, d2 + 1)
  1062. assert f2(2) == 3
  1063. f3 = lambdify([[d2]], d2 + 1)
  1064. assert f3([2]) == 3
  1065. def test_lambdify_mixed_symbol_dummy_args():
  1066. d = Dummy()
  1067. # Contrived example of name clash
  1068. dsym = symbols(str(d))
  1069. f = lambdify([d, dsym], d - dsym)
  1070. assert f(4, 1) == 3
  1071. def test_numpy_array_arg():
  1072. # Test for issue 14655 (numpy part)
  1073. if not numpy:
  1074. skip("numpy not installed")
  1075. f = lambdify([[x, y]], x*x + y, 'numpy')
  1076. assert f(numpy.array([2.0, 1.0])) == 5
  1077. def test_scipy_fns():
  1078. if not scipy:
  1079. skip("scipy not installed")
  1080. single_arg_sympy_fns = [Ei, erf, erfc, factorial, gamma, loggamma, digamma, Si, Ci]
  1081. single_arg_scipy_fns = [scipy.special.expi, scipy.special.erf, scipy.special.erfc,
  1082. scipy.special.factorial, scipy.special.gamma, scipy.special.gammaln,
  1083. scipy.special.psi, scipy.special.sici, scipy.special.sici]
  1084. numpy.random.seed(0)
  1085. for (sympy_fn, scipy_fn) in zip(single_arg_sympy_fns, single_arg_scipy_fns):
  1086. f = lambdify(x, sympy_fn(x), modules="scipy")
  1087. for i in range(20):
  1088. tv = numpy.random.uniform(-10, 10) + 1j*numpy.random.uniform(-5, 5)
  1089. # SciPy thinks that factorial(z) is 0 when re(z) < 0 and
  1090. # does not support complex numbers.
  1091. # SymPy does not think so.
  1092. if sympy_fn == factorial:
  1093. tv = numpy.abs(tv)
  1094. # SciPy supports gammaln for real arguments only,
  1095. # and there is also a branch cut along the negative real axis
  1096. if sympy_fn == loggamma:
  1097. tv = numpy.abs(tv)
  1098. # SymPy's digamma evaluates as polygamma(0, z)
  1099. # which SciPy supports for real arguments only
  1100. if sympy_fn == digamma:
  1101. tv = numpy.real(tv)
  1102. sympy_result = sympy_fn(tv).evalf()
  1103. scipy_result = scipy_fn(tv)
  1104. # SciPy's sici returns a tuple with both Si and Ci present in it
  1105. # which needs to be unpacked
  1106. if sympy_fn == Si:
  1107. scipy_result = scipy_fn(tv)[0]
  1108. if sympy_fn == Ci:
  1109. scipy_result = scipy_fn(tv)[1]
  1110. assert abs(f(tv) - sympy_result) < 1e-13*(1 + abs(sympy_result))
  1111. assert abs(f(tv) - scipy_result) < 1e-13*(1 + abs(sympy_result))
  1112. double_arg_sympy_fns = [RisingFactorial, besselj, bessely, besseli,
  1113. besselk, polygamma]
  1114. double_arg_scipy_fns = [scipy.special.poch, scipy.special.jv,
  1115. scipy.special.yv, scipy.special.iv, scipy.special.kv, scipy.special.polygamma]
  1116. for (sympy_fn, scipy_fn) in zip(double_arg_sympy_fns, double_arg_scipy_fns):
  1117. f = lambdify((x, y), sympy_fn(x, y), modules="scipy")
  1118. for i in range(20):
  1119. # SciPy supports only real orders of Bessel functions
  1120. tv1 = numpy.random.uniform(-10, 10)
  1121. tv2 = numpy.random.uniform(-10, 10) + 1j*numpy.random.uniform(-5, 5)
  1122. # SciPy requires a real valued 2nd argument for: poch, polygamma
  1123. if sympy_fn in (RisingFactorial, polygamma):
  1124. tv2 = numpy.real(tv2)
  1125. if sympy_fn == polygamma:
  1126. tv1 = abs(int(tv1)) # first argument to polygamma must be a non-negative integer.
  1127. sympy_result = sympy_fn(tv1, tv2).evalf()
  1128. assert abs(f(tv1, tv2) - sympy_result) < 1e-13*(1 + abs(sympy_result))
  1129. assert abs(f(tv1, tv2) - scipy_fn(tv1, tv2)) < 1e-13*(1 + abs(sympy_result))
  1130. def test_scipy_polys():
  1131. if not scipy:
  1132. skip("scipy not installed")
  1133. numpy.random.seed(0)
  1134. params = symbols('n k a b')
  1135. # list polynomials with the number of parameters
  1136. polys = [
  1137. (chebyshevt, 1),
  1138. (chebyshevu, 1),
  1139. (legendre, 1),
  1140. (hermite, 1),
  1141. (laguerre, 1),
  1142. (gegenbauer, 2),
  1143. (assoc_legendre, 2),
  1144. (assoc_laguerre, 2),
  1145. (jacobi, 3)
  1146. ]
  1147. msg = \
  1148. "The random test of the function {func} with the arguments " \
  1149. "{args} had failed because the SymPy result {sympy_result} " \
  1150. "and SciPy result {scipy_result} had failed to converge " \
  1151. "within the tolerance {tol} " \
  1152. "(Actual absolute difference : {diff})"
  1153. for sympy_fn, num_params in polys:
  1154. args = params[:num_params] + (x,)
  1155. f = lambdify(args, sympy_fn(*args))
  1156. for _ in range(10):
  1157. tn = numpy.random.randint(3, 10)
  1158. tparams = tuple(numpy.random.uniform(0, 5, size=num_params-1))
  1159. tv = numpy.random.uniform(-10, 10) + 1j*numpy.random.uniform(-5, 5)
  1160. # SciPy supports hermite for real arguments only
  1161. if sympy_fn == hermite:
  1162. tv = numpy.real(tv)
  1163. # assoc_legendre needs x in (-1, 1) and integer param at most n
  1164. if sympy_fn == assoc_legendre:
  1165. tv = numpy.random.uniform(-1, 1)
  1166. tparams = tuple(numpy.random.randint(1, tn, size=1))
  1167. vals = (tn,) + tparams + (tv,)
  1168. scipy_result = f(*vals)
  1169. sympy_result = sympy_fn(*vals).evalf()
  1170. atol = 1e-9*(1 + abs(sympy_result))
  1171. diff = abs(scipy_result - sympy_result)
  1172. try:
  1173. assert diff < atol
  1174. except TypeError:
  1175. raise AssertionError(
  1176. msg.format(
  1177. func=repr(sympy_fn),
  1178. args=repr(vals),
  1179. sympy_result=repr(sympy_result),
  1180. scipy_result=repr(scipy_result),
  1181. diff=diff,
  1182. tol=atol)
  1183. )
  1184. def test_lambdify_inspect():
  1185. f = lambdify(x, x**2)
  1186. # Test that inspect.getsource works but don't hard-code implementation
  1187. # details
  1188. assert 'x**2' in inspect.getsource(f)
  1189. def test_issue_14941():
  1190. x, y = Dummy(), Dummy()
  1191. # test dict
  1192. f1 = lambdify([x, y], {x: 3, y: 3}, 'sympy')
  1193. assert f1(2, 3) == {2: 3, 3: 3}
  1194. # test tuple
  1195. f2 = lambdify([x, y], (y, x), 'sympy')
  1196. assert f2(2, 3) == (3, 2)
  1197. f2b = lambdify([], (1,)) # gh-23224
  1198. assert f2b() == (1,)
  1199. # test list
  1200. f3 = lambdify([x, y], [y, x], 'sympy')
  1201. assert f3(2, 3) == [3, 2]
  1202. def test_lambdify_Derivative_arg_issue_16468():
  1203. f = Function('f')(x)
  1204. fx = f.diff()
  1205. assert lambdify((f, fx), f + fx)(10, 5) == 15
  1206. assert eval(lambdastr((f, fx), f/fx))(10, 5) == 2
  1207. raises(Exception, lambda:
  1208. eval(lambdastr((f, fx), f/fx, dummify=False)))
  1209. assert eval(lambdastr((f, fx), f/fx, dummify=True))(10, 5) == 2
  1210. assert eval(lambdastr((fx, f), f/fx, dummify=True))(S(10), 5) == S.Half
  1211. assert lambdify(fx, 1 + fx)(41) == 42
  1212. assert eval(lambdastr(fx, 1 + fx, dummify=True))(41) == 42
  1213. def test_lambdify_Derivative_zeta():
  1214. # This is related to gh-11802 (and to lesser extent gh-26663)
  1215. expr1 = zeta(x).diff(x, evaluate=False)
  1216. f1 = lambdify(x, expr1, modules=['mpmath'])
  1217. ans1 = f1(2)
  1218. ref1 = (zeta(2+1e-8).evalf()-zeta(2).evalf())/1e-8
  1219. assert abs(ans1 - ref1)/abs(ref1) < 1e-7
  1220. expr2 = zeta(x**2).diff(x)
  1221. f2 = lambdify(x, expr2, modules=['mpmath'])
  1222. ans2 = f2(2**0.5)
  1223. ref2 = 2*2**0.5*ref1
  1224. assert abs(ans2-ref2)/abs(ref2) < 1e-7
  1225. def test_lambdify_Derivative_custom_printer():
  1226. func1 = Function('func1')
  1227. func2 = Function('func2')
  1228. class MyPrinter(NumPyPrinter):
  1229. def _print_Derivative_func1(self, args, seq_orders):
  1230. arg, = args
  1231. order, = seq_orders
  1232. return '42'
  1233. expr1 = func1(x).diff(x)
  1234. raises(PrintMethodNotImplementedError, lambda: lambdify([x], expr1))
  1235. f1 = lambdify([x], expr1, printer=MyPrinter)
  1236. assert f1(7) == 42
  1237. expr2 = func2(x).diff(x)
  1238. raises(PrintMethodNotImplementedError, lambda: lambdify([x], expr2, printer=MyPrinter))
  1239. def test_lambdify_derivative_and_functions_as_arguments():
  1240. # see: https://github.com/sympy/sympy/issues/26663#issuecomment-2157179517
  1241. t, a, b = symbols('t, a, b')
  1242. f = Function('f')(t)
  1243. args = f.diff(t, 2), f.diff(t), f, a, b
  1244. expr1 = a*f.diff(t, 2) + b*f.diff(t) + a*b*f + a**2
  1245. num_args = 2.0, 3.0, 4.0, 5.0, 6.0
  1246. ref1 = 5*2 + 6*3 + 5*6*4 + 5**2
  1247. expr2 = a*f.diff(t, 2) + b*f.diff(t) - a*b*f + b**2 - a**2
  1248. ref2 = 5*2 + 6*3 - 5*6*4 + 6**2 - 5**2
  1249. for dummify, _cse in product([False, None, True], [False, True]):
  1250. func1 = lambdify(args, expr1, cse=_cse, dummify=dummify)
  1251. res1 = func1(*num_args)
  1252. assert abs(res1 - ref1) < 1e-12
  1253. func12 = lambdify(args, [expr1, expr2], cse=_cse, dummify=dummify)
  1254. res12 = func12(*num_args)
  1255. assert len(res12) == 2
  1256. assert abs(res12[0] - ref1) < 1e-12
  1257. assert abs(res12[1] - ref2) < 1e-12
  1258. def test_imag_real():
  1259. f_re = lambdify([z], sympy.re(z))
  1260. val = 3+2j
  1261. assert f_re(val) == val.real
  1262. f_im = lambdify([z], sympy.im(z)) # see #15400
  1263. assert f_im(val) == val.imag
  1264. def test_MatrixSymbol_issue_15578():
  1265. if not numpy:
  1266. skip("numpy not installed")
  1267. A = MatrixSymbol('A', 2, 2)
  1268. A0 = numpy.array([[1, 2], [3, 4]])
  1269. f = lambdify(A, A**(-1))
  1270. assert numpy.allclose(f(A0), numpy.array([[-2., 1.], [1.5, -0.5]]))
  1271. g = lambdify(A, A**3)
  1272. assert numpy.allclose(g(A0), numpy.array([[37, 54], [81, 118]]))
  1273. def test_issue_15654():
  1274. if not scipy:
  1275. skip("scipy not installed")
  1276. from sympy.abc import n, l, r, Z
  1277. from sympy.physics import hydrogen
  1278. nv, lv, rv, Zv = 1, 0, 3, 1
  1279. sympy_value = hydrogen.R_nl(nv, lv, rv, Zv).evalf()
  1280. f = lambdify((n, l, r, Z), hydrogen.R_nl(n, l, r, Z))
  1281. scipy_value = f(nv, lv, rv, Zv)
  1282. assert abs(sympy_value - scipy_value) < 1e-15
  1283. def test_issue_15827():
  1284. if not numpy:
  1285. skip("numpy not installed")
  1286. A = MatrixSymbol("A", 3, 3)
  1287. B = MatrixSymbol("B", 2, 3)
  1288. C = MatrixSymbol("C", 3, 4)
  1289. D = MatrixSymbol("D", 4, 5)
  1290. k=symbols("k")
  1291. f = lambdify(A, (2*k)*A)
  1292. g = lambdify(A, (2+k)*A)
  1293. h = lambdify(A, 2*A)
  1294. i = lambdify((B, C, D), 2*B*C*D)
  1295. assert numpy.array_equal(f(numpy.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])), \
  1296. numpy.array([[2*k, 4*k, 6*k], [2*k, 4*k, 6*k], [2*k, 4*k, 6*k]], dtype=object))
  1297. assert numpy.array_equal(g(numpy.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])), \
  1298. numpy.array([[k + 2, 2*k + 4, 3*k + 6], [k + 2, 2*k + 4, 3*k + 6], \
  1299. [k + 2, 2*k + 4, 3*k + 6]], dtype=object))
  1300. assert numpy.array_equal(h(numpy.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])), \
  1301. numpy.array([[2, 4, 6], [2, 4, 6], [2, 4, 6]]))
  1302. assert numpy.array_equal(i(numpy.array([[1, 2, 3], [1, 2, 3]]), numpy.array([[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]]), \
  1303. numpy.array([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])), numpy.array([[ 120, 240, 360, 480, 600], \
  1304. [ 120, 240, 360, 480, 600]]))
  1305. def test_issue_16930():
  1306. if not scipy:
  1307. skip("scipy not installed")
  1308. x = symbols("x")
  1309. f = lambda x: S.GoldenRatio * x**2
  1310. f_ = lambdify(x, f(x), modules='scipy')
  1311. assert f_(1) == scipy.constants.golden_ratio
  1312. def test_issue_17898():
  1313. if not scipy:
  1314. skip("scipy not installed")
  1315. x = symbols("x")
  1316. f_ = lambdify([x], sympy.LambertW(x,-1), modules='scipy')
  1317. assert f_(0.1) == mpmath.lambertw(0.1, -1)
  1318. def test_issue_13167_21411():
  1319. if not numpy:
  1320. skip("numpy not installed")
  1321. f1 = lambdify(x, sympy.Heaviside(x))
  1322. f2 = lambdify(x, sympy.Heaviside(x, 1))
  1323. res1 = f1([-1, 0, 1])
  1324. res2 = f2([-1, 0, 1])
  1325. assert Abs(res1[0]).n() < 1e-15 # First functionality: only one argument passed
  1326. assert Abs(res1[1] - 1/2).n() < 1e-15
  1327. assert Abs(res1[2] - 1).n() < 1e-15
  1328. assert Abs(res2[0]).n() < 1e-15 # Second functionality: two arguments passed
  1329. assert Abs(res2[1] - 1).n() < 1e-15
  1330. assert Abs(res2[2] - 1).n() < 1e-15
  1331. def test_single_e():
  1332. f = lambdify(x, E)
  1333. assert f(23) == exp(1.0)
  1334. def test_issue_16536():
  1335. if not scipy:
  1336. skip("scipy not installed")
  1337. a = symbols('a')
  1338. f1 = lowergamma(a, x)
  1339. F = lambdify((a, x), f1, modules='scipy')
  1340. assert abs(lowergamma(1, 3) - F(1, 3)) <= 1e-10
  1341. f2 = uppergamma(a, x)
  1342. F = lambdify((a, x), f2, modules='scipy')
  1343. assert abs(uppergamma(1, 3) - F(1, 3)) <= 1e-10
  1344. def test_issue_22726():
  1345. if not numpy:
  1346. skip("numpy not installed")
  1347. x1, x2 = symbols('x1 x2')
  1348. f = Max(S.Zero, Min(x1, x2))
  1349. g = derive_by_array(f, (x1, x2))
  1350. G = lambdify((x1, x2), g, modules='numpy')
  1351. point = {x1: 1, x2: 2}
  1352. assert (abs(g.subs(point) - G(*point.values())) <= 1e-10).all()
  1353. def test_issue_22739():
  1354. if not numpy:
  1355. skip("numpy not installed")
  1356. x1, x2 = symbols('x1 x2')
  1357. f = Heaviside(Min(x1, x2))
  1358. F = lambdify((x1, x2), f, modules='numpy')
  1359. point = {x1: 1, x2: 2}
  1360. assert abs(f.subs(point) - F(*point.values())) <= 1e-10
  1361. def test_issue_22992():
  1362. if not numpy:
  1363. skip("numpy not installed")
  1364. a, t = symbols('a t')
  1365. expr = a*(log(cot(t/2)) - cos(t))
  1366. F = lambdify([a, t], expr, 'numpy')
  1367. point = {a: 10, t: 2}
  1368. assert abs(expr.subs(point) - F(*point.values())) <= 1e-10
  1369. # Standard math
  1370. F = lambdify([a, t], expr)
  1371. assert abs(expr.subs(point) - F(*point.values())) <= 1e-10
  1372. def test_issue_19764():
  1373. if not numpy:
  1374. skip("numpy not installed")
  1375. expr = Array([x, x**2])
  1376. f = lambdify(x, expr, 'numpy')
  1377. assert f(1).__class__ == numpy.ndarray
  1378. def test_issue_20070():
  1379. if not numba:
  1380. skip("numba not installed")
  1381. f = lambdify(x, sin(x), 'numpy')
  1382. assert numba.jit(f, nopython=True)(1)==0.8414709848078965
  1383. def test_fresnel_integrals_scipy():
  1384. if not scipy:
  1385. skip("scipy not installed")
  1386. f1 = fresnelc(x)
  1387. f2 = fresnels(x)
  1388. F1 = lambdify(x, f1, modules='scipy')
  1389. F2 = lambdify(x, f2, modules='scipy')
  1390. assert abs(fresnelc(1.3) - F1(1.3)) <= 1e-10
  1391. assert abs(fresnels(1.3) - F2(1.3)) <= 1e-10
  1392. def test_beta_scipy():
  1393. if not scipy:
  1394. skip("scipy not installed")
  1395. f = beta(x, y)
  1396. F = lambdify((x, y), f, modules='scipy')
  1397. assert abs(beta(1.3, 2.3) - F(1.3, 2.3)) <= 1e-10
  1398. def test_beta_math():
  1399. f = beta(x, y)
  1400. F = lambdify((x, y), f, modules='math')
  1401. assert abs(beta(1.3, 2.3) - F(1.3, 2.3)) <= 1e-10
  1402. def test_betainc_scipy():
  1403. if not scipy:
  1404. skip("scipy not installed")
  1405. f = betainc(w, x, y, z)
  1406. F = lambdify((w, x, y, z), f, modules='scipy')
  1407. assert abs(betainc(1.4, 3.1, 0.1, 0.5) - F(1.4, 3.1, 0.1, 0.5)) <= 1e-10
  1408. def test_betainc_regularized_scipy():
  1409. if not scipy:
  1410. skip("scipy not installed")
  1411. f = betainc_regularized(w, x, y, z)
  1412. F = lambdify((w, x, y, z), f, modules='scipy')
  1413. assert abs(betainc_regularized(0.2, 3.5, 0.1, 1) - F(0.2, 3.5, 0.1, 1)) <= 1e-10
  1414. def test_numpy_special_math():
  1415. if not numpy:
  1416. skip("numpy not installed")
  1417. funcs = [expm1, log1p, exp2, log2, log10, hypot, logaddexp, logaddexp2]
  1418. for func in funcs:
  1419. if 2 in func.nargs:
  1420. expr = func(x, y)
  1421. args = (x, y)
  1422. num_args = (0.3, 0.4)
  1423. elif 1 in func.nargs:
  1424. expr = func(x)
  1425. args = (x,)
  1426. num_args = (0.3,)
  1427. else:
  1428. raise NotImplementedError("Need to handle other than unary & binary functions in test")
  1429. f = lambdify(args, expr)
  1430. result = f(*num_args)
  1431. reference = expr.subs(dict(zip(args, num_args))).evalf()
  1432. assert numpy.allclose(result, float(reference))
  1433. lae2 = lambdify((x, y), logaddexp2(log2(x), log2(y)))
  1434. assert abs(2.0**lae2(1e-50, 2.5e-50) - 3.5e-50) < 1e-62 # from NumPy's docstring
  1435. def test_scipy_special_math():
  1436. if not scipy:
  1437. skip("scipy not installed")
  1438. cm1 = lambdify((x,), cosm1(x), modules='scipy')
  1439. assert abs(cm1(1e-20) + 5e-41) < 1e-200
  1440. have_scipy_1_10plus = tuple(map(int, scipy.version.version.split('.')[:2])) >= (1, 10)
  1441. if have_scipy_1_10plus:
  1442. cm2 = lambdify((x, y), powm1(x, y), modules='scipy')
  1443. assert abs(cm2(1.2, 1e-9) - 1.82321557e-10) < 1e-17
  1444. def test_scipy_bernoulli():
  1445. if not scipy:
  1446. skip("scipy not installed")
  1447. bern = lambdify((x,), bernoulli(x), modules='scipy')
  1448. assert bern(1) == 0.5
  1449. def test_scipy_harmonic():
  1450. if not scipy:
  1451. skip("scipy not installed")
  1452. hn = lambdify((x,), harmonic(x), modules='scipy')
  1453. assert hn(2) == 1.5
  1454. hnm = lambdify((x, y), harmonic(x, y), modules='scipy')
  1455. assert hnm(2, 2) == 1.25
  1456. def test_cupy_array_arg():
  1457. if not cupy:
  1458. skip("CuPy not installed")
  1459. f = lambdify([[x, y]], x*x + y, 'cupy')
  1460. result = f(cupy.array([2.0, 1.0]))
  1461. assert result == 5
  1462. assert "cupy" in str(type(result))
  1463. def test_cupy_array_arg_using_numpy():
  1464. # numpy functions can be run on cupy arrays
  1465. # unclear if we can "officially" support this,
  1466. # depends on numpy __array_function__ support
  1467. if not cupy:
  1468. skip("CuPy not installed")
  1469. f = lambdify([[x, y]], x*x + y, 'numpy')
  1470. result = f(cupy.array([2.0, 1.0]))
  1471. assert result == 5
  1472. assert "cupy" in str(type(result))
  1473. def test_cupy_dotproduct():
  1474. if not cupy:
  1475. skip("CuPy not installed")
  1476. A = Matrix([x, y, z])
  1477. f1 = lambdify([x, y, z], DotProduct(A, A), modules='cupy')
  1478. f2 = lambdify([x, y, z], DotProduct(A, A.T), modules='cupy')
  1479. f3 = lambdify([x, y, z], DotProduct(A.T, A), modules='cupy')
  1480. f4 = lambdify([x, y, z], DotProduct(A, A.T), modules='cupy')
  1481. assert f1(1, 2, 3) == \
  1482. f2(1, 2, 3) == \
  1483. f3(1, 2, 3) == \
  1484. f4(1, 2, 3) == \
  1485. cupy.array([14])
  1486. def test_jax_array_arg():
  1487. if not jax:
  1488. skip("JAX not installed")
  1489. f = lambdify([[x, y]], x*x + y, 'jax')
  1490. result = f(jax.numpy.array([2.0, 1.0]))
  1491. assert result == 5
  1492. assert "jax" in str(type(result))
  1493. def test_jax_array_arg_using_numpy():
  1494. if not jax:
  1495. skip("JAX not installed")
  1496. f = lambdify([[x, y]], x*x + y, 'numpy')
  1497. result = f(jax.numpy.array([2.0, 1.0]))
  1498. assert result == 5
  1499. assert "jax" in str(type(result))
  1500. def test_jax_dotproduct():
  1501. if not jax:
  1502. skip("JAX not installed")
  1503. A = Matrix([x, y, z])
  1504. f1 = lambdify([x, y, z], DotProduct(A, A), modules='jax')
  1505. f2 = lambdify([x, y, z], DotProduct(A, A.T), modules='jax')
  1506. f3 = lambdify([x, y, z], DotProduct(A.T, A), modules='jax')
  1507. f4 = lambdify([x, y, z], DotProduct(A, A.T), modules='jax')
  1508. assert f1(1, 2, 3) == \
  1509. f2(1, 2, 3) == \
  1510. f3(1, 2, 3) == \
  1511. f4(1, 2, 3) == \
  1512. jax.numpy.array([14])
  1513. def test_lambdify_cse():
  1514. def no_op_cse(exprs):
  1515. return (), exprs
  1516. def dummy_cse(exprs):
  1517. from sympy.simplify.cse_main import cse
  1518. return cse(exprs, symbols=numbered_symbols(cls=Dummy))
  1519. def minmem(exprs):
  1520. from sympy.simplify.cse_main import cse_release_variables, cse
  1521. return cse(exprs, postprocess=cse_release_variables)
  1522. class Case:
  1523. def __init__(self, *, args, exprs, num_args, requires_numpy=False):
  1524. self.args = args
  1525. self.exprs = exprs
  1526. self.num_args = num_args
  1527. subs_dict = dict(zip(self.args, self.num_args))
  1528. self.ref = [e.subs(subs_dict).evalf() for e in exprs]
  1529. self.requires_numpy = requires_numpy
  1530. def lambdify(self, *, cse):
  1531. return lambdify(self.args, self.exprs, cse=cse)
  1532. def assertAllClose(self, result, *, abstol=1e-15, reltol=1e-15):
  1533. if self.requires_numpy:
  1534. assert all(numpy.allclose(result[i], numpy.asarray(r, dtype=float),
  1535. rtol=reltol, atol=abstol)
  1536. for i, r in enumerate(self.ref))
  1537. return
  1538. for i, r in enumerate(self.ref):
  1539. abs_err = abs(result[i] - r)
  1540. if r == 0:
  1541. assert abs_err < abstol
  1542. else:
  1543. assert abs_err/abs(r) < reltol
  1544. cases = [
  1545. Case(
  1546. args=(x, y, z),
  1547. exprs=[
  1548. x + y + z,
  1549. x + y - z,
  1550. 2*x + 2*y - z,
  1551. (x+y)**2 + (y+z)**2,
  1552. ],
  1553. num_args=(2., 3., 4.)
  1554. ),
  1555. Case(
  1556. args=(x, y, z),
  1557. exprs=[
  1558. x + sympy.Heaviside(x),
  1559. y + sympy.Heaviside(x),
  1560. z + sympy.Heaviside(x, 1),
  1561. z/sympy.Heaviside(x, 1)
  1562. ],
  1563. num_args=(0., 3., 4.)
  1564. ),
  1565. Case(
  1566. args=(x, y, z),
  1567. exprs=[
  1568. x + sinc(y),
  1569. y + sinc(y),
  1570. z - sinc(y)
  1571. ],
  1572. num_args=(0.1, 0.2, 0.3)
  1573. ),
  1574. Case(
  1575. args=(x, y, z),
  1576. exprs=[
  1577. Matrix([[x, x*y], [sin(z) + 4, x**z]]),
  1578. x*y+sin(z)-x**z,
  1579. Matrix([x*x, sin(z), x**z])
  1580. ],
  1581. num_args=(1.,2.,3.),
  1582. requires_numpy=True
  1583. ),
  1584. Case(
  1585. args=(x, y),
  1586. exprs=[(x + y - 1)**2, x, x + y,
  1587. (x + y)/(2*x + 1) + (x + y - 1)**2, (2*x + 1)**(x + y)],
  1588. num_args=(1,2)
  1589. )
  1590. ]
  1591. for case in cases:
  1592. if not numpy and case.requires_numpy:
  1593. continue
  1594. for _cse in [False, True, minmem, no_op_cse, dummy_cse]:
  1595. f = case.lambdify(cse=_cse)
  1596. result = f(*case.num_args)
  1597. case.assertAllClose(result)
  1598. def test_issue_25288():
  1599. syms = numbered_symbols(cls=Dummy)
  1600. ok = lambdify(x, [x**2, sin(x**2)], cse=lambda e: cse(e, symbols=syms))(2)
  1601. assert ok
  1602. def test_deprecated_set():
  1603. with warns_deprecated_sympy():
  1604. lambdify({x, y}, x + y)
  1605. def test_issue_13881():
  1606. if not numpy:
  1607. skip("numpy not installed.")
  1608. X = MatrixSymbol('X', 3, 1)
  1609. f = lambdify(X, X.T*X, 'numpy')
  1610. assert f(numpy.array([1, 2, 3])) == 14
  1611. assert f(numpy.array([3, 2, 1])) == 14
  1612. f = lambdify(X, X*X.T, 'numpy')
  1613. assert f(numpy.array([1, 2, 3])) == 14
  1614. assert f(numpy.array([3, 2, 1])) == 14
  1615. f = lambdify(X, (X*X.T)*X, 'numpy')
  1616. arr1 = numpy.array([[1], [2], [3]])
  1617. arr2 = numpy.array([[14],[28],[42]])
  1618. assert numpy.array_equal(f(arr1), arr2)
  1619. def test_23536_lambdify_cse_dummy():
  1620. f = Function('x')(y)
  1621. g = Function('w')(y)
  1622. expr = z + (f**4 + g**5)*(f**3 + (g*f)**3)
  1623. expr = expr.expand()
  1624. eval_expr = lambdify(((f, g), z), expr, cse=True)
  1625. ans = eval_expr((1.0, 2.0), 3.0) # shouldn't raise NameError
  1626. assert ans == 300.0 # not a list and value is 300
  1627. class LambdifyDocstringTestCase:
  1628. SIGNATURE = None
  1629. EXPR = None
  1630. SRC = None
  1631. def __init__(self, docstring_limit, expected_redacted):
  1632. self.docstring_limit = docstring_limit
  1633. self.expected_redacted = expected_redacted
  1634. @property
  1635. def expected_expr(self):
  1636. expr_redacted_msg = "EXPRESSION REDACTED DUE TO LENGTH, (see lambdify's `docstring_limit`)"
  1637. return self.EXPR if not self.expected_redacted else expr_redacted_msg
  1638. @property
  1639. def expected_src(self):
  1640. src_redacted_msg = "SOURCE CODE REDACTED DUE TO LENGTH, (see lambdify's `docstring_limit`)"
  1641. return self.SRC if not self.expected_redacted else src_redacted_msg
  1642. @property
  1643. def expected_docstring(self):
  1644. expected_docstring = (
  1645. f'Created with lambdify. Signature:\n\n'
  1646. f'func({self.SIGNATURE})\n\n'
  1647. f'Expression:\n\n'
  1648. f'{self.expected_expr}\n\n'
  1649. f'Source code:\n\n'
  1650. f'{self.expected_src}\n\n'
  1651. f'Imported modules:\n\n'
  1652. )
  1653. return expected_docstring
  1654. def __len__(self):
  1655. return len(self.expected_docstring)
  1656. def __repr__(self):
  1657. return (
  1658. f'{self.__class__.__name__}('
  1659. f'docstring_limit={self.docstring_limit}, '
  1660. f'expected_redacted={self.expected_redacted})'
  1661. )
  1662. def test_lambdify_docstring_size_limit_simple_symbol():
  1663. class SimpleSymbolTestCase(LambdifyDocstringTestCase):
  1664. SIGNATURE = 'x'
  1665. EXPR = 'x'
  1666. SRC = (
  1667. 'def _lambdifygenerated(x):\n'
  1668. ' return x\n'
  1669. )
  1670. x = symbols('x')
  1671. test_cases = (
  1672. SimpleSymbolTestCase(docstring_limit=None, expected_redacted=False),
  1673. SimpleSymbolTestCase(docstring_limit=100, expected_redacted=False),
  1674. SimpleSymbolTestCase(docstring_limit=1, expected_redacted=False),
  1675. SimpleSymbolTestCase(docstring_limit=0, expected_redacted=True),
  1676. SimpleSymbolTestCase(docstring_limit=-1, expected_redacted=True),
  1677. )
  1678. for test_case in test_cases:
  1679. lambdified_expr = lambdify(
  1680. [x],
  1681. x,
  1682. 'sympy',
  1683. docstring_limit=test_case.docstring_limit,
  1684. )
  1685. assert lambdified_expr.__doc__ == test_case.expected_docstring
  1686. def test_lambdify_docstring_size_limit_nested_expr():
  1687. class ExprListTestCase(LambdifyDocstringTestCase):
  1688. SIGNATURE = 'x, y, z'
  1689. EXPR = (
  1690. '[x, [y], z, x**3 + 3*x**2*y + 3*x**2*z + 3*x*y**2 + 6*x*y*z '
  1691. '+ 3*x*z**2 +...'
  1692. )
  1693. SRC = (
  1694. 'def _lambdifygenerated(x, y, z):\n'
  1695. ' return [x, [y], z, x**3 + 3*x**2*y + 3*x**2*z + 3*x*y**2 '
  1696. '+ 6*x*y*z + 3*x*z**2 + y**3 + 3*y**2*z + 3*y*z**2 + z**3]\n'
  1697. )
  1698. x, y, z = symbols('x, y, z')
  1699. expr = [x, [y], z, ((x + y + z)**3).expand()]
  1700. test_cases = (
  1701. ExprListTestCase(docstring_limit=None, expected_redacted=False),
  1702. ExprListTestCase(docstring_limit=200, expected_redacted=False),
  1703. ExprListTestCase(docstring_limit=50, expected_redacted=True),
  1704. ExprListTestCase(docstring_limit=0, expected_redacted=True),
  1705. ExprListTestCase(docstring_limit=-1, expected_redacted=True),
  1706. )
  1707. for test_case in test_cases:
  1708. lambdified_expr = lambdify(
  1709. [x, y, z],
  1710. expr,
  1711. 'sympy',
  1712. docstring_limit=test_case.docstring_limit,
  1713. )
  1714. assert lambdified_expr.__doc__ == test_case.expected_docstring
  1715. def test_lambdify_docstring_size_limit_matrix():
  1716. class MatrixTestCase(LambdifyDocstringTestCase):
  1717. SIGNATURE = 'x, y, z'
  1718. EXPR = (
  1719. 'Matrix([[0, x], [x + y + z, x**3 + 3*x**2*y + 3*x**2*z + 3*x*y**2 '
  1720. '+ 6*x*y*z...'
  1721. )
  1722. SRC = (
  1723. 'def _lambdifygenerated(x, y, z):\n'
  1724. ' return ImmutableDenseMatrix([[0, x], [x + y + z, x**3 '
  1725. '+ 3*x**2*y + 3*x**2*z + 3*x*y**2 + 6*x*y*z + 3*x*z**2 + y**3 '
  1726. '+ 3*y**2*z + 3*y*z**2 + z**3]])\n'
  1727. )
  1728. x, y, z = symbols('x, y, z')
  1729. expr = Matrix([[S.Zero, x], [x + y + z, ((x + y + z)**3).expand()]])
  1730. test_cases = (
  1731. MatrixTestCase(docstring_limit=None, expected_redacted=False),
  1732. MatrixTestCase(docstring_limit=200, expected_redacted=False),
  1733. MatrixTestCase(docstring_limit=50, expected_redacted=True),
  1734. MatrixTestCase(docstring_limit=0, expected_redacted=True),
  1735. MatrixTestCase(docstring_limit=-1, expected_redacted=True),
  1736. )
  1737. for test_case in test_cases:
  1738. lambdified_expr = lambdify(
  1739. [x, y, z],
  1740. expr,
  1741. 'sympy',
  1742. docstring_limit=test_case.docstring_limit,
  1743. )
  1744. assert lambdified_expr.__doc__ == test_case.expected_docstring
  1745. def test_lambdify_empty_tuple():
  1746. a = symbols("a")
  1747. expr = ((), (a,))
  1748. f = lambdify(a, expr)
  1749. result = f(1)
  1750. assert result == ((), (1,)), "Lambdify did not handle the empty tuple correctly."
  1751. def test_assoc_legendre_numerical_evaluation():
  1752. tol = 1e-10
  1753. sympy_result_integer = assoc_legendre(1, 1/2, 0.1).evalf()
  1754. sympy_result_complex = assoc_legendre(2, 1, 3).evalf()
  1755. mpmath_result_integer = -0.474572528387641
  1756. mpmath_result_complex = -25.45584412271571*I
  1757. assert all_close(sympy_result_integer, mpmath_result_integer, tol)
  1758. assert all_close(sympy_result_complex, mpmath_result_complex, tol)
  1759. def test_Piecewise():
  1760. modules = [math]
  1761. if numpy:
  1762. modules.append('numpy')
  1763. for mod in modules:
  1764. # test isinf
  1765. f = lambdify(x, Piecewise((7.0, isinf(x)), (3.0, True)), mod)
  1766. assert f(+float('inf')) == +7.0
  1767. assert f(-float('inf')) == +7.0
  1768. assert f(42.) == 3.0
  1769. f2 = lambdify(x, Piecewise((7.0*sign(x), isinf(x)), (3.0, True)), mod)
  1770. assert f2(+float('inf')) == +7.0
  1771. assert f2(-float('inf')) == -7.0
  1772. assert f2(42.) == 3.0
  1773. # test isnan (gh-26784)
  1774. g = lambdify(x, Piecewise((7.0, isnan(x)), (3.0, True)), mod)
  1775. assert g(float('nan')) == 7.0
  1776. assert g(42.) == 3.0
  1777. def test_array_symbol():
  1778. if not numpy:
  1779. skip("numpy not installed.")
  1780. a = ArraySymbol('a', (3,))
  1781. f = lambdify((a), a)
  1782. assert numpy.all(f(numpy.array([1,2,3])) == numpy.array([1,2,3]))