Skip to content

Instantly share code, notes, and snippets.

@abadams
Last active September 17, 2021 19:02
Show Gist options
  • Save abadams/9eaf42972f17d8dc6e3f4ad28d18cd98 to your computer and use it in GitHub Desktop.
Save abadams/9eaf42972f17d8dc6e3f4ad28d18cd98 to your computer and use it in GitHub Desktop.
z3py query to solve for an averaging tree
import z3
#kernel = [1, 2, 1]
#ops = 2
kernel = [1, 4, 6, 4, 1]
ops = 10
kernel_sum = sum(kernel)
bits = 0
while 2 ** bits < kernel_sum:
bits += 1
print("Using %d bits" % bits)
values = [z3.BitVec('v%d' % i, bits) for i in range(len(kernel))]
indicators = []
round_up = []
constraints = []
# Each value is an integer linear combination of the previous
# values. There are N indicator variables for which inputs are used,
# of which at most 2 can be true. This means we only need one extra
# bit for our integer linear combination. Then there's one more
# indicator variable for the rounding mode.
for i in range(ops):
# Add an extra bit to do the summation in
sum = z3.BitVecVal(0, bits + 1)
inds = []
# All previous values are candidate inputs
for (j, v) in enumerate(values):
# Make the indicator variable that indicates this value should be used
d = z3.Bool('b_%d_%d' % (i, j))
inds.append(d)
sum += z3.If(d, z3.ZeroExt(1, v), z3.BitVecVal(0, bits + 1))
indicators.append(inds)
# Make the indicator variable that sets the rounding mode
d = z3.Bool('r_%d' % i)
round_up.append(d)
sum = z3.If(d, sum + z3.BitVecVal(1, bits + 1), sum)
sum = z3.Extract(bits, 1, sum)
# Exactly two of the indicator variables are true
constraints.append(z3.AtMost(*inds, 2))
constraints.append(z3.AtLeast(*inds, 2))
values.append(sum)
correct = z3.BitVecVal(0, bits * 2)
for (i, k) in enumerate(kernel):
correct += z3.BitVecVal(k, bits * 2) * z3.ZeroExt(bits, values[i])
correct_ties_up = correct + z3.BitVecVal(kernel_sum/2, bits * 2)
correct_ties_down = correct + z3.BitVecVal(kernel_sum/2 - 1, bits * 2)
correct_ties_up = z3.Extract(2*bits - 1, bits, correct_ties_up)
correct_ties_down = z3.Extract(2*bits - 1, bits, correct_ties_down)
# For now don't worry about bias, and just assert that the averaging tree either rounds ties up or down
c = z3.Or((values[-1] == correct_ties_up, values[-1] == correct_ties_down))
constraints.append(z3.ForAll(values[:len(kernel)], c))
s = z3.Solver()
s.add(*constraints)
print(s.sexpr())
print(s.check())
print(s.unsat_core())
model = s.model()
print(model)
idx = len(kernel)
for op in range(ops):
args = []
for (i, ind) in enumerate(indicators[op]):
if z3.is_true(model[ind]):
args.append(i)
print('v%d = avg(%d, %d, %d)' % (idx, args[0], args[1], z3.is_true(round_up[op])))
idx += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment