Last active
September 17, 2021 19:02
-
-
Save abadams/9eaf42972f17d8dc6e3f4ad28d18cd98 to your computer and use it in GitHub Desktop.
z3py query to solve for an averaging tree
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import z3 | |
#kernel = [1, 2, 1] | |
#ops = 2 | |
kernel = [1, 4, 6, 4, 1] | |
ops = 10 | |
kernel_sum = sum(kernel) | |
bits = 0 | |
while 2 ** bits < kernel_sum: | |
bits += 1 | |
print("Using %d bits" % bits) | |
values = [z3.BitVec('v%d' % i, bits) for i in range(len(kernel))] | |
indicators = [] | |
round_up = [] | |
constraints = [] | |
# Each value is an integer linear combination of the previous | |
# values. There are N indicator variables for which inputs are used, | |
# of which at most 2 can be true. This means we only need one extra | |
# bit for our integer linear combination. Then there's one more | |
# indicator variable for the rounding mode. | |
for i in range(ops): | |
# Add an extra bit to do the summation in | |
sum = z3.BitVecVal(0, bits + 1) | |
inds = [] | |
# All previous values are candidate inputs | |
for (j, v) in enumerate(values): | |
# Make the indicator variable that indicates this value should be used | |
d = z3.Bool('b_%d_%d' % (i, j)) | |
inds.append(d) | |
sum += z3.If(d, z3.ZeroExt(1, v), z3.BitVecVal(0, bits + 1)) | |
indicators.append(inds) | |
# Make the indicator variable that sets the rounding mode | |
d = z3.Bool('r_%d' % i) | |
round_up.append(d) | |
sum = z3.If(d, sum + z3.BitVecVal(1, bits + 1), sum) | |
sum = z3.Extract(bits, 1, sum) | |
# Exactly two of the indicator variables are true | |
constraints.append(z3.AtMost(*inds, 2)) | |
constraints.append(z3.AtLeast(*inds, 2)) | |
values.append(sum) | |
correct = z3.BitVecVal(0, bits * 2) | |
for (i, k) in enumerate(kernel): | |
correct += z3.BitVecVal(k, bits * 2) * z3.ZeroExt(bits, values[i]) | |
correct_ties_up = correct + z3.BitVecVal(kernel_sum/2, bits * 2) | |
correct_ties_down = correct + z3.BitVecVal(kernel_sum/2 - 1, bits * 2) | |
correct_ties_up = z3.Extract(2*bits - 1, bits, correct_ties_up) | |
correct_ties_down = z3.Extract(2*bits - 1, bits, correct_ties_down) | |
# For now don't worry about bias, and just assert that the averaging tree either rounds ties up or down | |
c = z3.Or((values[-1] == correct_ties_up, values[-1] == correct_ties_down)) | |
constraints.append(z3.ForAll(values[:len(kernel)], c)) | |
s = z3.Solver() | |
s.add(*constraints) | |
print(s.sexpr()) | |
print(s.check()) | |
print(s.unsat_core()) | |
model = s.model() | |
print(model) | |
idx = len(kernel) | |
for op in range(ops): | |
args = [] | |
for (i, ind) in enumerate(indicators[op]): | |
if z3.is_true(model[ind]): | |
args.append(i) | |
print('v%d = avg(%d, %d, %d)' % (idx, args[0], args[1], z3.is_true(round_up[op]))) | |
idx += 1 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment