satask.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. """
  2. Module to evaluate the proposition with assumptions using SAT algorithm.
  3. """
  4. from sympy.core.singleton import S
  5. from sympy.core.symbol import Symbol
  6. from sympy.core.kind import NumberKind, UndefinedKind
  7. from sympy.assumptions.ask_generated import get_all_known_matrix_facts, get_all_known_number_facts
  8. from sympy.assumptions.assume import global_assumptions, AppliedPredicate
  9. from sympy.assumptions.sathandlers import class_fact_registry
  10. from sympy.core import oo
  11. from sympy.logic.inference import satisfiable
  12. from sympy.assumptions.cnf import CNF, EncodedCNF
  13. from sympy.matrices.kind import MatrixKind
  14. def satask(proposition, assumptions=True, context=global_assumptions,
  15. use_known_facts=True, iterations=oo):
  16. """
  17. Function to evaluate the proposition with assumptions using SAT algorithm.
  18. This function extracts every fact relevant to the expressions composing
  19. proposition and assumptions. For example, if a predicate containing
  20. ``Abs(x)`` is proposed, then ``Q.zero(Abs(x)) | Q.positive(Abs(x))``
  21. will be found and passed to SAT solver because ``Q.nonnegative`` is
  22. registered as a fact for ``Abs``.
  23. Proposition is evaluated to ``True`` or ``False`` if the truth value can be
  24. determined. If not, ``None`` is returned.
  25. Parameters
  26. ==========
  27. proposition : Any boolean expression.
  28. Proposition which will be evaluated to boolean value.
  29. assumptions : Any boolean expression, optional.
  30. Local assumptions to evaluate the *proposition*.
  31. context : AssumptionsContext, optional.
  32. Default assumptions to evaluate the *proposition*. By default,
  33. this is ``sympy.assumptions.global_assumptions`` variable.
  34. use_known_facts : bool, optional.
  35. If ``True``, facts from ``sympy.assumptions.ask_generated``
  36. module are passed to SAT solver as well.
  37. iterations : int, optional.
  38. Number of times that relevant facts are recursively extracted.
  39. Default is infinite times until no new fact is found.
  40. Returns
  41. =======
  42. ``True``, ``False``, or ``None``
  43. Examples
  44. ========
  45. >>> from sympy import Abs, Q
  46. >>> from sympy.assumptions.satask import satask
  47. >>> from sympy.abc import x
  48. >>> satask(Q.zero(Abs(x)), Q.zero(x))
  49. True
  50. """
  51. props = CNF.from_prop(proposition)
  52. _props = CNF.from_prop(~proposition)
  53. assumptions = CNF.from_prop(assumptions)
  54. context_cnf = CNF()
  55. if context:
  56. context_cnf = context_cnf.extend(context)
  57. sat = get_all_relevant_facts(props, assumptions, context_cnf,
  58. use_known_facts=use_known_facts, iterations=iterations)
  59. sat.add_from_cnf(assumptions)
  60. if context:
  61. sat.add_from_cnf(context_cnf)
  62. return check_satisfiability(props, _props, sat)
  63. def check_satisfiability(prop, _prop, factbase):
  64. sat_true = factbase.copy()
  65. sat_false = factbase.copy()
  66. sat_true.add_from_cnf(prop)
  67. sat_false.add_from_cnf(_prop)
  68. can_be_true = satisfiable(sat_true)
  69. can_be_false = satisfiable(sat_false)
  70. if can_be_true and can_be_false:
  71. return None
  72. if can_be_true and not can_be_false:
  73. return True
  74. if not can_be_true and can_be_false:
  75. return False
  76. if not can_be_true and not can_be_false:
  77. # TODO: Run additional checks to see which combination of the
  78. # assumptions, global_assumptions, and relevant_facts are
  79. # inconsistent.
  80. raise ValueError("Inconsistent assumptions")
  81. def extract_predargs(proposition, assumptions=None, context=None):
  82. """
  83. Extract every expression in the argument of predicates from *proposition*,
  84. *assumptions* and *context*.
  85. Parameters
  86. ==========
  87. proposition : sympy.assumptions.cnf.CNF
  88. assumptions : sympy.assumptions.cnf.CNF, optional.
  89. context : sympy.assumptions.cnf.CNF, optional.
  90. CNF generated from assumptions context.
  91. Examples
  92. ========
  93. >>> from sympy import Q, Abs
  94. >>> from sympy.assumptions.cnf import CNF
  95. >>> from sympy.assumptions.satask import extract_predargs
  96. >>> from sympy.abc import x, y
  97. >>> props = CNF.from_prop(Q.zero(Abs(x*y)))
  98. >>> assump = CNF.from_prop(Q.zero(x) & Q.zero(y))
  99. >>> extract_predargs(props, assump)
  100. {x, y, Abs(x*y)}
  101. """
  102. req_keys = find_symbols(proposition)
  103. keys = proposition.all_predicates()
  104. # XXX: We need this since True/False are not Basic
  105. lkeys = set()
  106. if assumptions:
  107. lkeys |= assumptions.all_predicates()
  108. if context:
  109. lkeys |= context.all_predicates()
  110. lkeys = lkeys - {S.true, S.false}
  111. tmp_keys = None
  112. while tmp_keys != set():
  113. tmp = set()
  114. for l in lkeys:
  115. syms = find_symbols(l)
  116. if (syms & req_keys) != set():
  117. tmp |= syms
  118. tmp_keys = tmp - req_keys
  119. req_keys |= tmp_keys
  120. keys |= {l for l in lkeys if find_symbols(l) & req_keys != set()}
  121. exprs = set()
  122. for key in keys:
  123. if isinstance(key, AppliedPredicate):
  124. exprs |= set(key.arguments)
  125. else:
  126. exprs.add(key)
  127. return exprs
  128. def find_symbols(pred):
  129. """
  130. Find every :obj:`~.Symbol` in *pred*.
  131. Parameters
  132. ==========
  133. pred : sympy.assumptions.cnf.CNF, or any Expr.
  134. """
  135. if isinstance(pred, CNF):
  136. symbols = set()
  137. for a in pred.all_predicates():
  138. symbols |= find_symbols(a)
  139. return symbols
  140. return pred.atoms(Symbol)
  141. def get_relevant_clsfacts(exprs, relevant_facts=None):
  142. """
  143. Extract relevant facts from the items in *exprs*. Facts are defined in
  144. ``assumptions.sathandlers`` module.
  145. This function is recursively called by ``get_all_relevant_facts()``.
  146. Parameters
  147. ==========
  148. exprs : set
  149. Expressions whose relevant facts are searched.
  150. relevant_facts : sympy.assumptions.cnf.CNF, optional.
  151. Pre-discovered relevant facts.
  152. Returns
  153. =======
  154. exprs : set
  155. Candidates for next relevant fact searching.
  156. relevant_facts : sympy.assumptions.cnf.CNF
  157. Updated relevant facts.
  158. Examples
  159. ========
  160. Here, we will see how facts relevant to ``Abs(x*y)`` are recursively
  161. extracted. On the first run, set containing the expression is passed
  162. without pre-discovered relevant facts. The result is a set containing
  163. candidates for next run, and ``CNF()`` instance containing facts
  164. which are relevant to ``Abs`` and its argument.
  165. >>> from sympy import Abs
  166. >>> from sympy.assumptions.satask import get_relevant_clsfacts
  167. >>> from sympy.abc import x, y
  168. >>> exprs = {Abs(x*y)}
  169. >>> exprs, facts = get_relevant_clsfacts(exprs)
  170. >>> exprs
  171. {x*y}
  172. >>> facts.clauses #doctest: +SKIP
  173. {frozenset({Literal(Q.odd(Abs(x*y)), False), Literal(Q.odd(x*y), True)}),
  174. frozenset({Literal(Q.zero(Abs(x*y)), False), Literal(Q.zero(x*y), True)}),
  175. frozenset({Literal(Q.even(Abs(x*y)), False), Literal(Q.even(x*y), True)}),
  176. frozenset({Literal(Q.zero(Abs(x*y)), True), Literal(Q.zero(x*y), False)}),
  177. frozenset({Literal(Q.even(Abs(x*y)), False),
  178. Literal(Q.odd(Abs(x*y)), False),
  179. Literal(Q.odd(x*y), True)}),
  180. frozenset({Literal(Q.even(Abs(x*y)), False),
  181. Literal(Q.even(x*y), True),
  182. Literal(Q.odd(Abs(x*y)), False)}),
  183. frozenset({Literal(Q.positive(Abs(x*y)), False),
  184. Literal(Q.zero(Abs(x*y)), False)})}
  185. We pass the first run's results to the second run, and get the expressions
  186. for next run and updated facts.
  187. >>> exprs, facts = get_relevant_clsfacts(exprs, relevant_facts=facts)
  188. >>> exprs
  189. {x, y}
  190. On final run, no more candidate is returned thus we know that all
  191. relevant facts are successfully retrieved.
  192. >>> exprs, facts = get_relevant_clsfacts(exprs, relevant_facts=facts)
  193. >>> exprs
  194. set()
  195. """
  196. if not relevant_facts:
  197. relevant_facts = CNF()
  198. newexprs = set()
  199. for expr in exprs:
  200. for fact in class_fact_registry(expr):
  201. newfact = CNF.to_CNF(fact)
  202. relevant_facts = relevant_facts._and(newfact)
  203. for key in newfact.all_predicates():
  204. if isinstance(key, AppliedPredicate):
  205. newexprs |= set(key.arguments)
  206. return newexprs - exprs, relevant_facts
  207. def get_all_relevant_facts(proposition, assumptions, context,
  208. use_known_facts=True, iterations=oo):
  209. """
  210. Extract all relevant facts from *proposition* and *assumptions*.
  211. This function extracts the facts by recursively calling
  212. ``get_relevant_clsfacts()``. Extracted facts are converted to
  213. ``EncodedCNF`` and returned.
  214. Parameters
  215. ==========
  216. proposition : sympy.assumptions.cnf.CNF
  217. CNF generated from proposition expression.
  218. assumptions : sympy.assumptions.cnf.CNF
  219. CNF generated from assumption expression.
  220. context : sympy.assumptions.cnf.CNF
  221. CNF generated from assumptions context.
  222. use_known_facts : bool, optional.
  223. If ``True``, facts from ``sympy.assumptions.ask_generated``
  224. module are encoded as well.
  225. iterations : int, optional.
  226. Number of times that relevant facts are recursively extracted.
  227. Default is infinite times until no new fact is found.
  228. Returns
  229. =======
  230. sympy.assumptions.cnf.EncodedCNF
  231. Examples
  232. ========
  233. >>> from sympy import Q
  234. >>> from sympy.assumptions.cnf import CNF
  235. >>> from sympy.assumptions.satask import get_all_relevant_facts
  236. >>> from sympy.abc import x, y
  237. >>> props = CNF.from_prop(Q.nonzero(x*y))
  238. >>> assump = CNF.from_prop(Q.nonzero(x))
  239. >>> context = CNF.from_prop(Q.nonzero(y))
  240. >>> get_all_relevant_facts(props, assump, context) #doctest: +SKIP
  241. <sympy.assumptions.cnf.EncodedCNF at 0x7f09faa6ccd0>
  242. """
  243. # The relevant facts might introduce new keys, e.g., Q.zero(x*y) will
  244. # introduce the keys Q.zero(x) and Q.zero(y), so we need to run it until
  245. # we stop getting new things. Hopefully this strategy won't lead to an
  246. # infinite loop in the future.
  247. i = 0
  248. relevant_facts = CNF()
  249. all_exprs = set()
  250. while True:
  251. if i == 0:
  252. exprs = extract_predargs(proposition, assumptions, context)
  253. all_exprs |= exprs
  254. exprs, relevant_facts = get_relevant_clsfacts(exprs, relevant_facts)
  255. i += 1
  256. if i >= iterations:
  257. break
  258. if not exprs:
  259. break
  260. if use_known_facts:
  261. known_facts_CNF = CNF()
  262. if any(expr.kind == MatrixKind(NumberKind) for expr in all_exprs):
  263. known_facts_CNF.add_clauses(get_all_known_matrix_facts())
  264. # check for undefinedKind since kind system isn't fully implemented
  265. if any(((expr.kind == NumberKind) or (expr.kind == UndefinedKind)) for expr in all_exprs):
  266. known_facts_CNF.add_clauses(get_all_known_number_facts())
  267. kf_encoded = EncodedCNF()
  268. kf_encoded.from_cnf(known_facts_CNF)
  269. def translate_literal(lit, delta):
  270. if lit > 0:
  271. return lit + delta
  272. else:
  273. return lit - delta
  274. def translate_data(data, delta):
  275. return [{translate_literal(i, delta) for i in clause} for clause in data]
  276. data = []
  277. symbols = []
  278. n_lit = len(kf_encoded.symbols)
  279. for i, expr in enumerate(all_exprs):
  280. symbols += [pred(expr) for pred in kf_encoded.symbols]
  281. data += translate_data(kf_encoded.data, i * n_lit)
  282. encoding = dict(list(zip(symbols, range(1, len(symbols)+1))))
  283. ctx = EncodedCNF(data, encoding)
  284. else:
  285. ctx = EncodedCNF()
  286. ctx.add_from_cnf(relevant_facts)
  287. return ctx