Skip to content

Instantly share code, notes, and snippets.

@llandsmeer
Created December 9, 2021 00:56
Show Gist options
  • Save llandsmeer/6b38036a1decdc04eea008d552d47bfd to your computer and use it in GitHub Desktop.
Save llandsmeer/6b38036a1decdc04eea008d552d47bfd to your computer and use it in GitHub Desktop.
import itertools
import collections
from z3 import Int, And, Or, solve, PbEq, Solver, Implies
A, B, C, D, E, F, G = range(7)
DIGITS = [
{A, B, C, E, F, G}, # 0
{C, F}, # 1
{A, C, D, E, G}, # 2
{A, C, D, F, G}, # 3
{B, C, D, F}, # 4
{A, B, D, F, G}, # 5
{A, B, D, E, F, G}, # 6
{A, C, F}, # 7
{A, B, C, D, E, F, G}, # 8
{A, B, C, D, F, G} # 9
]
ALL_SEGMENTS = {A, B, C, D, E, F, G}
def exactly_one_must_be_true(equalities):
'Encode the contraint that exactly one of a set of booleans is true'
if len(equalities) == 1:
return equalities[0]
else:
return PbEq([(o, 1) for o in equalities], 1)
def solve_signal_patterns(signal_patterns):
solver = Solver()
all_10_digit_clauses = collections.defaultdict(list)
for pattern in signal_patterns:
pattern = ''.join(sorted(pattern))
for i, digit in enumerate(DIGITS):
if len(digit) != len(pattern):
continue
eq = Int(pattern) == i
# Add constraints on having each digit appear exactly once
all_10_digit_clauses[i].append(eq)
all_10_digit_clauses[pattern].append(eq)
# Add constraints related to the segment wirings
if len(digit) >= 7:
# adding these permutations is slow,
# but we can add the inverse constraints
for segment in ALL_SEGMENTS - digit:
for p in pattern:
solver.add(Implies(eq, p != segment))
else:
digit_permutations_clauses = []
for digit_perm in itertools.permutations(digit):
perm_correct = And(*[Int(p) == d for p, d in zip(pattern, digit_perm)])
digit_permutations_clauses.append(perm_correct)
solver.add(Implies(eq, exactly_one_must_be_true(digit_permutations_clauses)))
# Add constraints on having each digit appear exactly once
for clause in all_10_digit_clauses.values():
solver.add(exactly_one_must_be_true(clause))
solver.check()
model = solver.model()
return {str(k):model[k].as_long() for k in model}
def solve_line(line):
signal_patterns = line.split(' | ')[0].split()
four_digits = line.split(' | ')[1].split()
decode = solve_signal_patterns(signal_patterns)
digits = []
for digit in four_digits:
digit = ''.join(sorted(digit))
digits.append(decode[digit])
print(line.rstrip())
return int(''.join(map(str, digits)))
from multiprocessing import Pool
lines = list(open('./input/8.txt'))
with Pool(12) as pool:
outputs = pool.map(solve_line, lines)
print(sum(outputs))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment