mathematica.py 39 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085
  1. from __future__ import annotations
  2. import re
  3. import typing
  4. from itertools import product
  5. from typing import Any, Callable
  6. import sympy
  7. from sympy import Mul, Add, Pow, Rational, log, exp, sqrt, cos, sin, tan, asin, acos, acot, asec, acsc, sinh, cosh, tanh, asinh, \
  8. acosh, atanh, acoth, asech, acsch, expand, im, flatten, polylog, cancel, expand_trig, sign, simplify, \
  9. UnevaluatedExpr, S, atan, atan2, Mod, Max, Min, rf, Ei, Si, Ci, airyai, airyaiprime, airybi, primepi, prime, \
  10. isprime, cot, sec, csc, csch, sech, coth, Function, I, pi, Tuple, GreaterThan, StrictGreaterThan, StrictLessThan, \
  11. LessThan, Equality, Or, And, Lambda, Integer, Dummy, symbols
  12. from sympy.core.sympify import sympify, _sympify
  13. from sympy.functions.special.bessel import airybiprime
  14. from sympy.functions.special.error_functions import li
  15. from sympy.utilities.exceptions import sympy_deprecation_warning
  16. def mathematica(s, additional_translations=None):
  17. sympy_deprecation_warning(
  18. """The ``mathematica`` function for the Mathematica parser is now
  19. deprecated. Use ``parse_mathematica`` instead.
  20. The parameter ``additional_translation`` can be replaced by SymPy's
  21. .replace( ) or .subs( ) methods on the output expression instead.""",
  22. deprecated_since_version="1.11",
  23. active_deprecations_target="mathematica-parser-new",
  24. )
  25. parser = MathematicaParser(additional_translations)
  26. return sympify(parser._parse_old(s))
  27. def parse_mathematica(s):
  28. """
  29. Translate a string containing a Wolfram Mathematica expression to a SymPy
  30. expression.
  31. If the translator is unable to find a suitable SymPy expression, the
  32. ``FullForm`` of the Mathematica expression will be output, using SymPy
  33. ``Function`` objects as nodes of the syntax tree.
  34. Examples
  35. ========
  36. >>> from sympy.parsing.mathematica import parse_mathematica
  37. >>> parse_mathematica("Sin[x]^2 Tan[y]")
  38. sin(x)**2*tan(y)
  39. >>> e = parse_mathematica("F[7,5,3]")
  40. >>> e
  41. F(7, 5, 3)
  42. >>> from sympy import Function, Max, Min
  43. >>> e.replace(Function("F"), lambda *x: Max(*x)*Min(*x))
  44. 21
  45. Both standard input form and Mathematica full form are supported:
  46. >>> parse_mathematica("x*(a + b)")
  47. x*(a + b)
  48. >>> parse_mathematica("Times[x, Plus[a, b]]")
  49. x*(a + b)
  50. To get a matrix from Wolfram's code:
  51. >>> m = parse_mathematica("{{a, b}, {c, d}}")
  52. >>> m
  53. ((a, b), (c, d))
  54. >>> from sympy import Matrix
  55. >>> Matrix(m)
  56. Matrix([
  57. [a, b],
  58. [c, d]])
  59. If the translation into equivalent SymPy expressions fails, an SymPy
  60. expression equivalent to Wolfram Mathematica's "FullForm" will be created:
  61. >>> parse_mathematica("x_.")
  62. Optional(Pattern(x, Blank()))
  63. >>> parse_mathematica("Plus @@ {x, y, z}")
  64. Apply(Plus, (x, y, z))
  65. >>> parse_mathematica("f[x_, 3] := x^3 /; x > 0")
  66. SetDelayed(f(Pattern(x, Blank()), 3), Condition(x**3, x > 0))
  67. """
  68. parser = MathematicaParser()
  69. return parser.parse(s)
  70. def _parse_Function(*args):
  71. if len(args) == 1:
  72. arg = args[0]
  73. Slot = Function("Slot")
  74. slots = arg.atoms(Slot)
  75. numbers = [a.args[0] for a in slots]
  76. number_of_arguments = max(numbers)
  77. if isinstance(number_of_arguments, Integer):
  78. variables = symbols(f"dummy0:{number_of_arguments}", cls=Dummy)
  79. return Lambda(variables, arg.xreplace({Slot(i+1): v for i, v in enumerate(variables)}))
  80. return Lambda((), arg)
  81. elif len(args) == 2:
  82. variables = args[0]
  83. body = args[1]
  84. return Lambda(variables, body)
  85. else:
  86. raise SyntaxError("Function node expects 1 or 2 arguments")
  87. def _deco(cls):
  88. cls._initialize_class()
  89. return cls
  90. @_deco
  91. class MathematicaParser:
  92. """
  93. An instance of this class converts a string of a Wolfram Mathematica
  94. expression to a SymPy expression.
  95. The main parser acts internally in three stages:
  96. 1. tokenizer: tokenizes the Mathematica expression and adds the missing *
  97. operators. Handled by ``_from_mathematica_to_tokens(...)``
  98. 2. full form list: sort the list of strings output by the tokenizer into a
  99. syntax tree of nested lists and strings, equivalent to Mathematica's
  100. ``FullForm`` expression output. This is handled by the function
  101. ``_from_tokens_to_fullformlist(...)``.
  102. 3. SymPy expression: the syntax tree expressed as full form list is visited
  103. and the nodes with equivalent classes in SymPy are replaced. Unknown
  104. syntax tree nodes are cast to SymPy ``Function`` objects. This is
  105. handled by ``_from_fullformlist_to_sympy(...)``.
  106. """
  107. # left: Mathematica, right: SymPy
  108. CORRESPONDENCES = {
  109. 'Sqrt[x]': 'sqrt(x)',
  110. 'Rational[x,y]': 'Rational(x,y)',
  111. 'Exp[x]': 'exp(x)',
  112. 'Log[x]': 'log(x)',
  113. 'Log[x,y]': 'log(y,x)',
  114. 'Log2[x]': 'log(x,2)',
  115. 'Log10[x]': 'log(x,10)',
  116. 'Mod[x,y]': 'Mod(x,y)',
  117. 'Max[*x]': 'Max(*x)',
  118. 'Min[*x]': 'Min(*x)',
  119. 'Pochhammer[x,y]':'rf(x,y)',
  120. 'ArcTan[x,y]':'atan2(y,x)',
  121. 'ExpIntegralEi[x]': 'Ei(x)',
  122. 'SinIntegral[x]': 'Si(x)',
  123. 'CosIntegral[x]': 'Ci(x)',
  124. 'AiryAi[x]': 'airyai(x)',
  125. 'AiryAiPrime[x]': 'airyaiprime(x)',
  126. 'AiryBi[x]' :'airybi(x)',
  127. 'AiryBiPrime[x]' :'airybiprime(x)',
  128. 'LogIntegral[x]':' li(x)',
  129. 'PrimePi[x]': 'primepi(x)',
  130. 'Prime[x]': 'prime(x)',
  131. 'PrimeQ[x]': 'isprime(x)'
  132. }
  133. # trigonometric, e.t.c.
  134. for arc, tri, h in product(('', 'Arc'), (
  135. 'Sin', 'Cos', 'Tan', 'Cot', 'Sec', 'Csc'), ('', 'h')):
  136. fm = arc + tri + h + '[x]'
  137. if arc: # arc func
  138. fs = 'a' + tri.lower() + h + '(x)'
  139. else: # non-arc func
  140. fs = tri.lower() + h + '(x)'
  141. CORRESPONDENCES.update({fm: fs})
  142. REPLACEMENTS = {
  143. ' ': '',
  144. '^': '**',
  145. '{': '[',
  146. '}': ']',
  147. }
  148. RULES = {
  149. # a single whitespace to '*'
  150. 'whitespace': (
  151. re.compile(r'''
  152. (?:(?<=[a-zA-Z\d])|(?<=\d\.)) # a letter or a number
  153. \s+ # any number of whitespaces
  154. (?:(?=[a-zA-Z\d])|(?=\.\d)) # a letter or a number
  155. ''', re.VERBOSE),
  156. '*'),
  157. # add omitted '*' character
  158. 'add*_1': (
  159. re.compile(r'''
  160. (?:(?<=[])\d])|(?<=\d\.)) # ], ) or a number
  161. # ''
  162. (?=[(a-zA-Z]) # ( or a single letter
  163. ''', re.VERBOSE),
  164. '*'),
  165. # add omitted '*' character (variable letter preceding)
  166. 'add*_2': (
  167. re.compile(r'''
  168. (?<=[a-zA-Z]) # a letter
  169. \( # ( as a character
  170. (?=.) # any characters
  171. ''', re.VERBOSE),
  172. '*('),
  173. # convert 'Pi' to 'pi'
  174. 'Pi': (
  175. re.compile(r'''
  176. (?:
  177. \A|(?<=[^a-zA-Z])
  178. )
  179. Pi # 'Pi' is 3.14159... in Mathematica
  180. (?=[^a-zA-Z])
  181. ''', re.VERBOSE),
  182. 'pi'),
  183. }
  184. # Mathematica function name pattern
  185. FM_PATTERN = re.compile(r'''
  186. (?:
  187. \A|(?<=[^a-zA-Z]) # at the top or a non-letter
  188. )
  189. [A-Z][a-zA-Z\d]* # Function
  190. (?=\[) # [ as a character
  191. ''', re.VERBOSE)
  192. # list or matrix pattern (for future usage)
  193. ARG_MTRX_PATTERN = re.compile(r'''
  194. \{.*\}
  195. ''', re.VERBOSE)
  196. # regex string for function argument pattern
  197. ARGS_PATTERN_TEMPLATE = r'''
  198. (?:
  199. \A|(?<=[^a-zA-Z])
  200. )
  201. {arguments} # model argument like x, y,...
  202. (?=[^a-zA-Z])
  203. '''
  204. # will contain transformed CORRESPONDENCES dictionary
  205. TRANSLATIONS: dict[tuple[str, int], dict[str, Any]] = {}
  206. # cache for a raw users' translation dictionary
  207. cache_original: dict[tuple[str, int], dict[str, Any]] = {}
  208. # cache for a compiled users' translation dictionary
  209. cache_compiled: dict[tuple[str, int], dict[str, Any]] = {}
  210. @classmethod
  211. def _initialize_class(cls):
  212. # get a transformed CORRESPONDENCES dictionary
  213. d = cls._compile_dictionary(cls.CORRESPONDENCES)
  214. cls.TRANSLATIONS.update(d)
  215. def __init__(self, additional_translations=None):
  216. self.translations = {}
  217. # update with TRANSLATIONS (class constant)
  218. self.translations.update(self.TRANSLATIONS)
  219. if additional_translations is None:
  220. additional_translations = {}
  221. # check the latest added translations
  222. if self.__class__.cache_original != additional_translations:
  223. if not isinstance(additional_translations, dict):
  224. raise ValueError('The argument must be dict type')
  225. # get a transformed additional_translations dictionary
  226. d = self._compile_dictionary(additional_translations)
  227. # update cache
  228. self.__class__.cache_original = additional_translations
  229. self.__class__.cache_compiled = d
  230. # merge user's own translations
  231. self.translations.update(self.__class__.cache_compiled)
  232. @classmethod
  233. def _compile_dictionary(cls, dic):
  234. # for return
  235. d = {}
  236. for fm, fs in dic.items():
  237. # check function form
  238. cls._check_input(fm)
  239. cls._check_input(fs)
  240. # uncover '*' hiding behind a whitespace
  241. fm = cls._apply_rules(fm, 'whitespace')
  242. fs = cls._apply_rules(fs, 'whitespace')
  243. # remove whitespace(s)
  244. fm = cls._replace(fm, ' ')
  245. fs = cls._replace(fs, ' ')
  246. # search Mathematica function name
  247. m = cls.FM_PATTERN.search(fm)
  248. # if no-hit
  249. if m is None:
  250. err = "'{f}' function form is invalid.".format(f=fm)
  251. raise ValueError(err)
  252. # get Mathematica function name like 'Log'
  253. fm_name = m.group()
  254. # get arguments of Mathematica function
  255. args, end = cls._get_args(m)
  256. # function side check. (e.g.) '2*Func[x]' is invalid.
  257. if m.start() != 0 or end != len(fm):
  258. err = "'{f}' function form is invalid.".format(f=fm)
  259. raise ValueError(err)
  260. # check the last argument's 1st character
  261. if args[-1][0] == '*':
  262. key_arg = '*'
  263. else:
  264. key_arg = len(args)
  265. key = (fm_name, key_arg)
  266. # convert '*x' to '\\*x' for regex
  267. re_args = [x if x[0] != '*' else '\\' + x for x in args]
  268. # for regex. Example: (?:(x|y|z))
  269. xyz = '(?:(' + '|'.join(re_args) + '))'
  270. # string for regex compile
  271. patStr = cls.ARGS_PATTERN_TEMPLATE.format(arguments=xyz)
  272. pat = re.compile(patStr, re.VERBOSE)
  273. # update dictionary
  274. d[key] = {}
  275. d[key]['fs'] = fs # SymPy function template
  276. d[key]['args'] = args # args are ['x', 'y'] for example
  277. d[key]['pat'] = pat
  278. return d
  279. def _convert_function(self, s):
  280. '''Parse Mathematica function to SymPy one'''
  281. # compiled regex object
  282. pat = self.FM_PATTERN
  283. scanned = '' # converted string
  284. cur = 0 # position cursor
  285. while True:
  286. m = pat.search(s)
  287. if m is None:
  288. # append the rest of string
  289. scanned += s
  290. break
  291. # get Mathematica function name
  292. fm = m.group()
  293. # get arguments, and the end position of fm function
  294. args, end = self._get_args(m)
  295. # the start position of fm function
  296. bgn = m.start()
  297. # convert Mathematica function to SymPy one
  298. s = self._convert_one_function(s, fm, args, bgn, end)
  299. # update cursor
  300. cur = bgn
  301. # append converted part
  302. scanned += s[:cur]
  303. # shrink s
  304. s = s[cur:]
  305. return scanned
  306. def _convert_one_function(self, s, fm, args, bgn, end):
  307. # no variable-length argument
  308. if (fm, len(args)) in self.translations:
  309. key = (fm, len(args))
  310. # x, y,... model arguments
  311. x_args = self.translations[key]['args']
  312. # make CORRESPONDENCES between model arguments and actual ones
  313. d = dict(zip(x_args, args))
  314. # with variable-length argument
  315. elif (fm, '*') in self.translations:
  316. key = (fm, '*')
  317. # x, y,..*args (model arguments)
  318. x_args = self.translations[key]['args']
  319. # make CORRESPONDENCES between model arguments and actual ones
  320. d = {}
  321. for i, x in enumerate(x_args):
  322. if x[0] == '*':
  323. d[x] = ','.join(args[i:])
  324. break
  325. d[x] = args[i]
  326. # out of self.translations
  327. else:
  328. err = "'{f}' is out of the whitelist.".format(f=fm)
  329. raise ValueError(err)
  330. # template string of converted function
  331. template = self.translations[key]['fs']
  332. # regex pattern for x_args
  333. pat = self.translations[key]['pat']
  334. scanned = ''
  335. cur = 0
  336. while True:
  337. m = pat.search(template)
  338. if m is None:
  339. scanned += template
  340. break
  341. # get model argument
  342. x = m.group()
  343. # get a start position of the model argument
  344. xbgn = m.start()
  345. # add the corresponding actual argument
  346. scanned += template[:xbgn] + d[x]
  347. # update cursor to the end of the model argument
  348. cur = m.end()
  349. # shrink template
  350. template = template[cur:]
  351. # update to swapped string
  352. s = s[:bgn] + scanned + s[end:]
  353. return s
  354. @classmethod
  355. def _get_args(cls, m):
  356. '''Get arguments of a Mathematica function'''
  357. s = m.string # whole string
  358. anc = m.end() + 1 # pointing the first letter of arguments
  359. square, curly = [], [] # stack for brackets
  360. args = []
  361. # current cursor
  362. cur = anc
  363. for i, c in enumerate(s[anc:], anc):
  364. # extract one argument
  365. if c == ',' and (not square) and (not curly):
  366. args.append(s[cur:i]) # add an argument
  367. cur = i + 1 # move cursor
  368. # handle list or matrix (for future usage)
  369. if c == '{':
  370. curly.append(c)
  371. elif c == '}':
  372. curly.pop()
  373. # seek corresponding ']' with skipping irrevant ones
  374. if c == '[':
  375. square.append(c)
  376. elif c == ']':
  377. if square:
  378. square.pop()
  379. else: # empty stack
  380. args.append(s[cur:i])
  381. break
  382. # the next position to ']' bracket (the function end)
  383. func_end = i + 1
  384. return args, func_end
  385. @classmethod
  386. def _replace(cls, s, bef):
  387. aft = cls.REPLACEMENTS[bef]
  388. s = s.replace(bef, aft)
  389. return s
  390. @classmethod
  391. def _apply_rules(cls, s, bef):
  392. pat, aft = cls.RULES[bef]
  393. return pat.sub(aft, s)
  394. @classmethod
  395. def _check_input(cls, s):
  396. for bracket in (('[', ']'), ('{', '}'), ('(', ')')):
  397. if s.count(bracket[0]) != s.count(bracket[1]):
  398. err = "'{f}' function form is invalid.".format(f=s)
  399. raise ValueError(err)
  400. if '{' in s:
  401. err = "Currently list is not supported."
  402. raise ValueError(err)
  403. def _parse_old(self, s):
  404. # input check
  405. self._check_input(s)
  406. # uncover '*' hiding behind a whitespace
  407. s = self._apply_rules(s, 'whitespace')
  408. # remove whitespace(s)
  409. s = self._replace(s, ' ')
  410. # add omitted '*' character
  411. s = self._apply_rules(s, 'add*_1')
  412. s = self._apply_rules(s, 'add*_2')
  413. # translate function
  414. s = self._convert_function(s)
  415. # '^' to '**'
  416. s = self._replace(s, '^')
  417. # 'Pi' to 'pi'
  418. s = self._apply_rules(s, 'Pi')
  419. # '{', '}' to '[', ']', respectively
  420. # s = cls._replace(s, '{') # currently list is not taken into account
  421. # s = cls._replace(s, '}')
  422. return s
  423. def parse(self, s):
  424. s2 = self._from_mathematica_to_tokens(s)
  425. s3 = self._from_tokens_to_fullformlist(s2)
  426. s4 = self._from_fullformlist_to_sympy(s3)
  427. return s4
  428. INFIX = "Infix"
  429. PREFIX = "Prefix"
  430. POSTFIX = "Postfix"
  431. FLAT = "Flat"
  432. RIGHT = "Right"
  433. LEFT = "Left"
  434. _mathematica_op_precedence: list[tuple[str, str | None, dict[str, str | Callable]]] = [
  435. (POSTFIX, None, {";": lambda x: x + ["Null"] if isinstance(x, list) and x and x[0] == "CompoundExpression" else ["CompoundExpression", x, "Null"]}),
  436. (INFIX, FLAT, {";": "CompoundExpression"}),
  437. (INFIX, RIGHT, {"=": "Set", ":=": "SetDelayed", "+=": "AddTo", "-=": "SubtractFrom", "*=": "TimesBy", "/=": "DivideBy"}),
  438. (INFIX, LEFT, {"//": lambda x, y: [x, y]}),
  439. (POSTFIX, None, {"&": "Function"}),
  440. (INFIX, LEFT, {"/.": "ReplaceAll"}),
  441. (INFIX, RIGHT, {"->": "Rule", ":>": "RuleDelayed"}),
  442. (INFIX, LEFT, {"/;": "Condition"}),
  443. (INFIX, FLAT, {"|": "Alternatives"}),
  444. (POSTFIX, None, {"..": "Repeated", "...": "RepeatedNull"}),
  445. (INFIX, FLAT, {"||": "Or"}),
  446. (INFIX, FLAT, {"&&": "And"}),
  447. (PREFIX, None, {"!": "Not"}),
  448. (INFIX, FLAT, {"===": "SameQ", "=!=": "UnsameQ"}),
  449. (INFIX, FLAT, {"==": "Equal", "!=": "Unequal", "<=": "LessEqual", "<": "Less", ">=": "GreaterEqual", ">": "Greater"}),
  450. (INFIX, None, {";;": "Span"}),
  451. (INFIX, FLAT, {"+": "Plus", "-": "Plus"}),
  452. (INFIX, FLAT, {"*": "Times", "/": "Times"}),
  453. (INFIX, FLAT, {".": "Dot"}),
  454. (PREFIX, None, {"-": lambda x: MathematicaParser._get_neg(x),
  455. "+": lambda x: x}),
  456. (INFIX, RIGHT, {"^": "Power"}),
  457. (INFIX, RIGHT, {"@@": "Apply", "/@": "Map", "//@": "MapAll", "@@@": lambda x, y: ["Apply", x, y, ["List", "1"]]}),
  458. (POSTFIX, None, {"'": "Derivative", "!": "Factorial", "!!": "Factorial2", "--": "Decrement"}),
  459. (INFIX, None, {"[": lambda x, y: [x, *y], "[[": lambda x, y: ["Part", x, *y]}),
  460. (PREFIX, None, {"{": lambda x: ["List", *x], "(": lambda x: x[0]}),
  461. (INFIX, None, {"?": "PatternTest"}),
  462. (POSTFIX, None, {
  463. "_": lambda x: ["Pattern", x, ["Blank"]],
  464. "_.": lambda x: ["Optional", ["Pattern", x, ["Blank"]]],
  465. "__": lambda x: ["Pattern", x, ["BlankSequence"]],
  466. "___": lambda x: ["Pattern", x, ["BlankNullSequence"]],
  467. }),
  468. (INFIX, None, {"_": lambda x, y: ["Pattern", x, ["Blank", y]]}),
  469. (PREFIX, None, {"#": "Slot", "##": "SlotSequence"}),
  470. ]
  471. _missing_arguments_default = {
  472. "#": lambda: ["Slot", "1"],
  473. "##": lambda: ["SlotSequence", "1"],
  474. }
  475. _literal = r"[A-Za-z][A-Za-z0-9]*"
  476. _number = r"(?:[0-9]+(?:\.[0-9]*)?|\.[0-9]+)"
  477. _enclosure_open = ["(", "[", "[[", "{"]
  478. _enclosure_close = [")", "]", "]]", "}"]
  479. @classmethod
  480. def _get_neg(cls, x):
  481. return f"-{x}" if isinstance(x, str) and re.match(MathematicaParser._number, x) else ["Times", "-1", x]
  482. @classmethod
  483. def _get_inv(cls, x):
  484. return ["Power", x, "-1"]
  485. _regex_tokenizer = None
  486. def _get_tokenizer(self):
  487. if self._regex_tokenizer is not None:
  488. # Check if the regular expression has already been compiled:
  489. return self._regex_tokenizer
  490. tokens = [self._literal, self._number]
  491. tokens_escape = self._enclosure_open[:] + self._enclosure_close[:]
  492. for typ, strat, symdict in self._mathematica_op_precedence:
  493. for k in symdict:
  494. tokens_escape.append(k)
  495. tokens_escape.sort(key=lambda x: -len(x))
  496. tokens.extend(map(re.escape, tokens_escape))
  497. tokens.append(",")
  498. tokens.append("\n")
  499. tokenizer = re.compile("(" + "|".join(tokens) + ")")
  500. self._regex_tokenizer = tokenizer
  501. return self._regex_tokenizer
  502. def _from_mathematica_to_tokens(self, code: str):
  503. tokenizer = self._get_tokenizer()
  504. # Find strings:
  505. code_splits: list[str | list] = []
  506. while True:
  507. string_start = code.find("\"")
  508. if string_start == -1:
  509. if len(code) > 0:
  510. code_splits.append(code)
  511. break
  512. match_end = re.search(r'(?<!\\)"', code[string_start+1:])
  513. if match_end is None:
  514. raise SyntaxError('mismatch in string " " expression')
  515. string_end = string_start + match_end.start() + 1
  516. if string_start > 0:
  517. code_splits.append(code[:string_start])
  518. code_splits.append(["_Str", code[string_start+1:string_end].replace('\\"', '"')])
  519. code = code[string_end+1:]
  520. # Remove comments:
  521. for i, code_split in enumerate(code_splits):
  522. if isinstance(code_split, list):
  523. continue
  524. while True:
  525. pos_comment_start = code_split.find("(*")
  526. if pos_comment_start == -1:
  527. break
  528. pos_comment_end = code_split.find("*)")
  529. if pos_comment_end == -1 or pos_comment_end < pos_comment_start:
  530. raise SyntaxError("mismatch in comment (* *) code")
  531. code_split = code_split[:pos_comment_start] + code_split[pos_comment_end+2:]
  532. code_splits[i] = code_split
  533. # Tokenize the input strings with a regular expression:
  534. token_lists = [tokenizer.findall(i) if isinstance(i, str) and i.isascii() else [i] for i in code_splits]
  535. tokens = [j for i in token_lists for j in i]
  536. # Remove newlines at the beginning
  537. while tokens and tokens[0] == "\n":
  538. tokens.pop(0)
  539. # Remove newlines at the end
  540. while tokens and tokens[-1] == "\n":
  541. tokens.pop(-1)
  542. return tokens
  543. def _is_op(self, token: str | list) -> bool:
  544. if isinstance(token, list):
  545. return False
  546. if re.match(self._literal, token):
  547. return False
  548. if re.match("-?" + self._number, token):
  549. return False
  550. return True
  551. def _is_valid_star1(self, token: str | list) -> bool:
  552. if token in (")", "}"):
  553. return True
  554. return not self._is_op(token)
  555. def _is_valid_star2(self, token: str | list) -> bool:
  556. if token in ("(", "{"):
  557. return True
  558. return not self._is_op(token)
  559. def _from_tokens_to_fullformlist(self, tokens: list):
  560. stack: list[list] = [[]]
  561. open_seq = []
  562. pointer: int = 0
  563. while pointer < len(tokens):
  564. token = tokens[pointer]
  565. if token in self._enclosure_open:
  566. stack[-1].append(token)
  567. open_seq.append(token)
  568. stack.append([])
  569. elif token == ",":
  570. if len(stack[-1]) == 0 and stack[-2][-1] == open_seq[-1]:
  571. raise SyntaxError("%s cannot be followed by comma ," % open_seq[-1])
  572. stack[-1] = self._parse_after_braces(stack[-1])
  573. stack.append([])
  574. elif token in self._enclosure_close:
  575. ind = self._enclosure_close.index(token)
  576. if self._enclosure_open[ind] != open_seq[-1]:
  577. unmatched_enclosure = SyntaxError("unmatched enclosure")
  578. if token == "]]" and open_seq[-1] == "[":
  579. if open_seq[-2] == "[":
  580. # These two lines would be logically correct, but are
  581. # unnecessary:
  582. # token = "]"
  583. # tokens[pointer] = "]"
  584. tokens.insert(pointer+1, "]")
  585. elif open_seq[-2] == "[[":
  586. if tokens[pointer+1] == "]":
  587. tokens[pointer+1] = "]]"
  588. elif tokens[pointer+1] == "]]":
  589. tokens[pointer+1] = "]]"
  590. tokens.insert(pointer+2, "]")
  591. else:
  592. raise unmatched_enclosure
  593. else:
  594. raise unmatched_enclosure
  595. if len(stack[-1]) == 0 and stack[-2][-1] == "(":
  596. raise SyntaxError("( ) not valid syntax")
  597. last_stack = self._parse_after_braces(stack[-1], True)
  598. stack[-1] = last_stack
  599. new_stack_element = []
  600. while stack[-1][-1] != open_seq[-1]:
  601. new_stack_element.append(stack.pop())
  602. new_stack_element.reverse()
  603. if open_seq[-1] == "(" and len(new_stack_element) != 1:
  604. raise SyntaxError("( must be followed by one expression, %i detected" % len(new_stack_element))
  605. stack[-1].append(new_stack_element)
  606. open_seq.pop(-1)
  607. else:
  608. stack[-1].append(token)
  609. pointer += 1
  610. if len(stack) != 1:
  611. raise RuntimeError("Stack should have only one element")
  612. return self._parse_after_braces(stack[0])
  613. def _util_remove_newlines(self, lines: list, tokens: list, inside_enclosure: bool):
  614. pointer = 0
  615. size = len(tokens)
  616. while pointer < size:
  617. token = tokens[pointer]
  618. if token == "\n":
  619. if inside_enclosure:
  620. # Ignore newlines inside enclosures
  621. tokens.pop(pointer)
  622. size -= 1
  623. continue
  624. if pointer == 0:
  625. tokens.pop(0)
  626. size -= 1
  627. continue
  628. if pointer > 1:
  629. try:
  630. prev_expr = self._parse_after_braces(tokens[:pointer], inside_enclosure)
  631. except SyntaxError:
  632. tokens.pop(pointer)
  633. size -= 1
  634. continue
  635. else:
  636. prev_expr = tokens[0]
  637. if len(prev_expr) > 0 and prev_expr[0] == "CompoundExpression":
  638. lines.extend(prev_expr[1:])
  639. else:
  640. lines.append(prev_expr)
  641. for i in range(pointer):
  642. tokens.pop(0)
  643. size -= pointer
  644. pointer = 0
  645. continue
  646. pointer += 1
  647. def _util_add_missing_asterisks(self, tokens: list):
  648. size: int = len(tokens)
  649. pointer: int = 0
  650. while pointer < size:
  651. if (pointer > 0 and
  652. self._is_valid_star1(tokens[pointer - 1]) and
  653. self._is_valid_star2(tokens[pointer])):
  654. # This is a trick to add missing * operators in the expression,
  655. # `"*" in op_dict` makes sure the precedence level is the same as "*",
  656. # while `not self._is_op( ... )` makes sure this and the previous
  657. # expression are not operators.
  658. if tokens[pointer] == "(":
  659. # ( has already been processed by now, replace:
  660. tokens[pointer] = "*"
  661. tokens[pointer + 1] = tokens[pointer + 1][0]
  662. else:
  663. tokens.insert(pointer, "*")
  664. pointer += 1
  665. size += 1
  666. pointer += 1
  667. def _parse_after_braces(self, tokens: list, inside_enclosure: bool = False):
  668. op_dict: dict
  669. changed: bool = False
  670. lines: list = []
  671. self._util_remove_newlines(lines, tokens, inside_enclosure)
  672. for op_type, grouping_strat, op_dict in reversed(self._mathematica_op_precedence):
  673. if "*" in op_dict:
  674. self._util_add_missing_asterisks(tokens)
  675. size: int = len(tokens)
  676. pointer: int = 0
  677. while pointer < size:
  678. token = tokens[pointer]
  679. if isinstance(token, str) and token in op_dict:
  680. op_name: str | Callable = op_dict[token]
  681. node: list
  682. first_index: int
  683. if isinstance(op_name, str):
  684. node = [op_name]
  685. first_index = 1
  686. else:
  687. node = []
  688. first_index = 0
  689. if token in ("+", "-") and op_type == self.PREFIX and pointer > 0 and not self._is_op(tokens[pointer - 1]):
  690. # Make sure that PREFIX + - don't match expressions like a + b or a - b,
  691. # the INFIX + - are supposed to match that expression:
  692. pointer += 1
  693. continue
  694. if op_type == self.INFIX:
  695. if pointer == 0 or pointer == size - 1 or self._is_op(tokens[pointer - 1]) or self._is_op(tokens[pointer + 1]):
  696. pointer += 1
  697. continue
  698. changed = True
  699. tokens[pointer] = node
  700. if op_type == self.INFIX:
  701. arg1 = tokens.pop(pointer-1)
  702. arg2 = tokens.pop(pointer)
  703. if token == "/":
  704. arg2 = self._get_inv(arg2)
  705. elif token == "-":
  706. arg2 = self._get_neg(arg2)
  707. pointer -= 1
  708. size -= 2
  709. node.append(arg1)
  710. node_p = node
  711. if grouping_strat == self.FLAT:
  712. while pointer + 2 < size and self._check_op_compatible(tokens[pointer+1], token):
  713. node_p.append(arg2)
  714. other_op = tokens.pop(pointer+1)
  715. arg2 = tokens.pop(pointer+1)
  716. if other_op == "/":
  717. arg2 = self._get_inv(arg2)
  718. elif other_op == "-":
  719. arg2 = self._get_neg(arg2)
  720. size -= 2
  721. node_p.append(arg2)
  722. elif grouping_strat == self.RIGHT:
  723. while pointer + 2 < size and tokens[pointer+1] == token:
  724. node_p.append([op_name, arg2])
  725. node_p = node_p[-1]
  726. tokens.pop(pointer+1)
  727. arg2 = tokens.pop(pointer+1)
  728. size -= 2
  729. node_p.append(arg2)
  730. elif grouping_strat == self.LEFT:
  731. while pointer + 1 < size and tokens[pointer+1] == token:
  732. if isinstance(op_name, str):
  733. node_p[first_index] = [op_name, node_p[first_index], arg2]
  734. else:
  735. node_p[first_index] = op_name(node_p[first_index], arg2)
  736. tokens.pop(pointer+1)
  737. arg2 = tokens.pop(pointer+1)
  738. size -= 2
  739. node_p.append(arg2)
  740. else:
  741. node.append(arg2)
  742. elif op_type == self.PREFIX:
  743. if grouping_strat is not None:
  744. raise TypeError("'Prefix' op_type should not have a grouping strat")
  745. if pointer == size - 1 or self._is_op(tokens[pointer + 1]):
  746. tokens[pointer] = self._missing_arguments_default[token]()
  747. else:
  748. node.append(tokens.pop(pointer+1))
  749. size -= 1
  750. elif op_type == self.POSTFIX:
  751. if grouping_strat is not None:
  752. raise TypeError("'Prefix' op_type should not have a grouping strat")
  753. if pointer == 0 or self._is_op(tokens[pointer - 1]):
  754. tokens[pointer] = self._missing_arguments_default[token]()
  755. else:
  756. node.append(tokens.pop(pointer-1))
  757. pointer -= 1
  758. size -= 1
  759. if isinstance(op_name, Callable): # type: ignore
  760. op_call: Callable = typing.cast(Callable, op_name)
  761. new_node = op_call(*node)
  762. node.clear()
  763. if isinstance(new_node, list):
  764. node.extend(new_node)
  765. else:
  766. tokens[pointer] = new_node
  767. pointer += 1
  768. if len(tokens) > 1 or (len(lines) == 0 and len(tokens) == 0):
  769. if changed:
  770. # Trick to deal with cases in which an operator with lower
  771. # precedence should be transformed before an operator of higher
  772. # precedence. Such as in the case of `#&[x]` (that is
  773. # equivalent to `Lambda(d_, d_)(x)` in SymPy). In this case the
  774. # operator `&` has lower precedence than `[`, but needs to be
  775. # evaluated first because otherwise `# (&[x])` is not a valid
  776. # expression:
  777. return self._parse_after_braces(tokens, inside_enclosure)
  778. raise SyntaxError("unable to create a single AST for the expression")
  779. if len(lines) > 0:
  780. if tokens[0] and tokens[0][0] == "CompoundExpression":
  781. tokens = tokens[0][1:]
  782. compound_expression = ["CompoundExpression", *lines, *tokens]
  783. return compound_expression
  784. return tokens[0]
  785. def _check_op_compatible(self, op1: str, op2: str):
  786. if op1 == op2:
  787. return True
  788. muldiv = {"*", "/"}
  789. addsub = {"+", "-"}
  790. if op1 in muldiv and op2 in muldiv:
  791. return True
  792. if op1 in addsub and op2 in addsub:
  793. return True
  794. return False
  795. def _from_fullform_to_fullformlist(self, wmexpr: str):
  796. """
  797. Parses FullForm[Downvalues[]] generated by Mathematica
  798. """
  799. out: list = []
  800. stack = [out]
  801. generator = re.finditer(r'[\[\],]', wmexpr)
  802. last_pos = 0
  803. for match in generator:
  804. if match is None:
  805. break
  806. position = match.start()
  807. last_expr = wmexpr[last_pos:position].replace(',', '').replace(']', '').replace('[', '').strip()
  808. if match.group() == ',':
  809. if last_expr != '':
  810. stack[-1].append(last_expr)
  811. elif match.group() == ']':
  812. if last_expr != '':
  813. stack[-1].append(last_expr)
  814. stack.pop()
  815. elif match.group() == '[':
  816. stack[-1].append([last_expr])
  817. stack.append(stack[-1][-1])
  818. last_pos = match.end()
  819. return out[0]
  820. def _from_fullformlist_to_fullformsympy(self, pylist: list):
  821. from sympy import Function, Symbol
  822. def converter(expr):
  823. if isinstance(expr, list):
  824. if len(expr) > 0:
  825. head = expr[0]
  826. args = [converter(arg) for arg in expr[1:]]
  827. return Function(head)(*args)
  828. else:
  829. raise ValueError("Empty list of expressions")
  830. elif isinstance(expr, str):
  831. return Symbol(expr)
  832. else:
  833. return _sympify(expr)
  834. return converter(pylist)
  835. _node_conversions = {
  836. "Times": Mul,
  837. "Plus": Add,
  838. "Power": Pow,
  839. "Rational": Rational,
  840. "Log": lambda *a: log(*reversed(a)),
  841. "Log2": lambda x: log(x, 2),
  842. "Log10": lambda x: log(x, 10),
  843. "Exp": exp,
  844. "Sqrt": sqrt,
  845. "Sin": sin,
  846. "Cos": cos,
  847. "Tan": tan,
  848. "Cot": cot,
  849. "Sec": sec,
  850. "Csc": csc,
  851. "ArcSin": asin,
  852. "ArcCos": acos,
  853. "ArcTan": lambda *a: atan2(*reversed(a)) if len(a) == 2 else atan(*a),
  854. "ArcCot": acot,
  855. "ArcSec": asec,
  856. "ArcCsc": acsc,
  857. "Sinh": sinh,
  858. "Cosh": cosh,
  859. "Tanh": tanh,
  860. "Coth": coth,
  861. "Sech": sech,
  862. "Csch": csch,
  863. "ArcSinh": asinh,
  864. "ArcCosh": acosh,
  865. "ArcTanh": atanh,
  866. "ArcCoth": acoth,
  867. "ArcSech": asech,
  868. "ArcCsch": acsch,
  869. "Expand": expand,
  870. "Im": im,
  871. "Re": sympy.re,
  872. "Flatten": flatten,
  873. "Polylog": polylog,
  874. "Cancel": cancel,
  875. # Gamma=gamma,
  876. "TrigExpand": expand_trig,
  877. "Sign": sign,
  878. "Simplify": simplify,
  879. "Defer": UnevaluatedExpr,
  880. "Identity": S,
  881. # Sum=Sum_doit,
  882. # Module=With,
  883. # Block=With,
  884. "Null": lambda *a: S.Zero,
  885. "Mod": Mod,
  886. "Max": Max,
  887. "Min": Min,
  888. "Pochhammer": rf,
  889. "ExpIntegralEi": Ei,
  890. "SinIntegral": Si,
  891. "CosIntegral": Ci,
  892. "AiryAi": airyai,
  893. "AiryAiPrime": airyaiprime,
  894. "AiryBi": airybi,
  895. "AiryBiPrime": airybiprime,
  896. "LogIntegral": li,
  897. "PrimePi": primepi,
  898. "Prime": prime,
  899. "PrimeQ": isprime,
  900. "List": Tuple,
  901. "Greater": StrictGreaterThan,
  902. "GreaterEqual": GreaterThan,
  903. "Less": StrictLessThan,
  904. "LessEqual": LessThan,
  905. "Equal": Equality,
  906. "Or": Or,
  907. "And": And,
  908. "Function": _parse_Function,
  909. }
  910. _atom_conversions = {
  911. "I": I,
  912. "Pi": pi,
  913. }
  914. def _from_fullformlist_to_sympy(self, full_form_list):
  915. def recurse(expr):
  916. if isinstance(expr, list):
  917. if isinstance(expr[0], list):
  918. head = recurse(expr[0])
  919. else:
  920. head = self._node_conversions.get(expr[0], Function(expr[0]))
  921. return head(*[recurse(arg) for arg in expr[1:]])
  922. else:
  923. return self._atom_conversions.get(expr, sympify(expr))
  924. return recurse(full_form_list)
  925. def _from_fullformsympy_to_sympy(self, mform):
  926. expr = mform
  927. for mma_form, sympy_node in self._node_conversions.items():
  928. expr = expr.replace(Function(mma_form), sympy_node)
  929. return expr