Skip to content

Instantly share code, notes, and snippets.

@r3domfox
Created December 16, 2021 17:44
Show Gist options
  • Save r3domfox/630599498e1e1128818e7b31a7522a0a to your computer and use it in GitHub Desktop.
Save r3domfox/630599498e1e1128818e7b31a7522a0a to your computer and use it in GitHub Desktop.
AOC2021 Day 16 (Python)
from functools import reduce
LITERAL = 4
TOTAL_LENGTH = 0
def hex_convert(c):
d = int(c, 16)
b = 8
for i in range(4):
yield 1 if d & b else 0
b = b >> 1
def to_bitstream(hex_string):
return [bit for c in hex_string for bit in hex_convert(c)]
class PacketParser(object):
def __init__(self, bit_stream):
self.bit_stream = bit_stream
self.pos = 0
def take_bits(self, bit_count):
bits = self.bit_stream[self.pos:self.pos + bit_count]
self.pos += bit_count
return bits
def take_number(self, bit_count):
bits = self.take_bits(bit_count)
number = 0
for bit in bits:
number = (number << 1) + bit
return number
def take_packet_version(self):
return self.take_number(3)
def take_packet_type(self):
return self.take_number(3)
def take_literal_digit(self):
has_more = self.take_bits(1)[0] == 1
digit = self.take_number(4)
return has_more, digit
def take_literal(self):
literal = 0
taken_count = 0
while True:
has_more, digit = self.take_literal_digit()
literal = (literal << 4) + digit
taken_count += 5
if not has_more:
break
hex_count = taken_count >> 2
pad = (hex_count << 2) - taken_count
if pad > 0:
self.take_bits(pad)
return literal
def take_length_type_id(self):
return self.take_number(1)
def take_subpacket_length(self):
return self.take_number(15)
def take_subpacket_count(self):
return self.take_number(11)
def take_packet(self):
packet_version = self.take_packet_version()
packet_type = self.take_packet_type()
if packet_type == LITERAL:
return packet_version, packet_type, self.take_literal()
length_type = self.take_length_type_id()
if length_type == TOTAL_LENGTH:
subpacket_length = self.take_subpacket_length()
sub_parser = PacketParser(self.take_bits(subpacket_length))
return packet_version, packet_type, sub_parser.take_all_packets()
else:
subpacket_count = self.take_subpacket_count()
return packet_version, packet_type, [self.take_packet() for _ in range(subpacket_count)]
def take_all_packets(self):
packets = []
while self.pos < len(self.bit_stream):
packets.append(self.take_packet())
return packets
def sum_versions(packet_info):
packet_version, packet_type, packet_data = packet_info
if packet_type == LITERAL:
return packet_version
else:
return packet_version + sum(sum_versions(packet) for packet in packet_data)
def parse_hexstring(hexstring):
bit_stream = to_bitstream(hexstring)
parser = PacketParser(bit_stream)
return parser.take_packet()
operators = {
0: lambda ps: sum(ps),
1: lambda ps: reduce(lambda a, b: a * b, ps),
2: lambda ps: min(ps),
3: lambda ps: max(ps),
5: lambda ps: 1 if ps[0] > ps[1] else 0,
6: lambda ps: 1 if ps[0] < ps[1] else 0,
7: lambda ps: 1 if ps[0] == ps[1] else 0
}
def interpret(packet):
version, operator, ps = packet
if operator == LITERAL:
return ps
ps_literals = [interpret(p) for p in ps]
return operators[operator](ps_literals)
def test_sum_versions():
assert parse_hexstring('D2FE28') == (6, 4, 2021)
assert parse_hexstring('38006F45291200') == (1, 6, [(6, 4, 10), (2, 4, 20)])
assert parse_hexstring('EE00D40C823060') == (7, 3, [(2, 4, 1), (4, 4, 2), (1, 4, 3)])
assert sum_versions(parse_hexstring('8A004A801A8002F478')) == 16
assert sum_versions(parse_hexstring('620080001611562C8802118E34')) == 12
assert sum_versions(parse_hexstring('C0015000016115A2E0802F182340')) == 23
assert sum_versions(parse_hexstring('A0016C880162017C3686B18A3D4780')) == 31
print()
with open("puzzle_inputs/day16.txt") as file:
hex_string = next(file).strip()
packet_tree = parse_hexstring(hex_string)
print(sum_versions(packet_tree))
print(interpret(packet_tree))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment