Skip to content

Instantly share code, notes, and snippets.

@bonzini
Last active August 29, 2015 14:10
Show Gist options
  • Save bonzini/7942680bf9818938a259 to your computer and use it in GitHub Desktop.
Save bonzini/7942680bf9818938a259 to your computer and use it in GitHub Desktop.
BDD library in Python
# Copyright 2014 Paolo Bonzini
# API inspired by py-simple-bdd.
# License: X11 (MIT)
class Variable(object):
def __init__(self, name):
self._name = name
self._hash = hash(name)
self._node = Node(self, Node.T, Node.F)
def __str__(self):
return str(self._name)
def __repr__(self):
return __name__+ '.Variable(' + repr(self._name) + ')'
def __hash__(self):
return self._hash
def __eq__(self, other):
return (self is other or
(isinstance(other, Variable) and
self._hash == other._hash and
self._name == other._name))
def __ne__(self, other):
return not self.__eq__(other)
@property
def name(self):
return self._name
@property
def node(self):
return self._node
class Ordering(object):
def __init__(self, vars=[]):
self._vars = []
self._order = dict()
self._n = 0
self.extend(vars)
def extend(self, vars):
for i in vars:
self._order[i] = self._n
self._n = self._n + 1
def __len__(self):
return self._n
def __getitem__(self, var):
if var is None:
return self._n
else:
return self._order[var]
def sort(self, vars):
return sorted(vars, key = lambda x: self._order[x])
@property
def vars(self):
return self._vars
@property
def comparator(self):
"""Return a two-argument function that takes two variables and
returns True if the first is above the second in the BDD."""
return (lambda x, y:
not (x is None) and
((y is None) or
self._order[x] < self._order[y]))
class _AbstractNode(object):
def visitLazy(self, compose, t=None, f=None):
cache = dict()
if not (t is None):
cache[id(Node.T)] = t
if not (f is None):
cache[id(Node.F)] = f
def v(p):
if id(p) in cache:
result = cache[id(p)]
else:
result = cache[id(p)] = \
compose(p, lambda: v(p.t), lambda: v(p.f))
return result
return v(self)
def visit(self, compose, t, f):
cache = dict()
cache[id(Node.T)] = t
cache[id(Node.F)] = f
def v(p):
if id(p) in cache:
result = cache[id(p)]
else:
result = cache[id(p)] = compose(p, v(p.t), v(p.f))
return result
return v(self)
def countNodes(self):
"""Returns the number of distinct node objects in the BDD"""
def compose(p, vt, vf):
compose.n = compose.n + 1
if not Node.isTerminal(p):
vt()
vf()
return compose.n
compose.n = 0
return self.visitLazy(compose)
@property
def root(self):
return self
class Node(_AbstractNode):
@staticmethod
def isTerminal(p):
"""Tests if p is Node.T or Node.F"""
return not isinstance(p, Node)
class __TerminalNode(_AbstractNode):
def __repr__(self):
return __name__ + '.Node.' + str(self)
@property
def var(self):
return None
@property
def t(self):
return self
@property
def f(self):
return self
class __TrueNode(__TerminalNode):
def __str__(self):
return 'T'
def __hash__(self):
return 1
def __invert__(self):
return Node.F
class __FalseNode(__TerminalNode):
def __str__(self):
return 'F'
def __hash__(self):
return 0
def __invert__(self):
return Node.T
T = __TrueNode()
F = __FalseNode()
@staticmethod
def terminal(value):
"""Returns the terminal corresponding to the boolean interpretation of value"""
if value and value != Node.F:
return Node.T
else:
return Node.F
def __init__(self, var, t, f, negated = None):
self._hash = (hash(var) + hash(t) - hash(f)) % 0xFFFFFFFF
self._var = var
self._t = t
self._f = f
self._negated = negated or Node(var, ~t, ~f, self)
@property
def var(self):
return self._var
@property
def t(self):
return self._t
@property
def f(self):
return self._f
def __str__(self):
if self.t == Node.T and self.f == Node.F:
return str(self.var)
elif self.t == Node.F and self.f == Node.T:
return str(self.var)+"'"
elif self.t == Node.T:
return '(%s | %s)' % (str(self.var), str(self.f))
elif self.f == Node.F:
return '(%s & %s)' % (str(self.var), str(self.t))
else:
return '(%s ? %s : %s)' % (str(self.var), str(self.t), str(self.f))
def __repr__(self):
return '%s.Node(%s, %s, %s)' % \
(__name__ , repr(self.var), repr(self.t), repr(self.f))
def __hash__(self):
return self._hash
def __eq__(self, other):
return (self is other or
(isinstance(other, Node) and
self._hash == other._hash and
self.var == other.var and
self.t == other.t and
self.f == other.f))
def __ne__(self, other):
return not self.__eq__(other)
def __invert__(self):
return self._negated
class BDD(object):
class __Cache(dict):
def __init__(self, ordering):
self[Node.T] = Node.T
self[Node.F] = Node.F
self._above = ordering.comparator
def replace(self, t, f, p, q=None):
if t == f:
result = t
elif t is p.t and f is p.f:
result = p
elif not (q is None) and t is q.t and f is q.f:
result = q
else:
result = bddNode(p.var, t, f)
if result in self:
return self[result]
elif ~result in self:
return ~self[~result]
self[result] = result
return result
def evaluate(self, p, vars, i, assignments, memo):
if i >= len(vars) or self._above(vars[i], p.var):
return p
tup = (p, i)
if tup in memo:
return memo[tup]
if p.var == vars[i]:
if Node.terminal(assignments[p.var]) == Node.T:
result = self.evaluate(p.t, vars, i + 1, assignments, memo)
else:
result = self.evaluate(p.f, vars, i + 1, assignments, memo)
else:
t = self.evaluate(p.t, vars, i, assignments, memo)
f = self.evaluate(p.f, vars, i, assignments, memo)
result = self.replace(t, f, p)
memo[tup] = result
return result
def reduce(self, p):
t = Node.isTerminal(p.t) and p.t or self.reduce(p.t)
f = Node.isTerminal(p.f) and p.f or self.reduce(p.f)
return self.replace(t, f, p)
def apply(self, p, q, binop, memo):
easy = binop(p, q)
if not (easy is None):
return easy
tup = (p, q)
if tup in memo:
return memo[tup]
if p.var == q.var:
t = self.apply(p.t, q.t, binop, memo)
f = self.apply(p.f, q.f, binop, memo)
result = self.replace(t, f, p, q)
elif self._above(p.var, q.var):
t = self.apply(p.t, q, binop, memo)
f = self.apply(p.f, q, binop, memo)
result = self.replace(t, f, p)
else:
t = self.apply(p, q.t, binop, memo)
f = self.apply(p, q.f, binop, memo)
result = self.replace(t, f, q)
memo[tup] = result
return result
def __init__(self, root, ordering, cache=None, negated=None):
self._root = root
self._ordering = ordering
self._cache = cache or BDD.__Cache(self._ordering)
self._negated = negated or BDD(~root, ordering, self._cache, self)
@staticmethod
def conjunction(vars, ordering, cache=None):
"""Form a bdd that is the AND of all the variables."""
result = Node.T
for i in reversed(ordering.sort(vars)):
result = Node(i, result, Node.F)
return BDD(result, ordering, cache)
@staticmethod
def disjunction(vars, ordering, cache=None):
"""Form a bdd that is the OR of all the variables."""
result = Node.F
for i in reversed(ordering.sort(vars)):
result = Node(i, Node.T, result)
return BDD(result, ordering, cache)
@staticmethod
def andAll(bdds, ordering, cache=None):
n = 1
stack = [BDD(Node.T, ordering, cache)]
for i in bdds:
n = n + 1
m = n & -n
while m != 1:
i = stack.pop() & i
m >>= 1
stack.append(i)
top = Node.T
while len(stack) > 0:
top &= stack.pop()
return top
@staticmethod
def orAll(bdds, ordering, cache=None):
n = 1
stack = [BDD(Node.F, ordering, cache)]
for i in bdds:
n = n + 1
m = n & -n
while m != 1:
i = stack.pop() | i
m >>= 1
stack.append(i)
top = Node.F
while len(stack) > 0:
top |= stack.pop()
return top
def countTrue(self):
# Each missing variable doubles the number of assignments
f = lambda p, vt, vf: \
(vt << (self._ordering[p.t.var] - self._ordering[p.var]) - 1) + \
(vf << (self._ordering[p.f.var] - self._ordering[p.var]) - 1)
return self.visit(f, 1, 0) << self._ordering[self._root.var]
def evaluate(self, assignments):
vars = self._ordering.sort(assignments)
node = self._cache.evaluate(self._root, vars, 0, assignments, dict())
return BDD(node, self._ordering, self._cache)
def __xnor(self, a, b):
return BDD(a.root, self._ordering, self._cache) ^ ~b
def force(self, assignments):
vars = self._ordering.sort(assignments)
iffs = (self.__xnor(v.node, assignments[v]) for v in vars)
return self & BDD.andAll(iffs, self._ordering, self._cache)
def apply(self, binop, other):
node = self._cache.apply(self._root, other.root, binop, dict())
return BDD(node, self._ordering, self._cache)
def reduce(self):
node = self._cache.reduce(self._root)
return BDD(node, self._ordering, self._cache)
def visitLazy(self, compose, t, f):
return self._root.visitLazy(compose, t, f)
def visit(self, compose, t, f):
return self._root.visit(compose, t, f)
@property
def root(self):
return self._root
@property
def ordering(self):
return self._ordering
def __str__(self):
return str(self._root)
def __hash__(self):
return hash(self._root)
def __eq__(self, other):
return (self is other or
(isinstance(other, BDD) and
self._root == other._root and
self._ordering is other._ordering))
def __ne__(self, other):
return not self.__eq__(other)
def __invert__(self):
return self._negated
@staticmethod
def __AND(p, q):
if p == Node.T:
return q
if q == Node.T:
return p
if p == Node.F:
return p
if q == Node.F:
return q
return None
def __and__(self, other):
return self.apply(BDD.__AND, other)
def __rand__(self, other):
return self & other
@staticmethod
def __OR(p, q):
if p == Node.F:
return q
if q == Node.F:
return p
if p == Node.T:
return p
if q == Node.T:
return q
return None
def __or__(self, other):
return self.apply(BDD.__OR, other)
def __ror__(self, other):
return self | other
@staticmethod
def __XOR(p, q):
if p == Node.T:
return ~q
if q == Node.T:
return ~p
if p == Node.F:
return q
if q == Node.F:
return p
return None
def __xor__(self, other):
return self.apply(BDD.__XOR, other)
def __rxor__(self, other):
return self ^ other
def bddNode(var, t, f):
if t.var == var:
t = t.t
if f.var == var:
f = f.f
if t == f:
return t
return Node(var, t, f)
# Copyright 2011 Craig Eales
# Copyright 2014 Paolo Bonzini
# Based on the py-simple-bdd unit tests
# This file is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This file is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this file. If not, see <http://www.gnu.org/licenses/>.
from bdd import Variable, Node, BDD, bddNode, Ordering
import bdd # for eval
import unittest
x = Variable('x')
y = Variable('y')
z = Variable('z')
w = Variable('w')
u = Variable('u')
ordering = Ordering([x, y, z, w, u])
T = Node.T
F = Node.F
def as_bdd(n):
return BDD(n, ordering)
class TestNode(unittest.TestCase):
def setUp(self):
pass
def tearDown(self):
pass
def testTrue(self):
self.assertEqual(T, T)
self.assertNotEqual(T, F)
self.assertEqual(T, eval(repr(T)))
self.assertEqual(as_bdd(T).evaluate({x:False}).root, T)
self.assertEqual(as_bdd(T).evaluate({x:True}).root, T)
self.assertTrue(Node.isTerminal(T))
self.assertEqual(Node.terminal(True), T)
self.assertEqual(Node.terminal([x]), T)
self.assertEqual(~T, F)
def testFalse(self):
self.assertEqual(F, F)
self.assertNotEqual(F, T)
self.assertEqual(F, eval(repr(F)))
self.assertEqual(as_bdd(F).evaluate({x:False}).root, F)
self.assertEqual(as_bdd(F).evaluate({x:True}).root, F)
self.assertTrue(Node.isTerminal(F))
self.assertEqual(Node.terminal(False), F)
self.assertEqual(Node.terminal([]), F)
self.assertEqual(~F, T)
def testSingleNode(self):
n1 = bddNode(x, T, F)
self.assertFalse(Node.isTerminal(n1))
self.assertEqual(n1.countNodes(), 3)
self.assertEqual(n1, n1)
self.assertEqual(n1, eval(repr(n1)))
self.assertEqual(hash(n1), hash(eval(repr(n1))))
self.assertEqual(as_bdd(n1).evaluate({x:True}).root, T)
self.assertEqual(as_bdd(n1).evaluate({x:False}).root, F)
self.assertEqual(as_bdd(n1).evaluate({y:True}).root, n1)
n2 = bddNode(y, T, F)
self.assertNotEqual(n1, n2)
self.assertEqual(n2.countNodes(), 3)
n3 = bddNode(x, T, T)
self.assertNotEqual(n1, n3)
self.assertEqual(n3, T)
self.assertEqual(n3.countNodes(), 1)
n4 = bddNode(x, F, T)
self.assertNotEqual(n1, n4)
self.assertEqual(n4.countNodes(), 3)
self.assertEqual(~n1, n4)
self.assertEqual(~n4, n1)
n5 = bddNode(x, F, F)
self.assertNotEqual(n1, n5)
self.assertEqual(n5, F)
self.assertEqual(n5.countNodes(), 1)
def testNestedNode(self):
n1 = bddNode(z, T, F)
n2 = bddNode(z, F, T)
n3 = bddNode(z, T, F)
cn1 = bddNode(y, n1, n2)
cn2 = bddNode(y, n2, n1)
cn3 = bddNode(y, n3, n2)
self.assertEqual(cn1, cn1)
self.assertNotEqual(cn1, cn2)
self.assertEqual(cn1, cn3)
self.assertEqual(cn1.countNodes(), 5)
self.assertEqual(hash(cn1), hash(cn3))
self.assertEqual(cn1, eval(repr(cn1)))
bdd1 = as_bdd(cn1)
self.assertEqual(bdd1.evaluate({y:True}).root, n1)
self.assertEqual(bdd1.evaluate({y:False}).root, n2)
self.assertEqual(bdd1.evaluate({z:True}).evaluate({y:True}).root, T)
self.assertEqual(bdd1.evaluate({z:True}).evaluate({y:False}).root, F)
self.assertEqual(bdd1.evaluate({z:False}).evaluate({y:True}).root, F)
self.assertEqual(bdd1.evaluate({z:False}).evaluate({y:False}).root, T)
self.assertEqual(bdd1.evaluate({y:True}).evaluate({z:True}).root, T)
self.assertEqual(bdd1.evaluate({y:False}).evaluate({z:True}).root, F)
self.assertEqual(bdd1.evaluate({y:True}).evaluate({z:False}).root, F)
self.assertEqual(bdd1.evaluate({y:False}).evaluate({z:False}).root, T)
self.assertEqual(bdd1.evaluate({z:True, y:True}).root, T)
self.assertEqual(bdd1.evaluate({z:True, y:False}).root, F)
self.assertEqual(bdd1.evaluate({z:False, y:True}).root, F)
self.assertEqual(bdd1.evaluate({z:False, y:False}).root, T)
cn4 = bddNode(z, bddNode(z, T, F), bddNode(z, T, F))
bdd4 = as_bdd(cn4)
self.assertEqual(bdd4.evaluate({z:True}).root, T)
self.assertEqual(bdd4.evaluate({z:False}).root, F)
self.assertEqual(~bdd4.evaluate({z:True}).root, F)
self.assertEqual(~bdd4.evaluate({z:False}).root, T)
self.assertEqual(cn4.countNodes(), 3)
self.assertEqual(~n1, n2)
self.assertEqual(~n2, n1)
self.assertEqual(~~cn1, cn1)
self.assertNotEqual(~cn1, cn1)
def testOrderings(self):
above = ordering.comparator
self.assertTrue(above(y, w))
self.assertFalse(above(z, y))
self.assertTrue(ordering.sort([y, w]) == [y, w])
self.assertTrue(ordering.sort([w, y]) == [y, w])
self.assertTrue(ordering.sort([x, w, y]) == [x, y, w])
def testSimple(self):
self.assertEqual(x.node, bddNode(x, T, F))
self.assertEqual(~(x.node), bddNode(x, F, T))
def testConjunction(self):
self.assertEqual(BDD.conjunction([], ordering).root, T)
self.assertEqual(BDD.conjunction([x], ordering).root, x.node)
self.assertEqual(BDD.conjunction([x, y], ordering).root, bddNode(x, y.node, F))
self.assertEqual(BDD.conjunction([x, y, z], ordering).root, bddNode(x, bddNode(y, z.node, F), F))
def testDisjunction(self):
self.assertEqual(BDD.disjunction([], ordering).root, F)
self.assertEqual(BDD.disjunction([x], ordering).root, x.node)
self.assertEqual(BDD.disjunction([x, y], ordering).root, bddNode(x, T, y.node))
self.assertEqual(BDD.disjunction([x, y, z], ordering).root, bddNode(x, T, bddNode(y, T, z.node)))
def testOr(self):
bddx = as_bdd(x.node)
bddy = as_bdd(y.node)
n1 = BDD.disjunction([x, y], ordering)
n2 = bddx | y.node
self.assertEqual(n1, n2)
n3 = bddx | bddy
self.assertEqual(n1, n3)
n4 = x.node | bddy
self.assertEqual(n1, n4)
def testAnd(self):
bddx = as_bdd(x.node)
bddy = as_bdd(y.node)
n1 = BDD.conjunction([x, y], ordering)
n2 = bddx & y.node
self.assertEqual(n1, n2)
n3 = bddx & bddy
self.assertEqual(n1, n3)
n4 = x.node & bddy
self.assertEqual(n1, n4)
def testComplexApply(self):
bddx = as_bdd(x.node)
bddy = as_bdd(y.node)
xnor = (bddx | ~bddy) & (bddy | ~bddx)
xor = (bddx & ~bddy) | (bddy & ~bddx)
self.assertEqual(xnor.evaluate({x:True,y:True}).root, T)
self.assertEqual(xnor.evaluate({x:True,y:False}).root, F)
self.assertEqual(xnor.evaluate({x:False,y:True}).root, F)
self.assertEqual(xnor.evaluate({x:False,y:False}).root, T)
self.assertEqual(xor.evaluate({x:True,y:True}).root, F)
self.assertEqual(xor.evaluate({x:True,y:False}).root, T)
self.assertEqual(xor.evaluate({x:False,y:True}).root, T)
self.assertEqual(xor.evaluate({x:False,y:False}).root, F)
self.assertEqual(xor, ~xnor)
self.assertEqual((xor | xnor).root, T)
self.assertEqual((xor & xnor).root, F)
def testCountTrue(self):
self.assertEqual(as_bdd(Node.T).countTrue(), 1 << len(ordering))
self.assertEqual(as_bdd(Node.F).countTrue(), 0)
bdd1 = BDD.conjunction([x, y, z, w, u], ordering)
self.assertEqual(bdd1.countTrue(), 1)
bdd2 = BDD.disjunction([x, y, z, w, u], ordering)
self.assertEqual(bdd2.countTrue(), (1 << len(ordering)) - 1)
def testAndAll(self):
bdd1 = BDD.conjunction([x, y, z, w], ordering)
bdd2 = BDD.andAll((as_bdd(v.node) for v in [x, y, z, w]), ordering)
self.assertEqual(bdd1, bdd2)
def testOrAll(self):
bdd1 = BDD.disjunction([x, y, z, w], ordering)
bdd2 = BDD.orAll((as_bdd(v.node) for v in [x, y, z, w]), ordering)
self.assertEqual(bdd1, bdd2)
def testForce(self):
bdd1 = (as_bdd(x.node) & as_bdd(w.node)) | as_bdd(u.node)
bdd2 = (as_bdd(y.node) | as_bdd(z.node))
bdd3 = (as_bdd(y.node) & as_bdd(z.node))
bdd = bdd1.force({w: bdd2, u: bdd3})
self.assertEqual(bdd.countTrue(),
((as_bdd(x.node) & bdd2) | bdd3).countTrue() >> 2)
def testReduce(self):
z1 = bdd.Node(z, Node.T, Node.F)
z2 = bdd.Node(z, Node.T, Node.F)
redundant = bdd.Node(x, z1, z2)
self.assertEqual(redundant.countNodes(), 5)
reduced = as_bdd(redundant).reduce()
self.assertEqual(reduced.root, z.node)
self.assertEqual(reduced.root.countNodes(), 3)
if __name__ == '__main__':
unittest.main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment