Forked from dougallj/gist:9211fd24c3759f7f340dede28929c659
Last active
January 13, 2020 07:07
-
-
Save zwegner/0202c60b9410d029b5cfc5c5643e3374 to your computer and use it in GitHub Desktop.
Ternary logic multiplication (0, 1, unknown)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
N_BITS = 8 | |
MASK = (1 << N_BITS) - 1 | |
# Bit class, which stores values as a truth table of up to two variables | |
class Bit: | |
def __init__(self, op, bit_a=None, bit_b=None): | |
self.op = op | |
# Remove variables when the truth table doesn't depend on them | |
if self.op is not None and self.op & 0x3 == self.op >> 2 & 0x3: | |
bit_a = None | |
if self.op is not None and self.op & 0x5 == self.op >> 1 & 0x5: | |
bit_b = None | |
self.bit_a = bit_a | |
self.bit_b = bit_b | |
def __and__(self, other): | |
return check_op(self, other, lambda a, b: a & b) | |
def __or__(self, other): | |
return check_op(self, other, lambda a, b: a | b) | |
def __xor__(self, other): | |
return check_op(self, other, lambda a, b: a ^ b) | |
def __repr__(self): | |
if self.op is None: | |
return 'x' | |
return '%s(%s, %s)' % (OPS[self.op], self.bit_a, self.bit_b) | |
# Merge two bits with a given operation. This code kinda sucks as far as | |
# weird special cases go. The main logic is using binary ops (& | ^) on | |
# the truth tables themselves. | |
def check_op(x, y, fn): | |
# Check if the operation doesn't depend on one value, in which case | |
# we can ignore fully-unknown values (where op is None) | |
if x.op is not None and fn(x.op, ZERO) == fn(x.op, ONE) == x.op: | |
assert x.op in (ZERO, ONE) | |
return Bit(x.op) | |
if y.op is not None and fn(y.op, ZERO) == fn(y.op, ONE) == y.op: | |
assert y.op in (ZERO, ONE) | |
return Bit(y.op) | |
# op is None -> unknown | |
if x.op is None or y.op is None: | |
return UNKNOWN | |
# Make sure both bits are equations of the same two (or fewer) variables | |
if x.bit_a is not None and y.bit_a is not None and x.bit_a != y.bit_a: | |
return UNKNOWN | |
if x.bit_b is not None and y.bit_b is not None and x.bit_b != y.bit_b: | |
return UNKNOWN | |
bit_a = y.bit_a if x.bit_a is None else x.bit_a | |
bit_b = y.bit_b if x.bit_b is None else x.bit_b | |
return Bit(fn(x.op, y.op), bit_a, bit_b) | |
# Full adder of three Bit objects | |
def add_3(a, b, c): | |
s = a ^ b | |
sum = s ^ c | |
carry = (a & b) | (s & c) | |
return (carry, sum) | |
# Ops: four-bit truth tables for binary operations on two bits | |
OPS = { | |
0b0000: 'ZERO', | |
0b1111: 'ONE', | |
0b0110: 'XOR', | |
0b1001: 'NXOR', | |
0b1110: 'OR', | |
0b0001: 'NOR', | |
0b1000: 'AND', | |
0b0111: 'NAND', | |
0b1100: 'A', | |
0b0011: 'NA', | |
0b1010: 'B', | |
0b0101: 'NB', | |
} | |
for k, v in OPS.items(): | |
globals()[v] = k | |
UNKNOWN = Bit(None) | |
class Ternary: | |
def __init__(self, ones, unknowns): | |
self.ones = ones & MASK | |
self.unknowns = unknowns & MASK | |
assert (self.ones & self.unknowns) == 0, (bin(self.ones), bin(self.unknowns)) | |
def __add__(self, other): | |
x = self.ones + other.ones | |
u = self.unknowns | other.unknowns | (x ^ (x + self.unknowns + other.unknowns)) | |
return Ternary(x & ~u, u) | |
def __or__(self, other): | |
o = self.ones | other.ones | |
return Ternary(o, (self.unknowns | other.unknowns) & ~o) | |
def __lshift__(self, count): | |
return Ternary(self.ones << count, self.unknowns << count) | |
# Quick linear algorithm, imprecise | |
def dumb_mul(self, other): | |
result = Ternary(0, 0) | |
for i in range(N_BITS): | |
if self.ones & 1 << i: | |
result += other << i | |
elif self.unknowns & 1 << i: | |
u = other << i | |
result += Ternary(0, u.ones | u.unknowns) | |
return result | |
def __mul__(self, other): | |
result = [Bit(ZERO) for i in range(N_BITS)] | |
# Convert other to Bits | |
other_bits = [] | |
for i in range(N_BITS): | |
if other.unknowns & 1 << i: | |
bit = Bit(B, bit_b='b%s' % i) | |
else: | |
bit = Bit(ONE if other.ones >> i & 1 else ZERO) | |
other_bits.append(bit) | |
# Run up to N_BITS additions, one for each bit of self, and each taking | |
# up to N_BITS steps, as we run a full adder each time | |
for i in range(N_BITS): | |
if self.ones & 1 << i: | |
carry = Bit(ZERO) | |
for j in range(i, N_BITS): | |
carry, result[j] = add_3(result[j], other_bits[j - i], carry) | |
elif self.unknowns & 1 << i: | |
carry = Bit(ZERO) | |
u_bit = Bit(A, bit_a='a%s' % i) | |
for j in range(i, N_BITS): | |
carry, result[j] = add_3(result[j], u_bit & other_bits[j - i], carry) | |
# Convert from Bit objects to the simplified Ternary | |
ones = 0 | |
unknowns = 0 | |
for i in range(N_BITS): | |
if result[i].op == ONE: | |
ones |= 1 << i | |
elif result[i].op != ZERO: | |
unknowns |= 1 << i | |
return Ternary(ones, unknowns) | |
def __repr__(self): | |
return ''.join('x' if self.unknowns & 1 << i else str(self.ones >> i & 1) | |
for i in reversed(range(N_BITS))) | |
def iter_values(self): | |
ones = self.ones | |
for value in iter_subsets(self.unknowns): | |
yield ones | value | |
def iter_subsets(mask): | |
value = 0 | |
while True: | |
yield value | |
value = (value - mask) & mask | |
if value == 0: | |
break | |
def union(a, b): | |
return Ternary(a.ones & b.ones, a.unknowns | b.unknowns | (a.ones ^ b.ones)) | |
def slow_op(a, b, op): | |
r = Ternary(op(a.ones, b.ones), 0) | |
for A in a.iter_values(): | |
for B in b.iter_values(): | |
r = union(r, Ternary(op(A, B), 0)) | |
return r | |
def iter_ternary_values(input_size=4): | |
for o in range(1 << input_size): | |
for u in iter_subsets((1 << input_size) - 1 & ~o): | |
yield Ternary(o, u) | |
def test_op(f, name, input_size=4): | |
good = 0 | |
bad = 0 | |
for a, b in itertools.product(iter_ternary_values(input_size), repeat=2): | |
r0 = slow_op(a, b, f) | |
r1 = f(a, b) | |
if repr(r0) != repr(r1): | |
print('%r %s %r -> slow=%r, fast=%r, dumb=%r' % (a, name, b, r0, r1, a.dumb_mul(b))) | |
bad += 1 | |
else: | |
good += 1 | |
print('testing %r: %d good, %d bad' % (name, good, bad)) | |
test_op(lambda a, b: (a * b), "*") | |
test_op(lambda a, b: (a + b), "+") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment