z3_wrapper.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. from sympy.printing.smtlib import smtlib_code
  2. from sympy.assumptions.assume import AppliedPredicate
  3. from sympy.assumptions.cnf import EncodedCNF
  4. from sympy.assumptions.ask import Q
  5. from sympy.core import Add, Mul
  6. from sympy.core.relational import Equality, LessThan, GreaterThan, StrictLessThan, StrictGreaterThan
  7. from sympy.functions.elementary.complexes import Abs
  8. from sympy.functions.elementary.exponential import Pow
  9. from sympy.functions.elementary.miscellaneous import Min, Max
  10. from sympy.logic.boolalg import And, Or, Xor, Implies
  11. from sympy.logic.boolalg import Not, ITE
  12. from sympy.assumptions.relation.equality import StrictGreaterThanPredicate, StrictLessThanPredicate, GreaterThanPredicate, LessThanPredicate, EqualityPredicate
  13. from sympy.external import import_module
  14. def z3_satisfiable(expr, all_models=False):
  15. if not isinstance(expr, EncodedCNF):
  16. exprs = EncodedCNF()
  17. exprs.add_prop(expr)
  18. expr = exprs
  19. z3 = import_module("z3")
  20. if z3 is None:
  21. raise ImportError("z3 is not installed")
  22. s = encoded_cnf_to_z3_solver(expr, z3)
  23. res = str(s.check())
  24. if res == "unsat":
  25. return False
  26. elif res == "sat":
  27. return z3_model_to_sympy_model(s.model(), expr)
  28. else:
  29. return None
  30. def z3_model_to_sympy_model(z3_model, enc_cnf):
  31. rev_enc = {value : key for key, value in enc_cnf.encoding.items()}
  32. return {rev_enc[int(var.name()[1:])] : bool(z3_model[var]) for var in z3_model}
  33. def clause_to_assertion(clause):
  34. clause_strings = [f"d{abs(lit)}" if lit > 0 else f"(not d{abs(lit)})" for lit in clause]
  35. return "(assert (or " + " ".join(clause_strings) + "))"
  36. def encoded_cnf_to_z3_solver(enc_cnf, z3):
  37. def dummify_bool(pred):
  38. return False
  39. assert isinstance(pred, AppliedPredicate)
  40. if pred.function in [Q.positive, Q.negative, Q.zero]:
  41. return pred
  42. else:
  43. return False
  44. s = z3.Solver()
  45. declarations = [f"(declare-const d{var} Bool)" for var in enc_cnf.variables]
  46. assertions = [clause_to_assertion(clause) for clause in enc_cnf.data]
  47. symbols = set()
  48. for pred, enc in enc_cnf.encoding.items():
  49. if not isinstance(pred, AppliedPredicate):
  50. continue
  51. if pred.function not in (Q.gt, Q.lt, Q.ge, Q.le, Q.ne, Q.eq, Q.positive, Q.negative, Q.extended_negative, Q.extended_positive, Q.zero, Q.nonzero, Q.nonnegative, Q.nonpositive, Q.extended_nonzero, Q.extended_nonnegative, Q.extended_nonpositive):
  52. continue
  53. pred_str = smtlib_code(pred, auto_declare=False, auto_assert=False, known_functions=known_functions)
  54. symbols |= pred.free_symbols
  55. pred = pred_str
  56. clause = f"(implies d{enc} {pred})"
  57. assertion = "(assert " + clause + ")"
  58. assertions.append(assertion)
  59. for sym in symbols:
  60. declarations.append(f"(declare-const {sym} Real)")
  61. declarations = "\n".join(declarations)
  62. assertions = "\n".join(assertions)
  63. s.from_string(declarations)
  64. s.from_string(assertions)
  65. return s
  66. known_functions = {
  67. Add: '+',
  68. Mul: '*',
  69. Equality: '=',
  70. LessThan: '<=',
  71. GreaterThan: '>=',
  72. StrictLessThan: '<',
  73. StrictGreaterThan: '>',
  74. EqualityPredicate(): '=',
  75. LessThanPredicate(): '<=',
  76. GreaterThanPredicate(): '>=',
  77. StrictLessThanPredicate(): '<',
  78. StrictGreaterThanPredicate(): '>',
  79. Abs: 'abs',
  80. Min: 'min',
  81. Max: 'max',
  82. Pow: '^',
  83. And: 'and',
  84. Or: 'or',
  85. Xor: 'xor',
  86. Not: 'not',
  87. ITE: 'ite',
  88. Implies: '=>',
  89. }