test_smtlib.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553
  1. import contextlib
  2. import itertools
  3. import re
  4. import typing
  5. from enum import Enum
  6. from typing import Callable
  7. import sympy
  8. from sympy import Add, Implies, sqrt
  9. from sympy.core import Mul, Pow
  10. from sympy.core import (S, pi, symbols, Function, Rational, Integer,
  11. Symbol, Eq, Ne, Le, Lt, Gt, Ge)
  12. from sympy.functions import Piecewise, exp, sin, cos
  13. from sympy.assumptions.ask import Q
  14. from sympy.printing.smtlib import smtlib_code
  15. from sympy.testing.pytest import raises, Failed
  16. x, y, z = symbols('x,y,z')
  17. class _W(Enum):
  18. DEFAULTING_TO_FLOAT = re.compile("Could not infer type of `.+`. Defaulting to float.", re.IGNORECASE)
  19. WILL_NOT_DECLARE = re.compile("Non-Symbol/Function `.+` will not be declared.", re.IGNORECASE)
  20. WILL_NOT_ASSERT = re.compile("Non-Boolean expression `.+` will not be asserted. Converting to SMTLib verbatim.", re.IGNORECASE)
  21. @contextlib.contextmanager
  22. def _check_warns(expected: typing.Iterable[_W]):
  23. warns: typing.List[str] = []
  24. log_warn = warns.append
  25. yield log_warn
  26. errors = []
  27. for i, (w, e) in enumerate(itertools.zip_longest(warns, expected)):
  28. if not e:
  29. errors += [f"[{i}] Received unexpected warning `{w}`."]
  30. elif not w:
  31. errors += [f"[{i}] Did not receive expected warning `{e.name}`."]
  32. elif not e.value.match(w):
  33. errors += [f"[{i}] Warning `{w}` does not match expected {e.name}."]
  34. if errors: raise Failed('\n'.join(errors))
  35. def test_Integer():
  36. with _check_warns([_W.WILL_NOT_ASSERT] * 2) as w:
  37. assert smtlib_code(Integer(67), log_warn=w) == "67"
  38. assert smtlib_code(Integer(-1), log_warn=w) == "-1"
  39. with _check_warns([]) as w:
  40. assert smtlib_code(Integer(67)) == "67"
  41. assert smtlib_code(Integer(-1)) == "-1"
  42. def test_Rational():
  43. with _check_warns([_W.WILL_NOT_ASSERT] * 4) as w:
  44. assert smtlib_code(Rational(3, 7), log_warn=w) == "(/ 3 7)"
  45. assert smtlib_code(Rational(18, 9), log_warn=w) == "2"
  46. assert smtlib_code(Rational(3, -7), log_warn=w) == "(/ -3 7)"
  47. assert smtlib_code(Rational(-3, -7), log_warn=w) == "(/ 3 7)"
  48. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT] * 2) as w:
  49. assert smtlib_code(x + Rational(3, 7), auto_declare=False, log_warn=w) == "(+ (/ 3 7) x)"
  50. assert smtlib_code(Rational(3, 7) * x, log_warn=w) == "(declare-const x Real)\n" \
  51. "(* (/ 3 7) x)"
  52. def test_Relational():
  53. with _check_warns([_W.DEFAULTING_TO_FLOAT] * 12) as w:
  54. assert smtlib_code(Eq(x, y), auto_declare=False, log_warn=w) == "(assert (= x y))"
  55. assert smtlib_code(Ne(x, y), auto_declare=False, log_warn=w) == "(assert (not (= x y)))"
  56. assert smtlib_code(Le(x, y), auto_declare=False, log_warn=w) == "(assert (<= x y))"
  57. assert smtlib_code(Lt(x, y), auto_declare=False, log_warn=w) == "(assert (< x y))"
  58. assert smtlib_code(Gt(x, y), auto_declare=False, log_warn=w) == "(assert (> x y))"
  59. assert smtlib_code(Ge(x, y), auto_declare=False, log_warn=w) == "(assert (>= x y))"
  60. def test_AppliedBinaryRelation():
  61. with _check_warns([_W.DEFAULTING_TO_FLOAT] * 12) as w:
  62. assert smtlib_code(Q.eq(x, y), auto_declare=False, log_warn=w) == "(assert (= x y))"
  63. assert smtlib_code(Q.ne(x, y), auto_declare=False, log_warn=w) == "(assert (not (= x y)))"
  64. assert smtlib_code(Q.lt(x, y), auto_declare=False, log_warn=w) == "(assert (< x y))"
  65. assert smtlib_code(Q.le(x, y), auto_declare=False, log_warn=w) == "(assert (<= x y))"
  66. assert smtlib_code(Q.gt(x, y), auto_declare=False, log_warn=w) == "(assert (> x y))"
  67. assert smtlib_code(Q.ge(x, y), auto_declare=False, log_warn=w) == "(assert (>= x y))"
  68. raises(ValueError, lambda: smtlib_code(Q.complex(x), log_warn=w))
  69. def test_AppliedPredicate():
  70. with _check_warns([_W.DEFAULTING_TO_FLOAT] * 6) as w:
  71. assert smtlib_code(Q.positive(x), auto_declare=False, log_warn=w) == "(assert (> x 0))"
  72. assert smtlib_code(Q.negative(x), auto_declare=False, log_warn=w) == "(assert (< x 0))"
  73. assert smtlib_code(Q.zero(x), auto_declare=False, log_warn=w) == "(assert (= x 0))"
  74. assert smtlib_code(Q.nonpositive(x), auto_declare=False, log_warn=w) == "(assert (<= x 0))"
  75. assert smtlib_code(Q.nonnegative(x), auto_declare=False, log_warn=w) == "(assert (>= x 0))"
  76. assert smtlib_code(Q.nonzero(x), auto_declare=False, log_warn=w) == "(assert (not (= x 0)))"
  77. def test_Function():
  78. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  79. assert smtlib_code(sin(x) ** cos(x), auto_declare=False, log_warn=w) == "(pow (sin x) (cos x))"
  80. with _check_warns([_W.WILL_NOT_ASSERT]) as w:
  81. assert smtlib_code(
  82. abs(x),
  83. symbol_table={x: int, y: bool},
  84. known_types={int: "INTEGER_TYPE"},
  85. known_functions={sympy.Abs: "ABSOLUTE_VALUE_OF"},
  86. log_warn=w
  87. ) == "(declare-const x INTEGER_TYPE)\n" \
  88. "(ABSOLUTE_VALUE_OF x)"
  89. my_fun1 = Function('f1')
  90. with _check_warns([_W.WILL_NOT_ASSERT]) as w:
  91. assert smtlib_code(
  92. my_fun1(x),
  93. symbol_table={my_fun1: Callable[[bool], float]},
  94. log_warn=w
  95. ) == "(declare-const x Bool)\n" \
  96. "(declare-fun f1 (Bool) Real)\n" \
  97. "(f1 x)"
  98. with _check_warns([]) as w:
  99. assert smtlib_code(
  100. my_fun1(x),
  101. symbol_table={my_fun1: Callable[[bool], bool]},
  102. log_warn=w
  103. ) == "(declare-const x Bool)\n" \
  104. "(declare-fun f1 (Bool) Bool)\n" \
  105. "(assert (f1 x))"
  106. assert smtlib_code(
  107. Eq(my_fun1(x, z), y),
  108. symbol_table={my_fun1: Callable[[int, bool], bool]},
  109. log_warn=w
  110. ) == "(declare-const x Int)\n" \
  111. "(declare-const y Bool)\n" \
  112. "(declare-const z Bool)\n" \
  113. "(declare-fun f1 (Int Bool) Bool)\n" \
  114. "(assert (= (f1 x z) y))"
  115. assert smtlib_code(
  116. Eq(my_fun1(x, z), y),
  117. symbol_table={my_fun1: Callable[[int, bool], bool]},
  118. known_functions={my_fun1: "MY_KNOWN_FUN", Eq: '=='},
  119. log_warn=w
  120. ) == "(declare-const x Int)\n" \
  121. "(declare-const y Bool)\n" \
  122. "(declare-const z Bool)\n" \
  123. "(assert (== (MY_KNOWN_FUN x z) y))"
  124. with _check_warns([_W.DEFAULTING_TO_FLOAT] * 3) as w:
  125. assert smtlib_code(
  126. Eq(my_fun1(x, z), y),
  127. known_functions={my_fun1: "MY_KNOWN_FUN", Eq: '=='},
  128. log_warn=w
  129. ) == "(declare-const x Real)\n" \
  130. "(declare-const y Real)\n" \
  131. "(declare-const z Real)\n" \
  132. "(assert (== (MY_KNOWN_FUN x z) y))"
  133. def test_Pow():
  134. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  135. assert smtlib_code(x ** 3, auto_declare=False, log_warn=w) == "(pow x 3)"
  136. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  137. assert smtlib_code(x ** (y ** 3), auto_declare=False, log_warn=w) == "(pow x (pow y 3))"
  138. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  139. assert smtlib_code(x ** Rational(2, 3), auto_declare=False, log_warn=w) == '(pow x (/ 2 3))'
  140. a = Symbol('a', integer=True)
  141. b = Symbol('b', real=True)
  142. c = Symbol('c')
  143. def g(x): return 2 * x
  144. # if x=1, y=2, then expr=2.333...
  145. expr = 1 / (g(a) * 3.5) ** (a - b ** a) / (a ** 2 + b)
  146. with _check_warns([]) as w:
  147. assert smtlib_code(
  148. [
  149. Eq(a < 2, c),
  150. Eq(b > a, c),
  151. c & True,
  152. Eq(expr, 2 + Rational(1, 3))
  153. ],
  154. log_warn=w
  155. ) == '(declare-const a Int)\n' \
  156. '(declare-const b Real)\n' \
  157. '(declare-const c Bool)\n' \
  158. '(assert (= (< a 2) c))\n' \
  159. '(assert (= (> b a) c))\n' \
  160. '(assert c)\n' \
  161. '(assert (= ' \
  162. '(* (pow (* 7.0 a) (+ (pow b a) (* -1 a))) (pow (+ b (pow a 2)) -1)) ' \
  163. '(/ 7 3)' \
  164. '))'
  165. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  166. assert smtlib_code(
  167. Mul(-2, c, Pow(Mul(b, b, evaluate=False), -1, evaluate=False), evaluate=False),
  168. log_warn=w
  169. ) == '(declare-const b Real)\n' \
  170. '(declare-const c Real)\n' \
  171. '(* -2 c (pow (* b b) -1))'
  172. def test_basic_ops():
  173. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  174. assert smtlib_code(x * y, auto_declare=False, log_warn=w) == "(* x y)"
  175. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  176. assert smtlib_code(x + y, auto_declare=False, log_warn=w) == "(+ x y)"
  177. # with _check_warns([_SmtlibWarnings.DEFAULTING_TO_FLOAT, _SmtlibWarnings.DEFAULTING_TO_FLOAT, _SmtlibWarnings.WILL_NOT_ASSERT]) as w:
  178. # todo: implement re-write, currently does '(+ x (* -1 y))' instead
  179. # assert smtlib_code(x - y, auto_declare=False, log_warn=w) == "(- x y)"
  180. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  181. assert smtlib_code(-x, auto_declare=False, log_warn=w) == "(* -1 x)"
  182. def test_quantifier_extensions():
  183. from sympy.logic.boolalg import Boolean
  184. from sympy import Interval, Tuple, sympify
  185. # start For-all quantifier class example
  186. class ForAll(Boolean):
  187. def _smtlib(self, printer):
  188. bound_symbol_declarations = [
  189. printer._s_expr(sym.name, [
  190. printer._known_types[printer.symbol_table[sym]],
  191. Interval(start, end)
  192. ]) for sym, start, end in self.limits
  193. ]
  194. return printer._s_expr('forall', [
  195. printer._s_expr('', bound_symbol_declarations),
  196. self.function
  197. ])
  198. @property
  199. def bound_symbols(self):
  200. return {s for s, _, _ in self.limits}
  201. @property
  202. def free_symbols(self):
  203. bound_symbol_names = {s.name for s in self.bound_symbols}
  204. return {
  205. s for s in self.function.free_symbols
  206. if s.name not in bound_symbol_names
  207. }
  208. def __new__(cls, *args):
  209. limits = [sympify(a) for a in args if isinstance(a, (tuple, Tuple))]
  210. function = [sympify(a) for a in args if isinstance(a, Boolean)]
  211. assert len(limits) + len(function) == len(args)
  212. assert len(function) == 1
  213. function = function[0]
  214. if isinstance(function, ForAll): return ForAll.__new__(
  215. ForAll, *(limits + function.limits), function.function
  216. )
  217. inst = Boolean.__new__(cls)
  218. inst._args = tuple(limits + [function])
  219. inst.limits = limits
  220. inst.function = function
  221. return inst
  222. # end For-All Quantifier class example
  223. f = Function('f')
  224. with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w:
  225. assert smtlib_code(
  226. ForAll((x, -42, +21), Eq(f(x), f(x))),
  227. symbol_table={f: Callable[[float], float]},
  228. log_warn=w
  229. ) == '(assert (forall ( (x Real [-42, 21])) true))'
  230. with _check_warns([_W.DEFAULTING_TO_FLOAT] * 2) as w:
  231. assert smtlib_code(
  232. ForAll(
  233. (x, -42, +21), (y, -100, 3),
  234. Implies(Eq(x, y), Eq(f(x), f(y)))
  235. ),
  236. symbol_table={f: Callable[[float], float]},
  237. log_warn=w
  238. ) == '(declare-fun f (Real) Real)\n' \
  239. '(assert (' \
  240. 'forall ( (x Real [-42, 21]) (y Real [-100, 3])) ' \
  241. '(=> (= x y) (= (f x) (f y)))' \
  242. '))'
  243. a = Symbol('a', integer=True)
  244. b = Symbol('b', real=True)
  245. c = Symbol('c')
  246. with _check_warns([]) as w:
  247. assert smtlib_code(
  248. ForAll(
  249. (a, 2, 100), ForAll(
  250. (b, 2, 100),
  251. Implies(a < b, sqrt(a) < b) | c
  252. )),
  253. log_warn=w
  254. ) == '(declare-const c Bool)\n' \
  255. '(assert (forall ( (a Int [2, 100]) (b Real [2, 100])) ' \
  256. '(or c (=> (< a b) (< (pow a (/ 1 2)) b)))' \
  257. '))'
  258. def test_mix_number_mult_symbols():
  259. with _check_warns([_W.WILL_NOT_ASSERT]) as w:
  260. assert smtlib_code(
  261. 1 / pi,
  262. known_constants={pi: "MY_PI"},
  263. log_warn=w
  264. ) == '(pow MY_PI -1)'
  265. with _check_warns([_W.WILL_NOT_ASSERT]) as w:
  266. assert smtlib_code(
  267. [
  268. Eq(pi, 3.14, evaluate=False),
  269. 1 / pi,
  270. ],
  271. known_constants={pi: "MY_PI"},
  272. log_warn=w
  273. ) == '(assert (= MY_PI 3.14))\n' \
  274. '(pow MY_PI -1)'
  275. with _check_warns([_W.WILL_NOT_ASSERT]) as w:
  276. assert smtlib_code(
  277. Add(S.Zero, S.One, S.NegativeOne, S.Half,
  278. S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
  279. known_constants={
  280. S.Pi: 'p', S.GoldenRatio: 'g',
  281. S.Exp1: 'e'
  282. },
  283. known_functions={
  284. Add: 'plus',
  285. exp: 'exp'
  286. },
  287. precision=3,
  288. log_warn=w
  289. ) == '(plus 0 1 -1 (/ 1 2) (exp 1) p g)'
  290. with _check_warns([_W.WILL_NOT_ASSERT]) as w:
  291. assert smtlib_code(
  292. Add(S.Zero, S.One, S.NegativeOne, S.Half,
  293. S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
  294. known_constants={
  295. S.Pi: 'p'
  296. },
  297. known_functions={
  298. Add: 'plus',
  299. exp: 'exp'
  300. },
  301. precision=3,
  302. log_warn=w
  303. ) == '(plus 0 1 -1 (/ 1 2) (exp 1) p 1.62)'
  304. with _check_warns([_W.WILL_NOT_ASSERT]) as w:
  305. assert smtlib_code(
  306. Add(S.Zero, S.One, S.NegativeOne, S.Half,
  307. S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
  308. known_functions={Add: 'plus'},
  309. precision=3,
  310. log_warn=w
  311. ) == '(plus 0 1 -1 (/ 1 2) 2.72 3.14 1.62)'
  312. with _check_warns([_W.WILL_NOT_ASSERT]) as w:
  313. assert smtlib_code(
  314. Add(S.Zero, S.One, S.NegativeOne, S.Half,
  315. S.Exp1, S.Pi, S.GoldenRatio, evaluate=False),
  316. known_constants={S.Exp1: 'e'},
  317. known_functions={Add: 'plus'},
  318. precision=3,
  319. log_warn=w
  320. ) == '(plus 0 1 -1 (/ 1 2) e 3.14 1.62)'
  321. def test_boolean():
  322. with _check_warns([]) as w:
  323. assert smtlib_code(x & y, log_warn=w) == '(declare-const x Bool)\n' \
  324. '(declare-const y Bool)\n' \
  325. '(assert (and x y))'
  326. assert smtlib_code(x | y, log_warn=w) == '(declare-const x Bool)\n' \
  327. '(declare-const y Bool)\n' \
  328. '(assert (or x y))'
  329. assert smtlib_code(~x, log_warn=w) == '(declare-const x Bool)\n' \
  330. '(assert (not x))'
  331. assert smtlib_code(x & y & z, log_warn=w) == '(declare-const x Bool)\n' \
  332. '(declare-const y Bool)\n' \
  333. '(declare-const z Bool)\n' \
  334. '(assert (and x y z))'
  335. with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w:
  336. assert smtlib_code((x & ~y) | (z > 3), log_warn=w) == '(declare-const x Bool)\n' \
  337. '(declare-const y Bool)\n' \
  338. '(declare-const z Real)\n' \
  339. '(assert (or (> z 3) (and x (not y))))'
  340. f = Function('f')
  341. g = Function('g')
  342. h = Function('h')
  343. with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w:
  344. assert smtlib_code(
  345. [Gt(f(x), y),
  346. Lt(y, g(z))],
  347. symbol_table={
  348. f: Callable[[bool], int], g: Callable[[bool], int],
  349. }, log_warn=w
  350. ) == '(declare-const x Bool)\n' \
  351. '(declare-const y Real)\n' \
  352. '(declare-const z Bool)\n' \
  353. '(declare-fun f (Bool) Int)\n' \
  354. '(declare-fun g (Bool) Int)\n' \
  355. '(assert (> (f x) y))\n' \
  356. '(assert (< y (g z)))'
  357. with _check_warns([]) as w:
  358. assert smtlib_code(
  359. [Eq(f(x), y),
  360. Lt(y, g(z))],
  361. symbol_table={
  362. f: Callable[[bool], int], g: Callable[[bool], int],
  363. }, log_warn=w
  364. ) == '(declare-const x Bool)\n' \
  365. '(declare-const y Int)\n' \
  366. '(declare-const z Bool)\n' \
  367. '(declare-fun f (Bool) Int)\n' \
  368. '(declare-fun g (Bool) Int)\n' \
  369. '(assert (= (f x) y))\n' \
  370. '(assert (< y (g z)))'
  371. with _check_warns([]) as w:
  372. assert smtlib_code(
  373. [Eq(f(x), y),
  374. Eq(g(f(x)), z),
  375. Eq(h(g(f(x))), x)],
  376. symbol_table={
  377. f: Callable[[float], int],
  378. g: Callable[[int], bool],
  379. h: Callable[[bool], float]
  380. },
  381. log_warn=w
  382. ) == '(declare-const x Real)\n' \
  383. '(declare-const y Int)\n' \
  384. '(declare-const z Bool)\n' \
  385. '(declare-fun f (Real) Int)\n' \
  386. '(declare-fun g (Int) Bool)\n' \
  387. '(declare-fun h (Bool) Real)\n' \
  388. '(assert (= (f x) y))\n' \
  389. '(assert (= (g (f x)) z))\n' \
  390. '(assert (= (h (g (f x))) x))'
  391. # todo: make smtlib_code support arrays
  392. # def test_containers():
  393. # assert julia_code([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \
  394. # "Any[1, 2, 3, Any[4, 5, Any[6, 7]], 8, Any[9, 10], 11]"
  395. # assert julia_code((1, 2, (3, 4))) == "(1, 2, (3, 4))"
  396. # assert julia_code([1]) == "Any[1]"
  397. # assert julia_code((1,)) == "(1,)"
  398. # assert julia_code(Tuple(*[1, 2, 3])) == "(1, 2, 3)"
  399. # assert julia_code((1, x * y, (3, x ** 2))) == "(1, x .* y, (3, x .^ 2))"
  400. # # scalar, matrix, empty matrix and empty list
  401. # assert julia_code((1, eye(3), Matrix(0, 0, []), [])) == "(1, [1 0 0;\n0 1 0;\n0 0 1], zeros(0, 0), Any[])"
  402. def test_smtlib_piecewise():
  403. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  404. assert smtlib_code(
  405. Piecewise((x, x < 1),
  406. (x ** 2, True)),
  407. auto_declare=False,
  408. log_warn=w
  409. ) == '(ite (< x 1) x (pow x 2))'
  410. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  411. assert smtlib_code(
  412. Piecewise((x ** 2, x < 1),
  413. (x ** 3, x < 2),
  414. (x ** 4, x < 3),
  415. (x ** 5, True)),
  416. auto_declare=False,
  417. log_warn=w
  418. ) == '(ite (< x 1) (pow x 2) ' \
  419. '(ite (< x 2) (pow x 3) ' \
  420. '(ite (< x 3) (pow x 4) ' \
  421. '(pow x 5))))'
  422. # Check that Piecewise without a True (default) condition error
  423. expr = Piecewise((x, x < 1), (x ** 2, x > 1), (sin(x), x > 0))
  424. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  425. raises(AssertionError, lambda: smtlib_code(expr, log_warn=w))
  426. def test_smtlib_piecewise_times_const():
  427. pw = Piecewise((x, x < 1), (x ** 2, True))
  428. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  429. assert smtlib_code(2 * pw, log_warn=w) == '(declare-const x Real)\n(* 2 (ite (< x 1) x (pow x 2)))'
  430. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  431. assert smtlib_code(pw / x, log_warn=w) == '(declare-const x Real)\n(* (pow x -1) (ite (< x 1) x (pow x 2)))'
  432. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  433. assert smtlib_code(pw / (x * y), log_warn=w) == '(declare-const x Real)\n(declare-const y Real)\n(* (pow x -1) (pow y -1) (ite (< x 1) x (pow x 2)))'
  434. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  435. assert smtlib_code(pw / 3, log_warn=w) == '(declare-const x Real)\n(* (/ 1 3) (ite (< x 1) x (pow x 2)))'
  436. # todo: make smtlib_code support arrays / matrices ?
  437. # def test_smtlib_matrix_assign_to():
  438. # A = Matrix([[1, 2, 3]])
  439. # assert smtlib_code(A, assign_to='a') == "a = [1 2 3]"
  440. # A = Matrix([[1, 2], [3, 4]])
  441. # assert smtlib_code(A, assign_to='A') == "A = [1 2;\n3 4]"
  442. # def test_julia_matrix_1x1():
  443. # A = Matrix([[3]])
  444. # B = MatrixSymbol('B', 1, 1)
  445. # C = MatrixSymbol('C', 1, 2)
  446. # assert julia_code(A, assign_to=B) == "B = [3]"
  447. # raises(ValueError, lambda: julia_code(A, assign_to=C))
  448. # def test_julia_matrix_elements():
  449. # A = Matrix([[x, 2, x * y]])
  450. # assert julia_code(A[0, 0] ** 2 + A[0, 1] + A[0, 2]) == "x .^ 2 + x .* y + 2"
  451. # A = MatrixSymbol('AA', 1, 3)
  452. # assert julia_code(A) == "AA"
  453. # assert julia_code(A[0, 0] ** 2 + sin(A[0, 1]) + A[0, 2]) == \
  454. # "sin(AA[1,2]) + AA[1,1] .^ 2 + AA[1,3]"
  455. # assert julia_code(sum(A)) == "AA[1,1] + AA[1,2] + AA[1,3]"
  456. def test_smtlib_boolean():
  457. with _check_warns([]) as w:
  458. assert smtlib_code(True, auto_assert=False, log_warn=w) == 'true'
  459. assert smtlib_code(True, log_warn=w) == '(assert true)'
  460. assert smtlib_code(S.true, log_warn=w) == '(assert true)'
  461. assert smtlib_code(S.false, log_warn=w) == '(assert false)'
  462. assert smtlib_code(False, log_warn=w) == '(assert false)'
  463. assert smtlib_code(False, auto_assert=False, log_warn=w) == 'false'
  464. def test_not_supported():
  465. f = Function('f')
  466. with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w:
  467. raises(KeyError, lambda: smtlib_code(f(x).diff(x), symbol_table={f: Callable[[float], float]}, log_warn=w))
  468. with _check_warns([_W.WILL_NOT_ASSERT]) as w:
  469. raises(KeyError, lambda: smtlib_code(S.ComplexInfinity, log_warn=w))
  470. def test_Float():
  471. assert smtlib_code(0.0) == "0.0"
  472. assert smtlib_code(0.000000000000000003) == '(* 3.0 (pow 10 -18))'
  473. assert smtlib_code(5.3) == "5.3"