lra_satask.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. from sympy.assumptions.assume import global_assumptions
  2. from sympy.assumptions.cnf import CNF, EncodedCNF
  3. from sympy.assumptions.ask import Q
  4. from sympy.logic.inference import satisfiable
  5. from sympy.logic.algorithms.lra_theory import UnhandledInput, ALLOWED_PRED
  6. from sympy.matrices.kind import MatrixKind
  7. from sympy.core.kind import NumberKind
  8. from sympy.assumptions.assume import AppliedPredicate
  9. from sympy.core.mul import Mul
  10. from sympy.core.singleton import S
  11. def lra_satask(proposition, assumptions=True, context=global_assumptions):
  12. """
  13. Function to evaluate the proposition with assumptions using SAT algorithm
  14. in conjunction with an Linear Real Arithmetic theory solver.
  15. Used to handle inequalities. Should eventually be depreciated and combined
  16. into satask, but infinity handling and other things need to be implemented
  17. before that can happen.
  18. """
  19. props = CNF.from_prop(proposition)
  20. _props = CNF.from_prop(~proposition)
  21. cnf = CNF.from_prop(assumptions)
  22. assumptions = EncodedCNF()
  23. assumptions.from_cnf(cnf)
  24. context_cnf = CNF()
  25. if context:
  26. context_cnf = context_cnf.extend(context)
  27. assumptions.add_from_cnf(context_cnf)
  28. return check_satisfiability(props, _props, assumptions)
  29. # Some predicates such as Q.prime can't be handled by lra_satask.
  30. # For example, (x > 0) & (x < 1) & Q.prime(x) is unsat but lra_satask would think it was sat.
  31. # WHITE_LIST is a list of predicates that can always be handled.
  32. WHITE_LIST = ALLOWED_PRED | {Q.positive, Q.negative, Q.zero, Q.nonzero, Q.nonpositive, Q.nonnegative,
  33. Q.extended_positive, Q.extended_negative, Q.extended_nonpositive,
  34. Q.extended_negative, Q.extended_nonzero, Q.negative_infinite,
  35. Q.positive_infinite}
  36. def check_satisfiability(prop, _prop, factbase):
  37. sat_true = factbase.copy()
  38. sat_false = factbase.copy()
  39. sat_true.add_from_cnf(prop)
  40. sat_false.add_from_cnf(_prop)
  41. all_pred, all_exprs = get_all_pred_and_expr_from_enc_cnf(sat_true)
  42. for pred in all_pred:
  43. if pred.function not in WHITE_LIST and pred.function != Q.ne:
  44. raise UnhandledInput(f"LRASolver: {pred} is an unhandled predicate")
  45. for expr in all_exprs:
  46. if expr.kind == MatrixKind(NumberKind):
  47. raise UnhandledInput(f"LRASolver: {expr} is of MatrixKind")
  48. if expr == S.NaN:
  49. raise UnhandledInput("LRASolver: nan")
  50. # convert old assumptions into predicates and add them to sat_true and sat_false
  51. # also check for unhandled predicates
  52. for assm in extract_pred_from_old_assum(all_exprs):
  53. n = len(sat_true.encoding)
  54. if assm not in sat_true.encoding:
  55. sat_true.encoding[assm] = n+1
  56. sat_true.data.append([sat_true.encoding[assm]])
  57. n = len(sat_false.encoding)
  58. if assm not in sat_false.encoding:
  59. sat_false.encoding[assm] = n+1
  60. sat_false.data.append([sat_false.encoding[assm]])
  61. sat_true = _preprocess(sat_true)
  62. sat_false = _preprocess(sat_false)
  63. can_be_true = satisfiable(sat_true, use_lra_theory=True) is not False
  64. can_be_false = satisfiable(sat_false, use_lra_theory=True) is not False
  65. if can_be_true and can_be_false:
  66. return None
  67. if can_be_true and not can_be_false:
  68. return True
  69. if not can_be_true and can_be_false:
  70. return False
  71. if not can_be_true and not can_be_false:
  72. raise ValueError("Inconsistent assumptions")
  73. def _preprocess(enc_cnf):
  74. """
  75. Returns an encoded cnf with only Q.eq, Q.gt, Q.lt,
  76. Q.ge, and Q.le predicate.
  77. Converts every unequality into a disjunction of strict
  78. inequalities. For example, x != 3 would become
  79. x < 3 OR x > 3.
  80. Also converts all negated Q.ne predicates into
  81. equalities.
  82. """
  83. # loops through each literal in each clause
  84. # to construct a new, preprocessed encodedCNF
  85. enc_cnf = enc_cnf.copy()
  86. cur_enc = 1
  87. rev_encoding = {value: key for key, value in enc_cnf.encoding.items()}
  88. new_encoding = {}
  89. new_data = []
  90. for clause in enc_cnf.data:
  91. new_clause = []
  92. for lit in clause:
  93. if lit == 0:
  94. new_clause.append(lit)
  95. new_encoding[lit] = False
  96. continue
  97. prop = rev_encoding[abs(lit)]
  98. negated = lit < 0
  99. sign = (lit > 0) - (lit < 0)
  100. prop = _pred_to_binrel(prop)
  101. if not isinstance(prop, AppliedPredicate):
  102. if prop not in new_encoding:
  103. new_encoding[prop] = cur_enc
  104. cur_enc += 1
  105. lit = new_encoding[prop]
  106. new_clause.append(sign*lit)
  107. continue
  108. if negated and prop.function == Q.eq:
  109. negated = False
  110. prop = Q.ne(*prop.arguments)
  111. if prop.function == Q.ne:
  112. arg1, arg2 = prop.arguments
  113. if negated:
  114. new_prop = Q.eq(arg1, arg2)
  115. if new_prop not in new_encoding:
  116. new_encoding[new_prop] = cur_enc
  117. cur_enc += 1
  118. new_enc = new_encoding[new_prop]
  119. new_clause.append(new_enc)
  120. continue
  121. else:
  122. new_props = (Q.gt(arg1, arg2), Q.lt(arg1, arg2))
  123. for new_prop in new_props:
  124. if new_prop not in new_encoding:
  125. new_encoding[new_prop] = cur_enc
  126. cur_enc += 1
  127. new_enc = new_encoding[new_prop]
  128. new_clause.append(new_enc)
  129. continue
  130. if prop.function == Q.eq and negated:
  131. assert False
  132. if prop not in new_encoding:
  133. new_encoding[prop] = cur_enc
  134. cur_enc += 1
  135. new_clause.append(new_encoding[prop]*sign)
  136. new_data.append(new_clause)
  137. assert len(new_encoding) >= cur_enc - 1
  138. enc_cnf = EncodedCNF(new_data, new_encoding)
  139. return enc_cnf
  140. def _pred_to_binrel(pred):
  141. if not isinstance(pred, AppliedPredicate):
  142. return pred
  143. if pred.function in pred_to_pos_neg_zero:
  144. f = pred_to_pos_neg_zero[pred.function]
  145. if f is False:
  146. return False
  147. pred = f(pred.arguments[0])
  148. if pred.function == Q.positive:
  149. pred = Q.gt(pred.arguments[0], 0)
  150. elif pred.function == Q.negative:
  151. pred = Q.lt(pred.arguments[0], 0)
  152. elif pred.function == Q.zero:
  153. pred = Q.eq(pred.arguments[0], 0)
  154. elif pred.function == Q.nonpositive:
  155. pred = Q.le(pred.arguments[0], 0)
  156. elif pred.function == Q.nonnegative:
  157. pred = Q.ge(pred.arguments[0], 0)
  158. elif pred.function == Q.nonzero:
  159. pred = Q.ne(pred.arguments[0], 0)
  160. return pred
  161. pred_to_pos_neg_zero = {
  162. Q.extended_positive: Q.positive,
  163. Q.extended_negative: Q.negative,
  164. Q.extended_nonpositive: Q.nonpositive,
  165. Q.extended_negative: Q.negative,
  166. Q.extended_nonzero: Q.nonzero,
  167. Q.negative_infinite: False,
  168. Q.positive_infinite: False
  169. }
  170. def get_all_pred_and_expr_from_enc_cnf(enc_cnf):
  171. all_exprs = set()
  172. all_pred = set()
  173. for pred in enc_cnf.encoding.keys():
  174. if isinstance(pred, AppliedPredicate):
  175. all_pred.add(pred)
  176. all_exprs.update(pred.arguments)
  177. return all_pred, all_exprs
  178. def extract_pred_from_old_assum(all_exprs):
  179. """
  180. Returns a list of relevant new assumption predicate
  181. based on any old assumptions.
  182. Raises an UnhandledInput exception if any of the assumptions are
  183. unhandled.
  184. Ignored predicate:
  185. - commutative
  186. - complex
  187. - algebraic
  188. - transcendental
  189. - extended_real
  190. - real
  191. - all matrix predicate
  192. - rational
  193. - irrational
  194. Example
  195. =======
  196. >>> from sympy.assumptions.lra_satask import extract_pred_from_old_assum
  197. >>> from sympy import symbols
  198. >>> x, y = symbols("x y", positive=True)
  199. >>> extract_pred_from_old_assum([x, y, 2])
  200. [Q.positive(x), Q.positive(y)]
  201. """
  202. ret = []
  203. for expr in all_exprs:
  204. if not hasattr(expr, "free_symbols"):
  205. continue
  206. if len(expr.free_symbols) == 0:
  207. continue
  208. if expr.is_real is not True:
  209. raise UnhandledInput(f"LRASolver: {expr} must be real")
  210. # test for I times imaginary variable; such expressions are considered real
  211. if isinstance(expr, Mul) and any(arg.is_real is not True for arg in expr.args):
  212. raise UnhandledInput(f"LRASolver: {expr} must be real")
  213. if expr.is_integer == True and expr.is_zero != True:
  214. raise UnhandledInput(f"LRASolver: {expr} is an integer")
  215. if expr.is_integer == False:
  216. raise UnhandledInput(f"LRASolver: {expr} can't be an integer")
  217. if expr.is_rational == False:
  218. raise UnhandledInput(f"LRASolver: {expr} is irational")
  219. if expr.is_zero:
  220. ret.append(Q.zero(expr))
  221. elif expr.is_positive:
  222. ret.append(Q.positive(expr))
  223. elif expr.is_negative:
  224. ret.append(Q.negative(expr))
  225. elif expr.is_nonzero:
  226. ret.append(Q.nonzero(expr))
  227. elif expr.is_nonpositive:
  228. ret.append(Q.nonpositive(expr))
  229. elif expr.is_nonnegative:
  230. ret.append(Q.nonnegative(expr))
  231. return ret