Created
December 16, 2021 17:04
-
-
Save unbibium/16d0ca712266ef3abd35357cb2fe0768 to your computer and use it in GitHub Desktop.
AoC 2021 day 16: BITS interpreter
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import sys, os,math | |
from collections import deque | |
class Bitstream: | |
def __init__(self, hexstring, length=math.inf): | |
# length only controls when EOF flag is set | |
self.chars = deque(hexstring) | |
self.bits = [] | |
self.eof = False | |
self.length = length | |
def read_bit(self): | |
if self.eof: | |
return 0 | |
self.length -= 1 | |
if self.length == 0 or not (self.chars or any(self.bits)): | |
self.eof=True | |
if not self.bits: | |
if self.chars: | |
hexchr = self.chars.popleft() | |
hexint = int(hexchr,16) | |
else: # consider raising an error instead | |
raise EOFError | |
self.bits = deque( map(hexint.__and__, [8,4,2,1]) ) | |
return 1 if self.bits.popleft() else 0 | |
def read_hex(self, bitcount): | |
whole, part = divmod(bitcount,4) | |
result = "" | |
for i in range(whole): | |
result += "0123456789ABCDEF"[self.read_int(4)] | |
if part: | |
result += "0123456789ABCDEF"[self.read_int(part) << 4-part] | |
return result | |
def read_int(self, bitcount): | |
result = 0 | |
for i in range(bitcount): | |
result = result*2 + self.read_bit() | |
return result | |
class Packet: | |
def __init__(self, source): | |
if type(source) is str: source = Bitstream(source) | |
self.source = source | |
self.version = source.read_int(3) | |
self.type_id = source.read_int(3) | |
if self.type_id == 4: # literal | |
more, self.literal_int = 1, 0 | |
while more: | |
more = source.read_bit() | |
self.literal_int = self.literal_int * 16 + source.read_int(4) | |
self.version_sum = self.version | |
self.subpackets = None | |
else: | |
self.subpackets = [] | |
self.length_type_id = source.read_int(1) | |
if self.length_type_id == 0: | |
bits_in_subpacket = source.read_int(15) | |
# very inefficient converting back and forth but it's OK | |
hexsource = source.read_hex(bits_in_subpacket) | |
stream = Bitstream(hexsource, bits_in_subpacket) | |
while not stream.eof: | |
self.subpackets.append( Packet(stream) ) | |
else: | |
subpacket_count = source.read_int(11) | |
self.subpackets = list(Packet(source) for i in range(subpacket_count)) | |
self.version_sum = self.version + sum(sub.version_sum for sub in self.subpackets) | |
def __getitem__(self, i): | |
return self.subpackets[i] | |
def __iter__(self, i): | |
return iter(self.subpackets) | |
funcs = [ | |
sum, | |
math.prod, | |
min, | |
max, | |
int, #unused | |
lambda a: int.__gt__(*a), | |
lambda a: int.__lt__(*a), | |
lambda a: int.__eq__(*a) | |
] | |
fmts = [ | |
lambda a: "(" + ("+".join(a)) + ")", | |
lambda a: "(" + ("*".join(a)) + ")", | |
lambda a: "min([" + (",".join(a)) + "])", | |
lambda a: "max([" + (",".join(a)) + "])", | |
str, | |
lambda a: "(%s > %s)" % tuple(a), | |
lambda a: "(%s < %s)" % tuple(a), | |
lambda a: "(%s == %s)" % tuple(a) | |
] | |
def __str__(self): | |
if self.type_id == 4: | |
return str(self.literal_int) | |
fmt = Packet.fmts[self.type_id] | |
return fmt(str(packet) for packet in self.subpackets) | |
def calc(self): | |
if self.type_id == 4: | |
return self.literal_int | |
func = Packet.funcs[self.type_id] | |
return func(packet.calc() for packet in self.subpackets) | |
def part1(lines): | |
bs = Bitstream(lines[0].rstrip()) | |
result = 0 | |
while not bs.eof: | |
result += Packet(bs).version_sum | |
return result | |
def part2(lines): | |
bs = Bitstream(lines[0].rstrip()) | |
p = Packet(bs) | |
print(str(p)) | |
result = p.calc() | |
print("leftover:", bs.read_hex(32)) | |
return result | |
if __name__ == '__main__': | |
if len(sys.argv)<2: | |
print("Usage",sys.argv[0],"filename") | |
sys.exit(1) | |
with open(sys.argv[1]) as f: | |
lines = f.readlines() | |
print("part1", part1(lines)) | |
print("part2", part2(lines)) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import unittest | |
from bits import * | |
class BitstreamTest(unittest.TestCase): | |
def test_odd_stuff(self): | |
b = Bitstream("D2FE28") | |
self.assertEqual(b.read_int(3),6) | |
self.assertEqual(b.read_int(3),4) | |
def test_popstr(self): | |
b = Bitstream("D2FE28") | |
self.assertEqual(b.read_hex(21), "D2FE28") | |
b = Bitstream("D2FE28") | |
self.assertEqual(b.read_hex(20), "D2FE2") | |
class Part1Test(unittest.TestCase): | |
def test_foo(self): | |
p = Packet("D2FE28") | |
self.assertEqual(p.version,6) | |
self.assertEqual(p.type_id,4) | |
self.assertEqual(p.literal_int, 2021) | |
def test_example_two(self): | |
p = Packet("38006F45291200") | |
self.assertEqual(p.version,1) | |
self.assertEqual(p.type_id,6) | |
self.assertEqual(p[0].type_id, 4) | |
self.assertEqual(p[0].literal_int, 10) | |
self.assertEqual(p[1].type_id, 4) | |
self.assertEqual(p[1].literal_int, 20) | |
def test_example_three(self): | |
p = Packet("EE00D40C823060") | |
self.assertEqual(p.version,7) | |
self.assertEqual(p.type_id,3) | |
self.assertEqual(len(p.subpackets), 3) | |
self.assertEqual(p[0].type_id, 4) | |
self.assertEqual(p[0].literal_int, 1) | |
self.assertEqual(p[0].calc(), 1) | |
self.assertEqual(p[1].type_id, 4) | |
self.assertEqual(p[1].literal_int, 2) | |
self.assertEqual(p[1].calc(), 2) | |
self.assertEqual(p[2].type_id, 4) | |
self.assertEqual(p[2].literal_int, 3) | |
self.assertEqual(p[2].calc(), 3) | |
def test_example_four_a(self): | |
p = Packet("8A004A801A8002F478") | |
self.assertEqual(p.version,4) | |
self.assertEqual(p[0].version,1) | |
self.assertEqual(p[0][0].version,5) | |
self.assertEqual(p[0][0][0].version,6) | |
def test_version_sums(self): | |
p = Packet("8A004A801A8002F478") | |
self.assertEqual(p.version_sum, 16) | |
p = Packet("620080001611562C8802118E34") | |
self.assertEqual(p.version_sum, 12) | |
p = Packet("C0015000016115A2E0802F182340") | |
self.assertEqual(p.version_sum, 23) | |
p = Packet("A0016C880162017C3686B18A3D4780") | |
self.assertEqual(p.version_sum, 31) | |
class TestPart2(unittest.TestCase): | |
def assertCalcEquals(self, source, expected): | |
self.assertEqual(Packet(source).calc(), expected, source) | |
def test_list(self): | |
self.assertCalcEquals("C200B40A82",3) | |
def test_product(self): | |
self.assertCalcEquals("04005AC33890",54) | |
def test_min(self): | |
self.assertCalcEquals("880086C3E88112",7) | |
def test_max(self): | |
self.assertCalcEquals("CE00C43D881120",9) | |
def test_lt(self): | |
self.assertCalcEquals("D8005AC2A8F0",1) | |
def test_gt(self): | |
self.assertCalcEquals("F600BC2D8F",0) | |
def test_eq(self): | |
self.assertCalcEquals("9C005AC2F8F0",0) | |
def test_sums_eq(self): | |
self.assertCalcEquals("9C0141080250320F1802104A08",1) | |
if __name__ == '__main__': | |
unittest.main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment