This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
import collections | |
from z3 import Int, And, Or, solve, PbEq, Solver, Implies | |
A, B, C, D, E, F, G = range(7) | |
DIGITS = [ | |
{A, B, C, E, F, G}, # 0 | |
{C, F}, # 1 | |
{A, C, D, E, G}, # 2 | |
{A, C, D, F, G}, # 3 | |
{B, C, D, F}, # 4 | |
{A, B, D, F, G}, # 5 | |
{A, B, D, E, F, G}, # 6 | |
{A, C, F}, # 7 | |
{A, B, C, D, E, F, G}, # 8 | |
{A, B, C, D, F, G} # 9 | |
] | |
ALL_SEGMENTS = {A, B, C, D, E, F, G} | |
def exactly_one_must_be_true(equalities): | |
'Encode the contraint that exactly one of a set of booleans is true' | |
if len(equalities) == 1: | |
return equalities[0] | |
else: | |
return PbEq([(o, 1) for o in equalities], 1) | |
def solve_signal_patterns(signal_patterns): | |
solver = Solver() | |
all_10_digit_clauses = collections.defaultdict(list) | |
for pattern in signal_patterns: | |
pattern = ''.join(sorted(pattern)) | |
for i, digit in enumerate(DIGITS): | |
if len(digit) != len(pattern): | |
continue | |
eq = Int(pattern) == i | |
# Add constraints on having each digit appear exactly once | |
all_10_digit_clauses[i].append(eq) | |
all_10_digit_clauses[pattern].append(eq) | |
# Add constraints related to the segment wirings | |
if len(digit) >= 7: | |
# adding these permutations is slow, | |
# but we can add the inverse constraints | |
for segment in ALL_SEGMENTS - digit: | |
for p in pattern: | |
solver.add(Implies(eq, p != segment)) | |
else: | |
digit_permutations_clauses = [] | |
for digit_perm in itertools.permutations(digit): | |
perm_correct = And(*[Int(p) == d for p, d in zip(pattern, digit_perm)]) | |
digit_permutations_clauses.append(perm_correct) | |
solver.add(Implies(eq, exactly_one_must_be_true(digit_permutations_clauses))) | |
# Add constraints on having each digit appear exactly once | |
for clause in all_10_digit_clauses.values(): | |
solver.add(exactly_one_must_be_true(clause)) | |
solver.check() | |
model = solver.model() | |
return {str(k):model[k].as_long() for k in model} | |
def solve_line(line): | |
signal_patterns = line.split(' | ')[0].split() | |
four_digits = line.split(' | ')[1].split() | |
decode = solve_signal_patterns(signal_patterns) | |
digits = [] | |
for digit in four_digits: | |
digit = ''.join(sorted(digit)) | |
digits.append(decode[digit]) | |
print(line.rstrip()) | |
return int(''.join(map(str, digits))) | |
from multiprocessing import Pool | |
lines = list(open('./input/8.txt')) | |
with Pool(12) as pool: | |
outputs = pool.map(solve_line, lines) | |
print(sum(outputs)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment