pycode.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852
  1. """
  2. Python code printers
  3. This module contains Python code printers for plain Python as well as NumPy & SciPy enabled code.
  4. """
  5. from collections import defaultdict
  6. from itertools import chain
  7. from sympy.core import S
  8. from sympy.core.mod import Mod
  9. from .precedence import precedence
  10. from .codeprinter import CodePrinter
  11. _kw = {
  12. 'and', 'as', 'assert', 'break', 'class', 'continue', 'def', 'del', 'elif',
  13. 'else', 'except', 'finally', 'for', 'from', 'global', 'if', 'import', 'in',
  14. 'is', 'lambda', 'not', 'or', 'pass', 'raise', 'return', 'try', 'while',
  15. 'with', 'yield', 'None', 'False', 'nonlocal', 'True'
  16. }
  17. _known_functions = {
  18. 'Abs': 'abs',
  19. 'Min': 'min',
  20. 'Max': 'max',
  21. }
  22. _known_functions_math = {
  23. 'acos': 'acos',
  24. 'acosh': 'acosh',
  25. 'asin': 'asin',
  26. 'asinh': 'asinh',
  27. 'atan': 'atan',
  28. 'atan2': 'atan2',
  29. 'atanh': 'atanh',
  30. 'ceiling': 'ceil',
  31. 'cos': 'cos',
  32. 'cosh': 'cosh',
  33. 'erf': 'erf',
  34. 'erfc': 'erfc',
  35. 'exp': 'exp',
  36. 'expm1': 'expm1',
  37. 'factorial': 'factorial',
  38. 'floor': 'floor',
  39. 'gamma': 'gamma',
  40. 'hypot': 'hypot',
  41. 'isinf': 'isinf',
  42. 'isnan': 'isnan',
  43. 'loggamma': 'lgamma',
  44. 'log': 'log',
  45. 'ln': 'log',
  46. 'log10': 'log10',
  47. 'log1p': 'log1p',
  48. 'log2': 'log2',
  49. 'sin': 'sin',
  50. 'sinh': 'sinh',
  51. 'Sqrt': 'sqrt',
  52. 'tan': 'tan',
  53. 'tanh': 'tanh'
  54. } # Not used from ``math``: [copysign isclose isfinite isinf ldexp frexp pow modf
  55. # radians trunc fmod fsum gcd degrees fabs]
  56. _known_constants_math = {
  57. 'Exp1': 'e',
  58. 'Pi': 'pi',
  59. 'E': 'e',
  60. 'Infinity': 'inf',
  61. 'NaN': 'nan',
  62. 'ComplexInfinity': 'nan'
  63. }
  64. def _print_known_func(self, expr):
  65. known = self.known_functions[expr.__class__.__name__]
  66. return '{name}({args})'.format(name=self._module_format(known),
  67. args=', '.join((self._print(arg) for arg in expr.args)))
  68. def _print_known_const(self, expr):
  69. known = self.known_constants[expr.__class__.__name__]
  70. return self._module_format(known)
  71. class AbstractPythonCodePrinter(CodePrinter):
  72. printmethod = "_pythoncode"
  73. language = "Python"
  74. reserved_words = _kw
  75. modules = None # initialized to a set in __init__
  76. tab = ' '
  77. _kf = dict(chain(
  78. _known_functions.items(),
  79. [(k, 'math.' + v) for k, v in _known_functions_math.items()]
  80. ))
  81. _kc = {k: 'math.'+v for k, v in _known_constants_math.items()}
  82. _operators = {'and': 'and', 'or': 'or', 'not': 'not'}
  83. _default_settings = dict(
  84. CodePrinter._default_settings,
  85. user_functions={},
  86. precision=17,
  87. inline=True,
  88. fully_qualified_modules=True,
  89. contract=False,
  90. standard='python3',
  91. )
  92. def __init__(self, settings=None):
  93. super().__init__(settings)
  94. # Python standard handler
  95. std = self._settings['standard']
  96. if std is None:
  97. import sys
  98. std = 'python{}'.format(sys.version_info.major)
  99. if std != 'python3':
  100. raise ValueError('Only Python 3 is supported.')
  101. self.standard = std
  102. self.module_imports = defaultdict(set)
  103. # Known functions and constants handler
  104. self.known_functions = dict(self._kf, **(settings or {}).get(
  105. 'user_functions', {}))
  106. self.known_constants = dict(self._kc, **(settings or {}).get(
  107. 'user_constants', {}))
  108. def _declare_number_const(self, name, value):
  109. return "%s = %s" % (name, value)
  110. def _module_format(self, fqn, register=True):
  111. parts = fqn.split('.')
  112. if register and len(parts) > 1:
  113. self.module_imports['.'.join(parts[:-1])].add(parts[-1])
  114. if self._settings['fully_qualified_modules']:
  115. return fqn
  116. else:
  117. return fqn.split('(')[0].split('[')[0].split('.')[-1]
  118. def _format_code(self, lines):
  119. return lines
  120. def _get_statement(self, codestring):
  121. return "{}".format(codestring)
  122. def _get_comment(self, text):
  123. return " # {}".format(text)
  124. def _expand_fold_binary_op(self, op, args):
  125. """
  126. This method expands a fold on binary operations.
  127. ``functools.reduce`` is an example of a folded operation.
  128. For example, the expression
  129. `A + B + C + D`
  130. is folded into
  131. `((A + B) + C) + D`
  132. """
  133. if len(args) == 1:
  134. return self._print(args[0])
  135. else:
  136. return "%s(%s, %s)" % (
  137. self._module_format(op),
  138. self._expand_fold_binary_op(op, args[:-1]),
  139. self._print(args[-1]),
  140. )
  141. def _expand_reduce_binary_op(self, op, args):
  142. """
  143. This method expands a reduction on binary operations.
  144. Notice: this is NOT the same as ``functools.reduce``.
  145. For example, the expression
  146. `A + B + C + D`
  147. is reduced into:
  148. `(A + B) + (C + D)`
  149. """
  150. if len(args) == 1:
  151. return self._print(args[0])
  152. else:
  153. N = len(args)
  154. Nhalf = N // 2
  155. return "%s(%s, %s)" % (
  156. self._module_format(op),
  157. self._expand_reduce_binary_op(args[:Nhalf]),
  158. self._expand_reduce_binary_op(args[Nhalf:]),
  159. )
  160. def _print_NaN(self, expr):
  161. return "float('nan')"
  162. def _print_Infinity(self, expr):
  163. return "float('inf')"
  164. def _print_NegativeInfinity(self, expr):
  165. return "float('-inf')"
  166. def _print_ComplexInfinity(self, expr):
  167. return self._print_NaN(expr)
  168. def _print_Mod(self, expr):
  169. PREC = precedence(expr)
  170. return ('{} % {}'.format(*(self.parenthesize(x, PREC) for x in expr.args)))
  171. def _print_Piecewise(self, expr):
  172. result = []
  173. i = 0
  174. for arg in expr.args:
  175. e = arg.expr
  176. c = arg.cond
  177. if i == 0:
  178. result.append('(')
  179. result.append('(')
  180. result.append(self._print(e))
  181. result.append(')')
  182. result.append(' if ')
  183. result.append(self._print(c))
  184. result.append(' else ')
  185. i += 1
  186. result = result[:-1]
  187. if result[-1] == 'True':
  188. result = result[:-2]
  189. result.append(')')
  190. else:
  191. result.append(' else None)')
  192. return ''.join(result)
  193. def _print_Relational(self, expr):
  194. "Relational printer for Equality and Unequality"
  195. op = {
  196. '==' :'equal',
  197. '!=' :'not_equal',
  198. '<' :'less',
  199. '<=' :'less_equal',
  200. '>' :'greater',
  201. '>=' :'greater_equal',
  202. }
  203. if expr.rel_op in op:
  204. lhs = self._print(expr.lhs)
  205. rhs = self._print(expr.rhs)
  206. return '({lhs} {op} {rhs})'.format(op=expr.rel_op, lhs=lhs, rhs=rhs)
  207. return super()._print_Relational(expr)
  208. def _print_ITE(self, expr):
  209. from sympy.functions.elementary.piecewise import Piecewise
  210. return self._print(expr.rewrite(Piecewise))
  211. def _print_Sum(self, expr):
  212. loops = (
  213. 'for {i} in range({a}, {b}+1)'.format(
  214. i=self._print(i),
  215. a=self._print(a),
  216. b=self._print(b))
  217. for i, a, b in expr.limits[::-1])
  218. return '(builtins.sum({function} {loops}))'.format(
  219. function=self._print(expr.function),
  220. loops=' '.join(loops))
  221. def _print_ImaginaryUnit(self, expr):
  222. return '1j'
  223. def _print_KroneckerDelta(self, expr):
  224. a, b = expr.args
  225. return '(1 if {a} == {b} else 0)'.format(
  226. a = self._print(a),
  227. b = self._print(b)
  228. )
  229. def _print_MatrixBase(self, expr):
  230. name = expr.__class__.__name__
  231. func = self.known_functions.get(name, name)
  232. return "%s(%s)" % (func, self._print(expr.tolist()))
  233. _print_SparseRepMatrix = \
  234. _print_MutableSparseMatrix = \
  235. _print_ImmutableSparseMatrix = \
  236. _print_Matrix = \
  237. _print_DenseMatrix = \
  238. _print_MutableDenseMatrix = \
  239. _print_ImmutableMatrix = \
  240. _print_ImmutableDenseMatrix = \
  241. lambda self, expr: self._print_MatrixBase(expr)
  242. def _indent_codestring(self, codestring):
  243. return '\n'.join([self.tab + line for line in codestring.split('\n')])
  244. def _print_FunctionDefinition(self, fd):
  245. body = '\n'.join((self._print(arg) for arg in fd.body))
  246. return "def {name}({parameters}):\n{body}".format(
  247. name=self._print(fd.name),
  248. parameters=', '.join([self._print(var.symbol) for var in fd.parameters]),
  249. body=self._indent_codestring(body)
  250. )
  251. def _print_While(self, whl):
  252. body = '\n'.join((self._print(arg) for arg in whl.body))
  253. return "while {cond}:\n{body}".format(
  254. cond=self._print(whl.condition),
  255. body=self._indent_codestring(body)
  256. )
  257. def _print_Declaration(self, decl):
  258. return '%s = %s' % (
  259. self._print(decl.variable.symbol),
  260. self._print(decl.variable.value)
  261. )
  262. def _print_BreakToken(self, bt):
  263. return 'break'
  264. def _print_Return(self, ret):
  265. arg, = ret.args
  266. return 'return %s' % self._print(arg)
  267. def _print_Raise(self, rs):
  268. arg, = rs.args
  269. return 'raise %s' % self._print(arg)
  270. def _print_RuntimeError_(self, re):
  271. message, = re.args
  272. return "RuntimeError(%s)" % self._print(message)
  273. def _print_Print(self, prnt):
  274. print_args = ', '.join((self._print(arg) for arg in prnt.print_args))
  275. from sympy.codegen.ast import none
  276. if prnt.format_string != none:
  277. print_args = '{} % ({}), end=""'.format(
  278. self._print(prnt.format_string),
  279. print_args
  280. )
  281. if prnt.file != None: # Must be '!= None', cannot be 'is not None'
  282. print_args += ', file=%s' % self._print(prnt.file)
  283. return 'print(%s)' % print_args
  284. def _print_Stream(self, strm):
  285. if str(strm.name) == 'stdout':
  286. return self._module_format('sys.stdout')
  287. elif str(strm.name) == 'stderr':
  288. return self._module_format('sys.stderr')
  289. else:
  290. return self._print(strm.name)
  291. def _print_NoneToken(self, arg):
  292. return 'None'
  293. def _hprint_Pow(self, expr, rational=False, sqrt='math.sqrt'):
  294. """Printing helper function for ``Pow``
  295. Notes
  296. =====
  297. This preprocesses the ``sqrt`` as math formatter and prints division
  298. Examples
  299. ========
  300. >>> from sympy import sqrt
  301. >>> from sympy.printing.pycode import PythonCodePrinter
  302. >>> from sympy.abc import x
  303. Python code printer automatically looks up ``math.sqrt``.
  304. >>> printer = PythonCodePrinter()
  305. >>> printer._hprint_Pow(sqrt(x), rational=True)
  306. 'x**(1/2)'
  307. >>> printer._hprint_Pow(sqrt(x), rational=False)
  308. 'math.sqrt(x)'
  309. >>> printer._hprint_Pow(1/sqrt(x), rational=True)
  310. 'x**(-1/2)'
  311. >>> printer._hprint_Pow(1/sqrt(x), rational=False)
  312. '1/math.sqrt(x)'
  313. >>> printer._hprint_Pow(1/x, rational=False)
  314. '1/x'
  315. >>> printer._hprint_Pow(1/x, rational=True)
  316. 'x**(-1)'
  317. Using sqrt from numpy or mpmath
  318. >>> printer._hprint_Pow(sqrt(x), sqrt='numpy.sqrt')
  319. 'numpy.sqrt(x)'
  320. >>> printer._hprint_Pow(sqrt(x), sqrt='mpmath.sqrt')
  321. 'mpmath.sqrt(x)'
  322. See Also
  323. ========
  324. sympy.printing.str.StrPrinter._print_Pow
  325. """
  326. PREC = precedence(expr)
  327. if expr.exp == S.Half and not rational:
  328. func = self._module_format(sqrt)
  329. arg = self._print(expr.base)
  330. return '{func}({arg})'.format(func=func, arg=arg)
  331. if expr.is_commutative and not rational:
  332. if -expr.exp is S.Half:
  333. func = self._module_format(sqrt)
  334. num = self._print(S.One)
  335. arg = self._print(expr.base)
  336. return f"{num}/{func}({arg})"
  337. if expr.exp is S.NegativeOne:
  338. num = self._print(S.One)
  339. arg = self.parenthesize(expr.base, PREC, strict=False)
  340. return f"{num}/{arg}"
  341. base_str = self.parenthesize(expr.base, PREC, strict=False)
  342. exp_str = self.parenthesize(expr.exp, PREC, strict=False)
  343. return "{}**{}".format(base_str, exp_str)
  344. class ArrayPrinter:
  345. def _arrayify(self, indexed):
  346. from sympy.tensor.array.expressions.from_indexed_to_array import convert_indexed_to_array
  347. try:
  348. return convert_indexed_to_array(indexed)
  349. except Exception:
  350. return indexed
  351. def _get_einsum_string(self, subranks, contraction_indices):
  352. letters = self._get_letter_generator_for_einsum()
  353. contraction_string = ""
  354. counter = 0
  355. d = {j: min(i) for i in contraction_indices for j in i}
  356. indices = []
  357. for rank_arg in subranks:
  358. lindices = []
  359. for i in range(rank_arg):
  360. if counter in d:
  361. lindices.append(d[counter])
  362. else:
  363. lindices.append(counter)
  364. counter += 1
  365. indices.append(lindices)
  366. mapping = {}
  367. letters_free = []
  368. letters_dum = []
  369. for i in indices:
  370. for j in i:
  371. if j not in mapping:
  372. l = next(letters)
  373. mapping[j] = l
  374. else:
  375. l = mapping[j]
  376. contraction_string += l
  377. if j in d:
  378. if l not in letters_dum:
  379. letters_dum.append(l)
  380. else:
  381. letters_free.append(l)
  382. contraction_string += ","
  383. contraction_string = contraction_string[:-1]
  384. return contraction_string, letters_free, letters_dum
  385. def _get_letter_generator_for_einsum(self):
  386. for i in range(97, 123):
  387. yield chr(i)
  388. for i in range(65, 91):
  389. yield chr(i)
  390. raise ValueError("out of letters")
  391. def _print_ArrayTensorProduct(self, expr):
  392. letters = self._get_letter_generator_for_einsum()
  393. contraction_string = ",".join(["".join([next(letters) for j in range(i)]) for i in expr.subranks])
  394. return '%s("%s", %s)' % (
  395. self._module_format(self._module + "." + self._einsum),
  396. contraction_string,
  397. ", ".join([self._print(arg) for arg in expr.args])
  398. )
  399. def _print_ArrayContraction(self, expr):
  400. from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct
  401. base = expr.expr
  402. contraction_indices = expr.contraction_indices
  403. if isinstance(base, ArrayTensorProduct):
  404. elems = ",".join(["%s" % (self._print(arg)) for arg in base.args])
  405. ranks = base.subranks
  406. else:
  407. elems = self._print(base)
  408. ranks = [len(base.shape)]
  409. contraction_string, letters_free, letters_dum = self._get_einsum_string(ranks, contraction_indices)
  410. if not contraction_indices:
  411. return self._print(base)
  412. if isinstance(base, ArrayTensorProduct):
  413. elems = ",".join(["%s" % (self._print(arg)) for arg in base.args])
  414. else:
  415. elems = self._print(base)
  416. return "%s(\"%s\", %s)" % (
  417. self._module_format(self._module + "." + self._einsum),
  418. "{}->{}".format(contraction_string, "".join(sorted(letters_free))),
  419. elems,
  420. )
  421. def _print_ArrayDiagonal(self, expr):
  422. from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct
  423. diagonal_indices = list(expr.diagonal_indices)
  424. if isinstance(expr.expr, ArrayTensorProduct):
  425. subranks = expr.expr.subranks
  426. elems = expr.expr.args
  427. else:
  428. subranks = expr.subranks
  429. elems = [expr.expr]
  430. diagonal_string, letters_free, letters_dum = self._get_einsum_string(subranks, diagonal_indices)
  431. elems = [self._print(i) for i in elems]
  432. return '%s("%s", %s)' % (
  433. self._module_format(self._module + "." + self._einsum),
  434. "{}->{}".format(diagonal_string, "".join(letters_free+letters_dum)),
  435. ", ".join(elems)
  436. )
  437. def _print_PermuteDims(self, expr):
  438. return "%s(%s, %s)" % (
  439. self._module_format(self._module + "." + self._transpose),
  440. self._print(expr.expr),
  441. self._print(expr.permutation.array_form),
  442. )
  443. def _print_ArrayAdd(self, expr):
  444. return self._expand_fold_binary_op(self._module + "." + self._add, expr.args)
  445. def _print_OneArray(self, expr):
  446. return "%s((%s,))" % (
  447. self._module_format(self._module+ "." + self._ones),
  448. ','.join(map(self._print,expr.args))
  449. )
  450. def _print_ZeroArray(self, expr):
  451. return "%s((%s,))" % (
  452. self._module_format(self._module+ "." + self._zeros),
  453. ','.join(map(self._print,expr.args))
  454. )
  455. def _print_Assignment(self, expr):
  456. #XXX: maybe this needs to happen at a higher level e.g. at _print or
  457. #doprint?
  458. lhs = self._print(self._arrayify(expr.lhs))
  459. rhs = self._print(self._arrayify(expr.rhs))
  460. return "%s = %s" % ( lhs, rhs )
  461. def _print_IndexedBase(self, expr):
  462. return self._print_ArraySymbol(expr)
  463. class PythonCodePrinter(AbstractPythonCodePrinter):
  464. def _print_sign(self, e):
  465. return '(0.0 if {e} == 0 else {f}(1, {e}))'.format(
  466. f=self._module_format('math.copysign'), e=self._print(e.args[0]))
  467. def _print_Not(self, expr):
  468. PREC = precedence(expr)
  469. return self._operators['not'] + ' ' + self.parenthesize(expr.args[0], PREC)
  470. def _print_IndexedBase(self, expr):
  471. return expr.name
  472. def _print_Indexed(self, expr):
  473. base = expr.args[0]
  474. index = expr.args[1:]
  475. return "{}[{}]".format(str(base), ", ".join([self._print(ind) for ind in index]))
  476. def _print_Pow(self, expr, rational=False):
  477. return self._hprint_Pow(expr, rational=rational)
  478. def _print_Rational(self, expr):
  479. return '{}/{}'.format(expr.p, expr.q)
  480. def _print_Half(self, expr):
  481. return self._print_Rational(expr)
  482. def _print_frac(self, expr):
  483. return self._print_Mod(Mod(expr.args[0], 1))
  484. def _print_Symbol(self, expr):
  485. name = super()._print_Symbol(expr)
  486. if name in self.reserved_words:
  487. if self._settings['error_on_reserved']:
  488. msg = ('This expression includes the symbol "{}" which is a '
  489. 'reserved keyword in this language.')
  490. raise ValueError(msg.format(name))
  491. return name + self._settings['reserved_word_suffix']
  492. elif '{' in name: # Remove curly braces from subscripted variables
  493. return name.replace('{', '').replace('}', '')
  494. else:
  495. return name
  496. _print_lowergamma = CodePrinter._print_not_supported
  497. _print_uppergamma = CodePrinter._print_not_supported
  498. _print_fresnelc = CodePrinter._print_not_supported
  499. _print_fresnels = CodePrinter._print_not_supported
  500. for k in PythonCodePrinter._kf:
  501. setattr(PythonCodePrinter, '_print_%s' % k, _print_known_func)
  502. for k in _known_constants_math:
  503. setattr(PythonCodePrinter, '_print_%s' % k, _print_known_const)
  504. def pycode(expr, **settings):
  505. """ Converts an expr to a string of Python code
  506. Parameters
  507. ==========
  508. expr : Expr
  509. A SymPy expression.
  510. fully_qualified_modules : bool
  511. Whether or not to write out full module names of functions
  512. (``math.sin`` vs. ``sin``). default: ``True``.
  513. standard : str or None, optional
  514. Only 'python3' (default) is supported.
  515. This parameter may be removed in the future.
  516. Examples
  517. ========
  518. >>> from sympy import pycode, tan, Symbol
  519. >>> pycode(tan(Symbol('x')) + 1)
  520. 'math.tan(x) + 1'
  521. """
  522. return PythonCodePrinter(settings).doprint(expr)
  523. from itertools import chain
  524. from sympy.printing.pycode import PythonCodePrinter
  525. _known_functions_cmath = {
  526. 'exp': 'exp',
  527. 'sqrt': 'sqrt',
  528. 'log': 'log',
  529. 'cos': 'cos',
  530. 'sin': 'sin',
  531. 'tan': 'tan',
  532. 'acos': 'acos',
  533. 'asin': 'asin',
  534. 'atan': 'atan',
  535. 'cosh': 'cosh',
  536. 'sinh': 'sinh',
  537. 'tanh': 'tanh',
  538. 'acosh': 'acosh',
  539. 'asinh': 'asinh',
  540. 'atanh': 'atanh',
  541. }
  542. _known_constants_cmath = {
  543. 'Pi': 'pi',
  544. 'E': 'e',
  545. 'Infinity': 'inf',
  546. 'NegativeInfinity': '-inf',
  547. }
  548. class CmathPrinter(PythonCodePrinter):
  549. """ Printer for Python's cmath module """
  550. printmethod = "_cmathcode"
  551. language = "Python with cmath"
  552. _kf = dict(chain(
  553. _known_functions_cmath.items()
  554. ))
  555. _kc = {k: 'cmath.' + v for k, v in _known_constants_cmath.items()}
  556. def _print_Pow(self, expr, rational=False):
  557. return self._hprint_Pow(expr, rational=rational, sqrt='cmath.sqrt')
  558. def _print_Float(self, e):
  559. return '{func}({val})'.format(func=self._module_format('cmath.mpf'), val=self._print(e))
  560. def _print_known_func(self, expr):
  561. func_name = expr.func.__name__
  562. if func_name in self._kf:
  563. return f"cmath.{self._kf[func_name]}({', '.join(map(self._print, expr.args))})"
  564. return super()._print_Function(expr)
  565. def _print_known_const(self, expr):
  566. return self._kc[expr.__class__.__name__]
  567. def _print_re(self, expr):
  568. """Prints `re(z)` as `z.real`"""
  569. return f"({self._print(expr.args[0])}).real"
  570. def _print_im(self, expr):
  571. """Prints `im(z)` as `z.imag`"""
  572. return f"({self._print(expr.args[0])}).imag"
  573. for k in CmathPrinter._kf:
  574. setattr(CmathPrinter, '_print_%s' % k, CmathPrinter._print_known_func)
  575. for k in _known_constants_cmath:
  576. setattr(CmathPrinter, '_print_%s' % k, CmathPrinter._print_known_const)
  577. _not_in_mpmath = 'log1p log2'.split()
  578. _in_mpmath = [(k, v) for k, v in _known_functions_math.items() if k not in _not_in_mpmath]
  579. _known_functions_mpmath = dict(_in_mpmath, **{
  580. 'beta': 'beta',
  581. 'frac': 'frac',
  582. 'fresnelc': 'fresnelc',
  583. 'fresnels': 'fresnels',
  584. 'sign': 'sign',
  585. 'loggamma': 'loggamma',
  586. 'hyper': 'hyper',
  587. 'meijerg': 'meijerg',
  588. 'besselj': 'besselj',
  589. 'bessely': 'bessely',
  590. 'besseli': 'besseli',
  591. 'besselk': 'besselk',
  592. })
  593. _known_constants_mpmath = {
  594. 'Exp1': 'e',
  595. 'Pi': 'pi',
  596. 'GoldenRatio': 'phi',
  597. 'EulerGamma': 'euler',
  598. 'Catalan': 'catalan',
  599. 'NaN': 'nan',
  600. 'Infinity': 'inf',
  601. 'NegativeInfinity': 'ninf'
  602. }
  603. def _unpack_integral_limits(integral_expr):
  604. """ helper function for _print_Integral that
  605. - accepts an Integral expression
  606. - returns a tuple of
  607. - a list variables of integration
  608. - a list of tuples of the upper and lower limits of integration
  609. """
  610. integration_vars = []
  611. limits = []
  612. for integration_range in integral_expr.limits:
  613. if len(integration_range) == 3:
  614. integration_var, lower_limit, upper_limit = integration_range
  615. else:
  616. raise NotImplementedError("Only definite integrals are supported")
  617. integration_vars.append(integration_var)
  618. limits.append((lower_limit, upper_limit))
  619. return integration_vars, limits
  620. class MpmathPrinter(PythonCodePrinter):
  621. """
  622. Lambda printer for mpmath which maintains precision for floats
  623. """
  624. printmethod = "_mpmathcode"
  625. language = "Python with mpmath"
  626. _kf = dict(chain(
  627. _known_functions.items(),
  628. [(k, 'mpmath.' + v) for k, v in _known_functions_mpmath.items()]
  629. ))
  630. _kc = {k: 'mpmath.'+v for k, v in _known_constants_mpmath.items()}
  631. def _print_Float(self, e):
  632. # XXX: This does not handle setting mpmath.mp.dps. It is assumed that
  633. # the caller of the lambdified function will have set it to sufficient
  634. # precision to match the Floats in the expression.
  635. # Remove 'mpz' if gmpy is installed.
  636. args = str(tuple(map(int, e._mpf_)))
  637. return '{func}({args})'.format(func=self._module_format('mpmath.mpf'), args=args)
  638. def _print_Rational(self, e):
  639. return "{func}({p})/{func}({q})".format(
  640. func=self._module_format('mpmath.mpf'),
  641. q=self._print(e.q),
  642. p=self._print(e.p)
  643. )
  644. def _print_Half(self, e):
  645. return self._print_Rational(e)
  646. def _print_uppergamma(self, e):
  647. return "{}({}, {}, {})".format(
  648. self._module_format('mpmath.gammainc'),
  649. self._print(e.args[0]),
  650. self._print(e.args[1]),
  651. self._module_format('mpmath.inf'))
  652. def _print_lowergamma(self, e):
  653. return "{}({}, 0, {})".format(
  654. self._module_format('mpmath.gammainc'),
  655. self._print(e.args[0]),
  656. self._print(e.args[1]))
  657. def _print_log2(self, e):
  658. return '{0}({1})/{0}(2)'.format(
  659. self._module_format('mpmath.log'), self._print(e.args[0]))
  660. def _print_log1p(self, e):
  661. return '{}({})'.format(
  662. self._module_format('mpmath.log1p'), self._print(e.args[0]))
  663. def _print_Pow(self, expr, rational=False):
  664. return self._hprint_Pow(expr, rational=rational, sqrt='mpmath.sqrt')
  665. def _print_Integral(self, e):
  666. integration_vars, limits = _unpack_integral_limits(e)
  667. return "{}(lambda {}: {}, {})".format(
  668. self._module_format("mpmath.quad"),
  669. ", ".join(map(self._print, integration_vars)),
  670. self._print(e.args[0]),
  671. ", ".join("(%s, %s)" % tuple(map(self._print, l)) for l in limits))
  672. def _print_Derivative_zeta(self, args, seq_orders):
  673. arg, = args
  674. deriv_order, = seq_orders
  675. return '{}({}, derivative={})'.format(
  676. self._module_format('mpmath.zeta'),
  677. self._print(arg), deriv_order
  678. )
  679. for k in MpmathPrinter._kf:
  680. setattr(MpmathPrinter, '_print_%s' % k, _print_known_func)
  681. for k in _known_constants_mpmath:
  682. setattr(MpmathPrinter, '_print_%s' % k, _print_known_const)
  683. class SymPyPrinter(AbstractPythonCodePrinter):
  684. language = "Python with SymPy"
  685. _default_settings = dict(
  686. AbstractPythonCodePrinter._default_settings,
  687. strict=False # any class name will per definition be what we target in SymPyPrinter.
  688. )
  689. def _print_Function(self, expr):
  690. mod = expr.func.__module__ or ''
  691. return '%s(%s)' % (self._module_format(mod + ('.' if mod else '') + expr.func.__name__),
  692. ', '.join((self._print(arg) for arg in expr.args)))
  693. def _print_Pow(self, expr, rational=False):
  694. return self._hprint_Pow(expr, rational=rational, sqrt='sympy.sqrt')