| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115 |
- from sympy.printing.smtlib import smtlib_code
- from sympy.assumptions.assume import AppliedPredicate
- from sympy.assumptions.cnf import EncodedCNF
- from sympy.assumptions.ask import Q
- from sympy.core import Add, Mul
- from sympy.core.relational import Equality, LessThan, GreaterThan, StrictLessThan, StrictGreaterThan
- from sympy.functions.elementary.complexes import Abs
- from sympy.functions.elementary.exponential import Pow
- from sympy.functions.elementary.miscellaneous import Min, Max
- from sympy.logic.boolalg import And, Or, Xor, Implies
- from sympy.logic.boolalg import Not, ITE
- from sympy.assumptions.relation.equality import StrictGreaterThanPredicate, StrictLessThanPredicate, GreaterThanPredicate, LessThanPredicate, EqualityPredicate
- from sympy.external import import_module
- def z3_satisfiable(expr, all_models=False):
- if not isinstance(expr, EncodedCNF):
- exprs = EncodedCNF()
- exprs.add_prop(expr)
- expr = exprs
- z3 = import_module("z3")
- if z3 is None:
- raise ImportError("z3 is not installed")
- s = encoded_cnf_to_z3_solver(expr, z3)
- res = str(s.check())
- if res == "unsat":
- return False
- elif res == "sat":
- return z3_model_to_sympy_model(s.model(), expr)
- else:
- return None
- def z3_model_to_sympy_model(z3_model, enc_cnf):
- rev_enc = {value : key for key, value in enc_cnf.encoding.items()}
- return {rev_enc[int(var.name()[1:])] : bool(z3_model[var]) for var in z3_model}
- def clause_to_assertion(clause):
- clause_strings = [f"d{abs(lit)}" if lit > 0 else f"(not d{abs(lit)})" for lit in clause]
- return "(assert (or " + " ".join(clause_strings) + "))"
- def encoded_cnf_to_z3_solver(enc_cnf, z3):
- def dummify_bool(pred):
- return False
- assert isinstance(pred, AppliedPredicate)
- if pred.function in [Q.positive, Q.negative, Q.zero]:
- return pred
- else:
- return False
- s = z3.Solver()
- declarations = [f"(declare-const d{var} Bool)" for var in enc_cnf.variables]
- assertions = [clause_to_assertion(clause) for clause in enc_cnf.data]
- symbols = set()
- for pred, enc in enc_cnf.encoding.items():
- if not isinstance(pred, AppliedPredicate):
- continue
- if pred.function not in (Q.gt, Q.lt, Q.ge, Q.le, Q.ne, Q.eq, Q.positive, Q.negative, Q.extended_negative, Q.extended_positive, Q.zero, Q.nonzero, Q.nonnegative, Q.nonpositive, Q.extended_nonzero, Q.extended_nonnegative, Q.extended_nonpositive):
- continue
- pred_str = smtlib_code(pred, auto_declare=False, auto_assert=False, known_functions=known_functions)
- symbols |= pred.free_symbols
- pred = pred_str
- clause = f"(implies d{enc} {pred})"
- assertion = "(assert " + clause + ")"
- assertions.append(assertion)
- for sym in symbols:
- declarations.append(f"(declare-const {sym} Real)")
- declarations = "\n".join(declarations)
- assertions = "\n".join(assertions)
- s.from_string(declarations)
- s.from_string(assertions)
- return s
- known_functions = {
- Add: '+',
- Mul: '*',
- Equality: '=',
- LessThan: '<=',
- GreaterThan: '>=',
- StrictLessThan: '<',
- StrictGreaterThan: '>',
- EqualityPredicate(): '=',
- LessThanPredicate(): '<=',
- GreaterThanPredicate(): '>=',
- StrictLessThanPredicate(): '<',
- StrictGreaterThanPredicate(): '>',
- Abs: 'abs',
- Min: 'min',
- Max: 'max',
- Pow: '^',
- And: 'and',
- Or: 'or',
- Xor: 'xor',
- Not: 'not',
- ITE: 'ite',
- Implies: '=>',
- }
|