Skip to content

Instantly share code, notes, and snippets.

@miquelramirez
Created June 2, 2021 13:03
Show Gist options
  • Save miquelramirez/314efd600d47b7902a3db29676653619 to your computer and use it in GitHub Desktop.
Save miquelramirez/314efd600d47b7902a3db29676653619 to your computer and use it in GitHub Desktop.
`tarski` to `pysmt` formula rewriting
from tarski import FirstOrderLanguage
from tarski.syntax import *
import pysmt
from pysmt.shortcuts import FreshSymbol, Symbol
from pysmt.shortcuts import Bool, Int, Real, FunctionType
from pysmt.shortcuts import And, Implies, Or, TRUE, FALSE, Not
from pysmt.typing import INT, BOOL, REAL
from pysmt.shortcuts import LE, GE, Equals, NotEquals, LT, GT, Plus, Minus, Times, Div
class CompilationError(Exception):
"""
Generic compilation error
"""
pass
class TermIsNotVariable(Exception):
"""
A term was requested to be mapped into a variable, but that was not
possible.
"""
pass
class TransformationError(Exception):
"""
Could not transform a given formula or term
"""
pass
tsk_to_pysmt = {
BuiltinPredicateSymbol.EQ: Equals,
BuiltinPredicateSymbol.NE: NotEquals,
BuiltinPredicateSymbol.LT: LT,
BuiltinPredicateSymbol.LE: LE,
BuiltinPredicateSymbol.GT: GT,
BuiltinPredicateSymbol.GE: GE,
BuiltinFunctionSymbol.ADD: Plus,
BuiltinFunctionSymbol.SUB: Minus,
BuiltinFunctionSymbol.MUL: Times,
BuiltinFunctionSymbol.DIV: Div
# ...
}
def map_sort_to_type(L: FirstOrderLanguage, s: Sort, overrides=None):
"""
Maps a sort in a Tarski FOL into an SMT-LIB type
:param L: source language
:param s: sort
:param overrides: overrides over default behaviour
:return:
"""
if overrides is None:
overrides = {}
try:
return overrides[s.name]
except KeyError:
pass
if s == L.Integer:
return INT
elif s == L.Real:
return REAL
else:
return INT
def resolve_constant(L: FirstOrderLanguage, phi: Term, target_sort: Sort = None):
"""
Maps terms into PySMT constants
:param L:
:param phi:
:param target_sort:
:return:
"""
if not isinstance(phi, Constant):
raise CompilationError("pysmt translator: Compilation of static (constant) terms like '{}' not implemented yet!".format(str(phi)))
if target_sort is None:
target_sort = phi.sort
if target_sort == L.Integer:
return Int(phi.symbol)
if target_sort == L.get('Bool'):
return Int(phi.symbol)
domain = list(target_sort.domain())
for k, v in enumerate(domain):
if v.symbol == phi.symbol:
return Int(k)
return None
def create_function_type(L: FirstOrderLanguage, func: Function, smt_function_types, smt_function_terms, overrides=None):
if overrides is None:
overrides = {}
domain_types = [map_sort_to_type(L, s, overrides) for s in func.domain]
codomain_type = map_sort_to_type(L, func.codomain, overrides)
func_type = FunctionType(codomain_type, domain_types)
smt_function_types[func.signature] = func_type
smt_function_terms[func.signature] = Symbol(func.signature, func_type)
return smt_function_terms[func.signature]
def create_function_application_term(L: FirstOrderLanguage, smt_function_types, smt_function_terms, f: Function, args, overrides=None):
"""
Creates function application terms
:param L:
:param smt_function_types:
:param smt_function_terms:
:param f:
:param args:
:param overrides:
:return:
"""
if overrides is None:
overrides = {}
try:
func_term = smt_function_terms[f.signature]
except KeyError:
if f.arity > 0:
func_term = create_function_type(L, f, smt_function_types, smt_function_terms, overrides)
smt_function_terms[f.signature] = func_term
else:
# MRJ: arity zero symbol maps directly to term
x = Variable('x', f.codomain)
func_term = create_pysmt_variable(L, x)
smt_function_terms[f.signature] = func_term
return func_term
if f.arity == 0:
return func_term
return pysmt.shortcuts.Function(func_term, args)
def create_pysmt_bool_variable(name):
"""
Creates a Boolean PySMT term
"""
return FreshSymbol(name, BOOL)
def create_domain_axioms(y, sort):
"""
Generates domain axioms for a PySMT term of given Tarski sort
:param self:
:param y:
:param sort:
:return:
"""
lb = Int(0)
ub = Int(len(list(sort.domain())) - 1)
# print('{} <= {} <= {}'.format(lb, xi, ub))
return [GE(y, lb), LE(y, ub)]
def create_var_for_term(L: FirstOrderLanguage, t: CompoundTerm, name: str):
"""
Creates a PySMT variable
:param L: Tarski FOL
:param var:
:return:
"""
axioms = []
if t.sort == L.Integer:
y_var = Symbol(name, INT)
elif t.sort == L.Real:
y_var = Symbol(name, REAL)
elif t.sort == L.get('Bool'):
y_var = Symbol(name, INT)
axioms += [GE(y_var, Int(0)), LE(y_var, Int(1))]
else:
y_var = Symbol(name, INT)
axioms += create_domain_axioms(y_var, t.sort)
return y_var, axioms
class TransformationError(Exception):
pass
class FormulaRewriter(object):
def __init__(self, L: FirstOrderLanguage):
"""
:param L:
"""
self.L = L
self.smt_function_types = OrderedDict()
self.smt_function_terms = OrderedDict()
def rewrite(self, phi: CompoundFormula, x: np.array, i: int, symbol_idx: OrderedDict):
"""
Rewrites formula phi into SMT-LIB, every occurrence of x becomes an occurrence of x@i
:param L: instance first-order language
:param phi: formula
:param x: array of SMT decision variables
:param i: time index
:param symbol_idx: index of symbols
:return:
"""
if isinstance(phi, QuantifiedFormula):
raise TransformationError("rewrite_formula(): formula {} is not quantifier free!".format(phi))
elif isinstance(phi, Tautology):
return TRUE()
elif isinstance(phi, CompoundFormula):
y_sub = [self.rewrite(psi, x, i, symbol_idx) for psi in phi.subformulas]
if phi.connective == Connective.Not:
return Not(y_sub[0])
elif phi.connective == Connective.And:
return And(y_sub[0], y_sub[1])
elif phi.connective == Connective.Or:
return Or(y_sub[0], y_sub[1])
elif isinstance(phi, Atom):
if phi.predicate.builtin:
y_sub = [self.rewrite(psi, x, i, symbol_idx) for psi in phi.subterms]
if len(y_sub) != 2:
raise TransformationError("rewrite_formula():", phi, "Only built-in binary predicates are supported")
return tsk_to_pysmt[phi.predicate.symbol](y_sub[0], y_sub[1])
raise TransformationError("temporal.Theory", phi, "arbitrary atomic formulas are not supported")
elif isinstance(phi, CompoundTerm):
if phi.symbol.builtin:
y_sub = [self.rewrite(psi, x, i, symbol_idx) for psi in phi.subterms]
if len(y_sub) != 2:
raise TransformationError("temporal.Theory", phi, "Only built-in binary functions are supported")
return tsk_to_pysmt[phi.symbol.symbol](y_sub[0], y_sub[1])
# MRJ: all terms which are not builtin are supposed to be grounded and already
# tracked by the theory
# print(phi,self.i,self.j)
return x[symbol_idx[symref(phi)], i]
# try:
# params = []
# for st in phi.subterms:
# try:
# params.append(symbol_idx[symref(st)])
# except KeyError:
# params.append(resolve_constant(self.L, st))
# return create_function_application_term(self.L, self.smt_function_types, self.smt_function_terms, phi.symbol, params)
# except TermIsNotVariable:
# return resolve_constant(self.L, phi)
elif isinstance(phi, Variable):
try:
return x[symbol_idx[symref(phi)], i]
except KeyError:
raise TransformationError("rewrite():", phi,
"Did not know how to translate term".format(type(phi)))
elif isinstance(phi, Constant):
return resolve_constant(self.L, phi)
else:
raise TransformationError("rewrite():", phi,
"Did not know how to translate formula of type '{}'!".format(type(phi)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment