printers.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646
  1. import sys
  2. import sympy
  3. from sympy.printing.precedence import PRECEDENCE, precedence
  4. from sympy.printing.str import StrPrinter
  5. INDEX_TYPE = "int64_t"
  6. INDEX_TYPE_MAX = (1 << 63) - 1
  7. INDEX_TYPE_MIN = -1 << 63
  8. # This printer contains rules that are supposed to be generic for both C/C++ and
  9. # Python
  10. class ExprPrinter(StrPrinter):
  11. # override this so that _print_FloorDiv is used
  12. printmethod = "_torch_sympystr"
  13. def _print_Mul(self, expr: sympy.Expr) -> str:
  14. return self.stringify(expr.args, "*", precedence(expr))
  15. def _print_Not(self, expr: sympy.Expr) -> str:
  16. # pyrefly: ignore [missing-attribute]
  17. return f"not ({self._print(expr.args[0])})"
  18. def _print_Add(self, expr: sympy.Expr, order: str | None = None) -> str:
  19. return self.stringify(expr.args, " + ", precedence(expr))
  20. def _print_Relational(self, expr: sympy.Expr) -> str:
  21. return self.stringify(expr.args, f" {expr.rel_op} ", precedence(expr))
  22. def _print_BitwiseFn_bitwise_and(self, expr: sympy.Expr) -> str:
  23. return self.stringify(expr.args, " & ", PRECEDENCE["BitwiseAnd"])
  24. def _print_BitwiseFn_bitwise_or(self, expr: sympy.Expr) -> str:
  25. return self.stringify(expr.args, " | ", PRECEDENCE["BitwiseOr"])
  26. def _print_BitwiseFn_bitwise_xor(self, expr: sympy.Expr) -> str:
  27. return self.stringify(expr.args, " ^ ", PRECEDENCE["BitwiseXor"])
  28. # NB: this is OK to put here, because Mod is only defined for positive
  29. # numbers, and so across C/Python its behavior is consistent
  30. def _print_Mod(self, expr: sympy.Expr) -> str:
  31. return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
  32. def _print_FloatTrueDiv(self, expr: sympy.Expr) -> str:
  33. s = self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5)
  34. return f"({s})"
  35. def _print_CleanDiv(self, expr: sympy.Expr) -> str:
  36. return self._print_FloorDiv(expr)
  37. def _print_Identity(self, expr: sympy.Expr) -> str:
  38. # pyrefly: ignore [missing-attribute]
  39. return self._print(expr.args[0])
  40. def _print_Float(self, expr: sympy.Expr) -> str:
  41. if expr._prec == 53:
  42. # IEEE-754 double precision have 53 bits. SymPy prints them with
  43. # 15 digits, but we need 17 for round-trip correctness
  44. return str(sympy.Float(expr, dps=17))
  45. else:
  46. # We don't use other precisions in pytorch
  47. return str(expr)
  48. # This must be implemented because sympy will collect x * x into Pow(x, 2), without
  49. # any explicit intervention. We print it just like x * x, notably, we
  50. # never generate sympy.Pow with floats.
  51. #
  52. # NB: this pow by natural, you should never have used builtin sympy.pow
  53. # for FloatPow, and a symbolic exponent should be PowByNatural. These
  54. # means exp is guaranteed to be integer.
  55. # pyrefly: ignore [bad-override]
  56. def _print_Pow(self, expr: sympy.Expr) -> str:
  57. base, exp = expr.args
  58. if exp != int(exp):
  59. raise AssertionError(exp)
  60. exp = int(exp)
  61. if exp < 0:
  62. raise AssertionError(f"exponent must be non-negative, got {exp}")
  63. if exp > 0:
  64. return self.stringify([base] * exp, "*", PRECEDENCE["Mul"])
  65. return "1"
  66. # Explicit NotImplemented functions are to prevent default sympy printing
  67. # behavior, which will just barf out ToFloat(...) to your IR. The error
  68. # message is better here because it tells you which printer class it needs
  69. # to go in.
  70. def _print_ToFloat(self, expr: sympy.Expr) -> str:
  71. raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}")
  72. def _print_Infinity(self, expr: sympy.Expr) -> str:
  73. raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}")
  74. def _print_NegativeInfinity(self, expr: sympy.Expr) -> str:
  75. raise NotImplementedError(
  76. f"_print_NegativeInfinity not implemented for {type(self)}"
  77. )
  78. def _print_FloorDiv(self, expr: sympy.Expr) -> str:
  79. raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
  80. def _print_PythonMod(self, expr: sympy.Expr) -> str:
  81. raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}")
  82. def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
  83. raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}")
  84. def _print_PowByNatural(self, expr: sympy.Expr) -> str:
  85. raise NotImplementedError(
  86. f"_print_PowByNatural not implemented for {type(self)}"
  87. )
  88. def _print_FloatPow(self, expr: sympy.Expr) -> str:
  89. raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}")
  90. def _print_TruncToInt(self, expr: sympy.Expr) -> str:
  91. raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}")
  92. def _print_RoundToInt(self, expr: sympy.Expr) -> str:
  93. raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}")
  94. def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
  95. raise NotImplementedError(
  96. f"_print_RoundDecimal not implemented for {type(self)}"
  97. )
  98. # NB: Some float operations are INTENTIONALLY not implemented for
  99. # printers. You can implement them as a quick unblock, but it is better
  100. # to ask yourself why we haven't done this computation in the Tensor
  101. # universe instead
  102. def _print_TruncToFloat(self, expr: sympy.Expr) -> str:
  103. raise NotImplementedError(
  104. f"_print_TruncToFloat not implemented for {type(self)}"
  105. )
  106. class PythonPrinter(ExprPrinter):
  107. def _print_ToFloat(self, expr: sympy.Expr) -> str:
  108. if len(expr.args) != 1:
  109. raise AssertionError("ToFloat expects exactly one argument")
  110. # NB: We use sym_float here because the printer is used for cache
  111. # serialization, and cache guards get evaluated with SymInt to
  112. # propagate guards to the parent ShapeEnv. However, this comes at a
  113. # runtime cost for guards involving float. If this is unacceptable
  114. # overhead, what you want to do is have two separate printers for
  115. # SymInt, one for when the inputs are guaranteed to be int, and
  116. # another for when they could be SymInt.
  117. #
  118. # NB: sym_min/sym_max also have this problem, but I chose not to fix
  119. # those.
  120. #
  121. # See https://github.com/pytorch/pytorch/issues/142507 for more
  122. # context.
  123. # pyrefly: ignore [missing-attribute]
  124. return f"torch.sym_float({self._print(expr.args[0])})"
  125. def _print_And(self, expr: sympy.Expr) -> str:
  126. return self.stringify(expr.args, " and ", precedence(expr))
  127. def _print_Or(self, expr: sympy.Expr) -> str:
  128. return self.stringify(expr.args, " or ", precedence(expr))
  129. def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
  130. x, div, mod = (
  131. self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args
  132. )
  133. if div != "1":
  134. x = f"({x} // {div})"
  135. return f"({x} % {mod})"
  136. def _print_Infinity(self, expr: sympy.Expr) -> str:
  137. return "math.inf"
  138. def _print_NegativeInfinity(self, expr: sympy.Expr) -> str:
  139. return "-math.inf"
  140. # WARNING: this is dangerous for Triton, which has C-style modulus
  141. def _print_PythonMod(self, expr: sympy.Expr) -> str:
  142. return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
  143. # WARNING: this is dangerous for Triton, which has C-style modulus
  144. def _print_FloorDiv(self, expr: sympy.Expr) -> str:
  145. x, div = (self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args)
  146. return f"{x} // {div}"
  147. # WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python
  148. # does a special algorithm
  149. def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
  150. return self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5)
  151. def _helper_sqrt(self, expr: sympy.Expr) -> str:
  152. # pyrefly: ignore [missing-attribute]
  153. return f"math.sqrt({self._print(expr)})"
  154. def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str:
  155. return self._helper_sqrt(expr.args[0])
  156. def _print_FloatPow(self, expr: sympy.Expr) -> str:
  157. return self.stringify(expr.args, " ** ", PRECEDENCE["Pow"])
  158. # TODO: Not sure this works with Triton, even when base/exp are integral
  159. def _print_PowByNatural(self, expr: sympy.Expr) -> str:
  160. return self.stringify(expr.args, " ** ", PRECEDENCE["Pow"])
  161. def _print_floor(self, expr: sympy.Expr) -> str:
  162. if len(expr.args) != 1:
  163. raise AssertionError("floor expects exactly one argument")
  164. # pyrefly: ignore [missing-attribute]
  165. return f"math.floor({self._print(expr.args[0])})"
  166. def _print_FloorToInt(self, expr: sympy.Expr) -> str:
  167. if len(expr.args) != 1:
  168. raise AssertionError("FloorToInt expects exactly one argument")
  169. # pyrefly: ignore [missing-attribute]
  170. return f"math.floor({self._print(expr.args[0])})"
  171. def _print_TruncToInt(self, expr: sympy.Expr) -> str:
  172. if len(expr.args) != 1:
  173. raise AssertionError("TruncToInt expects exactly one argument")
  174. # This also could have been int(), they'll do the same thing for float
  175. # pyrefly: ignore [missing-attribute]
  176. return f"math.trunc({self._print(expr.args[0])})"
  177. def _print_ceiling(self, expr: sympy.Expr) -> str:
  178. if len(expr.args) != 1:
  179. raise AssertionError("ceiling expects exactly one argument")
  180. # pyrefly: ignore [missing-attribute]
  181. return f"math.ceil({self._print(expr.args[0])})"
  182. def _print_CeilToInt(self, expr: sympy.Expr) -> str:
  183. if len(expr.args) != 1:
  184. raise AssertionError("CeilToInt expects exactly one argument")
  185. # pyrefly: ignore [missing-attribute]
  186. return f"math.ceil({self._print(expr.args[0])})"
  187. def _print_Abs(self, expr: sympy.Expr) -> str:
  188. if len(expr.args) != 1:
  189. raise AssertionError("Abs expects exactly one argument")
  190. # pyrefly: ignore [missing-attribute]
  191. return f"abs({self._print(expr.args[0])})"
  192. # NB: It's expected that we've made explicit any promotion in the sympy
  193. # expression, so it doesn't matter that Python max/min doesn't perform
  194. # promotion
  195. def _print_Max(self, expr: sympy.Expr) -> str:
  196. if len(expr.args) < 2:
  197. raise AssertionError("Max expects at least two arguments")
  198. # pyrefly: ignore [missing-attribute]
  199. return f"max({', '.join(map(self._print, expr.args))})"
  200. def _print_Min(self, expr: sympy.Expr) -> str:
  201. if len(expr.args) < 2:
  202. raise AssertionError("Min expects at least two arguments")
  203. # pyrefly: ignore [missing-attribute]
  204. return f"min({', '.join(map(self._print, expr.args))})"
  205. def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str:
  206. if len(expr.args) != 1:
  207. raise AssertionError("cos expects exactly one argument")
  208. # pyrefly: ignore [missing-attribute]
  209. return f"math.cos({self._print(expr.args[0])})"
  210. def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str:
  211. if len(expr.args) != 1:
  212. raise AssertionError("cosh expects exactly one argument")
  213. # pyrefly: ignore [missing-attribute]
  214. return f"math.cosh({self._print(expr.args[0])})"
  215. def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str:
  216. if len(expr.args) != 1:
  217. raise AssertionError("acos expects exactly one argument")
  218. # pyrefly: ignore [missing-attribute]
  219. return f"math.acos({self._print(expr.args[0])})"
  220. def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str:
  221. if len(expr.args) != 1:
  222. raise AssertionError("sin expects exactly one argument")
  223. # pyrefly: ignore [missing-attribute]
  224. return f"math.sin({self._print(expr.args[0])})"
  225. def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str:
  226. if len(expr.args) != 1:
  227. raise AssertionError("sinh expects exactly one argument")
  228. # pyrefly: ignore [missing-attribute]
  229. return f"math.sinh({self._print(expr.args[0])})"
  230. def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str:
  231. if len(expr.args) != 1:
  232. raise AssertionError("asin expects exactly one argument")
  233. # pyrefly: ignore [missing-attribute]
  234. return f"math.asin({self._print(expr.args[0])})"
  235. def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str:
  236. if len(expr.args) != 1:
  237. raise AssertionError("tan expects exactly one argument")
  238. # pyrefly: ignore [missing-attribute]
  239. return f"math.tan({self._print(expr.args[0])})"
  240. def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str:
  241. if len(expr.args) != 1:
  242. raise AssertionError("tanh expects exactly one argument")
  243. # pyrefly: ignore [missing-attribute]
  244. return f"math.tanh({self._print(expr.args[0])})"
  245. def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str:
  246. if len(expr.args) != 1:
  247. raise AssertionError("atan expects exactly one argument")
  248. # pyrefly: ignore [missing-attribute]
  249. return f"math.atan({self._print(expr.args[0])})"
  250. def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str:
  251. if len(expr.args) != 1:
  252. raise AssertionError("log2 expects exactly one argument")
  253. # pyrefly: ignore [missing-attribute]
  254. return f"math.log2({self._print(expr.args[0])})"
  255. def _print_RoundToInt(self, expr: sympy.Expr) -> str:
  256. if len(expr.args) != 1:
  257. raise AssertionError("RoundToInt expects exactly one argument")
  258. # pyrefly: ignore [missing-attribute]
  259. return f"round({self._print(expr.args[0])})"
  260. def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
  261. if len(expr.args) != 2:
  262. raise AssertionError("RoundDecimal expects exactly two arguments")
  263. number, ndigits = expr.args
  264. if not isinstance(ndigits, sympy.Integer):
  265. raise TypeError("ndigits must be an instance of sympy.Integer")
  266. # pyrefly: ignore [missing-attribute]
  267. return f"round({self._print(number)}, {ndigits})"
  268. def _print_Piecewise(self, expr: sympy.Expr) -> str:
  269. # Convert Piecewise(expr_cond_pairs) to nested ternary expressions
  270. # Piecewise((e1, c1), (e2, c2), ..., (eN, cN))
  271. # becomes: e1 if c1 else (e2 if c2 else (... else eN))
  272. result: str | None = None
  273. for expr_i, cond_i in reversed(expr.args):
  274. # pyrefly: ignore [missing-attribute]
  275. expr_str = self._print(expr_i)
  276. if cond_i == True: # noqa: E712
  277. # This is the default case
  278. result = expr_str
  279. else:
  280. # pyrefly: ignore [missing-attribute]
  281. cond_str = self._print(cond_i)
  282. if result is None:
  283. result = expr_str
  284. else:
  285. result = f"({expr_str} if {cond_str} else {result})"
  286. return result if result else "0"
  287. class CppPrinter(ExprPrinter):
  288. def _print_Integer(self, expr: sympy.Expr) -> str:
  289. suffix = "LL" if sys.platform in ["darwin", "win32"] else "L"
  290. i = int(expr)
  291. if i > INDEX_TYPE_MAX or i < INDEX_TYPE_MIN:
  292. raise OverflowError(f"{i} too big to convert to {INDEX_TYPE}")
  293. elif i == INDEX_TYPE_MIN:
  294. if i != (-1) << 63:
  295. raise AssertionError("unexpected minimum index type value")
  296. # Writing -9223372036854775808L makes the value overflow
  297. # as it is parsed as -(9223372036854775808L) by the C/C++ compiler
  298. return f"(-1{suffix} << 63)"
  299. return f"{i}{suffix}"
  300. def _print_Where(self, expr: sympy.Expr) -> str:
  301. c, p, q = (
  302. self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args
  303. )
  304. return f"{c} ? {p} : {q}"
  305. def _print_Piecewise(self, expr: sympy.Expr) -> str:
  306. # Convert Piecewise(expr_cond_pairs) to nested ternary operators
  307. # Piecewise((e1, c1), (e2, c2), ..., (eN, cN))
  308. # becomes: c1 ? e1 : (c2 ? e2 : (... : eN))
  309. result: str | None = None
  310. for expr_i, cond_i in reversed(expr.args):
  311. expr_str = self.parenthesize(expr_i, PRECEDENCE["Atom"] - 0.5)
  312. if cond_i == True: # noqa: E712
  313. # This is the default case
  314. result = expr_str
  315. else:
  316. cond_str = self.parenthesize(cond_i, PRECEDENCE["Atom"] - 0.5)
  317. if result is None:
  318. result = expr_str
  319. else:
  320. result = f"{cond_str} ? {expr_str} : {result}"
  321. return f"({result})" if result else "0"
  322. def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
  323. x, div, mod = expr.args
  324. x = self.doprint(x)
  325. if div != 1:
  326. div = self.doprint(div)
  327. if expr.is_integer:
  328. x = f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
  329. else:
  330. x = f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
  331. mod = self.doprint(mod)
  332. return f"(static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod}))"
  333. def _print_FloorDiv(self, expr: sympy.Expr) -> str:
  334. x, div = expr.args
  335. x = self.doprint(x)
  336. div = self.doprint(div)
  337. if expr.is_integer:
  338. return f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
  339. return f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
  340. def _print_floor(self, expr: sympy.Expr) -> str:
  341. if len(expr.args) != 1:
  342. raise AssertionError("floor expects exactly one argument")
  343. # pyrefly: ignore [missing-attribute]
  344. r = f"std::floor({self._print(expr.args[0])})"
  345. return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
  346. def _print_FloorToInt(self, expr: sympy.Expr) -> str:
  347. if len(expr.args) != 1:
  348. raise AssertionError("FloorToInt expects exactly one argument")
  349. # pyrefly: ignore [missing-attribute]
  350. r = f"std::floor({self._print(expr.args[0])})"
  351. return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
  352. def _print_TruncToInt(self, expr: sympy.Expr) -> str:
  353. if len(expr.args) != 1:
  354. raise AssertionError("TruncToInt expects exactly one argument")
  355. # pyrefly: ignore [missing-attribute]
  356. r = f"std::trunc({self._print(expr.args[0])})"
  357. return f"static_cast<{INDEX_TYPE}>({r})"
  358. def _print_TruncToFloat(self, expr: sympy.Expr) -> str:
  359. if len(expr.args) != 1:
  360. raise AssertionError("TruncToFloat expects exactly one argument")
  361. # pyrefly: ignore [missing-attribute]
  362. return f"std::trunc({self._print(expr.args[0])})"
  363. def _print_ToFloat(self, expr: sympy.Expr) -> str:
  364. if len(expr.args) != 1:
  365. raise AssertionError("ToFloat expects exactly one argument")
  366. # pyrefly: ignore [missing-attribute]
  367. return f"static_cast<double>({self._print(expr.args[0])})"
  368. def _print_PythonMod(self, expr: sympy.Expr) -> str:
  369. x, div = expr.args
  370. x = self.doprint(x)
  371. div = self.doprint(div)
  372. return f"c10::div_mod({x}, {div})"
  373. def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
  374. lhs, rhs = expr.args
  375. # TODO: This is only accurate up to 2**53
  376. # pyrefly: ignore [missing-attribute]
  377. return f"static_cast<double>({self._print(lhs)}) / static_cast<double>({self._print(rhs)})"
  378. # TODO: PowByNatural: we need to implement our own int-int pow. Do NOT
  379. # use std::pow, that operates on floats
  380. def _print_PowByNatural(self, expr: sympy.Expr) -> str:
  381. # Implement the special-case of 2**x for now
  382. base, exp = expr.args
  383. if base == 2:
  384. # pyrefly: ignore [missing-attribute]
  385. return f"(1 << ({self._print(exp)}))"
  386. raise NotImplementedError(
  387. f"_print_PowByNatural not implemented for {type(self)}"
  388. )
  389. def _print_FloatPow(self, expr: sympy.Expr) -> str:
  390. base, exp = expr.args
  391. # pyrefly: ignore [missing-attribute]
  392. return f"std::pow({self._print(base)}, {self._print(exp)})"
  393. def _print_Pow(self, expr: sympy.Expr) -> str:
  394. # Uses float constants to perform FP div
  395. base, exp = expr.args
  396. if exp == 0.5 or exp == -0.5:
  397. # pyrefly: ignore [missing-attribute]
  398. base = self._print(base)
  399. return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})"
  400. if exp.is_integer:
  401. exp = int(exp)
  402. if exp > 0:
  403. r = self.stringify([base] * exp, "*", PRECEDENCE["Mul"])
  404. elif exp < -1:
  405. r = (
  406. "1.0/("
  407. + self.stringify([base] * abs(exp), "*", PRECEDENCE["Mul"])
  408. + ")"
  409. )
  410. elif exp == -1:
  411. # pyrefly: ignore [missing-attribute]
  412. r = "1.0/" + self._print(base)
  413. else: # exp == 0
  414. r = "1.0"
  415. return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
  416. else:
  417. # TODO: float vs double
  418. return f"std::pow({base}, {float(exp)})"
  419. def _print_Rational(self, expr: sympy.Expr) -> str:
  420. # Uses float constants to perform FP div
  421. if expr.q == 1:
  422. r = f"{expr.p}"
  423. else:
  424. r = f"{expr.p}.0/{expr.q}.0"
  425. return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
  426. def _print_ceiling(self, expr: sympy.Expr) -> str:
  427. if len(expr.args) != 1:
  428. raise AssertionError("ceiling expects exactly one argument")
  429. # pyrefly: ignore [missing-attribute]
  430. r = f"std::ceil({self._print(expr.args[0])})"
  431. return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
  432. def _print_CeilToInt(self, expr: sympy.Expr) -> str:
  433. if len(expr.args) != 1:
  434. raise AssertionError("CeilToInt expects exactly one argument")
  435. # pyrefly: ignore [missing-attribute]
  436. r = f"std::ceil({self._print(expr.args[0])})"
  437. return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
  438. def _print_Min(self, expr: sympy.Expr) -> str:
  439. # pyrefly: ignore [missing-attribute]
  440. args = [self._print(a) for a in expr.args]
  441. if len(args) == 2:
  442. return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
  443. else:
  444. # Initializer list overload
  445. il = "{" + ", ".join(args) + "}"
  446. return f"std::min<{INDEX_TYPE}>({il})"
  447. def _print_Max(self, expr: sympy.Expr) -> str:
  448. # pyrefly: ignore [missing-attribute]
  449. args = [self._print(a) for a in expr.args]
  450. if len(args) == 2:
  451. return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
  452. else:
  453. # Initializer list overload
  454. il = "{" + ", ".join(args) + "}"
  455. return f"std::max<{INDEX_TYPE}>({il})"
  456. def _print_Abs(self, expr: sympy.Expr) -> str:
  457. if len(expr.args) != 1:
  458. raise AssertionError("Abs expects exactly one argument")
  459. # pyrefly: ignore [missing-attribute]
  460. return f"std::abs({self._print(expr.args[0])})"
  461. def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str:
  462. if len(expr.args) != 1:
  463. raise AssertionError("cos expects exactly one argument")
  464. # pyrefly: ignore [missing-attribute]
  465. return f"std::cos({self._print(expr.args[0])})"
  466. def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str:
  467. if len(expr.args) != 1:
  468. raise AssertionError("cosh expects exactly one argument")
  469. # pyrefly: ignore [missing-attribute]
  470. return f"std::cosh({self._print(expr.args[0])})"
  471. def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str:
  472. if len(expr.args) != 1:
  473. raise AssertionError("acos expects exactly one argument")
  474. # pyrefly: ignore [missing-attribute]
  475. return f"std::acos({self._print(expr.args[0])})"
  476. def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str:
  477. if len(expr.args) != 1:
  478. raise AssertionError("sin expects exactly one argument")
  479. # pyrefly: ignore [missing-attribute]
  480. return f"math.sin({self._print(expr.args[0])})"
  481. def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str:
  482. if len(expr.args) != 1:
  483. raise AssertionError("sinh expects exactly one argument")
  484. # pyrefly: ignore [missing-attribute]
  485. return f"std::sinh({self._print(expr.args[0])})"
  486. def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str:
  487. if len(expr.args) != 1:
  488. raise AssertionError("asin expects exactly one argument")
  489. # pyrefly: ignore [missing-attribute]
  490. return f"std::asin({self._print(expr.args[0])})"
  491. def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str:
  492. if len(expr.args) != 1:
  493. raise AssertionError("tan expects exactly one argument")
  494. # pyrefly: ignore [missing-attribute]
  495. return f"std::tan({self._print(expr.args[0])})"
  496. def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str:
  497. if len(expr.args) != 1:
  498. raise AssertionError("tanh expects exactly one argument")
  499. # pyrefly: ignore [missing-attribute]
  500. return f"std::tanh({self._print(expr.args[0])})"
  501. def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str:
  502. if len(expr.args) != 1:
  503. raise AssertionError("atan expects exactly one argument")
  504. # pyrefly: ignore [missing-attribute]
  505. return f"std::atan({self._print(expr.args[0])})"
  506. def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str:
  507. # pyrefly: ignore [missing-attribute]
  508. return f"std::sqrt({self._print(expr.args[0])})"
  509. def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str:
  510. # pyrefly: ignore [missing-attribute]
  511. return f"std::log2({self._print(expr.args[0])})"
  512. def _print_RoundToInt(self, expr: sympy.Expr) -> str:
  513. if len(expr.args) != 1:
  514. raise AssertionError("RoundToInt expects exactly one argument")
  515. # TODO: dispatch to llrint depending on index type
  516. # pyrefly: ignore [missing-attribute]
  517. return f"std::lrint({self._print(expr.args[0])})"
  518. def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
  519. if len(expr.args) != 2:
  520. raise AssertionError("RoundDecimal expects exactly two arguments")
  521. number, ndigits = expr.args
  522. if number.is_integer:
  523. # ndigits < 0 should have been filtered by the sympy function
  524. if ndigits >= 0:
  525. raise AssertionError("ndigits must be negative for integer inputs")
  526. raise ValueError(
  527. f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
  528. )
  529. number_str = self.parenthesize(number, PRECEDENCE["Mul"])
  530. return f"static_cast<double>(std::nearbyint(1e{ndigits} * {number_str}) * 1e{-ndigits})"
  531. def _print_BooleanTrue(self, expr: sympy.Expr) -> str:
  532. return "true"
  533. def _print_BooleanFalse(self, expr: sympy.Expr) -> str:
  534. return "false"
  535. def _print_Infinity(self, expr: sympy.Expr) -> str:
  536. return "std::numeric_limits<double>::infinity()"
  537. def _print_NegativeInfinity(self, expr: sympy.Expr) -> str:
  538. return f"-{self._print_Infinity(expr)}"