Skip to content

Instantly share code, notes, and snippets.

@aledesole
Created December 22, 2020 00:31
Show Gist options
  • Save aledesole/27538ab9ced0a39af03b93bdec882a08 to your computer and use it in GitHub Desktop.
Save aledesole/27538ab9ced0a39af03b93bdec882a08 to your computer and use it in GitHub Desktop.
Advent of code 2020, Day 20, Python3.9
from sys import stdin
from functools import reduce
from itertools import product
from collections import defaultdict
# Return all 8 transformations of the image
def tr(l, img):
rt = lambda m,_: m+[[[m[-1][l-i-1][j]
for i in range(l)] for j in range(l)]]
fl = lambda m,_: m+[[[m[-1][l-i-1][j]
for j in range(l)] for i in range(l)]]
return reduce(rt, range(3),
reduce(fl, range(1),
reduce (rt, range(3), [img])))
# Tests each tile against every other and return graph of matches
def analize(tiles):
v, h = defaultdict(set), defaultdict(set)
for ((ltid, limgs), (rtid, rimgs),
ltr, rtr) in product(tiles.items(), tiles.items(),
range(8), range(8)):
if ltid == rtid:
continue
elif limgs[ltr][-1] == rimgs[rtr][0]:
v[ltid].add((rtid, ltr, rtr))
elif all(l[-1] == r[0] for l,r in zip(limgs[ltr], rimgs[rtr])):
h[ltid].add((rtid, ltr, rtr))
return (v,h)
# Return correct sequence of tiles of size n*n for the puzzle
def sequence(n, rel):
keys = set(rel[0].keys())
ctid = min(keys, key=lambda x: len(rel[0][x]))
pr = lambda x: (n*(x//n-1) + n-(x%n)-1) if x>=n else -1
fit = lambda x, y: any(z[0] == y for z in rel[0][x])
return reduce(
lambda rs, i: [(z + [x], tids-{x})
for z,tids in rs for x in tids
if fit(x, z[-1]) and fit(x, z[pr(i+1)])],
range(len(keys)-1), [([ctid], keys-{ctid})])[0][0]
# Return the final picture without tile borders
def assemble(n, tiles, seq, rel):
res = [[' ']*8*n for _ in range(8*n)]
for i,tid in enumerate(seq):
tr = {j for j in range(8)}
row, col = (n-i%n-1 if (i//n)%2 else i%n), i//n
if col>0:
ptid = seq[n*(col-1) + n-(i%n)-1]
tr &= {r for t,l,r in rel[1][ptid] if t==tid}
if col+1<n:
ntid = seq[n*(col+1) + n-(i%n)-1]
tr &= {l for t,l,r in rel[1][tid] if t==ntid}
if row>0:
ptid = seq[n*col + i%n + (1 if col%2 else -1)]
tr &= {r for t,l,r in rel[0][ptid] if t==tid}
if row+1<n:
ntid = seq[n*col + i%n + (-1 if col%2 else 1)]
tr &= {l for t,l,r in rel[0][tid] if t==ntid}
img = tiles[tid][next(x for x in tr)]
for i,r in enumerate(img[1:-1]):
for j,v in enumerate(r[1:-1]):
res[8*row+i][8*col+j] = v
return res
# Count all non-overlapping points inside the image that match the pattern
def match(img, patt):
count = sum(v == '#' for x in patt for v in x)
res = 0
for dx,dy in product(range(len(img[0])-len(patt[0])+1),
range(len(img)-len(patt)+1)):
if (all(img[dy+i][dx+j] != '.' for i,r in enumerate(patt)
for j,v in enumerate(r) if v == '#')):
res += count
return res
def pretty_print(img):
for r in img:
print (''.join(r))
L = 12
tiles = {int(x[0][5:9]): tr(10,[list(z) for z in x[1:]])
for x in [x.splitlines() for x in stdin.read().split('\n\n')]}
rel = analize(tiles)
seq = sequence(L, rel)
# Part1
print (reduce(lambda x,y: x*y,
[seq[i] for i in (0, L-1, L*(L-1), L*L-1)], 1))
monster = '''
#
# ## ## ###
# # # # # #
'''.split('\n')[1:-1]
for img in tr(8*L, assemble(L, tiles, seq, rel)):
if res := match(img, monster):
pretty_print(img)
# Part 2
print(sum(v == '#' for x in img for v in x) - res)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment