rust.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637
  1. """
  2. Rust code printer
  3. The `RustCodePrinter` converts SymPy expressions into Rust expressions.
  4. A complete code generator, which uses `rust_code` extensively, can be found
  5. in `sympy.utilities.codegen`. The `codegen` module can be used to generate
  6. complete source code files.
  7. """
  8. # Possible Improvement
  9. #
  10. # * make sure we follow Rust Style Guidelines_
  11. # * make use of pattern matching
  12. # * better support for reference
  13. # * generate generic code and use trait to make sure they have specific methods
  14. # * use crates_ to get more math support
  15. # - num_
  16. # + BigInt_, BigUint_
  17. # + Complex_
  18. # + Rational64_, Rational32_, BigRational_
  19. #
  20. # .. _crates: https://crates.io/
  21. # .. _Guidelines: https://github.com/rust-lang/rust/tree/master/src/doc/style
  22. # .. _num: http://rust-num.github.io/num/num/
  23. # .. _BigInt: http://rust-num.github.io/num/num/bigint/struct.BigInt.html
  24. # .. _BigUint: http://rust-num.github.io/num/num/bigint/struct.BigUint.html
  25. # .. _Complex: http://rust-num.github.io/num/num/complex/struct.Complex.html
  26. # .. _Rational32: http://rust-num.github.io/num/num/rational/type.Rational32.html
  27. # .. _Rational64: http://rust-num.github.io/num/num/rational/type.Rational64.html
  28. # .. _BigRational: http://rust-num.github.io/num/num/rational/type.BigRational.html
  29. from __future__ import annotations
  30. from functools import reduce
  31. import operator
  32. from typing import Any
  33. from sympy.codegen.ast import (
  34. float32, float64, int32,
  35. real, integer, bool_
  36. )
  37. from sympy.core import S, Rational, Float, Lambda
  38. from sympy.core.expr import Expr
  39. from sympy.core.numbers import equal_valued
  40. from sympy.functions.elementary.integers import ceiling, floor
  41. from sympy.printing.codeprinter import CodePrinter
  42. from sympy.printing.precedence import PRECEDENCE
  43. # Rust's methods for integer and float can be found at here :
  44. #
  45. # * `Rust - Primitive Type f64 <https://doc.rust-lang.org/std/primitive.f64.html>`_
  46. # * `Rust - Primitive Type i64 <https://doc.rust-lang.org/std/primitive.i64.html>`_
  47. #
  48. # Function Style :
  49. #
  50. # 1. args[0].func(args[1:]), method with arguments
  51. # 2. args[0].func(), method without arguments
  52. # 3. args[1].func(), method without arguments (e.g. (e, x) => x.exp())
  53. # 4. func(args), function with arguments
  54. # dictionary mapping SymPy function to (argument_conditions, Rust_function).
  55. # Used in RustCodePrinter._print_Function(self)
  56. class float_floor(floor):
  57. """
  58. Same as `sympy.floor`, but mimics the Rust behavior of returning a float rather than an integer
  59. """
  60. def _eval_is_integer(self):
  61. return False
  62. class float_ceiling(ceiling):
  63. """
  64. Same as `sympy.ceiling`, but mimics the Rust behavior of returning a float rather than an integer
  65. """
  66. def _eval_is_integer(self):
  67. return False
  68. function_overrides = {
  69. "floor": (floor, float_floor),
  70. "ceiling": (ceiling, float_ceiling),
  71. }
  72. # f64 method in Rust
  73. known_functions = {
  74. # "": "is_nan",
  75. # "": "is_infinite",
  76. # "": "is_finite",
  77. # "": "is_normal",
  78. # "": "classify",
  79. "float_floor": "floor",
  80. "float_ceiling": "ceil",
  81. # "": "round",
  82. # "": "trunc",
  83. # "": "fract",
  84. "Abs": "abs",
  85. # "": "signum",
  86. # "": "is_sign_positive",
  87. # "": "is_sign_negative",
  88. # "": "mul_add",
  89. "Pow": [(lambda base, exp: equal_valued(exp, -1), "recip", 2), # 1.0/x
  90. (lambda base, exp: equal_valued(exp, 0.5), "sqrt", 2), # x ** 0.5
  91. (lambda base, exp: equal_valued(exp, -0.5), "sqrt().recip", 2), # 1/(x ** 0.5)
  92. (lambda base, exp: exp == Rational(1, 3), "cbrt", 2), # x ** (1/3)
  93. (lambda base, exp: equal_valued(base, 2), "exp2", 3), # 2 ** x
  94. (lambda base, exp: exp.is_integer, "powi", 1), # x ** y, for i32
  95. (lambda base, exp: not exp.is_integer, "powf", 1)], # x ** y, for f64
  96. "exp": [(lambda exp: True, "exp", 2)], # e ** x
  97. "log": "ln",
  98. # "": "log", # number.log(base)
  99. # "": "log2",
  100. # "": "log10",
  101. # "": "to_degrees",
  102. # "": "to_radians",
  103. "Max": "max",
  104. "Min": "min",
  105. # "": "hypot", # (x**2 + y**2) ** 0.5
  106. "sin": "sin",
  107. "cos": "cos",
  108. "tan": "tan",
  109. "asin": "asin",
  110. "acos": "acos",
  111. "atan": "atan",
  112. "atan2": "atan2",
  113. # "": "sin_cos",
  114. # "": "exp_m1", # e ** x - 1
  115. # "": "ln_1p", # ln(1 + x)
  116. "sinh": "sinh",
  117. "cosh": "cosh",
  118. "tanh": "tanh",
  119. "asinh": "asinh",
  120. "acosh": "acosh",
  121. "atanh": "atanh",
  122. "sqrt": "sqrt", # To enable automatic rewrites
  123. }
  124. # i64 method in Rust
  125. # known_functions_i64 = {
  126. # "": "min_value",
  127. # "": "max_value",
  128. # "": "from_str_radix",
  129. # "": "count_ones",
  130. # "": "count_zeros",
  131. # "": "leading_zeros",
  132. # "": "trainling_zeros",
  133. # "": "rotate_left",
  134. # "": "rotate_right",
  135. # "": "swap_bytes",
  136. # "": "from_be",
  137. # "": "from_le",
  138. # "": "to_be", # to big endian
  139. # "": "to_le", # to little endian
  140. # "": "checked_add",
  141. # "": "checked_sub",
  142. # "": "checked_mul",
  143. # "": "checked_div",
  144. # "": "checked_rem",
  145. # "": "checked_neg",
  146. # "": "checked_shl",
  147. # "": "checked_shr",
  148. # "": "checked_abs",
  149. # "": "saturating_add",
  150. # "": "saturating_sub",
  151. # "": "saturating_mul",
  152. # "": "wrapping_add",
  153. # "": "wrapping_sub",
  154. # "": "wrapping_mul",
  155. # "": "wrapping_div",
  156. # "": "wrapping_rem",
  157. # "": "wrapping_neg",
  158. # "": "wrapping_shl",
  159. # "": "wrapping_shr",
  160. # "": "wrapping_abs",
  161. # "": "overflowing_add",
  162. # "": "overflowing_sub",
  163. # "": "overflowing_mul",
  164. # "": "overflowing_div",
  165. # "": "overflowing_rem",
  166. # "": "overflowing_neg",
  167. # "": "overflowing_shl",
  168. # "": "overflowing_shr",
  169. # "": "overflowing_abs",
  170. # "Pow": "pow",
  171. # "Abs": "abs",
  172. # "sign": "signum",
  173. # "": "is_positive",
  174. # "": "is_negnative",
  175. # }
  176. # These are the core reserved words in the Rust language. Taken from:
  177. # https://doc.rust-lang.org/reference/keywords.html
  178. reserved_words = ['abstract',
  179. 'as',
  180. 'async',
  181. 'await',
  182. 'become',
  183. 'box',
  184. 'break',
  185. 'const',
  186. 'continue',
  187. 'crate',
  188. 'do',
  189. 'dyn',
  190. 'else',
  191. 'enum',
  192. 'extern',
  193. 'false',
  194. 'final',
  195. 'fn',
  196. 'for',
  197. 'gen',
  198. 'if',
  199. 'impl',
  200. 'in',
  201. 'let',
  202. 'loop',
  203. 'macro',
  204. 'match',
  205. 'mod',
  206. 'move',
  207. 'mut',
  208. 'override',
  209. 'priv',
  210. 'pub',
  211. 'ref',
  212. 'return',
  213. 'Self',
  214. 'self',
  215. 'static',
  216. 'struct',
  217. 'super',
  218. 'trait',
  219. 'true',
  220. 'try',
  221. 'type',
  222. 'typeof',
  223. 'unsafe',
  224. 'unsized',
  225. 'use',
  226. 'virtual',
  227. 'where',
  228. 'while',
  229. 'yield']
  230. class TypeCast(Expr):
  231. """
  232. The type casting operator of the Rust language.
  233. """
  234. def __init__(self, expr, type_) -> None:
  235. super().__init__()
  236. self.explicit = expr.is_integer and type_ is not integer
  237. self._assumptions = expr._assumptions
  238. if self.explicit:
  239. setattr(self, 'precedence', PRECEDENCE["Func"] + 10)
  240. @property
  241. def expr(self):
  242. return self.args[0]
  243. @property
  244. def type_(self):
  245. return self.args[1]
  246. def sort_key(self, order=None):
  247. return self.args[0].sort_key(order=order)
  248. class RustCodePrinter(CodePrinter):
  249. """A printer to convert SymPy expressions to strings of Rust code"""
  250. printmethod = "_rust_code"
  251. language = "Rust"
  252. type_aliases = {
  253. integer: int32,
  254. real: float64,
  255. }
  256. type_mappings = {
  257. int32: 'i32',
  258. float32: 'f32',
  259. float64: 'f64',
  260. bool_: 'bool'
  261. }
  262. _default_settings: dict[str, Any] = dict(CodePrinter._default_settings, **{
  263. 'precision': 17,
  264. 'user_functions': {},
  265. 'contract': True,
  266. 'dereference': set(),
  267. })
  268. def __init__(self, settings={}):
  269. CodePrinter.__init__(self, settings)
  270. self.known_functions = dict(known_functions)
  271. userfuncs = settings.get('user_functions', {})
  272. self.known_functions.update(userfuncs)
  273. self._dereference = set(settings.get('dereference', []))
  274. self.reserved_words = set(reserved_words)
  275. self.function_overrides = function_overrides
  276. def _rate_index_position(self, p):
  277. return p*5
  278. def _get_statement(self, codestring):
  279. return "%s;" % codestring
  280. def _get_comment(self, text):
  281. return "// %s" % text
  282. def _declare_number_const(self, name, value):
  283. type_ = self.type_mappings[self.type_aliases[real]]
  284. return "const %s: %s = %s;" % (name, type_, value)
  285. def _format_code(self, lines):
  286. return self.indent_code(lines)
  287. def _traverse_matrix_indices(self, mat):
  288. rows, cols = mat.shape
  289. return ((i, j) for i in range(rows) for j in range(cols))
  290. def _get_loop_opening_ending(self, indices):
  291. open_lines = []
  292. close_lines = []
  293. loopstart = "for %(var)s in %(start)s..%(end)s {"
  294. for i in indices:
  295. # Rust arrays start at 0 and end at dimension-1
  296. open_lines.append(loopstart % {
  297. 'var': self._print(i),
  298. 'start': self._print(i.lower),
  299. 'end': self._print(i.upper + 1)})
  300. close_lines.append("}")
  301. return open_lines, close_lines
  302. def _print_caller_var(self, expr):
  303. if len(expr.args) > 1:
  304. # for something like `sin(x + y + z)`,
  305. # make sure we can get '(x + y + z).sin()'
  306. # instead of 'x + y + z.sin()'
  307. return '(' + self._print(expr) + ')'
  308. elif expr.is_number:
  309. return self._print(expr, _type=True)
  310. else:
  311. return self._print(expr)
  312. def _print_Function(self, expr):
  313. """
  314. basic function for printing `Function`
  315. Function Style :
  316. 1. args[0].func(args[1:]), method with arguments
  317. 2. args[0].func(), method without arguments
  318. 3. args[1].func(), method without arguments (e.g. (e, x) => x.exp())
  319. 4. func(args), function with arguments
  320. """
  321. if expr.func.__name__ in self.known_functions:
  322. cond_func = self.known_functions[expr.func.__name__]
  323. func = None
  324. style = 1
  325. if isinstance(cond_func, str):
  326. func = cond_func
  327. else:
  328. for cond, func, style in cond_func:
  329. if cond(*expr.args):
  330. break
  331. if func is not None:
  332. if style == 1:
  333. ret = "%(var)s.%(method)s(%(args)s)" % {
  334. 'var': self._print_caller_var(expr.args[0]),
  335. 'method': func,
  336. 'args': self.stringify(expr.args[1:], ", ") if len(expr.args) > 1 else ''
  337. }
  338. elif style == 2:
  339. ret = "%(var)s.%(method)s()" % {
  340. 'var': self._print_caller_var(expr.args[0]),
  341. 'method': func,
  342. }
  343. elif style == 3:
  344. ret = "%(var)s.%(method)s()" % {
  345. 'var': self._print_caller_var(expr.args[1]),
  346. 'method': func,
  347. }
  348. else:
  349. ret = "%(func)s(%(args)s)" % {
  350. 'func': func,
  351. 'args': self.stringify(expr.args, ", "),
  352. }
  353. return ret
  354. elif hasattr(expr, '_imp_') and isinstance(expr._imp_, Lambda):
  355. # inlined function
  356. return self._print(expr._imp_(*expr.args))
  357. else:
  358. return self._print_not_supported(expr)
  359. def _print_Mul(self, expr):
  360. contains_floats = any(arg.is_real and not arg.is_integer for arg in expr.args)
  361. if contains_floats:
  362. expr = reduce(operator.mul,(self._cast_to_float(arg) if arg != -1 else arg for arg in expr.args))
  363. return super()._print_Mul(expr)
  364. def _print_Add(self, expr, order=None):
  365. contains_floats = any(arg.is_real and not arg.is_integer for arg in expr.args)
  366. if contains_floats:
  367. expr = reduce(operator.add, (self._cast_to_float(arg) for arg in expr.args))
  368. return super()._print_Add(expr, order)
  369. def _print_Pow(self, expr):
  370. if expr.base.is_integer and not expr.exp.is_integer:
  371. expr = type(expr)(Float(expr.base), expr.exp)
  372. return self._print(expr)
  373. return self._print_Function(expr)
  374. def _print_TypeCast(self, expr):
  375. if not expr.explicit:
  376. return self._print(expr.expr)
  377. else:
  378. return self._print(expr.expr) + ' as %s' % self.type_mappings[self.type_aliases[expr.type_]]
  379. def _print_Float(self, expr, _type=False):
  380. ret = super()._print_Float(expr)
  381. if _type:
  382. return ret + '_%s' % self.type_mappings[self.type_aliases[real]]
  383. else:
  384. return ret
  385. def _print_Integer(self, expr, _type=False):
  386. ret = super()._print_Integer(expr)
  387. if _type:
  388. return ret + '_%s' % self.type_mappings[self.type_aliases[integer]]
  389. else:
  390. return ret
  391. def _print_Rational(self, expr):
  392. p, q = int(expr.p), int(expr.q)
  393. float_suffix = self.type_mappings[self.type_aliases[real]]
  394. return '%d_%s/%d.0' % (p, float_suffix, q)
  395. def _print_Relational(self, expr):
  396. if (expr.lhs.is_integer and not expr.rhs.is_integer) or (expr.rhs.is_integer and not expr.lhs.is_integer):
  397. lhs = self._cast_to_float(expr.lhs)
  398. rhs = self._cast_to_float(expr.rhs)
  399. else:
  400. lhs = expr.lhs
  401. rhs = expr.rhs
  402. lhs_code = self._print(lhs)
  403. rhs_code = self._print(rhs)
  404. op = expr.rel_op
  405. return "{} {} {}".format(lhs_code, op, rhs_code)
  406. def _print_Indexed(self, expr):
  407. # calculate index for 1d array
  408. dims = expr.shape
  409. elem = S.Zero
  410. offset = S.One
  411. for i in reversed(range(expr.rank)):
  412. elem += expr.indices[i]*offset
  413. offset *= dims[i]
  414. return "%s[%s]" % (self._print(expr.base.label), self._print(elem))
  415. def _print_Idx(self, expr):
  416. return expr.label.name
  417. def _print_Dummy(self, expr):
  418. return expr.name
  419. def _print_Exp1(self, expr, _type=False):
  420. return "E"
  421. def _print_Pi(self, expr, _type=False):
  422. return 'PI'
  423. def _print_Infinity(self, expr, _type=False):
  424. return 'INFINITY'
  425. def _print_NegativeInfinity(self, expr, _type=False):
  426. return 'NEG_INFINITY'
  427. def _print_BooleanTrue(self, expr, _type=False):
  428. return "true"
  429. def _print_BooleanFalse(self, expr, _type=False):
  430. return "false"
  431. def _print_bool(self, expr, _type=False):
  432. return str(expr).lower()
  433. def _print_NaN(self, expr, _type=False):
  434. return "NAN"
  435. def _print_Piecewise(self, expr):
  436. if expr.args[-1].cond != True:
  437. # We need the last conditional to be a True, otherwise the resulting
  438. # function may not return a result.
  439. raise ValueError("All Piecewise expressions must contain an "
  440. "(expr, True) statement to be used as a default "
  441. "condition. Without one, the generated "
  442. "expression may not evaluate to anything under "
  443. "some condition.")
  444. lines = []
  445. for i, (e, c) in enumerate(expr.args):
  446. if i == 0:
  447. lines.append("if (%s) {" % self._print(c))
  448. elif i == len(expr.args) - 1 and c == True:
  449. lines[-1] += " else {"
  450. else:
  451. lines[-1] += " else if (%s) {" % self._print(c)
  452. code0 = self._print(e)
  453. lines.append(code0)
  454. lines.append("}")
  455. if self._settings['inline']:
  456. return " ".join(lines)
  457. else:
  458. return "\n".join(lines)
  459. def _print_ITE(self, expr):
  460. from sympy.functions import Piecewise
  461. return self._print(expr.rewrite(Piecewise, deep=False))
  462. def _print_MatrixBase(self, A):
  463. if A.cols == 1:
  464. return "[%s]" % ", ".join(self._print(a) for a in A)
  465. else:
  466. raise ValueError("Full Matrix Support in Rust need Crates (https://crates.io/keywords/matrix).")
  467. def _print_SparseRepMatrix(self, mat):
  468. # do not allow sparse matrices to be made dense
  469. return self._print_not_supported(mat)
  470. def _print_MatrixElement(self, expr):
  471. return "%s[%s]" % (expr.parent,
  472. expr.j + expr.i*expr.parent.shape[1])
  473. def _print_Symbol(self, expr):
  474. name = super()._print_Symbol(expr)
  475. if expr in self._dereference:
  476. return '(*%s)' % name
  477. else:
  478. return name
  479. def _print_Assignment(self, expr):
  480. from sympy.tensor.indexed import IndexedBase
  481. lhs = expr.lhs
  482. rhs = expr.rhs
  483. if self._settings["contract"] and (lhs.has(IndexedBase) or
  484. rhs.has(IndexedBase)):
  485. # Here we check if there is looping to be done, and if so
  486. # print the required loops.
  487. return self._doprint_loops(rhs, lhs)
  488. else:
  489. lhs_code = self._print(lhs)
  490. rhs_code = self._print(rhs)
  491. return self._get_statement("%s = %s" % (lhs_code, rhs_code))
  492. def _print_sign(self, expr):
  493. arg = self._print(expr.args[0])
  494. return "(if (%s == 0.0) { 0.0 } else { (%s).signum() })" % (arg, arg)
  495. def _cast_to_float(self, expr):
  496. if not expr.is_number:
  497. return TypeCast(expr, real)
  498. elif expr.is_integer:
  499. return Float(expr)
  500. return expr
  501. def _can_print(self, name):
  502. """ Check if function ``name`` is either a known function or has its own
  503. printing method. Used to check if rewriting is possible."""
  504. # since the whole point of function_overrides is to enable proper printing,
  505. # we presume they all are printable
  506. return name in self.known_functions or name in function_overrides or getattr(self, '_print_{}'.format(name), False)
  507. def _collect_functions(self, expr):
  508. functions = set()
  509. if isinstance(expr, Expr):
  510. if expr.is_Function:
  511. functions.add(expr.func)
  512. for arg in expr.args:
  513. functions = functions.union(self._collect_functions(arg))
  514. return functions
  515. def _rewrite_known_functions(self, expr):
  516. if not isinstance(expr, Expr):
  517. return expr
  518. expression_functions = self._collect_functions(expr)
  519. rewriteable_functions = {
  520. name: (target_f, required_fs)
  521. for name, (target_f, required_fs) in self._rewriteable_functions.items()
  522. if self._can_print(target_f)
  523. and all(self._can_print(f) for f in required_fs)
  524. }
  525. for func in expression_functions:
  526. target_f, _ = rewriteable_functions.get(func.__name__, (None, None))
  527. if target_f:
  528. expr = expr.rewrite(target_f)
  529. return expr
  530. def indent_code(self, code):
  531. """Accepts a string of code or a list of code lines"""
  532. if isinstance(code, str):
  533. code_lines = self.indent_code(code.splitlines(True))
  534. return ''.join(code_lines)
  535. tab = " "
  536. inc_token = ('{', '(', '{\n', '(\n')
  537. dec_token = ('}', ')')
  538. code = [ line.lstrip(' \t') for line in code ]
  539. increase = [ int(any(map(line.endswith, inc_token))) for line in code ]
  540. decrease = [ int(any(map(line.startswith, dec_token)))
  541. for line in code ]
  542. pretty = []
  543. level = 0
  544. for n, line in enumerate(code):
  545. if line in ('', '\n'):
  546. pretty.append(line)
  547. continue
  548. level -= decrease[n]
  549. pretty.append("%s%s" % (tab*level, line))
  550. level += increase[n]
  551. return pretty