Skip to content

Instantly share code, notes, and snippets.

@llandsmeer
Created Dec 9, 2021
Embed
What would you like to do?
import itertools
import collections
from z3 import Int, And, Or, solve, PbEq, Solver, Implies
A, B, C, D, E, F, G = range(7)
DIGITS = [
{A, B, C, E, F, G}, # 0
{C, F}, # 1
{A, C, D, E, G}, # 2
{A, C, D, F, G}, # 3
{B, C, D, F}, # 4
{A, B, D, F, G}, # 5
{A, B, D, E, F, G}, # 6
{A, C, F}, # 7
{A, B, C, D, E, F, G}, # 8
{A, B, C, D, F, G} # 9
]
ALL_SEGMENTS = {A, B, C, D, E, F, G}
def exactly_one_must_be_true(equalities):
'Encode the contraint that exactly one of a set of booleans is true'
if len(equalities) == 1:
return equalities[0]
else:
return PbEq([(o, 1) for o in equalities], 1)
def solve_signal_patterns(signal_patterns):
solver = Solver()
all_10_digit_clauses = collections.defaultdict(list)
for pattern in signal_patterns:
pattern = ''.join(sorted(pattern))
for i, digit in enumerate(DIGITS):
if len(digit) != len(pattern):
continue
eq = Int(pattern) == i
# Add constraints on having each digit appear exactly once
all_10_digit_clauses[i].append(eq)
all_10_digit_clauses[pattern].append(eq)
# Add constraints related to the segment wirings
if len(digit) >= 7:
# adding these permutations is slow,
# but we can add the inverse constraints
for segment in ALL_SEGMENTS - digit:
for p in pattern:
solver.add(Implies(eq, p != segment))
else:
digit_permutations_clauses = []
for digit_perm in itertools.permutations(digit):
perm_correct = And(*[Int(p) == d for p, d in zip(pattern, digit_perm)])
digit_permutations_clauses.append(perm_correct)
solver.add(Implies(eq, exactly_one_must_be_true(digit_permutations_clauses)))
# Add constraints on having each digit appear exactly once
for clause in all_10_digit_clauses.values():
solver.add(exactly_one_must_be_true(clause))
solver.check()
model = solver.model()
return {str(k):model[k].as_long() for k in model}
def solve_line(line):
signal_patterns = line.split(' | ')[0].split()
four_digits = line.split(' | ')[1].split()
decode = solve_signal_patterns(signal_patterns)
digits = []
for digit in four_digits:
digit = ''.join(sorted(digit))
digits.append(decode[digit])
print(line.rstrip())
return int(''.join(map(str, digits)))
from multiprocessing import Pool
lines = list(open('./input/8.txt'))
with Pool(12) as pool:
outputs = pool.map(solve_line, lines)
print(sum(outputs))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment