Last active
October 18, 2015 15:41
-
-
Save 5nizza/773c08a6dbddb82d92da to your computer and use it in GitHub Desktop.
Convert CUDD's BDD into SMT formula. Naive way -- defines function for each node. Careful -- not tested.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Function(object): | |
def __init__(self, name, input_args, body): | |
self.body = body | |
self.input_args = input_args | |
self.name = name | |
def smt_call(self, val_by_arg=None): | |
args_ = ' '.join(val_by_arg[a] if val_by_arg else a | |
for a in self.input_args) | |
return '({name} {args})'.format(name=self.name, args=args_) | |
def smt_declare(self): | |
if self.body: | |
arg_ty=' '.join('({a} Bool)'.format(a=a) | |
for a in self.input_args) | |
header = 'define-fun {name} ({arg_ty}) Bool'.format(name=self.name, | |
arg_ty=arg_ty) | |
return '(\n' \ | |
' {header}\n' \ | |
' {body}\n' \ | |
')'.format(header=header, body=self.body) | |
else: | |
ty = ' '.join(['Bool']*len(self.input_args)) | |
smt = '(declare-fun {name} ({ty}) Bool)'.format(name=self.name, | |
ty=ty) | |
return smt | |
def __str__(self): | |
return 'function:{name}({args})'.format(name=self.name, | |
args=self.input_args) | |
def _smt_bool(b): | |
return ['false', 'true'][b] | |
def _smt_negated(expr): | |
return '(not {expr})'.format(expr=expr) | |
def _smt_ite(guard_expr, then_expr, else_expr): | |
return '(ite {g} {t} {e})'.format(g=guard_expr, | |
t=then_expr, | |
e=else_expr) | |
def _get_smt_var(var_index): | |
return 'var'+str(var_index) | |
def _get_func_name(reg_bdd_node, _name_by_bdd=dict(), _last_by_idx=dict()): | |
""" | |
Each node gets name: {depth}_{num}, where {num} is unique but arbitrary | |
""" | |
if reg_bdd_node in _name_by_bdd: | |
return _name_by_bdd[reg_bdd_node] | |
idx = reg_bdd_node.NodeReadIndex() | |
if idx in _last_by_idx: | |
last = _last_by_idx[idx] + 1 | |
_last_by_idx[idx] = last | |
else: | |
last = 0 | |
_last_by_idx[idx] = 0 | |
name = 'func_' + str(idx) + '_' + str(last) | |
_name_by_bdd[reg_bdd_node] = name | |
return name | |
def smt_walk(sign_bdd_node, | |
ordered_func_args, | |
name_by_node_index, | |
funcs_by_depth, | |
get_var_depth, | |
_func_by_name=dict()): | |
""" | |
:arg sign_bdd_node -- the node to be converted into SMT formula | |
:arg ordered_func_args defines arguments of functions that represent nodes | |
:arg name_by_node_index maps bdd node to the name of variable: dict: NodeIndex -> smt_name | |
:arg funcs_by_depth is filled, | |
you need to declare those functions in SMT in the increasing order | |
:arg _func_by_name -- caching (for internal use) | |
:return SMT function call as string (may be a negated expression) | |
""" | |
smt_walk2 = lambda node: smt_walk(node, ordered_func_args, name_by_node_index, funcs_by_depth, get_var_depth) | |
if sign_bdd_node.IsConstant(): | |
# one is the first-class pointer, zero - is the complement | |
res = _smt_bool(not sign_bdd_node.IsComplement()) | |
return res | |
func_name = _get_func_name(sign_bdd_node.Regular()) | |
if func_name in _func_by_name: | |
res = _func_by_name[func_name].smt_call() | |
return res \ | |
if not sign_bdd_node.IsComplement() \ | |
else _smt_negated(res) | |
guard_var = name_by_node_index[sign_bdd_node.NodeReadIndex()] | |
then_expr = smt_walk2(sign_bdd_node.T()) | |
else_expr = smt_walk2(sign_bdd_node.E()) | |
func_body = _smt_ite(guard_var, then_expr, else_expr) | |
func = Function(func_name, ordered_func_args, func_body) | |
_func_by_name[func_name] = func | |
depth = get_var_depth(sign_bdd_node) | |
if depth not in funcs_by_depth: | |
funcs_by_depth[depth] = set() | |
funcs_by_depth[depth].add(func) | |
res = func.smt_call() | |
return res \ | |
if not sign_bdd_node.IsComplement() \ | |
else _smt_negated(res) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment