test_inference.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. """For more tests on satisfiability, see test_dimacs"""
  2. from sympy.assumptions.ask import Q
  3. from sympy.core.symbol import symbols
  4. from sympy.core.relational import Unequality
  5. from sympy.logic.boolalg import And, Or, Implies, Equivalent, true, false
  6. from sympy.logic.inference import literal_symbol, \
  7. pl_true, satisfiable, valid, entails, PropKB
  8. from sympy.logic.algorithms.dpll import dpll, dpll_satisfiable, \
  9. find_pure_symbol, find_unit_clause, unit_propagate, \
  10. find_pure_symbol_int_repr, find_unit_clause_int_repr, \
  11. unit_propagate_int_repr
  12. from sympy.logic.algorithms.dpll2 import dpll_satisfiable as dpll2_satisfiable
  13. from sympy.logic.algorithms.z3_wrapper import z3_satisfiable
  14. from sympy.assumptions.cnf import CNF, EncodedCNF
  15. from sympy.logic.tests.test_lra_theory import make_random_problem
  16. from sympy.core.random import randint
  17. from sympy.testing.pytest import raises, skip
  18. from sympy.external import import_module
  19. def test_literal():
  20. A, B = symbols('A,B')
  21. assert literal_symbol(True) is True
  22. assert literal_symbol(False) is False
  23. assert literal_symbol(A) is A
  24. assert literal_symbol(~A) is A
  25. def test_find_pure_symbol():
  26. A, B, C = symbols('A,B,C')
  27. assert find_pure_symbol([A], [A]) == (A, True)
  28. assert find_pure_symbol([A, B], [~A | B, ~B | A]) == (None, None)
  29. assert find_pure_symbol([A, B, C], [ A | ~B, ~B | ~C, C | A]) == (A, True)
  30. assert find_pure_symbol([A, B, C], [~A | B, B | ~C, C | A]) == (B, True)
  31. assert find_pure_symbol([A, B, C], [~A | ~B, ~B | ~C, C | A]) == (B, False)
  32. assert find_pure_symbol(
  33. [A, B, C], [~A | B, ~B | ~C, C | A]) == (None, None)
  34. def test_find_pure_symbol_int_repr():
  35. assert find_pure_symbol_int_repr([1], [{1}]) == (1, True)
  36. assert find_pure_symbol_int_repr([1, 2],
  37. [{-1, 2}, {-2, 1}]) == (None, None)
  38. assert find_pure_symbol_int_repr([1, 2, 3],
  39. [{1, -2}, {-2, -3}, {3, 1}]) == (1, True)
  40. assert find_pure_symbol_int_repr([1, 2, 3],
  41. [{-1, 2}, {2, -3}, {3, 1}]) == (2, True)
  42. assert find_pure_symbol_int_repr([1, 2, 3],
  43. [{-1, -2}, {-2, -3}, {3, 1}]) == (2, False)
  44. assert find_pure_symbol_int_repr([1, 2, 3],
  45. [{-1, 2}, {-2, -3}, {3, 1}]) == (None, None)
  46. def test_unit_clause():
  47. A, B, C = symbols('A,B,C')
  48. assert find_unit_clause([A], {}) == (A, True)
  49. assert find_unit_clause([A, ~A], {}) == (A, True) # Wrong ??
  50. assert find_unit_clause([A | B], {A: True}) == (B, True)
  51. assert find_unit_clause([A | B], {B: True}) == (A, True)
  52. assert find_unit_clause(
  53. [A | B | C, B | ~C, A | ~B], {A: True}) == (B, False)
  54. assert find_unit_clause([A | B | C, B | ~C, A | B], {A: True}) == (B, True)
  55. assert find_unit_clause([A | B | C, B | ~C, A ], {}) == (A, True)
  56. def test_unit_clause_int_repr():
  57. assert find_unit_clause_int_repr(map(set, [[1]]), {}) == (1, True)
  58. assert find_unit_clause_int_repr(map(set, [[1], [-1]]), {}) == (1, True)
  59. assert find_unit_clause_int_repr([{1, 2}], {1: True}) == (2, True)
  60. assert find_unit_clause_int_repr([{1, 2}], {2: True}) == (1, True)
  61. assert find_unit_clause_int_repr(map(set,
  62. [[1, 2, 3], [2, -3], [1, -2]]), {1: True}) == (2, False)
  63. assert find_unit_clause_int_repr(map(set,
  64. [[1, 2, 3], [3, -3], [1, 2]]), {1: True}) == (2, True)
  65. A, B, C = symbols('A,B,C')
  66. assert find_unit_clause([A | B | C, B | ~C, A ], {}) == (A, True)
  67. def test_unit_propagate():
  68. A, B, C = symbols('A,B,C')
  69. assert unit_propagate([A | B], A) == []
  70. assert unit_propagate([A | B, ~A | C, ~C | B, A], A) == [C, ~C | B, A]
  71. def test_unit_propagate_int_repr():
  72. assert unit_propagate_int_repr([{1, 2}], 1) == []
  73. assert unit_propagate_int_repr(map(set,
  74. [[1, 2], [-1, 3], [-3, 2], [1]]), 1) == [{3}, {-3, 2}]
  75. def test_dpll():
  76. """This is also tested in test_dimacs"""
  77. A, B, C = symbols('A,B,C')
  78. assert dpll([A | B], [A, B], {A: True, B: True}) == {A: True, B: True}
  79. def test_dpll_satisfiable():
  80. A, B, C = symbols('A,B,C')
  81. assert dpll_satisfiable( A & ~A ) is False
  82. assert dpll_satisfiable( A & ~B ) == {A: True, B: False}
  83. assert dpll_satisfiable(
  84. A | B ) in ({A: True}, {B: True}, {A: True, B: True})
  85. assert dpll_satisfiable(
  86. (~A | B) & (~B | A) ) in ({A: True, B: True}, {A: False, B: False})
  87. assert dpll_satisfiable( (A | B) & (~B | C) ) in ({A: True, B: False},
  88. {A: True, C: True}, {B: True, C: True})
  89. assert dpll_satisfiable( A & B & C ) == {A: True, B: True, C: True}
  90. assert dpll_satisfiable( (A | B) & (A >> B) ) == {B: True}
  91. assert dpll_satisfiable( Equivalent(A, B) & A ) == {A: True, B: True}
  92. assert dpll_satisfiable( Equivalent(A, B) & ~A ) == {A: False, B: False}
  93. def test_dpll2_satisfiable():
  94. A, B, C = symbols('A,B,C')
  95. assert dpll2_satisfiable( A & ~A ) is False
  96. assert dpll2_satisfiable( A & ~B ) == {A: True, B: False}
  97. assert dpll2_satisfiable(
  98. A | B ) in ({A: True}, {B: True}, {A: True, B: True})
  99. assert dpll2_satisfiable(
  100. (~A | B) & (~B | A) ) in ({A: True, B: True}, {A: False, B: False})
  101. assert dpll2_satisfiable( (A | B) & (~B | C) ) in ({A: True, B: False, C: True},
  102. {A: True, B: True, C: True})
  103. assert dpll2_satisfiable( A & B & C ) == {A: True, B: True, C: True}
  104. assert dpll2_satisfiable( (A | B) & (A >> B) ) in ({B: True, A: False},
  105. {B: True, A: True})
  106. assert dpll2_satisfiable( Equivalent(A, B) & A ) == {A: True, B: True}
  107. assert dpll2_satisfiable( Equivalent(A, B) & ~A ) == {A: False, B: False}
  108. def test_minisat22_satisfiable():
  109. A, B, C = symbols('A,B,C')
  110. minisat22_satisfiable = lambda expr: satisfiable(expr, algorithm="minisat22")
  111. assert minisat22_satisfiable( A & ~A ) is False
  112. assert minisat22_satisfiable( A & ~B ) == {A: True, B: False}
  113. assert minisat22_satisfiable(
  114. A | B ) in ({A: True}, {B: False}, {A: False, B: True}, {A: True, B: True}, {A: True, B: False})
  115. assert minisat22_satisfiable(
  116. (~A | B) & (~B | A) ) in ({A: True, B: True}, {A: False, B: False})
  117. assert minisat22_satisfiable( (A | B) & (~B | C) ) in ({A: True, B: False, C: True},
  118. {A: True, B: True, C: True}, {A: False, B: True, C: True}, {A: True, B: False, C: False})
  119. assert minisat22_satisfiable( A & B & C ) == {A: True, B: True, C: True}
  120. assert minisat22_satisfiable( (A | B) & (A >> B) ) in ({B: True, A: False},
  121. {B: True, A: True})
  122. assert minisat22_satisfiable( Equivalent(A, B) & A ) == {A: True, B: True}
  123. assert minisat22_satisfiable( Equivalent(A, B) & ~A ) == {A: False, B: False}
  124. def test_minisat22_minimal_satisfiable():
  125. A, B, C = symbols('A,B,C')
  126. minisat22_satisfiable = lambda expr, minimal=True: satisfiable(expr, algorithm="minisat22", minimal=True)
  127. assert minisat22_satisfiable( A & ~A ) is False
  128. assert minisat22_satisfiable( A & ~B ) == {A: True, B: False}
  129. assert minisat22_satisfiable(
  130. A | B ) in ({A: True}, {B: False}, {A: False, B: True}, {A: True, B: True}, {A: True, B: False})
  131. assert minisat22_satisfiable(
  132. (~A | B) & (~B | A) ) in ({A: True, B: True}, {A: False, B: False})
  133. assert minisat22_satisfiable( (A | B) & (~B | C) ) in ({A: True, B: False, C: True},
  134. {A: True, B: True, C: True}, {A: False, B: True, C: True}, {A: True, B: False, C: False})
  135. assert minisat22_satisfiable( A & B & C ) == {A: True, B: True, C: True}
  136. assert minisat22_satisfiable( (A | B) & (A >> B) ) in ({B: True, A: False},
  137. {B: True, A: True})
  138. assert minisat22_satisfiable( Equivalent(A, B) & A ) == {A: True, B: True}
  139. assert minisat22_satisfiable( Equivalent(A, B) & ~A ) == {A: False, B: False}
  140. g = satisfiable((A | B | C),algorithm="minisat22",minimal=True,all_models=True)
  141. sol = next(g)
  142. first_solution = {key for key, value in sol.items() if value}
  143. sol=next(g)
  144. second_solution = {key for key, value in sol.items() if value}
  145. sol=next(g)
  146. third_solution = {key for key, value in sol.items() if value}
  147. assert not first_solution <= second_solution
  148. assert not second_solution <= third_solution
  149. assert not first_solution <= third_solution
  150. def test_satisfiable():
  151. A, B, C = symbols('A,B,C')
  152. assert satisfiable(A & (A >> B) & ~B) is False
  153. def test_valid():
  154. A, B, C = symbols('A,B,C')
  155. assert valid(A >> (B >> A)) is True
  156. assert valid((A >> (B >> C)) >> ((A >> B) >> (A >> C))) is True
  157. assert valid((~B >> ~A) >> (A >> B)) is True
  158. assert valid(A | B | C) is False
  159. assert valid(A >> B) is False
  160. def test_pl_true():
  161. A, B, C = symbols('A,B,C')
  162. assert pl_true(True) is True
  163. assert pl_true( A & B, {A: True, B: True}) is True
  164. assert pl_true( A | B, {A: True}) is True
  165. assert pl_true( A | B, {B: True}) is True
  166. assert pl_true( A | B, {A: None, B: True}) is True
  167. assert pl_true( A >> B, {A: False}) is True
  168. assert pl_true( A | B | ~C, {A: False, B: True, C: True}) is True
  169. assert pl_true(Equivalent(A, B), {A: False, B: False}) is True
  170. # test for false
  171. assert pl_true(False) is False
  172. assert pl_true( A & B, {A: False, B: False}) is False
  173. assert pl_true( A & B, {A: False}) is False
  174. assert pl_true( A & B, {B: False}) is False
  175. assert pl_true( A | B, {A: False, B: False}) is False
  176. #test for None
  177. assert pl_true(B, {B: None}) is None
  178. assert pl_true( A & B, {A: True, B: None}) is None
  179. assert pl_true( A >> B, {A: True, B: None}) is None
  180. assert pl_true(Equivalent(A, B), {A: None}) is None
  181. assert pl_true(Equivalent(A, B), {A: True, B: None}) is None
  182. # Test for deep
  183. assert pl_true(A | B, {A: False}, deep=True) is None
  184. assert pl_true(~A & ~B, {A: False}, deep=True) is None
  185. assert pl_true(A | B, {A: False, B: False}, deep=True) is False
  186. assert pl_true(A & B & (~A | ~B), {A: True}, deep=True) is False
  187. assert pl_true((C >> A) >> (B >> A), {C: True}, deep=True) is True
  188. def test_pl_true_wrong_input():
  189. from sympy.core.numbers import pi
  190. raises(ValueError, lambda: pl_true('John Cleese'))
  191. raises(ValueError, lambda: pl_true(42 + pi + pi ** 2))
  192. raises(ValueError, lambda: pl_true(42))
  193. def test_entails():
  194. A, B, C = symbols('A, B, C')
  195. assert entails(A, [A >> B, ~B]) is False
  196. assert entails(B, [Equivalent(A, B), A]) is True
  197. assert entails((A >> B) >> (~A >> ~B)) is False
  198. assert entails((A >> B) >> (~B >> ~A)) is True
  199. def test_PropKB():
  200. A, B, C = symbols('A,B,C')
  201. kb = PropKB()
  202. assert kb.ask(A >> B) is False
  203. assert kb.ask(A >> (B >> A)) is True
  204. kb.tell(A >> B)
  205. kb.tell(B >> C)
  206. assert kb.ask(A) is False
  207. assert kb.ask(B) is False
  208. assert kb.ask(C) is False
  209. assert kb.ask(~A) is False
  210. assert kb.ask(~B) is False
  211. assert kb.ask(~C) is False
  212. assert kb.ask(A >> C) is True
  213. kb.tell(A)
  214. assert kb.ask(A) is True
  215. assert kb.ask(B) is True
  216. assert kb.ask(C) is True
  217. assert kb.ask(~C) is False
  218. kb.retract(A)
  219. assert kb.ask(C) is False
  220. def test_propKB_tolerant():
  221. """"tolerant to bad input"""
  222. kb = PropKB()
  223. A, B, C = symbols('A,B,C')
  224. assert kb.ask(B) is False
  225. def test_satisfiable_non_symbols():
  226. x, y = symbols('x y')
  227. assumptions = Q.zero(x*y)
  228. facts = Implies(Q.zero(x*y), Q.zero(x) | Q.zero(y))
  229. query = ~Q.zero(x) & ~Q.zero(y)
  230. refutations = [
  231. {Q.zero(x): True, Q.zero(x*y): True},
  232. {Q.zero(y): True, Q.zero(x*y): True},
  233. {Q.zero(x): True, Q.zero(y): True, Q.zero(x*y): True},
  234. {Q.zero(x): True, Q.zero(y): False, Q.zero(x*y): True},
  235. {Q.zero(x): False, Q.zero(y): True, Q.zero(x*y): True}]
  236. assert not satisfiable(And(assumptions, facts, query), algorithm='dpll')
  237. assert satisfiable(And(assumptions, facts, ~query), algorithm='dpll') in refutations
  238. assert not satisfiable(And(assumptions, facts, query), algorithm='dpll2')
  239. assert satisfiable(And(assumptions, facts, ~query), algorithm='dpll2') in refutations
  240. def test_satisfiable_bool():
  241. from sympy.core.singleton import S
  242. assert satisfiable(true) == {true: true}
  243. assert satisfiable(S.true) == {true: true}
  244. assert satisfiable(false) is False
  245. assert satisfiable(S.false) is False
  246. def test_satisfiable_all_models():
  247. from sympy.abc import A, B
  248. assert next(satisfiable(False, all_models=True)) is False
  249. assert list(satisfiable((A >> ~A) & A, all_models=True)) == [False]
  250. assert list(satisfiable(True, all_models=True)) == [{true: true}]
  251. models = [{A: True, B: False}, {A: False, B: True}]
  252. result = satisfiable(A ^ B, all_models=True)
  253. models.remove(next(result))
  254. models.remove(next(result))
  255. raises(StopIteration, lambda: next(result))
  256. assert not models
  257. assert list(satisfiable(Equivalent(A, B), all_models=True)) == \
  258. [{A: False, B: False}, {A: True, B: True}]
  259. models = [{A: False, B: False}, {A: False, B: True}, {A: True, B: True}]
  260. for model in satisfiable(A >> B, all_models=True):
  261. models.remove(model)
  262. assert not models
  263. # This is a santiy test to check that only the required number
  264. # of solutions are generated. The expr below has 2**100 - 1 models
  265. # which would time out the test if all are generated at once.
  266. from sympy.utilities.iterables import numbered_symbols
  267. from sympy.logic.boolalg import Or
  268. sym = numbered_symbols()
  269. X = [next(sym) for i in range(100)]
  270. result = satisfiable(Or(*X), all_models=True)
  271. for i in range(10):
  272. assert next(result)
  273. def test_z3():
  274. z3 = import_module("z3")
  275. if not z3:
  276. skip("z3 not installed.")
  277. A, B, C = symbols('A,B,C')
  278. x, y, z = symbols('x,y,z')
  279. assert z3_satisfiable((x >= 2) & (x < 1)) is False
  280. assert z3_satisfiable( A & ~A ) is False
  281. model = z3_satisfiable(A & (~A | B | C))
  282. assert bool(model) is True
  283. assert model[A] is True
  284. # test nonlinear function
  285. assert z3_satisfiable((x ** 2 >= 2) & (x < 1) & (x > -1)) is False
  286. def test_z3_vs_lra_dpll2():
  287. z3 = import_module("z3")
  288. if z3 is None:
  289. skip("z3 not installed.")
  290. def boolean_formula_to_encoded_cnf(bf):
  291. cnf = CNF.from_prop(bf)
  292. enc = EncodedCNF()
  293. enc.from_cnf(cnf)
  294. return enc
  295. def make_random_cnf(num_clauses=5, num_constraints=10, num_var=2):
  296. assert num_clauses <= num_constraints
  297. constraints = make_random_problem(num_variables=num_var, num_constraints=num_constraints, rational=False)
  298. clauses = [[cons] for cons in constraints[:num_clauses]]
  299. for cons in constraints[num_clauses:]:
  300. if isinstance(cons, Unequality):
  301. cons = ~cons
  302. i = randint(0, num_clauses-1)
  303. clauses[i].append(cons)
  304. clauses = [Or(*clause) for clause in clauses]
  305. cnf = And(*clauses)
  306. return boolean_formula_to_encoded_cnf(cnf)
  307. lra_dpll2_satisfiable = lambda x: dpll2_satisfiable(x, use_lra_theory=True)
  308. for _ in range(50):
  309. cnf = make_random_cnf(num_clauses=10, num_constraints=15, num_var=2)
  310. try:
  311. z3_sat = z3_satisfiable(cnf)
  312. except z3.z3types.Z3Exception:
  313. continue
  314. lra_dpll2_sat = lra_dpll2_satisfiable(cnf) is not False
  315. assert z3_sat == lra_dpll2_sat
  316. def test_issue_27733():
  317. x, y = symbols('x,y')
  318. clauses = [[1, -3, -2], [5, 7, -8, -6, -4], [-10, -9, 10, 11, -4], [-12, 13, 14], [-10, 9, -6, 11, -4],
  319. [16, -15, 18, -19, -17], [11, -6, 10, -9], [9, 11, -10, -9], [2, -3, -1], [-13, 12], [-15, 3, -17],
  320. [-16, -15, 19, -17], [-6, -9, 10, 11, -4], [20, -1, -2], [-23, -22, -21], [10, 11, -10, -9],
  321. [9, 11, -4, -10], [24, -6, -4], [-14, 12], [-10, -9, 9, -6, 11], [25, -27, -26], [-15, 19, -18, -17],
  322. [5, 8, -7, -6, -4], [-30, -29, 28], [12], [14]]
  323. encoding = {Q.gt(y, i): i for i in range(1, 31) if i != 11 and i != 12}
  324. encoding[Q.gt(x, 0)] = 11
  325. encoding[Q.lt(x, 0)] = 12
  326. cnf = EncodedCNF(clauses, encoding)
  327. assert satisfiable(cnf, use_lra_theory=True) is False