codeprinter.py 41 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039
  1. from __future__ import annotations
  2. from typing import Any
  3. from functools import wraps
  4. from sympy.core import Add, Mul, Pow, S, sympify, Float
  5. from sympy.core.basic import Basic
  6. from sympy.core.expr import Expr, UnevaluatedExpr
  7. from sympy.core.function import Lambda
  8. from sympy.core.mul import _keep_coeff
  9. from sympy.core.sorting import default_sort_key
  10. from sympy.core.symbol import Symbol
  11. from sympy.functions.elementary.complexes import re
  12. from sympy.printing.str import StrPrinter
  13. from sympy.printing.precedence import precedence, PRECEDENCE
  14. class requires:
  15. """ Decorator for registering requirements on print methods. """
  16. def __init__(self, **kwargs):
  17. self._req = kwargs
  18. def __call__(self, method):
  19. def _method_wrapper(self_, *args, **kwargs):
  20. for k, v in self._req.items():
  21. getattr(self_, k).update(v)
  22. return method(self_, *args, **kwargs)
  23. return wraps(method)(_method_wrapper)
  24. class AssignmentError(Exception):
  25. """
  26. Raised if an assignment variable for a loop is missing.
  27. """
  28. pass
  29. class PrintMethodNotImplementedError(NotImplementedError):
  30. """
  31. Raised if a _print_* method is missing in the Printer.
  32. """
  33. pass
  34. def _convert_python_lists(arg):
  35. if isinstance(arg, list):
  36. from sympy.codegen.abstract_nodes import List
  37. return List(*(_convert_python_lists(e) for e in arg))
  38. elif isinstance(arg, tuple):
  39. return tuple(_convert_python_lists(e) for e in arg)
  40. else:
  41. return arg
  42. class CodePrinter(StrPrinter):
  43. """
  44. The base class for code-printing subclasses.
  45. """
  46. _operators = {
  47. 'and': '&&',
  48. 'or': '||',
  49. 'not': '!',
  50. }
  51. _default_settings: dict[str, Any] = {
  52. 'order': None,
  53. 'full_prec': 'auto',
  54. 'error_on_reserved': False,
  55. 'reserved_word_suffix': '_',
  56. 'human': True,
  57. 'inline': False,
  58. 'allow_unknown_functions': False,
  59. 'strict': None # True or False; None => True if human == True
  60. }
  61. # Functions which are "simple" to rewrite to other functions that
  62. # may be supported
  63. # function_to_rewrite : (function_to_rewrite_to, iterable_with_other_functions_required)
  64. _rewriteable_functions = {
  65. 'cot': ('tan', []),
  66. 'csc': ('sin', []),
  67. 'sec': ('cos', []),
  68. 'acot': ('atan', []),
  69. 'acsc': ('asin', []),
  70. 'asec': ('acos', []),
  71. 'coth': ('exp', []),
  72. 'csch': ('exp', []),
  73. 'sech': ('exp', []),
  74. 'acoth': ('log', []),
  75. 'acsch': ('log', []),
  76. 'asech': ('log', []),
  77. 'catalan': ('gamma', []),
  78. 'fibonacci': ('sqrt', []),
  79. 'lucas': ('sqrt', []),
  80. 'beta': ('gamma', []),
  81. 'sinc': ('sin', ['Piecewise']),
  82. 'Mod': ('floor', []),
  83. 'factorial': ('gamma', []),
  84. 'factorial2': ('gamma', ['Piecewise']),
  85. 'subfactorial': ('uppergamma', []),
  86. 'RisingFactorial': ('gamma', ['Piecewise']),
  87. 'FallingFactorial': ('gamma', ['Piecewise']),
  88. 'binomial': ('gamma', []),
  89. 'frac': ('floor', []),
  90. 'Max': ('Piecewise', []),
  91. 'Min': ('Piecewise', []),
  92. 'Heaviside': ('Piecewise', []),
  93. 'erf2': ('erf', []),
  94. 'erfc': ('erf', []),
  95. 'Li': ('li', []),
  96. 'Ei': ('li', []),
  97. 'dirichlet_eta': ('zeta', []),
  98. 'riemann_xi': ('zeta', ['gamma']),
  99. 'SingularityFunction': ('Piecewise', []),
  100. }
  101. def __init__(self, settings=None):
  102. super().__init__(settings=settings)
  103. if self._settings.get('strict', True) == None:
  104. # for backwards compatibility, human=False need not to throw:
  105. self._settings['strict'] = self._settings.get('human', True) == True
  106. if not hasattr(self, 'reserved_words'):
  107. self.reserved_words = set()
  108. def _handle_UnevaluatedExpr(self, expr):
  109. return expr.replace(re, lambda arg: arg if isinstance(
  110. arg, UnevaluatedExpr) and arg.args[0].is_real else re(arg))
  111. def doprint(self, expr, assign_to=None):
  112. """
  113. Print the expression as code.
  114. Parameters
  115. ----------
  116. expr : Expression
  117. The expression to be printed.
  118. assign_to : Symbol, string, MatrixSymbol, list of strings or Symbols (optional)
  119. If provided, the printed code will set the expression to a variable or multiple variables
  120. with the name or names given in ``assign_to``.
  121. """
  122. from sympy.matrices.expressions.matexpr import MatrixSymbol
  123. from sympy.codegen.ast import CodeBlock, Assignment
  124. def _handle_assign_to(expr, assign_to):
  125. if assign_to is None:
  126. return sympify(expr)
  127. if isinstance(assign_to, (list, tuple)):
  128. if len(expr) != len(assign_to):
  129. raise ValueError('Failed to assign an expression of length {} to {} variables'.format(len(expr), len(assign_to)))
  130. return CodeBlock(*[_handle_assign_to(lhs, rhs) for lhs, rhs in zip(expr, assign_to)])
  131. if isinstance(assign_to, str):
  132. if expr.is_Matrix:
  133. assign_to = MatrixSymbol(assign_to, *expr.shape)
  134. else:
  135. assign_to = Symbol(assign_to)
  136. elif not isinstance(assign_to, Basic):
  137. raise TypeError("{} cannot assign to object of type {}".format(
  138. type(self).__name__, type(assign_to)))
  139. return Assignment(assign_to, expr)
  140. expr = _convert_python_lists(expr)
  141. expr = _handle_assign_to(expr, assign_to)
  142. # Remove re(...) nodes due to UnevaluatedExpr.is_real always is None:
  143. expr = self._handle_UnevaluatedExpr(expr)
  144. # keep a set of expressions that are not strictly translatable to Code
  145. # and number constants that must be declared and initialized
  146. self._not_supported = set()
  147. self._number_symbols = set()
  148. lines = self._print(expr).splitlines()
  149. # format the output
  150. if self._settings["human"]:
  151. frontlines = []
  152. if self._not_supported:
  153. frontlines.append(self._get_comment(
  154. "Not supported in {}:".format(self.language)))
  155. for expr in sorted(self._not_supported, key=str):
  156. frontlines.append(self._get_comment(type(expr).__name__))
  157. for name, value in sorted(self._number_symbols, key=str):
  158. frontlines.append(self._declare_number_const(name, value))
  159. lines = frontlines + lines
  160. lines = self._format_code(lines)
  161. result = "\n".join(lines)
  162. else:
  163. lines = self._format_code(lines)
  164. num_syms = {(k, self._print(v)) for k, v in self._number_symbols}
  165. result = (num_syms, self._not_supported, "\n".join(lines))
  166. self._not_supported = set()
  167. self._number_symbols = set()
  168. return result
  169. def _doprint_loops(self, expr, assign_to=None):
  170. # Here we print an expression that contains Indexed objects, they
  171. # correspond to arrays in the generated code. The low-level implementation
  172. # involves looping over array elements and possibly storing results in temporary
  173. # variables or accumulate it in the assign_to object.
  174. if self._settings.get('contract', True):
  175. from sympy.tensor import get_contraction_structure
  176. # Setup loops over non-dummy indices -- all terms need these
  177. indices = self._get_expression_indices(expr, assign_to)
  178. # Setup loops over dummy indices -- each term needs separate treatment
  179. dummies = get_contraction_structure(expr)
  180. else:
  181. indices = []
  182. dummies = {None: (expr,)}
  183. openloop, closeloop = self._get_loop_opening_ending(indices)
  184. # terms with no summations first
  185. if None in dummies:
  186. text = StrPrinter.doprint(self, Add(*dummies[None]))
  187. else:
  188. # If all terms have summations we must initialize array to Zero
  189. text = StrPrinter.doprint(self, 0)
  190. # skip redundant assignments (where lhs == rhs)
  191. lhs_printed = self._print(assign_to)
  192. lines = []
  193. if text != lhs_printed:
  194. lines.extend(openloop)
  195. if assign_to is not None:
  196. text = self._get_statement("%s = %s" % (lhs_printed, text))
  197. lines.append(text)
  198. lines.extend(closeloop)
  199. # then terms with summations
  200. for d in dummies:
  201. if isinstance(d, tuple):
  202. indices = self._sort_optimized(d, expr)
  203. openloop_d, closeloop_d = self._get_loop_opening_ending(
  204. indices)
  205. for term in dummies[d]:
  206. if term in dummies and not ([list(f.keys()) for f in dummies[term]]
  207. == [[None] for f in dummies[term]]):
  208. # If one factor in the term has it's own internal
  209. # contractions, those must be computed first.
  210. # (temporary variables?)
  211. raise NotImplementedError(
  212. "FIXME: no support for contractions in factor yet")
  213. else:
  214. # We need the lhs expression as an accumulator for
  215. # the loops, i.e
  216. #
  217. # for (int d=0; d < dim; d++){
  218. # lhs[] = lhs[] + term[][d]
  219. # } ^.................. the accumulator
  220. #
  221. # We check if the expression already contains the
  222. # lhs, and raise an exception if it does, as that
  223. # syntax is currently undefined. FIXME: What would be
  224. # a good interpretation?
  225. if assign_to is None:
  226. raise AssignmentError(
  227. "need assignment variable for loops")
  228. if term.has(assign_to):
  229. raise ValueError("FIXME: lhs present in rhs,\
  230. this is undefined in CodePrinter")
  231. lines.extend(openloop)
  232. lines.extend(openloop_d)
  233. text = "%s = %s" % (lhs_printed, StrPrinter.doprint(
  234. self, assign_to + term))
  235. lines.append(self._get_statement(text))
  236. lines.extend(closeloop_d)
  237. lines.extend(closeloop)
  238. return "\n".join(lines)
  239. def _get_expression_indices(self, expr, assign_to):
  240. from sympy.tensor import get_indices
  241. rinds, junk = get_indices(expr)
  242. linds, junk = get_indices(assign_to)
  243. # support broadcast of scalar
  244. if linds and not rinds:
  245. rinds = linds
  246. if rinds != linds:
  247. raise ValueError("lhs indices must match non-dummy"
  248. " rhs indices in %s" % expr)
  249. return self._sort_optimized(rinds, assign_to)
  250. def _sort_optimized(self, indices, expr):
  251. from sympy.tensor.indexed import Indexed
  252. if not indices:
  253. return []
  254. # determine optimized loop order by giving a score to each index
  255. # the index with the highest score are put in the innermost loop.
  256. score_table = {}
  257. for i in indices:
  258. score_table[i] = 0
  259. arrays = expr.atoms(Indexed)
  260. for arr in arrays:
  261. for p, ind in enumerate(arr.indices):
  262. try:
  263. score_table[ind] += self._rate_index_position(p)
  264. except KeyError:
  265. pass
  266. return sorted(indices, key=lambda x: score_table[x])
  267. def _rate_index_position(self, p):
  268. """function to calculate score based on position among indices
  269. This method is used to sort loops in an optimized order, see
  270. CodePrinter._sort_optimized()
  271. """
  272. raise NotImplementedError("This function must be implemented by "
  273. "subclass of CodePrinter.")
  274. def _get_statement(self, codestring):
  275. """Formats a codestring with the proper line ending."""
  276. raise NotImplementedError("This function must be implemented by "
  277. "subclass of CodePrinter.")
  278. def _get_comment(self, text):
  279. """Formats a text string as a comment."""
  280. raise NotImplementedError("This function must be implemented by "
  281. "subclass of CodePrinter.")
  282. def _declare_number_const(self, name, value):
  283. """Declare a numeric constant at the top of a function"""
  284. raise NotImplementedError("This function must be implemented by "
  285. "subclass of CodePrinter.")
  286. def _format_code(self, lines):
  287. """Take in a list of lines of code, and format them accordingly.
  288. This may include indenting, wrapping long lines, etc..."""
  289. raise NotImplementedError("This function must be implemented by "
  290. "subclass of CodePrinter.")
  291. def _get_loop_opening_ending(self, indices):
  292. """Returns a tuple (open_lines, close_lines) containing lists
  293. of codelines"""
  294. raise NotImplementedError("This function must be implemented by "
  295. "subclass of CodePrinter.")
  296. def _print_Dummy(self, expr):
  297. if expr.name.startswith('Dummy_'):
  298. return '_' + expr.name
  299. else:
  300. return '%s_%d' % (expr.name, expr.dummy_index)
  301. def _print_Idx(self, expr):
  302. return self._print(expr.label)
  303. def _print_CodeBlock(self, expr):
  304. return '\n'.join([self._print(i) for i in expr.args])
  305. def _print_String(self, string):
  306. return str(string)
  307. def _print_QuotedString(self, arg):
  308. return '"%s"' % arg.text
  309. def _print_Comment(self, string):
  310. return self._get_comment(str(string))
  311. def _print_Assignment(self, expr):
  312. from sympy.codegen.ast import Assignment
  313. from sympy.functions.elementary.piecewise import Piecewise
  314. from sympy.matrices.expressions.matexpr import MatrixSymbol
  315. from sympy.tensor.indexed import IndexedBase
  316. lhs = expr.lhs
  317. rhs = expr.rhs
  318. # We special case assignments that take multiple lines
  319. if isinstance(expr.rhs, Piecewise):
  320. # Here we modify Piecewise so each expression is now
  321. # an Assignment, and then continue on the print.
  322. expressions = []
  323. conditions = []
  324. for (e, c) in rhs.args:
  325. expressions.append(Assignment(lhs, e))
  326. conditions.append(c)
  327. temp = Piecewise(*zip(expressions, conditions))
  328. return self._print(temp)
  329. elif isinstance(lhs, MatrixSymbol):
  330. # Here we form an Assignment for each element in the array,
  331. # printing each one.
  332. lines = []
  333. for (i, j) in self._traverse_matrix_indices(lhs):
  334. temp = Assignment(lhs[i, j], rhs[i, j])
  335. code0 = self._print(temp)
  336. lines.append(code0)
  337. return "\n".join(lines)
  338. elif self._settings.get("contract", False) and (lhs.has(IndexedBase) or
  339. rhs.has(IndexedBase)):
  340. # Here we check if there is looping to be done, and if so
  341. # print the required loops.
  342. return self._doprint_loops(rhs, lhs)
  343. else:
  344. lhs_code = self._print(lhs)
  345. rhs_code = self._print(rhs)
  346. return self._get_statement("%s = %s" % (lhs_code, rhs_code))
  347. def _print_AugmentedAssignment(self, expr):
  348. lhs_code = self._print(expr.lhs)
  349. rhs_code = self._print(expr.rhs)
  350. return self._get_statement("{} {} {}".format(
  351. *(self._print(arg) for arg in [lhs_code, expr.op, rhs_code])))
  352. def _print_FunctionCall(self, expr):
  353. return '%s(%s)' % (
  354. expr.name,
  355. ', '.join((self._print(arg) for arg in expr.function_args)))
  356. def _print_Variable(self, expr):
  357. return self._print(expr.symbol)
  358. def _print_Symbol(self, expr):
  359. name = super()._print_Symbol(expr)
  360. if name in self.reserved_words:
  361. if self._settings['error_on_reserved']:
  362. msg = ('This expression includes the symbol "{}" which is a '
  363. 'reserved keyword in this language.')
  364. raise ValueError(msg.format(name))
  365. return name + self._settings['reserved_word_suffix']
  366. else:
  367. return name
  368. def _can_print(self, name):
  369. """ Check if function ``name`` is either a known function or has its own
  370. printing method. Used to check if rewriting is possible."""
  371. return name in self.known_functions or getattr(self, '_print_{}'.format(name), False)
  372. def _print_Function(self, expr):
  373. if expr.func.__name__ in self.known_functions:
  374. cond_func = self.known_functions[expr.func.__name__]
  375. if isinstance(cond_func, str):
  376. return "%s(%s)" % (cond_func, self.stringify(expr.args, ", "))
  377. else:
  378. for cond, func in cond_func:
  379. if cond(*expr.args):
  380. break
  381. if func is not None:
  382. try:
  383. return func(*[self.parenthesize(item, 0) for item in expr.args])
  384. except TypeError:
  385. return "%s(%s)" % (func, self.stringify(expr.args, ", "))
  386. elif hasattr(expr, '_imp_') and isinstance(expr._imp_, Lambda):
  387. # inlined function
  388. return self._print(expr._imp_(*expr.args))
  389. elif expr.func.__name__ in self._rewriteable_functions:
  390. # Simple rewrite to supported function possible
  391. target_f, required_fs = self._rewriteable_functions[expr.func.__name__]
  392. if self._can_print(target_f) and all(self._can_print(f) for f in required_fs):
  393. return '(' + self._print(expr.rewrite(target_f)) + ')'
  394. if expr.is_Function and self._settings.get('allow_unknown_functions', False):
  395. return '%s(%s)' % (self._print(expr.func), ', '.join(map(self._print, expr.args)))
  396. else:
  397. return self._print_not_supported(expr)
  398. _print_Expr = _print_Function
  399. def _print_Derivative(self, expr):
  400. obj, *wrt_order_pairs = expr.args
  401. for func_arg in obj.args:
  402. if not func_arg.is_Symbol:
  403. raise ValueError("%s._print_Derivative(...) only supports functions with symbols as arguments." %
  404. self.__class__.__name__)
  405. meth_name = '_print_Derivative_%s' % obj.func.__name__
  406. pmeth = getattr(self, meth_name, None)
  407. if pmeth is None:
  408. if self._settings.get('strict', False):
  409. raise PrintMethodNotImplementedError(
  410. f"Unsupported by {type(self)}: {type(expr)}" +
  411. f"\nPrinter has no method: {meth_name}" +
  412. "\nSet the printer option 'strict' to False in order to generate partially printed code."
  413. )
  414. return self._print_not_supported(expr)
  415. orders = dict(wrt_order_pairs)
  416. seq_orders = [orders[arg] for arg in obj.args]
  417. return pmeth(obj.args, seq_orders)
  418. # Don't inherit the str-printer method for Heaviside to the code printers
  419. _print_Heaviside = None
  420. def _print_NumberSymbol(self, expr):
  421. if self._settings.get("inline", False):
  422. return self._print(Float(expr.evalf(self._settings["precision"])))
  423. else:
  424. # A Number symbol that is not implemented here or with _printmethod
  425. # is registered and evaluated
  426. self._number_symbols.add((expr,
  427. Float(expr.evalf(self._settings["precision"]))))
  428. return str(expr)
  429. def _print_Catalan(self, expr):
  430. return self._print_NumberSymbol(expr)
  431. def _print_EulerGamma(self, expr):
  432. return self._print_NumberSymbol(expr)
  433. def _print_GoldenRatio(self, expr):
  434. return self._print_NumberSymbol(expr)
  435. def _print_TribonacciConstant(self, expr):
  436. return self._print_NumberSymbol(expr)
  437. def _print_Exp1(self, expr):
  438. return self._print_NumberSymbol(expr)
  439. def _print_Pi(self, expr):
  440. return self._print_NumberSymbol(expr)
  441. def _print_And(self, expr):
  442. PREC = precedence(expr)
  443. return (" %s " % self._operators['and']).join(self.parenthesize(a, PREC)
  444. for a in sorted(expr.args, key=default_sort_key))
  445. def _print_Or(self, expr):
  446. PREC = precedence(expr)
  447. return (" %s " % self._operators['or']).join(self.parenthesize(a, PREC)
  448. for a in sorted(expr.args, key=default_sort_key))
  449. def _print_Xor(self, expr):
  450. if self._operators.get('xor') is None:
  451. return self._print(expr.to_nnf())
  452. PREC = precedence(expr)
  453. return (" %s " % self._operators['xor']).join(self.parenthesize(a, PREC)
  454. for a in expr.args)
  455. def _print_Equivalent(self, expr):
  456. if self._operators.get('equivalent') is None:
  457. return self._print(expr.to_nnf())
  458. PREC = precedence(expr)
  459. return (" %s " % self._operators['equivalent']).join(self.parenthesize(a, PREC)
  460. for a in expr.args)
  461. def _print_Not(self, expr):
  462. PREC = precedence(expr)
  463. return self._operators['not'] + self.parenthesize(expr.args[0], PREC)
  464. def _print_BooleanFunction(self, expr):
  465. return self._print(expr.to_nnf())
  466. def _print_isnan(self, arg):
  467. return 'isnan(%s)' % self._print(*arg.args)
  468. def _print_isinf(self, arg):
  469. return 'isinf(%s)' % self._print(*arg.args)
  470. def _print_Mul(self, expr):
  471. prec = precedence(expr)
  472. c, e = expr.as_coeff_Mul()
  473. if c < 0:
  474. expr = _keep_coeff(-c, e)
  475. sign = "-"
  476. else:
  477. sign = ""
  478. a = [] # items in the numerator
  479. b = [] # items that are in the denominator (if any)
  480. pow_paren = [] # Will collect all pow with more than one base element and exp = -1
  481. if self.order not in ('old', 'none'):
  482. args = expr.as_ordered_factors()
  483. else:
  484. # use make_args in case expr was something like -x -> x
  485. args = Mul.make_args(expr)
  486. # Gather args for numerator/denominator
  487. for item in args:
  488. if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative:
  489. if item.exp != -1:
  490. b.append(Pow(item.base, -item.exp, evaluate=False))
  491. else:
  492. if len(item.args[0].args) != 1 and isinstance(item.base, Mul): # To avoid situations like #14160
  493. pow_paren.append(item)
  494. b.append(Pow(item.base, -item.exp))
  495. else:
  496. a.append(item)
  497. a = a or [S.One]
  498. if len(a) == 1 and sign == "-":
  499. # Unary minus does not have a SymPy class, and hence there's no
  500. # precedence weight associated with it, Python's unary minus has
  501. # an operator precedence between multiplication and exponentiation,
  502. # so we use this to compute a weight.
  503. a_str = [self.parenthesize(a[0], 0.5*(PRECEDENCE["Pow"]+PRECEDENCE["Mul"]))]
  504. else:
  505. a_str = [self.parenthesize(x, prec) for x in a]
  506. b_str = [self.parenthesize(x, prec) for x in b]
  507. # To parenthesize Pow with exp = -1 and having more than one Symbol
  508. for item in pow_paren:
  509. if item.base in b:
  510. b_str[b.index(item.base)] = "(%s)" % b_str[b.index(item.base)]
  511. if not b:
  512. return sign + '*'.join(a_str)
  513. elif len(b) == 1:
  514. return sign + '*'.join(a_str) + "/" + b_str[0]
  515. else:
  516. return sign + '*'.join(a_str) + "/(%s)" % '*'.join(b_str)
  517. def _print_not_supported(self, expr):
  518. if self._settings.get('strict', False):
  519. raise PrintMethodNotImplementedError(
  520. f"Unsupported by {type(self)}: {type(expr)}" +
  521. "\nSet the printer option 'strict' to False in order to generate partially printed code."
  522. )
  523. try:
  524. self._not_supported.add(expr)
  525. except TypeError:
  526. # not hashable
  527. pass
  528. return self.emptyPrinter(expr)
  529. # The following can not be simply translated into C or Fortran
  530. _print_Basic = _print_not_supported
  531. _print_ComplexInfinity = _print_not_supported
  532. _print_ExprCondPair = _print_not_supported
  533. _print_GeometryEntity = _print_not_supported
  534. _print_Infinity = _print_not_supported
  535. _print_Integral = _print_not_supported
  536. _print_Interval = _print_not_supported
  537. _print_AccumulationBounds = _print_not_supported
  538. _print_Limit = _print_not_supported
  539. _print_MatrixBase = _print_not_supported
  540. _print_DeferredVector = _print_not_supported
  541. _print_NaN = _print_not_supported
  542. _print_NegativeInfinity = _print_not_supported
  543. _print_Order = _print_not_supported
  544. _print_RootOf = _print_not_supported
  545. _print_RootsOf = _print_not_supported
  546. _print_RootSum = _print_not_supported
  547. _print_Uniform = _print_not_supported
  548. _print_Unit = _print_not_supported
  549. _print_Wild = _print_not_supported
  550. _print_WildFunction = _print_not_supported
  551. _print_Relational = _print_not_supported
  552. # Code printer functions. These are included in this file so that they can be
  553. # imported in the top-level __init__.py without importing the sympy.codegen
  554. # module.
  555. def ccode(expr, assign_to=None, standard='c99', **settings):
  556. """Converts an expr to a string of c code
  557. Parameters
  558. ==========
  559. expr : Expr
  560. A SymPy expression to be converted.
  561. assign_to : optional
  562. When given, the argument is used as the name of the variable to which
  563. the expression is assigned. Can be a string, ``Symbol``,
  564. ``MatrixSymbol``, or ``Indexed`` type. This is helpful in case of
  565. line-wrapping, or for expressions that generate multi-line statements.
  566. standard : str, optional
  567. String specifying the standard. If your compiler supports a more modern
  568. standard you may set this to 'c99' to allow the printer to use more math
  569. functions. [default='c89'].
  570. precision : integer, optional
  571. The precision for numbers such as pi [default=17].
  572. user_functions : dict, optional
  573. A dictionary where the keys are string representations of either
  574. ``FunctionClass`` or ``UndefinedFunction`` instances and the values
  575. are their desired C string representations. Alternatively, the
  576. dictionary value can be a list of tuples i.e. [(argument_test,
  577. cfunction_string)] or [(argument_test, cfunction_formater)]. See below
  578. for examples.
  579. dereference : iterable, optional
  580. An iterable of symbols that should be dereferenced in the printed code
  581. expression. These would be values passed by address to the function.
  582. For example, if ``dereference=[a]``, the resulting code would print
  583. ``(*a)`` instead of ``a``.
  584. human : bool, optional
  585. If True, the result is a single string that may contain some constant
  586. declarations for the number symbols. If False, the same information is
  587. returned in a tuple of (symbols_to_declare, not_supported_functions,
  588. code_text). [default=True].
  589. contract: bool, optional
  590. If True, ``Indexed`` instances are assumed to obey tensor contraction
  591. rules and the corresponding nested loops over indices are generated.
  592. Setting contract=False will not generate loops, instead the user is
  593. responsible to provide values for the indices in the code.
  594. [default=True].
  595. Examples
  596. ========
  597. >>> from sympy import ccode, symbols, Rational, sin, ceiling, Abs, Function
  598. >>> x, tau = symbols("x, tau")
  599. >>> expr = (2*tau)**Rational(7, 2)
  600. >>> ccode(expr)
  601. '8*M_SQRT2*pow(tau, 7.0/2.0)'
  602. >>> ccode(expr, math_macros={})
  603. '8*sqrt(2)*pow(tau, 7.0/2.0)'
  604. >>> ccode(sin(x), assign_to="s")
  605. 's = sin(x);'
  606. >>> from sympy.codegen.ast import real, float80
  607. >>> ccode(expr, type_aliases={real: float80})
  608. '8*M_SQRT2l*powl(tau, 7.0L/2.0L)'
  609. Simple custom printing can be defined for certain types by passing a
  610. dictionary of {"type" : "function"} to the ``user_functions`` kwarg.
  611. Alternatively, the dictionary value can be a list of tuples i.e.
  612. [(argument_test, cfunction_string)].
  613. >>> custom_functions = {
  614. ... "ceiling": "CEIL",
  615. ... "Abs": [(lambda x: not x.is_integer, "fabs"),
  616. ... (lambda x: x.is_integer, "ABS")],
  617. ... "func": "f"
  618. ... }
  619. >>> func = Function('func')
  620. >>> ccode(func(Abs(x) + ceiling(x)), standard='C89', user_functions=custom_functions)
  621. 'f(fabs(x) + CEIL(x))'
  622. or if the C-function takes a subset of the original arguments:
  623. >>> ccode(2**x + 3**x, standard='C99', user_functions={'Pow': [
  624. ... (lambda b, e: b == 2, lambda b, e: 'exp2(%s)' % e),
  625. ... (lambda b, e: b != 2, 'pow')]})
  626. 'exp2(x) + pow(3, x)'
  627. ``Piecewise`` expressions are converted into conditionals. If an
  628. ``assign_to`` variable is provided an if statement is created, otherwise
  629. the ternary operator is used. Note that if the ``Piecewise`` lacks a
  630. default term, represented by ``(expr, True)`` then an error will be thrown.
  631. This is to prevent generating an expression that may not evaluate to
  632. anything.
  633. >>> from sympy import Piecewise
  634. >>> expr = Piecewise((x + 1, x > 0), (x, True))
  635. >>> print(ccode(expr, tau, standard='C89'))
  636. if (x > 0) {
  637. tau = x + 1;
  638. }
  639. else {
  640. tau = x;
  641. }
  642. Support for loops is provided through ``Indexed`` types. With
  643. ``contract=True`` these expressions will be turned into loops, whereas
  644. ``contract=False`` will just print the assignment expression that should be
  645. looped over:
  646. >>> from sympy import Eq, IndexedBase, Idx
  647. >>> len_y = 5
  648. >>> y = IndexedBase('y', shape=(len_y,))
  649. >>> t = IndexedBase('t', shape=(len_y,))
  650. >>> Dy = IndexedBase('Dy', shape=(len_y-1,))
  651. >>> i = Idx('i', len_y-1)
  652. >>> e=Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i]))
  653. >>> ccode(e.rhs, assign_to=e.lhs, contract=False, standard='C89')
  654. 'Dy[i] = (y[i + 1] - y[i])/(t[i + 1] - t[i]);'
  655. Matrices are also supported, but a ``MatrixSymbol`` of the same dimensions
  656. must be provided to ``assign_to``. Note that any expression that can be
  657. generated normally can also exist inside a Matrix:
  658. >>> from sympy import Matrix, MatrixSymbol
  659. >>> mat = Matrix([x**2, Piecewise((x + 1, x > 0), (x, True)), sin(x)])
  660. >>> A = MatrixSymbol('A', 3, 1)
  661. >>> print(ccode(mat, A, standard='C89'))
  662. A[0] = pow(x, 2);
  663. if (x > 0) {
  664. A[1] = x + 1;
  665. }
  666. else {
  667. A[1] = x;
  668. }
  669. A[2] = sin(x);
  670. """
  671. from sympy.printing.c import c_code_printers
  672. return c_code_printers[standard.lower()](settings).doprint(expr, assign_to)
  673. def print_ccode(expr, **settings):
  674. """Prints C representation of the given expression."""
  675. print(ccode(expr, **settings))
  676. def fcode(expr, assign_to=None, **settings):
  677. """Converts an expr to a string of fortran code
  678. Parameters
  679. ==========
  680. expr : Expr
  681. A SymPy expression to be converted.
  682. assign_to : optional
  683. When given, the argument is used as the name of the variable to which
  684. the expression is assigned. Can be a string, ``Symbol``,
  685. ``MatrixSymbol``, or ``Indexed`` type. This is helpful in case of
  686. line-wrapping, or for expressions that generate multi-line statements.
  687. precision : integer, optional
  688. DEPRECATED. Use type_mappings instead. The precision for numbers such
  689. as pi [default=17].
  690. user_functions : dict, optional
  691. A dictionary where keys are ``FunctionClass`` instances and values are
  692. their string representations. Alternatively, the dictionary value can
  693. be a list of tuples i.e. [(argument_test, cfunction_string)]. See below
  694. for examples.
  695. human : bool, optional
  696. If True, the result is a single string that may contain some constant
  697. declarations for the number symbols. If False, the same information is
  698. returned in a tuple of (symbols_to_declare, not_supported_functions,
  699. code_text). [default=True].
  700. contract: bool, optional
  701. If True, ``Indexed`` instances are assumed to obey tensor contraction
  702. rules and the corresponding nested loops over indices are generated.
  703. Setting contract=False will not generate loops, instead the user is
  704. responsible to provide values for the indices in the code.
  705. [default=True].
  706. source_format : optional
  707. The source format can be either 'fixed' or 'free'. [default='fixed']
  708. standard : integer, optional
  709. The Fortran standard to be followed. This is specified as an integer.
  710. Acceptable standards are 66, 77, 90, 95, 2003, and 2008. Default is 77.
  711. Note that currently the only distinction internally is between
  712. standards before 95, and those 95 and after. This may change later as
  713. more features are added.
  714. name_mangling : bool, optional
  715. If True, then the variables that would become identical in
  716. case-insensitive Fortran are mangled by appending different number
  717. of ``_`` at the end. If False, SymPy Will not interfere with naming of
  718. variables. [default=True]
  719. Examples
  720. ========
  721. >>> from sympy import fcode, symbols, Rational, sin, ceiling, floor
  722. >>> x, tau = symbols("x, tau")
  723. >>> fcode((2*tau)**Rational(7, 2))
  724. ' 8*sqrt(2.0d0)*tau**(7.0d0/2.0d0)'
  725. >>> fcode(sin(x), assign_to="s")
  726. ' s = sin(x)'
  727. Custom printing can be defined for certain types by passing a dictionary of
  728. "type" : "function" to the ``user_functions`` kwarg. Alternatively, the
  729. dictionary value can be a list of tuples i.e. [(argument_test,
  730. cfunction_string)].
  731. >>> custom_functions = {
  732. ... "ceiling": "CEIL",
  733. ... "floor": [(lambda x: not x.is_integer, "FLOOR1"),
  734. ... (lambda x: x.is_integer, "FLOOR2")]
  735. ... }
  736. >>> fcode(floor(x) + ceiling(x), user_functions=custom_functions)
  737. ' CEIL(x) + FLOOR1(x)'
  738. ``Piecewise`` expressions are converted into conditionals. If an
  739. ``assign_to`` variable is provided an if statement is created, otherwise
  740. the ternary operator is used. Note that if the ``Piecewise`` lacks a
  741. default term, represented by ``(expr, True)`` then an error will be thrown.
  742. This is to prevent generating an expression that may not evaluate to
  743. anything.
  744. >>> from sympy import Piecewise
  745. >>> expr = Piecewise((x + 1, x > 0), (x, True))
  746. >>> print(fcode(expr, tau))
  747. if (x > 0) then
  748. tau = x + 1
  749. else
  750. tau = x
  751. end if
  752. Support for loops is provided through ``Indexed`` types. With
  753. ``contract=True`` these expressions will be turned into loops, whereas
  754. ``contract=False`` will just print the assignment expression that should be
  755. looped over:
  756. >>> from sympy import Eq, IndexedBase, Idx
  757. >>> len_y = 5
  758. >>> y = IndexedBase('y', shape=(len_y,))
  759. >>> t = IndexedBase('t', shape=(len_y,))
  760. >>> Dy = IndexedBase('Dy', shape=(len_y-1,))
  761. >>> i = Idx('i', len_y-1)
  762. >>> e=Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i]))
  763. >>> fcode(e.rhs, assign_to=e.lhs, contract=False)
  764. ' Dy(i) = (y(i + 1) - y(i))/(t(i + 1) - t(i))'
  765. Matrices are also supported, but a ``MatrixSymbol`` of the same dimensions
  766. must be provided to ``assign_to``. Note that any expression that can be
  767. generated normally can also exist inside a Matrix:
  768. >>> from sympy import Matrix, MatrixSymbol
  769. >>> mat = Matrix([x**2, Piecewise((x + 1, x > 0), (x, True)), sin(x)])
  770. >>> A = MatrixSymbol('A', 3, 1)
  771. >>> print(fcode(mat, A))
  772. A(1, 1) = x**2
  773. if (x > 0) then
  774. A(2, 1) = x + 1
  775. else
  776. A(2, 1) = x
  777. end if
  778. A(3, 1) = sin(x)
  779. """
  780. from sympy.printing.fortran import FCodePrinter
  781. return FCodePrinter(settings).doprint(expr, assign_to)
  782. def print_fcode(expr, **settings):
  783. """Prints the Fortran representation of the given expression.
  784. See fcode for the meaning of the optional arguments.
  785. """
  786. print(fcode(expr, **settings))
  787. def cxxcode(expr, assign_to=None, standard='c++11', **settings):
  788. """ C++ equivalent of :func:`~.ccode`. """
  789. from sympy.printing.cxx import cxx_code_printers
  790. return cxx_code_printers[standard.lower()](settings).doprint(expr, assign_to)
  791. def rust_code(expr, assign_to=None, **settings):
  792. """Converts an expr to a string of Rust code
  793. Parameters
  794. ==========
  795. expr : Expr
  796. A SymPy expression to be converted.
  797. assign_to : optional
  798. When given, the argument is used as the name of the variable to which
  799. the expression is assigned. Can be a string, ``Symbol``,
  800. ``MatrixSymbol``, or ``Indexed`` type. This is helpful in case of
  801. line-wrapping, or for expressions that generate multi-line statements.
  802. precision : integer, optional
  803. The precision for numbers such as pi [default=15].
  804. user_functions : dict, optional
  805. A dictionary where the keys are string representations of either
  806. ``FunctionClass`` or ``UndefinedFunction`` instances and the values
  807. are their desired C string representations. Alternatively, the
  808. dictionary value can be a list of tuples i.e. [(argument_test,
  809. cfunction_string)]. See below for examples.
  810. dereference : iterable, optional
  811. An iterable of symbols that should be dereferenced in the printed code
  812. expression. These would be values passed by address to the function.
  813. For example, if ``dereference=[a]``, the resulting code would print
  814. ``(*a)`` instead of ``a``.
  815. human : bool, optional
  816. If True, the result is a single string that may contain some constant
  817. declarations for the number symbols. If False, the same information is
  818. returned in a tuple of (symbols_to_declare, not_supported_functions,
  819. code_text). [default=True].
  820. contract: bool, optional
  821. If True, ``Indexed`` instances are assumed to obey tensor contraction
  822. rules and the corresponding nested loops over indices are generated.
  823. Setting contract=False will not generate loops, instead the user is
  824. responsible to provide values for the indices in the code.
  825. [default=True].
  826. Examples
  827. ========
  828. >>> from sympy import rust_code, symbols, Rational, sin, ceiling, Abs, Function
  829. >>> x, tau = symbols("x, tau")
  830. >>> rust_code((2*tau)**Rational(7, 2))
  831. '8.0*1.4142135623731*tau.powf(7_f64/2.0)'
  832. >>> rust_code(sin(x), assign_to="s")
  833. 's = x.sin();'
  834. Simple custom printing can be defined for certain types by passing a
  835. dictionary of {"type" : "function"} to the ``user_functions`` kwarg.
  836. Alternatively, the dictionary value can be a list of tuples i.e.
  837. [(argument_test, cfunction_string)].
  838. >>> custom_functions = {
  839. ... "ceiling": "CEIL",
  840. ... "Abs": [(lambda x: not x.is_integer, "fabs", 4),
  841. ... (lambda x: x.is_integer, "ABS", 4)],
  842. ... "func": "f"
  843. ... }
  844. >>> func = Function('func')
  845. >>> rust_code(func(Abs(x) + ceiling(x)), user_functions=custom_functions)
  846. '(fabs(x) + x.ceil()).f()'
  847. ``Piecewise`` expressions are converted into conditionals. If an
  848. ``assign_to`` variable is provided an if statement is created, otherwise
  849. the ternary operator is used. Note that if the ``Piecewise`` lacks a
  850. default term, represented by ``(expr, True)`` then an error will be thrown.
  851. This is to prevent generating an expression that may not evaluate to
  852. anything.
  853. >>> from sympy import Piecewise
  854. >>> expr = Piecewise((x + 1, x > 0), (x, True))
  855. >>> print(rust_code(expr, tau))
  856. tau = if (x > 0.0) {
  857. x + 1
  858. } else {
  859. x
  860. };
  861. Support for loops is provided through ``Indexed`` types. With
  862. ``contract=True`` these expressions will be turned into loops, whereas
  863. ``contract=False`` will just print the assignment expression that should be
  864. looped over:
  865. >>> from sympy import Eq, IndexedBase, Idx
  866. >>> len_y = 5
  867. >>> y = IndexedBase('y', shape=(len_y,))
  868. >>> t = IndexedBase('t', shape=(len_y,))
  869. >>> Dy = IndexedBase('Dy', shape=(len_y-1,))
  870. >>> i = Idx('i', len_y-1)
  871. >>> e=Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i]))
  872. >>> rust_code(e.rhs, assign_to=e.lhs, contract=False)
  873. 'Dy[i] = (y[i + 1] - y[i])/(t[i + 1] - t[i]);'
  874. Matrices are also supported, but a ``MatrixSymbol`` of the same dimensions
  875. must be provided to ``assign_to``. Note that any expression that can be
  876. generated normally can also exist inside a Matrix:
  877. >>> from sympy import Matrix, MatrixSymbol
  878. >>> mat = Matrix([x**2, Piecewise((x + 1, x > 0), (x, True)), sin(x)])
  879. >>> A = MatrixSymbol('A', 3, 1)
  880. >>> print(rust_code(mat, A))
  881. A = [x.powi(2), if (x > 0.0) {
  882. x + 1
  883. } else {
  884. x
  885. }, x.sin()];
  886. """
  887. from sympy.printing.rust import RustCodePrinter
  888. printer = RustCodePrinter(settings)
  889. expr = printer._rewrite_known_functions(expr)
  890. if isinstance(expr, Expr):
  891. for src_func, dst_func in printer.function_overrides.values():
  892. expr = expr.replace(src_func, dst_func)
  893. return printer.doprint(expr, assign_to)
  894. def print_rust_code(expr, **settings):
  895. """Prints Rust representation of the given expression."""
  896. print(rust_code(expr, **settings))