Created
June 2, 2021 13:03
-
-
Save miquelramirez/314efd600d47b7902a3db29676653619 to your computer and use it in GitHub Desktop.
`tarski` to `pysmt` formula rewriting
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tarski import FirstOrderLanguage | |
from tarski.syntax import * | |
import pysmt | |
from pysmt.shortcuts import FreshSymbol, Symbol | |
from pysmt.shortcuts import Bool, Int, Real, FunctionType | |
from pysmt.shortcuts import And, Implies, Or, TRUE, FALSE, Not | |
from pysmt.typing import INT, BOOL, REAL | |
from pysmt.shortcuts import LE, GE, Equals, NotEquals, LT, GT, Plus, Minus, Times, Div | |
class CompilationError(Exception): | |
""" | |
Generic compilation error | |
""" | |
pass | |
class TermIsNotVariable(Exception): | |
""" | |
A term was requested to be mapped into a variable, but that was not | |
possible. | |
""" | |
pass | |
class TransformationError(Exception): | |
""" | |
Could not transform a given formula or term | |
""" | |
pass | |
tsk_to_pysmt = { | |
BuiltinPredicateSymbol.EQ: Equals, | |
BuiltinPredicateSymbol.NE: NotEquals, | |
BuiltinPredicateSymbol.LT: LT, | |
BuiltinPredicateSymbol.LE: LE, | |
BuiltinPredicateSymbol.GT: GT, | |
BuiltinPredicateSymbol.GE: GE, | |
BuiltinFunctionSymbol.ADD: Plus, | |
BuiltinFunctionSymbol.SUB: Minus, | |
BuiltinFunctionSymbol.MUL: Times, | |
BuiltinFunctionSymbol.DIV: Div | |
# ... | |
} | |
def map_sort_to_type(L: FirstOrderLanguage, s: Sort, overrides=None): | |
""" | |
Maps a sort in a Tarski FOL into an SMT-LIB type | |
:param L: source language | |
:param s: sort | |
:param overrides: overrides over default behaviour | |
:return: | |
""" | |
if overrides is None: | |
overrides = {} | |
try: | |
return overrides[s.name] | |
except KeyError: | |
pass | |
if s == L.Integer: | |
return INT | |
elif s == L.Real: | |
return REAL | |
else: | |
return INT | |
def resolve_constant(L: FirstOrderLanguage, phi: Term, target_sort: Sort = None): | |
""" | |
Maps terms into PySMT constants | |
:param L: | |
:param phi: | |
:param target_sort: | |
:return: | |
""" | |
if not isinstance(phi, Constant): | |
raise CompilationError("pysmt translator: Compilation of static (constant) terms like '{}' not implemented yet!".format(str(phi))) | |
if target_sort is None: | |
target_sort = phi.sort | |
if target_sort == L.Integer: | |
return Int(phi.symbol) | |
if target_sort == L.get('Bool'): | |
return Int(phi.symbol) | |
domain = list(target_sort.domain()) | |
for k, v in enumerate(domain): | |
if v.symbol == phi.symbol: | |
return Int(k) | |
return None | |
def create_function_type(L: FirstOrderLanguage, func: Function, smt_function_types, smt_function_terms, overrides=None): | |
if overrides is None: | |
overrides = {} | |
domain_types = [map_sort_to_type(L, s, overrides) for s in func.domain] | |
codomain_type = map_sort_to_type(L, func.codomain, overrides) | |
func_type = FunctionType(codomain_type, domain_types) | |
smt_function_types[func.signature] = func_type | |
smt_function_terms[func.signature] = Symbol(func.signature, func_type) | |
return smt_function_terms[func.signature] | |
def create_function_application_term(L: FirstOrderLanguage, smt_function_types, smt_function_terms, f: Function, args, overrides=None): | |
""" | |
Creates function application terms | |
:param L: | |
:param smt_function_types: | |
:param smt_function_terms: | |
:param f: | |
:param args: | |
:param overrides: | |
:return: | |
""" | |
if overrides is None: | |
overrides = {} | |
try: | |
func_term = smt_function_terms[f.signature] | |
except KeyError: | |
if f.arity > 0: | |
func_term = create_function_type(L, f, smt_function_types, smt_function_terms, overrides) | |
smt_function_terms[f.signature] = func_term | |
else: | |
# MRJ: arity zero symbol maps directly to term | |
x = Variable('x', f.codomain) | |
func_term = create_pysmt_variable(L, x) | |
smt_function_terms[f.signature] = func_term | |
return func_term | |
if f.arity == 0: | |
return func_term | |
return pysmt.shortcuts.Function(func_term, args) | |
def create_pysmt_bool_variable(name): | |
""" | |
Creates a Boolean PySMT term | |
""" | |
return FreshSymbol(name, BOOL) | |
def create_domain_axioms(y, sort): | |
""" | |
Generates domain axioms for a PySMT term of given Tarski sort | |
:param self: | |
:param y: | |
:param sort: | |
:return: | |
""" | |
lb = Int(0) | |
ub = Int(len(list(sort.domain())) - 1) | |
# print('{} <= {} <= {}'.format(lb, xi, ub)) | |
return [GE(y, lb), LE(y, ub)] | |
def create_var_for_term(L: FirstOrderLanguage, t: CompoundTerm, name: str): | |
""" | |
Creates a PySMT variable | |
:param L: Tarski FOL | |
:param var: | |
:return: | |
""" | |
axioms = [] | |
if t.sort == L.Integer: | |
y_var = Symbol(name, INT) | |
elif t.sort == L.Real: | |
y_var = Symbol(name, REAL) | |
elif t.sort == L.get('Bool'): | |
y_var = Symbol(name, INT) | |
axioms += [GE(y_var, Int(0)), LE(y_var, Int(1))] | |
else: | |
y_var = Symbol(name, INT) | |
axioms += create_domain_axioms(y_var, t.sort) | |
return y_var, axioms | |
class TransformationError(Exception): | |
pass | |
class FormulaRewriter(object): | |
def __init__(self, L: FirstOrderLanguage): | |
""" | |
:param L: | |
""" | |
self.L = L | |
self.smt_function_types = OrderedDict() | |
self.smt_function_terms = OrderedDict() | |
def rewrite(self, phi: CompoundFormula, x: np.array, i: int, symbol_idx: OrderedDict): | |
""" | |
Rewrites formula phi into SMT-LIB, every occurrence of x becomes an occurrence of x@i | |
:param L: instance first-order language | |
:param phi: formula | |
:param x: array of SMT decision variables | |
:param i: time index | |
:param symbol_idx: index of symbols | |
:return: | |
""" | |
if isinstance(phi, QuantifiedFormula): | |
raise TransformationError("rewrite_formula(): formula {} is not quantifier free!".format(phi)) | |
elif isinstance(phi, Tautology): | |
return TRUE() | |
elif isinstance(phi, CompoundFormula): | |
y_sub = [self.rewrite(psi, x, i, symbol_idx) for psi in phi.subformulas] | |
if phi.connective == Connective.Not: | |
return Not(y_sub[0]) | |
elif phi.connective == Connective.And: | |
return And(y_sub[0], y_sub[1]) | |
elif phi.connective == Connective.Or: | |
return Or(y_sub[0], y_sub[1]) | |
elif isinstance(phi, Atom): | |
if phi.predicate.builtin: | |
y_sub = [self.rewrite(psi, x, i, symbol_idx) for psi in phi.subterms] | |
if len(y_sub) != 2: | |
raise TransformationError("rewrite_formula():", phi, "Only built-in binary predicates are supported") | |
return tsk_to_pysmt[phi.predicate.symbol](y_sub[0], y_sub[1]) | |
raise TransformationError("temporal.Theory", phi, "arbitrary atomic formulas are not supported") | |
elif isinstance(phi, CompoundTerm): | |
if phi.symbol.builtin: | |
y_sub = [self.rewrite(psi, x, i, symbol_idx) for psi in phi.subterms] | |
if len(y_sub) != 2: | |
raise TransformationError("temporal.Theory", phi, "Only built-in binary functions are supported") | |
return tsk_to_pysmt[phi.symbol.symbol](y_sub[0], y_sub[1]) | |
# MRJ: all terms which are not builtin are supposed to be grounded and already | |
# tracked by the theory | |
# print(phi,self.i,self.j) | |
return x[symbol_idx[symref(phi)], i] | |
# try: | |
# params = [] | |
# for st in phi.subterms: | |
# try: | |
# params.append(symbol_idx[symref(st)]) | |
# except KeyError: | |
# params.append(resolve_constant(self.L, st)) | |
# return create_function_application_term(self.L, self.smt_function_types, self.smt_function_terms, phi.symbol, params) | |
# except TermIsNotVariable: | |
# return resolve_constant(self.L, phi) | |
elif isinstance(phi, Variable): | |
try: | |
return x[symbol_idx[symref(phi)], i] | |
except KeyError: | |
raise TransformationError("rewrite():", phi, | |
"Did not know how to translate term".format(type(phi))) | |
elif isinstance(phi, Constant): | |
return resolve_constant(self.L, phi) | |
else: | |
raise TransformationError("rewrite():", phi, | |
"Did not know how to translate formula of type '{}'!".format(type(phi))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment