test_pycode.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. from sympy import Not
  2. from sympy.codegen import Assignment
  3. from sympy.codegen.ast import none
  4. from sympy.codegen.cfunctions import expm1, log1p
  5. from sympy.codegen.scipy_nodes import cosm1
  6. from sympy.codegen.matrix_nodes import MatrixSolve
  7. from sympy.core import Expr, Mod, symbols, Eq, Le, Gt, zoo, oo, Rational, Pow
  8. from sympy.core.function import Derivative
  9. from sympy.core.numbers import pi
  10. from sympy.core.singleton import S
  11. from sympy.functions import acos, KroneckerDelta, Piecewise, sign, sqrt, Min, Max, cot, acsch, asec, coth, sec, log, sin, cos, tan, asin, atan, sinh, cosh, tanh, asinh, acosh, atanh
  12. from sympy.functions.elementary.trigonometric import atan2
  13. from sympy.logic import And, Or
  14. from sympy.matrices import SparseMatrix, MatrixSymbol, Identity
  15. from sympy.printing.codeprinter import PrintMethodNotImplementedError
  16. from sympy.printing.pycode import (
  17. MpmathPrinter, CmathPrinter, PythonCodePrinter, pycode, SymPyPrinter
  18. )
  19. from sympy.printing.tensorflow import TensorflowPrinter
  20. from sympy.printing.numpy import NumPyPrinter, SciPyPrinter
  21. from sympy.testing.pytest import raises, skip
  22. from sympy.tensor import IndexedBase, Idx
  23. from sympy.tensor.array.expressions.array_expressions import ArraySymbol, ArrayDiagonal, ArrayContraction, ZeroArray, OneArray
  24. from sympy.external import import_module
  25. from sympy.functions.special.gamma_functions import loggamma
  26. x, y, z = symbols('x y z')
  27. p = IndexedBase("p")
  28. def test_PythonCodePrinter():
  29. prntr = PythonCodePrinter()
  30. assert not prntr.module_imports
  31. assert prntr.doprint(x**y) == 'x**y'
  32. assert prntr.doprint(Mod(x, 2)) == 'x % 2'
  33. assert prntr.doprint(-Mod(x, y)) == '-(x % y)'
  34. assert prntr.doprint(Mod(-x, y)) == '(-x) % y'
  35. assert prntr.doprint(And(x, y)) == 'x and y'
  36. assert prntr.doprint(Or(x, y)) == 'x or y'
  37. assert prntr.doprint(1/(x+y)) == '1/(x + y)'
  38. assert prntr.doprint(Not(x)) == 'not x'
  39. assert not prntr.module_imports
  40. assert prntr.doprint(pi) == 'math.pi'
  41. assert prntr.module_imports == {'math': {'pi'}}
  42. assert prntr.doprint(x**Rational(1, 2)) == 'math.sqrt(x)'
  43. assert prntr.doprint(sqrt(x)) == 'math.sqrt(x)'
  44. assert prntr.module_imports == {'math': {'pi', 'sqrt'}}
  45. assert prntr.doprint(acos(x)) == 'math.acos(x)'
  46. assert prntr.doprint(cot(x)) == '(1/math.tan(x))'
  47. assert prntr.doprint(coth(x)) == '((math.exp(x) + math.exp(-x))/(math.exp(x) - math.exp(-x)))'
  48. assert prntr.doprint(asec(x)) == '(math.acos(1/x))'
  49. assert prntr.doprint(acsch(x)) == '(math.log(math.sqrt(1 + x**(-2)) + 1/x))'
  50. assert prntr.doprint(Assignment(x, 2)) == 'x = 2'
  51. assert prntr.doprint(Piecewise((1, Eq(x, 0)),
  52. (2, x>6))) == '((1) if (x == 0) else (2) if (x > 6) else None)'
  53. assert prntr.doprint(Piecewise((2, Le(x, 0)),
  54. (3, Gt(x, 0)), evaluate=False)) == '((2) if (x <= 0) else'\
  55. ' (3) if (x > 0) else None)'
  56. assert prntr.doprint(sign(x)) == '(0.0 if x == 0 else math.copysign(1, x))'
  57. assert prntr.doprint(p[0, 1]) == 'p[0, 1]'
  58. assert prntr.doprint(KroneckerDelta(x,y)) == '(1 if x == y else 0)'
  59. assert prntr.doprint((2,3)) == "(2, 3)"
  60. assert prntr.doprint([2,3]) == "[2, 3]"
  61. assert prntr.doprint(Min(x, y)) == "min(x, y)"
  62. assert prntr.doprint(Max(x, y)) == "max(x, y)"
  63. def test_PythonCodePrinter_standard():
  64. prntr = PythonCodePrinter()
  65. assert prntr.standard == 'python3'
  66. raises(ValueError, lambda: PythonCodePrinter({'standard':'python4'}))
  67. def test_CmathPrinter():
  68. p = CmathPrinter()
  69. assert p.doprint(sqrt(x)) == 'cmath.sqrt(x)'
  70. assert p.doprint(log(x)) == 'cmath.log(x)'
  71. assert p.doprint(sin(x)) == 'cmath.sin(x)'
  72. assert p.doprint(cos(x)) == 'cmath.cos(x)'
  73. assert p.doprint(tan(x)) == 'cmath.tan(x)'
  74. assert p.doprint(asin(x)) == 'cmath.asin(x)'
  75. assert p.doprint(acos(x)) == 'cmath.acos(x)'
  76. assert p.doprint(atan(x)) == 'cmath.atan(x)'
  77. assert p.doprint(sinh(x)) == 'cmath.sinh(x)'
  78. assert p.doprint(cosh(x)) == 'cmath.cosh(x)'
  79. assert p.doprint(tanh(x)) == 'cmath.tanh(x)'
  80. assert p.doprint(asinh(x)) == 'cmath.asinh(x)'
  81. assert p.doprint(acosh(x)) == 'cmath.acosh(x)'
  82. assert p.doprint(atanh(x)) == 'cmath.atanh(x)'
  83. def test_MpmathPrinter():
  84. p = MpmathPrinter()
  85. assert p.doprint(sign(x)) == 'mpmath.sign(x)'
  86. assert p.doprint(Rational(1, 2)) == 'mpmath.mpf(1)/mpmath.mpf(2)'
  87. assert p.doprint(S.Exp1) == 'mpmath.e'
  88. assert p.doprint(S.Pi) == 'mpmath.pi'
  89. assert p.doprint(S.GoldenRatio) == 'mpmath.phi'
  90. assert p.doprint(S.EulerGamma) == 'mpmath.euler'
  91. assert p.doprint(S.NaN) == 'mpmath.nan'
  92. assert p.doprint(S.Infinity) == 'mpmath.inf'
  93. assert p.doprint(S.NegativeInfinity) == 'mpmath.ninf'
  94. assert p.doprint(loggamma(x)) == 'mpmath.loggamma(x)'
  95. def test_NumPyPrinter():
  96. from sympy.core.function import Lambda
  97. from sympy.matrices.expressions.adjoint import Adjoint
  98. from sympy.matrices.expressions.diagonal import (DiagMatrix, DiagonalMatrix, DiagonalOf)
  99. from sympy.matrices.expressions.funcmatrix import FunctionMatrix
  100. from sympy.matrices.expressions.hadamard import HadamardProduct
  101. from sympy.matrices.expressions.kronecker import KroneckerProduct
  102. from sympy.matrices.expressions.special import (OneMatrix, ZeroMatrix)
  103. from sympy.abc import a, b
  104. p = NumPyPrinter()
  105. assert p.doprint(sign(x)) == 'numpy.sign(x)'
  106. A = MatrixSymbol("A", 2, 2)
  107. B = MatrixSymbol("B", 2, 2)
  108. C = MatrixSymbol("C", 1, 5)
  109. D = MatrixSymbol("D", 3, 4)
  110. assert p.doprint(A**(-1)) == "numpy.linalg.inv(A)"
  111. assert p.doprint(A**5) == "numpy.linalg.matrix_power(A, 5)"
  112. assert p.doprint(Identity(3)) == "numpy.eye(3)"
  113. u = MatrixSymbol('x', 2, 1)
  114. v = MatrixSymbol('y', 2, 1)
  115. assert p.doprint(MatrixSolve(A, u)) == 'numpy.linalg.solve(A, x)'
  116. assert p.doprint(MatrixSolve(A, u) + v) == 'numpy.linalg.solve(A, x) + y'
  117. assert p.doprint(ZeroMatrix(2, 3)) == "numpy.zeros((2, 3))"
  118. assert p.doprint(OneMatrix(2, 3)) == "numpy.ones((2, 3))"
  119. assert p.doprint(FunctionMatrix(4, 5, Lambda((a, b), a + b))) == \
  120. "numpy.fromfunction(lambda a, b: a + b, (4, 5))"
  121. assert p.doprint(HadamardProduct(A, B)) == "numpy.multiply(A, B)"
  122. assert p.doprint(KroneckerProduct(A, B)) == "numpy.kron(A, B)"
  123. assert p.doprint(Adjoint(A)) == "numpy.conjugate(numpy.transpose(A))"
  124. assert p.doprint(DiagonalOf(A)) == "numpy.reshape(numpy.diag(A), (-1, 1))"
  125. assert p.doprint(DiagMatrix(C)) == "numpy.diagflat(C)"
  126. assert p.doprint(DiagonalMatrix(D)) == "numpy.multiply(D, numpy.eye(3, 4))"
  127. # Workaround for numpy negative integer power errors
  128. assert p.doprint(x**-1) == 'x**(-1.0)'
  129. assert p.doprint(x**-2) == 'x**(-2.0)'
  130. expr = Pow(2, -1, evaluate=False)
  131. assert p.doprint(expr) == "2**(-1.0)"
  132. assert p.doprint(S.Exp1) == 'numpy.e'
  133. assert p.doprint(S.Pi) == 'numpy.pi'
  134. assert p.doprint(S.EulerGamma) == 'numpy.euler_gamma'
  135. assert p.doprint(S.NaN) == 'numpy.nan'
  136. assert p.doprint(S.Infinity) == 'numpy.inf'
  137. assert p.doprint(S.NegativeInfinity) == '-numpy.inf'
  138. # Function rewriting operator precedence fix
  139. assert p.doprint(sec(x)**2) == '(numpy.cos(x)**(-1.0))**2'
  140. def test_issue_18770():
  141. numpy = import_module('numpy')
  142. if not numpy:
  143. skip("numpy not installed.")
  144. from sympy.functions.elementary.miscellaneous import (Max, Min)
  145. from sympy.utilities.lambdify import lambdify
  146. expr1 = Min(0.1*x + 3, x + 1, 0.5*x + 1)
  147. func = lambdify(x, expr1, "numpy")
  148. assert (func(numpy.linspace(0, 3, 3)) == [1.0, 1.75, 2.5 ]).all()
  149. assert func(4) == 3
  150. expr1 = Max(x**2, x**3)
  151. func = lambdify(x,expr1, "numpy")
  152. assert (func(numpy.linspace(-1, 2, 4)) == [1, 0, 1, 8] ).all()
  153. assert func(4) == 64
  154. def test_SciPyPrinter():
  155. p = SciPyPrinter()
  156. expr = acos(x)
  157. assert 'numpy' not in p.module_imports
  158. assert p.doprint(expr) == 'numpy.arccos(x)'
  159. assert 'numpy' in p.module_imports
  160. assert not any(m.startswith('scipy') for m in p.module_imports)
  161. smat = SparseMatrix(2, 5, {(0, 1): 3})
  162. assert p.doprint(smat) == \
  163. 'scipy.sparse.coo_matrix(([3], ([0], [1])), shape=(2, 5))'
  164. assert 'scipy.sparse' in p.module_imports
  165. assert p.doprint(S.GoldenRatio) == 'scipy.constants.golden_ratio'
  166. assert p.doprint(S.Pi) == 'scipy.constants.pi'
  167. assert p.doprint(S.Exp1) == 'numpy.e'
  168. def test_pycode_reserved_words():
  169. s1, s2 = symbols('if else')
  170. raises(ValueError, lambda: pycode(s1 + s2, error_on_reserved=True))
  171. py_str = pycode(s1 + s2)
  172. assert py_str in ('else_ + if_', 'if_ + else_')
  173. def test_issue_20762():
  174. # Make sure pycode removes curly braces from subscripted variables
  175. a_b, b, a_11 = symbols('a_{b} b a_{11}')
  176. expr = a_b*b
  177. assert pycode(expr) == 'a_b*b'
  178. expr = a_11*b
  179. assert pycode(expr) == 'a_11*b'
  180. def test_sqrt():
  181. prntr = PythonCodePrinter()
  182. assert prntr._print_Pow(sqrt(x), rational=False) == 'math.sqrt(x)'
  183. assert prntr._print_Pow(1/sqrt(x), rational=False) == '1/math.sqrt(x)'
  184. prntr = PythonCodePrinter({'standard' : 'python3'})
  185. assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
  186. assert prntr._print_Pow(1/sqrt(x), rational=True) == 'x**(-1/2)'
  187. prntr = MpmathPrinter()
  188. assert prntr._print_Pow(sqrt(x), rational=False) == 'mpmath.sqrt(x)'
  189. assert prntr._print_Pow(sqrt(x), rational=True) == \
  190. "x**(mpmath.mpf(1)/mpmath.mpf(2))"
  191. prntr = NumPyPrinter()
  192. assert prntr._print_Pow(sqrt(x), rational=False) == 'numpy.sqrt(x)'
  193. assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
  194. prntr = SciPyPrinter()
  195. assert prntr._print_Pow(sqrt(x), rational=False) == 'numpy.sqrt(x)'
  196. assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
  197. prntr = SymPyPrinter()
  198. assert prntr._print_Pow(sqrt(x), rational=False) == 'sympy.sqrt(x)'
  199. assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)'
  200. def test_frac():
  201. from sympy.functions.elementary.integers import frac
  202. expr = frac(x)
  203. prntr = NumPyPrinter()
  204. assert prntr.doprint(expr) == 'numpy.mod(x, 1)'
  205. prntr = SciPyPrinter()
  206. assert prntr.doprint(expr) == 'numpy.mod(x, 1)'
  207. prntr = PythonCodePrinter()
  208. assert prntr.doprint(expr) == 'x % 1'
  209. prntr = MpmathPrinter()
  210. assert prntr.doprint(expr) == 'mpmath.frac(x)'
  211. prntr = SymPyPrinter()
  212. assert prntr.doprint(expr) == 'sympy.functions.elementary.integers.frac(x)'
  213. class CustomPrintedObject(Expr):
  214. def _numpycode(self, printer):
  215. return 'numpy'
  216. def _mpmathcode(self, printer):
  217. return 'mpmath'
  218. def test_printmethod():
  219. obj = CustomPrintedObject()
  220. assert NumPyPrinter().doprint(obj) == 'numpy'
  221. assert MpmathPrinter().doprint(obj) == 'mpmath'
  222. def test_codegen_ast_nodes():
  223. assert pycode(none) == 'None'
  224. def test_issue_14283():
  225. prntr = PythonCodePrinter()
  226. assert prntr.doprint(zoo) == "math.nan"
  227. assert prntr.doprint(-oo) == "float('-inf')"
  228. def test_NumPyPrinter_print_seq():
  229. n = NumPyPrinter()
  230. assert n._print_seq(range(2)) == '(0, 1,)'
  231. def test_issue_16535_16536():
  232. from sympy.functions.special.gamma_functions import (lowergamma, uppergamma)
  233. a = symbols('a')
  234. expr1 = lowergamma(a, x)
  235. expr2 = uppergamma(a, x)
  236. prntr = SciPyPrinter()
  237. assert prntr.doprint(expr1) == 'scipy.special.gamma(a)*scipy.special.gammainc(a, x)'
  238. assert prntr.doprint(expr2) == 'scipy.special.gamma(a)*scipy.special.gammaincc(a, x)'
  239. p_numpy = NumPyPrinter()
  240. p_pycode = PythonCodePrinter({'strict': False})
  241. for expr in [expr1, expr2]:
  242. with raises(NotImplementedError):
  243. p_numpy.doprint(expr1)
  244. assert "Not supported" in p_pycode.doprint(expr)
  245. def test_Integral():
  246. from sympy.functions.elementary.exponential import exp
  247. from sympy.integrals.integrals import Integral
  248. single = Integral(exp(-x), (x, 0, oo))
  249. double = Integral(x**2*exp(x*y), (x, -z, z), (y, 0, z))
  250. indefinite = Integral(x**2, x)
  251. evaluateat = Integral(x**2, (x, 1))
  252. prntr = SciPyPrinter()
  253. assert prntr.doprint(single) == 'scipy.integrate.quad(lambda x: numpy.exp(-x), 0, numpy.inf)[0]'
  254. assert prntr.doprint(double) == 'scipy.integrate.nquad(lambda x, y: x**2*numpy.exp(x*y), ((-z, z), (0, z)))[0]'
  255. raises(NotImplementedError, lambda: prntr.doprint(indefinite))
  256. raises(NotImplementedError, lambda: prntr.doprint(evaluateat))
  257. prntr = MpmathPrinter()
  258. assert prntr.doprint(single) == 'mpmath.quad(lambda x: mpmath.exp(-x), (0, mpmath.inf))'
  259. assert prntr.doprint(double) == 'mpmath.quad(lambda x, y: x**2*mpmath.exp(x*y), (-z, z), (0, z))'
  260. raises(NotImplementedError, lambda: prntr.doprint(indefinite))
  261. raises(NotImplementedError, lambda: prntr.doprint(evaluateat))
  262. def test_fresnel_integrals():
  263. from sympy.functions.special.error_functions import (fresnelc, fresnels)
  264. expr1 = fresnelc(x)
  265. expr2 = fresnels(x)
  266. prntr = SciPyPrinter()
  267. assert prntr.doprint(expr1) == 'scipy.special.fresnel(x)[1]'
  268. assert prntr.doprint(expr2) == 'scipy.special.fresnel(x)[0]'
  269. p_numpy = NumPyPrinter()
  270. p_pycode = PythonCodePrinter()
  271. p_mpmath = MpmathPrinter()
  272. for expr in [expr1, expr2]:
  273. with raises(NotImplementedError):
  274. p_numpy.doprint(expr)
  275. with raises(NotImplementedError):
  276. p_pycode.doprint(expr)
  277. assert p_mpmath.doprint(expr1) == 'mpmath.fresnelc(x)'
  278. assert p_mpmath.doprint(expr2) == 'mpmath.fresnels(x)'
  279. def test_beta():
  280. from sympy.functions.special.beta_functions import beta
  281. expr = beta(x, y)
  282. prntr = SciPyPrinter()
  283. assert prntr.doprint(expr) == 'scipy.special.beta(x, y)'
  284. prntr = NumPyPrinter()
  285. assert prntr.doprint(expr) == '(math.gamma(x)*math.gamma(y)/math.gamma(x + y))'
  286. prntr = PythonCodePrinter()
  287. assert prntr.doprint(expr) == '(math.gamma(x)*math.gamma(y)/math.gamma(x + y))'
  288. prntr = PythonCodePrinter({'allow_unknown_functions': True})
  289. assert prntr.doprint(expr) == '(math.gamma(x)*math.gamma(y)/math.gamma(x + y))'
  290. prntr = MpmathPrinter()
  291. assert prntr.doprint(expr) == 'mpmath.beta(x, y)'
  292. def test_airy():
  293. from sympy.functions.special.bessel import (airyai, airybi)
  294. expr1 = airyai(x)
  295. expr2 = airybi(x)
  296. prntr = SciPyPrinter()
  297. assert prntr.doprint(expr1) == 'scipy.special.airy(x)[0]'
  298. assert prntr.doprint(expr2) == 'scipy.special.airy(x)[2]'
  299. prntr = NumPyPrinter({'strict': False})
  300. assert "Not supported" in prntr.doprint(expr1)
  301. assert "Not supported" in prntr.doprint(expr2)
  302. prntr = PythonCodePrinter({'strict': False})
  303. assert "Not supported" in prntr.doprint(expr1)
  304. assert "Not supported" in prntr.doprint(expr2)
  305. def test_airy_prime():
  306. from sympy.functions.special.bessel import (airyaiprime, airybiprime)
  307. expr1 = airyaiprime(x)
  308. expr2 = airybiprime(x)
  309. prntr = SciPyPrinter()
  310. assert prntr.doprint(expr1) == 'scipy.special.airy(x)[1]'
  311. assert prntr.doprint(expr2) == 'scipy.special.airy(x)[3]'
  312. prntr = NumPyPrinter({'strict': False})
  313. assert "Not supported" in prntr.doprint(expr1)
  314. assert "Not supported" in prntr.doprint(expr2)
  315. prntr = PythonCodePrinter({'strict': False})
  316. assert "Not supported" in prntr.doprint(expr1)
  317. assert "Not supported" in prntr.doprint(expr2)
  318. def test_numerical_accuracy_functions():
  319. prntr = SciPyPrinter()
  320. assert prntr.doprint(expm1(x)) == 'numpy.expm1(x)'
  321. assert prntr.doprint(log1p(x)) == 'numpy.log1p(x)'
  322. assert prntr.doprint(cosm1(x)) == 'scipy.special.cosm1(x)'
  323. def test_array_printer():
  324. A = ArraySymbol('A', (4,4,6,6,6))
  325. I = IndexedBase('I')
  326. i,j,k = Idx('i', (0,1)), Idx('j', (2,3)), Idx('k', (4,5))
  327. prntr = NumPyPrinter()
  328. assert prntr.doprint(ZeroArray(5)) == 'numpy.zeros((5,))'
  329. assert prntr.doprint(OneArray(5)) == 'numpy.ones((5,))'
  330. assert prntr.doprint(ArrayContraction(A, [2,3])) == 'numpy.einsum("abccd->abd", A)'
  331. assert prntr.doprint(I) == 'I'
  332. assert prntr.doprint(ArrayDiagonal(A, [2,3,4])) == 'numpy.einsum("abccc->abc", A)'
  333. assert prntr.doprint(ArrayDiagonal(A, [0,1], [2,3])) == 'numpy.einsum("aabbc->cab", A)'
  334. assert prntr.doprint(ArrayContraction(A, [2], [3])) == 'numpy.einsum("abcde->abe", A)'
  335. assert prntr.doprint(Assignment(I[i,j,k], I[i,j,k])) == 'I = I'
  336. prntr = TensorflowPrinter()
  337. assert prntr.doprint(ZeroArray(5)) == 'tensorflow.zeros((5,))'
  338. assert prntr.doprint(OneArray(5)) == 'tensorflow.ones((5,))'
  339. assert prntr.doprint(ArrayContraction(A, [2,3])) == 'tensorflow.linalg.einsum("abccd->abd", A)'
  340. assert prntr.doprint(I) == 'I'
  341. assert prntr.doprint(ArrayDiagonal(A, [2,3,4])) == 'tensorflow.linalg.einsum("abccc->abc", A)'
  342. assert prntr.doprint(ArrayDiagonal(A, [0,1], [2,3])) == 'tensorflow.linalg.einsum("aabbc->cab", A)'
  343. assert prntr.doprint(ArrayContraction(A, [2], [3])) == 'tensorflow.linalg.einsum("abcde->abe", A)'
  344. assert prntr.doprint(Assignment(I[i,j,k], I[i,j,k])) == 'I = I'
  345. def test_custom_Derivative_methods():
  346. class MyPrinter(SciPyPrinter):
  347. def _print_Derivative_cosm1(self, args, seq_orders):
  348. arg, = args
  349. order, = seq_orders
  350. return 'my_custom_cosm1(%s, deriv_order=%d)' % (self._print(arg), order)
  351. def _print_Derivative_atan2(self, args, seq_orders):
  352. arg1, arg2 = args
  353. ord1, ord2 = seq_orders
  354. return 'my_custom_atan2(%s, %s, deriv1=%d, deriv2=%d)' % (
  355. self._print(arg1), self._print(arg2), ord1, ord2
  356. )
  357. p = MyPrinter()
  358. cosm1_1 = cosm1(x).diff(x, evaluate=False)
  359. assert p.doprint(cosm1_1) == 'my_custom_cosm1(x, deriv_order=1)'
  360. atan2_2_3 = atan2(x, y).diff(x, 2, y, 3, evaluate=False)
  361. assert p.doprint(atan2_2_3) == 'my_custom_atan2(x, y, deriv1=2, deriv2=3)'
  362. try:
  363. p.doprint(expm1(x).diff(x, evaluate=False))
  364. except PrintMethodNotImplementedError as e:
  365. assert '_print_Derivative_expm1' in repr(e)
  366. else:
  367. assert False # should have thrown
  368. try:
  369. p.doprint(Derivative(cosm1(x**2),x))
  370. except ValueError as e:
  371. assert '_print_Derivative(' in repr(e)
  372. else:
  373. assert False # should have thrown