Created
December 16, 2021 17:44
-
-
Save r3domfox/630599498e1e1128818e7b31a7522a0a to your computer and use it in GitHub Desktop.
AOC2021 Day 16 (Python)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from functools import reduce | |
LITERAL = 4 | |
TOTAL_LENGTH = 0 | |
def hex_convert(c): | |
d = int(c, 16) | |
b = 8 | |
for i in range(4): | |
yield 1 if d & b else 0 | |
b = b >> 1 | |
def to_bitstream(hex_string): | |
return [bit for c in hex_string for bit in hex_convert(c)] | |
class PacketParser(object): | |
def __init__(self, bit_stream): | |
self.bit_stream = bit_stream | |
self.pos = 0 | |
def take_bits(self, bit_count): | |
bits = self.bit_stream[self.pos:self.pos + bit_count] | |
self.pos += bit_count | |
return bits | |
def take_number(self, bit_count): | |
bits = self.take_bits(bit_count) | |
number = 0 | |
for bit in bits: | |
number = (number << 1) + bit | |
return number | |
def take_packet_version(self): | |
return self.take_number(3) | |
def take_packet_type(self): | |
return self.take_number(3) | |
def take_literal_digit(self): | |
has_more = self.take_bits(1)[0] == 1 | |
digit = self.take_number(4) | |
return has_more, digit | |
def take_literal(self): | |
literal = 0 | |
taken_count = 0 | |
while True: | |
has_more, digit = self.take_literal_digit() | |
literal = (literal << 4) + digit | |
taken_count += 5 | |
if not has_more: | |
break | |
hex_count = taken_count >> 2 | |
pad = (hex_count << 2) - taken_count | |
if pad > 0: | |
self.take_bits(pad) | |
return literal | |
def take_length_type_id(self): | |
return self.take_number(1) | |
def take_subpacket_length(self): | |
return self.take_number(15) | |
def take_subpacket_count(self): | |
return self.take_number(11) | |
def take_packet(self): | |
packet_version = self.take_packet_version() | |
packet_type = self.take_packet_type() | |
if packet_type == LITERAL: | |
return packet_version, packet_type, self.take_literal() | |
length_type = self.take_length_type_id() | |
if length_type == TOTAL_LENGTH: | |
subpacket_length = self.take_subpacket_length() | |
sub_parser = PacketParser(self.take_bits(subpacket_length)) | |
return packet_version, packet_type, sub_parser.take_all_packets() | |
else: | |
subpacket_count = self.take_subpacket_count() | |
return packet_version, packet_type, [self.take_packet() for _ in range(subpacket_count)] | |
def take_all_packets(self): | |
packets = [] | |
while self.pos < len(self.bit_stream): | |
packets.append(self.take_packet()) | |
return packets | |
def sum_versions(packet_info): | |
packet_version, packet_type, packet_data = packet_info | |
if packet_type == LITERAL: | |
return packet_version | |
else: | |
return packet_version + sum(sum_versions(packet) for packet in packet_data) | |
def parse_hexstring(hexstring): | |
bit_stream = to_bitstream(hexstring) | |
parser = PacketParser(bit_stream) | |
return parser.take_packet() | |
operators = { | |
0: lambda ps: sum(ps), | |
1: lambda ps: reduce(lambda a, b: a * b, ps), | |
2: lambda ps: min(ps), | |
3: lambda ps: max(ps), | |
5: lambda ps: 1 if ps[0] > ps[1] else 0, | |
6: lambda ps: 1 if ps[0] < ps[1] else 0, | |
7: lambda ps: 1 if ps[0] == ps[1] else 0 | |
} | |
def interpret(packet): | |
version, operator, ps = packet | |
if operator == LITERAL: | |
return ps | |
ps_literals = [interpret(p) for p in ps] | |
return operators[operator](ps_literals) | |
def test_sum_versions(): | |
assert parse_hexstring('D2FE28') == (6, 4, 2021) | |
assert parse_hexstring('38006F45291200') == (1, 6, [(6, 4, 10), (2, 4, 20)]) | |
assert parse_hexstring('EE00D40C823060') == (7, 3, [(2, 4, 1), (4, 4, 2), (1, 4, 3)]) | |
assert sum_versions(parse_hexstring('8A004A801A8002F478')) == 16 | |
assert sum_versions(parse_hexstring('620080001611562C8802118E34')) == 12 | |
assert sum_versions(parse_hexstring('C0015000016115A2E0802F182340')) == 23 | |
assert sum_versions(parse_hexstring('A0016C880162017C3686B18A3D4780')) == 31 | |
print() | |
with open("puzzle_inputs/day16.txt") as file: | |
hex_string = next(file).strip() | |
packet_tree = parse_hexstring(hex_string) | |
print(sum_versions(packet_tree)) | |
print(interpret(packet_tree)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment