mathematica.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. """
  2. Mathematica code printer
  3. """
  4. from __future__ import annotations
  5. from typing import Any
  6. from sympy.core import Basic, Expr, Float
  7. from sympy.core.sorting import default_sort_key
  8. from sympy.printing.codeprinter import CodePrinter
  9. from sympy.printing.precedence import precedence
  10. # Used in MCodePrinter._print_Function(self)
  11. known_functions = {
  12. "exp": [(lambda x: True, "Exp")],
  13. "log": [(lambda x: True, "Log")],
  14. "sin": [(lambda x: True, "Sin")],
  15. "cos": [(lambda x: True, "Cos")],
  16. "tan": [(lambda x: True, "Tan")],
  17. "cot": [(lambda x: True, "Cot")],
  18. "sec": [(lambda x: True, "Sec")],
  19. "csc": [(lambda x: True, "Csc")],
  20. "asin": [(lambda x: True, "ArcSin")],
  21. "acos": [(lambda x: True, "ArcCos")],
  22. "atan": [(lambda x: True, "ArcTan")],
  23. "acot": [(lambda x: True, "ArcCot")],
  24. "asec": [(lambda x: True, "ArcSec")],
  25. "acsc": [(lambda x: True, "ArcCsc")],
  26. "sinh": [(lambda x: True, "Sinh")],
  27. "cosh": [(lambda x: True, "Cosh")],
  28. "tanh": [(lambda x: True, "Tanh")],
  29. "coth": [(lambda x: True, "Coth")],
  30. "sech": [(lambda x: True, "Sech")],
  31. "csch": [(lambda x: True, "Csch")],
  32. "asinh": [(lambda x: True, "ArcSinh")],
  33. "acosh": [(lambda x: True, "ArcCosh")],
  34. "atanh": [(lambda x: True, "ArcTanh")],
  35. "acoth": [(lambda x: True, "ArcCoth")],
  36. "asech": [(lambda x: True, "ArcSech")],
  37. "acsch": [(lambda x: True, "ArcCsch")],
  38. "sinc": [(lambda x: True, "Sinc")],
  39. "conjugate": [(lambda x: True, "Conjugate")],
  40. "Max": [(lambda *x: True, "Max")],
  41. "Min": [(lambda *x: True, "Min")],
  42. "erf": [(lambda x: True, "Erf")],
  43. "erf2": [(lambda *x: True, "Erf")],
  44. "erfc": [(lambda x: True, "Erfc")],
  45. "erfi": [(lambda x: True, "Erfi")],
  46. "erfinv": [(lambda x: True, "InverseErf")],
  47. "erfcinv": [(lambda x: True, "InverseErfc")],
  48. "erf2inv": [(lambda *x: True, "InverseErf")],
  49. "expint": [(lambda *x: True, "ExpIntegralE")],
  50. "Ei": [(lambda x: True, "ExpIntegralEi")],
  51. "fresnelc": [(lambda x: True, "FresnelC")],
  52. "fresnels": [(lambda x: True, "FresnelS")],
  53. "gamma": [(lambda x: True, "Gamma")],
  54. "uppergamma": [(lambda *x: True, "Gamma")],
  55. "polygamma": [(lambda *x: True, "PolyGamma")],
  56. "loggamma": [(lambda x: True, "LogGamma")],
  57. "beta": [(lambda *x: True, "Beta")],
  58. "Ci": [(lambda x: True, "CosIntegral")],
  59. "Si": [(lambda x: True, "SinIntegral")],
  60. "Chi": [(lambda x: True, "CoshIntegral")],
  61. "Shi": [(lambda x: True, "SinhIntegral")],
  62. "li": [(lambda x: True, "LogIntegral")],
  63. "factorial": [(lambda x: True, "Factorial")],
  64. "factorial2": [(lambda x: True, "Factorial2")],
  65. "subfactorial": [(lambda x: True, "Subfactorial")],
  66. "catalan": [(lambda x: True, "CatalanNumber")],
  67. "harmonic": [(lambda *x: True, "HarmonicNumber")],
  68. "lucas": [(lambda x: True, "LucasL")],
  69. "RisingFactorial": [(lambda *x: True, "Pochhammer")],
  70. "FallingFactorial": [(lambda *x: True, "FactorialPower")],
  71. "laguerre": [(lambda *x: True, "LaguerreL")],
  72. "assoc_laguerre": [(lambda *x: True, "LaguerreL")],
  73. "hermite": [(lambda *x: True, "HermiteH")],
  74. "jacobi": [(lambda *x: True, "JacobiP")],
  75. "gegenbauer": [(lambda *x: True, "GegenbauerC")],
  76. "chebyshevt": [(lambda *x: True, "ChebyshevT")],
  77. "chebyshevu": [(lambda *x: True, "ChebyshevU")],
  78. "legendre": [(lambda *x: True, "LegendreP")],
  79. "assoc_legendre": [(lambda *x: True, "LegendreP")],
  80. "mathieuc": [(lambda *x: True, "MathieuC")],
  81. "mathieus": [(lambda *x: True, "MathieuS")],
  82. "mathieucprime": [(lambda *x: True, "MathieuCPrime")],
  83. "mathieusprime": [(lambda *x: True, "MathieuSPrime")],
  84. "stieltjes": [(lambda x: True, "StieltjesGamma")],
  85. "elliptic_e": [(lambda *x: True, "EllipticE")],
  86. "elliptic_f": [(lambda *x: True, "EllipticE")],
  87. "elliptic_k": [(lambda x: True, "EllipticK")],
  88. "elliptic_pi": [(lambda *x: True, "EllipticPi")],
  89. "zeta": [(lambda *x: True, "Zeta")],
  90. "dirichlet_eta": [(lambda x: True, "DirichletEta")],
  91. "riemann_xi": [(lambda x: True, "RiemannXi")],
  92. "besseli": [(lambda *x: True, "BesselI")],
  93. "besselj": [(lambda *x: True, "BesselJ")],
  94. "besselk": [(lambda *x: True, "BesselK")],
  95. "bessely": [(lambda *x: True, "BesselY")],
  96. "hankel1": [(lambda *x: True, "HankelH1")],
  97. "hankel2": [(lambda *x: True, "HankelH2")],
  98. "airyai": [(lambda x: True, "AiryAi")],
  99. "airybi": [(lambda x: True, "AiryBi")],
  100. "airyaiprime": [(lambda x: True, "AiryAiPrime")],
  101. "airybiprime": [(lambda x: True, "AiryBiPrime")],
  102. "polylog": [(lambda *x: True, "PolyLog")],
  103. "lerchphi": [(lambda *x: True, "LerchPhi")],
  104. "gcd": [(lambda *x: True, "GCD")],
  105. "lcm": [(lambda *x: True, "LCM")],
  106. "jn": [(lambda *x: True, "SphericalBesselJ")],
  107. "yn": [(lambda *x: True, "SphericalBesselY")],
  108. "hyper": [(lambda *x: True, "HypergeometricPFQ")],
  109. "meijerg": [(lambda *x: True, "MeijerG")],
  110. "appellf1": [(lambda *x: True, "AppellF1")],
  111. "DiracDelta": [(lambda x: True, "DiracDelta")],
  112. "Heaviside": [(lambda x: True, "HeavisideTheta")],
  113. "KroneckerDelta": [(lambda *x: True, "KroneckerDelta")],
  114. "sqrt": [(lambda x: True, "Sqrt")], # For automatic rewrites
  115. }
  116. class MCodePrinter(CodePrinter):
  117. """A printer to convert Python expressions to
  118. strings of the Wolfram's Mathematica code
  119. """
  120. printmethod = "_mcode"
  121. language = "Wolfram Language"
  122. _default_settings: dict[str, Any] = dict(CodePrinter._default_settings, **{
  123. 'precision': 15,
  124. 'user_functions': {},
  125. })
  126. _number_symbols: set[tuple[Expr, Float]] = set()
  127. _not_supported: set[Basic] = set()
  128. def __init__(self, settings={}):
  129. """Register function mappings supplied by user"""
  130. CodePrinter.__init__(self, settings)
  131. self.known_functions = dict(known_functions)
  132. userfuncs = settings.get('user_functions', {}).copy()
  133. for k, v in userfuncs.items():
  134. if not isinstance(v, list):
  135. userfuncs[k] = [(lambda *x: True, v)]
  136. self.known_functions.update(userfuncs)
  137. def _format_code(self, lines):
  138. return lines
  139. def _print_Pow(self, expr):
  140. PREC = precedence(expr)
  141. return '%s^%s' % (self.parenthesize(expr.base, PREC),
  142. self.parenthesize(expr.exp, PREC))
  143. def _print_Mul(self, expr):
  144. PREC = precedence(expr)
  145. c, nc = expr.args_cnc()
  146. res = super()._print_Mul(expr.func(*c))
  147. if nc:
  148. res += '*'
  149. res += '**'.join(self.parenthesize(a, PREC) for a in nc)
  150. return res
  151. def _print_Relational(self, expr):
  152. lhs_code = self._print(expr.lhs)
  153. rhs_code = self._print(expr.rhs)
  154. op = expr.rel_op
  155. return "{} {} {}".format(lhs_code, op, rhs_code)
  156. # Primitive numbers
  157. def _print_Zero(self, expr):
  158. return '0'
  159. def _print_One(self, expr):
  160. return '1'
  161. def _print_NegativeOne(self, expr):
  162. return '-1'
  163. def _print_Half(self, expr):
  164. return '1/2'
  165. def _print_ImaginaryUnit(self, expr):
  166. return 'I'
  167. # Infinity and invalid numbers
  168. def _print_Infinity(self, expr):
  169. return 'Infinity'
  170. def _print_NegativeInfinity(self, expr):
  171. return '-Infinity'
  172. def _print_ComplexInfinity(self, expr):
  173. return 'ComplexInfinity'
  174. def _print_NaN(self, expr):
  175. return 'Indeterminate'
  176. # Mathematical constants
  177. def _print_Exp1(self, expr):
  178. return 'E'
  179. def _print_Pi(self, expr):
  180. return 'Pi'
  181. def _print_GoldenRatio(self, expr):
  182. return 'GoldenRatio'
  183. def _print_TribonacciConstant(self, expr):
  184. expanded = expr.expand(func=True)
  185. PREC = precedence(expr)
  186. return self.parenthesize(expanded, PREC)
  187. def _print_EulerGamma(self, expr):
  188. return 'EulerGamma'
  189. def _print_Catalan(self, expr):
  190. return 'Catalan'
  191. def _print_list(self, expr):
  192. return '{' + ', '.join(self.doprint(a) for a in expr) + '}'
  193. _print_tuple = _print_list
  194. _print_Tuple = _print_list
  195. def _print_ImmutableDenseMatrix(self, expr):
  196. return self.doprint(expr.tolist())
  197. def _print_ImmutableSparseMatrix(self, expr):
  198. def print_rule(pos, val):
  199. return '{} -> {}'.format(
  200. self.doprint((pos[0]+1, pos[1]+1)), self.doprint(val))
  201. def print_data():
  202. items = sorted(expr.todok().items(), key=default_sort_key)
  203. return '{' + \
  204. ', '.join(print_rule(k, v) for k, v in items) + \
  205. '}'
  206. def print_dims():
  207. return self.doprint(expr.shape)
  208. return 'SparseArray[{}, {}]'.format(print_data(), print_dims())
  209. def _print_ImmutableDenseNDimArray(self, expr):
  210. return self.doprint(expr.tolist())
  211. def _print_ImmutableSparseNDimArray(self, expr):
  212. def print_string_list(string_list):
  213. return '{' + ', '.join(a for a in string_list) + '}'
  214. def to_mathematica_index(*args):
  215. """Helper function to change Python style indexing to
  216. Pathematica indexing.
  217. Python indexing (0, 1 ... n-1)
  218. -> Mathematica indexing (1, 2 ... n)
  219. """
  220. return tuple(i + 1 for i in args)
  221. def print_rule(pos, val):
  222. """Helper function to print a rule of Mathematica"""
  223. return '{} -> {}'.format(self.doprint(pos), self.doprint(val))
  224. def print_data():
  225. """Helper function to print data part of Mathematica
  226. sparse array.
  227. It uses the fourth notation ``SparseArray[data,{d1,d2,...}]``
  228. from
  229. https://reference.wolfram.com/language/ref/SparseArray.html
  230. ``data`` must be formatted with rule.
  231. """
  232. return print_string_list(
  233. [print_rule(
  234. to_mathematica_index(*(expr._get_tuple_index(key))),
  235. value)
  236. for key, value in sorted(expr._sparse_array.items())]
  237. )
  238. def print_dims():
  239. """Helper function to print dimensions part of Mathematica
  240. sparse array.
  241. It uses the fourth notation ``SparseArray[data,{d1,d2,...}]``
  242. from
  243. https://reference.wolfram.com/language/ref/SparseArray.html
  244. """
  245. return self.doprint(expr.shape)
  246. return 'SparseArray[{}, {}]'.format(print_data(), print_dims())
  247. def _print_Function(self, expr):
  248. if expr.func.__name__ in self.known_functions:
  249. cond_mfunc = self.known_functions[expr.func.__name__]
  250. for cond, mfunc in cond_mfunc:
  251. if cond(*expr.args):
  252. return "%s[%s]" % (mfunc, self.stringify(expr.args, ", "))
  253. elif expr.func.__name__ in self._rewriteable_functions:
  254. # Simple rewrite to supported function possible
  255. target_f, required_fs = self._rewriteable_functions[expr.func.__name__]
  256. if self._can_print(target_f) and all(self._can_print(f) for f in required_fs):
  257. return self._print(expr.rewrite(target_f))
  258. return expr.func.__name__ + "[%s]" % self.stringify(expr.args, ", ")
  259. _print_MinMaxBase = _print_Function
  260. def _print_LambertW(self, expr):
  261. if len(expr.args) == 1:
  262. return "ProductLog[{}]".format(self._print(expr.args[0]))
  263. return "ProductLog[{}, {}]".format(
  264. self._print(expr.args[1]), self._print(expr.args[0]))
  265. def _print_atan2(self, expr):
  266. return "ArcTan[{}, {}]".format(
  267. self._print(expr.args[1]), self._print(expr.args[0]))
  268. def _print_Integral(self, expr):
  269. if len(expr.variables) == 1 and not expr.limits[0][1:]:
  270. args = [expr.args[0], expr.variables[0]]
  271. else:
  272. args = expr.args
  273. return "Hold[Integrate[" + ', '.join(self.doprint(a) for a in args) + "]]"
  274. def _print_Sum(self, expr):
  275. return "Hold[Sum[" + ', '.join(self.doprint(a) for a in expr.args) + "]]"
  276. def _print_Derivative(self, expr):
  277. dexpr = expr.expr
  278. dvars = [i[0] if i[1] == 1 else i for i in expr.variable_count]
  279. return "Hold[D[" + ', '.join(self.doprint(a) for a in [dexpr] + dvars) + "]]"
  280. def _get_comment(self, text):
  281. return "(* {} *)".format(text)
  282. def mathematica_code(expr, **settings):
  283. r"""Converts an expr to a string of the Wolfram Mathematica code
  284. Examples
  285. ========
  286. >>> from sympy import mathematica_code as mcode, symbols, sin
  287. >>> x = symbols('x')
  288. >>> mcode(sin(x).series(x).removeO())
  289. '(1/120)*x^5 - 1/6*x^3 + x'
  290. """
  291. return MCodePrinter(settings).doprint(expr)